Spaces:
Sleeping
Sleeping
Mohammed Altaf commited on
Commit ·
8ab6a5f
0
Parent(s):
first commit
Browse files- .gitignore +18 -0
- .python-version +1 -0
- README.md +143 -0
- __init__.py +15 -0
- baseline.py +187 -0
- client.py +63 -0
- datasets/sales.csv +0 -0
- models.py +58 -0
- openenv.yaml +8 -0
- pyproject.toml +18 -0
- server/Dockerfile +58 -0
- server/__init__.py +0 -0
- server/app.py +23 -0
- server/data_analysis_env.py +270 -0
- tasks/__init__.py +14 -0
- tasks/base_task.py +62 -0
- tasks/task_easy.py +55 -0
- tasks/task_hard.py +111 -0
- tasks/task_medium.py +85 -0
- uv.lock +0 -0
.gitignore
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python-generated files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[oc]
|
| 4 |
+
build/
|
| 5 |
+
dist/
|
| 6 |
+
wheels/
|
| 7 |
+
*.egg-info
|
| 8 |
+
|
| 9 |
+
# Virtual environments
|
| 10 |
+
.venv
|
| 11 |
+
OpenEnv/
|
| 12 |
+
*.ipynb
|
| 13 |
+
personal/
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
# avoid claude stuff
|
| 17 |
+
CLAUDE.md
|
| 18 |
+
.claude
|
.python-version
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
3.13
|
README.md
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Data Analysis Agent Environment
|
| 2 |
+
|
| 3 |
+
An OpenEnv-compliant RL environment for training and evaluating data analysis agents. Agents execute pandas code against a business dataset to answer analytical questions, graded by deterministic programmatic graders.
|
| 4 |
+
|
| 5 |
+
## Motivation
|
| 6 |
+
|
| 7 |
+
Data analysis is a universal real-world task. Every business needs analysts who can query datasets, compute metrics, and extract insights. This environment lets RL agents practice that exact workflow — explore a dataset with code, then submit a precise answer — with automatic scoring.
|
| 8 |
+
|
| 9 |
+
## Action & Observation Spaces
|
| 10 |
+
|
| 11 |
+
### Action (`DataAction`)
|
| 12 |
+
|
| 13 |
+
| Field | Type | Description |
|
| 14 |
+
|---|---|---|
|
| 15 |
+
| `action_type` | `"execute_code"` or `"submit_answer"` | What the agent wants to do |
|
| 16 |
+
| `code` | `str` (optional) | Python/pandas code to execute |
|
| 17 |
+
| `answer` | `str` (optional) | Final answer to submit for grading |
|
| 18 |
+
|
| 19 |
+
### Observation (`DataObservation`)
|
| 20 |
+
|
| 21 |
+
| Field | Type | Description |
|
| 22 |
+
|---|---|---|
|
| 23 |
+
| `output` | `str` | Stdout from code execution or environment messages |
|
| 24 |
+
| `success` | `bool` | Whether the action succeeded |
|
| 25 |
+
| `error` | `str` (optional) | Error message if action failed |
|
| 26 |
+
| `task_description` | `str` | The question to answer (set on reset) |
|
| 27 |
+
| `dataset_info` | `str` | Dataset schema summary (set on reset) |
|
| 28 |
+
| `done` | `bool` | Whether the episode is over |
|
| 29 |
+
| `reward` | `float` | Step reward |
|
| 30 |
+
|
| 31 |
+
### State (`DataState`)
|
| 32 |
+
|
| 33 |
+
| Field | Type | Description |
|
| 34 |
+
|---|---|---|
|
| 35 |
+
| `episode_id` | `str` | Unique episode identifier |
|
| 36 |
+
| `step_count` | `int` | Current step number |
|
| 37 |
+
| `task_id` | `int` | Active task (1, 2, or 3) |
|
| 38 |
+
| `answer_submitted` | `bool` | Whether final answer was submitted |
|
| 39 |
+
| `final_score` | `float` | Graded score after submission |
|
| 40 |
+
|
| 41 |
+
## Tasks
|
| 42 |
+
|
| 43 |
+
All tasks use a synthetic e-commerce dataset (~2000 orders) with columns: `order_id`, `customer_id`, `product_name`, `category`, `quantity`, `unit_price`, `total_price`, `order_date`, `city`, `country`.
|
| 44 |
+
|
| 45 |
+
### Task 1 — Easy: Top Revenue Category
|
| 46 |
+
- **Question**: What is the top-selling product category by total revenue?
|
| 47 |
+
- **Grading**: Exact match (case-insensitive) → 1.0 or 0.0
|
| 48 |
+
- **Expected difficulty**: Single groupby + sum + argmax
|
| 49 |
+
|
| 50 |
+
### Task 2 — Medium: City Revenue Share
|
| 51 |
+
- **Question**: Which city generates the most revenue? What percentage of total revenue does it represent?
|
| 52 |
+
- **Grading**: 0.5 for correct city + 0.5 for percentage within ±0.1%
|
| 53 |
+
- **Expected difficulty**: Groupby + percentage calculation + formatting
|
| 54 |
+
|
| 55 |
+
### Task 3 — Hard: Repeat Customer Cohort Analysis
|
| 56 |
+
- **Question**: How many unique customers ordered in both January and December? Compare their average order value to all other customers.
|
| 57 |
+
- **Grading**: 0.33 per correct field (count, cohort AOV, other AOV)
|
| 58 |
+
- **Expected difficulty**: Temporal filtering, set intersection, conditional aggregation
|
| 59 |
+
|
| 60 |
+
## Reward Function
|
| 61 |
+
|
| 62 |
+
| Event | Reward |
|
| 63 |
+
|---|---|
|
| 64 |
+
| Successful code execution | +0.05 |
|
| 65 |
+
| Code execution error | -0.05 |
|
| 66 |
+
| Final answer (graded) | 0.0 — 1.0 based on task grader |
|
| 67 |
+
| Max steps (20) exceeded | 0.0 |
|
| 68 |
+
|
| 69 |
+
## Setup & Usage
|
| 70 |
+
|
| 71 |
+
### Prerequisites
|
| 72 |
+
- Python 3.13+
|
| 73 |
+
- [uv](https://docs.astral.sh/uv/) package manager
|
| 74 |
+
|
| 75 |
+
### Install
|
| 76 |
+
```bash
|
| 77 |
+
uv sync
|
| 78 |
+
```
|
| 79 |
+
|
| 80 |
+
### Run the server
|
| 81 |
+
```bash
|
| 82 |
+
uv run uvicorn server.app:app --host 0.0.0.0 --port 8000
|
| 83 |
+
```
|
| 84 |
+
|
| 85 |
+
### Run the baseline
|
| 86 |
+
```bash
|
| 87 |
+
OPENAI_API_KEY=sk-... uv run python baseline.py
|
| 88 |
+
```
|
| 89 |
+
|
| 90 |
+
### Docker
|
| 91 |
+
```bash
|
| 92 |
+
docker build -t data-analysis-env -f server/Dockerfile .
|
| 93 |
+
docker run -p 8000:8000 data-analysis-env
|
| 94 |
+
```
|
| 95 |
+
|
| 96 |
+
### Client usage (Python)
|
| 97 |
+
```python
|
| 98 |
+
from client import DataAnalysisClient
|
| 99 |
+
from models import DataAction
|
| 100 |
+
|
| 101 |
+
# Async
|
| 102 |
+
async with DataAnalysisClient(base_url="http://localhost:8000") as client:
|
| 103 |
+
result = await client.reset(task_id=1)
|
| 104 |
+
result = await client.step(DataAction(action_type="execute_code", code="print(df.head())"))
|
| 105 |
+
result = await client.step(DataAction(action_type="submit_answer", answer="Electronics"))
|
| 106 |
+
|
| 107 |
+
# Sync
|
| 108 |
+
with DataAnalysisClient(base_url="http://localhost:8000").sync() as client:
|
| 109 |
+
result = client.reset(task_id=2)
|
| 110 |
+
result = client.step(DataAction(action_type="execute_code", code="print(df.groupby('city')['total_price'].sum())"))
|
| 111 |
+
```
|
| 112 |
+
|
| 113 |
+
## Baseline Scores
|
| 114 |
+
|
| 115 |
+
| Task | Difficulty | gpt-4o-mini Score |
|
| 116 |
+
|---|---|---|
|
| 117 |
+
| 1 | Easy | TBD |
|
| 118 |
+
| 2 | Medium | TBD |
|
| 119 |
+
| 3 | Hard | TBD |
|
| 120 |
+
| **Average** | | **TBD** |
|
| 121 |
+
|
| 122 |
+
*(Run the baseline script to populate these scores)*
|
| 123 |
+
|
| 124 |
+
## Project Structure
|
| 125 |
+
|
| 126 |
+
```
|
| 127 |
+
├── models.py # DataAction, DataObservation, DataState
|
| 128 |
+
├── client.py # DataAnalysisClient (EnvClient subclass)
|
| 129 |
+
├── baseline.py # OpenAI baseline inference script
|
| 130 |
+
├── tasks/
|
| 131 |
+
│ ├── base_task.py # Task ABC with grade() interface
|
| 132 |
+
│ ├── task_easy.py # Task 1: Top revenue category
|
| 133 |
+
│ ├── task_medium.py # Task 2: City revenue share
|
| 134 |
+
│ └── task_hard.py # Task 3: Repeat customer cohort
|
| 135 |
+
├── datasets/
|
| 136 |
+
│ └── sales.csv # Synthetic e-commerce dataset
|
| 137 |
+
├── server/
|
| 138 |
+
│ ├── app.py # FastAPI app (create_app)
|
| 139 |
+
│ ├── data_analysis_env.py # Environment implementation
|
| 140 |
+
│ └── Dockerfile # Container build
|
| 141 |
+
├── openenv.yaml # OpenEnv spec metadata
|
| 142 |
+
└── pyproject.toml # Dependencies and project config
|
| 143 |
+
```
|
__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Data Analysis Agent Environment for OpenEnv.
|
| 2 |
+
|
| 3 |
+
An RL environment where agents execute pandas code against a business
|
| 4 |
+
dataset to answer analytical questions with programmatic grading.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from client import DataAnalysisClient
|
| 8 |
+
from models import DataAction, DataObservation, DataState
|
| 9 |
+
|
| 10 |
+
__all__ = [
|
| 11 |
+
"DataAnalysisClient",
|
| 12 |
+
"DataAction",
|
| 13 |
+
"DataObservation",
|
| 14 |
+
"DataState",
|
| 15 |
+
]
|
baseline.py
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Baseline inference script for the Data Analysis Agent environment.
|
| 2 |
+
|
| 3 |
+
Uses the OpenAI API to run a model (gpt-4o-mini) against all 3 tasks
|
| 4 |
+
and produces reproducible baseline scores.
|
| 5 |
+
|
| 6 |
+
Usage:
|
| 7 |
+
OPENAI_API_KEY=sk-... uv run python baseline.py
|
| 8 |
+
OPENAI_API_KEY=sk-... uv run python baseline.py --base-url http://localhost:8000
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import argparse
|
| 12 |
+
import json
|
| 13 |
+
import os
|
| 14 |
+
import sys
|
| 15 |
+
|
| 16 |
+
import requests
|
| 17 |
+
from openai import OpenAI
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
SYSTEM_PROMPT = """You are a data analyst. You are given a dataset loaded as a pandas DataFrame called `df`.
|
| 21 |
+
You can execute Python/pandas code to explore the dataset and answer the question.
|
| 22 |
+
|
| 23 |
+
Rules:
|
| 24 |
+
- Use `print()` to see results of your code
|
| 25 |
+
- The DataFrame `df` is pre-loaded with pandas as `pd` and numpy as `np`
|
| 26 |
+
- When you have the answer, submit it in the exact format requested
|
| 27 |
+
- Be precise with numbers and formatting
|
| 28 |
+
|
| 29 |
+
Respond with JSON in one of these formats:
|
| 30 |
+
1. To execute code: {{"action": "execute_code", "code": "your python code here"}}
|
| 31 |
+
2. To submit answer: {{"action": "submit_answer", "answer": "your answer here"}}
|
| 32 |
+
|
| 33 |
+
Respond with ONLY the JSON, no other text."""
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def run_task(client: OpenAI, base_url: str, task_id: int, max_steps: int = 15) -> float:
|
| 37 |
+
"""Run a single task using the OpenAI API as the agent.
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
client: The OpenAI client instance.
|
| 41 |
+
base_url: The environment server base URL.
|
| 42 |
+
task_id: Which task to run (1, 2, or 3).
|
| 43 |
+
max_steps: Maximum agent steps before giving up.
|
| 44 |
+
|
| 45 |
+
Returns:
|
| 46 |
+
The final score for this task (0.0 to 1.0).
|
| 47 |
+
"""
|
| 48 |
+
# Reset environment with the specified task
|
| 49 |
+
reset_resp = requests.post(
|
| 50 |
+
f"{base_url}/reset",
|
| 51 |
+
json={"task_id": task_id},
|
| 52 |
+
timeout=30,
|
| 53 |
+
)
|
| 54 |
+
reset_data = reset_resp.json()
|
| 55 |
+
obs = reset_data.get("observation", reset_data)
|
| 56 |
+
|
| 57 |
+
task_desc = obs.get("task_description", "")
|
| 58 |
+
dataset_info = obs.get("dataset_info", "")
|
| 59 |
+
|
| 60 |
+
messages = [
|
| 61 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 62 |
+
{
|
| 63 |
+
"role": "user",
|
| 64 |
+
"content": f"Task: {task_desc}\n\nDataset Info:\n{dataset_info}",
|
| 65 |
+
},
|
| 66 |
+
]
|
| 67 |
+
|
| 68 |
+
print(f"\n--- Task {task_id} ---")
|
| 69 |
+
print(f"Question: {task_desc}")
|
| 70 |
+
|
| 71 |
+
for step in range(max_steps):
|
| 72 |
+
response = client.chat.completions.create(
|
| 73 |
+
model="gpt-4o-mini",
|
| 74 |
+
messages=messages,
|
| 75 |
+
temperature=0.0,
|
| 76 |
+
)
|
| 77 |
+
assistant_msg = response.choices[0].message.content.strip()
|
| 78 |
+
|
| 79 |
+
# Parse the agent's JSON response
|
| 80 |
+
try:
|
| 81 |
+
# Handle markdown code blocks if present
|
| 82 |
+
if assistant_msg.startswith("```"):
|
| 83 |
+
assistant_msg = assistant_msg.split("```")[1]
|
| 84 |
+
if assistant_msg.startswith("json"):
|
| 85 |
+
assistant_msg = assistant_msg[4:]
|
| 86 |
+
assistant_msg = assistant_msg.strip()
|
| 87 |
+
action = json.loads(assistant_msg)
|
| 88 |
+
except json.JSONDecodeError:
|
| 89 |
+
messages.append({"role": "assistant", "content": assistant_msg})
|
| 90 |
+
messages.append({
|
| 91 |
+
"role": "user",
|
| 92 |
+
"content": "Invalid JSON. Please respond with valid JSON only.",
|
| 93 |
+
})
|
| 94 |
+
continue
|
| 95 |
+
|
| 96 |
+
action_type = action.get("action", "")
|
| 97 |
+
|
| 98 |
+
if action_type == "execute_code":
|
| 99 |
+
# Send code execution to environment
|
| 100 |
+
step_resp = requests.post(
|
| 101 |
+
f"{base_url}/step",
|
| 102 |
+
json={
|
| 103 |
+
"action_type": "execute_code",
|
| 104 |
+
"code": action.get("code", ""),
|
| 105 |
+
},
|
| 106 |
+
timeout=30,
|
| 107 |
+
)
|
| 108 |
+
step_data = step_resp.json()
|
| 109 |
+
step_obs = step_data.get("observation", step_data)
|
| 110 |
+
|
| 111 |
+
output = step_obs.get("output", "")
|
| 112 |
+
error = step_obs.get("error", "")
|
| 113 |
+
result_text = f"Output: {output}" if not error else f"Error: {error}"
|
| 114 |
+
print(f" Step {step + 1}: execute_code -> {result_text[:100]}")
|
| 115 |
+
|
| 116 |
+
messages.append({"role": "assistant", "content": assistant_msg})
|
| 117 |
+
messages.append({"role": "user", "content": result_text})
|
| 118 |
+
|
| 119 |
+
elif action_type == "submit_answer":
|
| 120 |
+
# Submit final answer
|
| 121 |
+
step_resp = requests.post(
|
| 122 |
+
f"{base_url}/step",
|
| 123 |
+
json={
|
| 124 |
+
"action_type": "submit_answer",
|
| 125 |
+
"answer": action.get("answer", ""),
|
| 126 |
+
},
|
| 127 |
+
timeout=30,
|
| 128 |
+
)
|
| 129 |
+
step_data = step_resp.json()
|
| 130 |
+
step_obs = step_data.get("observation", step_data)
|
| 131 |
+
|
| 132 |
+
score = step_obs.get("metadata", {}).get("score", 0.0)
|
| 133 |
+
print(f" Step {step + 1}: submit_answer -> '{action.get('answer', '')}'")
|
| 134 |
+
print(f" Score: {score:.2f}")
|
| 135 |
+
return score
|
| 136 |
+
else:
|
| 137 |
+
messages.append({"role": "assistant", "content": assistant_msg})
|
| 138 |
+
messages.append({
|
| 139 |
+
"role": "user",
|
| 140 |
+
"content": f"Unknown action '{action_type}'. Use 'execute_code' or 'submit_answer'.",
|
| 141 |
+
})
|
| 142 |
+
|
| 143 |
+
print(" Max steps reached without submitting an answer.")
|
| 144 |
+
return 0.0
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def main():
|
| 148 |
+
"""Run baseline inference across all 3 tasks and report scores."""
|
| 149 |
+
parser = argparse.ArgumentParser(description="Baseline inference for Data Analysis Env")
|
| 150 |
+
parser.add_argument(
|
| 151 |
+
"--base-url",
|
| 152 |
+
default="http://localhost:8000",
|
| 153 |
+
help="Environment server URL (default: http://localhost:8000)",
|
| 154 |
+
)
|
| 155 |
+
args = parser.parse_args()
|
| 156 |
+
|
| 157 |
+
api_key = os.environ.get("OPENAI_API_KEY")
|
| 158 |
+
if not api_key:
|
| 159 |
+
print("Error: OPENAI_API_KEY environment variable is required.")
|
| 160 |
+
sys.exit(1)
|
| 161 |
+
|
| 162 |
+
client = OpenAI(api_key=api_key)
|
| 163 |
+
|
| 164 |
+
print("=" * 50)
|
| 165 |
+
print("Data Analysis Agent - Baseline Inference")
|
| 166 |
+
print(f"Server: {args.base_url}")
|
| 167 |
+
print(f"Model: gpt-4o-mini")
|
| 168 |
+
print("=" * 50)
|
| 169 |
+
|
| 170 |
+
scores = {}
|
| 171 |
+
for task_id in [1, 2, 3]:
|
| 172 |
+
score = run_task(client, args.base_url, task_id)
|
| 173 |
+
scores[task_id] = score
|
| 174 |
+
|
| 175 |
+
print("\n" + "=" * 50)
|
| 176 |
+
print("RESULTS")
|
| 177 |
+
print("=" * 50)
|
| 178 |
+
difficulties = {1: "Easy", 2: "Medium", 3: "Hard"}
|
| 179 |
+
for task_id, score in scores.items():
|
| 180 |
+
print(f" Task {task_id} ({difficulties[task_id]}): {score:.2f}")
|
| 181 |
+
avg = sum(scores.values()) / len(scores)
|
| 182 |
+
print(f"\n Average Score: {avg:.2f}")
|
| 183 |
+
print("=" * 50)
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
if __name__ == "__main__":
|
| 187 |
+
main()
|
client.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Client for the Data Analysis Agent environment.
|
| 2 |
+
|
| 3 |
+
Provides a typed async/sync client for interacting with the
|
| 4 |
+
data analysis environment server over HTTP/WebSocket.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from openenv.core.env_client import EnvClient
|
| 8 |
+
from openenv.core.client_types import StepResult
|
| 9 |
+
|
| 10 |
+
from models import DataAction, DataObservation, DataState
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class DataAnalysisClient(EnvClient[DataAction, DataObservation, DataState]):
|
| 14 |
+
"""Client for interacting with the Data Analysis environment server.
|
| 15 |
+
|
| 16 |
+
Supports both async and sync usage patterns:
|
| 17 |
+
- Async: ``async with DataAnalysisClient(base_url=...) as client:``
|
| 18 |
+
- Sync: ``with DataAnalysisClient(base_url=...).sync() as client:``
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
def _step_payload(self, action: DataAction) -> dict:
|
| 22 |
+
"""Convert a DataAction into a JSON-serializable payload.
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
action: The action to send to the server.
|
| 26 |
+
|
| 27 |
+
Returns:
|
| 28 |
+
A dictionary representation of the action.
|
| 29 |
+
"""
|
| 30 |
+
payload = {"action_type": action.action_type}
|
| 31 |
+
if action.code is not None:
|
| 32 |
+
payload["code"] = action.code
|
| 33 |
+
if action.answer is not None:
|
| 34 |
+
payload["answer"] = action.answer
|
| 35 |
+
return payload
|
| 36 |
+
|
| 37 |
+
def _parse_result(self, payload: dict) -> StepResult[DataObservation]:
|
| 38 |
+
"""Parse the server's JSON response into a StepResult.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
payload: The raw JSON response from the server.
|
| 42 |
+
|
| 43 |
+
Returns:
|
| 44 |
+
A StepResult containing the parsed observation, reward, and done flag.
|
| 45 |
+
"""
|
| 46 |
+
obs_data = payload.get("observation", payload)
|
| 47 |
+
obs = DataObservation(**obs_data)
|
| 48 |
+
return StepResult(
|
| 49 |
+
observation=obs,
|
| 50 |
+
reward=payload.get("reward", obs.reward),
|
| 51 |
+
done=payload.get("done", obs.done),
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
def _parse_state(self, payload: dict) -> DataState:
|
| 55 |
+
"""Parse the server's state response into a DataState.
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
payload: The raw JSON state response from the server.
|
| 59 |
+
|
| 60 |
+
Returns:
|
| 61 |
+
A DataState object reflecting the current episode state.
|
| 62 |
+
"""
|
| 63 |
+
return DataState(**payload)
|
datasets/sales.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
models.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Pydantic models for the Data Analysis Agent environment.
|
| 2 |
+
|
| 3 |
+
Defines the action, observation, and state types used for communication
|
| 4 |
+
between the RL agent and the environment server.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from typing import Literal, Optional
|
| 8 |
+
|
| 9 |
+
from openenv.core.env_server import Action, Observation, State
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class DataAction(Action):
|
| 13 |
+
"""Agent action for the data analysis environment.
|
| 14 |
+
|
| 15 |
+
The agent can either execute pandas code against the loaded dataset
|
| 16 |
+
or submit a final answer to be graded.
|
| 17 |
+
|
| 18 |
+
Attributes:
|
| 19 |
+
action_type: Whether to execute code or submit an answer.
|
| 20 |
+
code: Python/pandas code to execute (required when action_type is "execute_code").
|
| 21 |
+
answer: Final answer string (required when action_type is "submit_answer").
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
action_type: Literal["execute_code", "submit_answer"]
|
| 25 |
+
code: Optional[str] = None
|
| 26 |
+
answer: Optional[str] = None
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class DataObservation(Observation):
|
| 30 |
+
"""Observation returned after each step or reset.
|
| 31 |
+
|
| 32 |
+
Attributes:
|
| 33 |
+
output: String output from code execution or environment messages.
|
| 34 |
+
success: Whether the last action executed without errors.
|
| 35 |
+
error: Error message if the last action failed.
|
| 36 |
+
task_description: The task question, populated on reset.
|
| 37 |
+
dataset_info: Column names and dtypes summary, populated on reset.
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
output: str = ""
|
| 41 |
+
success: bool = True
|
| 42 |
+
error: Optional[str] = None
|
| 43 |
+
task_description: str = ""
|
| 44 |
+
dataset_info: str = ""
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class DataState(State):
|
| 48 |
+
"""Episode state for the data analysis environment.
|
| 49 |
+
|
| 50 |
+
Attributes:
|
| 51 |
+
task_id: The current task being evaluated (1, 2, or 3).
|
| 52 |
+
answer_submitted: Whether the agent has submitted a final answer.
|
| 53 |
+
final_score: The graded score after answer submission (0.0 to 1.0).
|
| 54 |
+
"""
|
| 55 |
+
|
| 56 |
+
task_id: int = 1
|
| 57 |
+
answer_submitted: bool = False
|
| 58 |
+
final_score: float = 0.0
|
openenv.yaml
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
spec_version: 1
|
| 2 |
+
name: data_analysis_env
|
| 3 |
+
version: "0.1.0"
|
| 4 |
+
description: "RL environment for training data analysis agents on business datasets"
|
| 5 |
+
type: space
|
| 6 |
+
runtime: fastapi
|
| 7 |
+
app: server.app:app
|
| 8 |
+
port: 8000
|
pyproject.toml
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "openenv-data-analysis-env"
|
| 3 |
+
version = "0.1.0"
|
| 4 |
+
description = "RL environment for training data analysis agents on business datasets"
|
| 5 |
+
readme = "README.md"
|
| 6 |
+
requires-python = ">=3.13"
|
| 7 |
+
dependencies = [
|
| 8 |
+
"openenv-core>=0.2.3",
|
| 9 |
+
"fastapi>=0.115.0",
|
| 10 |
+
"pydantic>=2.0.0",
|
| 11 |
+
"uvicorn>=0.24.0",
|
| 12 |
+
"pandas>=2.0.0",
|
| 13 |
+
"numpy>=1.24.0",
|
| 14 |
+
"openai>=1.0.0",
|
| 15 |
+
]
|
| 16 |
+
|
| 17 |
+
[project.scripts]
|
| 18 |
+
server = "server.app:main"
|
server/Dockerfile
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Multi-stage build for the Data Analysis Agent environment
|
| 2 |
+
ARG BASE_IMAGE=ghcr.io/meta-pytorch/openenv-base:latest
|
| 3 |
+
FROM ${BASE_IMAGE} AS builder
|
| 4 |
+
|
| 5 |
+
WORKDIR /app
|
| 6 |
+
|
| 7 |
+
# Copy environment code
|
| 8 |
+
COPY .. /app/env
|
| 9 |
+
|
| 10 |
+
WORKDIR /app/env
|
| 11 |
+
|
| 12 |
+
# Ensure uv is available
|
| 13 |
+
RUN if ! command -v uv >/dev/null 2>&1; then \
|
| 14 |
+
curl -LsSf https://astral.sh/uv/install.sh | sh && \
|
| 15 |
+
mv /root/.local/bin/uv /usr/local/bin/uv && \
|
| 16 |
+
mv /root/.local/bin/uvx /usr/local/bin/uvx; \
|
| 17 |
+
fi
|
| 18 |
+
|
| 19 |
+
# Install git for build-time dependencies
|
| 20 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 21 |
+
git \
|
| 22 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 23 |
+
|
| 24 |
+
# Install dependencies with cache
|
| 25 |
+
RUN --mount=type=cache,target=/root/.cache/uv \
|
| 26 |
+
if [ -f uv.lock ]; then \
|
| 27 |
+
uv sync --frozen --no-install-project --no-editable; \
|
| 28 |
+
else \
|
| 29 |
+
uv sync --no-install-project --no-editable; \
|
| 30 |
+
fi
|
| 31 |
+
|
| 32 |
+
RUN --mount=type=cache,target=/root/.cache/uv \
|
| 33 |
+
if [ -f uv.lock ]; then \
|
| 34 |
+
uv sync --frozen --no-editable; \
|
| 35 |
+
else \
|
| 36 |
+
uv sync --no-editable; \
|
| 37 |
+
fi
|
| 38 |
+
|
| 39 |
+
# Final runtime stage
|
| 40 |
+
FROM ${BASE_IMAGE}
|
| 41 |
+
|
| 42 |
+
WORKDIR /app
|
| 43 |
+
|
| 44 |
+
# Copy virtual environment and code
|
| 45 |
+
COPY --from=builder /app/env/.venv /app/.venv
|
| 46 |
+
COPY --from=builder /app/env /app/env
|
| 47 |
+
|
| 48 |
+
ENV PATH="/app/.venv/bin:$PATH"
|
| 49 |
+
ENV PYTHONPATH="/app/env:$PYTHONPATH"
|
| 50 |
+
|
| 51 |
+
# Health check
|
| 52 |
+
HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
|
| 53 |
+
CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:8000/health')" || exit 1
|
| 54 |
+
|
| 55 |
+
EXPOSE 8000
|
| 56 |
+
|
| 57 |
+
# Run server
|
| 58 |
+
CMD ["sh", "-c", "cd /app/env && uvicorn server.app:app --host 0.0.0.0 --port 8000"]
|
server/__init__.py
ADDED
|
File without changes
|
server/app.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""FastAPI application for the Data Analysis Agent environment.
|
| 2 |
+
|
| 3 |
+
Creates the OpenEnv-compliant HTTP/WebSocket server that wraps
|
| 4 |
+
the DataAnalysisEnv environment.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from openenv.core.env_server import create_app
|
| 8 |
+
|
| 9 |
+
from models import DataAction, DataObservation
|
| 10 |
+
from server.data_analysis_env import DataAnalysisEnv
|
| 11 |
+
|
| 12 |
+
app = create_app(DataAnalysisEnv, DataAction, DataObservation, env_name="data_analysis_env")
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def main():
|
| 16 |
+
"""Run the environment server with uvicorn."""
|
| 17 |
+
import uvicorn
|
| 18 |
+
|
| 19 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
if __name__ == "__main__":
|
| 23 |
+
main()
|
server/data_analysis_env.py
ADDED
|
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Data Analysis Agent environment implementation.
|
| 2 |
+
|
| 3 |
+
Provides an RL environment where an agent executes pandas code against
|
| 4 |
+
a business dataset to answer analytical questions. Each episode presents
|
| 5 |
+
a task with a programmatic grader that scores performance 0.0-1.0.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import io
|
| 9 |
+
import sys
|
| 10 |
+
import uuid
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
from typing import Any, Optional
|
| 13 |
+
|
| 14 |
+
import numpy as np
|
| 15 |
+
import pandas as pd
|
| 16 |
+
|
| 17 |
+
from openenv.core.env_server import Environment
|
| 18 |
+
|
| 19 |
+
from models import DataAction, DataObservation, DataState
|
| 20 |
+
from tasks import TASKS
|
| 21 |
+
|
| 22 |
+
DATASET_PATH = Path(__file__).resolve().parent.parent / "datasets" / "sales.csv"
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class DataAnalysisEnv(Environment):
|
| 26 |
+
"""Environment for training data analysis agents on business datasets.
|
| 27 |
+
|
| 28 |
+
The agent receives a task question and can execute pandas code against
|
| 29 |
+
a pre-loaded DataFrame. The episode ends when the agent submits an answer
|
| 30 |
+
or exceeds the maximum number of steps.
|
| 31 |
+
|
| 32 |
+
Attributes:
|
| 33 |
+
MAX_STEPS: Maximum steps before forced episode termination.
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
MAX_STEPS = 20
|
| 37 |
+
|
| 38 |
+
def __init__(self):
|
| 39 |
+
"""Initialize the environment with default state."""
|
| 40 |
+
super().__init__()
|
| 41 |
+
self._source_df = pd.read_csv(DATASET_PATH)
|
| 42 |
+
self._df = self._source_df.copy()
|
| 43 |
+
self._state = DataState()
|
| 44 |
+
self._task = None
|
| 45 |
+
self._exec_namespace = {}
|
| 46 |
+
|
| 47 |
+
def _build_namespace(self) -> dict:
|
| 48 |
+
"""Build a restricted execution namespace for agent code.
|
| 49 |
+
|
| 50 |
+
The namespace includes only pandas, numpy, and the dataset copy.
|
| 51 |
+
Dangerous builtins like open, exec, eval, and __import__ are removed.
|
| 52 |
+
|
| 53 |
+
Returns:
|
| 54 |
+
A dictionary to use as the globals for exec().
|
| 55 |
+
"""
|
| 56 |
+
safe_builtins = {
|
| 57 |
+
k: v for k, v in __builtins__.items()
|
| 58 |
+
if k not in ("open", "exec", "eval", "__import__", "compile", "exit", "quit")
|
| 59 |
+
} if isinstance(__builtins__, dict) else {
|
| 60 |
+
k: getattr(__builtins__, k) for k in dir(__builtins__)
|
| 61 |
+
if k not in ("open", "exec", "eval", "__import__", "compile", "exit", "quit")
|
| 62 |
+
and not k.startswith("_")
|
| 63 |
+
}
|
| 64 |
+
return {
|
| 65 |
+
"__builtins__": safe_builtins,
|
| 66 |
+
"df": self._df.copy(),
|
| 67 |
+
"pd": pd,
|
| 68 |
+
"np": np,
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
def _dataset_info(self) -> str:
|
| 72 |
+
"""Generate a summary of the dataset schema for the agent.
|
| 73 |
+
|
| 74 |
+
Returns:
|
| 75 |
+
A string describing column names, dtypes, row count, and a sample.
|
| 76 |
+
"""
|
| 77 |
+
buf = io.StringIO()
|
| 78 |
+
self._df.info(buf=buf)
|
| 79 |
+
info_str = buf.getvalue()
|
| 80 |
+
sample = self._df.head(3).to_string()
|
| 81 |
+
return f"Dataset shape: {self._df.shape}\n\n{info_str}\nSample rows:\n{sample}"
|
| 82 |
+
|
| 83 |
+
def reset(
|
| 84 |
+
self,
|
| 85 |
+
seed: Optional[int] = None,
|
| 86 |
+
episode_id: Optional[str] = None,
|
| 87 |
+
**kwargs: Any,
|
| 88 |
+
) -> DataObservation:
|
| 89 |
+
"""Reset the environment for a new episode.
|
| 90 |
+
|
| 91 |
+
Args:
|
| 92 |
+
seed: Optional random seed (unused, kept for interface compliance).
|
| 93 |
+
episode_id: Optional episode identifier; generated if not provided.
|
| 94 |
+
**kwargs: Additional keyword arguments. Supports 'task_id' (int, 1-3).
|
| 95 |
+
|
| 96 |
+
Returns:
|
| 97 |
+
An initial observation with the task description and dataset info.
|
| 98 |
+
"""
|
| 99 |
+
task_id = kwargs.get("task_id", 1)
|
| 100 |
+
eid = episode_id or str(uuid.uuid4())
|
| 101 |
+
|
| 102 |
+
self._df = self._source_df.copy()
|
| 103 |
+
self._state = DataState(episode_id=eid, step_count=0, task_id=task_id)
|
| 104 |
+
self._exec_namespace = self._build_namespace()
|
| 105 |
+
|
| 106 |
+
task_cls = TASKS.get(task_id)
|
| 107 |
+
if task_cls is None:
|
| 108 |
+
return DataObservation(
|
| 109 |
+
done=True,
|
| 110 |
+
reward=0.0,
|
| 111 |
+
success=False,
|
| 112 |
+
error=f"Invalid task_id: {task_id}. Must be 1, 2, or 3.",
|
| 113 |
+
)
|
| 114 |
+
self._task = task_cls(self._df)
|
| 115 |
+
|
| 116 |
+
return DataObservation(
|
| 117 |
+
done=False,
|
| 118 |
+
reward=0.0,
|
| 119 |
+
output="Environment ready. Use 'execute_code' actions to explore the dataset, then 'submit_answer' with your result.",
|
| 120 |
+
task_description=self._task.description,
|
| 121 |
+
dataset_info=self._dataset_info(),
|
| 122 |
+
metadata={"task_id": task_id, "difficulty": self._task.difficulty},
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
def step(
|
| 126 |
+
self,
|
| 127 |
+
action: DataAction,
|
| 128 |
+
timeout_s: Optional[float] = None,
|
| 129 |
+
**kwargs: Any,
|
| 130 |
+
) -> DataObservation:
|
| 131 |
+
"""Execute one step in the environment.
|
| 132 |
+
|
| 133 |
+
Handles two action types:
|
| 134 |
+
- execute_code: runs pandas code in a sandboxed namespace
|
| 135 |
+
- submit_answer: grades the agent's final answer and ends the episode
|
| 136 |
+
|
| 137 |
+
Args:
|
| 138 |
+
action: The agent's action (execute_code or submit_answer).
|
| 139 |
+
timeout_s: Optional timeout in seconds (unused).
|
| 140 |
+
**kwargs: Additional keyword arguments.
|
| 141 |
+
|
| 142 |
+
Returns:
|
| 143 |
+
An observation with execution output, reward, and done flag.
|
| 144 |
+
"""
|
| 145 |
+
self._state.step_count += 1
|
| 146 |
+
|
| 147 |
+
if self._state.answer_submitted:
|
| 148 |
+
return DataObservation(
|
| 149 |
+
done=True,
|
| 150 |
+
reward=0.0,
|
| 151 |
+
output="Episode is already finished. Call reset() to start a new one.",
|
| 152 |
+
success=False,
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
# Check max steps
|
| 156 |
+
if self._state.step_count >= self.MAX_STEPS and action.action_type != "submit_answer":
|
| 157 |
+
self._state.answer_submitted = True
|
| 158 |
+
return DataObservation(
|
| 159 |
+
done=True,
|
| 160 |
+
reward=0.0,
|
| 161 |
+
output=f"Maximum steps ({self.MAX_STEPS}) exceeded without submitting an answer.",
|
| 162 |
+
success=False,
|
| 163 |
+
metadata={"reason": "max_steps_exceeded"},
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
if action.action_type == "execute_code":
|
| 167 |
+
return self._handle_execute_code(action)
|
| 168 |
+
elif action.action_type == "submit_answer":
|
| 169 |
+
return self._handle_submit_answer(action)
|
| 170 |
+
else:
|
| 171 |
+
return DataObservation(
|
| 172 |
+
done=False,
|
| 173 |
+
reward=-0.05,
|
| 174 |
+
success=False,
|
| 175 |
+
error=f"Unknown action_type: {action.action_type}",
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
def _handle_execute_code(self, action: DataAction) -> DataObservation:
|
| 179 |
+
"""Execute pandas code in the sandboxed namespace.
|
| 180 |
+
|
| 181 |
+
Args:
|
| 182 |
+
action: The action containing the code to execute.
|
| 183 |
+
|
| 184 |
+
Returns:
|
| 185 |
+
An observation with stdout output or error message.
|
| 186 |
+
"""
|
| 187 |
+
if not action.code:
|
| 188 |
+
return DataObservation(
|
| 189 |
+
done=False,
|
| 190 |
+
reward=-0.05,
|
| 191 |
+
success=False,
|
| 192 |
+
error="No code provided for execute_code action.",
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
stdout_capture = io.StringIO()
|
| 196 |
+
old_stdout = sys.stdout
|
| 197 |
+
try:
|
| 198 |
+
sys.stdout = stdout_capture
|
| 199 |
+
exec(action.code, self._exec_namespace)
|
| 200 |
+
sys.stdout = old_stdout
|
| 201 |
+
output = stdout_capture.getvalue()
|
| 202 |
+
|
| 203 |
+
# If code produced no print output, try to get the last expression value
|
| 204 |
+
if not output.strip():
|
| 205 |
+
try:
|
| 206 |
+
result = eval(action.code.strip().split("\n")[-1], self._exec_namespace)
|
| 207 |
+
if result is not None:
|
| 208 |
+
output = str(result)
|
| 209 |
+
except Exception:
|
| 210 |
+
output = "(Code executed successfully with no output)"
|
| 211 |
+
|
| 212 |
+
return DataObservation(
|
| 213 |
+
done=False,
|
| 214 |
+
reward=0.05,
|
| 215 |
+
output=output[:5000],
|
| 216 |
+
success=True,
|
| 217 |
+
metadata={"steps_remaining": self.MAX_STEPS - self._state.step_count},
|
| 218 |
+
)
|
| 219 |
+
except Exception as e:
|
| 220 |
+
sys.stdout = old_stdout
|
| 221 |
+
return DataObservation(
|
| 222 |
+
done=False,
|
| 223 |
+
reward=-0.05,
|
| 224 |
+
success=False,
|
| 225 |
+
error=f"{type(e).__name__}: {e}",
|
| 226 |
+
output="",
|
| 227 |
+
metadata={"steps_remaining": self.MAX_STEPS - self._state.step_count},
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
def _handle_submit_answer(self, action: DataAction) -> DataObservation:
|
| 231 |
+
"""Grade the agent's submitted answer and end the episode.
|
| 232 |
+
|
| 233 |
+
Args:
|
| 234 |
+
action: The action containing the answer to grade.
|
| 235 |
+
|
| 236 |
+
Returns:
|
| 237 |
+
An observation with the final score and done=True.
|
| 238 |
+
"""
|
| 239 |
+
if not action.answer:
|
| 240 |
+
return DataObservation(
|
| 241 |
+
done=False,
|
| 242 |
+
reward=-0.05,
|
| 243 |
+
success=False,
|
| 244 |
+
error="No answer provided for submit_answer action.",
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
self._state.answer_submitted = True
|
| 248 |
+
score = self._task.grade(action.answer)
|
| 249 |
+
self._state.final_score = score
|
| 250 |
+
|
| 251 |
+
return DataObservation(
|
| 252 |
+
done=True,
|
| 253 |
+
reward=score,
|
| 254 |
+
output=f"Answer submitted. Score: {score:.2f}/1.00",
|
| 255 |
+
success=True,
|
| 256 |
+
metadata={
|
| 257 |
+
"score": score,
|
| 258 |
+
"expected_answer": self._task.expected_answer(),
|
| 259 |
+
"submitted_answer": action.answer,
|
| 260 |
+
},
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
@property
|
| 264 |
+
def state(self) -> DataState:
|
| 265 |
+
"""Return the current episode state.
|
| 266 |
+
|
| 267 |
+
Returns:
|
| 268 |
+
The current DataState with episode_id, step_count, task_id, etc.
|
| 269 |
+
"""
|
| 270 |
+
return self._state
|
tasks/__init__.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Task definitions for the Data Analysis Agent environment."""
|
| 2 |
+
|
| 3 |
+
from tasks.base_task import BaseTask
|
| 4 |
+
from tasks.task_easy import TopRevenueCategoryTask
|
| 5 |
+
from tasks.task_medium import CityRevenueShareTask
|
| 6 |
+
from tasks.task_hard import RepeatCustomerCohortTask
|
| 7 |
+
|
| 8 |
+
TASKS = {
|
| 9 |
+
1: TopRevenueCategoryTask,
|
| 10 |
+
2: CityRevenueShareTask,
|
| 11 |
+
3: RepeatCustomerCohortTask,
|
| 12 |
+
}
|
| 13 |
+
|
| 14 |
+
__all__ = ["BaseTask", "TASKS", "TopRevenueCategoryTask", "CityRevenueShareTask", "RepeatCustomerCohortTask"]
|
tasks/base_task.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Abstract base class for data analysis tasks.
|
| 2 |
+
|
| 3 |
+
Each task defines a question, computes the expected answer from the dataset,
|
| 4 |
+
and provides a grader that scores agent responses from 0.0 to 1.0.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from abc import ABC, abstractmethod
|
| 8 |
+
|
| 9 |
+
import pandas as pd
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class BaseTask(ABC):
|
| 13 |
+
"""Base class for all data analysis tasks.
|
| 14 |
+
|
| 15 |
+
Subclasses must implement the question, compute the expected answer
|
| 16 |
+
from the dataset, and provide a grading function.
|
| 17 |
+
|
| 18 |
+
Attributes:
|
| 19 |
+
df: The pandas DataFrame containing the dataset.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def __init__(self, df: pd.DataFrame):
|
| 23 |
+
"""Initialize the task with a dataset.
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
df: The pandas DataFrame to analyze.
|
| 27 |
+
"""
|
| 28 |
+
self.df = df
|
| 29 |
+
|
| 30 |
+
@property
|
| 31 |
+
@abstractmethod
|
| 32 |
+
def task_id(self) -> int:
|
| 33 |
+
"""Return the unique task identifier."""
|
| 34 |
+
|
| 35 |
+
@property
|
| 36 |
+
@abstractmethod
|
| 37 |
+
def difficulty(self) -> str:
|
| 38 |
+
"""Return the difficulty level: 'easy', 'medium', or 'hard'."""
|
| 39 |
+
|
| 40 |
+
@property
|
| 41 |
+
@abstractmethod
|
| 42 |
+
def description(self) -> str:
|
| 43 |
+
"""Return the task question shown to the agent."""
|
| 44 |
+
|
| 45 |
+
@abstractmethod
|
| 46 |
+
def expected_answer(self) -> str:
|
| 47 |
+
"""Compute and return the ground-truth answer from the dataset.
|
| 48 |
+
|
| 49 |
+
Returns:
|
| 50 |
+
The expected answer as a formatted string.
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
@abstractmethod
|
| 54 |
+
def grade(self, answer: str) -> float:
|
| 55 |
+
"""Grade the agent's submitted answer.
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
answer: The agent's submitted answer string.
|
| 59 |
+
|
| 60 |
+
Returns:
|
| 61 |
+
A score between 0.0 and 1.0.
|
| 62 |
+
"""
|
tasks/task_easy.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Task 1 (Easy): Identify the top-selling product category by total revenue.
|
| 2 |
+
|
| 3 |
+
Requires a single groupby + sum + idxmax operation.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import pandas as pd
|
| 7 |
+
|
| 8 |
+
from tasks.base_task import BaseTask
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class TopRevenueCategoryTask(BaseTask):
|
| 12 |
+
"""Easy task: find the product category with the highest total revenue.
|
| 13 |
+
|
| 14 |
+
The agent must group the dataset by category, sum the total_price column,
|
| 15 |
+
and identify which category has the highest revenue.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
@property
|
| 19 |
+
def task_id(self) -> int:
|
| 20 |
+
"""Return the task identifier."""
|
| 21 |
+
return 1
|
| 22 |
+
|
| 23 |
+
@property
|
| 24 |
+
def difficulty(self) -> str:
|
| 25 |
+
"""Return the difficulty level."""
|
| 26 |
+
return "easy"
|
| 27 |
+
|
| 28 |
+
@property
|
| 29 |
+
def description(self) -> str:
|
| 30 |
+
"""Return the task question."""
|
| 31 |
+
return (
|
| 32 |
+
"What is the top-selling product category by total revenue? "
|
| 33 |
+
"Submit just the category name as your answer."
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
def expected_answer(self) -> str:
|
| 37 |
+
"""Compute the top revenue category from the dataset.
|
| 38 |
+
|
| 39 |
+
Returns:
|
| 40 |
+
The name of the category with the highest total_price sum.
|
| 41 |
+
"""
|
| 42 |
+
return self.df.groupby("category")["total_price"].sum().idxmax()
|
| 43 |
+
|
| 44 |
+
def grade(self, answer: str) -> float:
|
| 45 |
+
"""Grade the answer by case-insensitive string match.
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
answer: The agent's submitted category name.
|
| 49 |
+
|
| 50 |
+
Returns:
|
| 51 |
+
1.0 if the answer matches the expected category, 0.0 otherwise.
|
| 52 |
+
"""
|
| 53 |
+
expected = self.expected_answer().strip().lower()
|
| 54 |
+
submitted = answer.strip().lower()
|
| 55 |
+
return 1.0 if submitted == expected else 0.0
|
tasks/task_hard.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Task 3 (Hard): Analyze repeat customers who ordered in both January and December.
|
| 2 |
+
|
| 3 |
+
Requires temporal filtering, set intersection, and conditional aggregation.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import re
|
| 7 |
+
|
| 8 |
+
import pandas as pd
|
| 9 |
+
|
| 10 |
+
from tasks.base_task import BaseTask
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class RepeatCustomerCohortTask(BaseTask):
|
| 14 |
+
"""Hard task: find customers who ordered in both January and December.
|
| 15 |
+
|
| 16 |
+
The agent must identify customers present in both months, count them,
|
| 17 |
+
and compare their average order value to all other customers.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
@property
|
| 21 |
+
def task_id(self) -> int:
|
| 22 |
+
"""Return the task identifier."""
|
| 23 |
+
return 3
|
| 24 |
+
|
| 25 |
+
@property
|
| 26 |
+
def difficulty(self) -> str:
|
| 27 |
+
"""Return the difficulty level."""
|
| 28 |
+
return "hard"
|
| 29 |
+
|
| 30 |
+
@property
|
| 31 |
+
def description(self) -> str:
|
| 32 |
+
"""Return the task question."""
|
| 33 |
+
return (
|
| 34 |
+
"How many unique customers placed orders in BOTH January and December? "
|
| 35 |
+
"What is their average order value compared to all other customers? "
|
| 36 |
+
"Submit your answer in the format: "
|
| 37 |
+
"'Cohort: N customers, Cohort AOV: $X.XX, Other AOV: $X.XX'"
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
def _compute_cohort(self) -> tuple[set, float, float]:
|
| 41 |
+
"""Compute the cohort of customers ordering in both January and December.
|
| 42 |
+
|
| 43 |
+
Returns:
|
| 44 |
+
A tuple of (cohort_customer_ids, cohort_aov, other_aov).
|
| 45 |
+
"""
|
| 46 |
+
df = self.df.copy()
|
| 47 |
+
df["order_date"] = pd.to_datetime(df["order_date"])
|
| 48 |
+
jan_customers = set(df[df["order_date"].dt.month == 1]["customer_id"])
|
| 49 |
+
dec_customers = set(df[df["order_date"].dt.month == 12]["customer_id"])
|
| 50 |
+
cohort = jan_customers & dec_customers
|
| 51 |
+
|
| 52 |
+
cohort_aov = df[df["customer_id"].isin(cohort)]["total_price"].mean()
|
| 53 |
+
other_aov = df[~df["customer_id"].isin(cohort)]["total_price"].mean()
|
| 54 |
+
return cohort, round(cohort_aov, 2), round(other_aov, 2)
|
| 55 |
+
|
| 56 |
+
def expected_answer(self) -> str:
|
| 57 |
+
"""Compute the expected cohort analysis answer.
|
| 58 |
+
|
| 59 |
+
Returns:
|
| 60 |
+
Formatted string like 'Cohort: 57 customers, Cohort AOV: $126.57, Other AOV: $122.94'.
|
| 61 |
+
"""
|
| 62 |
+
cohort, cohort_aov, other_aov = self._compute_cohort()
|
| 63 |
+
return f"Cohort: {len(cohort)} customers, Cohort AOV: ${cohort_aov}, Other AOV: ${other_aov}"
|
| 64 |
+
|
| 65 |
+
def grade(self, answer: str) -> float:
|
| 66 |
+
"""Grade the answer with partial credit for each of the three fields.
|
| 67 |
+
|
| 68 |
+
Scoring:
|
| 69 |
+
- 0.33 for correct customer count (exact match)
|
| 70 |
+
- 0.33 for cohort AOV within ±0.5% of expected
|
| 71 |
+
- 0.34 for other AOV within ±0.5% of expected
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
answer: The agent's submitted answer string.
|
| 75 |
+
|
| 76 |
+
Returns:
|
| 77 |
+
A score between 0.0 and 1.0.
|
| 78 |
+
"""
|
| 79 |
+
cohort, expected_cohort_aov, expected_other_aov = self._compute_cohort()
|
| 80 |
+
expected_count = len(cohort)
|
| 81 |
+
score = 0.0
|
| 82 |
+
|
| 83 |
+
# Check customer count
|
| 84 |
+
count_match = re.search(r"Cohort:\s*(\d+)\s*customers?", answer, re.IGNORECASE)
|
| 85 |
+
if count_match:
|
| 86 |
+
if int(count_match.group(1)) == expected_count:
|
| 87 |
+
score += 0.33
|
| 88 |
+
|
| 89 |
+
# Check cohort AOV
|
| 90 |
+
cohort_aov_match = re.search(r"Cohort\s+AOV:\s*\$?([\d.]+)", answer, re.IGNORECASE)
|
| 91 |
+
if cohort_aov_match:
|
| 92 |
+
try:
|
| 93 |
+
submitted = float(cohort_aov_match.group(1))
|
| 94 |
+
tolerance = expected_cohort_aov * 0.005
|
| 95 |
+
if abs(submitted - expected_cohort_aov) <= tolerance:
|
| 96 |
+
score += 0.33
|
| 97 |
+
except ValueError:
|
| 98 |
+
pass
|
| 99 |
+
|
| 100 |
+
# Check other AOV
|
| 101 |
+
other_aov_match = re.search(r"Other\s+AOV:\s*\$?([\d.]+)", answer, re.IGNORECASE)
|
| 102 |
+
if other_aov_match:
|
| 103 |
+
try:
|
| 104 |
+
submitted = float(other_aov_match.group(1))
|
| 105 |
+
tolerance = expected_other_aov * 0.005
|
| 106 |
+
if abs(submitted - expected_other_aov) <= tolerance:
|
| 107 |
+
score += 0.34
|
| 108 |
+
except ValueError:
|
| 109 |
+
pass
|
| 110 |
+
|
| 111 |
+
return score
|
tasks/task_medium.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Task 2 (Medium): Find the top revenue city and its share of total revenue.
|
| 2 |
+
|
| 3 |
+
Requires groupby + aggregation + percentage calculation + formatting.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import re
|
| 7 |
+
|
| 8 |
+
import pandas as pd
|
| 9 |
+
|
| 10 |
+
from tasks.base_task import BaseTask
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class CityRevenueShareTask(BaseTask):
|
| 14 |
+
"""Medium task: identify the city with the highest revenue and its percentage share.
|
| 15 |
+
|
| 16 |
+
The agent must group by city, compute total revenue per city,
|
| 17 |
+
find the top city, and calculate what percentage of overall revenue it represents.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
@property
|
| 21 |
+
def task_id(self) -> int:
|
| 22 |
+
"""Return the task identifier."""
|
| 23 |
+
return 2
|
| 24 |
+
|
| 25 |
+
@property
|
| 26 |
+
def difficulty(self) -> str:
|
| 27 |
+
"""Return the difficulty level."""
|
| 28 |
+
return "medium"
|
| 29 |
+
|
| 30 |
+
@property
|
| 31 |
+
def description(self) -> str:
|
| 32 |
+
"""Return the task question."""
|
| 33 |
+
return (
|
| 34 |
+
"Which city generates the most revenue? What percentage of total revenue "
|
| 35 |
+
"does it represent? Round to 2 decimal places. "
|
| 36 |
+
"Submit your answer in the format: 'City: <name>, Percentage: <X.XX>%'"
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
def expected_answer(self) -> str:
|
| 40 |
+
"""Compute the top city and its revenue share.
|
| 41 |
+
|
| 42 |
+
Returns:
|
| 43 |
+
Formatted string like 'City: London, Percentage: 10.81%'.
|
| 44 |
+
"""
|
| 45 |
+
city_rev = self.df.groupby("city")["total_price"].sum()
|
| 46 |
+
top_city = city_rev.idxmax()
|
| 47 |
+
pct = round(city_rev[top_city] / city_rev.sum() * 100, 2)
|
| 48 |
+
return f"City: {top_city}, Percentage: {pct}%"
|
| 49 |
+
|
| 50 |
+
def grade(self, answer: str) -> float:
|
| 51 |
+
"""Grade the answer with partial credit for city and percentage.
|
| 52 |
+
|
| 53 |
+
Scoring:
|
| 54 |
+
- 0.5 for correct city name (case-insensitive)
|
| 55 |
+
- 0.5 for percentage within ±0.1 of expected
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
answer: The agent's submitted answer string.
|
| 59 |
+
|
| 60 |
+
Returns:
|
| 61 |
+
A score between 0.0 and 1.0.
|
| 62 |
+
"""
|
| 63 |
+
score = 0.0
|
| 64 |
+
city_rev = self.df.groupby("city")["total_price"].sum()
|
| 65 |
+
expected_city = city_rev.idxmax()
|
| 66 |
+
expected_pct = round(city_rev[expected_city] / city_rev.sum() * 100, 2)
|
| 67 |
+
|
| 68 |
+
# Check city
|
| 69 |
+
city_match = re.search(r"City:\s*([^,]+)", answer, re.IGNORECASE)
|
| 70 |
+
if city_match:
|
| 71 |
+
submitted_city = city_match.group(1).strip()
|
| 72 |
+
if submitted_city.lower() == expected_city.lower():
|
| 73 |
+
score += 0.5
|
| 74 |
+
|
| 75 |
+
# Check percentage
|
| 76 |
+
pct_match = re.search(r"Percentage:\s*([\d.]+)%?", answer, re.IGNORECASE)
|
| 77 |
+
if pct_match:
|
| 78 |
+
try:
|
| 79 |
+
submitted_pct = float(pct_match.group(1))
|
| 80 |
+
if abs(submitted_pct - expected_pct) <= 0.1:
|
| 81 |
+
score += 0.5
|
| 82 |
+
except ValueError:
|
| 83 |
+
pass
|
| 84 |
+
|
| 85 |
+
return score
|
uv.lock
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|