# envs/finqa_env/server/finqa_environment.py """ FinQA Environment Implementation. A financial question-answering environment that evaluates LLMs on their ability to answer complex financial questions using tool calls on SEC 10-K filing data. """ import logging import os import random import uuid from typing import Any, Dict, List, Optional import pandas as pd from fastmcp import FastMCP from openenv.core.env_server.mcp_environment import MCPEnvironment from openenv.core.env_server.mcp_types import CallToolAction from openenv.core.env_server.types import Action, Observation from ..models import AVAILABLE_TOOLS, FinQAState from .rewards import compute_reward from .tools import FinQATools logger = logging.getLogger(__name__) class FinQAEnvironment(MCPEnvironment): """ Financial QA environment for RL training. Evaluates agents on their ability to answer financial questions by: - Exploring available tables for a company - Querying table metadata and executing SQL queries - Performing calculations - Submitting final answers Args: data_path: Path to the data directory containing benchmark_questions/ and input_companies/ max_steps: Maximum number of tool calls per episode (default: 50) task: Task name - currently only 'finqa' supported (default: 'finqa') """ def __init__( self, data_path: str = "./data", max_steps: int = 50, task: str = "finqa", ): # Create MCP server and define tools inline mcp = FastMCP("finqa_env") self.data_path = data_path self.max_steps = max_steps self.task = task assert task == "finqa", "Only finqa task is supported" self.questions = self._load_questions() logger.info(f"Loaded {len(self.questions)} questions for task '{task}'") self._finqa_tools = FinQATools(data_path) # Register tools with FastMCP @mcp.tool def get_descriptions(company_name: str) -> str: """ Get a list of available table names for a company. Args: company_name: The name of the company Returns: JSON list of table names """ return self._finqa_tools.get_descriptions(company_name) @mcp.tool def get_table_info(company_name: str, table_name: str) -> str: """ Get table metadata: description, columns, types, unique values. Args: company_name: The name of the company table_name: The name of the table Returns: JSON string with table metadata """ return self._finqa_tools.get_table_info(company_name, table_name) @mcp.tool def sql_query(company_name: str, table_name: str, query: str) -> str: """ Execute a SQL query on a table. Select * not allowed. Filters are required: WHERE, HAVING, IN, NOT IN, EXISTS, NOT EXISTS, ANY, SOME, ALL, LIKE, NOT LIKE, BETWEEN, NOT BETWEEN, IS NULL, IS NOT NULL, CASE, FILTER. Args: company_name: The name of the company table_name: The name of the table query: SQL query to execute (must include filters) Returns: JSON string with query results """ return self._finqa_tools.sql_query(company_name, table_name, query) @mcp.tool def submit_answer(answer: str) -> str: """ Submit a final answer for the question. Args: answer: The final answer to submit Returns: Confirmation message """ return self._finqa_tools.submit_answer(answer) # Pass the MCP server to the base class super().__init__(mcp) # Shuffle dataset for sequential selection self._shuffled_questions = self.questions.copy() random.shuffle(self._shuffled_questions) self._question_index = 0 self._state = FinQAState() self._history: List[Dict[str, Any]] = [] def _load_questions(self) -> List[Dict[str, Any]]: """Load questions from the benchmark CSV.""" csv_path = os.path.join( self.data_path, "benchmark_questions", f"{self.task}.csv" ) if not os.path.isfile(csv_path): raise FileNotFoundError(f"Benchmark file not found: {csv_path}") df = pd.read_csv(csv_path) questions = [] for _, row in df.iterrows(): questions.append( { "id": str(row.get("id", "")), "user_query": row["user_query"], "company": row["company"], "question": row["question"], "answer": row["answer"], "question_type": row.get("question_type", ""), "explanation": row.get("explanation", ""), } ) return questions def _get_next_question(self) -> Dict[str, Any]: """Get the next question using sequential shuffle selection.""" if self._question_index >= len(self._shuffled_questions): random.shuffle(self._shuffled_questions) self._question_index = 0 question = self._shuffled_questions[self._question_index] self._question_index += 1 return question def reset( self, seed: Optional[int] = None, episode_id: Optional[str] = None, **kwargs: Any, ) -> Observation: """ Reset the environment for a new episode. Returns: Initial observation with the question """ question = self._get_next_question() self._state = FinQAState( episode_id=episode_id or str(uuid.uuid4()), step_count=0, current_question=question["user_query"], current_company=question["company"], ground_truth=question["answer"], question_id=question["id"], ) self._history = [] logger.info( f"Reset episode {self._state.episode_id} with question: {question['question'][:200]}..." ) return Observation( done=False, reward=0.0, metadata={ "question": question["user_query"], "company": question["company"], "tool_result": "", "history": [], "step_count": 0, "available_tools": AVAILABLE_TOOLS.copy(), }, ) def _step_impl( self, action: Action, timeout_s: Optional[float] = None, **kwargs: Any, ) -> Observation: """ Handle non-MCP actions. Returns an error since this env is MCP-only. """ return Observation( done=False, reward=0.0, metadata={ "error": f"Unknown action type: {type(action).__name__}. " "Use ListToolsAction or CallToolAction for MCP interactions." }, ) def step( self, action: Action, timeout_s: Optional[float] = None, **kwargs: Any, ) -> Observation: """ Execute a step in the environment. Delegates to base class for MCP actions. Handles submit_answer reward computation and max-step termination. """ self._state.step_count += 1 # Let the base class handle MCP actions obs = super().step(action, timeout_s=timeout_s, **kwargs) # Check if submit_answer was called if isinstance(action, CallToolAction) and action.tool_name == "submit_answer": submitted_answer = action.arguments.get("answer", "") reward = compute_reward(submitted_answer, self._state.ground_truth) logger.info( f"Episode {self._state.episode_id} ended: " f"submitted='{submitted_answer}', truth='{self._state.ground_truth}', reward={reward}" ) return Observation( done=True, reward=reward, metadata={ **obs.metadata, "ground_truth": self._state.ground_truth, "submitted_answer": submitted_answer, }, ) # Check for max steps if self._state.step_count >= self.max_steps: logger.info( f"Episode {self._state.episode_id} terminated: max steps reached" ) return Observation( done=True, reward=0.0, metadata={ **obs.metadata, "error": f"Max steps ({self.max_steps}) reached without submitting answer.", }, ) return obs @property def state(self) -> FinQAState: """Get the current environment state.""" return self._state