Spaces:
Sleeping
Sleeping
Env code
#1
by Mohammed-Altaf - opened
- .env.example +0 -21
- .gitattributes +35 -0
- .gitignore +0 -13
- .python-version +0 -1
- Dockerfile +0 -32
- README.md +0 -176
- __init__.py +0 -15
- baseline.py +0 -163
- client.py +0 -56
- datasets/sales.csv +0 -0
- datasets/store_data.db +0 -0
- helpers/__init__.py +0 -0
- helpers/constants.py +0 -9
- helpers/logging.py +0 -44
- helpers/prompts.py +0 -27
- helpers/response_parser.py +0 -181
- inference.py +0 -172
- models.py +0 -52
- openenv.yaml +0 -21
- pyproject.toml +0 -25
- server/Dockerfile +0 -58
- server/__init__.py +0 -0
- server/app.py +0 -15
- server/data_analysis_env.py +0 -296
- tasks/__init__.py +0 -29
- tasks/base_task.py +0 -51
- tasks/task_easy.py +0 -53
- tasks/task_hard.py +0 -103
- tasks/task_hard_2.py +0 -103
- tasks/task_hard_3.py +0 -107
- tasks/task_medium.py +0 -77
- tasks/task_medium_2.py +0 -88
- uv.lock +0 -0
.env.example
DELETED
|
@@ -1,21 +0,0 @@
|
|
| 1 |
-
# Copy this file to .env and fill in your values.
|
| 2 |
-
# .env is gitignored — never commit actual keys.
|
| 3 |
-
#
|
| 4 |
-
# Usage:
|
| 5 |
-
# cp .env.example .env
|
| 6 |
-
# # edit .env with your values
|
| 7 |
-
# uv run python inference.py
|
| 8 |
-
|
| 9 |
-
# OpenAI-compatible LLM API endpoint
|
| 10 |
-
API_BASE_URL=https://router.huggingface.co/v1
|
| 11 |
-
|
| 12 |
-
# Model identifier
|
| 13 |
-
MODEL_NAME=Qwen/Qwen2.5-72B-Instruct
|
| 14 |
-
|
| 15 |
-
# API key (Hugging Face token or other provider key)
|
| 16 |
-
HF_TOKEN=hf_...
|
| 17 |
-
|
| 18 |
-
# (Optional) Override the environment server URL
|
| 19 |
-
# Default is the deployed HF Space: https://mohammed-altaf-dataanalysis-env.hf.space
|
| 20 |
-
# Override for local testing:
|
| 21 |
-
# ENV_SERVER_URL=http://localhost:8000
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.gitattributes
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
DELETED
|
@@ -1,13 +0,0 @@
|
|
| 1 |
-
__pycache__/
|
| 2 |
-
*.py[oc]
|
| 3 |
-
build/
|
| 4 |
-
dist/
|
| 5 |
-
wheels/
|
| 6 |
-
*.egg-info
|
| 7 |
-
.venv
|
| 8 |
-
OpenEnv/
|
| 9 |
-
*.ipynb
|
| 10 |
-
personal/
|
| 11 |
-
.env
|
| 12 |
-
CLAUDE.md
|
| 13 |
-
.claude
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.python-version
DELETED
|
@@ -1 +0,0 @@
|
|
| 1 |
-
3.13
|
|
|
|
|
|
Dockerfile
DELETED
|
@@ -1,32 +0,0 @@
|
|
| 1 |
-
ARG BASE_IMAGE=python:3.13-slim
|
| 2 |
-
FROM ${BASE_IMAGE}
|
| 3 |
-
|
| 4 |
-
WORKDIR /app
|
| 5 |
-
|
| 6 |
-
# Install uv
|
| 7 |
-
RUN pip install uv --no-cache-dir
|
| 8 |
-
|
| 9 |
-
# Copy project files
|
| 10 |
-
COPY pyproject.toml uv.lock* ./
|
| 11 |
-
COPY models.py client.py __init__.py baseline.py inference.py openenv.yaml ./
|
| 12 |
-
COPY server/ ./server/
|
| 13 |
-
COPY tasks/ ./tasks/
|
| 14 |
-
COPY datasets/ ./datasets/
|
| 15 |
-
COPY helpers/ ./helpers/
|
| 16 |
-
|
| 17 |
-
# Install dependencies into the uv-managed venv
|
| 18 |
-
RUN uv sync --frozen --no-dev
|
| 19 |
-
|
| 20 |
-
# Make the venv's python/pip the default so `python inference.py` works
|
| 21 |
-
# without needing `uv run` as a prefix
|
| 22 |
-
ENV PATH="/app/.venv/bin:$PATH"
|
| 23 |
-
|
| 24 |
-
# Ensure local modules (client, models, helpers, tasks) are always importable
|
| 25 |
-
# regardless of the working directory the evaluator uses
|
| 26 |
-
ENV PYTHONPATH="/app:$PYTHONPATH"
|
| 27 |
-
|
| 28 |
-
# HF Spaces runs containers as a non-root user on port 7860
|
| 29 |
-
ENV PORT=7860
|
| 30 |
-
EXPOSE 7860
|
| 31 |
-
|
| 32 |
-
CMD ["uvicorn", "server.app:app", "--host", "0.0.0.0", "--port", "7860"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
README.md
DELETED
|
@@ -1,176 +0,0 @@
|
|
| 1 |
-
---
|
| 2 |
-
title: Data Analysis Agent Environment
|
| 3 |
-
emoji: 📊
|
| 4 |
-
colorFrom: blue
|
| 5 |
-
colorTo: indigo
|
| 6 |
-
sdk: docker
|
| 7 |
-
pinned: false
|
| 8 |
-
---
|
| 9 |
-
|
| 10 |
-
# Data Analysis Agent Environment
|
| 11 |
-
|
| 12 |
-
An OpenEnv-compliant RL environment for training and evaluating data analysis agents. Agents execute pandas code against a business dataset to answer analytical questions, graded by deterministic programmatic graders.
|
| 13 |
-
|
| 14 |
-
## Motivation
|
| 15 |
-
|
| 16 |
-
Data analysis is a universal real-world task. Every business needs analysts who can query datasets, compute metrics, and extract insights. This environment lets RL agents practice that exact workflow — explore a dataset with code, then submit a precise answer — with automatic scoring.
|
| 17 |
-
|
| 18 |
-
## Action & Observation Spaces
|
| 19 |
-
|
| 20 |
-
### Action (`DataAction`)
|
| 21 |
-
|
| 22 |
-
| Field | Type | Description |
|
| 23 |
-
|---|---|---|
|
| 24 |
-
| `action_type` | `"execute_code"` or `"submit_answer"` | What the agent wants to do |
|
| 25 |
-
| `code` | `str` (optional) | Python/pandas code to execute |
|
| 26 |
-
| `answer` | `str` (optional) | Final answer to submit for grading |
|
| 27 |
-
|
| 28 |
-
### Observation (`DataObservation`)
|
| 29 |
-
|
| 30 |
-
| Field | Type | Description |
|
| 31 |
-
|---|---|---|
|
| 32 |
-
| `output` | `str` | Stdout from code execution or environment messages |
|
| 33 |
-
| `success` | `bool` | Whether the action succeeded |
|
| 34 |
-
| `error` | `str` (optional) | Error message if action failed |
|
| 35 |
-
| `task_description` | `str` | The question to answer (set on reset) |
|
| 36 |
-
| `dataset_info` | `str` | Dataset schema summary (set on reset) |
|
| 37 |
-
| `done` | `bool` | Whether the episode is over |
|
| 38 |
-
| `reward` | `float` | Step reward |
|
| 39 |
-
|
| 40 |
-
### State (`DataState`)
|
| 41 |
-
|
| 42 |
-
| Field | Type | Description |
|
| 43 |
-
|---|---|---|
|
| 44 |
-
| `episode_id` | `str` | Unique episode identifier |
|
| 45 |
-
| `step_count` | `int` | Current step number |
|
| 46 |
-
| `task_id` | `int` | Active task (1–6) |
|
| 47 |
-
| `answer_submitted` | `bool` | Whether final answer was submitted |
|
| 48 |
-
| `final_score` | `float` | Graded score after submission |
|
| 49 |
-
|
| 50 |
-
## Tasks
|
| 51 |
-
|
| 52 |
-
Tasks use two data sources:
|
| 53 |
-
- **`df`** — synthetic e-commerce sales CSV (~2000 orders): `order_id`, `customer_id`, `product_name`, `category`, `quantity`, `unit_price`, `total_price`, `order_date`, `city`, `country`
|
| 54 |
-
- **SQLite DB** (`store_data.db`) — additional tables for cross-source tasks: `customer_profiles` (300 rows), `product_catalog` (25 rows)
|
| 55 |
-
|
| 56 |
-
### Task 1 — Easy: Top Revenue Category
|
| 57 |
-
- **Question**: What is the top-selling product category by total revenue?
|
| 58 |
-
- **Grading**: Containment match (case-insensitive) → 1.0 or 0.0
|
| 59 |
-
- **Expected difficulty**: Single groupby + sum + argmax
|
| 60 |
-
|
| 61 |
-
### Task 2 — Medium: City Revenue Share
|
| 62 |
-
- **Question**: Which city generates the most revenue? What percentage of total revenue does it represent?
|
| 63 |
-
- **Grading**: 0.5 for correct city + 0.5 for percentage within ±0.1%
|
| 64 |
-
- **Expected difficulty**: Groupby + percentage calculation + formatting
|
| 65 |
-
|
| 66 |
-
### Task 3 — Medium: Repeat Customer Cohort Analysis
|
| 67 |
-
- **Question**: How many unique customers ordered in both January and December? Compare their average order value to all other customers.
|
| 68 |
-
- **Grading**: 0.33 per correct field (count, cohort AOV, other AOV)
|
| 69 |
-
- **Expected difficulty**: Temporal filtering, set intersection, conditional aggregation
|
| 70 |
-
|
| 71 |
-
### Task 4 — Hard: Monthly Revenue Ratio
|
| 72 |
-
- **Question**: Which month had the highest vs. lowest total revenue? What is the ratio between them?
|
| 73 |
-
- **Grading**: 0.33 for best month + 0.33 for worst month + 0.34 for ratio within ±0.01
|
| 74 |
-
- **Expected difficulty**: Monthly resample/groupby, min/max comparison, ratio formatting
|
| 75 |
-
|
| 76 |
-
### Task 5 — Hard: Customer Loyalty Tier Revenue (cross-source)
|
| 77 |
-
- **Question**: Which customer loyalty tier generates the highest total revenue and what percentage does it represent?
|
| 78 |
-
- **Data**: Requires joining `df` with `customer_profiles` table from SQLite on `customer_id`
|
| 79 |
-
- **Grading**: 0.33 for tier name + 0.33 for revenue within ±0.5% + 0.34 for percentage within ±0.1
|
| 80 |
-
- **Expected difficulty**: SQLite query → pandas merge → groupby aggregation
|
| 81 |
-
|
| 82 |
-
### Task 6 — Hard: Supplier Profitability (cross-source)
|
| 83 |
-
- **Question**: Which supplier has the highest total profit? What is their average profit margin?
|
| 84 |
-
- **Data**: Requires joining `df` with `product_catalog` table from SQLite on `product_name`
|
| 85 |
-
- **Grading**: 0.33 for supplier name + 0.34 for total profit within ±0.5% + 0.33 for avg margin within ±0.1
|
| 86 |
-
- **Expected difficulty**: SQLite query → pandas merge → per-order profit/margin calculation → group aggregation
|
| 87 |
-
|
| 88 |
-
## Reward Function
|
| 89 |
-
|
| 90 |
-
| Event | Reward |
|
| 91 |
-
|---|---|
|
| 92 |
-
| Successful code execution | +0.05 |
|
| 93 |
-
| Code execution error | -0.05 |
|
| 94 |
-
| Final answer (graded) | 0.0 — 1.0 based on task grader |
|
| 95 |
-
| Max steps (20) exceeded | 0.0 |
|
| 96 |
-
|
| 97 |
-
## Setup & Usage
|
| 98 |
-
|
| 99 |
-
### Prerequisites
|
| 100 |
-
- Python 3.13+
|
| 101 |
-
- [uv](https://docs.astral.sh/uv/) package manager
|
| 102 |
-
|
| 103 |
-
### Install
|
| 104 |
-
```bash
|
| 105 |
-
uv sync
|
| 106 |
-
```
|
| 107 |
-
|
| 108 |
-
### Run the server
|
| 109 |
-
```bash
|
| 110 |
-
uv run uvicorn server.app:app --host 0.0.0.0 --port 8000
|
| 111 |
-
```
|
| 112 |
-
|
| 113 |
-
### Run the inference
|
| 114 |
-
- First export all the required env variables mentioned in the .env.example. Then run below command
|
| 115 |
-
```bash
|
| 116 |
-
uv run python inference.py
|
| 117 |
-
```
|
| 118 |
-
|
| 119 |
-
### Run the baseline
|
| 120 |
-
```bash
|
| 121 |
-
OPENAI_API_KEY=sk-... uv run python baseline.py
|
| 122 |
-
# Against a deployed HF Space:
|
| 123 |
-
OPENAI_API_KEY=sk-... uv run python baseline.py --base-url https://<your-username>-<space-name>.hf.space
|
| 124 |
-
```
|
| 125 |
-
|
| 126 |
-
### Docker (local)
|
| 127 |
-
```bash
|
| 128 |
-
docker build -t data-analysis-env .
|
| 129 |
-
docker run -p 7860:7860 data-analysis-env
|
| 130 |
-
```
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
### Client usage (Python)
|
| 134 |
-
```python
|
| 135 |
-
from client import DataAnalysisClient
|
| 136 |
-
from models import DataAction
|
| 137 |
-
|
| 138 |
-
# Async
|
| 139 |
-
async with DataAnalysisClient(base_url="http://localhost:8000") as client:
|
| 140 |
-
result = await client.reset(task_id=1)
|
| 141 |
-
result = await client.step(DataAction(action_type="execute_code", code="print(df.head())"))
|
| 142 |
-
result = await client.step(DataAction(action_type="submit_answer", answer="Electronics"))
|
| 143 |
-
|
| 144 |
-
# Sync
|
| 145 |
-
with DataAnalysisClient(base_url="http://localhost:8000").sync() as client:
|
| 146 |
-
result = client.reset(task_id=2)
|
| 147 |
-
result = client.step(DataAction(action_type="execute_code", code="print(df.groupby('city')['total_price'].sum())"))
|
| 148 |
-
```
|
| 149 |
-
|
| 150 |
-
## Project Structure
|
| 151 |
-
|
| 152 |
-
```
|
| 153 |
-
├── models.py # DataAction, DataObservation, DataState
|
| 154 |
-
├── client.py # DataAnalysisClient (EnvClient subclass)
|
| 155 |
-
├── inference.py # HF inference script (uses HF Inference API)
|
| 156 |
-
├── baseline.py # OpenAI baseline inference script
|
| 157 |
-
├── helpers/
|
| 158 |
-
│ └── response_parser.py # Robust LLM JSON response parser
|
| 159 |
-
├── tasks/
|
| 160 |
-
│ ├── base_task.py # Task ABC with grade() interface
|
| 161 |
-
│ ├── task_easy.py # Task 1 (Easy): Top revenue category
|
| 162 |
-
│ ├── task_medium.py # Task 2 (Medium): City revenue share
|
| 163 |
-
│ ├── task_medium_2.py # Task 4 (Hard): Monthly revenue ratio
|
| 164 |
-
│ ├── task_hard.py # Task 3 (Medium): Repeat customer cohort
|
| 165 |
-
│ ├── task_hard_2.py # Task 5 (Hard): Customer loyalty tier revenue
|
| 166 |
-
│ └── task_hard_3.py # Task 6 (Hard): Supplier profitability
|
| 167 |
-
├── datasets/
|
| 168 |
-
│ ├── sales.csv # Synthetic e-commerce sales dataset
|
| 169 |
-
│ └── store_data.db # SQLite DB: customer_profiles, product_catalog
|
| 170 |
-
├── server/
|
| 171 |
-
│ ├── app.py # FastAPI app entry point
|
| 172 |
-
│ └── data_analysis_env.py # Environment implementation
|
| 173 |
-
├── Dockerfile # HF Spaces Docker build (port 7860)
|
| 174 |
-
├── openenv.yaml # OpenEnv spec metadata
|
| 175 |
-
└── pyproject.toml # Dependencies and project config
|
| 176 |
-
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
__init__.py
DELETED
|
@@ -1,15 +0,0 @@
|
|
| 1 |
-
"""Data Analysis Agent Environment for OpenEnv.
|
| 2 |
-
|
| 3 |
-
An RL environment where agents execute pandas code against a business
|
| 4 |
-
dataset to answer analytical questions with programmatic grading.
|
| 5 |
-
"""
|
| 6 |
-
|
| 7 |
-
from client import DataAnalysisClient
|
| 8 |
-
from models import DataAction, DataObservation, DataState
|
| 9 |
-
|
| 10 |
-
__all__ = [
|
| 11 |
-
"DataAnalysisClient",
|
| 12 |
-
"DataAction",
|
| 13 |
-
"DataObservation",
|
| 14 |
-
"DataState",
|
| 15 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
baseline.py
DELETED
|
@@ -1,163 +0,0 @@
|
|
| 1 |
-
"""Baseline inference script for the Data Analysis Agent environment.
|
| 2 |
-
|
| 3 |
-
Uses the OpenAI API to run a model (gpt-4o-mini) against all 6 tasks
|
| 4 |
-
and produces reproducible baseline scores.
|
| 5 |
-
|
| 6 |
-
The script uses DataAnalysisClient (WebSocket) because the HTTP endpoints
|
| 7 |
-
are stateless — each request gets a fresh env instance. State (namespace,
|
| 8 |
-
task, dataset) only persists within a WebSocket session.
|
| 9 |
-
|
| 10 |
-
Tasks 1-3 use only the pandas DataFrame (df). Tasks 4-6 are cross-source:
|
| 11 |
-
they also require querying a SQLite database via sqlite3.connect(db_path).
|
| 12 |
-
|
| 13 |
-
Usage:
|
| 14 |
-
OPENAI_API_KEY=sk-... uv run python baseline.py
|
| 15 |
-
OPENAI_API_KEY=sk-... uv run python baseline.py --base-url http://localhost:8000
|
| 16 |
-
"""
|
| 17 |
-
|
| 18 |
-
import argparse
|
| 19 |
-
import json
|
| 20 |
-
import os
|
| 21 |
-
import sys
|
| 22 |
-
|
| 23 |
-
from openai import OpenAI
|
| 24 |
-
|
| 25 |
-
from client import DataAnalysisClient
|
| 26 |
-
from helpers.prompts import SYSTEM_PROMPT
|
| 27 |
-
from models import DataAction
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
def run_task(openai_client: OpenAI, env_client: DataAnalysisClient, task_id: int, max_steps: int = 15) -> float:
|
| 31 |
-
"""Run a single task using the OpenAI API as the agent.
|
| 32 |
-
|
| 33 |
-
Args:
|
| 34 |
-
openai_client: The OpenAI client instance.
|
| 35 |
-
env_client: The connected DataAnalysisClient (sync wrapper).
|
| 36 |
-
task_id: Which task to run (1–6).
|
| 37 |
-
max_steps: Maximum agent steps before giving up.
|
| 38 |
-
|
| 39 |
-
Returns:
|
| 40 |
-
The final score for this task (0.0 to 1.0).
|
| 41 |
-
"""
|
| 42 |
-
result = env_client.reset(task_id=task_id)
|
| 43 |
-
obs = result.observation
|
| 44 |
-
|
| 45 |
-
messages = [
|
| 46 |
-
{"role": "system", "content": SYSTEM_PROMPT},
|
| 47 |
-
{
|
| 48 |
-
"role": "user",
|
| 49 |
-
"content": f"Task: {obs.task_description}\n\nDataset Info:\n{obs.dataset_info}",
|
| 50 |
-
},
|
| 51 |
-
]
|
| 52 |
-
|
| 53 |
-
print(f"\n--- Task {task_id} ---")
|
| 54 |
-
print(f"Question: {obs.task_description}")
|
| 55 |
-
|
| 56 |
-
for step in range(max_steps):
|
| 57 |
-
response = openai_client.chat.completions.create(
|
| 58 |
-
model="gpt-4o-mini",
|
| 59 |
-
messages=messages,
|
| 60 |
-
temperature=0.0,
|
| 61 |
-
)
|
| 62 |
-
assistant_msg = response.choices[0].message.content.strip()
|
| 63 |
-
|
| 64 |
-
# Parse the agent's JSON response
|
| 65 |
-
try:
|
| 66 |
-
# Handle markdown code blocks if present
|
| 67 |
-
if assistant_msg.startswith("```"):
|
| 68 |
-
assistant_msg = assistant_msg.split("```")[1]
|
| 69 |
-
if assistant_msg.startswith("json"):
|
| 70 |
-
assistant_msg = assistant_msg[4:]
|
| 71 |
-
assistant_msg = assistant_msg.strip()
|
| 72 |
-
action = json.loads(assistant_msg)
|
| 73 |
-
except json.JSONDecodeError:
|
| 74 |
-
messages.append({"role": "assistant", "content": assistant_msg})
|
| 75 |
-
messages.append(
|
| 76 |
-
{
|
| 77 |
-
"role": "user",
|
| 78 |
-
"content": "Invalid JSON. Please respond with valid JSON only.",
|
| 79 |
-
}
|
| 80 |
-
)
|
| 81 |
-
continue
|
| 82 |
-
|
| 83 |
-
action_type = action.get("action", "")
|
| 84 |
-
|
| 85 |
-
if action_type == "execute_code":
|
| 86 |
-
result = env_client.step(DataAction(action_type="execute_code", code=action.get("code", "")))
|
| 87 |
-
obs = result.observation
|
| 88 |
-
result_text = f"Output: {obs.output}" if not obs.error else f"Error: {obs.error}"
|
| 89 |
-
print(f" Step {step + 1}: execute_code -> {result_text[:120]}")
|
| 90 |
-
messages.append({"role": "assistant", "content": assistant_msg})
|
| 91 |
-
messages.append({"role": "user", "content": result_text})
|
| 92 |
-
|
| 93 |
-
elif action_type == "submit_answer":
|
| 94 |
-
result = env_client.step(DataAction(action_type="submit_answer", answer=action.get("answer", "")))
|
| 95 |
-
obs = result.observation
|
| 96 |
-
score = obs.metadata.get("score", 0.0) if obs.metadata else result.reward
|
| 97 |
-
print(f" Step {step + 1}: submit_answer -> '{action.get('answer', '')}'")
|
| 98 |
-
print(f" Score: {score:.2f}")
|
| 99 |
-
return score
|
| 100 |
-
|
| 101 |
-
else:
|
| 102 |
-
messages.append({"role": "assistant", "content": assistant_msg})
|
| 103 |
-
messages.append(
|
| 104 |
-
{
|
| 105 |
-
"role": "user",
|
| 106 |
-
"content": f"Unknown action '{action_type}'. Use 'execute_code' or 'submit_answer'.",
|
| 107 |
-
}
|
| 108 |
-
)
|
| 109 |
-
|
| 110 |
-
print(" Max steps reached without submitting an answer.")
|
| 111 |
-
return 0.0
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
def main():
|
| 115 |
-
"""Run baseline inference across all 6 tasks and report scores."""
|
| 116 |
-
parser = argparse.ArgumentParser(description="Baseline inference for Data Analysis Env")
|
| 117 |
-
parser.add_argument(
|
| 118 |
-
"--base-url",
|
| 119 |
-
default="http://localhost:8000",
|
| 120 |
-
help="Environment server URL (default: http://localhost:8000)",
|
| 121 |
-
)
|
| 122 |
-
args = parser.parse_args()
|
| 123 |
-
|
| 124 |
-
api_key = os.environ.get("OPENAI_API_KEY")
|
| 125 |
-
if not api_key:
|
| 126 |
-
print("Error: OPENAI_API_KEY environment variable is required.")
|
| 127 |
-
sys.exit(1)
|
| 128 |
-
|
| 129 |
-
openai_client = OpenAI(api_key=api_key)
|
| 130 |
-
|
| 131 |
-
print("=" * 55)
|
| 132 |
-
print("Data Analysis Agent - Baseline Inference")
|
| 133 |
-
print(f"Server: {args.base_url}")
|
| 134 |
-
print("Model: gpt-4o-mini")
|
| 135 |
-
print("=" * 55)
|
| 136 |
-
|
| 137 |
-
scores = {}
|
| 138 |
-
difficulties = {
|
| 139 |
-
1: "Easy",
|
| 140 |
-
2: "Medium",
|
| 141 |
-
3: "Medium",
|
| 142 |
-
4: "Hard",
|
| 143 |
-
5: "Hard",
|
| 144 |
-
6: "Hard",
|
| 145 |
-
}
|
| 146 |
-
|
| 147 |
-
with DataAnalysisClient(base_url=args.base_url).sync() as env_client:
|
| 148 |
-
for task_id in [1, 2, 3, 4, 5, 6]:
|
| 149 |
-
score = run_task(openai_client, env_client, task_id)
|
| 150 |
-
scores[task_id] = score
|
| 151 |
-
|
| 152 |
-
print("\n" + "=" * 55)
|
| 153 |
-
print("RESULTS")
|
| 154 |
-
print("=" * 55)
|
| 155 |
-
for task_id, score in scores.items():
|
| 156 |
-
print(f" Task {task_id} ({difficulties[task_id]:6s}): {score:.2f}")
|
| 157 |
-
avg = sum(scores.values()) / len(scores)
|
| 158 |
-
print(f"\n Average Score: {avg:.2f}")
|
| 159 |
-
print("=" * 55)
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
if __name__ == "__main__":
|
| 163 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
client.py
DELETED
|
@@ -1,56 +0,0 @@
|
|
| 1 |
-
from models import DataAction, DataObservation, DataState
|
| 2 |
-
from openenv.core.client_types import StepResult
|
| 3 |
-
from openenv.core.env_client import EnvClient
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
class DataAnalysisClient(EnvClient[DataAction, DataObservation, DataState]):
|
| 7 |
-
"""Client for interacting with the Data Analysis environment server.
|
| 8 |
-
|
| 9 |
-
Supports both async and sync usage patterns:
|
| 10 |
-
- Async: ``async with DataAnalysisClient(base_url=...) as client:``
|
| 11 |
-
- Sync: ``with DataAnalysisClient(base_url=...).sync() as client:``
|
| 12 |
-
"""
|
| 13 |
-
|
| 14 |
-
def _step_payload(self, action: DataAction) -> dict:
|
| 15 |
-
"""Convert a DataAction into a JSON-serializable payload.
|
| 16 |
-
|
| 17 |
-
Args:
|
| 18 |
-
action: The action to send to the server.
|
| 19 |
-
|
| 20 |
-
Returns:
|
| 21 |
-
A dictionary representation of the action.
|
| 22 |
-
"""
|
| 23 |
-
payload = {"action_type": action.action_type}
|
| 24 |
-
if action.code is not None:
|
| 25 |
-
payload["code"] = action.code
|
| 26 |
-
if action.answer is not None:
|
| 27 |
-
payload["answer"] = action.answer
|
| 28 |
-
return payload
|
| 29 |
-
|
| 30 |
-
def _parse_result(self, payload: dict) -> StepResult[DataObservation]:
|
| 31 |
-
"""Parse the server's JSON response into a StepResult.
|
| 32 |
-
|
| 33 |
-
Args:
|
| 34 |
-
payload: The raw JSON response from the server.
|
| 35 |
-
|
| 36 |
-
Returns:
|
| 37 |
-
A StepResult containing the parsed observation, reward, and done flag.
|
| 38 |
-
"""
|
| 39 |
-
obs_data = payload.get("observation", payload)
|
| 40 |
-
obs = DataObservation(**obs_data)
|
| 41 |
-
return StepResult(
|
| 42 |
-
observation=obs,
|
| 43 |
-
reward=payload.get("reward", obs.reward),
|
| 44 |
-
done=payload.get("done", obs.done),
|
| 45 |
-
)
|
| 46 |
-
|
| 47 |
-
def _parse_state(self, payload: dict) -> DataState:
|
| 48 |
-
"""Parse the server's state response into a DataState.
|
| 49 |
-
|
| 50 |
-
Args:
|
| 51 |
-
payload: The raw JSON state response from the server.
|
| 52 |
-
|
| 53 |
-
Returns:
|
| 54 |
-
A DataState object reflecting the current episode state.
|
| 55 |
-
"""
|
| 56 |
-
return DataState(**payload)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
datasets/sales.csv
DELETED
|
The diff for this file is too large to render.
See raw diff
|
|
|
datasets/store_data.db
DELETED
|
Binary file (24.6 kB)
|
|
|
helpers/__init__.py
DELETED
|
File without changes
|
helpers/constants.py
DELETED
|
@@ -1,9 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
|
| 3 |
-
TEMPERATURE = 0.0
|
| 4 |
-
MAX_TOKENS = 1024
|
| 5 |
-
MAX_STEPS = 15
|
| 6 |
-
API_BASE_URL = os.getenv("API_BASE_URL") or "https://router.huggingface.co/v1"
|
| 7 |
-
MODEL_NAME = os.getenv("MODEL_NAME") or "Qwen/Qwen2.5-72B-Instruct"
|
| 8 |
-
API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
|
| 9 |
-
ENV_SERVER_URL = os.getenv("ENV_SERVER_URL") or "https://mohammed-altaf-dataanalysis-env.hf.space"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
helpers/logging.py
DELETED
|
@@ -1,44 +0,0 @@
|
|
| 1 |
-
from typing import List, Optional
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
def safe_score(raw: float) -> float:
|
| 5 |
-
"""Clamp a raw score to the strictly-open interval (0.05, 0.95).
|
| 6 |
-
|
| 7 |
-
Args:
|
| 8 |
-
raw: Unclamped score value.
|
| 9 |
-
|
| 10 |
-
Returns:
|
| 11 |
-
Score guaranteed to be in [0.05, 0.95].
|
| 12 |
-
"""
|
| 13 |
-
return max(0.05, min(0.95, float(raw)))
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
def log_start(task: str, env: str, model: str) -> None:
|
| 17 |
-
"""Emit the [START] line at episode begin."""
|
| 18 |
-
print(f"[START] task={task} env={env} model={model}", flush=True)
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
|
| 22 |
-
"""Emit one [STEP] line immediately after env.step() returns.
|
| 23 |
-
|
| 24 |
-
Args:
|
| 25 |
-
step: 1-based step number.
|
| 26 |
-
action: Compact single-line action label (e.g. 'execute_code').
|
| 27 |
-
reward: Step reward, formatted to 2 decimal places.
|
| 28 |
-
done: Whether the episode ended after this step.
|
| 29 |
-
error: Raw error string from the env, or None.
|
| 30 |
-
"""
|
| 31 |
-
error_val = error.replace("\n", " ") if error else "null"
|
| 32 |
-
done_val = str(done).lower()
|
| 33 |
-
print(f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}", flush=True)
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
def log_end(task_id: int, score: float, steps: int) -> None:
|
| 37 |
-
"""Emit the [END] line after the episode completes.
|
| 38 |
-
|
| 39 |
-
Args:
|
| 40 |
-
task_id: The task number that just ran.
|
| 41 |
-
score: Final clamped score in [0.05, 0.95].
|
| 42 |
-
steps: Total number of steps taken.
|
| 43 |
-
"""
|
| 44 |
-
print(f"[END] task={task_id} score={score:.2f} steps={steps}", flush=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
helpers/prompts.py
DELETED
|
@@ -1,27 +0,0 @@
|
|
| 1 |
-
SYSTEM_PROMPT = """
|
| 2 |
-
<ROLE>
|
| 3 |
-
You are a data analyst. You have two data sources available:
|
| 4 |
-
1. `df` — a pandas DataFrame (sales CSV, pre-loaded)
|
| 5 |
-
2. A SQLite database at `db_path` — contains additional tables (e.g. customer_profiles, product_catalog)
|
| 6 |
-
</ROLE>
|
| 7 |
-
|
| 8 |
-
<RULES>
|
| 9 |
-
- Use `print()` to output results
|
| 10 |
-
- `pd`, `np`, `sqlite3`, and `db_path` are already in scope — NEVER use import statements (they will fail)
|
| 11 |
-
- `df` is a pandas DataFrame — use pandas operations on it, NEVER SQL
|
| 12 |
-
- To query the SQLite database use: `conn = sqlite3.connect(db_path)` then `pd.read_sql(query, conn)`
|
| 13 |
-
- For cross-source tasks: query SQLite for the extra data, then merge with `df` using pandas
|
| 14 |
-
- When you have the answer, submit it in the exact format requested
|
| 15 |
-
- Be precise with numbers and formatting
|
| 16 |
-
</RULES>
|
| 17 |
-
|
| 18 |
-
<RESPONSE>
|
| 19 |
-
Respond with JSON in one of these formats:
|
| 20 |
-
1. To execute code: {"action": "execute_code", "code": "your python code here"}
|
| 21 |
-
2. To submit answer: {"action": "submit_answer", "answer": "your answer here"}
|
| 22 |
-
</RESPONSE>
|
| 23 |
-
|
| 24 |
-
<NOTE>
|
| 25 |
-
Respond with ONLY the JSON, no other text.
|
| 26 |
-
</NOTE>
|
| 27 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
helpers/response_parser.py
DELETED
|
@@ -1,181 +0,0 @@
|
|
| 1 |
-
import json
|
| 2 |
-
import re
|
| 3 |
-
from typing import Any
|
| 4 |
-
|
| 5 |
-
FALLBACK_ACTION = json.dumps({"action": "submit_answer", "answer": "unknown"})
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
def _sanitize_string_value(match: re.Match) -> str:
|
| 9 |
-
"""
|
| 10 |
-
Receives a regex match of ("key": "value") and cleans only the value part.
|
| 11 |
-
Escapes unescaped newlines, tabs, carriage returns, and inner double quotes.
|
| 12 |
-
NOTE: This is the core trick LangChain uses in _replace_new_line / _custom_parser.
|
| 13 |
-
"""
|
| 14 |
-
opening = match.group(1)
|
| 15 |
-
value = match.group(2)
|
| 16 |
-
closing = match.group(3)
|
| 17 |
-
|
| 18 |
-
value = re.sub(r"\n", r"\\n", value)
|
| 19 |
-
value = re.sub(r"\r", r"\\r", value)
|
| 20 |
-
value = re.sub(r"\t", r"\\t", value)
|
| 21 |
-
value = re.sub(r'(?<!\\)"', r'\\"', value) # escape unescaped inner quotes
|
| 22 |
-
|
| 23 |
-
return opening + value + closing
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
def _sanitize_all_string_values(text: str) -> str:
|
| 27 |
-
"""
|
| 28 |
-
Apply _sanitize_string_value to every JSON string value in the text.
|
| 29 |
-
Uses re.DOTALL so values that span multiple lines are handled correctly.
|
| 30 |
-
NOTE: Generalised version of LangChain's _custom_parser (which only targeted action_input).
|
| 31 |
-
"""
|
| 32 |
-
return re.sub(
|
| 33 |
-
r'("[\w]+"\s*:\s*")(.*?)(")',
|
| 34 |
-
_sanitize_string_value,
|
| 35 |
-
text,
|
| 36 |
-
flags=re.DOTALL,
|
| 37 |
-
)
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
def _preprocess(text: str) -> str:
|
| 41 |
-
"""Fix common LLM response quirks before attempting JSON parsing."""
|
| 42 |
-
|
| 43 |
-
# Strip markdown code fences (```json ... ``` or ``` ... ```)
|
| 44 |
-
match = re.search(r"```(?:json)?\s*(.*?)```", text, re.DOTALL)
|
| 45 |
-
if match:
|
| 46 |
-
text = match.group(1).strip()
|
| 47 |
-
|
| 48 |
-
# Double curly braces {{"k": "v"}} → {"k": "v"}
|
| 49 |
-
text = text.replace("{{", "{").replace("}}", "}")
|
| 50 |
-
text = re.sub(r"\bTrue\b", "true", text)
|
| 51 |
-
text = re.sub(r"\bFalse\b", "false", text)
|
| 52 |
-
text = re.sub(r"\bNone\b", "null", text)
|
| 53 |
-
text = re.sub(r",\s*([}\]])", r"\1", text)
|
| 54 |
-
|
| 55 |
-
# Outer single-quote wrap '{"k": "v"}' → {"k": "v"}
|
| 56 |
-
if text.startswith("'") and text.endswith("'"):
|
| 57 |
-
text = text[1:-1].replace("\\'", "'")
|
| 58 |
-
|
| 59 |
-
return text.strip()
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
def _extract_json_blob(text: str) -> str:
|
| 63 |
-
"""
|
| 64 |
-
Pull out the first {...} or [...] blob from text that has prose around it.
|
| 65 |
-
Inspired by LangChain's _json_markdown_re fallback in parse_json_markdown.
|
| 66 |
-
"""
|
| 67 |
-
match = re.search(r"(\{.*\}|\[.*\])", text, re.DOTALL)
|
| 68 |
-
return match.group(1) if match else text
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
def _parse_partial_json(s: str) -> Any:
|
| 72 |
-
"""
|
| 73 |
-
Parse JSON that may be truncated / missing closing brackets.
|
| 74 |
-
Adapted from LangChain's parse_partial_json (originally from open-interpreter).
|
| 75 |
-
Uses a stack to track open containers and closes them before parsing.
|
| 76 |
-
"""
|
| 77 |
-
s = s.strip()
|
| 78 |
-
try:
|
| 79 |
-
return json.loads(s)
|
| 80 |
-
except json.JSONDecodeError:
|
| 81 |
-
pass
|
| 82 |
-
|
| 83 |
-
stack = []
|
| 84 |
-
is_inside = False
|
| 85 |
-
position = 0
|
| 86 |
-
|
| 87 |
-
for i, char in enumerate(s):
|
| 88 |
-
if is_inside:
|
| 89 |
-
if char == '"' and s[i - 1] != "\\":
|
| 90 |
-
is_inside = False
|
| 91 |
-
else:
|
| 92 |
-
if char == '"':
|
| 93 |
-
is_inside = True
|
| 94 |
-
stack.append('"')
|
| 95 |
-
elif char in "{[":
|
| 96 |
-
stack.append(char)
|
| 97 |
-
elif char in "}]":
|
| 98 |
-
if stack and stack[-1] in "{[":
|
| 99 |
-
stack.pop()
|
| 100 |
-
position = i
|
| 101 |
-
|
| 102 |
-
completed = s[: position + 1]
|
| 103 |
-
for bracket in reversed(stack):
|
| 104 |
-
if bracket == '"':
|
| 105 |
-
completed += '"'
|
| 106 |
-
elif bracket == "{":
|
| 107 |
-
completed += "}"
|
| 108 |
-
elif bracket == "[":
|
| 109 |
-
completed += "]"
|
| 110 |
-
|
| 111 |
-
return json.loads(completed)
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
def _extract_fields_direct(text: str) -> dict:
|
| 115 |
-
"""Extract action fields using greedy regex anchored to the last closing quote.
|
| 116 |
-
|
| 117 |
-
Handles the case where the model emits unescaped double-quote characters inside
|
| 118 |
-
a "code" or "answer" value (e.g. df["col"]). The non-greedy `(.*?)` in
|
| 119 |
-
_sanitize_all_string_values stops at the *first* inner quote and corrupts the
|
| 120 |
-
output. By using a greedy `(.*)` anchored with a lookahead for the last `"}`
|
| 121 |
-
boundary we capture the full value regardless of inner quotes.
|
| 122 |
-
|
| 123 |
-
Args:
|
| 124 |
-
text: Pre-processed JSON-like string.
|
| 125 |
-
|
| 126 |
-
Returns:
|
| 127 |
-
Dict with 'action' and 'code'/'answer' keys.
|
| 128 |
-
|
| 129 |
-
Raises:
|
| 130 |
-
ValueError: If the action field cannot be found or the value cannot be
|
| 131 |
-
extracted for the detected action type.
|
| 132 |
-
"""
|
| 133 |
-
action_match = re.search(r'"action"\s*:\s*"(\w+)"', text)
|
| 134 |
-
if not action_match:
|
| 135 |
-
raise ValueError("No 'action' field found")
|
| 136 |
-
action_type = action_match.group(1)
|
| 137 |
-
|
| 138 |
-
if action_type == "execute_code":
|
| 139 |
-
m = re.search(r'"code"\s*:\s*"(.*)"(?=\s*})', text, re.DOTALL)
|
| 140 |
-
if m:
|
| 141 |
-
return {"action": "execute_code", "code": m.group(1)}
|
| 142 |
-
elif action_type == "submit_answer":
|
| 143 |
-
m = re.search(r'"answer"\s*:\s*"(.*)"(?=\s*})', text, re.DOTALL)
|
| 144 |
-
if m:
|
| 145 |
-
return {"action": "submit_answer", "answer": m.group(1)}
|
| 146 |
-
|
| 147 |
-
raise ValueError(f"Could not extract value for action_type={action_type!r}")
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
def parse_model_action(response_text: str) -> dict:
|
| 151 |
-
"""
|
| 152 |
-
Parse a raw LLM response into an action dict.
|
| 153 |
-
|
| 154 |
-
Pipeline (mirrors LangChain's JsonOutputParser internals):
|
| 155 |
-
1. _preprocess – fix markdown fences, double braces, Python literals …
|
| 156 |
-
2. _sanitize_all_string_values – escape unescaped quotes/newlines inside values
|
| 157 |
-
3. _extract_json_blob – strip surrounding prose
|
| 158 |
-
4. _parse_partial_json – close truncated JSON with a stack algorithm
|
| 159 |
-
|
| 160 |
-
Each strategy is tried independently so a failure in one doesn't block others.
|
| 161 |
-
"""
|
| 162 |
-
text = response_text.strip()
|
| 163 |
-
|
| 164 |
-
strategies = [
|
| 165 |
-
lambda t: _parse_partial_json(t),
|
| 166 |
-
lambda t: _parse_partial_json(_sanitize_all_string_values(_preprocess(t))),
|
| 167 |
-
lambda t: _parse_partial_json(_sanitize_all_string_values(_preprocess(_extract_json_blob(t)))),
|
| 168 |
-
lambda t: _parse_partial_json(_sanitize_all_string_values(_extract_json_blob(_preprocess(t)))),
|
| 169 |
-
lambda t: _parse_partial_json(_sanitize_all_string_values(t)),
|
| 170 |
-
lambda t: _extract_fields_direct(_preprocess(_extract_json_blob(t))),
|
| 171 |
-
lambda t: _extract_fields_direct(_extract_json_blob(t)),
|
| 172 |
-
]
|
| 173 |
-
|
| 174 |
-
for strategy in strategies:
|
| 175 |
-
try:
|
| 176 |
-
return strategy(text)
|
| 177 |
-
except (json.JSONDecodeError, ValueError):
|
| 178 |
-
continue
|
| 179 |
-
|
| 180 |
-
print(f"JSON Decoding Error while parsing action in response text: {response_text}")
|
| 181 |
-
return json.loads(FALLBACK_ACTION)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
inference.py
DELETED
|
@@ -1,172 +0,0 @@
|
|
| 1 |
-
from typing import Any, List
|
| 2 |
-
|
| 3 |
-
try:
|
| 4 |
-
from dotenv import load_dotenv
|
| 5 |
-
|
| 6 |
-
load_dotenv()
|
| 7 |
-
except ImportError:
|
| 8 |
-
pass
|
| 9 |
-
|
| 10 |
-
from openai import OpenAI
|
| 11 |
-
|
| 12 |
-
from client import DataAnalysisClient
|
| 13 |
-
from helpers.constants import *
|
| 14 |
-
from helpers.logging import log_end, log_start, log_step, safe_score
|
| 15 |
-
from helpers.prompts import SYSTEM_PROMPT
|
| 16 |
-
from helpers.response_parser import FALLBACK_ACTION, parse_model_action
|
| 17 |
-
from models import DataAction
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
def run_task(openai_client: OpenAI, env_client: Any, task_id: int) -> float:
|
| 21 |
-
"""Run a single task episode using the language model as the agent.
|
| 22 |
-
|
| 23 |
-
Args:
|
| 24 |
-
openai_client: Configured OpenAI-compatible client.
|
| 25 |
-
env_client: Connected DataAnalysisClient (sync wrapper).
|
| 26 |
-
task_id: Task to evaluate (1 - 6)
|
| 27 |
-
|
| 28 |
-
Returns:
|
| 29 |
-
Final clamped score for this task in [0.05, 0.95].
|
| 30 |
-
"""
|
| 31 |
-
try:
|
| 32 |
-
result = env_client.reset(task_id=task_id)
|
| 33 |
-
except Exception as exc:
|
| 34 |
-
print(f"[DEBUG] env reset failed: {exc}", flush=True)
|
| 35 |
-
log_start(task=str(task_id), env=ENV_SERVER_URL, model=MODEL_NAME)
|
| 36 |
-
log_end(task_id=task_id, score=safe_score(0.0), steps=0)
|
| 37 |
-
return safe_score(0.0)
|
| 38 |
-
|
| 39 |
-
obs = result.observation
|
| 40 |
-
rewards: List[float] = []
|
| 41 |
-
|
| 42 |
-
messages = [
|
| 43 |
-
{"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]},
|
| 44 |
-
{
|
| 45 |
-
"role": "user",
|
| 46 |
-
"content": [
|
| 47 |
-
{
|
| 48 |
-
"type": "text",
|
| 49 |
-
"text": f"Task: {obs.task_description}\n\nDataset Info:\n{obs.dataset_info}",
|
| 50 |
-
}
|
| 51 |
-
],
|
| 52 |
-
},
|
| 53 |
-
]
|
| 54 |
-
|
| 55 |
-
log_start(task=str(task_id), env=ENV_SERVER_URL, model=MODEL_NAME)
|
| 56 |
-
|
| 57 |
-
for step in range(MAX_STEPS):
|
| 58 |
-
try:
|
| 59 |
-
completion = openai_client.chat.completions.create(
|
| 60 |
-
model=MODEL_NAME,
|
| 61 |
-
messages=messages,
|
| 62 |
-
temperature=TEMPERATURE,
|
| 63 |
-
max_tokens=MAX_TOKENS,
|
| 64 |
-
stream=False,
|
| 65 |
-
)
|
| 66 |
-
response_text = completion.choices[0].message.content or ""
|
| 67 |
-
except Exception as exc:
|
| 68 |
-
print(f"[DEBUG] Model request failed: {exc}", flush=True)
|
| 69 |
-
response_text = FALLBACK_ACTION
|
| 70 |
-
|
| 71 |
-
action = parse_model_action(response_text)
|
| 72 |
-
action_type = action.get("action", "")
|
| 73 |
-
|
| 74 |
-
if action_type == "execute_code":
|
| 75 |
-
try:
|
| 76 |
-
exec_result = env_client.step(DataAction(action_type="execute_code", code=action.get("code", "")))
|
| 77 |
-
exec_obs = exec_result.observation
|
| 78 |
-
reward = exec_result.reward or 0.0
|
| 79 |
-
done = exec_result.done
|
| 80 |
-
except Exception as exc:
|
| 81 |
-
print(f"[DEBUG] env step failed: {exc}", flush=True)
|
| 82 |
-
log_step(step=step + 1, action=action_type, reward=0.0, done=False, error=str(exc))
|
| 83 |
-
rewards.append(0.0)
|
| 84 |
-
continue
|
| 85 |
-
|
| 86 |
-
rewards.append(reward)
|
| 87 |
-
error = exec_obs.error if not exec_obs.success else None
|
| 88 |
-
result_text = f"Output: {exec_obs.output}" if not exec_obs.error else f"Error: {exec_obs.error}"
|
| 89 |
-
log_step(step=step + 1, action=action_type, reward=reward, done=done, error=error)
|
| 90 |
-
|
| 91 |
-
messages.append({"role": "assistant", "content": response_text})
|
| 92 |
-
messages.append({"role": "user", "content": [{"type": "text", "text": result_text}]})
|
| 93 |
-
|
| 94 |
-
elif action_type == "submit_answer":
|
| 95 |
-
try:
|
| 96 |
-
submit_result = env_client.step(
|
| 97 |
-
DataAction(action_type="submit_answer", answer=action.get("answer", ""))
|
| 98 |
-
)
|
| 99 |
-
submit_obs = submit_result.observation
|
| 100 |
-
raw_score = float(submit_obs.metadata.get("score", 0.0) if submit_obs.metadata else submit_result.reward)
|
| 101 |
-
except Exception as exc:
|
| 102 |
-
print(f"[DEBUG] env step failed: {exc}", flush=True)
|
| 103 |
-
log_step(step=step + 1, action=action_type, reward=0.0, done=True, error=str(exc))
|
| 104 |
-
final_score = safe_score(sum(rewards) / len(rewards)) if rewards else safe_score(0.0)
|
| 105 |
-
log_end(task_id=task_id, score=final_score, steps=step + 1)
|
| 106 |
-
return final_score
|
| 107 |
-
|
| 108 |
-
clamped = safe_score(raw_score)
|
| 109 |
-
rewards.append(clamped)
|
| 110 |
-
log_step(step=step + 1, action=action_type, reward=clamped, done=True, error=None)
|
| 111 |
-
final_score = safe_score(sum(rewards) / len(rewards))
|
| 112 |
-
log_end(task_id=task_id, score=final_score, steps=step + 1)
|
| 113 |
-
return final_score
|
| 114 |
-
|
| 115 |
-
else:
|
| 116 |
-
log_step(
|
| 117 |
-
step=step + 1,
|
| 118 |
-
action=action_type or "unknown",
|
| 119 |
-
reward=0.0,
|
| 120 |
-
done=False,
|
| 121 |
-
error=f"unknown action '{action_type}'",
|
| 122 |
-
)
|
| 123 |
-
messages.append({"role": "assistant", "content": response_text})
|
| 124 |
-
messages.append(
|
| 125 |
-
{
|
| 126 |
-
"role": "user",
|
| 127 |
-
"content": [
|
| 128 |
-
{
|
| 129 |
-
"type": "text",
|
| 130 |
-
"text": f"Unknown action '{action_type}'. Use 'execute_code' or 'submit_answer'.",
|
| 131 |
-
}
|
| 132 |
-
],
|
| 133 |
-
}
|
| 134 |
-
)
|
| 135 |
-
|
| 136 |
-
# Max steps reached without submission
|
| 137 |
-
final_score = safe_score(sum(rewards) / len(rewards)) if rewards else safe_score(0.0)
|
| 138 |
-
log_end(task_id=task_id, score=final_score, steps=MAX_STEPS)
|
| 139 |
-
return final_score
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
def main():
|
| 143 |
-
"""Run inference across all 6 tasks and report scores."""
|
| 144 |
-
print("Executing Data Analysis Environment")
|
| 145 |
-
openai_client = OpenAI(api_key=API_KEY, base_url=API_BASE_URL)
|
| 146 |
-
scores = {}
|
| 147 |
-
difficulties = {
|
| 148 |
-
1: "Easy_TopRevenueCategoryTask",
|
| 149 |
-
2: "Medium_CityRevenueShareTask",
|
| 150 |
-
3: "Medium_RepeatCustomerCohortTask",
|
| 151 |
-
4: "Hard_MonthlyRevenueRatioTask",
|
| 152 |
-
5: "Hard_CustomerLoyaltyRevenueTask",
|
| 153 |
-
6: "Hard_SupplierProfitabilityTask",
|
| 154 |
-
}
|
| 155 |
-
|
| 156 |
-
with DataAnalysisClient(base_url=ENV_SERVER_URL).sync() as env_client:
|
| 157 |
-
for task_id in difficulties.keys():
|
| 158 |
-
score = run_task(openai_client=openai_client, env_client=env_client, task_id=task_id)
|
| 159 |
-
scores[task_id] = score
|
| 160 |
-
|
| 161 |
-
print("\n" + "=" * 55)
|
| 162 |
-
print("RESULTS")
|
| 163 |
-
print("=" * 55)
|
| 164 |
-
for task_id, score in scores.items():
|
| 165 |
-
print(f" Task {task_id} ({difficulties[task_id]:6s}): {score:.2f}")
|
| 166 |
-
avg = sum(scores.values()) / len(scores)
|
| 167 |
-
print(f"\n Average Score : {avg:.2f}")
|
| 168 |
-
print("=" * 55)
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
if __name__ == "__main__":
|
| 172 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models.py
DELETED
|
@@ -1,52 +0,0 @@
|
|
| 1 |
-
from typing import Literal, Optional
|
| 2 |
-
|
| 3 |
-
from openenv.core.env_server import Action, Observation, State
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
class DataAction(Action):
|
| 7 |
-
"""Agent action for the data analysis environment.
|
| 8 |
-
|
| 9 |
-
The agent can either execute pandas code against the loaded dataset
|
| 10 |
-
or submit a final answer to be graded.
|
| 11 |
-
|
| 12 |
-
Attributes:
|
| 13 |
-
action_type: Whether to execute code or submit an answer.
|
| 14 |
-
code: Python/pandas code to execute (required when action_type is "execute_code").
|
| 15 |
-
answer: Final answer string (required when action_type is "submit_answer").
|
| 16 |
-
"""
|
| 17 |
-
|
| 18 |
-
action_type: Literal["execute_code", "submit_answer"]
|
| 19 |
-
code: Optional[str] = None
|
| 20 |
-
answer: Optional[str] = None
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
class DataObservation(Observation):
|
| 24 |
-
"""Observation returned after each step or reset.
|
| 25 |
-
|
| 26 |
-
Attributes:
|
| 27 |
-
output: String output from code execution or environment messages.
|
| 28 |
-
success: Whether the last action executed without errors.
|
| 29 |
-
error: Error message if the last action failed.
|
| 30 |
-
task_description: The task question, populated on reset.
|
| 31 |
-
dataset_info: Column names and dtypes summary, populated on reset.
|
| 32 |
-
"""
|
| 33 |
-
|
| 34 |
-
output: str = ""
|
| 35 |
-
success: bool = True
|
| 36 |
-
error: Optional[str] = None
|
| 37 |
-
task_description: str = ""
|
| 38 |
-
dataset_info: str = ""
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
class DataState(State):
|
| 42 |
-
"""Episode state for the data analysis environment.
|
| 43 |
-
|
| 44 |
-
Attributes:
|
| 45 |
-
task_id: The current task being evaluated (1, 2, or 3).
|
| 46 |
-
answer_submitted: Whether the agent has submitted a final answer.
|
| 47 |
-
final_score: The graded score after answer submission (0.0 to 1.0).
|
| 48 |
-
"""
|
| 49 |
-
|
| 50 |
-
task_id: int = 1
|
| 51 |
-
answer_submitted: bool = False
|
| 52 |
-
final_score: float = 0.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
openenv.yaml
DELETED
|
@@ -1,21 +0,0 @@
|
|
| 1 |
-
spec_version: 1
|
| 2 |
-
name: data_analysis_env
|
| 3 |
-
version: "0.1.0"
|
| 4 |
-
description: "RL environment for training data analysis agents on business datasets"
|
| 5 |
-
type: space
|
| 6 |
-
runtime: fastapi
|
| 7 |
-
app: server.app:app
|
| 8 |
-
port: 8000
|
| 9 |
-
tasks:
|
| 10 |
-
- id: task_easy
|
| 11 |
-
grader: tasks.task_easy:TopRevenueCategoryTask
|
| 12 |
-
- id: task_medium
|
| 13 |
-
grader: tasks.task_medium:CityRevenueShareTask
|
| 14 |
-
- id: task_medium_2
|
| 15 |
-
grader: tasks.task_medium_2:MonthlyRevenueRatioTask
|
| 16 |
-
- id: task_hard
|
| 17 |
-
grader: tasks.task_hard:RepeatCustomerCohortTask
|
| 18 |
-
- id: task_hard_2
|
| 19 |
-
grader: tasks.task_hard_2:CustomerLoyaltyRevenueTask
|
| 20 |
-
- id: task_hard_3
|
| 21 |
-
grader: tasks.task_hard_3:SupplierProfitabilityTask
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pyproject.toml
DELETED
|
@@ -1,25 +0,0 @@
|
|
| 1 |
-
[project]
|
| 2 |
-
name = "openenv-data-analysis-env"
|
| 3 |
-
version = "0.1.0"
|
| 4 |
-
description = "RL environment for training data analysis agents on business datasets"
|
| 5 |
-
readme = "README.md"
|
| 6 |
-
requires-python = ">=3.13"
|
| 7 |
-
dependencies = [
|
| 8 |
-
"openenv-core>=0.2.3",
|
| 9 |
-
"fastapi>=0.115.0",
|
| 10 |
-
"pydantic>=2.0.0",
|
| 11 |
-
"uvicorn>=0.24.0",
|
| 12 |
-
"pandas>=2.0.0",
|
| 13 |
-
"numpy>=1.24.0",
|
| 14 |
-
"openai>=1.0.0",
|
| 15 |
-
"black>=26.3.1",
|
| 16 |
-
"isort>=8.0.1",
|
| 17 |
-
"python-dotenv>=1.2.2",
|
| 18 |
-
]
|
| 19 |
-
|
| 20 |
-
[project.scripts]
|
| 21 |
-
server = "server.app:main"
|
| 22 |
-
|
| 23 |
-
[tool.black]
|
| 24 |
-
line-length = 120
|
| 25 |
-
target-version = ["py313"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
server/Dockerfile
DELETED
|
@@ -1,58 +0,0 @@
|
|
| 1 |
-
# Multi-stage build for the Data Analysis Agent environment
|
| 2 |
-
ARG BASE_IMAGE=ghcr.io/meta-pytorch/openenv-base:latest
|
| 3 |
-
FROM ${BASE_IMAGE} AS builder
|
| 4 |
-
|
| 5 |
-
WORKDIR /app
|
| 6 |
-
|
| 7 |
-
# Copy environment code
|
| 8 |
-
COPY .. /app/env
|
| 9 |
-
|
| 10 |
-
WORKDIR /app/env
|
| 11 |
-
|
| 12 |
-
# Ensure uv is available
|
| 13 |
-
RUN if ! command -v uv >/dev/null 2>&1; then \
|
| 14 |
-
curl -LsSf https://astral.sh/uv/install.sh | sh && \
|
| 15 |
-
mv /root/.local/bin/uv /usr/local/bin/uv && \
|
| 16 |
-
mv /root/.local/bin/uvx /usr/local/bin/uvx; \
|
| 17 |
-
fi
|
| 18 |
-
|
| 19 |
-
# Install git for build-time dependencies
|
| 20 |
-
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 21 |
-
git \
|
| 22 |
-
&& rm -rf /var/lib/apt/lists/*
|
| 23 |
-
|
| 24 |
-
# Install dependencies with cache
|
| 25 |
-
RUN --mount=type=cache,target=/root/.cache/uv \
|
| 26 |
-
if [ -f uv.lock ]; then \
|
| 27 |
-
uv sync --frozen --no-install-project --no-editable; \
|
| 28 |
-
else \
|
| 29 |
-
uv sync --no-install-project --no-editable; \
|
| 30 |
-
fi
|
| 31 |
-
|
| 32 |
-
RUN --mount=type=cache,target=/root/.cache/uv \
|
| 33 |
-
if [ -f uv.lock ]; then \
|
| 34 |
-
uv sync --frozen --no-editable; \
|
| 35 |
-
else \
|
| 36 |
-
uv sync --no-editable; \
|
| 37 |
-
fi
|
| 38 |
-
|
| 39 |
-
# Final runtime stage
|
| 40 |
-
FROM ${BASE_IMAGE}
|
| 41 |
-
|
| 42 |
-
WORKDIR /app
|
| 43 |
-
|
| 44 |
-
# Copy virtual environment and code
|
| 45 |
-
COPY --from=builder /app/env/.venv /app/.venv
|
| 46 |
-
COPY --from=builder /app/env /app/env
|
| 47 |
-
|
| 48 |
-
ENV PATH="/app/.venv/bin:$PATH"
|
| 49 |
-
ENV PYTHONPATH="/app/env:$PYTHONPATH"
|
| 50 |
-
|
| 51 |
-
# Health check — uses /docs which FastAPI always exposes
|
| 52 |
-
HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
|
| 53 |
-
CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:8000/docs')" || exit 1
|
| 54 |
-
|
| 55 |
-
EXPOSE 8000
|
| 56 |
-
|
| 57 |
-
# Run server
|
| 58 |
-
CMD ["sh", "-c", "cd /app/env && uvicorn server.app:app --host 0.0.0.0 --port 8000"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
server/__init__.py
DELETED
|
File without changes
|
server/app.py
DELETED
|
@@ -1,15 +0,0 @@
|
|
| 1 |
-
from models import DataAction, DataObservation
|
| 2 |
-
from openenv.core.env_server import create_app
|
| 3 |
-
from server.data_analysis_env import DataAnalysisEnv
|
| 4 |
-
|
| 5 |
-
app = create_app(DataAnalysisEnv, DataAction, DataObservation, env_name="data_analysis_env", max_concurrent_envs=3)
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
def main():
|
| 9 |
-
import uvicorn
|
| 10 |
-
|
| 11 |
-
uvicorn.run(app, host="0.0.0.0", port=8000)
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
if __name__ == "__main__":
|
| 15 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
server/data_analysis_env.py
DELETED
|
@@ -1,296 +0,0 @@
|
|
| 1 |
-
import io
|
| 2 |
-
import sqlite3
|
| 3 |
-
import sys
|
| 4 |
-
import uuid
|
| 5 |
-
from pathlib import Path
|
| 6 |
-
from typing import Any, Optional
|
| 7 |
-
|
| 8 |
-
import numpy as np
|
| 9 |
-
import pandas as pd
|
| 10 |
-
|
| 11 |
-
from models import DataAction, DataObservation, DataState
|
| 12 |
-
from openenv.core.env_server import Environment
|
| 13 |
-
from tasks import TASKS
|
| 14 |
-
|
| 15 |
-
DATASET_PATH = Path(__file__).resolve().parent.parent / "datasets" / "sales.csv"
|
| 16 |
-
DB_PATH = Path(__file__).resolve().parent.parent / "datasets" / "store_data.db"
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
class DataAnalysisEnv(Environment):
|
| 20 |
-
"""Environment for training data analysis agents on business datasets.
|
| 21 |
-
|
| 22 |
-
The agent receives a task question and can execute pandas code against
|
| 23 |
-
a pre-loaded DataFrame. The episode ends when the agent submits an answer
|
| 24 |
-
or exceeds the maximum number of steps.
|
| 25 |
-
|
| 26 |
-
Attributes:
|
| 27 |
-
MAX_STEPS: Maximum steps before forced episode termination.
|
| 28 |
-
"""
|
| 29 |
-
|
| 30 |
-
MAX_STEPS = 20
|
| 31 |
-
SUPPORTS_CONCURRENT_SESSIONS = True
|
| 32 |
-
|
| 33 |
-
def __init__(self):
|
| 34 |
-
super().__init__()
|
| 35 |
-
self._source_df = pd.read_csv(DATASET_PATH)
|
| 36 |
-
self._df = self._source_df.copy()
|
| 37 |
-
self._state = DataState()
|
| 38 |
-
self._task = None
|
| 39 |
-
self._exec_namespace = {}
|
| 40 |
-
|
| 41 |
-
def _build_namespace(self) -> dict:
|
| 42 |
-
"""Build a restricted execution namespace for agent code.
|
| 43 |
-
|
| 44 |
-
The namespace includes only pandas, numpy, and the dataset copy.
|
| 45 |
-
Dangerous builtins like open, exec, eval, and __import__ are removed.
|
| 46 |
-
|
| 47 |
-
Returns:
|
| 48 |
-
A dictionary to use as the globals for exec().
|
| 49 |
-
"""
|
| 50 |
-
safe_builtins = (
|
| 51 |
-
{
|
| 52 |
-
k: v
|
| 53 |
-
for k, v in __builtins__.items()
|
| 54 |
-
if k not in ("open", "exec", "eval", "__import__", "compile", "exit", "quit")
|
| 55 |
-
}
|
| 56 |
-
if isinstance(__builtins__, dict)
|
| 57 |
-
else {
|
| 58 |
-
k: getattr(__builtins__, k)
|
| 59 |
-
for k in dir(__builtins__)
|
| 60 |
-
if k not in ("open", "exec", "eval", "__import__", "compile", "exit", "quit") and not k.startswith("_")
|
| 61 |
-
}
|
| 62 |
-
)
|
| 63 |
-
return {
|
| 64 |
-
"__builtins__": safe_builtins,
|
| 65 |
-
"df": self._df.copy(),
|
| 66 |
-
"pd": pd,
|
| 67 |
-
"np": np,
|
| 68 |
-
"sqlite3": sqlite3,
|
| 69 |
-
"db_path": str(DB_PATH),
|
| 70 |
-
}
|
| 71 |
-
|
| 72 |
-
def _dataset_info(self) -> str:
|
| 73 |
-
"""Generate a summary of the dataset schema for the agent.
|
| 74 |
-
|
| 75 |
-
Includes the sales DataFrame schema plus the SQLite database table schemas
|
| 76 |
-
so the agent knows what data is available and where to find it.
|
| 77 |
-
|
| 78 |
-
Returns:
|
| 79 |
-
A string describing column names, dtypes, row count, a sample for df,
|
| 80 |
-
and table schemas for the SQLite database.
|
| 81 |
-
"""
|
| 82 |
-
buf = io.StringIO()
|
| 83 |
-
self._df.info(buf=buf)
|
| 84 |
-
info_str = buf.getvalue()
|
| 85 |
-
sample = self._df.head(3).to_string()
|
| 86 |
-
df_section = f"=== df (pandas DataFrame, pre-loaded from sales CSV) ===\nShape: {self._df.shape}\n{info_str}\nSample rows:\n{sample}"
|
| 87 |
-
|
| 88 |
-
try:
|
| 89 |
-
conn = sqlite3.connect(DB_PATH)
|
| 90 |
-
cursor = conn.cursor()
|
| 91 |
-
cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
|
| 92 |
-
tables = [row[0] for row in cursor.fetchall()]
|
| 93 |
-
db_lines = ["\n=== SQLite database (accessible via sqlite3.connect(db_path)) ==="]
|
| 94 |
-
for table in tables:
|
| 95 |
-
cursor.execute(f"PRAGMA table_info({table})")
|
| 96 |
-
cols = [(row[1], row[2]) for row in cursor.fetchall()]
|
| 97 |
-
cursor.execute(f"SELECT COUNT(*) FROM {table}")
|
| 98 |
-
count = cursor.fetchone()[0]
|
| 99 |
-
col_str = ", ".join(f"{c} ({t})" for c, t in cols)
|
| 100 |
-
db_lines.append(f" Table '{table}' ({count} rows): {col_str}")
|
| 101 |
-
conn.close()
|
| 102 |
-
db_section = "\n".join(db_lines)
|
| 103 |
-
except Exception:
|
| 104 |
-
db_section = "\n=== SQLite database: schema unavailable ==="
|
| 105 |
-
|
| 106 |
-
return f"{df_section}\n{db_section}"
|
| 107 |
-
|
| 108 |
-
def reset(
|
| 109 |
-
self,
|
| 110 |
-
seed: Optional[int] = None,
|
| 111 |
-
episode_id: Optional[str] = None,
|
| 112 |
-
**kwargs: Any,
|
| 113 |
-
) -> DataObservation:
|
| 114 |
-
"""Reset the environment for a new episode.
|
| 115 |
-
|
| 116 |
-
Args:
|
| 117 |
-
seed: Optional random seed (unused, kept for interface compliance).
|
| 118 |
-
episode_id: Optional episode identifier; generated if not provided.
|
| 119 |
-
**kwargs: Additional keyword arguments. Supports 'task_id' (int, 1-6).
|
| 120 |
-
|
| 121 |
-
Returns:
|
| 122 |
-
An initial observation with the task description and dataset info.
|
| 123 |
-
"""
|
| 124 |
-
task_id = kwargs.get("task_id", 1)
|
| 125 |
-
eid = episode_id or str(uuid.uuid4())
|
| 126 |
-
|
| 127 |
-
self._df = self._source_df.copy()
|
| 128 |
-
self._state = DataState(episode_id=eid, step_count=0, task_id=task_id)
|
| 129 |
-
self._exec_namespace = self._build_namespace()
|
| 130 |
-
|
| 131 |
-
task_cls = TASKS.get(task_id)
|
| 132 |
-
if task_cls is None:
|
| 133 |
-
return DataObservation(
|
| 134 |
-
done=True,
|
| 135 |
-
reward=0.0,
|
| 136 |
-
success=False,
|
| 137 |
-
error=f"Invalid task_id: {task_id}. Must be 1–6.",
|
| 138 |
-
)
|
| 139 |
-
self._task = task_cls(self._df)
|
| 140 |
-
|
| 141 |
-
return DataObservation(
|
| 142 |
-
done=False,
|
| 143 |
-
reward=0.0,
|
| 144 |
-
output="Environment ready. Use 'execute_code' actions to explore the dataset, then 'submit_answer' with your result.",
|
| 145 |
-
task_description=self._task.description,
|
| 146 |
-
dataset_info=self._dataset_info(),
|
| 147 |
-
metadata={"task_id": task_id, "difficulty": self._task.difficulty},
|
| 148 |
-
)
|
| 149 |
-
|
| 150 |
-
def step(
|
| 151 |
-
self,
|
| 152 |
-
action: DataAction,
|
| 153 |
-
timeout_s: Optional[float] = None,
|
| 154 |
-
**kwargs: Any,
|
| 155 |
-
) -> DataObservation:
|
| 156 |
-
"""Execute one step in the environment.
|
| 157 |
-
|
| 158 |
-
Handles two action types:
|
| 159 |
-
- execute_code: runs pandas code in a sandboxed namespace
|
| 160 |
-
- submit_answer: grades the agent's final answer and ends the episode
|
| 161 |
-
|
| 162 |
-
Args:
|
| 163 |
-
action: The agent's action (execute_code or submit_answer).
|
| 164 |
-
timeout_s: Optional timeout in seconds (unused).
|
| 165 |
-
**kwargs: Additional keyword arguments.
|
| 166 |
-
|
| 167 |
-
Returns:
|
| 168 |
-
An observation with execution output, reward, and done flag.
|
| 169 |
-
"""
|
| 170 |
-
self._state.step_count += 1
|
| 171 |
-
|
| 172 |
-
if self._state.answer_submitted:
|
| 173 |
-
return DataObservation(
|
| 174 |
-
done=True,
|
| 175 |
-
reward=0.0,
|
| 176 |
-
output="Episode is already finished. Call reset() to start a new one.",
|
| 177 |
-
success=False,
|
| 178 |
-
)
|
| 179 |
-
|
| 180 |
-
# Check max steps
|
| 181 |
-
if self._state.step_count >= self.MAX_STEPS and action.action_type != "submit_answer":
|
| 182 |
-
self._state.answer_submitted = True
|
| 183 |
-
return DataObservation(
|
| 184 |
-
done=True,
|
| 185 |
-
reward=0.0,
|
| 186 |
-
output=f"Maximum steps ({self.MAX_STEPS}) exceeded without submitting an answer.",
|
| 187 |
-
success=False,
|
| 188 |
-
metadata={"reason": "max_steps_exceeded"},
|
| 189 |
-
)
|
| 190 |
-
|
| 191 |
-
if action.action_type == "execute_code":
|
| 192 |
-
return self._handle_execute_code(action)
|
| 193 |
-
elif action.action_type == "submit_answer":
|
| 194 |
-
return self._handle_submit_answer(action)
|
| 195 |
-
else:
|
| 196 |
-
return DataObservation(
|
| 197 |
-
done=False,
|
| 198 |
-
reward=-0.05,
|
| 199 |
-
success=False,
|
| 200 |
-
error=f"Unknown action_type: {action.action_type}",
|
| 201 |
-
)
|
| 202 |
-
|
| 203 |
-
def _handle_execute_code(self, action: DataAction) -> DataObservation:
|
| 204 |
-
"""Execute pandas code in the sandboxed namespace.
|
| 205 |
-
|
| 206 |
-
Args:
|
| 207 |
-
action: The action containing the code to execute.
|
| 208 |
-
|
| 209 |
-
Returns:
|
| 210 |
-
An observation with stdout output or error message.
|
| 211 |
-
"""
|
| 212 |
-
if not action.code:
|
| 213 |
-
return DataObservation(
|
| 214 |
-
done=False,
|
| 215 |
-
reward=-0.05,
|
| 216 |
-
success=False,
|
| 217 |
-
error="No code provided for execute_code action.",
|
| 218 |
-
)
|
| 219 |
-
|
| 220 |
-
stdout_capture = io.StringIO()
|
| 221 |
-
old_stdout = sys.stdout
|
| 222 |
-
try:
|
| 223 |
-
sys.stdout = stdout_capture
|
| 224 |
-
exec(action.code, self._exec_namespace)
|
| 225 |
-
sys.stdout = old_stdout
|
| 226 |
-
output = stdout_capture.getvalue()
|
| 227 |
-
|
| 228 |
-
# If code produced no print output, try to get the last expression value
|
| 229 |
-
if not output.strip():
|
| 230 |
-
try:
|
| 231 |
-
result = eval(action.code.strip().split("\n")[-1], self._exec_namespace)
|
| 232 |
-
if result is not None:
|
| 233 |
-
output = str(result)
|
| 234 |
-
except Exception:
|
| 235 |
-
output = "(Code executed successfully with no output)"
|
| 236 |
-
|
| 237 |
-
return DataObservation(
|
| 238 |
-
done=False,
|
| 239 |
-
reward=0.05,
|
| 240 |
-
output=output[:5000],
|
| 241 |
-
success=True,
|
| 242 |
-
metadata={"steps_remaining": self.MAX_STEPS - self._state.step_count},
|
| 243 |
-
)
|
| 244 |
-
except Exception as e:
|
| 245 |
-
sys.stdout = old_stdout
|
| 246 |
-
return DataObservation(
|
| 247 |
-
done=False,
|
| 248 |
-
reward=-0.05,
|
| 249 |
-
success=False,
|
| 250 |
-
error=f"{type(e).__name__}: {e}",
|
| 251 |
-
output="",
|
| 252 |
-
metadata={"steps_remaining": self.MAX_STEPS - self._state.step_count},
|
| 253 |
-
)
|
| 254 |
-
|
| 255 |
-
def _handle_submit_answer(self, action: DataAction) -> DataObservation:
|
| 256 |
-
"""Grade the agent's submitted answer and end the episode.
|
| 257 |
-
|
| 258 |
-
Args:
|
| 259 |
-
action: The action containing the answer to grade.
|
| 260 |
-
|
| 261 |
-
Returns:
|
| 262 |
-
An observation with the final score and done=True.
|
| 263 |
-
"""
|
| 264 |
-
if not action.answer:
|
| 265 |
-
return DataObservation(
|
| 266 |
-
done=False,
|
| 267 |
-
reward=-0.05,
|
| 268 |
-
success=False,
|
| 269 |
-
error="No answer provided for submit_answer action.",
|
| 270 |
-
)
|
| 271 |
-
|
| 272 |
-
self._state.answer_submitted = True
|
| 273 |
-
raw_score = self._task.grade(action.answer)
|
| 274 |
-
score = max(0.05, min(0.95, raw_score))
|
| 275 |
-
self._state.final_score = score
|
| 276 |
-
|
| 277 |
-
return DataObservation(
|
| 278 |
-
done=True,
|
| 279 |
-
reward=score,
|
| 280 |
-
output=f"Answer submitted. Score: {score:.2f}/1.00",
|
| 281 |
-
success=True,
|
| 282 |
-
metadata={
|
| 283 |
-
"score": score,
|
| 284 |
-
"expected_answer": self._task.expected_answer(),
|
| 285 |
-
"submitted_answer": action.answer,
|
| 286 |
-
},
|
| 287 |
-
)
|
| 288 |
-
|
| 289 |
-
@property
|
| 290 |
-
def state(self) -> DataState:
|
| 291 |
-
"""Return the current episode state.
|
| 292 |
-
|
| 293 |
-
Returns:
|
| 294 |
-
The current DataState with episode_id, step_count, task_id, etc.
|
| 295 |
-
"""
|
| 296 |
-
return self._state
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tasks/__init__.py
DELETED
|
@@ -1,29 +0,0 @@
|
|
| 1 |
-
"""Task definitions for the Data Analysis Agent environment."""
|
| 2 |
-
|
| 3 |
-
from tasks.base_task import BaseTask
|
| 4 |
-
from tasks.task_easy import TopRevenueCategoryTask
|
| 5 |
-
from tasks.task_hard import RepeatCustomerCohortTask
|
| 6 |
-
from tasks.task_hard_2 import CustomerLoyaltyRevenueTask
|
| 7 |
-
from tasks.task_hard_3 import SupplierProfitabilityTask
|
| 8 |
-
from tasks.task_medium import CityRevenueShareTask
|
| 9 |
-
from tasks.task_medium_2 import MonthlyRevenueRatioTask
|
| 10 |
-
|
| 11 |
-
TASKS = {
|
| 12 |
-
1: TopRevenueCategoryTask,
|
| 13 |
-
2: CityRevenueShareTask,
|
| 14 |
-
3: RepeatCustomerCohortTask,
|
| 15 |
-
4: MonthlyRevenueRatioTask,
|
| 16 |
-
5: CustomerLoyaltyRevenueTask,
|
| 17 |
-
6: SupplierProfitabilityTask,
|
| 18 |
-
}
|
| 19 |
-
|
| 20 |
-
__all__ = [
|
| 21 |
-
"BaseTask",
|
| 22 |
-
"TASKS",
|
| 23 |
-
"TopRevenueCategoryTask",
|
| 24 |
-
"CityRevenueShareTask",
|
| 25 |
-
"RepeatCustomerCohortTask",
|
| 26 |
-
"MonthlyRevenueRatioTask",
|
| 27 |
-
"CustomerLoyaltyRevenueTask",
|
| 28 |
-
"SupplierProfitabilityTask",
|
| 29 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tasks/base_task.py
DELETED
|
@@ -1,51 +0,0 @@
|
|
| 1 |
-
from abc import ABC, abstractmethod
|
| 2 |
-
|
| 3 |
-
import pandas as pd
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
class BaseTask(ABC):
|
| 7 |
-
"""Base class for all data analysis tasks.
|
| 8 |
-
|
| 9 |
-
Subclasses must implement the question, compute the expected answer
|
| 10 |
-
from the dataset, and provide a grading function.
|
| 11 |
-
|
| 12 |
-
Attributes:
|
| 13 |
-
df: The pandas DataFrame containing the dataset.
|
| 14 |
-
"""
|
| 15 |
-
|
| 16 |
-
def __init__(self, df: pd.DataFrame):
|
| 17 |
-
self.df = df
|
| 18 |
-
|
| 19 |
-
@property
|
| 20 |
-
@abstractmethod
|
| 21 |
-
def task_id(self) -> int:
|
| 22 |
-
"""Return the unique task identifier."""
|
| 23 |
-
|
| 24 |
-
@property
|
| 25 |
-
@abstractmethod
|
| 26 |
-
def difficulty(self) -> str:
|
| 27 |
-
"""Return the difficulty level: 'easy', 'medium', or 'hard'."""
|
| 28 |
-
|
| 29 |
-
@property
|
| 30 |
-
@abstractmethod
|
| 31 |
-
def description(self) -> str:
|
| 32 |
-
"""Return the task question shown to the agent."""
|
| 33 |
-
|
| 34 |
-
@abstractmethod
|
| 35 |
-
def expected_answer(self) -> str:
|
| 36 |
-
"""Compute and return the ground-truth answer from the dataset.
|
| 37 |
-
|
| 38 |
-
Returns:
|
| 39 |
-
The expected answer as a formatted string.
|
| 40 |
-
"""
|
| 41 |
-
|
| 42 |
-
@abstractmethod
|
| 43 |
-
def grade(self, answer: str) -> float:
|
| 44 |
-
"""Grade the agent's submitted answer.
|
| 45 |
-
|
| 46 |
-
Args:
|
| 47 |
-
answer: The agent's submitted answer string.
|
| 48 |
-
|
| 49 |
-
Returns:
|
| 50 |
-
A score between 0.0 and 1.0.
|
| 51 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tasks/task_easy.py
DELETED
|
@@ -1,53 +0,0 @@
|
|
| 1 |
-
from tasks.base_task import BaseTask
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
class TopRevenueCategoryTask(BaseTask):
|
| 5 |
-
"""Easy task: find the product category with the highest total revenue.
|
| 6 |
-
|
| 7 |
-
The agent must group the dataset by category, sum the total_price column,
|
| 8 |
-
and identify which category has the highest revenue.
|
| 9 |
-
"""
|
| 10 |
-
|
| 11 |
-
@property
|
| 12 |
-
def task_id(self) -> int:
|
| 13 |
-
"""Return the task identifier."""
|
| 14 |
-
return 1
|
| 15 |
-
|
| 16 |
-
@property
|
| 17 |
-
def difficulty(self) -> str:
|
| 18 |
-
"""Return the difficulty level."""
|
| 19 |
-
return "easy"
|
| 20 |
-
|
| 21 |
-
@property
|
| 22 |
-
def description(self) -> str:
|
| 23 |
-
"""Return the task question."""
|
| 24 |
-
return (
|
| 25 |
-
"What is the top-selling product category by total revenue? "
|
| 26 |
-
"Submit just the category name as your answer."
|
| 27 |
-
)
|
| 28 |
-
|
| 29 |
-
def expected_answer(self) -> str:
|
| 30 |
-
"""Compute the top revenue category from the dataset.
|
| 31 |
-
|
| 32 |
-
Returns:
|
| 33 |
-
The name of the category with the highest total_price sum.
|
| 34 |
-
"""
|
| 35 |
-
return self.df.groupby("category")["total_price"].sum().idxmax()
|
| 36 |
-
|
| 37 |
-
def grade(self, answer: str) -> float:
|
| 38 |
-
"""Grade the answer by case-insensitive containment check.
|
| 39 |
-
|
| 40 |
-
Accepts the answer if the expected category name appears anywhere in
|
| 41 |
-
the submitted string, so responses like 'The top category is Clothing'
|
| 42 |
-
or 'Clothing ($74,792.74)' still receive full credit.
|
| 43 |
-
|
| 44 |
-
Args:
|
| 45 |
-
answer: The agent's submitted category name.
|
| 46 |
-
|
| 47 |
-
Returns:
|
| 48 |
-
1.0 if the expected category appears in the answer, 0.0 otherwise.
|
| 49 |
-
"""
|
| 50 |
-
expected = self.expected_answer().strip().lower()
|
| 51 |
-
submitted = answer.strip().lower()
|
| 52 |
-
raw = 1.0 if expected in submitted else 0.0
|
| 53 |
-
return max(0.05, min(0.95, raw))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tasks/task_hard.py
DELETED
|
@@ -1,103 +0,0 @@
|
|
| 1 |
-
import re
|
| 2 |
-
|
| 3 |
-
import pandas as pd
|
| 4 |
-
|
| 5 |
-
from tasks.base_task import BaseTask
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
class RepeatCustomerCohortTask(BaseTask):
|
| 9 |
-
"""Hard task: find customers who ordered in both January and December.
|
| 10 |
-
|
| 11 |
-
The agent must identify customers present in both months, count them,
|
| 12 |
-
and compare their average order value to all other customers.
|
| 13 |
-
"""
|
| 14 |
-
|
| 15 |
-
@property
|
| 16 |
-
def task_id(self) -> int:
|
| 17 |
-
return 3
|
| 18 |
-
|
| 19 |
-
@property
|
| 20 |
-
def difficulty(self) -> str:
|
| 21 |
-
return "hard"
|
| 22 |
-
|
| 23 |
-
@property
|
| 24 |
-
def description(self) -> str:
|
| 25 |
-
return (
|
| 26 |
-
"How many unique customers placed orders in BOTH January and December? "
|
| 27 |
-
"What is their average order value compared to all other customers? "
|
| 28 |
-
"Submit your answer in the format: "
|
| 29 |
-
"'Cohort: N customers, Cohort AOV: $X.XX, Other AOV: $X.XX'"
|
| 30 |
-
)
|
| 31 |
-
|
| 32 |
-
def _compute_cohort(self) -> tuple[set, float, float]:
|
| 33 |
-
"""Compute the cohort of customers ordering in both January and December.
|
| 34 |
-
|
| 35 |
-
Returns:
|
| 36 |
-
A tuple of (cohort_customer_ids, cohort_aov, other_aov).
|
| 37 |
-
"""
|
| 38 |
-
df = self.df.copy()
|
| 39 |
-
df["order_date"] = pd.to_datetime(df["order_date"])
|
| 40 |
-
jan_customers = set(df[df["order_date"].dt.month == 1]["customer_id"])
|
| 41 |
-
dec_customers = set(df[df["order_date"].dt.month == 12]["customer_id"])
|
| 42 |
-
cohort = jan_customers & dec_customers
|
| 43 |
-
|
| 44 |
-
cohort_aov = df[df["customer_id"].isin(cohort)]["total_price"].mean()
|
| 45 |
-
other_aov = df[~df["customer_id"].isin(cohort)]["total_price"].mean()
|
| 46 |
-
return cohort, round(cohort_aov, 2), round(other_aov, 2)
|
| 47 |
-
|
| 48 |
-
def expected_answer(self) -> str:
|
| 49 |
-
"""Compute the expected cohort analysis answer.
|
| 50 |
-
|
| 51 |
-
Returns:
|
| 52 |
-
Formatted string like 'Cohort: 57 customers, Cohort AOV: $126.57, Other AOV: $122.94'.
|
| 53 |
-
"""
|
| 54 |
-
cohort, cohort_aov, other_aov = self._compute_cohort()
|
| 55 |
-
return f"Cohort: {len(cohort)} customers, Cohort AOV: ${cohort_aov}, Other AOV: ${other_aov}"
|
| 56 |
-
|
| 57 |
-
def grade(self, answer: str) -> float:
|
| 58 |
-
"""Grade the answer with partial credit for each of the three fields.
|
| 59 |
-
|
| 60 |
-
Scoring:
|
| 61 |
-
- 0.33 for correct customer count (exact match)
|
| 62 |
-
- 0.33 for cohort AOV within ±0.5% of expected
|
| 63 |
-
- 0.34 for other AOV within ±0.5% of expected
|
| 64 |
-
|
| 65 |
-
Args:
|
| 66 |
-
answer: The agent's submitted answer string.
|
| 67 |
-
|
| 68 |
-
Returns:
|
| 69 |
-
A score between 0.0 and 1.0.
|
| 70 |
-
"""
|
| 71 |
-
cohort, expected_cohort_aov, expected_other_aov = self._compute_cohort()
|
| 72 |
-
expected_count = len(cohort)
|
| 73 |
-
score = 0.0
|
| 74 |
-
|
| 75 |
-
# Check customer count
|
| 76 |
-
count_match = re.search(r"Cohort:\s*(\d+)\s*customers?", answer, re.IGNORECASE)
|
| 77 |
-
if count_match:
|
| 78 |
-
if int(count_match.group(1)) == expected_count:
|
| 79 |
-
score += 0.33
|
| 80 |
-
|
| 81 |
-
# Check cohort AOV
|
| 82 |
-
cohort_aov_match = re.search(r"Cohort\s+AOV:\s*\$?([\d.]+)", answer, re.IGNORECASE)
|
| 83 |
-
if cohort_aov_match:
|
| 84 |
-
try:
|
| 85 |
-
submitted = float(cohort_aov_match.group(1))
|
| 86 |
-
tolerance = expected_cohort_aov * 0.005
|
| 87 |
-
if abs(submitted - expected_cohort_aov) <= tolerance:
|
| 88 |
-
score += 0.33
|
| 89 |
-
except ValueError:
|
| 90 |
-
pass
|
| 91 |
-
|
| 92 |
-
# Check other AOV
|
| 93 |
-
other_aov_match = re.search(r"Other\s+AOV:\s*\$?([\d.]+)", answer, re.IGNORECASE)
|
| 94 |
-
if other_aov_match:
|
| 95 |
-
try:
|
| 96 |
-
submitted = float(other_aov_match.group(1))
|
| 97 |
-
tolerance = expected_other_aov * 0.005
|
| 98 |
-
if abs(submitted - expected_other_aov) <= tolerance:
|
| 99 |
-
score += 0.34
|
| 100 |
-
except ValueError:
|
| 101 |
-
pass
|
| 102 |
-
|
| 103 |
-
return max(0.05, min(0.95, score))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tasks/task_hard_2.py
DELETED
|
@@ -1,103 +0,0 @@
|
|
| 1 |
-
import re
|
| 2 |
-
import sqlite3
|
| 3 |
-
from pathlib import Path
|
| 4 |
-
|
| 5 |
-
import pandas as pd
|
| 6 |
-
|
| 7 |
-
from tasks.base_task import BaseTask
|
| 8 |
-
|
| 9 |
-
DB_PATH = Path(__file__).resolve().parent.parent / "datasets" / "store_data.db"
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
class CustomerLoyaltyRevenueTask(BaseTask):
|
| 13 |
-
"""Hard task: find the highest-revenue customer loyalty tier using cross-source data.
|
| 14 |
-
|
| 15 |
-
The agent must query the customer_profiles table from the SQLite database,
|
| 16 |
-
join it with the sales DataFrame on customer_id, and compute revenue by tier.
|
| 17 |
-
The database is accessible via sqlite3.connect(db_path) in the sandbox.
|
| 18 |
-
"""
|
| 19 |
-
|
| 20 |
-
@property
|
| 21 |
-
def task_id(self) -> int:
|
| 22 |
-
return 5
|
| 23 |
-
|
| 24 |
-
@property
|
| 25 |
-
def difficulty(self) -> str:
|
| 26 |
-
return "hard"
|
| 27 |
-
|
| 28 |
-
@property
|
| 29 |
-
def description(self) -> str:
|
| 30 |
-
return (
|
| 31 |
-
"Using the customer profiles database (connect with sqlite3.connect(db_path)), "
|
| 32 |
-
"which customer loyalty tier generates the highest total revenue? "
|
| 33 |
-
"What percentage of total revenue does it represent? "
|
| 34 |
-
"Round percentage to 2 decimal places. "
|
| 35 |
-
"Submit your answer in the format: "
|
| 36 |
-
"'Top tier: <name>, Revenue: $X.XX, Percentage: X.XX%'"
|
| 37 |
-
)
|
| 38 |
-
|
| 39 |
-
def _compute(self) -> tuple:
|
| 40 |
-
"""Compute the top loyalty tier and its revenue share.
|
| 41 |
-
|
| 42 |
-
Returns:
|
| 43 |
-
A tuple of (top_tier, tier_revenue, percentage).
|
| 44 |
-
"""
|
| 45 |
-
conn = sqlite3.connect(DB_PATH)
|
| 46 |
-
profiles = pd.read_sql("SELECT customer_id, loyalty_tier FROM customer_profiles", conn)
|
| 47 |
-
conn.close()
|
| 48 |
-
merged = self.df.merge(profiles, on="customer_id", how="left")
|
| 49 |
-
tier_rev = merged.groupby("loyalty_tier")["total_price"].sum()
|
| 50 |
-
total = merged["total_price"].sum()
|
| 51 |
-
top = tier_rev.idxmax()
|
| 52 |
-
rev = tier_rev[top]
|
| 53 |
-
pct = rev / total * 100
|
| 54 |
-
return top, rev, pct
|
| 55 |
-
|
| 56 |
-
def expected_answer(self) -> str:
|
| 57 |
-
"""Compute the expected formatted answer.
|
| 58 |
-
|
| 59 |
-
Returns:
|
| 60 |
-
Formatted string like 'Top tier: Bronze, Revenue: $97210.91, Percentage: 39.28%'.
|
| 61 |
-
"""
|
| 62 |
-
top, rev, pct = self._compute()
|
| 63 |
-
return f"Top tier: {top}, Revenue: ${rev:.2f}, Percentage: {round(pct, 2)}%"
|
| 64 |
-
|
| 65 |
-
def grade(self, answer: str) -> float:
|
| 66 |
-
"""Grade with partial credit for each of the three fields.
|
| 67 |
-
|
| 68 |
-
Scoring:
|
| 69 |
-
- 0.33 for correct tier name (case-insensitive)
|
| 70 |
-
- 0.33 for revenue within ±0.5% of expected
|
| 71 |
-
- 0.34 for percentage within ±0.1 of expected
|
| 72 |
-
|
| 73 |
-
Args:
|
| 74 |
-
answer: The agent's submitted answer string.
|
| 75 |
-
|
| 76 |
-
Returns:
|
| 77 |
-
A score between 0.0 and 1.0.
|
| 78 |
-
"""
|
| 79 |
-
top, expected_rev, expected_pct = self._compute()
|
| 80 |
-
score = 0.0
|
| 81 |
-
|
| 82 |
-
tier_match = re.search(r"Top tier:\s*([^,]+)", answer, re.IGNORECASE)
|
| 83 |
-
if tier_match and tier_match.group(1).strip().lower() == top.lower():
|
| 84 |
-
score += 0.33
|
| 85 |
-
|
| 86 |
-
rev_match = re.search(r"Revenue:\s*\$?([\d.]+)", answer, re.IGNORECASE)
|
| 87 |
-
if rev_match:
|
| 88 |
-
try:
|
| 89 |
-
submitted = float(rev_match.group(1))
|
| 90 |
-
if abs(submitted - expected_rev) <= expected_rev * 0.005:
|
| 91 |
-
score += 0.33
|
| 92 |
-
except ValueError:
|
| 93 |
-
pass
|
| 94 |
-
|
| 95 |
-
pct_match = re.search(r"Percentage:\s*([\d.]+)%?", answer, re.IGNORECASE)
|
| 96 |
-
if pct_match:
|
| 97 |
-
try:
|
| 98 |
-
if abs(float(pct_match.group(1)) - expected_pct) <= 0.1:
|
| 99 |
-
score += 0.34
|
| 100 |
-
except ValueError:
|
| 101 |
-
pass
|
| 102 |
-
|
| 103 |
-
return max(0.05, min(0.95, score))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tasks/task_hard_3.py
DELETED
|
@@ -1,107 +0,0 @@
|
|
| 1 |
-
import re
|
| 2 |
-
import sqlite3
|
| 3 |
-
from pathlib import Path
|
| 4 |
-
|
| 5 |
-
import pandas as pd
|
| 6 |
-
|
| 7 |
-
from tasks.base_task import BaseTask
|
| 8 |
-
|
| 9 |
-
DB_PATH = Path(__file__).resolve().parent.parent / "datasets" / "store_data.db"
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
class SupplierProfitabilityTask(BaseTask):
|
| 13 |
-
"""Hard task: find the most profitable supplier using cross-source data.
|
| 14 |
-
|
| 15 |
-
The agent must query the product_catalog table from the SQLite database,
|
| 16 |
-
join it with the sales DataFrame on product_name, compute per-order profit
|
| 17 |
-
and margin, then aggregate by supplier.
|
| 18 |
-
The database is accessible via sqlite3.connect(db_path) in the sandbox.
|
| 19 |
-
"""
|
| 20 |
-
|
| 21 |
-
@property
|
| 22 |
-
def task_id(self) -> int:
|
| 23 |
-
return 6
|
| 24 |
-
|
| 25 |
-
@property
|
| 26 |
-
def difficulty(self) -> str:
|
| 27 |
-
return "hard"
|
| 28 |
-
|
| 29 |
-
@property
|
| 30 |
-
def description(self) -> str:
|
| 31 |
-
return (
|
| 32 |
-
"Using the product catalog database (connect with sqlite3.connect(db_path)), "
|
| 33 |
-
"which supplier has the highest total profit from orders? "
|
| 34 |
-
"(profit per order = (unit_price - cost_price) * quantity) "
|
| 35 |
-
"What is their total profit and average profit margin? "
|
| 36 |
-
"(margin % = (unit_price - cost_price) / unit_price * 100, "
|
| 37 |
-
"averaged across all their orders) "
|
| 38 |
-
"Round total profit to 2 decimal places and avg margin to 2 decimal places. "
|
| 39 |
-
"Submit your answer in the format: "
|
| 40 |
-
"'Supplier: <name>, Total profit: $X.XX, Avg margin: X.XX%'"
|
| 41 |
-
)
|
| 42 |
-
|
| 43 |
-
def _compute(self) -> tuple:
|
| 44 |
-
"""Compute the top supplier by profit and their average margin.
|
| 45 |
-
|
| 46 |
-
Returns:
|
| 47 |
-
A tuple of (supplier_name, total_profit, avg_margin_pct).
|
| 48 |
-
"""
|
| 49 |
-
conn = sqlite3.connect(DB_PATH)
|
| 50 |
-
catalog = pd.read_sql("SELECT product_name, supplier, cost_price FROM product_catalog", conn)
|
| 51 |
-
conn.close()
|
| 52 |
-
merged = self.df.merge(catalog, on="product_name", how="left")
|
| 53 |
-
merged["profit"] = (merged["unit_price"] - merged["cost_price"]) * merged["quantity"]
|
| 54 |
-
merged["margin"] = (merged["unit_price"] - merged["cost_price"]) / merged["unit_price"] * 100
|
| 55 |
-
sup_profit = merged.groupby("supplier")["profit"].sum()
|
| 56 |
-
sup_margin = merged.groupby("supplier")["margin"].mean()
|
| 57 |
-
top = sup_profit.idxmax()
|
| 58 |
-
return top, sup_profit[top], sup_margin[top]
|
| 59 |
-
|
| 60 |
-
def expected_answer(self) -> str:
|
| 61 |
-
"""Compute the expected formatted answer.
|
| 62 |
-
|
| 63 |
-
Returns:
|
| 64 |
-
Formatted string like 'Supplier: FashionWorld, Total profit: $38292.08, Avg margin: 52.08%'.
|
| 65 |
-
"""
|
| 66 |
-
top, profit, margin = self._compute()
|
| 67 |
-
return f"Supplier: {top}, Total profit: ${profit:.2f}, Avg margin: {round(margin, 2)}%"
|
| 68 |
-
|
| 69 |
-
def grade(self, answer: str) -> float:
|
| 70 |
-
"""Grade with partial credit for each of the three fields.
|
| 71 |
-
|
| 72 |
-
Scoring:
|
| 73 |
-
- 0.33 for correct supplier name (case-insensitive)
|
| 74 |
-
- 0.34 for total profit within ±0.5% of expected
|
| 75 |
-
- 0.33 for avg margin within ±0.1 of expected
|
| 76 |
-
|
| 77 |
-
Args:
|
| 78 |
-
answer: The agent's submitted answer string.
|
| 79 |
-
|
| 80 |
-
Returns:
|
| 81 |
-
A score between 0.0 and 1.0.
|
| 82 |
-
"""
|
| 83 |
-
top, expected_profit, expected_margin = self._compute()
|
| 84 |
-
score = 0.0
|
| 85 |
-
|
| 86 |
-
sup_match = re.search(r"Supplier:\s*([^,]+)", answer, re.IGNORECASE)
|
| 87 |
-
if sup_match and sup_match.group(1).strip().lower() == top.lower():
|
| 88 |
-
score += 0.33
|
| 89 |
-
|
| 90 |
-
profit_match = re.search(r"Total profit:\s*\$?([\d.]+)", answer, re.IGNORECASE)
|
| 91 |
-
if profit_match:
|
| 92 |
-
try:
|
| 93 |
-
submitted = float(profit_match.group(1))
|
| 94 |
-
if abs(submitted - expected_profit) <= expected_profit * 0.005:
|
| 95 |
-
score += 0.34
|
| 96 |
-
except ValueError:
|
| 97 |
-
pass
|
| 98 |
-
|
| 99 |
-
margin_match = re.search(r"Avg margin:\s*([\d.]+)%?", answer, re.IGNORECASE)
|
| 100 |
-
if margin_match:
|
| 101 |
-
try:
|
| 102 |
-
if abs(float(margin_match.group(1)) - expected_margin) <= 0.1:
|
| 103 |
-
score += 0.33
|
| 104 |
-
except ValueError:
|
| 105 |
-
pass
|
| 106 |
-
|
| 107 |
-
return max(0.05, min(0.95, score))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tasks/task_medium.py
DELETED
|
@@ -1,77 +0,0 @@
|
|
| 1 |
-
import re
|
| 2 |
-
|
| 3 |
-
import pandas as pd
|
| 4 |
-
|
| 5 |
-
from tasks.base_task import BaseTask
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
class CityRevenueShareTask(BaseTask):
|
| 9 |
-
"""Medium task: identify the city with the highest revenue and its percentage share.
|
| 10 |
-
|
| 11 |
-
The agent must group by city, compute total revenue per city,
|
| 12 |
-
find the top city, and calculate what percentage of overall revenue it represents.
|
| 13 |
-
"""
|
| 14 |
-
|
| 15 |
-
@property
|
| 16 |
-
def task_id(self) -> int:
|
| 17 |
-
return 2
|
| 18 |
-
|
| 19 |
-
@property
|
| 20 |
-
def difficulty(self) -> str:
|
| 21 |
-
return "medium"
|
| 22 |
-
|
| 23 |
-
@property
|
| 24 |
-
def description(self) -> str:
|
| 25 |
-
return (
|
| 26 |
-
"Which city generates the most revenue? What percentage of total revenue "
|
| 27 |
-
"does it represent? Round to 2 decimal places. "
|
| 28 |
-
"Submit your answer in the format: 'City: <name>, Percentage: <X.XX>%'"
|
| 29 |
-
)
|
| 30 |
-
|
| 31 |
-
def expected_answer(self) -> str:
|
| 32 |
-
"""Compute the top city and its revenue share.
|
| 33 |
-
|
| 34 |
-
Returns:
|
| 35 |
-
Formatted string like 'City: London, Percentage: 10.81%'.
|
| 36 |
-
"""
|
| 37 |
-
city_rev = self.df.groupby("city")["total_price"].sum()
|
| 38 |
-
top_city = city_rev.idxmax()
|
| 39 |
-
pct = round(city_rev[top_city] / city_rev.sum() * 100, 2)
|
| 40 |
-
return f"City: {top_city}, Percentage: {pct}%"
|
| 41 |
-
|
| 42 |
-
def grade(self, answer: str) -> float:
|
| 43 |
-
"""Grade the answer with partial credit for city and percentage.
|
| 44 |
-
|
| 45 |
-
Scoring:
|
| 46 |
-
- 0.5 for correct city name (case-insensitive)
|
| 47 |
-
- 0.5 for percentage within ±0.1 of expected
|
| 48 |
-
|
| 49 |
-
Args:
|
| 50 |
-
answer: The agent's submitted answer string.
|
| 51 |
-
|
| 52 |
-
Returns:
|
| 53 |
-
A score between 0.0 and 1.0.
|
| 54 |
-
"""
|
| 55 |
-
score = 0.0
|
| 56 |
-
city_rev = self.df.groupby("city")["total_price"].sum()
|
| 57 |
-
expected_city = city_rev.idxmax()
|
| 58 |
-
expected_pct = round(city_rev[expected_city] / city_rev.sum() * 100, 2)
|
| 59 |
-
|
| 60 |
-
# Check city
|
| 61 |
-
city_match = re.search(r"City:\s*([^,]+)", answer, re.IGNORECASE)
|
| 62 |
-
if city_match:
|
| 63 |
-
submitted_city = city_match.group(1).strip()
|
| 64 |
-
if submitted_city.lower() == expected_city.lower():
|
| 65 |
-
score += 0.5
|
| 66 |
-
|
| 67 |
-
# Check percentage
|
| 68 |
-
pct_match = re.search(r"Percentage:\s*([\d.]+)%?", answer, re.IGNORECASE)
|
| 69 |
-
if pct_match:
|
| 70 |
-
try:
|
| 71 |
-
submitted_pct = float(pct_match.group(1))
|
| 72 |
-
if abs(submitted_pct - expected_pct) <= 0.1:
|
| 73 |
-
score += 0.5
|
| 74 |
-
except ValueError:
|
| 75 |
-
pass
|
| 76 |
-
|
| 77 |
-
return max(0.05, min(0.95, score))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tasks/task_medium_2.py
DELETED
|
@@ -1,88 +0,0 @@
|
|
| 1 |
-
import re
|
| 2 |
-
|
| 3 |
-
import pandas as pd
|
| 4 |
-
|
| 5 |
-
from tasks.base_task import BaseTask
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
class MonthlyRevenueRatioTask(BaseTask):
|
| 9 |
-
"""Medium task: find the best and worst months by revenue and compute their ratio.
|
| 10 |
-
|
| 11 |
-
The agent must parse order_date, group by month, find the extremes,
|
| 12 |
-
and compute how many times larger the best month is versus the worst.
|
| 13 |
-
"""
|
| 14 |
-
|
| 15 |
-
@property
|
| 16 |
-
def task_id(self) -> int:
|
| 17 |
-
return 4
|
| 18 |
-
|
| 19 |
-
@property
|
| 20 |
-
def difficulty(self) -> str:
|
| 21 |
-
return "medium"
|
| 22 |
-
|
| 23 |
-
@property
|
| 24 |
-
def description(self) -> str:
|
| 25 |
-
return (
|
| 26 |
-
"What is the best and worst performing month by total revenue in 2024? "
|
| 27 |
-
"What is the ratio of best to worst month revenue? Round ratio to 2 decimal places. "
|
| 28 |
-
"Submit your answer in the format: "
|
| 29 |
-
"'Best: YYYY-MM ($X.XX), Worst: YYYY-MM ($X.XX), Ratio: X.XX'"
|
| 30 |
-
)
|
| 31 |
-
|
| 32 |
-
def _compute(self) -> tuple:
|
| 33 |
-
"""Compute the best month, worst month, and their revenue ratio.
|
| 34 |
-
|
| 35 |
-
Returns:
|
| 36 |
-
A tuple of (best_month_str, best_rev, worst_month_str, worst_rev, ratio).
|
| 37 |
-
"""
|
| 38 |
-
df = self.df.copy()
|
| 39 |
-
df["order_date"] = pd.to_datetime(df["order_date"])
|
| 40 |
-
monthly = df.groupby(df["order_date"].dt.to_period("M"))["total_price"].sum()
|
| 41 |
-
best = monthly.idxmax()
|
| 42 |
-
worst = monthly.idxmin()
|
| 43 |
-
ratio = round(monthly[best] / monthly[worst], 2)
|
| 44 |
-
return str(best), monthly[best], str(worst), monthly[worst], ratio
|
| 45 |
-
|
| 46 |
-
def expected_answer(self) -> str:
|
| 47 |
-
"""Compute the expected formatted answer.
|
| 48 |
-
|
| 49 |
-
Returns:
|
| 50 |
-
Formatted string like 'Best: 2024-12 ($23025.82), Worst: 2024-05 ($16871.48), Ratio: 1.36'.
|
| 51 |
-
"""
|
| 52 |
-
best, best_rev, worst, worst_rev, ratio = self._compute()
|
| 53 |
-
return f"Best: {best} (${best_rev:.2f}), Worst: {worst} (${worst_rev:.2f}), Ratio: {ratio}"
|
| 54 |
-
|
| 55 |
-
def grade(self, answer: str) -> float:
|
| 56 |
-
"""Grade with partial credit for each of the three fields.
|
| 57 |
-
|
| 58 |
-
Scoring:
|
| 59 |
-
- 0.33 for correct best month (exact YYYY-MM match)
|
| 60 |
-
- 0.33 for correct worst month (exact YYYY-MM match)
|
| 61 |
-
- 0.34 for ratio within ±0.01 of expected
|
| 62 |
-
|
| 63 |
-
Args:
|
| 64 |
-
answer: The agent's submitted answer string.
|
| 65 |
-
|
| 66 |
-
Returns:
|
| 67 |
-
A score between 0.0 and 1.0.
|
| 68 |
-
"""
|
| 69 |
-
best, _, worst, _, expected_ratio = self._compute()
|
| 70 |
-
score = 0.0
|
| 71 |
-
|
| 72 |
-
best_match = re.search(r"Best:\s*([\d]{4}-[\d]{2})", answer, re.IGNORECASE)
|
| 73 |
-
if best_match and best_match.group(1).strip() == best:
|
| 74 |
-
score += 0.33
|
| 75 |
-
|
| 76 |
-
worst_match = re.search(r"Worst:\s*([\d]{4}-[\d]{2})", answer, re.IGNORECASE)
|
| 77 |
-
if worst_match and worst_match.group(1).strip() == worst:
|
| 78 |
-
score += 0.33
|
| 79 |
-
|
| 80 |
-
ratio_match = re.search(r"Ratio:\s*([\d.]+)", answer, re.IGNORECASE)
|
| 81 |
-
if ratio_match:
|
| 82 |
-
try:
|
| 83 |
-
if abs(float(ratio_match.group(1)) - expected_ratio) <= 0.01:
|
| 84 |
-
score += 0.34
|
| 85 |
-
except ValueError:
|
| 86 |
-
pass
|
| 87 |
-
|
| 88 |
-
return max(0.05, min(0.95, score))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
uv.lock
DELETED
|
The diff for this file is too large to render.
See raw diff
|
|
|