| |
| """ |
| FinQA Environment Implementation. |
| |
| A financial question-answering environment that evaluates LLMs on their ability |
| to answer complex financial questions using tool calls on SEC 10-K filing data. |
| """ |
|
|
| import logging |
| import os |
| import random |
| import uuid |
| from typing import Any, Dict, List, Optional |
|
|
| import pandas as pd |
| from fastmcp import FastMCP |
| from openenv.core.env_server.mcp_environment import MCPEnvironment |
| from openenv.core.env_server.mcp_types import CallToolAction |
| from openenv.core.env_server.types import Action, Observation |
|
|
| from ..models import AVAILABLE_TOOLS, FinQAState |
| from .rewards import compute_reward |
| from .tools import FinQATools |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class FinQAEnvironment(MCPEnvironment): |
| """ |
| Financial QA environment for RL training. |
| |
| Evaluates agents on their ability to answer financial questions by: |
| - Exploring available tables for a company |
| - Querying table metadata and executing SQL queries |
| - Performing calculations |
| - Submitting final answers |
| |
| Args: |
| data_path: Path to the data directory containing benchmark_questions/ and input_companies/ |
| max_steps: Maximum number of tool calls per episode (default: 50) |
| task: Task name - currently only 'finqa' supported (default: 'finqa') |
| """ |
|
|
| def __init__( |
| self, |
| data_path: str = "./data", |
| max_steps: int = 50, |
| task: str = "finqa", |
| ): |
| |
| mcp = FastMCP("finqa_env") |
|
|
| self.data_path = data_path |
| self.max_steps = max_steps |
| self.task = task |
|
|
| assert task == "finqa", "Only finqa task is supported" |
|
|
| self.questions = self._load_questions() |
| logger.info(f"Loaded {len(self.questions)} questions for task '{task}'") |
|
|
| self._finqa_tools = FinQATools(data_path) |
|
|
| |
| @mcp.tool |
| def get_descriptions(company_name: str) -> str: |
| """ |
| Get a list of available table names for a company. |
| |
| Args: |
| company_name: The name of the company |
| |
| Returns: |
| JSON list of table names |
| """ |
| return self._finqa_tools.get_descriptions(company_name) |
|
|
| @mcp.tool |
| def get_table_info(company_name: str, table_name: str) -> str: |
| """ |
| Get table metadata: description, columns, types, unique values. |
| |
| Args: |
| company_name: The name of the company |
| table_name: The name of the table |
| |
| Returns: |
| JSON string with table metadata |
| """ |
| return self._finqa_tools.get_table_info(company_name, table_name) |
|
|
| @mcp.tool |
| def sql_query(company_name: str, table_name: str, query: str) -> str: |
| """ |
| Execute a SQL query on a table. Select * not allowed. |
| |
| Filters are required: WHERE, HAVING, IN, NOT IN, EXISTS, NOT EXISTS, |
| ANY, SOME, ALL, LIKE, NOT LIKE, BETWEEN, NOT BETWEEN, IS NULL, |
| IS NOT NULL, CASE, FILTER. |
| |
| Args: |
| company_name: The name of the company |
| table_name: The name of the table |
| query: SQL query to execute (must include filters) |
| |
| Returns: |
| JSON string with query results |
| """ |
| return self._finqa_tools.sql_query(company_name, table_name, query) |
|
|
| @mcp.tool |
| def submit_answer(answer: str) -> str: |
| """ |
| Submit a final answer for the question. |
| |
| Args: |
| answer: The final answer to submit |
| |
| Returns: |
| Confirmation message |
| """ |
| return self._finqa_tools.submit_answer(answer) |
|
|
| |
| super().__init__(mcp) |
|
|
| |
| self._shuffled_questions = self.questions.copy() |
| random.shuffle(self._shuffled_questions) |
| self._question_index = 0 |
|
|
| self._state = FinQAState() |
| self._history: List[Dict[str, Any]] = [] |
|
|
| def _load_questions(self) -> List[Dict[str, Any]]: |
| """Load questions from the benchmark CSV.""" |
| csv_path = os.path.join( |
| self.data_path, "benchmark_questions", f"{self.task}.csv" |
| ) |
|
|
| if not os.path.isfile(csv_path): |
| raise FileNotFoundError(f"Benchmark file not found: {csv_path}") |
|
|
| df = pd.read_csv(csv_path) |
|
|
| questions = [] |
| for _, row in df.iterrows(): |
| questions.append( |
| { |
| "id": str(row.get("id", "")), |
| "user_query": row["user_query"], |
| "company": row["company"], |
| "question": row["question"], |
| "answer": row["answer"], |
| "question_type": row.get("question_type", ""), |
| "explanation": row.get("explanation", ""), |
| } |
| ) |
|
|
| return questions |
|
|
| def _get_next_question(self) -> Dict[str, Any]: |
| """Get the next question using sequential shuffle selection.""" |
| if self._question_index >= len(self._shuffled_questions): |
| random.shuffle(self._shuffled_questions) |
| self._question_index = 0 |
|
|
| question = self._shuffled_questions[self._question_index] |
| self._question_index += 1 |
| return question |
|
|
| def reset( |
| self, |
| seed: Optional[int] = None, |
| episode_id: Optional[str] = None, |
| **kwargs: Any, |
| ) -> Observation: |
| """ |
| Reset the environment for a new episode. |
| |
| Returns: |
| Initial observation with the question |
| """ |
| question = self._get_next_question() |
| self._state = FinQAState( |
| episode_id=episode_id or str(uuid.uuid4()), |
| step_count=0, |
| current_question=question["user_query"], |
| current_company=question["company"], |
| ground_truth=question["answer"], |
| question_id=question["id"], |
| ) |
| self._history = [] |
|
|
| logger.info( |
| f"Reset episode {self._state.episode_id} with question: {question['question'][:200]}..." |
| ) |
|
|
| return Observation( |
| done=False, |
| reward=0.0, |
| metadata={ |
| "question": question["user_query"], |
| "company": question["company"], |
| "tool_result": "", |
| "history": [], |
| "step_count": 0, |
| "available_tools": AVAILABLE_TOOLS.copy(), |
| }, |
| ) |
|
|
| def _step_impl( |
| self, |
| action: Action, |
| timeout_s: Optional[float] = None, |
| **kwargs: Any, |
| ) -> Observation: |
| """ |
| Handle non-MCP actions. Returns an error since this env is MCP-only. |
| """ |
| return Observation( |
| done=False, |
| reward=0.0, |
| metadata={ |
| "error": f"Unknown action type: {type(action).__name__}. " |
| "Use ListToolsAction or CallToolAction for MCP interactions." |
| }, |
| ) |
|
|
| def step( |
| self, |
| action: Action, |
| timeout_s: Optional[float] = None, |
| **kwargs: Any, |
| ) -> Observation: |
| """ |
| Execute a step in the environment. |
| |
| Delegates to base class for MCP actions. Handles submit_answer |
| reward computation and max-step termination. |
| """ |
| self._state.step_count += 1 |
|
|
| |
| obs = super().step(action, timeout_s=timeout_s, **kwargs) |
|
|
| |
| if isinstance(action, CallToolAction) and action.tool_name == "submit_answer": |
| submitted_answer = action.arguments.get("answer", "") |
| reward = compute_reward(submitted_answer, self._state.ground_truth) |
| logger.info( |
| f"Episode {self._state.episode_id} ended: " |
| f"submitted='{submitted_answer}', truth='{self._state.ground_truth}', reward={reward}" |
| ) |
| return Observation( |
| done=True, |
| reward=reward, |
| metadata={ |
| **obs.metadata, |
| "ground_truth": self._state.ground_truth, |
| "submitted_answer": submitted_answer, |
| }, |
| ) |
|
|
| |
| if self._state.step_count >= self.max_steps: |
| logger.info( |
| f"Episode {self._state.episode_id} terminated: max steps reached" |
| ) |
| return Observation( |
| done=True, |
| reward=0.0, |
| metadata={ |
| **obs.metadata, |
| "error": f"Max steps ({self.max_steps}) reached without submitting answer.", |
| }, |
| ) |
|
|
| return obs |
|
|
| @property |
| def state(self) -> FinQAState: |
| """Get the current environment state.""" |
| return self._state |
|
|