| |
| """ |
| Tool implementations for the FinQA environment. |
| |
| Ported from FinQABenchmark with simplifications: |
| - Removed LangChain dependencies |
| - Added submit_answer tool for episode termination |
| """ |
|
|
| import json |
| import os |
| import re |
| import sqlite3 |
| from typing import Any, Dict, List, Tuple |
|
|
| import pandas as pd |
|
|
|
|
| class FinQATools: |
| """ |
| Tool implementations for financial QA tasks. |
| |
| Args: |
| data_path: Path to the data directory containing benchmark_questions/ and input_companies/ |
| """ |
|
|
| def __init__(self, data_path: str): |
| self.data_path = data_path |
| self.companies_path = os.path.join(data_path, "input_companies") |
| self._tables_cleaned = None |
|
|
| @property |
| def tables_cleaned(self) -> Dict: |
| """Lazy load the cleaned tables metadata.""" |
| if self._tables_cleaned is None: |
| tables_path = os.path.join( |
| self.companies_path, "tables_cleaned_all_companies.json" |
| ) |
| with open(tables_path, "r") as f: |
| self._tables_cleaned = json.load(f) |
| return self._tables_cleaned |
|
|
| def get_available_companies(self) -> List[str]: |
| """Get list of available company names.""" |
| return [ |
| d |
| for d in os.listdir(self.companies_path) |
| if os.path.isdir(os.path.join(self.companies_path, d)) |
| ] |
|
|
| def execute_tool( |
| self, tool_name: str, tool_args: Dict[str, Any] |
| ) -> Tuple[str, bool]: |
| """ |
| Execute a tool and return its result. |
| |
| Args: |
| tool_name: Name of the tool to execute |
| tool_args: Arguments for the tool |
| |
| Returns: |
| Tuple of (result_string, is_final_answer) |
| """ |
| if tool_name == "get_descriptions": |
| return self.get_descriptions(**tool_args), False |
| elif tool_name == "get_table_info": |
| return self.get_table_info(**tool_args), False |
| elif tool_name == "sql_query": |
| return self.sql_query(**tool_args), False |
| elif tool_name == "submit_answer": |
| return self.submit_answer(**tool_args), True |
| else: |
| return f"Error: Unknown tool '{tool_name}'", False |
|
|
| def get_descriptions(self, company_name: str) -> str: |
| """ |
| Get a list of available table names for a company. |
| |
| Args: |
| company_name: The name of the company |
| |
| Returns: |
| JSON list of table names |
| """ |
| company_path = os.path.join(self.companies_path, company_name) |
|
|
| if not os.path.isdir(company_path): |
| available = self.get_available_companies() |
| return ( |
| f"Error: '{company_name}' not found. Available companies: {available}" |
| ) |
|
|
| |
| tables = [] |
| for f in os.listdir(company_path): |
| if f.endswith(".json"): |
| tables.append(f.replace(".json", "")) |
|
|
| return json.dumps(tables) |
|
|
| def get_table_info(self, company_name: str, table_name: str) -> str: |
| """ |
| Get table metadata: description, columns, types, unique values. |
| |
| Args: |
| company_name: The name of the company |
| table_name: The name of the table |
| |
| Returns: |
| JSON string with table metadata (description, columns, dtypes, unique values) |
| """ |
| company_path = os.path.join(self.companies_path, company_name) |
|
|
| if not os.path.isdir(company_path): |
| available = self.get_available_companies() |
| return ( |
| f"Error: '{company_name}' not found. Available companies: {available}" |
| ) |
|
|
| |
| cleaned_table_name = table_name.replace(".json", "").replace(".txt", "") |
| table_key = f"{company_name}/{cleaned_table_name}" |
|
|
| if table_key not in self.tables_cleaned: |
| return f"Error: Table '{table_name}' not found for company '{company_name}'" |
|
|
| table_info = self.tables_cleaned[table_key].copy() |
|
|
| |
| cleaned_table = pd.DataFrame(json.loads(table_info["table"])) |
|
|
| |
| cols_to_drop = [] |
| for col in cleaned_table.columns.tolist()[1:]: |
| vals = cleaned_table[col].tolist()[1:] |
| cleaned_vals = [ |
| "".join(char for char in str(x) if char.isalnum()).strip() for x in vals |
| ] |
| all_numeric = all(v.isnumeric() or len(v) == 0 for v in cleaned_vals) |
| if all_numeric: |
| cols_to_drop.append(col) |
|
|
| table_info["column_dtypes"] = { |
| col: str(cleaned_table[col].dtype) for col in cleaned_table.columns.tolist() |
| } |
|
|
| |
| cleaned_table_filtered = cleaned_table.drop(cols_to_drop, axis=1) |
| table_info["unique_vals_per_col"] = { |
| col: list(cleaned_table_filtered[col].unique()) |
| for col in cleaned_table_filtered.columns.tolist() |
| } |
|
|
| |
| del table_info["table"] |
|
|
| return json.dumps(table_info, indent=0).replace("\n", "") |
|
|
| def sql_query(self, company_name: str, table_name: str, query: str) -> str: |
| """ |
| Execute a SQL query on a table. Select * not allowed (too inefficient). |
| |
| Filters are required to query: WHERE, HAVING, IN, NOT IN, EXISTS, NOT EXISTS, ANY, SOME, ALL, LIKE, NOT LIKE, BETWEEN, NOT BETWEEN, IS NULL, IS NOT NULL, CASE, FILTER. |
| |
| Args: |
| company_name: The name of the company |
| table_name: The name of the table |
| query: SQL query to execute (must include filters) |
| |
| Returns: |
| JSON string with query results |
| """ |
| |
| if "select *" in query.lower(): |
| return "Error: SELECT * is not allowed (too inefficient)" |
|
|
| sql_filters = [ |
| "WHERE", |
| "HAVING", |
| "IN", |
| "NOT IN", |
| "EXISTS", |
| "NOT EXISTS", |
| "ANY", |
| "SOME", |
| "ALL", |
| "LIKE", |
| "NOT LIKE", |
| "BETWEEN", |
| "NOT BETWEEN", |
| "IS NULL", |
| "IS NOT NULL", |
| "CASE", |
| "FILTER", |
| ] |
|
|
| query_upper = re.sub(r"(\r|\n|\t)+", " ", query).upper() |
| pattern = ( |
| r"(?<!\w|\[)(" |
| + "|".join([re.escape(f) for f in sql_filters]) |
| + r")(?!\w|\])" |
| ) |
|
|
| has_filter = ( |
| any(f" {filt} " in query_upper for filt in sql_filters) |
| or len(re.findall(pattern, query_upper)) > 0 |
| ) |
|
|
| if not has_filter: |
| return "Error: Query must include filters (WHERE, HAVING, etc.)" |
|
|
| |
| cleaned_table_name = table_name.replace(".txt", "").replace(".json", "") |
| table_path = os.path.join( |
| self.companies_path, company_name, f"{cleaned_table_name}.json" |
| ) |
|
|
| if not os.path.isfile(table_path): |
| return f"Error: Table file not found at {table_path}" |
|
|
| try: |
| |
| conn = sqlite3.connect(":memory:") |
| df = pd.read_json(table_path) |
| df.to_sql(cleaned_table_name, conn, index=False, if_exists="replace") |
| result = pd.read_sql_query(query, conn) |
| conn.close() |
|
|
| return result.to_json(orient="records") |
| except Exception as e: |
| return f"Error executing query: {str(e)}" |
|
|
| def submit_answer(self, answer: str) -> str: |
| """ |
| Submit a final answer for the question. |
| |
| Args: |
| answer: The final answer to submit |
| |
| Returns: |
| Confirmation message |
| """ |
| return f"Answer submitted: {answer}" |
|
|