HimanshuSardana2 commited on
Commit
abb357f
·
verified ·
1 Parent(s): 8c8a964

Upload folder using huggingface_hub

Browse files
Dockerfile ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ WORKDIR /app
4
+
5
+ RUN pip install --no-cache-dir uv
6
+
7
+ COPY server/requirements.txt /tmp/requirements.txt
8
+ RUN pip install --no-cache-dir -r /tmp/requirements.txt
9
+
10
+ COPY . /app/
11
+
12
+ ENV PYTHONPATH=/app
13
+ ENV DATA_DIR=/app/server/data
14
+
15
+ EXPOSE 8000
16
+
17
+ ENV ENABLE_WEB_INTERFACE=true
18
+ CMD ["uvicorn", "server.app:app", "--host", "0.0.0.0", "--port", "8000"]
README.md CHANGED
@@ -1,10 +1,156 @@
1
  ---
2
  title: Data Analysis Env
3
- emoji: 🐨
4
- colorFrom: gray
5
  colorTo: green
6
  sdk: docker
7
- pinned: false
 
 
 
 
 
 
 
 
 
8
  ---
9
 
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  title: Data Analysis Env
3
+ emoji: 📊
4
+ colorFrom: blue
5
  colorTo: green
6
  sdk: docker
7
+ app_file: inference.py
8
+ pytorch: false
9
+ python_version: "3.10"
10
+ tags:
11
+ - data-analysis
12
+ - pandas
13
+ - openenv
14
+ - ai-agents
15
+ license: mit
16
+ base_path: /web
17
  ---
18
 
19
+ # Data Analysis OpenEnv Environment
20
+
21
+ A real-world OpenEnv environment for training and evaluating AI agents on pandas data analysis tasks.
22
+
23
+ ## Environment Description
24
+
25
+ This environment simulates real-world data analysis workflows that humans perform daily:
26
+ - Loading and exploring CSV data
27
+ - Cleaning dirty data (handling missing values, removing duplicates)
28
+ - Transforming data (filtering, sorting, selecting columns)
29
+ - Merging multiple datasets
30
+ - Computing statistics and aggregations
31
+
32
+ ## Task Descriptions
33
+
34
+ ### Task 1: Basic Statistics (Easy)
35
+ - **Objective**: Load `simple.csv` and calculate the mean of the `price` column
36
+ - **Difficulty**: Easy
37
+ - **Expected Score**: 0.7+ for correct mean calculation
38
+
39
+ ### Task 2: Data Cleaning (Medium)
40
+ - **Objective**: Load `dirty.csv`, fill missing values (mean), remove duplicates, calculate median of `age`
41
+ - **Difficulty**: Medium
42
+ - **Expected Score**: 0.7+ for correct cleaning and median calculation
43
+
44
+ ### Task 3: Multi-table Analysis (Hard)
45
+ - **Objective**: Load `sales.csv` and `products.csv`, merge on product_id, calculate total sales per category
46
+ - **Difficulty**: Hard
47
+ - **Expected Score**: 0.7+ for correct merge and aggregation
48
+
49
+ ## Action Space
50
+
51
+ ```python
52
+ DataAnalysisAction(
53
+ tool: str, # Tool name: load_csv, show_data, show_columns, fill_missing,
54
+ # remove_duplicates, filter_rows, select_columns, group_by,
55
+ # calculate, sort_by, get_result, merge_datasets
56
+ parameters: dict # Tool parameters
57
+ )
58
+ ```
59
+
60
+ ## Observation Space
61
+
62
+ ```python
63
+ DataAnalysisObservation(
64
+ done: bool, # Episode done flag
65
+ reward: float, # Reward (0.0-1.0)
66
+ success: bool, # Tool executed successfully
67
+ output: str, # Tool output
68
+ data_shape: tuple[int, int], # (rows, columns)
69
+ columns: list[str], # Column names
70
+ tools_used: list[str], # History of tools called
71
+ error: Optional[str] # Error message if any
72
+ )
73
+ ```
74
+
75
+ ## Reward Function
76
+
77
+ - **+0.1**: Each successful tool execution
78
+ - **+0.5 × score**: Final result grading (score based on accuracy)
79
+ - **-0.1**: Failed tool execution or invalid tool
80
+ - **0.0**: Episode ends without meaningful progress
81
+
82
+ ## Setup Instructions
83
+
84
+ ### Local Development
85
+
86
+ ```bash
87
+ # Install dependencies
88
+ cd data_analysis_env
89
+ pip install -r server/requirements.txt
90
+
91
+ # Run the server
92
+ python -m server.app
93
+
94
+ # Or use uvicorn
95
+ uvicorn server.app:app --host 0.0.0.0 --port 8000
96
+ ```
97
+
98
+ ### Docker
99
+
100
+ ```bash
101
+ # Build the image
102
+ docker build -t data_analysis_env .
103
+
104
+ # Run the container
105
+ docker run -p 8000:8000 data_analysis_env
106
+ ```
107
+
108
+ ### Running Inference
109
+
110
+ ```bash
111
+ # Set environment variables
112
+ export HF_TOKEN=your_token
113
+ export API_BASE_URL=https://router.huggingface.co/v1
114
+ export MODEL_NAME=Qwen/Qwen2.5-72B-Instruct
115
+ export ENV_URL=http://localhost:8000
116
+
117
+ # Run inference
118
+ python inference.py
119
+ ```
120
+
121
+ ## Baseline Scores
122
+
123
+ | Task | Expected Score |
124
+ |------|--------------|
125
+ | task_1 (Easy) | 0.7-1.0 |
126
+ | task_2 (Medium) | 0.5-0.8 |
127
+ | task_3 (Hard) | 0.3-0.7 |
128
+
129
+ ## API Endpoints
130
+
131
+ - `POST /reset` - Reset environment with task name
132
+ - `POST /step` - Execute action
133
+ - `GET /state` - Get current state
134
+
135
+ ## Files
136
+
137
+ ```
138
+ data_analysis_env/
139
+ ├── __init__.py # Package init
140
+ ├── models.py # Pydantic models
141
+ ├── client.py # Client implementation
142
+ ├── inference.py # Inference script
143
+ ├── openenv.yaml # OpenEnv spec
144
+ ├── Dockerfile # Docker configuration
145
+ ├── server/
146
+ │ ├── app.py # FastAPI app
147
+ │ ├── data_analysis_environment.py # Environment implementation
148
+ │ ├── Dockerfile # Server Dockerfile
149
+ │ ├── requirements.txt
150
+ │ └── data/
151
+ │ ├── simple.csv
152
+ │ ├── dirty.csv
153
+ │ ├── sales.csv
154
+ │ └── products.csv
155
+ └── README.md
156
+ ```
__init__.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .models import (
2
+ DataAnalysisAction,
3
+ DataAnalysisObservation,
4
+ DataAnalysisState,
5
+ AVAILABLE_TOOLS,
6
+ )
7
+ from .client import DataAnalysisEnv
8
+
9
+
10
+ __all__ = [
11
+ "DataAnalysisAction",
12
+ "DataAnalysisObservation",
13
+ "DataAnalysisState",
14
+ "DataAnalysisEnv",
15
+ "AVAILABLE_TOOLS",
16
+ ]
17
+
18
+
19
+ TASKS = {
20
+ "task_1": {
21
+ "name": "Basic Statistics",
22
+ "description": "Load simple.csv and calculate the mean of the 'price' column",
23
+ "datafile": "simple.csv",
24
+ "target_column": "price",
25
+ "target_operation": "mean",
26
+ "expected_answer": None,
27
+ "difficulty": "easy",
28
+ },
29
+ "task_2": {
30
+ "name": "Data Cleaning",
31
+ "description": "Load dirty.csv, fill missing values, remove duplicates, then calculate median of 'age'",
32
+ "datafile": "dirty.csv",
33
+ "target_column": "age",
34
+ "target_operation": "median",
35
+ "expected_answer": None,
36
+ "difficulty": "medium",
37
+ },
38
+ "task_3": {
39
+ "name": "Multi-table Analysis",
40
+ "description": "Load sales.csv and products.csv, merge on product_id, calculate total sales per category",
41
+ "datafile": "sales.csv",
42
+ "secondary_datafile": "products.csv",
43
+ "target_column": "sales",
44
+ "group_by_column": "category",
45
+ "target_operation": "sum",
46
+ "expected_answer": None,
47
+ "difficulty": "hard",
48
+ },
49
+ }
client.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import httpx
4
+ from openenv.core.env_client import EnvClient, StepResult
5
+
6
+ from models import DataAnalysisAction, DataAnalysisObservation, DataAnalysisState
7
+
8
+
9
+ class DataAnalysisEnv(EnvClient):
10
+ def __init__(self, base_url: str = "http://localhost:8000"):
11
+ self._base_url = base_url.rstrip("/")
12
+ if self._base_url.startswith("ws://"):
13
+ self._base_url = self._base_url.replace("ws://", "http://")
14
+ elif not self._base_url.startswith("http://"):
15
+ self._base_url = "http://" + self._base_url
16
+ self._client: Optional[httpx.AsyncClient] = None
17
+
18
+ def _get_client(self) -> httpx.AsyncClient:
19
+ if self._client is None:
20
+ self._client = httpx.AsyncClient(base_url=self._base_url, timeout=60.0)
21
+ return self._client
22
+
23
+ async def reset(self, task: str = "task_1", **kwargs) -> StepResult:
24
+ client = self._get_client()
25
+ response = await client.post("/reset", json={"task": task})
26
+ response.raise_for_status()
27
+ data = response.json()
28
+ return self._parse_result(data)
29
+
30
+ async def step(self, action: DataAnalysisAction) -> StepResult:
31
+ payload = {
32
+ "action": {
33
+ "tool": action.tool,
34
+ "parameters": action.parameters,
35
+ }
36
+ }
37
+ client = self._get_client()
38
+ response = await client.post("/step", json=payload)
39
+ response.raise_for_status()
40
+ data = response.json()
41
+ return self._parse_result(data)
42
+
43
+ async def state(self) -> DataAnalysisState:
44
+ client = self._get_client()
45
+ response = await client.get("/state")
46
+ response.raise_for_status()
47
+ data = response.json()
48
+ return DataAnalysisState(**data)
49
+
50
+ async def close(self):
51
+ if self._client:
52
+ await self._client.aclose()
53
+ self._client = None
54
+
55
+ @staticmethod
56
+ def _parse_result(payload: dict) -> StepResult:
57
+ obs = DataAnalysisObservation(**payload.get("observation", {}))
58
+ return StepResult(
59
+ observation=obs,
60
+ reward=payload.get("reward", 0.0),
61
+ done=payload.get("done", False),
62
+ )
63
+
64
+ @staticmethod
65
+ def _parse_state(payload: dict) -> DataAnalysisState:
66
+ return DataAnalysisState(**payload)
inference.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import os
3
+ import sys
4
+ import textwrap
5
+ from typing import List, Optional
6
+
7
+ from openai import OpenAI
8
+ from openenv.core.env_client import StepResult
9
+
10
+ API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
11
+ API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
12
+ MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
13
+ BENCHMARK = "data_analysis_env"
14
+ MAX_STEPS = 20
15
+ SUCCESS_SCORE_THRESHOLD = 0.7
16
+
17
+
18
+ TASK_INSTRUCTIONS = {
19
+ "task_1": textwrap.dedent(
20
+ """You are a data analysis assistant. Your task is to: 1. Load the CSV file 'simple.csv' 2. Calculate the mean of the 'price' column. Available tools: load_csv(filename='filename.csv'), show_data(), show_columns(), calculate(column='column_name', operation='mean|median|sum|count|std|min|max'). Start by loading the data, then calculate the mean of the price column."""
21
+ ),
22
+ "task_2": textwrap.dedent(
23
+ """You are a data analysis assistant. Your task is to: 1. Load the CSV file 'dirty.csv' 2. Fill missing values (use mean) 3. Remove duplicate rows 4. Calculate the median of the 'age' column. Available tools: load_csv(filename='filename.csv'), fill_missing(value='mean|median|zero|value'), remove_duplicates(), show_data(), show_columns(), calculate(column='column_name', operation='mean|median|sum|count|std|min|max'). Start by loading the data, then clean it, then calculate the median."""
24
+ ),
25
+ "task_3": textwrap.dedent(
26
+ """You are a data analysis assistant. Your task is to: 1. Load 'sales.csv' and 'products.csv' 2. Merge them on 'product_id' 3. Group by 'category' and sum the 'sales' column 4. Get the final result. Available tools: load_csv(filename='filename.csv'), merge_datasets(filename='filename.csv', on='column_name'), show_data(), show_columns(), group_by(group_column='column_name', agg_column='column_name', operation='sum|mean|count'), calculate(column='column_name', operation='sum|mean|count'), get_result(). Start by loading both files, then merge, then group and aggregate."""
27
+ ),
28
+ }
29
+
30
+
31
+ def get_action_from_response(response: str):
32
+ from data_analysis_env import DataAnalysisAction
33
+
34
+ response = response.strip()
35
+
36
+ if response.lower() in ["done", "get_result()"]:
37
+ return DataAnalysisAction(tool="get_result", parameters={})
38
+
39
+ if "(" not in response or ")" not in response:
40
+ return None
41
+
42
+ try:
43
+ tool_name = response.split("(")[0].strip()
44
+ params_str = response.split("(")[1].split(")")[0].strip()
45
+
46
+ parameters = {}
47
+ if params_str:
48
+ for param in params_str.split(","):
49
+ param = param.strip()
50
+ if "=" in param:
51
+ key, value = param.split("=", 1)
52
+ key = key.strip()
53
+ value = value.strip().strip("'\"")
54
+
55
+ if value.lower() == "none":
56
+ value = None
57
+ elif value.lower() == "true":
58
+ value = True
59
+ elif value.lower() == "false":
60
+ value = False
61
+ else:
62
+ try:
63
+ if "." in value:
64
+ value = float(value)
65
+ else:
66
+ value = int(value)
67
+ except ValueError:
68
+ pass
69
+
70
+ parameters[key] = value
71
+
72
+ return DataAnalysisAction(tool=tool_name, parameters=parameters)
73
+
74
+ except Exception as e:
75
+ print(f"Error parsing action: {e}", file=sys.stderr)
76
+ return None
77
+
78
+
79
+ def log_start(task: str, env: str, model: str) -> None:
80
+ print(f"[START] task={task} env={env} model={model}", flush=True)
81
+
82
+
83
+ def log_step(
84
+ step: int, action: str, reward: float, done: bool, error: Optional[str]
85
+ ) -> None:
86
+ error_val = error if error else "null"
87
+ done_val = str(done).lower()
88
+ print(
89
+ f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}",
90
+ flush=True,
91
+ )
92
+
93
+
94
+ def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
95
+ rewards_str = ",".join(f"{r:.2f}" for r in rewards)
96
+ print(
97
+ f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}",
98
+ flush=True,
99
+ )
100
+
101
+
102
+ async def run_task(client: OpenAI, env, task_name: str):
103
+ from data_analysis_env import DataAnalysisAction
104
+
105
+ log_start(task=task_name, env=BENCHMARK, model=MODEL_NAME)
106
+
107
+ instruction = TASK_INSTRUCTIONS.get(task_name, "")
108
+ messages = [
109
+ {"role": "system", "content": instruction},
110
+ {"role": "user", "content": "Begin the analysis task."},
111
+ ]
112
+
113
+ step = 0
114
+ rewards = []
115
+ last_error = None
116
+
117
+ result = await env.reset(task=task_name)
118
+ obs = result.observation
119
+ reward_val = obs.reward if obs.reward is not None else 0.0
120
+
121
+ print(
122
+ f"[STEP] step={step} action=reset reward={reward_val:.2f} done={result.done} error=null",
123
+ flush=True,
124
+ )
125
+
126
+ while not result.done and step < MAX_STEPS:
127
+ step += 1
128
+
129
+ response = (
130
+ client.chat.completions.create(
131
+ model=MODEL_NAME,
132
+ messages=messages
133
+ + [{"role": "assistant", "content": f"Previous output: {obs.output}"}],
134
+ temperature=0.1,
135
+ max_tokens=500,
136
+ )
137
+ .choices[0]
138
+ .message.content
139
+ )
140
+
141
+ action = get_action_from_response(response)
142
+
143
+ if action is None:
144
+ last_error = "Could not parse action"
145
+ print(
146
+ f"[STEP] step={step} action='{response}' reward={obs.reward:.2f} done=false error={last_error}",
147
+ flush=True,
148
+ )
149
+ messages.append(
150
+ {
151
+ "role": "user",
152
+ "content": f"Invalid action format. Please use tool_name(param1=value1, param2=value2). Error: {last_error}",
153
+ }
154
+ )
155
+ continue
156
+
157
+ result = await env.step(action)
158
+ obs = result.observation
159
+ reward_val = obs.reward if obs.reward is not None else 0.0
160
+ rewards.append(reward_val)
161
+
162
+ error_str = obs.error if obs.error else "null"
163
+ print(
164
+ f"[STEP] step={step} action={action.tool}({action.parameters}) reward={reward_val:.2f} done={result.done} error={error_str}",
165
+ flush=True,
166
+ )
167
+
168
+ if obs.error:
169
+ last_error = obs.error
170
+ messages.append(
171
+ {
172
+ "role": "user",
173
+ "content": f"Error: {obs.error}. Please try a different tool or correct parameters.",
174
+ }
175
+ )
176
+ else:
177
+ messages.append(
178
+ {
179
+ "role": "user",
180
+ "content": f"Tool executed successfully. Output: {obs.output}",
181
+ }
182
+ )
183
+
184
+ if result.done:
185
+ break
186
+
187
+ score = obs.reward if obs.reward is not None else 0.0
188
+ success = score >= SUCCESS_SCORE_THRESHOLD
189
+
190
+ rewards_str = ",".join(f"{r:.2f}" for r in rewards)
191
+ log_end(success=success, steps=step, score=score, rewards=rewards)
192
+
193
+ return {
194
+ "task": task_name,
195
+ "success": success,
196
+ "steps": step,
197
+ "score": score,
198
+ "rewards": rewards,
199
+ }
200
+
201
+
202
+ async def main():
203
+ from data_analysis_env import DataAnalysisEnv
204
+
205
+ if not API_KEY:
206
+ print(
207
+ "Error: HF_TOKEN or API_KEY environment variable not set", file=sys.stderr
208
+ )
209
+ sys.exit(1)
210
+
211
+ client = OpenAI(api_key=API_KEY, base_url=API_BASE_URL)
212
+
213
+ base_url = os.getenv("ENV_URL", "http://localhost:8000")
214
+ env = DataAnalysisEnv(base_url=base_url)
215
+
216
+ results = []
217
+
218
+ for task_name in ["task_1", "task_2", "task_3"]:
219
+ try:
220
+ result = await run_task(client, env, task_name)
221
+ results.append(result)
222
+ except Exception as e:
223
+ print(f"Error running {task_name}: {e}", file=sys.stderr)
224
+ results.append(
225
+ {
226
+ "task": task_name,
227
+ "success": False,
228
+ "steps": 0,
229
+ "score": 0.0,
230
+ "rewards": [],
231
+ }
232
+ )
233
+
234
+ await env.close()
235
+
236
+ avg_score = sum(r["score"] for r in results) / len(results)
237
+ print(f"\n=== Summary ===")
238
+ print(f"Average Score: {avg_score:.2f}")
239
+ for r in results:
240
+ print(f" {r['task']}: {r['score']:.2f} ({'PASS' if r['success'] else 'FAIL'})")
241
+
242
+
243
+ if __name__ == "__main__":
244
+ asyncio.run(main())
models.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Literal, Optional
2
+ from pydantic import BaseModel, Field, field_validator
3
+
4
+
5
+ class DataAnalysisAction(BaseModel):
6
+ tool: str = Field(..., description="Tool name to execute")
7
+ parameters: dict[str, Any] = Field(
8
+ default_factory=dict, description="Tool parameters"
9
+ )
10
+
11
+ @field_validator("tool", mode="before")
12
+ @classmethod
13
+ def _coerce_tool(cls, value):
14
+ if isinstance(value, dict):
15
+ return value.get("tool", "")
16
+ return str(value)
17
+
18
+
19
+ class DataAnalysisObservation(BaseModel):
20
+ done: bool = Field(default=False, description="Whether episode is done")
21
+ reward: float = Field(default=0.0, description="Reward for this step")
22
+ success: bool = Field(
23
+ default=True, description="Whether tool executed successfully"
24
+ )
25
+ output: str = Field(default="", description="Tool output or error message")
26
+ data_shape: Optional[tuple[int, int]] = Field(
27
+ default=None, description="(rows, columns) of current data"
28
+ )
29
+ columns: list[str] = Field(
30
+ default_factory=list, description="Column names of current data"
31
+ )
32
+ tools_used: list[str] = Field(
33
+ default_factory=list, description="History of tools called"
34
+ )
35
+ error: Optional[str] = Field(
36
+ default=None, description="Error message if tool failed"
37
+ )
38
+
39
+ @field_validator("data_shape", mode="before")
40
+ @classmethod
41
+ def _coerce_shape(cls, value):
42
+ if isinstance(value, list) and len(value) == 2:
43
+ return tuple(value)
44
+ return value
45
+
46
+
47
+ class DataAnalysisState(BaseModel):
48
+ episode_id: Optional[str] = Field(
49
+ default=None, description="Unique episode identifier"
50
+ )
51
+ task_name: str = Field(default="", description="Current task name")
52
+ step_count: int = Field(default=0, description="Number of steps taken")
53
+ max_steps: int = Field(default=20, description="Maximum steps allowed per episode")
54
+ data_loaded: bool = Field(default=False, description="Whether data has been loaded")
55
+
56
+
57
+ AVAILABLE_TOOLS = [
58
+ "load_csv",
59
+ "show_data",
60
+ "show_columns",
61
+ "fill_missing",
62
+ "remove_duplicates",
63
+ "filter_rows",
64
+ "select_columns",
65
+ "group_by",
66
+ "calculate",
67
+ "sort_by",
68
+ "get_result",
69
+ "merge_datasets",
70
+ ]
openenv.yaml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ spec_version: 1
2
+ name: data_analysis_env
3
+ type: environment
4
+ runtime: fastapi
5
+ app: server.app:app
6
+ port: 8000
7
+ metadata:
8
+ title: Data Analysis Env
9
+ description: Real-world data analysis tasks using pandas - load, clean, transform, and analyze CSV data
10
+ difficulty:
11
+ - easy
12
+ - medium
13
+ - hard
14
+ tags:
15
+ - data-analysis
16
+ - pandas
17
+ - openenv
18
+ - ai-agents
19
+ author: Meta Hackathon
20
+ version: "1.0.0"
pyproject.toml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "data_analysis_env"
3
+ version = "0.1.0"
4
+ description = "Data Analysis Environment for OpenEnv - A real-world RL task for teaching agents pandas data analysis"
5
+ readme = "README.md"
6
+ requires-python = ">=3.10"
7
+ dependencies = [
8
+ "openenv-core>=0.1.0",
9
+ "pandas>=2.0.0",
10
+ "fastapi>=0.100.0",
11
+ "uvicorn>=0.23.0",
12
+ ]
13
+
14
+ [project.scripts]
15
+ data_analysis_env = "server.app:main"
16
+
17
+ [project.optional-dependencies]
18
+ dev = [
19
+ "pytest>=7.0.0",
20
+ "black>=23.0.0",
21
+ "mypy>=1.0.0",
22
+ ]
23
+
24
+ [build-system]
25
+ requires = ["hatchling"]
26
+ build-backend = "hatchling.build"
27
+
28
+ [tool.hatch.build.targets.wheel]
29
+ packages = ["data_analysis_env"]
server/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Server package
server/app.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+
4
+ from openenv.core.env_server import create_app
5
+
6
+ from server.quantum_openenv_env_environment import DataAnalysisEnvironment
7
+ from models import DataAnalysisAction, DataAnalysisObservation
8
+
9
+
10
+ def create_data_analysis_environment():
11
+ data_dir = os.getenv("DATA_DIR", "/app/data")
12
+ return DataAnalysisEnvironment(data_dir=data_dir)
13
+
14
+
15
+ app = create_app(
16
+ create_data_analysis_environment,
17
+ DataAnalysisAction,
18
+ DataAnalysisObservation,
19
+ env_name="data_analysis_env",
20
+ )
21
+
22
+
23
+ def main():
24
+ import uvicorn
25
+
26
+ uvicorn.run(app, host="0.0.0.0", port=8000)
27
+
28
+
29
+ if __name__ == "__main__":
30
+ main()
server/data/dirty.csv ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name,age,salary,city
2
+ John,25,50000,New York
3
+ Jane,30,60000,Los Angeles
4
+ Bob,,55000,Chicago
5
+ Alice,28,52000,Houston
6
+ John,25,50000,New York
7
+ Charlie,35,70000,Phoenix
8
+ Jane,30,60000,Los Angeles
9
+ David,,58000,San Diego
10
+ Eve,32,,Philadelphia
11
+ Frank,29,54000,Dallas
12
+ Bob,35,55000,Chicago
13
+ Grace,27,51000,Austin
14
+ Henry,,62000,Seattle
15
+ Ivy,31,56000,Denver
16
+ John,25,50000,New York
17
+ Jack,33,59000,Boston
18
+ Kelly,26,,Portland
server/data/products.csv ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ product_id,product_name,category,unit_price
2
+ P001,Widget Alpha,Electronics,50.00
3
+ P002,Widget Beta,Electronics,50.00
4
+ P003,Widget Gamma,Home,50.00
5
+ P004,Widget Delta,Home,50.00
server/data/sales.csv ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ transaction_id,product_id,quantity,sales,date
2
+ 1,P001,5,250.00,2024-01-15
3
+ 2,P002,3,150.00,2024-01-16
4
+ 3,P001,2,100.00,2024-01-17
5
+ 4,P003,4,200.00,2024-01-18
6
+ 5,P002,6,300.00,2024-01-19
7
+ 6,P001,3,150.00,2024-01-20
8
+ 7,P003,2,100.00,2024-01-21
9
+ 8,P002,5,250.00,2024-01-22
10
+ 9,P001,4,200.00,2024-01-23
11
+ 10,P003,3,150.00,2024-01-24
12
+ 11,P002,2,100.00,2024-01-25
13
+ 12,P001,5,250.00,2024-01-26
server/data/simple.csv ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ product,price,category
2
+ Widget A,29.99,Electronics
3
+ Widget B,49.99,Electronics
4
+ Widget C,19.99,Electronics
5
+ Widget D,39.99,Electronics
6
+ Widget E,59.99,Electronics
7
+ Widget F,24.99,Electronics
8
+ Widget G,34.99,Electronics
9
+ Widget H,44.99,Electronics
10
+ Widget I,54.99,Electronics
11
+ Widget J,64.99,Electronics
server/quantum_openenv_env_environment.py ADDED
@@ -0,0 +1,533 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ from typing import Any, Optional
4
+
5
+ import pandas as pd
6
+ import uuid
7
+
8
+ from openenv.core.env_server import Environment
9
+
10
+ from models import (
11
+ DataAnalysisAction,
12
+ DataAnalysisObservation,
13
+ DataAnalysisState,
14
+ AVAILABLE_TOOLS,
15
+ )
16
+
17
+
18
+ TASKS = {
19
+ "task_1": {
20
+ "name": "Basic Statistics",
21
+ "description": "Load simple.csv and calculate the mean of the 'price' column",
22
+ "datafile": "simple.csv",
23
+ "target_column": "price",
24
+ "target_operation": "mean",
25
+ "expected_answer": None,
26
+ "difficulty": "easy",
27
+ },
28
+ "task_2": {
29
+ "name": "Data Cleaning",
30
+ "description": "Load dirty.csv, fill missing values, remove duplicates, then calculate median of 'age'",
31
+ "datafile": "dirty.csv",
32
+ "target_column": "age",
33
+ "target_operation": "median",
34
+ "expected_answer": None,
35
+ "difficulty": "medium",
36
+ },
37
+ "task_3": {
38
+ "name": "Multi-table Analysis",
39
+ "description": "Load sales.csv and products.csv, merge on product_id, calculate total sales per category",
40
+ "datafile": "sales.csv",
41
+ "secondary_datafile": "products.csv",
42
+ "target_column": "sales",
43
+ "group_by_column": "category",
44
+ "target_operation": "sum",
45
+ "expected_answer": None,
46
+ "difficulty": "hard",
47
+ },
48
+ }
49
+
50
+
51
+ class DataAnalysisEnvironment(Environment):
52
+ def __init__(self, data_dir: Optional[str] = None):
53
+ super().__init__()
54
+ self._data_dir = data_dir or str(Path(__file__).parent / "data")
55
+ self._state = DataAnalysisState()
56
+ self._df: Optional[pd.DataFrame] = None
57
+ self._secondary_df: Optional[pd.DataFrame] = None
58
+ self._last_result: Any = None
59
+ self._reward = 0.0
60
+ self._tools_used: list[str] = []
61
+
62
+ def reset(
63
+ self, seed: Optional[int] = None, episode_id: Optional[str] = None, **kwargs
64
+ ) -> DataAnalysisObservation:
65
+ task_name = kwargs.get("task", "task_1")
66
+
67
+ self._state = DataAnalysisState(
68
+ episode_id=episode_id or str(uuid.uuid4()),
69
+ task_name=task_name,
70
+ step_count=0,
71
+ max_steps=20,
72
+ data_loaded=False,
73
+ )
74
+
75
+ self._df = None
76
+ self._secondary_df = None
77
+ self._last_result = None
78
+ self._reward = 0.0
79
+ self._tools_used = []
80
+
81
+ task = TASKS.get(task_name, TASKS["task_1"])
82
+ datafile = os.path.join(self._data_dir, task.get("datafile", "simple.csv"))
83
+
84
+ if os.path.exists(datafile):
85
+ self._df = pd.read_csv(datafile)
86
+ self._state.data_loaded = True
87
+
88
+ if task_name == "task_1":
89
+ task["expected_answer"] = float(self._df[task["target_column"]].mean())
90
+ elif task_name == "task_2":
91
+ df_clean = self._df.fillna(
92
+ self._df.median(numeric_only=True)
93
+ ).drop_duplicates()
94
+ task["expected_answer"] = float(
95
+ df_clean[task["target_column"]].median()
96
+ )
97
+ elif task_name == "task_3":
98
+ secondary = os.path.join(
99
+ self._data_dir, task.get("secondary_datafile", "products.csv")
100
+ )
101
+ if os.path.exists(secondary):
102
+ self._secondary_df = pd.read_csv(secondary)
103
+ merged = self._df.merge(self._secondary_df, on="product_id")
104
+ task["expected_answer"] = (
105
+ merged.groupby(task["group_by_column"])[task["target_column"]]
106
+ .sum()
107
+ .to_dict()
108
+ )
109
+
110
+ return DataAnalysisObservation(
111
+ done=False,
112
+ reward=0.0,
113
+ success=True,
114
+ output=f"Ready. Task: {task['name']}. {task['description']}",
115
+ data_shape=tuple(self._df.shape) if self._df is not None else None,
116
+ columns=list(self._df.columns) if self._df is not None else [],
117
+ tools_used=[],
118
+ )
119
+
120
+ def step(self, action: DataAnalysisAction) -> DataAnalysisObservation:
121
+ self._state.step_count += 1
122
+
123
+ tool = action.tool
124
+ params = action.parameters
125
+
126
+ self._tools_used.append(f"{tool}({params})")
127
+
128
+ if tool not in AVAILABLE_TOOLS:
129
+ self._reward = max(0, self._reward - 0.1)
130
+ return DataAnalysisObservation(
131
+ done=False,
132
+ reward=self._reward,
133
+ success=False,
134
+ output=f"Unknown tool: {tool}",
135
+ data_shape=tuple(self._df.shape) if self._df is not None else None,
136
+ columns=list(self._df.columns) if self._df is not None else [],
137
+ tools_used=self._tools_used,
138
+ error=f"Tool '{tool}' not found. Available: {AVAILABLE_TOOLS}",
139
+ )
140
+
141
+ try:
142
+ result = self._execute_tool(tool, params)
143
+
144
+ if result["success"]:
145
+ self._reward = min(1.0, self._reward + 0.1)
146
+ else:
147
+ self._reward = max(0, self._reward - 0.1)
148
+
149
+ done = self._state.step_count >= self._state.max_steps
150
+ if done and self._reward < 0.5:
151
+ self._reward = 0.0
152
+
153
+ return DataAnalysisObservation(
154
+ done=done,
155
+ reward=self._reward,
156
+ success=result["success"],
157
+ output=result["output"],
158
+ data_shape=tuple(self._df.shape) if self._df is not None else None,
159
+ columns=list(self._df.columns) if self._df is not None else [],
160
+ tools_used=self._tools_used,
161
+ error=result.get("error"),
162
+ )
163
+
164
+ except Exception as e:
165
+ self._reward = max(0, self._reward - 0.1)
166
+ return DataAnalysisObservation(
167
+ done=False,
168
+ reward=self._reward,
169
+ success=False,
170
+ output=f"Error executing {tool}: {str(e)}",
171
+ data_shape=tuple(self._df.shape) if self._df is not None else None,
172
+ columns=list(self._df.columns) if self._df is not None else [],
173
+ tools_used=self._tools_used,
174
+ error=str(e),
175
+ )
176
+
177
+ def _execute_tool(self, tool: str, params: dict) -> dict:
178
+ if tool == "load_csv":
179
+ return self._tool_load_csv(params)
180
+ elif tool == "show_data":
181
+ return self._tool_show_data(params)
182
+ elif tool == "show_columns":
183
+ return self._tool_show_columns(params)
184
+ elif tool == "fill_missing":
185
+ return self._tool_fill_missing(params)
186
+ elif tool == "remove_duplicates":
187
+ return self._tool_remove_duplicates(params)
188
+ elif tool == "filter_rows":
189
+ return self._tool_filter_rows(params)
190
+ elif tool == "select_columns":
191
+ return self._tool_select_columns(params)
192
+ elif tool == "group_by":
193
+ return self._tool_group_by(params)
194
+ elif tool == "calculate":
195
+ return self._tool_calculate(params)
196
+ elif tool == "sort_by":
197
+ return self._tool_sort_by(params)
198
+ elif tool == "get_result":
199
+ return self._tool_get_result(params)
200
+ elif tool == "merge_datasets":
201
+ return self._tool_merge_datasets(params)
202
+ return {"success": False, "output": f"Unknown tool: {tool}"}
203
+
204
+ def _tool_load_csv(self, params: dict) -> dict:
205
+ filename = params.get("filename", "")
206
+ filepath = os.path.join(self._data_dir, filename)
207
+
208
+ if not os.path.exists(filepath):
209
+ return {
210
+ "success": False,
211
+ "output": f"File not found: {filename}",
212
+ "error": "FileNotFound",
213
+ }
214
+
215
+ self._df = pd.read_csv(filepath)
216
+ self._state.data_loaded = True
217
+
218
+ return {
219
+ "success": True,
220
+ "output": f"Loaded {filename}: {self._df.shape[0]} rows, {self._df.shape[1]} columns. Columns: {list(self._df.columns)}",
221
+ }
222
+
223
+ def _tool_show_data(self, params: dict) -> dict:
224
+ if self._df is None:
225
+ return {"success": False, "output": "No data loaded", "error": "NoData"}
226
+
227
+ n = params.get("n", 5)
228
+ head = self._df.head(n).to_string()
229
+
230
+ return {
231
+ "success": True,
232
+ "output": f"Data shape: {self._df.shape}\n{head}",
233
+ }
234
+
235
+ def _tool_show_columns(self, params: dict) -> dict:
236
+ if self._df is None:
237
+ return {"success": False, "output": "No data loaded", "error": "NoData"}
238
+
239
+ cols = [(col, str(self._df[col].dtype)) for col in self._df.columns]
240
+ output = "Columns:\n" + "\n".join([f" {c}: {t}" for c, t in cols])
241
+
242
+ return {"success": True, "output": output}
243
+
244
+ def _tool_fill_missing(self, params: dict) -> dict:
245
+ if self._df is None:
246
+ return {"success": False, "output": "No data loaded", "error": "NoData"}
247
+
248
+ method = params.get("value", "mean")
249
+
250
+ if method == "mean":
251
+ self._df = self._df.fillna(self._df.mean(numeric_only=True))
252
+ elif method == "median":
253
+ self._df = self._df.fillna(self._df.median(numeric_only=True))
254
+ elif method == "zero":
255
+ self._df = self._df.fillna(0)
256
+ else:
257
+ self._df = self._df.fillna(method)
258
+
259
+ return {
260
+ "success": True,
261
+ "output": f"Filled missing values with {method}. Shape: {self._df.shape}",
262
+ }
263
+
264
+ def _tool_remove_duplicates(self, params: dict) -> dict:
265
+ if self._df is None:
266
+ return {"success": False, "output": "No data loaded", "error": "NoData"}
267
+
268
+ before = len(self._df)
269
+ self._df = self._df.drop_duplicates()
270
+ removed = before - len(self._df)
271
+
272
+ return {
273
+ "success": True,
274
+ "output": f"Removed {removed} duplicate rows. Remaining: {len(self._df)} rows",
275
+ }
276
+
277
+ def _tool_filter_rows(self, params: dict) -> dict:
278
+ if self._df is None:
279
+ return {"success": False, "output": "No data loaded", "error": "NoData"}
280
+
281
+ column = params.get("column", "")
282
+ operator = params.get("operator", "==")
283
+ value = params.get("value", None)
284
+
285
+ if column not in self._df.columns:
286
+ return {
287
+ "success": False,
288
+ "output": f"Column not found: {column}",
289
+ "error": "ColumnNotFound",
290
+ }
291
+
292
+ try:
293
+ if operator == "==":
294
+ mask = self._df[column] == value
295
+ elif operator == "!=":
296
+ mask = self._df[column] != value
297
+ elif operator == ">":
298
+ mask = self._df[column] > value
299
+ elif operator == ">=":
300
+ mask = self._df[column] >= value
301
+ elif operator == "<":
302
+ mask = self._df[column] < value
303
+ elif operator == "<=":
304
+ mask = self._df[column] <= value
305
+ else:
306
+ return {
307
+ "success": False,
308
+ "output": f"Unknown operator: {operator}",
309
+ "error": "InvalidOperator",
310
+ }
311
+
312
+ self._df = self._df[mask]
313
+ return {"success": True, "output": f"Filtered to {len(self._df)} rows"}
314
+
315
+ except Exception as e:
316
+ return {
317
+ "success": False,
318
+ "output": f"Filter error: {str(e)}",
319
+ "error": str(e),
320
+ }
321
+
322
+ def _tool_select_columns(self, params: dict) -> dict:
323
+ if self._df is None:
324
+ return {"success": False, "output": "No data loaded", "error": "NoData"}
325
+
326
+ columns = params.get("columns", [])
327
+ missing = [c for c in columns if c not in self._df.columns]
328
+
329
+ if missing:
330
+ return {
331
+ "success": False,
332
+ "output": f"Columns not found: {missing}",
333
+ "error": "ColumnNotFound",
334
+ }
335
+
336
+ self._df = self._df[columns]
337
+ return {
338
+ "success": True,
339
+ "output": f"Selected columns: {columns}. Shape: {self._df.shape}",
340
+ }
341
+
342
+ def _tool_group_by(self, params: dict) -> dict:
343
+ if self._df is None:
344
+ return {"success": False, "output": "No data loaded", "error": "NoData"}
345
+
346
+ group_column = params.get("group_column", "")
347
+ agg_column = params.get("agg_column", "")
348
+ operation = params.get("operation", "mean")
349
+
350
+ if group_column not in self._df.columns or agg_column not in self._df.columns:
351
+ return {
352
+ "success": False,
353
+ "output": "Columns not found",
354
+ "error": "ColumnNotFound",
355
+ }
356
+
357
+ result = self._df.groupby(group_column)[agg_column].agg(operation)
358
+ self._last_result = result.to_dict()
359
+
360
+ return {
361
+ "success": True,
362
+ "output": f"Grouped by {group_column}, aggregated {agg_column} with {operation}:\n{result.to_string()}",
363
+ }
364
+
365
+ def _tool_calculate(self, params: dict) -> dict:
366
+ if self._df is None:
367
+ return {"success": False, "output": "No data loaded", "error": "NoData"}
368
+
369
+ column = params.get("column", "")
370
+ operation = params.get("operation", "mean")
371
+
372
+ if column not in self._df.columns:
373
+ return {
374
+ "success": False,
375
+ "output": f"Column not found: {column}",
376
+ "error": "ColumnNotFound",
377
+ }
378
+
379
+ try:
380
+ if operation == "mean":
381
+ result = self._df[column].mean()
382
+ elif operation == "median":
383
+ result = self._df[column].median()
384
+ elif operation == "sum":
385
+ result = self._df[column].sum()
386
+ elif operation == "count":
387
+ result = self._df[column].count()
388
+ elif operation == "std":
389
+ result = self._df[column].std()
390
+ elif operation == "min":
391
+ result = self._df[column].min()
392
+ elif operation == "max":
393
+ result = self._df[column].max()
394
+ else:
395
+ return {
396
+ "success": False,
397
+ "output": f"Unknown operation: {operation}",
398
+ "error": "InvalidOperation",
399
+ }
400
+
401
+ self._last_result = float(result)
402
+ return {"success": True, "output": f"{operation}({column}) = {result}"}
403
+
404
+ except Exception as e:
405
+ return {
406
+ "success": False,
407
+ "output": f"Calculation error: {str(e)}",
408
+ "error": str(e),
409
+ }
410
+
411
+ def _tool_sort_by(self, params: dict) -> dict:
412
+ if self._df is None:
413
+ return {"success": False, "output": "No data loaded", "error": "NoData"}
414
+
415
+ column = params.get("column", "")
416
+ ascending = params.get("ascending", True)
417
+
418
+ if column not in self._df.columns:
419
+ return {
420
+ "success": False,
421
+ "output": f"Column not found: {column}",
422
+ "error": "ColumnNotFound",
423
+ }
424
+
425
+ self._df = self._df.sort_values(by=column, ascending=ascending)
426
+ return {
427
+ "success": True,
428
+ "output": f"Sorted by {column} (ascending={ascending})",
429
+ }
430
+
431
+ def _tool_get_result(self, params: dict) -> dict:
432
+ task = TASKS.get(self._state.task_name, TASKS["task_1"])
433
+
434
+ if self._last_result is not None:
435
+ score = self._grade_result(self._last_result, task)
436
+ self._reward = min(1.0, self._reward + 0.5 * score)
437
+ return {
438
+ "success": True,
439
+ "output": f"Final result: {self._last_result}",
440
+ "score": score,
441
+ }
442
+
443
+ return {"success": False, "output": "No result available", "error": "NoResult"}
444
+
445
+ def _tool_merge_datasets(self, params: dict) -> dict:
446
+ filename = params.get("filename", "")
447
+ on = params.get("on", "")
448
+
449
+ filepath = os.path.join(self._data_dir, filename)
450
+
451
+ if not os.path.exists(filepath):
452
+ return {
453
+ "success": False,
454
+ "output": f"File not found: {filename}",
455
+ "error": "FileNotFound",
456
+ }
457
+
458
+ other_df = pd.read_csv(filepath)
459
+
460
+ if on not in self._df.columns or on not in other_df.columns:
461
+ return {
462
+ "success": False,
463
+ "output": f"Merge column not found: {on}",
464
+ "error": "ColumnNotFound",
465
+ }
466
+
467
+ self._df = self._df.merge(other_df, on=on)
468
+
469
+ return {
470
+ "success": True,
471
+ "output": f"Merged with {filename} on {on}. Shape: {self._df.shape}",
472
+ }
473
+
474
+ def _grade_result(self, result: Any, task: dict) -> float:
475
+ task_name = self._state.task_name
476
+
477
+ if task_name == "task_1":
478
+ expected = task.get("expected_answer", 0)
479
+ if expected is None:
480
+ return 0.0
481
+ try:
482
+ actual = float(result)
483
+ if abs(actual - expected) < 0.01:
484
+ return 1.0
485
+ elif abs(actual - expected) < abs(expected) * 0.1:
486
+ return 0.7
487
+ else:
488
+ return 0.3
489
+ except:
490
+ return 0.0
491
+
492
+ elif task_name == "task_2":
493
+ expected = task.get("expected_answer", 0)
494
+ if expected is None:
495
+ return 0.0
496
+ try:
497
+ actual = float(result)
498
+ if abs(actual - expected) < 0.01:
499
+ return 1.0
500
+ elif abs(actual - expected) < abs(expected) * 0.1:
501
+ return 0.7
502
+ else:
503
+ return 0.3
504
+ except:
505
+ return 0.0
506
+
507
+ elif task_name == "task_3":
508
+ expected = task.get("expected_answer", {})
509
+ if expected is None or not isinstance(expected, dict):
510
+ return 0.0
511
+ try:
512
+ actual = dict(result) if hasattr(result, "items") else result
513
+
514
+ if isinstance(actual, dict) and isinstance(expected, dict):
515
+ if set(actual.keys()) == set(expected.keys()):
516
+ total_error = sum(
517
+ abs(actual.get(k, 0) - expected.get(k, 0)) for k in expected
518
+ )
519
+ if total_error < 0.01:
520
+ return 1.0
521
+ elif total_error < 50:
522
+ return 0.7
523
+ else:
524
+ return 0.3
525
+ return 0.5
526
+ except:
527
+ return 0.2
528
+
529
+ return 0.0
530
+
531
+ @property
532
+ def state(self) -> DataAnalysisState:
533
+ return self._state
server/requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ openenv-core
2
+ pandas
3
+ fastapi
4
+ uvicorn
uv.lock ADDED
The diff for this file is too large to render. See raw diff