Upload folder using huggingface_hub
Browse files- Dockerfile +83 -0
- README.md +194 -5
- __init__.py +24 -0
- client.py +33 -0
- download_data.sh +71 -0
- models.py +38 -0
- openenv.yaml +6 -0
- openenv_finqa_env.egg-info/PKG-INFO +15 -0
- openenv_finqa_env.egg-info/SOURCES.txt +16 -0
- openenv_finqa_env.egg-info/dependency_links.txt +1 -0
- openenv_finqa_env.egg-info/entry_points.txt +2 -0
- openenv_finqa_env.egg-info/requires.txt +11 -0
- openenv_finqa_env.egg-info/top_level.txt +1 -0
- pyproject.toml +34 -0
- server/__init__.py +13 -0
- server/app.py +35 -0
- server/finqa_environment.py +277 -0
- server/rewards.py +282 -0
- server/tools.py +218 -0
- uv.lock +0 -0
Dockerfile
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Multi-stage build using openenv-base
|
| 2 |
+
# This Dockerfile is flexible and works for both:
|
| 3 |
+
# - In-repo environments (with local src/core)
|
| 4 |
+
# - Standalone environments (with openenv-core from pip)
|
| 5 |
+
# The build script (openenv build) handles context detection and sets appropriate build args.
|
| 6 |
+
|
| 7 |
+
ARG BASE_IMAGE=ghcr.io/meta-pytorch/openenv-base:latest
|
| 8 |
+
FROM ${BASE_IMAGE} AS builder
|
| 9 |
+
|
| 10 |
+
WORKDIR /app
|
| 11 |
+
|
| 12 |
+
# Build argument to control whether we're building standalone or in-repo
|
| 13 |
+
ARG BUILD_MODE=in-repo
|
| 14 |
+
|
| 15 |
+
# Copy environment code (always at root of build context)
|
| 16 |
+
COPY . /app/env
|
| 17 |
+
|
| 18 |
+
# For in-repo builds, openenv-core is already in the pyproject.toml dependencies
|
| 19 |
+
# For standalone builds, openenv-core will be installed from pip via pyproject.toml
|
| 20 |
+
WORKDIR /app/env
|
| 21 |
+
|
| 22 |
+
# Ensure uv is available (for local builds where base image lacks it)
|
| 23 |
+
RUN if ! command -v uv >/dev/null 2>&1; then \
|
| 24 |
+
curl -LsSf https://astral.sh/uv/install.sh | sh && \
|
| 25 |
+
mv /root/.local/bin/uv /usr/local/bin/uv && \
|
| 26 |
+
mv /root/.local/bin/uvx /usr/local/bin/uvx; \
|
| 27 |
+
fi
|
| 28 |
+
|
| 29 |
+
# Install git for building from git repos (build-time only)
|
| 30 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 31 |
+
git \
|
| 32 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 33 |
+
|
| 34 |
+
# Install dependencies using uv sync
|
| 35 |
+
# First pass: install dependencies without the project (for better caching)
|
| 36 |
+
# Second pass: install the project itself
|
| 37 |
+
RUN --mount=type=cache,target=/root/.cache/uv \
|
| 38 |
+
if [ -f uv.lock ]; then \
|
| 39 |
+
uv sync --frozen --no-install-project --no-editable; \
|
| 40 |
+
else \
|
| 41 |
+
uv sync --no-install-project --no-editable; \
|
| 42 |
+
fi
|
| 43 |
+
|
| 44 |
+
RUN --mount=type=cache,target=/root/.cache/uv \
|
| 45 |
+
if [ -f uv.lock ]; then \
|
| 46 |
+
uv sync --frozen --no-editable; \
|
| 47 |
+
else \
|
| 48 |
+
uv sync --no-editable; \
|
| 49 |
+
fi
|
| 50 |
+
|
| 51 |
+
# Final runtime stage
|
| 52 |
+
FROM ${BASE_IMAGE}
|
| 53 |
+
|
| 54 |
+
WORKDIR /app
|
| 55 |
+
|
| 56 |
+
# Copy the virtual environment from builder
|
| 57 |
+
COPY --from=builder /app/env/.venv /app/.venv
|
| 58 |
+
|
| 59 |
+
# Copy the environment code
|
| 60 |
+
COPY --from=builder /app/env /app/env
|
| 61 |
+
|
| 62 |
+
# Set PATH to use the virtual environment
|
| 63 |
+
ENV PATH="/app/.venv/bin:$PATH"
|
| 64 |
+
|
| 65 |
+
# Set PYTHONPATH so imports work correctly
|
| 66 |
+
ENV PYTHONPATH="/app/env:$PYTHONPATH"
|
| 67 |
+
|
| 68 |
+
# Environment variables with defaults
|
| 69 |
+
ENV FINQA_DATA_PATH="/app/env/data"
|
| 70 |
+
ENV FINQA_MAX_STEPS="50"
|
| 71 |
+
ENV FINQA_TASK="finqa"
|
| 72 |
+
|
| 73 |
+
# Download data from HuggingFace at build time (requires network)
|
| 74 |
+
RUN pip install --no-cache-dir huggingface_hub[cli] && \
|
| 75 |
+
bash /app/env/download_data.sh snorkelai/finqa-data /app/env/data
|
| 76 |
+
|
| 77 |
+
# Health check using Python (more portable than curl/wget)
|
| 78 |
+
HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
|
| 79 |
+
CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:8000/health')" || exit 1
|
| 80 |
+
|
| 81 |
+
# Run the FastAPI server
|
| 82 |
+
ENV ENABLE_WEB_INTERFACE=true
|
| 83 |
+
CMD ["sh", "-c", "cd /app/env && uvicorn server.app:app --host 0.0.0.0 --port 8000"]
|
README.md
CHANGED
|
@@ -1,10 +1,199 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
---
|
| 9 |
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: FinQA Environment Server
|
| 3 |
+
emoji: 🔊
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: gray
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
| 8 |
+
app_port: 8000
|
| 9 |
+
base_path: /web
|
| 10 |
+
datasets:
|
| 11 |
+
- snorkelai/finqa-data
|
| 12 |
+
tags:
|
| 13 |
+
- openenv
|
| 14 |
---
|
| 15 |
|
| 16 |
+
# FinQA Environment
|
| 17 |
+
|
| 18 |
+
A financial question-answering environment for RL training. Evaluates LLMs on their ability to answer complex financial questions using tool calls on SEC 10-K filing data.
|
| 19 |
+
|
| 20 |
+
Based on [FinQABenchmark](https://github.com/snorkel-ai/FinQABenchmark) from Snorkel AI.
|
| 21 |
+
|
| 22 |
+
## Overview
|
| 23 |
+
|
| 24 |
+
FinQA tests an agent's ability to:
|
| 25 |
+
- Explore available financial tables for a company
|
| 26 |
+
- Query table metadata and execute SQL queries
|
| 27 |
+
- Perform calculations on extracted data
|
| 28 |
+
- Submit final answers to financial questions
|
| 29 |
+
|
| 30 |
+
**Dataset**: 290 questions from SEC 10-K filings across multiple companies (Alphabet, Amazon, Apple, AT&T, etc.)
|
| 31 |
+
|
| 32 |
+
**Reward**: Binary (1.0 for correct answer, 0.0 for incorrect) using fuzzy numerical matching with 1% tolerance.
|
| 33 |
+
|
| 34 |
+
> **Note**: This dataset is for evaluation only. Do not train on it.
|
| 35 |
+
|
| 36 |
+
## Quick Start
|
| 37 |
+
|
| 38 |
+
### Using Docker
|
| 39 |
+
|
| 40 |
+
```bash
|
| 41 |
+
# Build the image (from OpenEnv repo root)
|
| 42 |
+
docker build -t finqa-env:latest -f envs/finqa_env/server/Dockerfile .
|
| 43 |
+
|
| 44 |
+
# Run the server
|
| 45 |
+
docker run -p 8000:8000 finqa-env:latest
|
| 46 |
+
|
| 47 |
+
# To run evaluation script (example model gpt-5)
|
| 48 |
+
API_BASE_URL=https://api.openai.com/v1 API_KEY=$OPENAI_API_KEY MODEL=gpt-5 python examples/finqa_inference.py
|
| 49 |
+
```
|
| 50 |
+
|
| 51 |
+
### Local Development
|
| 52 |
+
|
| 53 |
+
```bash
|
| 54 |
+
# Install dependencies
|
| 55 |
+
uv pip install pandas
|
| 56 |
+
|
| 57 |
+
# Download data from HuggingFace
|
| 58 |
+
cd envs/finqa_env
|
| 59 |
+
./download_data.sh
|
| 60 |
+
```
|
| 61 |
+
|
| 62 |
+
### Using the Client
|
| 63 |
+
|
| 64 |
+
The client uses the MCP protocol and is async by default:
|
| 65 |
+
|
| 66 |
+
```python
|
| 67 |
+
import asyncio
|
| 68 |
+
from envs.finqa_env import FinQAEnv, CallToolAction
|
| 69 |
+
|
| 70 |
+
async def main():
|
| 71 |
+
async with FinQAEnv(base_url="http://localhost:8000") as env:
|
| 72 |
+
# Reset to get a question
|
| 73 |
+
obs = await env.reset()
|
| 74 |
+
question = obs.metadata["question"]
|
| 75 |
+
company = obs.metadata["company"]
|
| 76 |
+
print(f"Question: {question}")
|
| 77 |
+
print(f"Company: {company}")
|
| 78 |
+
|
| 79 |
+
# Discover available tools
|
| 80 |
+
tools = await env.list_tools()
|
| 81 |
+
print([t.name for t in tools])
|
| 82 |
+
|
| 83 |
+
# Use tools via call_tool (convenience method)
|
| 84 |
+
result = await env.call_tool("get_descriptions", company_name=company)
|
| 85 |
+
print(f"Available tables: {result}")
|
| 86 |
+
|
| 87 |
+
# Or use step() with CallToolAction for full observation access
|
| 88 |
+
step_result = await env.step(CallToolAction(
|
| 89 |
+
tool_name="sql_query",
|
| 90 |
+
arguments={
|
| 91 |
+
"company_name": "alphabet",
|
| 92 |
+
"table_name": "us_gaap_ScheduleOfIncomeBeforeIncomeTaxDomesticAndForeignTableTextBlock",
|
| 93 |
+
"query": "SELECT * FROM data WHERE year = '2022'"
|
| 94 |
+
}
|
| 95 |
+
))
|
| 96 |
+
print(f"Done: {step_result.done}, Reward: {step_result.reward}")
|
| 97 |
+
|
| 98 |
+
# Submit answer
|
| 99 |
+
result = await env.call_tool("submit_answer", answer="6.118")
|
| 100 |
+
|
| 101 |
+
asyncio.run(main())
|
| 102 |
+
```
|
| 103 |
+
|
| 104 |
+
## Available Tools
|
| 105 |
+
|
| 106 |
+
Tools are auto-discovered via MCP. Use `await env.list_tools()` to see all available tools at runtime.
|
| 107 |
+
|
| 108 |
+
| Tool | Description | Arguments |
|
| 109 |
+
|------|-------------|-----------|
|
| 110 |
+
| `get_descriptions` | Get list of available table names for a company | `company_name: str` |
|
| 111 |
+
| `get_table_info` | Get table metadata (columns, dtypes, unique values) | `company_name: str, table_name: str` |
|
| 112 |
+
| `sql_query` | Execute SQL query on a table (requires filters) | `company_name: str, table_name: str, query: str` |
|
| 113 |
+
| `submit_answer` | Submit final answer (ends episode) | `answer: str` |
|
| 114 |
+
|
| 115 |
+
### Tool Constraints
|
| 116 |
+
|
| 117 |
+
- **sql_query**: Must include filters (`WHERE`, `HAVING`, etc.). `SELECT *` is not allowed.
|
| 118 |
+
|
| 119 |
+
## Environment Variables
|
| 120 |
+
|
| 121 |
+
| Variable | Default | Description |
|
| 122 |
+
|----------|---------|-------------|
|
| 123 |
+
| `FINQA_DATA_PATH` | `/app/env/data` | Path to data directory |
|
| 124 |
+
| `FINQA_MAX_STEPS` | `50` | Maximum tool calls per episode |
|
| 125 |
+
| `FINQA_TASK` | `finqa` | Task name |
|
| 126 |
+
|
| 127 |
+
## Reward Computation
|
| 128 |
+
|
| 129 |
+
Rewards use fuzzy numerical matching:
|
| 130 |
+
|
| 131 |
+
- Extracts numbers from `\boxed{...}` format
|
| 132 |
+
- Handles percentages, fractions, and decimals
|
| 133 |
+
- 1% relative tolerance or 0.01 absolute tolerance
|
| 134 |
+
- Returns `1.0` for correct, `0.0` for incorrect
|
| 135 |
+
|
| 136 |
+
## Local Development
|
| 137 |
+
|
| 138 |
+
```bash
|
| 139 |
+
# From OpenEnv repo root
|
| 140 |
+
cd envs/finqa_env
|
| 141 |
+
|
| 142 |
+
# Run server locally
|
| 143 |
+
FINQA_DATA_PATH=./data uvicorn server.app:app --reload --port 8000
|
| 144 |
+
|
| 145 |
+
# Test with curl
|
| 146 |
+
curl http://localhost:8000/health
|
| 147 |
+
curl -X POST http://localhost:8000/reset
|
| 148 |
+
```
|
| 149 |
+
|
| 150 |
+
## Integration with RL Frameworks
|
| 151 |
+
|
| 152 |
+
### TRL (GRPO)
|
| 153 |
+
|
| 154 |
+
```python
|
| 155 |
+
import asyncio
|
| 156 |
+
from trl import GRPOTrainer
|
| 157 |
+
from envs.finqa_env import FinQAEnv
|
| 158 |
+
|
| 159 |
+
async def rollout_func(prompts, trainer):
|
| 160 |
+
async with FinQAEnv(base_url="http://localhost:8000") as env:
|
| 161 |
+
obs = await env.reset()
|
| 162 |
+
# Your agent logic here using await env.call_tool(...)
|
| 163 |
+
return {"reward": obs.reward, "completion": completion}
|
| 164 |
+
|
| 165 |
+
trainer = GRPOTrainer(
|
| 166 |
+
model=model,
|
| 167 |
+
rollout_func=rollout_func,
|
| 168 |
+
...
|
| 169 |
+
)
|
| 170 |
+
```
|
| 171 |
+
|
| 172 |
+
## Project Structure
|
| 173 |
+
|
| 174 |
+
```
|
| 175 |
+
finqa_env/
|
| 176 |
+
├── __init__.py # Exports FinQAEnv, CallToolAction, ListToolsAction
|
| 177 |
+
├── models.py # FinQAState and tool name constants
|
| 178 |
+
├── client.py # MCP client (subclasses MCPToolClient)
|
| 179 |
+
├── pyproject.toml # Dependencies
|
| 180 |
+
├── README.md # This file
|
| 181 |
+
├── data/ # Benchmark data (run download_data.sh)
|
| 182 |
+
│ ├── benchmark_questions/
|
| 183 |
+
│ │ └── finqa.csv
|
| 184 |
+
│ └── input_companies/
|
| 185 |
+
│ └── [company folders]
|
| 186 |
+
├── download_data.sh # Downloads data from HuggingFace
|
| 187 |
+
└── server/
|
| 188 |
+
├── __init__.py
|
| 189 |
+
├── finqa_environment.py # MCPEnvironment subclass with @mcp.tool decorators
|
| 190 |
+
├── tools.py # Tool implementations
|
| 191 |
+
├── rewards.py # Reward computation
|
| 192 |
+
├── app.py # FastAPI server
|
| 193 |
+
└── Dockerfile
|
| 194 |
+
```
|
| 195 |
+
|
| 196 |
+
## References
|
| 197 |
+
|
| 198 |
+
- [HuggingFace Dataset](https://huggingface.co/datasets/snorkelai/agent-finance-reasoning)
|
| 199 |
+
- [Leaderboard](https://leaderboard.snorkel.ai/category/snorkelfinance)
|
__init__.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# envs/finqa_env/__init__.py
|
| 2 |
+
"""
|
| 3 |
+
FinQA Environment for OpenEnv.
|
| 4 |
+
|
| 5 |
+
A financial question-answering environment that evaluates LLMs on their ability
|
| 6 |
+
to answer complex financial questions using tool calls on SEC 10-K filing data.
|
| 7 |
+
|
| 8 |
+
Example:
|
| 9 |
+
>>> from envs.finqa_env import FinQAEnv
|
| 10 |
+
>>>
|
| 11 |
+
>>> async with FinQAEnv(base_url="http://localhost:8000") as env:
|
| 12 |
+
... await env.reset()
|
| 13 |
+
... tools = await env.list_tools()
|
| 14 |
+
... result = await env.call_tool("get_descriptions", company_name="alphabet")
|
| 15 |
+
... result = await env.call_tool("submit_answer", answer="6.118")
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
from .client import FinQAEnv
|
| 19 |
+
from .models import FinQAState
|
| 20 |
+
|
| 21 |
+
# Re-export MCP types for convenience
|
| 22 |
+
from openenv.core.env_server.mcp_types import CallToolAction, ListToolsAction
|
| 23 |
+
|
| 24 |
+
__all__ = ["FinQAEnv", "FinQAState", "CallToolAction", "ListToolsAction"]
|
client.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# envs/finqa_env/client.py
|
| 2 |
+
"""
|
| 3 |
+
Client for the FinQA environment.
|
| 4 |
+
|
| 5 |
+
This client connects to a running FinQA environment server and provides
|
| 6 |
+
a Python interface for interacting with it via MCP tools. Async by default.
|
| 7 |
+
|
| 8 |
+
Example:
|
| 9 |
+
>>> from envs.finqa_env import FinQAEnv
|
| 10 |
+
>>>
|
| 11 |
+
>>> async with FinQAEnv(base_url="http://localhost:8000") as env:
|
| 12 |
+
... await env.reset()
|
| 13 |
+
... tools = await env.list_tools()
|
| 14 |
+
... result = await env.call_tool("get_descriptions", company_name="alphabet")
|
| 15 |
+
... print(result)
|
| 16 |
+
... result = await env.call_tool("submit_answer", answer="6.118")
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
from openenv.core.mcp_client import MCPToolClient
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class FinQAEnv(MCPToolClient):
|
| 23 |
+
"""
|
| 24 |
+
Client for the FinQA environment.
|
| 25 |
+
|
| 26 |
+
Inherits all functionality from MCPToolClient:
|
| 27 |
+
- list_tools(): Discover available tools
|
| 28 |
+
- call_tool(name, **kwargs): Call a tool by name
|
| 29 |
+
- reset(**kwargs): Reset the environment
|
| 30 |
+
- step(action): Execute an action
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
pass # MCPToolClient provides all needed functionality
|
download_data.sh
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Download FinQA data from HuggingFace
|
| 3 |
+
#
|
| 4 |
+
# This script downloads all FinQA data from HuggingFace:
|
| 5 |
+
# 1. Benchmark questions CSV
|
| 6 |
+
# 2. Company financial documents (preprocessed SEC 10-K filings)
|
| 7 |
+
#
|
| 8 |
+
# Usage:
|
| 9 |
+
# ./download_data.sh <hf_repo_or_url> [output_dir]
|
| 10 |
+
|
| 11 |
+
set -e
|
| 12 |
+
|
| 13 |
+
HF_REPO_OR_URL="${1}"
|
| 14 |
+
OUTPUT_DIR="${2:-./data}"
|
| 15 |
+
|
| 16 |
+
if [ -z "$HF_REPO_OR_URL" ]; then
|
| 17 |
+
echo "Usage: $0 <hf_repo_or_url> [output_dir]"
|
| 18 |
+
echo "Example: $0 snorkelai/finqa-data ./data"
|
| 19 |
+
exit 1
|
| 20 |
+
fi
|
| 21 |
+
|
| 22 |
+
echo "========================================"
|
| 23 |
+
echo "FinQA Data Download"
|
| 24 |
+
echo "========================================"
|
| 25 |
+
echo "Output directory: $OUTPUT_DIR"
|
| 26 |
+
echo ""
|
| 27 |
+
|
| 28 |
+
# Create output directory
|
| 29 |
+
mkdir -p "$OUTPUT_DIR"
|
| 30 |
+
|
| 31 |
+
# Check if data already exists
|
| 32 |
+
if [ -f "$OUTPUT_DIR/benchmark_questions/finqa.csv" ] && [ -d "$OUTPUT_DIR/input_companies" ]; then
|
| 33 |
+
echo "Data already exists in $OUTPUT_DIR, skipping download."
|
| 34 |
+
exit 0
|
| 35 |
+
fi
|
| 36 |
+
|
| 37 |
+
# Check for huggingface-cli
|
| 38 |
+
if ! command -v huggingface-cli &> /dev/null; then
|
| 39 |
+
echo "Error: huggingface-cli not found"
|
| 40 |
+
echo "Install it with: uv pip install huggingface_hub[cli]"
|
| 41 |
+
exit 1
|
| 42 |
+
fi
|
| 43 |
+
|
| 44 |
+
# Download from HuggingFace
|
| 45 |
+
echo "Downloading from HuggingFace: $HF_REPO_OR_URL"
|
| 46 |
+
if ! huggingface-cli download "$HF_REPO_OR_URL" --repo-type dataset --local-dir "$OUTPUT_DIR"; then
|
| 47 |
+
echo "Error: Failed to download dataset"
|
| 48 |
+
exit 1
|
| 49 |
+
fi
|
| 50 |
+
|
| 51 |
+
# Verify downloaded data
|
| 52 |
+
if [ ! -f "$OUTPUT_DIR/benchmark_questions/finqa.csv" ]; then
|
| 53 |
+
echo "Error: benchmark_questions/finqa.csv not found in downloaded data"
|
| 54 |
+
exit 1
|
| 55 |
+
fi
|
| 56 |
+
|
| 57 |
+
if [ ! -d "$OUTPUT_DIR/input_companies" ]; then
|
| 58 |
+
echo "Error: input_companies/ directory not found in downloaded data"
|
| 59 |
+
exit 1
|
| 60 |
+
fi
|
| 61 |
+
|
| 62 |
+
echo ""
|
| 63 |
+
echo "========================================"
|
| 64 |
+
echo "Download complete!"
|
| 65 |
+
echo "========================================"
|
| 66 |
+
echo "Data location: $OUTPUT_DIR"
|
| 67 |
+
echo ""
|
| 68 |
+
|
| 69 |
+
# Export data path
|
| 70 |
+
export FINQA_DATA_PATH="$OUTPUT_DIR"
|
| 71 |
+
echo "Exported: FINQA_DATA_PATH=$FINQA_DATA_PATH"
|
models.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# envs/finqa_env/models.py
|
| 2 |
+
"""
|
| 3 |
+
State types for the FinQA environment.
|
| 4 |
+
|
| 5 |
+
FinQA is a financial question-answering benchmark that evaluates LLMs on their
|
| 6 |
+
ability to answer complex financial questions using tool calls (SQL queries,
|
| 7 |
+
calculations, etc.) on SEC 10-K filing data.
|
| 8 |
+
|
| 9 |
+
This environment uses the MCP protocol for tool interactions. Use
|
| 10 |
+
``CallToolAction`` and ``ListToolsAction`` from ``openenv.core.env_server.mcp_types``
|
| 11 |
+
to interact with the environment.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
from openenv.core.env_server import State
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
# Tool names - defined statically to avoid circular imports
|
| 18 |
+
AVAILABLE_TOOLS = ["get_descriptions", "get_table_info", "sql_query", "submit_answer"]
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class FinQAState(State):
|
| 22 |
+
"""
|
| 23 |
+
Internal environment state for tracking the current episode.
|
| 24 |
+
|
| 25 |
+
All fields are set during reset() and are essential for episode tracking.
|
| 26 |
+
|
| 27 |
+
Attributes:
|
| 28 |
+
current_question: The question being asked
|
| 29 |
+
current_company: The company the question is about
|
| 30 |
+
ground_truth: The expected answer for reward computation
|
| 31 |
+
question_id: Identifier for the current question
|
| 32 |
+
# Inherited from State: episode_id, step_count
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
current_question: str = ""
|
| 36 |
+
current_company: str = ""
|
| 37 |
+
ground_truth: str = ""
|
| 38 |
+
question_id: str = ""
|
openenv.yaml
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
spec_version: 1
|
| 2 |
+
name: finqa_env
|
| 3 |
+
type: space
|
| 4 |
+
runtime: fastapi
|
| 5 |
+
app: server.app:app
|
| 6 |
+
port: 8000
|
openenv_finqa_env.egg-info/PKG-INFO
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Metadata-Version: 2.4
|
| 2 |
+
Name: openenv-finqa-env
|
| 3 |
+
Version: 0.1.0
|
| 4 |
+
Summary: FinQA Environment for OpenEnv - financial question-answering on SEC 10-K filing data
|
| 5 |
+
Requires-Python: >=3.10
|
| 6 |
+
Requires-Dist: openenv-core[core]>=0.2.1
|
| 7 |
+
Requires-Dist: fastapi>=0.115.0
|
| 8 |
+
Requires-Dist: fastmcp>=2.0.0
|
| 9 |
+
Requires-Dist: pydantic>=2.0.0
|
| 10 |
+
Requires-Dist: uvicorn>=0.24.0
|
| 11 |
+
Requires-Dist: requests>=2.31.0
|
| 12 |
+
Requires-Dist: pandas>=2.0.0
|
| 13 |
+
Provides-Extra: dev
|
| 14 |
+
Requires-Dist: pytest>=8.0.0; extra == "dev"
|
| 15 |
+
Requires-Dist: pytest-cov>=4.0.0; extra == "dev"
|
openenv_finqa_env.egg-info/SOURCES.txt
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
README.md
|
| 2 |
+
pyproject.toml
|
| 3 |
+
./__init__.py
|
| 4 |
+
./client.py
|
| 5 |
+
./models.py
|
| 6 |
+
openenv_finqa_env.egg-info/PKG-INFO
|
| 7 |
+
openenv_finqa_env.egg-info/SOURCES.txt
|
| 8 |
+
openenv_finqa_env.egg-info/dependency_links.txt
|
| 9 |
+
openenv_finqa_env.egg-info/entry_points.txt
|
| 10 |
+
openenv_finqa_env.egg-info/requires.txt
|
| 11 |
+
openenv_finqa_env.egg-info/top_level.txt
|
| 12 |
+
server/__init__.py
|
| 13 |
+
server/app.py
|
| 14 |
+
server/finqa_environment.py
|
| 15 |
+
server/rewards.py
|
| 16 |
+
server/tools.py
|
openenv_finqa_env.egg-info/dependency_links.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
openenv_finqa_env.egg-info/entry_points.txt
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[console_scripts]
|
| 2 |
+
server = finqa_env.server.app:main
|
openenv_finqa_env.egg-info/requires.txt
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
openenv-core[core]>=0.2.1
|
| 2 |
+
fastapi>=0.115.0
|
| 3 |
+
fastmcp>=2.0.0
|
| 4 |
+
pydantic>=2.0.0
|
| 5 |
+
uvicorn>=0.24.0
|
| 6 |
+
requests>=2.31.0
|
| 7 |
+
pandas>=2.0.0
|
| 8 |
+
|
| 9 |
+
[dev]
|
| 10 |
+
pytest>=8.0.0
|
| 11 |
+
pytest-cov>=4.0.0
|
openenv_finqa_env.egg-info/top_level.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
finqa_env
|
pyproject.toml
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["setuptools>=45", "wheel"]
|
| 3 |
+
build-backend = "setuptools.build_meta"
|
| 4 |
+
|
| 5 |
+
[project]
|
| 6 |
+
name = "openenv-finqa-env"
|
| 7 |
+
version = "0.1.0"
|
| 8 |
+
description = "FinQA Environment for OpenEnv - financial question-answering on SEC 10-K filing data"
|
| 9 |
+
requires-python = ">=3.10"
|
| 10 |
+
dependencies = [
|
| 11 |
+
# Core OpenEnv dependencies (required for server functionality)
|
| 12 |
+
"openenv-core[core]>=0.2.1",
|
| 13 |
+
"fastapi>=0.115.0",
|
| 14 |
+
"fastmcp>=2.0.0",
|
| 15 |
+
"pydantic>=2.0.0",
|
| 16 |
+
"uvicorn>=0.24.0",
|
| 17 |
+
"requests>=2.31.0",
|
| 18 |
+
# FinQA environment specific dependencies
|
| 19 |
+
"pandas>=2.0.0",
|
| 20 |
+
]
|
| 21 |
+
|
| 22 |
+
[project.optional-dependencies]
|
| 23 |
+
dev = [
|
| 24 |
+
"pytest>=8.0.0",
|
| 25 |
+
"pytest-cov>=4.0.0",
|
| 26 |
+
]
|
| 27 |
+
|
| 28 |
+
[project.scripts]
|
| 29 |
+
server = "finqa_env.server.app:main"
|
| 30 |
+
|
| 31 |
+
[tool.setuptools]
|
| 32 |
+
packages = ["finqa_env", "finqa_env.server"]
|
| 33 |
+
package-dir = { "finqa_env" = ".", "finqa_env.server" = "server" }
|
| 34 |
+
|
server/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# envs/finqa_env/server/__init__.py
|
| 2 |
+
"""Server-side components for the FinQA environment."""
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def __getattr__(name):
|
| 6 |
+
if name == "FinQAEnvironment":
|
| 7 |
+
from .finqa_environment import FinQAEnvironment
|
| 8 |
+
|
| 9 |
+
return FinQAEnvironment
|
| 10 |
+
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
__all__ = ["FinQAEnvironment"]
|
server/app.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# envs/finqa_env/server/app.py
|
| 2 |
+
"""
|
| 3 |
+
FastAPI server for the FinQA environment.
|
| 4 |
+
|
| 5 |
+
Environment Variables:
|
| 6 |
+
FINQA_DATA_PATH: Path to data directory (default: /app/env/data)
|
| 7 |
+
FINQA_MAX_STEPS: Maximum tool calls per episode (default: 50)
|
| 8 |
+
FINQA_TASK: Task name (default: finqa)
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import os
|
| 12 |
+
|
| 13 |
+
from openenv.core.env_server.http_server import create_app
|
| 14 |
+
from openenv.core.env_server.mcp_types import CallToolAction, CallToolObservation
|
| 15 |
+
from .finqa_environment import FinQAEnvironment
|
| 16 |
+
|
| 17 |
+
DATA_PATH = os.environ.get("FINQA_DATA_PATH", "/app/env/data")
|
| 18 |
+
MAX_STEPS = int(os.environ.get("FINQA_MAX_STEPS", "50"))
|
| 19 |
+
TASK = os.environ.get("FINQA_TASK", "finqa")
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def _env_factory():
|
| 23 |
+
"""Create a new FinQAEnvironment instance for each session."""
|
| 24 |
+
return FinQAEnvironment(
|
| 25 |
+
data_path=DATA_PATH,
|
| 26 |
+
max_steps=MAX_STEPS,
|
| 27 |
+
task=TASK,
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# Pass the class (factory) instead of instance for WebSocket session support
|
| 32 |
+
# Use MCP types for action/observation since this is a pure MCP environment
|
| 33 |
+
app = create_app(
|
| 34 |
+
_env_factory, CallToolAction, CallToolObservation, env_name="finqa_env"
|
| 35 |
+
)
|
server/finqa_environment.py
ADDED
|
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# envs/finqa_env/server/finqa_environment.py
|
| 2 |
+
"""
|
| 3 |
+
FinQA Environment Implementation.
|
| 4 |
+
|
| 5 |
+
A financial question-answering environment that evaluates LLMs on their ability
|
| 6 |
+
to answer complex financial questions using tool calls on SEC 10-K filing data.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import logging
|
| 10 |
+
import os
|
| 11 |
+
import random
|
| 12 |
+
import uuid
|
| 13 |
+
from typing import Any, Dict, List, Optional
|
| 14 |
+
|
| 15 |
+
import pandas as pd
|
| 16 |
+
from fastmcp import FastMCP
|
| 17 |
+
|
| 18 |
+
from openenv.core.env_server.mcp_environment import MCPEnvironment
|
| 19 |
+
from openenv.core.env_server.mcp_types import CallToolAction
|
| 20 |
+
from openenv.core.env_server.types import Action, Observation
|
| 21 |
+
|
| 22 |
+
from ..models import FinQAState, AVAILABLE_TOOLS
|
| 23 |
+
from .rewards import compute_reward
|
| 24 |
+
from .tools import FinQATools
|
| 25 |
+
|
| 26 |
+
logger = logging.getLogger(__name__)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class FinQAEnvironment(MCPEnvironment):
|
| 30 |
+
"""
|
| 31 |
+
Financial QA environment for RL training.
|
| 32 |
+
|
| 33 |
+
Evaluates agents on their ability to answer financial questions by:
|
| 34 |
+
- Exploring available tables for a company
|
| 35 |
+
- Querying table metadata and executing SQL queries
|
| 36 |
+
- Performing calculations
|
| 37 |
+
- Submitting final answers
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
data_path: Path to the data directory containing benchmark_questions/ and input_companies/
|
| 41 |
+
max_steps: Maximum number of tool calls per episode (default: 50)
|
| 42 |
+
task: Task name - currently only 'finqa' supported (default: 'finqa')
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
def __init__(
|
| 46 |
+
self,
|
| 47 |
+
data_path: str = "./data",
|
| 48 |
+
max_steps: int = 50,
|
| 49 |
+
task: str = "finqa",
|
| 50 |
+
):
|
| 51 |
+
# Create MCP server and define tools inline
|
| 52 |
+
mcp = FastMCP("finqa_env")
|
| 53 |
+
|
| 54 |
+
self.data_path = data_path
|
| 55 |
+
self.max_steps = max_steps
|
| 56 |
+
self.task = task
|
| 57 |
+
|
| 58 |
+
assert task == "finqa", "Only finqa task is supported"
|
| 59 |
+
|
| 60 |
+
self.questions = self._load_questions()
|
| 61 |
+
logger.info(f"Loaded {len(self.questions)} questions for task '{task}'")
|
| 62 |
+
|
| 63 |
+
self._finqa_tools = FinQATools(data_path)
|
| 64 |
+
|
| 65 |
+
# Register tools with FastMCP
|
| 66 |
+
@mcp.tool
|
| 67 |
+
def get_descriptions(company_name: str) -> str:
|
| 68 |
+
"""
|
| 69 |
+
Get a list of available table names for a company.
|
| 70 |
+
|
| 71 |
+
Args:
|
| 72 |
+
company_name: The name of the company
|
| 73 |
+
|
| 74 |
+
Returns:
|
| 75 |
+
JSON list of table names
|
| 76 |
+
"""
|
| 77 |
+
return self._finqa_tools.get_descriptions(company_name)
|
| 78 |
+
|
| 79 |
+
@mcp.tool
|
| 80 |
+
def get_table_info(company_name: str, table_name: str) -> str:
|
| 81 |
+
"""
|
| 82 |
+
Get table metadata: description, columns, types, unique values.
|
| 83 |
+
|
| 84 |
+
Args:
|
| 85 |
+
company_name: The name of the company
|
| 86 |
+
table_name: The name of the table
|
| 87 |
+
|
| 88 |
+
Returns:
|
| 89 |
+
JSON string with table metadata
|
| 90 |
+
"""
|
| 91 |
+
return self._finqa_tools.get_table_info(company_name, table_name)
|
| 92 |
+
|
| 93 |
+
@mcp.tool
|
| 94 |
+
def sql_query(company_name: str, table_name: str, query: str) -> str:
|
| 95 |
+
"""
|
| 96 |
+
Execute a SQL query on a table. Select * not allowed.
|
| 97 |
+
|
| 98 |
+
Filters are required: WHERE, HAVING, IN, NOT IN, EXISTS, NOT EXISTS,
|
| 99 |
+
ANY, SOME, ALL, LIKE, NOT LIKE, BETWEEN, NOT BETWEEN, IS NULL,
|
| 100 |
+
IS NOT NULL, CASE, FILTER.
|
| 101 |
+
|
| 102 |
+
Args:
|
| 103 |
+
company_name: The name of the company
|
| 104 |
+
table_name: The name of the table
|
| 105 |
+
query: SQL query to execute (must include filters)
|
| 106 |
+
|
| 107 |
+
Returns:
|
| 108 |
+
JSON string with query results
|
| 109 |
+
"""
|
| 110 |
+
return self._finqa_tools.sql_query(company_name, table_name, query)
|
| 111 |
+
|
| 112 |
+
@mcp.tool
|
| 113 |
+
def submit_answer(answer: str) -> str:
|
| 114 |
+
"""
|
| 115 |
+
Submit a final answer for the question.
|
| 116 |
+
|
| 117 |
+
Args:
|
| 118 |
+
answer: The final answer to submit
|
| 119 |
+
|
| 120 |
+
Returns:
|
| 121 |
+
Confirmation message
|
| 122 |
+
"""
|
| 123 |
+
return self._finqa_tools.submit_answer(answer)
|
| 124 |
+
|
| 125 |
+
# Pass the MCP server to the base class
|
| 126 |
+
super().__init__(mcp)
|
| 127 |
+
|
| 128 |
+
# Shuffle dataset for sequential selection
|
| 129 |
+
self._shuffled_questions = self.questions.copy()
|
| 130 |
+
random.shuffle(self._shuffled_questions)
|
| 131 |
+
self._question_index = 0
|
| 132 |
+
|
| 133 |
+
self._state = FinQAState()
|
| 134 |
+
self._history: List[Dict[str, Any]] = []
|
| 135 |
+
|
| 136 |
+
def _load_questions(self) -> List[Dict[str, Any]]:
|
| 137 |
+
"""Load questions from the benchmark CSV."""
|
| 138 |
+
csv_path = os.path.join(self.data_path, "benchmark_questions", f"{self.task}.csv")
|
| 139 |
+
|
| 140 |
+
if not os.path.isfile(csv_path):
|
| 141 |
+
raise FileNotFoundError(f"Benchmark file not found: {csv_path}")
|
| 142 |
+
|
| 143 |
+
df = pd.read_csv(csv_path)
|
| 144 |
+
|
| 145 |
+
questions = []
|
| 146 |
+
for _, row in df.iterrows():
|
| 147 |
+
questions.append({
|
| 148 |
+
"id": str(row.get("id", "")),
|
| 149 |
+
"user_query": row["user_query"],
|
| 150 |
+
"company": row["company"],
|
| 151 |
+
"question": row["question"],
|
| 152 |
+
"answer": row["answer"],
|
| 153 |
+
"question_type": row.get("question_type", ""),
|
| 154 |
+
"explanation": row.get("explanation", ""),
|
| 155 |
+
})
|
| 156 |
+
|
| 157 |
+
return questions
|
| 158 |
+
|
| 159 |
+
def _get_next_question(self) -> Dict[str, Any]:
|
| 160 |
+
"""Get the next question using sequential shuffle selection."""
|
| 161 |
+
if self._question_index >= len(self._shuffled_questions):
|
| 162 |
+
random.shuffle(self._shuffled_questions)
|
| 163 |
+
self._question_index = 0
|
| 164 |
+
|
| 165 |
+
question = self._shuffled_questions[self._question_index]
|
| 166 |
+
self._question_index += 1
|
| 167 |
+
return question
|
| 168 |
+
|
| 169 |
+
def reset(
|
| 170 |
+
self,
|
| 171 |
+
seed: Optional[int] = None,
|
| 172 |
+
episode_id: Optional[str] = None,
|
| 173 |
+
**kwargs: Any,
|
| 174 |
+
) -> Observation:
|
| 175 |
+
"""
|
| 176 |
+
Reset the environment for a new episode.
|
| 177 |
+
|
| 178 |
+
Returns:
|
| 179 |
+
Initial observation with the question
|
| 180 |
+
"""
|
| 181 |
+
question = self._get_next_question()
|
| 182 |
+
self._state = FinQAState(
|
| 183 |
+
episode_id=episode_id or str(uuid.uuid4()),
|
| 184 |
+
step_count=0,
|
| 185 |
+
current_question=question["user_query"],
|
| 186 |
+
current_company=question["company"],
|
| 187 |
+
ground_truth=question["answer"],
|
| 188 |
+
question_id=question["id"],
|
| 189 |
+
)
|
| 190 |
+
self._history = []
|
| 191 |
+
|
| 192 |
+
logger.info(f"Reset episode {self._state.episode_id} with question: {question['question'][:200]}...")
|
| 193 |
+
|
| 194 |
+
return Observation(
|
| 195 |
+
done=False,
|
| 196 |
+
reward=0.0,
|
| 197 |
+
metadata={
|
| 198 |
+
"question": question["user_query"],
|
| 199 |
+
"company": question["company"],
|
| 200 |
+
"tool_result": "",
|
| 201 |
+
"history": [],
|
| 202 |
+
"step_count": 0,
|
| 203 |
+
"available_tools": AVAILABLE_TOOLS.copy(),
|
| 204 |
+
},
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
def _step_impl(
|
| 208 |
+
self,
|
| 209 |
+
action: Action,
|
| 210 |
+
timeout_s: Optional[float] = None,
|
| 211 |
+
**kwargs: Any,
|
| 212 |
+
) -> Observation:
|
| 213 |
+
"""
|
| 214 |
+
Handle non-MCP actions. Returns an error since this env is MCP-only.
|
| 215 |
+
"""
|
| 216 |
+
return Observation(
|
| 217 |
+
done=False,
|
| 218 |
+
reward=0.0,
|
| 219 |
+
metadata={
|
| 220 |
+
"error": f"Unknown action type: {type(action).__name__}. "
|
| 221 |
+
"Use ListToolsAction or CallToolAction for MCP interactions."
|
| 222 |
+
},
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
def step(
|
| 226 |
+
self,
|
| 227 |
+
action: Action,
|
| 228 |
+
timeout_s: Optional[float] = None,
|
| 229 |
+
**kwargs: Any,
|
| 230 |
+
) -> Observation:
|
| 231 |
+
"""
|
| 232 |
+
Execute a step in the environment.
|
| 233 |
+
|
| 234 |
+
Delegates to base class for MCP actions. Handles submit_answer
|
| 235 |
+
reward computation and max-step termination.
|
| 236 |
+
"""
|
| 237 |
+
self._state.step_count += 1
|
| 238 |
+
|
| 239 |
+
# Let the base class handle MCP actions
|
| 240 |
+
obs = super().step(action, timeout_s=timeout_s, **kwargs)
|
| 241 |
+
|
| 242 |
+
# Check if submit_answer was called
|
| 243 |
+
if isinstance(action, CallToolAction) and action.tool_name == "submit_answer":
|
| 244 |
+
submitted_answer = action.arguments.get("answer", "")
|
| 245 |
+
reward = compute_reward(submitted_answer, self._state.ground_truth)
|
| 246 |
+
logger.info(
|
| 247 |
+
f"Episode {self._state.episode_id} ended: "
|
| 248 |
+
f"submitted='{submitted_answer}', truth='{self._state.ground_truth}', reward={reward}"
|
| 249 |
+
)
|
| 250 |
+
return Observation(
|
| 251 |
+
done=True,
|
| 252 |
+
reward=reward,
|
| 253 |
+
metadata={
|
| 254 |
+
**obs.metadata,
|
| 255 |
+
"ground_truth": self._state.ground_truth,
|
| 256 |
+
"submitted_answer": submitted_answer,
|
| 257 |
+
},
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
# Check for max steps
|
| 261 |
+
if self._state.step_count >= self.max_steps:
|
| 262 |
+
logger.info(f"Episode {self._state.episode_id} terminated: max steps reached")
|
| 263 |
+
return Observation(
|
| 264 |
+
done=True,
|
| 265 |
+
reward=0.0,
|
| 266 |
+
metadata={
|
| 267 |
+
**obs.metadata,
|
| 268 |
+
"error": f"Max steps ({self.max_steps}) reached without submitting answer.",
|
| 269 |
+
},
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
return obs
|
| 273 |
+
|
| 274 |
+
@property
|
| 275 |
+
def state(self) -> FinQAState:
|
| 276 |
+
"""Get the current environment state."""
|
| 277 |
+
return self._state
|
server/rewards.py
ADDED
|
@@ -0,0 +1,282 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# envs/finqa_env/server/rewards.py
|
| 2 |
+
"""
|
| 3 |
+
Reward computation for the FinQA environment.
|
| 4 |
+
|
| 5 |
+
Uses fuzzy numerical matching to compare predicted answers against ground truth.
|
| 6 |
+
Handles various formats: \boxed{}, percentages, fractions, decimals.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import re
|
| 10 |
+
from fractions import Fraction
|
| 11 |
+
from typing import Optional, Tuple
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def extract_boxed_answer(text: str) -> Optional[str]:
|
| 15 |
+
"""
|
| 16 |
+
Extract answer from \boxed{...} format.
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
text: Text potentially containing \boxed{answer}
|
| 20 |
+
|
| 21 |
+
Returns:
|
| 22 |
+
The extracted answer or None if not found
|
| 23 |
+
"""
|
| 24 |
+
match = re.search(r"\\boxed\{([^}]+)\}", text)
|
| 25 |
+
if match:
|
| 26 |
+
return match.group(1).strip()
|
| 27 |
+
return None
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def extract_all_boxed_answers(text: str) -> list:
|
| 31 |
+
"""
|
| 32 |
+
Extract all answers from \boxed{...} format.
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
text: Text potentially containing multiple \boxed{answer}
|
| 36 |
+
|
| 37 |
+
Returns:
|
| 38 |
+
List of extracted answers
|
| 39 |
+
"""
|
| 40 |
+
matches = re.findall(r"\\boxed\{([^}]+)\}", text)
|
| 41 |
+
return [m.strip() for m in matches]
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def parse_number(text: str, convert_percent: bool = True) -> Optional[float]:
|
| 45 |
+
"""
|
| 46 |
+
Parse a string into a float, handling various formats.
|
| 47 |
+
|
| 48 |
+
Handles:
|
| 49 |
+
- Plain numbers: "6.118", "-3.14"
|
| 50 |
+
- Percentages: "20.9%", "20.9 %"
|
| 51 |
+
- Fractions: "1/2", "3/4"
|
| 52 |
+
- Thousands separators: "1,234.56"
|
| 53 |
+
- Negative numbers in parens: "(100)"
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
text: String to parse
|
| 57 |
+
convert_percent: If True, divide percentages by 100. If False, just strip the % sign.
|
| 58 |
+
|
| 59 |
+
Returns:
|
| 60 |
+
Float value or None if parsing fails
|
| 61 |
+
"""
|
| 62 |
+
if text is None:
|
| 63 |
+
return None
|
| 64 |
+
|
| 65 |
+
text = text.strip()
|
| 66 |
+
|
| 67 |
+
if not text:
|
| 68 |
+
return None
|
| 69 |
+
|
| 70 |
+
try:
|
| 71 |
+
# Remove LaTeX annotations like \text{million}, \text{%}, etc.
|
| 72 |
+
text = re.sub(r"\\text\{[^}]*\}", "", text)
|
| 73 |
+
|
| 74 |
+
# Remove currency symbols ($ and \$)
|
| 75 |
+
text = text.replace("\\$", "").replace("$", "").strip()
|
| 76 |
+
|
| 77 |
+
# Handle percentage (including LaTeX escaped \%)
|
| 78 |
+
if "%" in text or "\\%" in text:
|
| 79 |
+
text = text.replace("\\%", "").replace("%", "").strip()
|
| 80 |
+
if convert_percent:
|
| 81 |
+
return float(text.replace(",", "")) / 100
|
| 82 |
+
else:
|
| 83 |
+
return float(text.replace(",", ""))
|
| 84 |
+
|
| 85 |
+
# Handle parentheses for negative numbers
|
| 86 |
+
if text.startswith("(") and text.endswith(")"):
|
| 87 |
+
text = "-" + text[1:-1]
|
| 88 |
+
|
| 89 |
+
# Handle fractions (e.g., "1/2", "3/4")
|
| 90 |
+
if "/" in text and not text.startswith("-"):
|
| 91 |
+
try:
|
| 92 |
+
return float(Fraction(text))
|
| 93 |
+
except (ValueError, ZeroDivisionError):
|
| 94 |
+
pass
|
| 95 |
+
|
| 96 |
+
# Handle negative fractions
|
| 97 |
+
if text.startswith("-") and "/" in text:
|
| 98 |
+
try:
|
| 99 |
+
return -float(Fraction(text[1:]))
|
| 100 |
+
except (ValueError, ZeroDivisionError):
|
| 101 |
+
pass
|
| 102 |
+
|
| 103 |
+
# Remove thousands separators and parse
|
| 104 |
+
text = text.replace(",", "")
|
| 105 |
+
return float(text)
|
| 106 |
+
|
| 107 |
+
except (ValueError, TypeError):
|
| 108 |
+
return None
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def normalize_answer(answer: str, convert_percent: bool = True) -> Tuple[Optional[float], str]:
|
| 112 |
+
"""
|
| 113 |
+
Normalize an answer string to a comparable format.
|
| 114 |
+
|
| 115 |
+
Args:
|
| 116 |
+
answer: Raw answer string
|
| 117 |
+
convert_percent: If True, divide percentages by 100. If False, just strip the % sign.
|
| 118 |
+
|
| 119 |
+
Returns:
|
| 120 |
+
Tuple of (parsed_number, cleaned_string)
|
| 121 |
+
"""
|
| 122 |
+
if answer is None:
|
| 123 |
+
return None, ""
|
| 124 |
+
|
| 125 |
+
# Try to extract from \boxed{} first
|
| 126 |
+
boxed = extract_boxed_answer(answer)
|
| 127 |
+
if boxed:
|
| 128 |
+
answer = boxed
|
| 129 |
+
|
| 130 |
+
# Clean up whitespace
|
| 131 |
+
answer = answer.strip()
|
| 132 |
+
|
| 133 |
+
# Try to parse as number
|
| 134 |
+
num = parse_number(answer, convert_percent)
|
| 135 |
+
|
| 136 |
+
return num, answer.lower()
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def extract_numbers_from_multi_value(text: str) -> list:
|
| 140 |
+
"""
|
| 141 |
+
Extract all numbers from a comma/semicolon separated string.
|
| 142 |
+
Handles formats like "2022: 0.933, 2023: 0.930" or "0.933, 0.931, 0.930".
|
| 143 |
+
"""
|
| 144 |
+
parts = _split_multi_value(text)
|
| 145 |
+
return [num for _, num in parts]
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def _split_multi_value(text: str) -> list:
|
| 149 |
+
"""
|
| 150 |
+
Extract (key, number) pairs from a comma/semicolon separated string.
|
| 151 |
+
|
| 152 |
+
Returns list of (key, float) tuples. Key is a year string like "2022"
|
| 153 |
+
if found, otherwise None.
|
| 154 |
+
"""
|
| 155 |
+
# Split by comma or semicolon (with optional LaTeX spacing like \; or \ )
|
| 156 |
+
parts = re.split(r'[,;]\s*|\\[;,]\s*', text)
|
| 157 |
+
results = []
|
| 158 |
+
for part in parts:
|
| 159 |
+
# Strip LaTeX whitespace commands (\ , \;, \,)
|
| 160 |
+
part = re.sub(r'\\[;, ]', ' ', part).strip()
|
| 161 |
+
if not part:
|
| 162 |
+
continue
|
| 163 |
+
# Try to extract a year label (e.g. "2022:", "2022 to 2023:", "2022→2023:")
|
| 164 |
+
# Normalize \rightarrow and similar to "to" before matching
|
| 165 |
+
part_normalized = re.sub(r'\\rightarrow|→|->|−>', ' to ', part)
|
| 166 |
+
year_match = re.search(r'(20\d{2}(?:\s*to\s*20\d{2})?)', part_normalized)
|
| 167 |
+
key = year_match.group(1) if year_match else None
|
| 168 |
+
# Remove label prefix like "2022:" or "2022:\"
|
| 169 |
+
cleaned = re.sub(r'^[^:]*:\s*\\?\s*', '', part)
|
| 170 |
+
num = parse_number(cleaned)
|
| 171 |
+
if num is not None:
|
| 172 |
+
results.append((key, num))
|
| 173 |
+
return results
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def compare_single_values(pred_num: Optional[float], truth_num: Optional[float],
|
| 177 |
+
pred_str: str, truth_str: str,
|
| 178 |
+
tolerance: float = 0.01, max_absolute_diff: float = 1.0) -> bool:
|
| 179 |
+
"""Compare two single values."""
|
| 180 |
+
# If both are numbers, compare numerically with tolerance
|
| 181 |
+
if pred_num is not None and truth_num is not None:
|
| 182 |
+
# Handle zero case
|
| 183 |
+
if truth_num == 0:
|
| 184 |
+
return abs(pred_num) < 0.001
|
| 185 |
+
|
| 186 |
+
# Calculate both errors
|
| 187 |
+
abs_diff = abs(pred_num - truth_num)
|
| 188 |
+
relative_error = abs_diff / abs(truth_num)
|
| 189 |
+
|
| 190 |
+
# BOTH conditions must pass
|
| 191 |
+
return relative_error <= tolerance and abs_diff <= max_absolute_diff
|
| 192 |
+
|
| 193 |
+
# If one is a number and other isn't, not equal
|
| 194 |
+
if (pred_num is None) != (truth_num is None):
|
| 195 |
+
return False
|
| 196 |
+
|
| 197 |
+
# Fall back to string comparison
|
| 198 |
+
return pred_str == truth_str
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def compute_reward(predicted: str, ground_truth: str, tolerance: float = 0.01, max_absolute_diff: float = 1.0) -> float:
|
| 202 |
+
"""
|
| 203 |
+
Compute reward based on answer correctness.
|
| 204 |
+
|
| 205 |
+
Uses fuzzy numerical matching with BOTH relative and absolute tolerance checks.
|
| 206 |
+
A prediction is correct only if it passes BOTH conditions.
|
| 207 |
+
|
| 208 |
+
Handles multiple values (e.g., ground truth with multiple \boxed{} values).
|
| 209 |
+
|
| 210 |
+
Args:
|
| 211 |
+
predicted: The predicted answer from the agent
|
| 212 |
+
ground_truth: The expected correct answer
|
| 213 |
+
tolerance: Relative tolerance for numerical comparison (default 1%)
|
| 214 |
+
max_absolute_diff: Maximum absolute difference allowed (default 1.0)
|
| 215 |
+
|
| 216 |
+
Returns:
|
| 217 |
+
1.0 if correct, 0.0 if incorrect
|
| 218 |
+
"""
|
| 219 |
+
# Check for multiple boxed answers in ground truth
|
| 220 |
+
truth_boxed = extract_all_boxed_answers(ground_truth)
|
| 221 |
+
|
| 222 |
+
if len(truth_boxed) > 1:
|
| 223 |
+
# Multiple ground truth values - split prediction by comma/semicolon
|
| 224 |
+
pred_values = re.split(r'[,;]\s*', predicted.strip())
|
| 225 |
+
|
| 226 |
+
if len(pred_values) != len(truth_boxed):
|
| 227 |
+
return 0.0 # Different number of values
|
| 228 |
+
|
| 229 |
+
# Compare each pair
|
| 230 |
+
for pred_val, truth_val in zip(pred_values, truth_boxed):
|
| 231 |
+
# Strip year/label prefix (e.g. "2024: -4" -> "-4")
|
| 232 |
+
pred_val_cleaned = re.sub(r'^[^:]*:\s*', '', pred_val) if ':' in pred_val else pred_val
|
| 233 |
+
pred_num, pred_str = normalize_answer(pred_val_cleaned)
|
| 234 |
+
truth_num, truth_str = normalize_answer(truth_val)
|
| 235 |
+
|
| 236 |
+
if not compare_single_values(pred_num, truth_num, pred_str, truth_str, tolerance, max_absolute_diff):
|
| 237 |
+
# Fallback: try without % conversion (for percentage points like "4.5%" vs "4.5")
|
| 238 |
+
pred_num_no_pct, _ = normalize_answer(pred_val, convert_percent=False)
|
| 239 |
+
if not compare_single_values(pred_num_no_pct, truth_num, pred_str, truth_str, tolerance, max_absolute_diff):
|
| 240 |
+
return 0.0
|
| 241 |
+
|
| 242 |
+
return 1.0 # All values matched
|
| 243 |
+
|
| 244 |
+
# Single value comparison
|
| 245 |
+
pred_num, pred_str = normalize_answer(predicted)
|
| 246 |
+
truth_num, truth_str = normalize_answer(ground_truth)
|
| 247 |
+
|
| 248 |
+
if compare_single_values(pred_num, truth_num, pred_str, truth_str, tolerance, max_absolute_diff):
|
| 249 |
+
return 1.0
|
| 250 |
+
|
| 251 |
+
pred_num_no_pct, _ = normalize_answer(predicted, convert_percent=False)
|
| 252 |
+
if compare_single_values(pred_num_no_pct, truth_num, pred_str, truth_str, tolerance, max_absolute_diff):
|
| 253 |
+
return 1.0
|
| 254 |
+
|
| 255 |
+
# Fallback: multi-value inside single \boxed{} (only if truth didn't parse as single number)
|
| 256 |
+
if len(truth_boxed) == 1 and truth_num is None:
|
| 257 |
+
truth_pairs = _split_multi_value(truth_boxed[0])
|
| 258 |
+
pred_pairs = _split_multi_value(predicted)
|
| 259 |
+
if len(truth_pairs) > 1 and len(pred_pairs) == len(truth_pairs):
|
| 260 |
+
# If both sides have year keys, match by key (order-independent)
|
| 261 |
+
truth_keys = {k for k, _ in truth_pairs if k is not None}
|
| 262 |
+
pred_keys = {k for k, _ in pred_pairs if k is not None}
|
| 263 |
+
if truth_keys and pred_keys and truth_keys == pred_keys:
|
| 264 |
+
truth_map = {k: v for k, v in truth_pairs}
|
| 265 |
+
pred_map = {k: v for k, v in pred_pairs}
|
| 266 |
+
for key in truth_map:
|
| 267 |
+
p, t = pred_map[key], truth_map[key]
|
| 268 |
+
abs_diff = abs(p - t)
|
| 269 |
+
rel_err = abs_diff / abs(t) if t != 0 else (0 if p == 0 else float('inf'))
|
| 270 |
+
if not (rel_err <= tolerance and abs_diff <= max_absolute_diff):
|
| 271 |
+
return 0.0
|
| 272 |
+
return 1.0
|
| 273 |
+
|
| 274 |
+
# Otherwise fall back to positional matching
|
| 275 |
+
for (_, p), (_, t) in zip(pred_pairs, truth_pairs):
|
| 276 |
+
abs_diff = abs(p - t)
|
| 277 |
+
rel_err = abs_diff / abs(t) if t != 0 else (0 if p == 0 else float('inf'))
|
| 278 |
+
if not (rel_err <= tolerance and abs_diff <= max_absolute_diff):
|
| 279 |
+
return 0.0
|
| 280 |
+
return 1.0
|
| 281 |
+
|
| 282 |
+
return 0.0
|
server/tools.py
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# envs/finqa_env/server/tools.py
|
| 2 |
+
"""
|
| 3 |
+
Tool implementations for the FinQA environment.
|
| 4 |
+
|
| 5 |
+
Ported from FinQABenchmark with simplifications:
|
| 6 |
+
- Removed LangChain dependencies
|
| 7 |
+
- Added submit_answer tool for episode termination
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import json
|
| 11 |
+
import os
|
| 12 |
+
import re
|
| 13 |
+
import sqlite3
|
| 14 |
+
from typing import Any, Dict, List, Tuple
|
| 15 |
+
|
| 16 |
+
import pandas as pd
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class FinQATools:
|
| 20 |
+
"""
|
| 21 |
+
Tool implementations for financial QA tasks.
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
data_path: Path to the data directory containing benchmark_questions/ and input_companies/
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
def __init__(self, data_path: str):
|
| 28 |
+
self.data_path = data_path
|
| 29 |
+
self.companies_path = os.path.join(data_path, "input_companies")
|
| 30 |
+
self._tables_cleaned = None
|
| 31 |
+
|
| 32 |
+
@property
|
| 33 |
+
def tables_cleaned(self) -> Dict:
|
| 34 |
+
"""Lazy load the cleaned tables metadata."""
|
| 35 |
+
if self._tables_cleaned is None:
|
| 36 |
+
tables_path = os.path.join(self.companies_path, "tables_cleaned_all_companies.json")
|
| 37 |
+
with open(tables_path, "r") as f:
|
| 38 |
+
self._tables_cleaned = json.load(f)
|
| 39 |
+
return self._tables_cleaned
|
| 40 |
+
|
| 41 |
+
def get_available_companies(self) -> List[str]:
|
| 42 |
+
"""Get list of available company names."""
|
| 43 |
+
return [
|
| 44 |
+
d for d in os.listdir(self.companies_path)
|
| 45 |
+
if os.path.isdir(os.path.join(self.companies_path, d))
|
| 46 |
+
]
|
| 47 |
+
|
| 48 |
+
def execute_tool(self, tool_name: str, tool_args: Dict[str, Any]) -> Tuple[str, bool]:
|
| 49 |
+
"""
|
| 50 |
+
Execute a tool and return its result.
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
tool_name: Name of the tool to execute
|
| 54 |
+
tool_args: Arguments for the tool
|
| 55 |
+
|
| 56 |
+
Returns:
|
| 57 |
+
Tuple of (result_string, is_final_answer)
|
| 58 |
+
"""
|
| 59 |
+
if tool_name == "get_descriptions":
|
| 60 |
+
return self.get_descriptions(**tool_args), False
|
| 61 |
+
elif tool_name == "get_table_info":
|
| 62 |
+
return self.get_table_info(**tool_args), False
|
| 63 |
+
elif tool_name == "sql_query":
|
| 64 |
+
return self.sql_query(**tool_args), False
|
| 65 |
+
elif tool_name == "submit_answer":
|
| 66 |
+
return self.submit_answer(**tool_args), True
|
| 67 |
+
else:
|
| 68 |
+
return f"Error: Unknown tool '{tool_name}'", False
|
| 69 |
+
|
| 70 |
+
def get_descriptions(self, company_name: str) -> str:
|
| 71 |
+
"""
|
| 72 |
+
Get a list of available table names for a company.
|
| 73 |
+
|
| 74 |
+
Args:
|
| 75 |
+
company_name: The name of the company
|
| 76 |
+
|
| 77 |
+
Returns:
|
| 78 |
+
JSON list of table names
|
| 79 |
+
"""
|
| 80 |
+
company_path = os.path.join(self.companies_path, company_name)
|
| 81 |
+
|
| 82 |
+
if not os.path.isdir(company_path):
|
| 83 |
+
available = self.get_available_companies()
|
| 84 |
+
return f"Error: '{company_name}' not found. Available companies: {available}"
|
| 85 |
+
|
| 86 |
+
# Get all JSON files (tables) for this company
|
| 87 |
+
tables = []
|
| 88 |
+
for f in os.listdir(company_path):
|
| 89 |
+
if f.endswith(".json"):
|
| 90 |
+
tables.append(f.replace(".json", ""))
|
| 91 |
+
|
| 92 |
+
return json.dumps(tables)
|
| 93 |
+
|
| 94 |
+
def get_table_info(self, company_name: str, table_name: str) -> str:
|
| 95 |
+
"""
|
| 96 |
+
Get table metadata: description, columns, types, unique values.
|
| 97 |
+
|
| 98 |
+
Args:
|
| 99 |
+
company_name: The name of the company
|
| 100 |
+
table_name: The name of the table
|
| 101 |
+
|
| 102 |
+
Returns:
|
| 103 |
+
JSON string with table metadata (description, columns, dtypes, unique values)
|
| 104 |
+
"""
|
| 105 |
+
company_path = os.path.join(self.companies_path, company_name)
|
| 106 |
+
|
| 107 |
+
if not os.path.isdir(company_path):
|
| 108 |
+
available = self.get_available_companies()
|
| 109 |
+
return f"Error: '{company_name}' not found. Available companies: {available}"
|
| 110 |
+
|
| 111 |
+
# Clean table name (remove .json or .txt if present)
|
| 112 |
+
cleaned_table_name = table_name.replace(".json", "").replace(".txt", "")
|
| 113 |
+
table_key = f"{company_name}/{cleaned_table_name}"
|
| 114 |
+
|
| 115 |
+
if table_key not in self.tables_cleaned:
|
| 116 |
+
return f"Error: Table '{table_name}' not found for company '{company_name}'"
|
| 117 |
+
|
| 118 |
+
table_info = self.tables_cleaned[table_key].copy()
|
| 119 |
+
|
| 120 |
+
# Load the actual table to get column info
|
| 121 |
+
cleaned_table = pd.DataFrame(json.loads(table_info["table"]))
|
| 122 |
+
|
| 123 |
+
# Drop numeric columns (keep only structure columns for querying hints)
|
| 124 |
+
cols_to_drop = []
|
| 125 |
+
for col in cleaned_table.columns.tolist()[1:]: # Skip first column
|
| 126 |
+
vals = cleaned_table[col].tolist()[1:]
|
| 127 |
+
cleaned_vals = [
|
| 128 |
+
"".join(char for char in str(x) if char.isalnum()).strip()
|
| 129 |
+
for x in vals
|
| 130 |
+
]
|
| 131 |
+
all_numeric = all(
|
| 132 |
+
v.isnumeric() or len(v) == 0 for v in cleaned_vals
|
| 133 |
+
)
|
| 134 |
+
if all_numeric:
|
| 135 |
+
cols_to_drop.append(col)
|
| 136 |
+
|
| 137 |
+
table_info["column_dtypes"] = {
|
| 138 |
+
col: str(cleaned_table[col].dtype)
|
| 139 |
+
for col in cleaned_table.columns.tolist()
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
# Only show unique values for non-numeric columns
|
| 143 |
+
cleaned_table_filtered = cleaned_table.drop(cols_to_drop, axis=1)
|
| 144 |
+
table_info["unique_vals_per_col"] = {
|
| 145 |
+
col: list(cleaned_table_filtered[col].unique())
|
| 146 |
+
for col in cleaned_table_filtered.columns.tolist()
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
# Remove the raw table data from response
|
| 150 |
+
del table_info["table"]
|
| 151 |
+
|
| 152 |
+
return json.dumps(table_info, indent=0).replace("\n", "")
|
| 153 |
+
|
| 154 |
+
def sql_query(self, company_name: str, table_name: str, query: str) -> str:
|
| 155 |
+
"""
|
| 156 |
+
Execute a SQL query on a table. Select * not allowed (too inefficient).
|
| 157 |
+
|
| 158 |
+
Filters are required to query: WHERE, HAVING, IN, NOT IN, EXISTS, NOT EXISTS, ANY, SOME, ALL, LIKE, NOT LIKE, BETWEEN, NOT BETWEEN, IS NULL, IS NOT NULL, CASE, FILTER.
|
| 159 |
+
|
| 160 |
+
Args:
|
| 161 |
+
company_name: The name of the company
|
| 162 |
+
table_name: The name of the table
|
| 163 |
+
query: SQL query to execute (must include filters)
|
| 164 |
+
|
| 165 |
+
Returns:
|
| 166 |
+
JSON string with query results
|
| 167 |
+
"""
|
| 168 |
+
# Validate query has filters (prevent full table scans)
|
| 169 |
+
if "select *" in query.lower():
|
| 170 |
+
return "Error: SELECT * is not allowed (too inefficient)"
|
| 171 |
+
|
| 172 |
+
sql_filters = [
|
| 173 |
+
"WHERE", "HAVING", "IN", "NOT IN", "EXISTS", "NOT EXISTS",
|
| 174 |
+
"ANY", "SOME", "ALL", "LIKE", "NOT LIKE", "BETWEEN",
|
| 175 |
+
"NOT BETWEEN", "IS NULL", "IS NOT NULL", "CASE", "FILTER"
|
| 176 |
+
]
|
| 177 |
+
|
| 178 |
+
query_upper = re.sub(r"(\r|\n|\t)+", " ", query).upper()
|
| 179 |
+
pattern = r"(?<!\w|\[)(" + "|".join([re.escape(f) for f in sql_filters]) + r")(?!\w|\])"
|
| 180 |
+
|
| 181 |
+
has_filter = (
|
| 182 |
+
any(f" {filt} " in query_upper for filt in sql_filters) or
|
| 183 |
+
len(re.findall(pattern, query_upper)) > 0
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
if not has_filter:
|
| 187 |
+
return "Error: Query must include filters (WHERE, HAVING, etc.)"
|
| 188 |
+
|
| 189 |
+
# Clean table name
|
| 190 |
+
cleaned_table_name = table_name.replace(".txt", "").replace(".json", "")
|
| 191 |
+
table_path = os.path.join(self.companies_path, company_name, f"{cleaned_table_name}.json")
|
| 192 |
+
|
| 193 |
+
if not os.path.isfile(table_path):
|
| 194 |
+
return f"Error: Table file not found at {table_path}"
|
| 195 |
+
|
| 196 |
+
try:
|
| 197 |
+
# Load table and execute query
|
| 198 |
+
conn = sqlite3.connect(":memory:")
|
| 199 |
+
df = pd.read_json(table_path)
|
| 200 |
+
df.to_sql(cleaned_table_name, conn, index=False, if_exists="replace")
|
| 201 |
+
result = pd.read_sql_query(query, conn)
|
| 202 |
+
conn.close()
|
| 203 |
+
|
| 204 |
+
return result.to_json(orient="records")
|
| 205 |
+
except Exception as e:
|
| 206 |
+
return f"Error executing query: {str(e)}"
|
| 207 |
+
|
| 208 |
+
def submit_answer(self, answer: str) -> str:
|
| 209 |
+
"""
|
| 210 |
+
Submit a final answer for the question.
|
| 211 |
+
|
| 212 |
+
Args:
|
| 213 |
+
answer: The final answer to submit
|
| 214 |
+
|
| 215 |
+
Returns:
|
| 216 |
+
Confirmation message
|
| 217 |
+
"""
|
| 218 |
+
return f"Answer submitted: {answer}"
|
uv.lock
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|