bhavishya2895 commited on
Commit
7b0eff4
·
verified ·
1 Parent(s): df8da83

Upload folder using huggingface_hub

Browse files
Dockerfile ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Multi-stage build using openenv-base
2
+ # This Dockerfile is flexible and works for both:
3
+ # - In-repo environments (with local src/core)
4
+ # - Standalone environments (with openenv-core from pip)
5
+ # The build script (openenv build) handles context detection and sets appropriate build args.
6
+
7
+ ARG BASE_IMAGE=ghcr.io/meta-pytorch/openenv-base:latest
8
+ FROM ${BASE_IMAGE} AS builder
9
+
10
+ WORKDIR /app
11
+
12
+ # Build argument to control whether we're building standalone or in-repo
13
+ ARG BUILD_MODE=in-repo
14
+
15
+ # Copy environment code (always at root of build context)
16
+ COPY . /app/env
17
+
18
+ # For in-repo builds, openenv-core is already in the pyproject.toml dependencies
19
+ # For standalone builds, openenv-core will be installed from pip via pyproject.toml
20
+ WORKDIR /app/env
21
+
22
+ # Ensure uv is available (for local builds where base image lacks it)
23
+ RUN if ! command -v uv >/dev/null 2>&1; then \
24
+ curl -LsSf https://astral.sh/uv/install.sh | sh && \
25
+ mv /root/.local/bin/uv /usr/local/bin/uv && \
26
+ mv /root/.local/bin/uvx /usr/local/bin/uvx; \
27
+ fi
28
+
29
+ # Install git for building from git repos (build-time only)
30
+ RUN apt-get update && apt-get install -y --no-install-recommends \
31
+ git \
32
+ && rm -rf /var/lib/apt/lists/*
33
+
34
+ # Install dependencies using uv sync
35
+ # First pass: install dependencies without the project (for better caching)
36
+ # Second pass: install the project itself
37
+ RUN --mount=type=cache,target=/root/.cache/uv \
38
+ if [ -f uv.lock ]; then \
39
+ uv sync --frozen --no-install-project --no-editable; \
40
+ else \
41
+ uv sync --no-install-project --no-editable; \
42
+ fi
43
+
44
+ RUN --mount=type=cache,target=/root/.cache/uv \
45
+ if [ -f uv.lock ]; then \
46
+ uv sync --frozen --no-editable; \
47
+ else \
48
+ uv sync --no-editable; \
49
+ fi
50
+
51
+ # Final runtime stage
52
+ FROM ${BASE_IMAGE}
53
+
54
+ WORKDIR /app
55
+
56
+ # Copy the virtual environment from builder
57
+ COPY --from=builder /app/env/.venv /app/.venv
58
+
59
+ # Copy the environment code
60
+ COPY --from=builder /app/env /app/env
61
+
62
+ # Set PATH to use the virtual environment
63
+ ENV PATH="/app/.venv/bin:$PATH"
64
+
65
+ # Set PYTHONPATH so imports work correctly
66
+ ENV PYTHONPATH="/app/env:$PYTHONPATH"
67
+
68
+ # Environment variables with defaults
69
+ ENV FINQA_DATA_PATH="/app/env/data"
70
+ ENV FINQA_MAX_STEPS="50"
71
+ ENV FINQA_TASK="finqa"
72
+
73
+ # Download data from HuggingFace at build time (requires network)
74
+ RUN pip install --no-cache-dir huggingface_hub[cli] && \
75
+ bash /app/env/download_data.sh snorkelai/finqa-data /app/env/data
76
+
77
+ # Health check using Python (more portable than curl/wget)
78
+ HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
79
+ CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:8000/health')" || exit 1
80
+
81
+ # Run the FastAPI server
82
+ ENV ENABLE_WEB_INTERFACE=true
83
+ CMD ["sh", "-c", "cd /app/env && uvicorn server.app:app --host 0.0.0.0 --port 8000"]
README.md CHANGED
@@ -1,10 +1,199 @@
1
  ---
2
- title: Finqa Env
3
- emoji: 📊
4
- colorFrom: pink
5
- colorTo: purple
6
  sdk: docker
7
  pinned: false
 
 
 
 
 
 
8
  ---
9
 
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: FinQA Environment Server
3
+ emoji: 🔊
4
+ colorFrom: blue
5
+ colorTo: gray
6
  sdk: docker
7
  pinned: false
8
+ app_port: 8000
9
+ base_path: /web
10
+ datasets:
11
+ - snorkelai/finqa-data
12
+ tags:
13
+ - openenv
14
  ---
15
 
16
+ # FinQA Environment
17
+
18
+ A financial question-answering environment for RL training. Evaluates LLMs on their ability to answer complex financial questions using tool calls on SEC 10-K filing data.
19
+
20
+ Based on [FinQABenchmark](https://github.com/snorkel-ai/FinQABenchmark) from Snorkel AI.
21
+
22
+ ## Overview
23
+
24
+ FinQA tests an agent's ability to:
25
+ - Explore available financial tables for a company
26
+ - Query table metadata and execute SQL queries
27
+ - Perform calculations on extracted data
28
+ - Submit final answers to financial questions
29
+
30
+ **Dataset**: 290 questions from SEC 10-K filings across multiple companies (Alphabet, Amazon, Apple, AT&T, etc.)
31
+
32
+ **Reward**: Binary (1.0 for correct answer, 0.0 for incorrect) using fuzzy numerical matching with 1% tolerance.
33
+
34
+ > **Note**: This dataset is for evaluation only. Do not train on it.
35
+
36
+ ## Quick Start
37
+
38
+ ### Using Docker
39
+
40
+ ```bash
41
+ # Build the image (from OpenEnv repo root)
42
+ docker build -t finqa-env:latest -f envs/finqa_env/server/Dockerfile .
43
+
44
+ # Run the server
45
+ docker run -p 8000:8000 finqa-env:latest
46
+
47
+ # To run evaluation script (example model gpt-5)
48
+ API_BASE_URL=https://api.openai.com/v1 API_KEY=$OPENAI_API_KEY MODEL=gpt-5 python examples/finqa_inference.py
49
+ ```
50
+
51
+ ### Local Development
52
+
53
+ ```bash
54
+ # Install dependencies
55
+ uv pip install pandas
56
+
57
+ # Download data from HuggingFace
58
+ cd envs/finqa_env
59
+ ./download_data.sh
60
+ ```
61
+
62
+ ### Using the Client
63
+
64
+ The client uses the MCP protocol and is async by default:
65
+
66
+ ```python
67
+ import asyncio
68
+ from envs.finqa_env import FinQAEnv, CallToolAction
69
+
70
+ async def main():
71
+ async with FinQAEnv(base_url="http://localhost:8000") as env:
72
+ # Reset to get a question
73
+ obs = await env.reset()
74
+ question = obs.metadata["question"]
75
+ company = obs.metadata["company"]
76
+ print(f"Question: {question}")
77
+ print(f"Company: {company}")
78
+
79
+ # Discover available tools
80
+ tools = await env.list_tools()
81
+ print([t.name for t in tools])
82
+
83
+ # Use tools via call_tool (convenience method)
84
+ result = await env.call_tool("get_descriptions", company_name=company)
85
+ print(f"Available tables: {result}")
86
+
87
+ # Or use step() with CallToolAction for full observation access
88
+ step_result = await env.step(CallToolAction(
89
+ tool_name="sql_query",
90
+ arguments={
91
+ "company_name": "alphabet",
92
+ "table_name": "us_gaap_ScheduleOfIncomeBeforeIncomeTaxDomesticAndForeignTableTextBlock",
93
+ "query": "SELECT * FROM data WHERE year = '2022'"
94
+ }
95
+ ))
96
+ print(f"Done: {step_result.done}, Reward: {step_result.reward}")
97
+
98
+ # Submit answer
99
+ result = await env.call_tool("submit_answer", answer="6.118")
100
+
101
+ asyncio.run(main())
102
+ ```
103
+
104
+ ## Available Tools
105
+
106
+ Tools are auto-discovered via MCP. Use `await env.list_tools()` to see all available tools at runtime.
107
+
108
+ | Tool | Description | Arguments |
109
+ |------|-------------|-----------|
110
+ | `get_descriptions` | Get list of available table names for a company | `company_name: str` |
111
+ | `get_table_info` | Get table metadata (columns, dtypes, unique values) | `company_name: str, table_name: str` |
112
+ | `sql_query` | Execute SQL query on a table (requires filters) | `company_name: str, table_name: str, query: str` |
113
+ | `submit_answer` | Submit final answer (ends episode) | `answer: str` |
114
+
115
+ ### Tool Constraints
116
+
117
+ - **sql_query**: Must include filters (`WHERE`, `HAVING`, etc.). `SELECT *` is not allowed.
118
+
119
+ ## Environment Variables
120
+
121
+ | Variable | Default | Description |
122
+ |----------|---------|-------------|
123
+ | `FINQA_DATA_PATH` | `/app/env/data` | Path to data directory |
124
+ | `FINQA_MAX_STEPS` | `50` | Maximum tool calls per episode |
125
+ | `FINQA_TASK` | `finqa` | Task name |
126
+
127
+ ## Reward Computation
128
+
129
+ Rewards use fuzzy numerical matching:
130
+
131
+ - Extracts numbers from `\boxed{...}` format
132
+ - Handles percentages, fractions, and decimals
133
+ - 1% relative tolerance or 0.01 absolute tolerance
134
+ - Returns `1.0` for correct, `0.0` for incorrect
135
+
136
+ ## Local Development
137
+
138
+ ```bash
139
+ # From OpenEnv repo root
140
+ cd envs/finqa_env
141
+
142
+ # Run server locally
143
+ FINQA_DATA_PATH=./data uvicorn server.app:app --reload --port 8000
144
+
145
+ # Test with curl
146
+ curl http://localhost:8000/health
147
+ curl -X POST http://localhost:8000/reset
148
+ ```
149
+
150
+ ## Integration with RL Frameworks
151
+
152
+ ### TRL (GRPO)
153
+
154
+ ```python
155
+ import asyncio
156
+ from trl import GRPOTrainer
157
+ from envs.finqa_env import FinQAEnv
158
+
159
+ async def rollout_func(prompts, trainer):
160
+ async with FinQAEnv(base_url="http://localhost:8000") as env:
161
+ obs = await env.reset()
162
+ # Your agent logic here using await env.call_tool(...)
163
+ return {"reward": obs.reward, "completion": completion}
164
+
165
+ trainer = GRPOTrainer(
166
+ model=model,
167
+ rollout_func=rollout_func,
168
+ ...
169
+ )
170
+ ```
171
+
172
+ ## Project Structure
173
+
174
+ ```
175
+ finqa_env/
176
+ ├── __init__.py # Exports FinQAEnv, CallToolAction, ListToolsAction
177
+ ├── models.py # FinQAState and tool name constants
178
+ ├── client.py # MCP client (subclasses MCPToolClient)
179
+ ├── pyproject.toml # Dependencies
180
+ ├── README.md # This file
181
+ ├── data/ # Benchmark data (run download_data.sh)
182
+ │ ├── benchmark_questions/
183
+ │ │ └── finqa.csv
184
+ │ └── input_companies/
185
+ │ └── [company folders]
186
+ ├── download_data.sh # Downloads data from HuggingFace
187
+ └── server/
188
+ ├── __init__.py
189
+ ├── finqa_environment.py # MCPEnvironment subclass with @mcp.tool decorators
190
+ ├── tools.py # Tool implementations
191
+ ├── rewards.py # Reward computation
192
+ ├── app.py # FastAPI server
193
+ └── Dockerfile
194
+ ```
195
+
196
+ ## References
197
+
198
+ - [HuggingFace Dataset](https://huggingface.co/datasets/snorkelai/agent-finance-reasoning)
199
+ - [Leaderboard](https://leaderboard.snorkel.ai/category/snorkelfinance)
__init__.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # envs/finqa_env/__init__.py
2
+ """
3
+ FinQA Environment for OpenEnv.
4
+
5
+ A financial question-answering environment that evaluates LLMs on their ability
6
+ to answer complex financial questions using tool calls on SEC 10-K filing data.
7
+
8
+ Example:
9
+ >>> from envs.finqa_env import FinQAEnv
10
+ >>>
11
+ >>> async with FinQAEnv(base_url="http://localhost:8000") as env:
12
+ ... await env.reset()
13
+ ... tools = await env.list_tools()
14
+ ... result = await env.call_tool("get_descriptions", company_name="alphabet")
15
+ ... result = await env.call_tool("submit_answer", answer="6.118")
16
+ """
17
+
18
+ from .client import FinQAEnv
19
+ from .models import FinQAState
20
+
21
+ # Re-export MCP types for convenience
22
+ from openenv.core.env_server.mcp_types import CallToolAction, ListToolsAction
23
+
24
+ __all__ = ["FinQAEnv", "FinQAState", "CallToolAction", "ListToolsAction"]
client.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # envs/finqa_env/client.py
2
+ """
3
+ Client for the FinQA environment.
4
+
5
+ This client connects to a running FinQA environment server and provides
6
+ a Python interface for interacting with it via MCP tools. Async by default.
7
+
8
+ Example:
9
+ >>> from envs.finqa_env import FinQAEnv
10
+ >>>
11
+ >>> async with FinQAEnv(base_url="http://localhost:8000") as env:
12
+ ... await env.reset()
13
+ ... tools = await env.list_tools()
14
+ ... result = await env.call_tool("get_descriptions", company_name="alphabet")
15
+ ... print(result)
16
+ ... result = await env.call_tool("submit_answer", answer="6.118")
17
+ """
18
+
19
+ from openenv.core.mcp_client import MCPToolClient
20
+
21
+
22
+ class FinQAEnv(MCPToolClient):
23
+ """
24
+ Client for the FinQA environment.
25
+
26
+ Inherits all functionality from MCPToolClient:
27
+ - list_tools(): Discover available tools
28
+ - call_tool(name, **kwargs): Call a tool by name
29
+ - reset(**kwargs): Reset the environment
30
+ - step(action): Execute an action
31
+ """
32
+
33
+ pass # MCPToolClient provides all needed functionality
download_data.sh ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Download FinQA data from HuggingFace
3
+ #
4
+ # This script downloads all FinQA data from HuggingFace:
5
+ # 1. Benchmark questions CSV
6
+ # 2. Company financial documents (preprocessed SEC 10-K filings)
7
+ #
8
+ # Usage:
9
+ # ./download_data.sh <hf_repo_or_url> [output_dir]
10
+
11
+ set -e
12
+
13
+ HF_REPO_OR_URL="${1}"
14
+ OUTPUT_DIR="${2:-./data}"
15
+
16
+ if [ -z "$HF_REPO_OR_URL" ]; then
17
+ echo "Usage: $0 <hf_repo_or_url> [output_dir]"
18
+ echo "Example: $0 snorkelai/finqa-data ./data"
19
+ exit 1
20
+ fi
21
+
22
+ echo "========================================"
23
+ echo "FinQA Data Download"
24
+ echo "========================================"
25
+ echo "Output directory: $OUTPUT_DIR"
26
+ echo ""
27
+
28
+ # Create output directory
29
+ mkdir -p "$OUTPUT_DIR"
30
+
31
+ # Check if data already exists
32
+ if [ -f "$OUTPUT_DIR/benchmark_questions/finqa.csv" ] && [ -d "$OUTPUT_DIR/input_companies" ]; then
33
+ echo "Data already exists in $OUTPUT_DIR, skipping download."
34
+ exit 0
35
+ fi
36
+
37
+ # Check for huggingface-cli
38
+ if ! command -v huggingface-cli &> /dev/null; then
39
+ echo "Error: huggingface-cli not found"
40
+ echo "Install it with: uv pip install huggingface_hub[cli]"
41
+ exit 1
42
+ fi
43
+
44
+ # Download from HuggingFace
45
+ echo "Downloading from HuggingFace: $HF_REPO_OR_URL"
46
+ if ! huggingface-cli download "$HF_REPO_OR_URL" --repo-type dataset --local-dir "$OUTPUT_DIR"; then
47
+ echo "Error: Failed to download dataset"
48
+ exit 1
49
+ fi
50
+
51
+ # Verify downloaded data
52
+ if [ ! -f "$OUTPUT_DIR/benchmark_questions/finqa.csv" ]; then
53
+ echo "Error: benchmark_questions/finqa.csv not found in downloaded data"
54
+ exit 1
55
+ fi
56
+
57
+ if [ ! -d "$OUTPUT_DIR/input_companies" ]; then
58
+ echo "Error: input_companies/ directory not found in downloaded data"
59
+ exit 1
60
+ fi
61
+
62
+ echo ""
63
+ echo "========================================"
64
+ echo "Download complete!"
65
+ echo "========================================"
66
+ echo "Data location: $OUTPUT_DIR"
67
+ echo ""
68
+
69
+ # Export data path
70
+ export FINQA_DATA_PATH="$OUTPUT_DIR"
71
+ echo "Exported: FINQA_DATA_PATH=$FINQA_DATA_PATH"
models.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # envs/finqa_env/models.py
2
+ """
3
+ State types for the FinQA environment.
4
+
5
+ FinQA is a financial question-answering benchmark that evaluates LLMs on their
6
+ ability to answer complex financial questions using tool calls (SQL queries,
7
+ calculations, etc.) on SEC 10-K filing data.
8
+
9
+ This environment uses the MCP protocol for tool interactions. Use
10
+ ``CallToolAction`` and ``ListToolsAction`` from ``openenv.core.env_server.mcp_types``
11
+ to interact with the environment.
12
+ """
13
+
14
+ from openenv.core.env_server import State
15
+
16
+
17
+ # Tool names - defined statically to avoid circular imports
18
+ AVAILABLE_TOOLS = ["get_descriptions", "get_table_info", "sql_query", "submit_answer"]
19
+
20
+
21
+ class FinQAState(State):
22
+ """
23
+ Internal environment state for tracking the current episode.
24
+
25
+ All fields are set during reset() and are essential for episode tracking.
26
+
27
+ Attributes:
28
+ current_question: The question being asked
29
+ current_company: The company the question is about
30
+ ground_truth: The expected answer for reward computation
31
+ question_id: Identifier for the current question
32
+ # Inherited from State: episode_id, step_count
33
+ """
34
+
35
+ current_question: str = ""
36
+ current_company: str = ""
37
+ ground_truth: str = ""
38
+ question_id: str = ""
openenv.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ spec_version: 1
2
+ name: finqa_env
3
+ type: space
4
+ runtime: fastapi
5
+ app: server.app:app
6
+ port: 8000
openenv_finqa_env.egg-info/PKG-INFO ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.4
2
+ Name: openenv-finqa-env
3
+ Version: 0.1.0
4
+ Summary: FinQA Environment for OpenEnv - financial question-answering on SEC 10-K filing data
5
+ Requires-Python: >=3.10
6
+ Requires-Dist: openenv-core[core]>=0.2.1
7
+ Requires-Dist: fastapi>=0.115.0
8
+ Requires-Dist: fastmcp>=2.0.0
9
+ Requires-Dist: pydantic>=2.0.0
10
+ Requires-Dist: uvicorn>=0.24.0
11
+ Requires-Dist: requests>=2.31.0
12
+ Requires-Dist: pandas>=2.0.0
13
+ Provides-Extra: dev
14
+ Requires-Dist: pytest>=8.0.0; extra == "dev"
15
+ Requires-Dist: pytest-cov>=4.0.0; extra == "dev"
openenv_finqa_env.egg-info/SOURCES.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ README.md
2
+ pyproject.toml
3
+ ./__init__.py
4
+ ./client.py
5
+ ./models.py
6
+ openenv_finqa_env.egg-info/PKG-INFO
7
+ openenv_finqa_env.egg-info/SOURCES.txt
8
+ openenv_finqa_env.egg-info/dependency_links.txt
9
+ openenv_finqa_env.egg-info/entry_points.txt
10
+ openenv_finqa_env.egg-info/requires.txt
11
+ openenv_finqa_env.egg-info/top_level.txt
12
+ server/__init__.py
13
+ server/app.py
14
+ server/finqa_environment.py
15
+ server/rewards.py
16
+ server/tools.py
openenv_finqa_env.egg-info/dependency_links.txt ADDED
@@ -0,0 +1 @@
 
 
1
+
openenv_finqa_env.egg-info/entry_points.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ [console_scripts]
2
+ server = finqa_env.server.app:main
openenv_finqa_env.egg-info/requires.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ openenv-core[core]>=0.2.1
2
+ fastapi>=0.115.0
3
+ fastmcp>=2.0.0
4
+ pydantic>=2.0.0
5
+ uvicorn>=0.24.0
6
+ requests>=2.31.0
7
+ pandas>=2.0.0
8
+
9
+ [dev]
10
+ pytest>=8.0.0
11
+ pytest-cov>=4.0.0
openenv_finqa_env.egg-info/top_level.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ finqa_env
pyproject.toml ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools>=45", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "openenv-finqa-env"
7
+ version = "0.1.0"
8
+ description = "FinQA Environment for OpenEnv - financial question-answering on SEC 10-K filing data"
9
+ requires-python = ">=3.10"
10
+ dependencies = [
11
+ # Core OpenEnv dependencies (required for server functionality)
12
+ "openenv-core[core]>=0.2.1",
13
+ "fastapi>=0.115.0",
14
+ "fastmcp>=2.0.0",
15
+ "pydantic>=2.0.0",
16
+ "uvicorn>=0.24.0",
17
+ "requests>=2.31.0",
18
+ # FinQA environment specific dependencies
19
+ "pandas>=2.0.0",
20
+ ]
21
+
22
+ [project.optional-dependencies]
23
+ dev = [
24
+ "pytest>=8.0.0",
25
+ "pytest-cov>=4.0.0",
26
+ ]
27
+
28
+ [project.scripts]
29
+ server = "finqa_env.server.app:main"
30
+
31
+ [tool.setuptools]
32
+ packages = ["finqa_env", "finqa_env.server"]
33
+ package-dir = { "finqa_env" = ".", "finqa_env.server" = "server" }
34
+
server/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # envs/finqa_env/server/__init__.py
2
+ """Server-side components for the FinQA environment."""
3
+
4
+
5
+ def __getattr__(name):
6
+ if name == "FinQAEnvironment":
7
+ from .finqa_environment import FinQAEnvironment
8
+
9
+ return FinQAEnvironment
10
+ raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
11
+
12
+
13
+ __all__ = ["FinQAEnvironment"]
server/app.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # envs/finqa_env/server/app.py
2
+ """
3
+ FastAPI server for the FinQA environment.
4
+
5
+ Environment Variables:
6
+ FINQA_DATA_PATH: Path to data directory (default: /app/env/data)
7
+ FINQA_MAX_STEPS: Maximum tool calls per episode (default: 50)
8
+ FINQA_TASK: Task name (default: finqa)
9
+ """
10
+
11
+ import os
12
+
13
+ from openenv.core.env_server.http_server import create_app
14
+ from openenv.core.env_server.mcp_types import CallToolAction, CallToolObservation
15
+ from .finqa_environment import FinQAEnvironment
16
+
17
+ DATA_PATH = os.environ.get("FINQA_DATA_PATH", "/app/env/data")
18
+ MAX_STEPS = int(os.environ.get("FINQA_MAX_STEPS", "50"))
19
+ TASK = os.environ.get("FINQA_TASK", "finqa")
20
+
21
+
22
+ def _env_factory():
23
+ """Create a new FinQAEnvironment instance for each session."""
24
+ return FinQAEnvironment(
25
+ data_path=DATA_PATH,
26
+ max_steps=MAX_STEPS,
27
+ task=TASK,
28
+ )
29
+
30
+
31
+ # Pass the class (factory) instead of instance for WebSocket session support
32
+ # Use MCP types for action/observation since this is a pure MCP environment
33
+ app = create_app(
34
+ _env_factory, CallToolAction, CallToolObservation, env_name="finqa_env"
35
+ )
server/finqa_environment.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # envs/finqa_env/server/finqa_environment.py
2
+ """
3
+ FinQA Environment Implementation.
4
+
5
+ A financial question-answering environment that evaluates LLMs on their ability
6
+ to answer complex financial questions using tool calls on SEC 10-K filing data.
7
+ """
8
+
9
+ import logging
10
+ import os
11
+ import random
12
+ import uuid
13
+ from typing import Any, Dict, List, Optional
14
+
15
+ import pandas as pd
16
+ from fastmcp import FastMCP
17
+
18
+ from openenv.core.env_server.mcp_environment import MCPEnvironment
19
+ from openenv.core.env_server.mcp_types import CallToolAction
20
+ from openenv.core.env_server.types import Action, Observation
21
+
22
+ from ..models import FinQAState, AVAILABLE_TOOLS
23
+ from .rewards import compute_reward
24
+ from .tools import FinQATools
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+
29
+ class FinQAEnvironment(MCPEnvironment):
30
+ """
31
+ Financial QA environment for RL training.
32
+
33
+ Evaluates agents on their ability to answer financial questions by:
34
+ - Exploring available tables for a company
35
+ - Querying table metadata and executing SQL queries
36
+ - Performing calculations
37
+ - Submitting final answers
38
+
39
+ Args:
40
+ data_path: Path to the data directory containing benchmark_questions/ and input_companies/
41
+ max_steps: Maximum number of tool calls per episode (default: 50)
42
+ task: Task name - currently only 'finqa' supported (default: 'finqa')
43
+ """
44
+
45
+ def __init__(
46
+ self,
47
+ data_path: str = "./data",
48
+ max_steps: int = 50,
49
+ task: str = "finqa",
50
+ ):
51
+ # Create MCP server and define tools inline
52
+ mcp = FastMCP("finqa_env")
53
+
54
+ self.data_path = data_path
55
+ self.max_steps = max_steps
56
+ self.task = task
57
+
58
+ assert task == "finqa", "Only finqa task is supported"
59
+
60
+ self.questions = self._load_questions()
61
+ logger.info(f"Loaded {len(self.questions)} questions for task '{task}'")
62
+
63
+ self._finqa_tools = FinQATools(data_path)
64
+
65
+ # Register tools with FastMCP
66
+ @mcp.tool
67
+ def get_descriptions(company_name: str) -> str:
68
+ """
69
+ Get a list of available table names for a company.
70
+
71
+ Args:
72
+ company_name: The name of the company
73
+
74
+ Returns:
75
+ JSON list of table names
76
+ """
77
+ return self._finqa_tools.get_descriptions(company_name)
78
+
79
+ @mcp.tool
80
+ def get_table_info(company_name: str, table_name: str) -> str:
81
+ """
82
+ Get table metadata: description, columns, types, unique values.
83
+
84
+ Args:
85
+ company_name: The name of the company
86
+ table_name: The name of the table
87
+
88
+ Returns:
89
+ JSON string with table metadata
90
+ """
91
+ return self._finqa_tools.get_table_info(company_name, table_name)
92
+
93
+ @mcp.tool
94
+ def sql_query(company_name: str, table_name: str, query: str) -> str:
95
+ """
96
+ Execute a SQL query on a table. Select * not allowed.
97
+
98
+ Filters are required: WHERE, HAVING, IN, NOT IN, EXISTS, NOT EXISTS,
99
+ ANY, SOME, ALL, LIKE, NOT LIKE, BETWEEN, NOT BETWEEN, IS NULL,
100
+ IS NOT NULL, CASE, FILTER.
101
+
102
+ Args:
103
+ company_name: The name of the company
104
+ table_name: The name of the table
105
+ query: SQL query to execute (must include filters)
106
+
107
+ Returns:
108
+ JSON string with query results
109
+ """
110
+ return self._finqa_tools.sql_query(company_name, table_name, query)
111
+
112
+ @mcp.tool
113
+ def submit_answer(answer: str) -> str:
114
+ """
115
+ Submit a final answer for the question.
116
+
117
+ Args:
118
+ answer: The final answer to submit
119
+
120
+ Returns:
121
+ Confirmation message
122
+ """
123
+ return self._finqa_tools.submit_answer(answer)
124
+
125
+ # Pass the MCP server to the base class
126
+ super().__init__(mcp)
127
+
128
+ # Shuffle dataset for sequential selection
129
+ self._shuffled_questions = self.questions.copy()
130
+ random.shuffle(self._shuffled_questions)
131
+ self._question_index = 0
132
+
133
+ self._state = FinQAState()
134
+ self._history: List[Dict[str, Any]] = []
135
+
136
+ def _load_questions(self) -> List[Dict[str, Any]]:
137
+ """Load questions from the benchmark CSV."""
138
+ csv_path = os.path.join(self.data_path, "benchmark_questions", f"{self.task}.csv")
139
+
140
+ if not os.path.isfile(csv_path):
141
+ raise FileNotFoundError(f"Benchmark file not found: {csv_path}")
142
+
143
+ df = pd.read_csv(csv_path)
144
+
145
+ questions = []
146
+ for _, row in df.iterrows():
147
+ questions.append({
148
+ "id": str(row.get("id", "")),
149
+ "user_query": row["user_query"],
150
+ "company": row["company"],
151
+ "question": row["question"],
152
+ "answer": row["answer"],
153
+ "question_type": row.get("question_type", ""),
154
+ "explanation": row.get("explanation", ""),
155
+ })
156
+
157
+ return questions
158
+
159
+ def _get_next_question(self) -> Dict[str, Any]:
160
+ """Get the next question using sequential shuffle selection."""
161
+ if self._question_index >= len(self._shuffled_questions):
162
+ random.shuffle(self._shuffled_questions)
163
+ self._question_index = 0
164
+
165
+ question = self._shuffled_questions[self._question_index]
166
+ self._question_index += 1
167
+ return question
168
+
169
+ def reset(
170
+ self,
171
+ seed: Optional[int] = None,
172
+ episode_id: Optional[str] = None,
173
+ **kwargs: Any,
174
+ ) -> Observation:
175
+ """
176
+ Reset the environment for a new episode.
177
+
178
+ Returns:
179
+ Initial observation with the question
180
+ """
181
+ question = self._get_next_question()
182
+ self._state = FinQAState(
183
+ episode_id=episode_id or str(uuid.uuid4()),
184
+ step_count=0,
185
+ current_question=question["user_query"],
186
+ current_company=question["company"],
187
+ ground_truth=question["answer"],
188
+ question_id=question["id"],
189
+ )
190
+ self._history = []
191
+
192
+ logger.info(f"Reset episode {self._state.episode_id} with question: {question['question'][:200]}...")
193
+
194
+ return Observation(
195
+ done=False,
196
+ reward=0.0,
197
+ metadata={
198
+ "question": question["user_query"],
199
+ "company": question["company"],
200
+ "tool_result": "",
201
+ "history": [],
202
+ "step_count": 0,
203
+ "available_tools": AVAILABLE_TOOLS.copy(),
204
+ },
205
+ )
206
+
207
+ def _step_impl(
208
+ self,
209
+ action: Action,
210
+ timeout_s: Optional[float] = None,
211
+ **kwargs: Any,
212
+ ) -> Observation:
213
+ """
214
+ Handle non-MCP actions. Returns an error since this env is MCP-only.
215
+ """
216
+ return Observation(
217
+ done=False,
218
+ reward=0.0,
219
+ metadata={
220
+ "error": f"Unknown action type: {type(action).__name__}. "
221
+ "Use ListToolsAction or CallToolAction for MCP interactions."
222
+ },
223
+ )
224
+
225
+ def step(
226
+ self,
227
+ action: Action,
228
+ timeout_s: Optional[float] = None,
229
+ **kwargs: Any,
230
+ ) -> Observation:
231
+ """
232
+ Execute a step in the environment.
233
+
234
+ Delegates to base class for MCP actions. Handles submit_answer
235
+ reward computation and max-step termination.
236
+ """
237
+ self._state.step_count += 1
238
+
239
+ # Let the base class handle MCP actions
240
+ obs = super().step(action, timeout_s=timeout_s, **kwargs)
241
+
242
+ # Check if submit_answer was called
243
+ if isinstance(action, CallToolAction) and action.tool_name == "submit_answer":
244
+ submitted_answer = action.arguments.get("answer", "")
245
+ reward = compute_reward(submitted_answer, self._state.ground_truth)
246
+ logger.info(
247
+ f"Episode {self._state.episode_id} ended: "
248
+ f"submitted='{submitted_answer}', truth='{self._state.ground_truth}', reward={reward}"
249
+ )
250
+ return Observation(
251
+ done=True,
252
+ reward=reward,
253
+ metadata={
254
+ **obs.metadata,
255
+ "ground_truth": self._state.ground_truth,
256
+ "submitted_answer": submitted_answer,
257
+ },
258
+ )
259
+
260
+ # Check for max steps
261
+ if self._state.step_count >= self.max_steps:
262
+ logger.info(f"Episode {self._state.episode_id} terminated: max steps reached")
263
+ return Observation(
264
+ done=True,
265
+ reward=0.0,
266
+ metadata={
267
+ **obs.metadata,
268
+ "error": f"Max steps ({self.max_steps}) reached without submitting answer.",
269
+ },
270
+ )
271
+
272
+ return obs
273
+
274
+ @property
275
+ def state(self) -> FinQAState:
276
+ """Get the current environment state."""
277
+ return self._state
server/rewards.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # envs/finqa_env/server/rewards.py
2
+ """
3
+ Reward computation for the FinQA environment.
4
+
5
+ Uses fuzzy numerical matching to compare predicted answers against ground truth.
6
+ Handles various formats: \boxed{}, percentages, fractions, decimals.
7
+ """
8
+
9
+ import re
10
+ from fractions import Fraction
11
+ from typing import Optional, Tuple
12
+
13
+
14
+ def extract_boxed_answer(text: str) -> Optional[str]:
15
+ """
16
+ Extract answer from \boxed{...} format.
17
+
18
+ Args:
19
+ text: Text potentially containing \boxed{answer}
20
+
21
+ Returns:
22
+ The extracted answer or None if not found
23
+ """
24
+ match = re.search(r"\\boxed\{([^}]+)\}", text)
25
+ if match:
26
+ return match.group(1).strip()
27
+ return None
28
+
29
+
30
+ def extract_all_boxed_answers(text: str) -> list:
31
+ """
32
+ Extract all answers from \boxed{...} format.
33
+
34
+ Args:
35
+ text: Text potentially containing multiple \boxed{answer}
36
+
37
+ Returns:
38
+ List of extracted answers
39
+ """
40
+ matches = re.findall(r"\\boxed\{([^}]+)\}", text)
41
+ return [m.strip() for m in matches]
42
+
43
+
44
+ def parse_number(text: str, convert_percent: bool = True) -> Optional[float]:
45
+ """
46
+ Parse a string into a float, handling various formats.
47
+
48
+ Handles:
49
+ - Plain numbers: "6.118", "-3.14"
50
+ - Percentages: "20.9%", "20.9 %"
51
+ - Fractions: "1/2", "3/4"
52
+ - Thousands separators: "1,234.56"
53
+ - Negative numbers in parens: "(100)"
54
+
55
+ Args:
56
+ text: String to parse
57
+ convert_percent: If True, divide percentages by 100. If False, just strip the % sign.
58
+
59
+ Returns:
60
+ Float value or None if parsing fails
61
+ """
62
+ if text is None:
63
+ return None
64
+
65
+ text = text.strip()
66
+
67
+ if not text:
68
+ return None
69
+
70
+ try:
71
+ # Remove LaTeX annotations like \text{million}, \text{%}, etc.
72
+ text = re.sub(r"\\text\{[^}]*\}", "", text)
73
+
74
+ # Remove currency symbols ($ and \$)
75
+ text = text.replace("\\$", "").replace("$", "").strip()
76
+
77
+ # Handle percentage (including LaTeX escaped \%)
78
+ if "%" in text or "\\%" in text:
79
+ text = text.replace("\\%", "").replace("%", "").strip()
80
+ if convert_percent:
81
+ return float(text.replace(",", "")) / 100
82
+ else:
83
+ return float(text.replace(",", ""))
84
+
85
+ # Handle parentheses for negative numbers
86
+ if text.startswith("(") and text.endswith(")"):
87
+ text = "-" + text[1:-1]
88
+
89
+ # Handle fractions (e.g., "1/2", "3/4")
90
+ if "/" in text and not text.startswith("-"):
91
+ try:
92
+ return float(Fraction(text))
93
+ except (ValueError, ZeroDivisionError):
94
+ pass
95
+
96
+ # Handle negative fractions
97
+ if text.startswith("-") and "/" in text:
98
+ try:
99
+ return -float(Fraction(text[1:]))
100
+ except (ValueError, ZeroDivisionError):
101
+ pass
102
+
103
+ # Remove thousands separators and parse
104
+ text = text.replace(",", "")
105
+ return float(text)
106
+
107
+ except (ValueError, TypeError):
108
+ return None
109
+
110
+
111
+ def normalize_answer(answer: str, convert_percent: bool = True) -> Tuple[Optional[float], str]:
112
+ """
113
+ Normalize an answer string to a comparable format.
114
+
115
+ Args:
116
+ answer: Raw answer string
117
+ convert_percent: If True, divide percentages by 100. If False, just strip the % sign.
118
+
119
+ Returns:
120
+ Tuple of (parsed_number, cleaned_string)
121
+ """
122
+ if answer is None:
123
+ return None, ""
124
+
125
+ # Try to extract from \boxed{} first
126
+ boxed = extract_boxed_answer(answer)
127
+ if boxed:
128
+ answer = boxed
129
+
130
+ # Clean up whitespace
131
+ answer = answer.strip()
132
+
133
+ # Try to parse as number
134
+ num = parse_number(answer, convert_percent)
135
+
136
+ return num, answer.lower()
137
+
138
+
139
+ def extract_numbers_from_multi_value(text: str) -> list:
140
+ """
141
+ Extract all numbers from a comma/semicolon separated string.
142
+ Handles formats like "2022: 0.933, 2023: 0.930" or "0.933, 0.931, 0.930".
143
+ """
144
+ parts = _split_multi_value(text)
145
+ return [num for _, num in parts]
146
+
147
+
148
+ def _split_multi_value(text: str) -> list:
149
+ """
150
+ Extract (key, number) pairs from a comma/semicolon separated string.
151
+
152
+ Returns list of (key, float) tuples. Key is a year string like "2022"
153
+ if found, otherwise None.
154
+ """
155
+ # Split by comma or semicolon (with optional LaTeX spacing like \; or \ )
156
+ parts = re.split(r'[,;]\s*|\\[;,]\s*', text)
157
+ results = []
158
+ for part in parts:
159
+ # Strip LaTeX whitespace commands (\ , \;, \,)
160
+ part = re.sub(r'\\[;, ]', ' ', part).strip()
161
+ if not part:
162
+ continue
163
+ # Try to extract a year label (e.g. "2022:", "2022 to 2023:", "2022→2023:")
164
+ # Normalize \rightarrow and similar to "to" before matching
165
+ part_normalized = re.sub(r'\\rightarrow|→|->|−>', ' to ', part)
166
+ year_match = re.search(r'(20\d{2}(?:\s*to\s*20\d{2})?)', part_normalized)
167
+ key = year_match.group(1) if year_match else None
168
+ # Remove label prefix like "2022:" or "2022:\"
169
+ cleaned = re.sub(r'^[^:]*:\s*\\?\s*', '', part)
170
+ num = parse_number(cleaned)
171
+ if num is not None:
172
+ results.append((key, num))
173
+ return results
174
+
175
+
176
+ def compare_single_values(pred_num: Optional[float], truth_num: Optional[float],
177
+ pred_str: str, truth_str: str,
178
+ tolerance: float = 0.01, max_absolute_diff: float = 1.0) -> bool:
179
+ """Compare two single values."""
180
+ # If both are numbers, compare numerically with tolerance
181
+ if pred_num is not None and truth_num is not None:
182
+ # Handle zero case
183
+ if truth_num == 0:
184
+ return abs(pred_num) < 0.001
185
+
186
+ # Calculate both errors
187
+ abs_diff = abs(pred_num - truth_num)
188
+ relative_error = abs_diff / abs(truth_num)
189
+
190
+ # BOTH conditions must pass
191
+ return relative_error <= tolerance and abs_diff <= max_absolute_diff
192
+
193
+ # If one is a number and other isn't, not equal
194
+ if (pred_num is None) != (truth_num is None):
195
+ return False
196
+
197
+ # Fall back to string comparison
198
+ return pred_str == truth_str
199
+
200
+
201
+ def compute_reward(predicted: str, ground_truth: str, tolerance: float = 0.01, max_absolute_diff: float = 1.0) -> float:
202
+ """
203
+ Compute reward based on answer correctness.
204
+
205
+ Uses fuzzy numerical matching with BOTH relative and absolute tolerance checks.
206
+ A prediction is correct only if it passes BOTH conditions.
207
+
208
+ Handles multiple values (e.g., ground truth with multiple \boxed{} values).
209
+
210
+ Args:
211
+ predicted: The predicted answer from the agent
212
+ ground_truth: The expected correct answer
213
+ tolerance: Relative tolerance for numerical comparison (default 1%)
214
+ max_absolute_diff: Maximum absolute difference allowed (default 1.0)
215
+
216
+ Returns:
217
+ 1.0 if correct, 0.0 if incorrect
218
+ """
219
+ # Check for multiple boxed answers in ground truth
220
+ truth_boxed = extract_all_boxed_answers(ground_truth)
221
+
222
+ if len(truth_boxed) > 1:
223
+ # Multiple ground truth values - split prediction by comma/semicolon
224
+ pred_values = re.split(r'[,;]\s*', predicted.strip())
225
+
226
+ if len(pred_values) != len(truth_boxed):
227
+ return 0.0 # Different number of values
228
+
229
+ # Compare each pair
230
+ for pred_val, truth_val in zip(pred_values, truth_boxed):
231
+ # Strip year/label prefix (e.g. "2024: -4" -> "-4")
232
+ pred_val_cleaned = re.sub(r'^[^:]*:\s*', '', pred_val) if ':' in pred_val else pred_val
233
+ pred_num, pred_str = normalize_answer(pred_val_cleaned)
234
+ truth_num, truth_str = normalize_answer(truth_val)
235
+
236
+ if not compare_single_values(pred_num, truth_num, pred_str, truth_str, tolerance, max_absolute_diff):
237
+ # Fallback: try without % conversion (for percentage points like "4.5%" vs "4.5")
238
+ pred_num_no_pct, _ = normalize_answer(pred_val, convert_percent=False)
239
+ if not compare_single_values(pred_num_no_pct, truth_num, pred_str, truth_str, tolerance, max_absolute_diff):
240
+ return 0.0
241
+
242
+ return 1.0 # All values matched
243
+
244
+ # Single value comparison
245
+ pred_num, pred_str = normalize_answer(predicted)
246
+ truth_num, truth_str = normalize_answer(ground_truth)
247
+
248
+ if compare_single_values(pred_num, truth_num, pred_str, truth_str, tolerance, max_absolute_diff):
249
+ return 1.0
250
+
251
+ pred_num_no_pct, _ = normalize_answer(predicted, convert_percent=False)
252
+ if compare_single_values(pred_num_no_pct, truth_num, pred_str, truth_str, tolerance, max_absolute_diff):
253
+ return 1.0
254
+
255
+ # Fallback: multi-value inside single \boxed{} (only if truth didn't parse as single number)
256
+ if len(truth_boxed) == 1 and truth_num is None:
257
+ truth_pairs = _split_multi_value(truth_boxed[0])
258
+ pred_pairs = _split_multi_value(predicted)
259
+ if len(truth_pairs) > 1 and len(pred_pairs) == len(truth_pairs):
260
+ # If both sides have year keys, match by key (order-independent)
261
+ truth_keys = {k for k, _ in truth_pairs if k is not None}
262
+ pred_keys = {k for k, _ in pred_pairs if k is not None}
263
+ if truth_keys and pred_keys and truth_keys == pred_keys:
264
+ truth_map = {k: v for k, v in truth_pairs}
265
+ pred_map = {k: v for k, v in pred_pairs}
266
+ for key in truth_map:
267
+ p, t = pred_map[key], truth_map[key]
268
+ abs_diff = abs(p - t)
269
+ rel_err = abs_diff / abs(t) if t != 0 else (0 if p == 0 else float('inf'))
270
+ if not (rel_err <= tolerance and abs_diff <= max_absolute_diff):
271
+ return 0.0
272
+ return 1.0
273
+
274
+ # Otherwise fall back to positional matching
275
+ for (_, p), (_, t) in zip(pred_pairs, truth_pairs):
276
+ abs_diff = abs(p - t)
277
+ rel_err = abs_diff / abs(t) if t != 0 else (0 if p == 0 else float('inf'))
278
+ if not (rel_err <= tolerance and abs_diff <= max_absolute_diff):
279
+ return 0.0
280
+ return 1.0
281
+
282
+ return 0.0
server/tools.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # envs/finqa_env/server/tools.py
2
+ """
3
+ Tool implementations for the FinQA environment.
4
+
5
+ Ported from FinQABenchmark with simplifications:
6
+ - Removed LangChain dependencies
7
+ - Added submit_answer tool for episode termination
8
+ """
9
+
10
+ import json
11
+ import os
12
+ import re
13
+ import sqlite3
14
+ from typing import Any, Dict, List, Tuple
15
+
16
+ import pandas as pd
17
+
18
+
19
+ class FinQATools:
20
+ """
21
+ Tool implementations for financial QA tasks.
22
+
23
+ Args:
24
+ data_path: Path to the data directory containing benchmark_questions/ and input_companies/
25
+ """
26
+
27
+ def __init__(self, data_path: str):
28
+ self.data_path = data_path
29
+ self.companies_path = os.path.join(data_path, "input_companies")
30
+ self._tables_cleaned = None
31
+
32
+ @property
33
+ def tables_cleaned(self) -> Dict:
34
+ """Lazy load the cleaned tables metadata."""
35
+ if self._tables_cleaned is None:
36
+ tables_path = os.path.join(self.companies_path, "tables_cleaned_all_companies.json")
37
+ with open(tables_path, "r") as f:
38
+ self._tables_cleaned = json.load(f)
39
+ return self._tables_cleaned
40
+
41
+ def get_available_companies(self) -> List[str]:
42
+ """Get list of available company names."""
43
+ return [
44
+ d for d in os.listdir(self.companies_path)
45
+ if os.path.isdir(os.path.join(self.companies_path, d))
46
+ ]
47
+
48
+ def execute_tool(self, tool_name: str, tool_args: Dict[str, Any]) -> Tuple[str, bool]:
49
+ """
50
+ Execute a tool and return its result.
51
+
52
+ Args:
53
+ tool_name: Name of the tool to execute
54
+ tool_args: Arguments for the tool
55
+
56
+ Returns:
57
+ Tuple of (result_string, is_final_answer)
58
+ """
59
+ if tool_name == "get_descriptions":
60
+ return self.get_descriptions(**tool_args), False
61
+ elif tool_name == "get_table_info":
62
+ return self.get_table_info(**tool_args), False
63
+ elif tool_name == "sql_query":
64
+ return self.sql_query(**tool_args), False
65
+ elif tool_name == "submit_answer":
66
+ return self.submit_answer(**tool_args), True
67
+ else:
68
+ return f"Error: Unknown tool '{tool_name}'", False
69
+
70
+ def get_descriptions(self, company_name: str) -> str:
71
+ """
72
+ Get a list of available table names for a company.
73
+
74
+ Args:
75
+ company_name: The name of the company
76
+
77
+ Returns:
78
+ JSON list of table names
79
+ """
80
+ company_path = os.path.join(self.companies_path, company_name)
81
+
82
+ if not os.path.isdir(company_path):
83
+ available = self.get_available_companies()
84
+ return f"Error: '{company_name}' not found. Available companies: {available}"
85
+
86
+ # Get all JSON files (tables) for this company
87
+ tables = []
88
+ for f in os.listdir(company_path):
89
+ if f.endswith(".json"):
90
+ tables.append(f.replace(".json", ""))
91
+
92
+ return json.dumps(tables)
93
+
94
+ def get_table_info(self, company_name: str, table_name: str) -> str:
95
+ """
96
+ Get table metadata: description, columns, types, unique values.
97
+
98
+ Args:
99
+ company_name: The name of the company
100
+ table_name: The name of the table
101
+
102
+ Returns:
103
+ JSON string with table metadata (description, columns, dtypes, unique values)
104
+ """
105
+ company_path = os.path.join(self.companies_path, company_name)
106
+
107
+ if not os.path.isdir(company_path):
108
+ available = self.get_available_companies()
109
+ return f"Error: '{company_name}' not found. Available companies: {available}"
110
+
111
+ # Clean table name (remove .json or .txt if present)
112
+ cleaned_table_name = table_name.replace(".json", "").replace(".txt", "")
113
+ table_key = f"{company_name}/{cleaned_table_name}"
114
+
115
+ if table_key not in self.tables_cleaned:
116
+ return f"Error: Table '{table_name}' not found for company '{company_name}'"
117
+
118
+ table_info = self.tables_cleaned[table_key].copy()
119
+
120
+ # Load the actual table to get column info
121
+ cleaned_table = pd.DataFrame(json.loads(table_info["table"]))
122
+
123
+ # Drop numeric columns (keep only structure columns for querying hints)
124
+ cols_to_drop = []
125
+ for col in cleaned_table.columns.tolist()[1:]: # Skip first column
126
+ vals = cleaned_table[col].tolist()[1:]
127
+ cleaned_vals = [
128
+ "".join(char for char in str(x) if char.isalnum()).strip()
129
+ for x in vals
130
+ ]
131
+ all_numeric = all(
132
+ v.isnumeric() or len(v) == 0 for v in cleaned_vals
133
+ )
134
+ if all_numeric:
135
+ cols_to_drop.append(col)
136
+
137
+ table_info["column_dtypes"] = {
138
+ col: str(cleaned_table[col].dtype)
139
+ for col in cleaned_table.columns.tolist()
140
+ }
141
+
142
+ # Only show unique values for non-numeric columns
143
+ cleaned_table_filtered = cleaned_table.drop(cols_to_drop, axis=1)
144
+ table_info["unique_vals_per_col"] = {
145
+ col: list(cleaned_table_filtered[col].unique())
146
+ for col in cleaned_table_filtered.columns.tolist()
147
+ }
148
+
149
+ # Remove the raw table data from response
150
+ del table_info["table"]
151
+
152
+ return json.dumps(table_info, indent=0).replace("\n", "")
153
+
154
+ def sql_query(self, company_name: str, table_name: str, query: str) -> str:
155
+ """
156
+ Execute a SQL query on a table. Select * not allowed (too inefficient).
157
+
158
+ Filters are required to query: WHERE, HAVING, IN, NOT IN, EXISTS, NOT EXISTS, ANY, SOME, ALL, LIKE, NOT LIKE, BETWEEN, NOT BETWEEN, IS NULL, IS NOT NULL, CASE, FILTER.
159
+
160
+ Args:
161
+ company_name: The name of the company
162
+ table_name: The name of the table
163
+ query: SQL query to execute (must include filters)
164
+
165
+ Returns:
166
+ JSON string with query results
167
+ """
168
+ # Validate query has filters (prevent full table scans)
169
+ if "select *" in query.lower():
170
+ return "Error: SELECT * is not allowed (too inefficient)"
171
+
172
+ sql_filters = [
173
+ "WHERE", "HAVING", "IN", "NOT IN", "EXISTS", "NOT EXISTS",
174
+ "ANY", "SOME", "ALL", "LIKE", "NOT LIKE", "BETWEEN",
175
+ "NOT BETWEEN", "IS NULL", "IS NOT NULL", "CASE", "FILTER"
176
+ ]
177
+
178
+ query_upper = re.sub(r"(\r|\n|\t)+", " ", query).upper()
179
+ pattern = r"(?<!\w|\[)(" + "|".join([re.escape(f) for f in sql_filters]) + r")(?!\w|\])"
180
+
181
+ has_filter = (
182
+ any(f" {filt} " in query_upper for filt in sql_filters) or
183
+ len(re.findall(pattern, query_upper)) > 0
184
+ )
185
+
186
+ if not has_filter:
187
+ return "Error: Query must include filters (WHERE, HAVING, etc.)"
188
+
189
+ # Clean table name
190
+ cleaned_table_name = table_name.replace(".txt", "").replace(".json", "")
191
+ table_path = os.path.join(self.companies_path, company_name, f"{cleaned_table_name}.json")
192
+
193
+ if not os.path.isfile(table_path):
194
+ return f"Error: Table file not found at {table_path}"
195
+
196
+ try:
197
+ # Load table and execute query
198
+ conn = sqlite3.connect(":memory:")
199
+ df = pd.read_json(table_path)
200
+ df.to_sql(cleaned_table_name, conn, index=False, if_exists="replace")
201
+ result = pd.read_sql_query(query, conn)
202
+ conn.close()
203
+
204
+ return result.to_json(orient="records")
205
+ except Exception as e:
206
+ return f"Error executing query: {str(e)}"
207
+
208
+ def submit_answer(self, answer: str) -> str:
209
+ """
210
+ Submit a final answer for the question.
211
+
212
+ Args:
213
+ answer: The final answer to submit
214
+
215
+ Returns:
216
+ Confirmation message
217
+ """
218
+ return f"Answer submitted: {answer}"
uv.lock ADDED
The diff for this file is too large to render. See raw diff