| |
| """ |
| FastAPI server for the FinQA environment with rich Gradio web UI. |
| |
| Environment Variables: |
| FINQA_DATA_PATH: Path to data directory (default: /app/env/data) |
| FINQA_MAX_STEPS: Maximum tool calls per episode (default: 50) |
| FINQA_TASK: Task name (default: finqa) |
| """ |
|
|
| import json |
| import os |
| from typing import Any, Dict |
|
|
| import gradio as gr |
| from pydantic import field_validator |
|
|
| from openenv.core.env_server.http_server import create_app |
| from openenv.core.env_server.mcp_types import CallToolAction, CallToolObservation |
|
|
| from openenv.core.env_server.gradio_theme import OPENENV_GRADIO_THEME, OPENENV_GRADIO_CSS |
| from .finqa_environment import FinQAEnvironment |
|
|
| DATA_PATH = os.environ.get("FINQA_DATA_PATH", "/app/env/data") |
| MAX_STEPS = int(os.environ.get("FINQA_MAX_STEPS", "50")) |
| TASK = os.environ.get("FINQA_TASK", "finqa") |
|
|
|
|
| def _env_factory(): |
| """Create a new FinQAEnvironment instance for each session.""" |
| return FinQAEnvironment( |
| data_path=DATA_PATH, |
| max_steps=MAX_STEPS, |
| task=TASK, |
| ) |
|
|
|
|
| class FinQACallToolAction(CallToolAction): |
| """CallToolAction that accepts JSON strings for arguments (web UI sends strings).""" |
|
|
| @field_validator("arguments", mode="before") |
| @classmethod |
| def parse_arguments(cls, v: Any) -> Dict[str, Any]: |
| if isinstance(v, str): |
| return json.loads(v) |
| return v |
|
|
|
|
| |
| |
| |
|
|
| SNORKEL_LOGO_URL = "https://cdn-avatars.huggingface.co/v1/production/uploads/1641327454607-608849cadf398c3b285ce95b.png" |
|
|
| DESCRIPTION_MD = f"""\ |
| <div style="display: flex; align-items: center; gap: 16px; margin-bottom: 8px;"> |
| <img src="{SNORKEL_LOGO_URL}" alt="Snorkel AI" style="height: 48px; width: 48px; border-radius: 8px;"> |
| <div> |
| <h1 style="margin: 0; font-size: 1.8em;">FinQA Environment</h1> |
| <p style="margin: 0; opacity: 0.7;">by <a href="https://snorkel.ai" target="_blank">Snorkel AI</a></p> |
| </div> |
| </div> |
| |
| A financial question-answering environment for evaluating RL agents on SEC 10-K |
| filing data. Based on [FinQA Benchmark](https://github.com/snorkel-ai/FinQABenchmark), |
| built on the [OpenEnv](https://github.com/snorkel-ai/openenv) framework with |
| [MCP](https://modelcontextprotocol.io/) tool-use protocol. |
| |
| **290 benchmark questions** from SEC 10-K filings across Alphabet, Amazon, Apple, |
| AT&T, Bank of America, Disney, and more. |
| |
| --- |
| |
| ## How It Works |
| |
| 1. **Reset** β receive a financial question and target company |
| 2. **Explore** β discover tables, inspect schemas, run SQL queries |
| 3. **Reason** β compute the answer from query results |
| 4. **Submit** β the environment scores it with fuzzy numerical matching |
| |
| --- |
| |
| ## Available Tools |
| |
| | Tool | Description | Parameters | |
| |------|-------------|------------| |
| | `get_descriptions` | Lists all financial data tables for a company | `company_name` | |
| | `get_table_info` | Returns table metadata: columns, types, and sample values | `company_name`, `table_name` | |
| | `sql_query` | Executes SQL on an in-memory SQLite database (filters required) | `company_name`, `table_name`, `query` | |
| | `submit_answer` | Submits the final answer and ends the episode | `answer` | |
| |
| --- |
| |
| ## Example Walkthrough |
| |
| After clicking **Reset Environment**, you'll get a question and company. Here's how to |
| use each tool step by step: |
| |
| **Step 1 β Discover tables** for the company: |
| - **Tool Name:** `get_descriptions` |
| - **Arguments:** `{{"company_name": "alphabet"}}` |
| |
| **Step 2 β Inspect a table's schema** before querying: |
| - **Tool Name:** `get_table_info` |
| - **Arguments:** `{{"company_name": "alphabet", "table_name": "goog_AssetsAndLiabilitiesLesseeTableTextBlock"}}` |
| |
| **Step 3 β Run a SQL query** (must include a filter clause): |
| - **Tool Name:** `sql_query` |
| - **Arguments:** `{{"company_name": "alphabet", "table_name": "goog_AssetsAndLiabilitiesLesseeTableTextBlock", "query": "SELECT metric, [2024] FROM goog_AssetsAndLiabilitiesLesseeTableTextBlock WHERE metric LIKE '%operating%'"}}` |
| |
| **Step 4 β Submit your answer** once you've computed it: |
| - **Tool Name:** `submit_answer` |
| - **Arguments:** `{{"answer": "6.118"}}` |
| |
| **Available companies:** `alphabet`, `amazon`, `apple`, `at_t`, `berkshire`, `boa`, |
| `boeing`, `caterpillar`, `chubb`, `citibank`, `disney`, `fedex`, `ford`, `gm`, |
| `gs`, `meta`, `microsoft`, `nvidia` |
| |
| --- |
| |
| ## Reward |
| |
| **1.0** for correct, **0.0** for incorrect β fuzzy numerical matching with |
| β€ 1% relative error and β€ 1.0 absolute difference. Supports percentages, |
| fractions, decimals, and `\\boxed{{}}` format. Max **50** tool calls per episode. |
| """ |
|
|
|
|
| |
| |
| |
|
|
| def _build_gradio_app(env: FinQAEnvironment) -> gr.Blocks: |
| """Build a Gradio Blocks app with descriptive content and interactive form.""" |
|
|
| def reset_env(): |
| try: |
| obs = env.reset() |
| question = obs.metadata.get("question", "") |
| company = obs.metadata.get("company", "") |
| obs_md = f"**Question:** {question}\n\n**Company:** {company}" |
| raw = json.dumps(obs.metadata, indent=2, default=str) |
| return obs_md, raw, "Environment reset successfully." |
| except Exception as e: |
| return "", "", f"Error: {e}" |
|
|
| def step_env(tool_name: str, arguments: str): |
| if not tool_name or not tool_name.strip(): |
| return "", "", "Please enter a tool name." |
| try: |
| args = json.loads(arguments) if arguments and arguments.strip() else {} |
| except json.JSONDecodeError as e: |
| return "", "", f"Invalid JSON arguments: {e}" |
| try: |
| action = FinQACallToolAction(tool_name=tool_name.strip(), arguments=args) |
| obs = env.step(action) |
| parts = [] |
| if obs.reward is not None: |
| parts.append(f"**Reward:** `{obs.reward}`") |
| if obs.done is not None: |
| parts.append(f"**Done:** `{obs.done}`") |
| |
| tool_result = "" |
| if hasattr(obs, "result") and obs.result is not None: |
| r = obs.result |
| if hasattr(r, "data") and r.data: |
| tool_result = str(r.data) |
| elif hasattr(r, "content") and r.content: |
| tool_result = "\n".join( |
| c.text for c in r.content if hasattr(c, "text") |
| ) |
| if not tool_result: |
| tool_result = obs.metadata.get("tool_result", "") |
| if tool_result: |
| parts.append(f"**Tool Result:**\n```\n{tool_result}\n```") |
| obs_md = "\n\n".join(parts) if parts else "*No observation data*" |
| raw_data = { |
| "done": obs.done, |
| "reward": obs.reward, |
| "tool_name": getattr(obs, "tool_name", None), |
| "result": tool_result, |
| "metadata": obs.metadata, |
| } |
| raw = json.dumps(raw_data, indent=2, default=str) |
| return obs_md, raw, "Step complete." |
| except Exception as e: |
| return "", "", f"Error: {e}" |
|
|
| def get_state(): |
| try: |
| s = env.state |
| data = s.model_dump() if hasattr(s, "model_dump") else s.__dict__ |
| return json.dumps(data, indent=2, default=str) |
| except Exception as e: |
| return f"Error: {e}" |
|
|
| with gr.Blocks( |
| title="FinQA Environment", |
| theme=OPENENV_GRADIO_THEME, |
| css=OPENENV_GRADIO_CSS, |
| ) as demo: |
| gr.Markdown(DESCRIPTION_MD) |
|
|
| gr.Markdown("---\n## Try It Out") |
|
|
| with gr.Row(): |
| reset_btn = gr.Button("Reset Environment", variant="primary") |
| state_btn = gr.Button("Get State", variant="secondary") |
|
|
| obs_display = gr.Markdown( |
| value="Click **Reset Environment** to start a new episode.", |
| label="Current Observation", |
| ) |
|
|
| with gr.Group(): |
| tool_name_input = gr.Textbox( |
| label="Tool Name", |
| placeholder="e.g. get_descriptions, get_table_info, sql_query, submit_answer", |
| ) |
| args_input = gr.Textbox( |
| label="Arguments (JSON)", |
| placeholder='e.g. {"company_name": "alphabet"}', |
| lines=3, |
| ) |
| step_btn = gr.Button("Step", variant="primary") |
|
|
| status = gr.Textbox(label="Status", interactive=False) |
| raw_json = gr.Code( |
| label="Raw JSON Response", |
| language="json", |
| interactive=False, |
| ) |
|
|
| reset_btn.click( |
| fn=reset_env, |
| outputs=[obs_display, raw_json, status], |
| ) |
| step_btn.click( |
| fn=step_env, |
| inputs=[tool_name_input, args_input], |
| outputs=[obs_display, raw_json, status], |
| ) |
| state_btn.click( |
| fn=get_state, |
| outputs=[raw_json], |
| ) |
|
|
| return demo |
|
|
|
|
| |
| |
| |
|
|
| app = create_app( |
| _env_factory, FinQACallToolAction, CallToolObservation, env_name="finqa_env" |
| ) |
|
|
| |
| _web_env = _env_factory() |
| _gradio_app = _build_gradio_app(_web_env) |
| app = gr.mount_gradio_app(app, _gradio_app, path="/web") |
| app = gr.mount_gradio_app(app, _gradio_app, path="/") |
|
|