varb15 commited on
Commit
f5583f9
·
verified ·
1 Parent(s): 0c216ef

Upload folder using huggingface_hub

Browse files
Dockerfile CHANGED
@@ -26,11 +26,11 @@ RUN uv sync --no-editable 2>/dev/null || pip install -e .
26
  ENV PATH="/app/.venv/bin:$PATH"
27
  ENV PYTHONPATH="/app:$PYTHONPATH"
28
 
29
- # Health check
30
  HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
31
- CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:8000/health')" || exit 1
32
 
33
- EXPOSE 8000
34
 
35
  ENV ENABLE_WEB_INTERFACE=true
36
- CMD ["uvicorn", "dataqa_env.server.app:app", "--host", "0.0.0.0", "--port", "8000"]
 
26
  ENV PATH="/app/.venv/bin:$PATH"
27
  ENV PYTHONPATH="/app:$PYTHONPATH"
28
 
29
+ # Health check — HF Spaces uses port 7860
30
  HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
31
+ CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:7860/health')" || exit 1
32
 
33
+ EXPOSE 7860
34
 
35
  ENV ENABLE_WEB_INTERFACE=true
36
+ CMD ["uvicorn", "dataqa_env.server.app:app", "--host", "0.0.0.0", "--port", "7860"]
README.md CHANGED
@@ -1,65 +1,238 @@
1
  ---
2
  title: DataQA Environment Server
3
- emoji: 🔍
4
  colorFrom: blue
5
  colorTo: gray
6
  sdk: docker
7
  pinned: false
8
- app_port: 8000
9
- base_path: /web
10
  tags:
11
  - openenv
 
12
  ---
13
 
14
  # DataQA Environment
15
 
16
- An OpenEnv environment for **Data Quality Assurance** — an LLM agent inspects datasets with planted quality issues and must identify them all.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
- ## Overview
19
 
20
- DataQA simulates the real-world task of validating datasets before they enter ML training pipelines or production databases. The agent receives a corrupted dataset along with its schema and validation rules, then must identify all planted data quality issues.
21
 
22
- ### Why Data QA?
 
 
23
 
24
- Every ML engineer and data scientist spends significant time debugging data quality issues — missing values, type mismatches, inconsistencies, and subtle statistical anomalies. This environment turns that task into a structured, gradable challenge.
25
 
26
  ## Environment API
27
 
28
- | Endpoint | Description |
29
- |----------|-------------|
30
- | `reset(task_id)` | Start a new episode with a corrupted dataset |
31
- | `step(issues)` | Submit identified issues, receive F1-scored feedback |
32
- | `state()` | Get current episode state |
 
33
 
34
  ## Tasks
35
 
36
- | Task | Issues | Difficulty | Description |
37
- |------|--------|-----------|-------------|
38
- | `easy` | 4 | Beginner | Employee directory nulls, wrong types, duplicates, out-of-range |
39
- | `medium` | 6 | Intermediate | E-commerce orders format violations, inconsistent totals, duplicate keys |
40
- | `hard` | 8 | Advanced | ML experiment metadata data leakage signals, unreasonable GPU usage, timestamp ordering |
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
  ## Reward Function
43
 
44
- Scoring uses **F1 score** (harmonic mean of precision and recall):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
- - **Precision**: What fraction of reported issues are real?
47
- - **Recall**: What fraction of planted issues did the agent find?
48
- - **F1**: `2 * precision * recall / (precision + recall)`
49
 
50
- Issues are matched by `row:<N>,col:<column>,issue:<type>` keys.
51
 
52
- The agent gets up to 3 attempts per task with feedback on each attempt (true positives, false positives, missed count).
53
 
54
- ## Action/Observation Space
 
 
 
 
 
55
 
56
- **Action**: List of issue strings in format `row:<row_number>,col:<column_name>,issue:<issue_type>`
57
 
58
- **Observation**: Dataset CSV + schema + validation rules + feedback from previous attempt
59
 
60
- **Issue Types**: `missing_value`, `wrong_type`, `duplicate_row`, `out_of_range`, `format_violation`, `inconsistent_value`, `statistical_outlier`, `referential_integrity`
 
 
 
 
61
 
62
- ## Quick Start
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
  ```bash
65
  # Install
@@ -68,42 +241,76 @@ pip install -e .
68
  # Run server locally
69
  uvicorn dataqa_env.server.app:app --host 0.0.0.0 --port 8000
70
 
71
- # Run inference
72
- API_BASE_URL=https://api.groq.com/openai/v1 \
73
- MODEL_NAME=llama-3.3-70b-versatile \
74
- LLM_API_KEY=your-key \
75
  python inference.py
76
  ```
77
 
78
  ## Docker
79
 
80
  ```bash
81
- docker build -t dataqa-env -f dataqa_env/server/Dockerfile .
82
  docker run -p 8000:8000 dataqa-env
83
  ```
84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  ## Environment Variables
86
 
87
  | Variable | Description | Default |
88
  |----------|-------------|---------|
89
- | `API_BASE_URL` | LLM API endpoint | `https://api.groq.com/openai/v1` |
90
- | `MODEL_NAME` | Model identifier | `llama-3.3-70b-versatile` |
91
- | `HF_TOKEN` | HuggingFace token | - |
92
  | `ENV_URL` | Environment server URL | `http://localhost:8000` |
93
- | `LLM_API_KEY` | API key for LLM provider | Falls back to HF_TOKEN |
94
 
95
  ## Architecture
96
 
97
  ```
98
  dataqa_env/
99
- ├── models.py # Pydantic: DataQAAction, DataQAObservation, DataQAState
 
100
  ├── client.py # EnvClient for WebSocket connections
101
  ├── server/
102
- │ ├── environment.py # Core DataQAEnvironment (reset/step/state)
103
- │ ├── tasks.py # Task definitions + data corruption + grading
104
- │ ├── app.py # FastAPI server
105
  │ └── Dockerfile
106
- ├── openenv.yaml
107
- ├── pyproject.toml
108
- ── inference.py # LLM agent using OpenAI client
 
 
 
 
 
 
109
  ```
 
1
  ---
2
  title: DataQA Environment Server
3
+ emoji: "\U0001F50D"
4
  colorFrom: blue
5
  colorTo: gray
6
  sdk: docker
7
  pinned: false
8
+ app_port: 7860
 
9
  tags:
10
  - openenv
11
+ base_path: /web
12
  ---
13
 
14
  # DataQA Environment
15
 
16
+ A two-phase OpenEnv RL environment for **Data Quality Assurance** — an LLM agent inspects corrupted datasets, identifies all planted quality issues, and proposes data repairs.
17
+
18
+ ### Demo: Agent Trajectory Replay
19
+
20
+ ```
21
+ EASY TASK (Step 2) — All 6 issues found + 5 fixes proposed
22
+ Reward: 0.87 | Identify: 1.00 | Fix: 0.67
23
+ ✓ row:4 name: empty → "David Kim" (fix correct)
24
+ ✓ row:7 salary: "seventy-five thousand" → "75000" (fix correct)
25
+ ✓ row:9 salary: "5000" → "73000" (fix correct)
26
+ ✓ row:15 email: mismatch → "oscar.rivera@company.com" (fix correct)
27
+ ✓ row:18 start_date: "2027-06-15" → "2022-01-19" (fix correct)
28
+ ✓ row:21 duplicate row detected
29
+
30
+ HARD TASK (Step 1 → Step 2)
31
+ Step 1: Found 5/10, missed hard issues → Reward: 0.69
32
+ Step 2: Found 10/10 + 5 fixes proposed → Reward: 0.77
33
+ Issues requiring ML knowledge:
34
+ • val_loss < train_loss (data leakage signal)
35
+ • resnet18 using 42.5GB GPU (impossible)
36
+ • 350 epochs on ImageNet in 30 min (impossible)
37
+ • wav2vec2 at 98.5% accuracy (exceeds SOTA)
38
+ ```
39
+
40
+ > The interactive replay UI with color-coded dataset visualization is available on the HF Space.
41
 
42
+ ## Motivation
43
 
44
+ Every ML engineer and data scientist spends significant time debugging data quality issues missing values, type mismatches, logical inconsistencies, and subtle statistical anomalies before data enters ML pipelines or production databases. This is a genuine, high-frequency human task that directly impacts model quality and business outcomes.
45
 
46
+ DataQA turns this into a **two-phase RL challenge**:
47
+ 1. **Identify** — systematically inspect corrupted data and pinpoint every planted issue
48
+ 2. **Fix** — propose corrected values by reasoning about schema, constraints, and context
49
 
50
+ This creates a rich multi-step decision problem where agents must explore datasets strategically, distinguish subtle anomalies from noise, and reason about what the correct data should be.
51
 
52
  ## Environment API
53
 
54
+ | Endpoint | Method | Description |
55
+ |----------|--------|-------------|
56
+ | `/reset` | POST | Start a new episode with a corrupted dataset |
57
+ | `/step` | POST | Submit identified issues + proposed fixes |
58
+ | `/state` | GET | Get current episode state |
59
+ | `/health` | GET | Health check |
60
 
61
  ## Tasks
62
 
63
+ | Task | Issues | Difficulty | Domain | Description |
64
+ |------|--------|-----------|--------|-------------|
65
+ | `easy` | 6 | Beginner | HR/Employee data (21 rows) | Nulls, wrong types, duplicates, out-of-range, email-name mismatch, future dates |
66
+ | `medium` | 8 | Intermediate | E-commerce orders (31 rows) | Inconsistent totals, invalid categories, duplicate keys, wrong date formats, invalid country codes, future-date deliveries |
67
+ | `hard` | 10 | Advanced | ML experiment metadata (31 rows) | Data leakage signals, unreasonable GPU memory, impossibly fast training, SOTA-exceeding accuracy, timestamp ordering, whitespace-only fields |
68
+
69
+ **Difficulty progression**: Easy issues are individually obvious (empty fields, text in numeric columns). Medium issues require cross-column reasoning (total != qty * price) and set membership checks. Hard issues require ML domain knowledge (val_loss < train_loss = data leakage) and multi-row temporal reasoning.
70
+
71
+ ## Two-Phase Action Space
72
+
73
+ ### Phase 1: Identify Issues
74
+
75
+ Submit issues in format: `row:<row_number>,col:<column_name>,issue:<issue_type>`
76
+
77
+ - `row_number`: 1-indexed data row position (after header)
78
+ - `column_name`: Exact column header name, lowercase
79
+ - `issue_type`: One of the supported types below
80
+
81
+ ### Phase 2: Propose Fixes
82
+
83
+ Submit fixes in format: `row:<row_number>,col:<column_name>,fix:<corrected_value>`
84
+
85
+ The agent proposes the **correct value** that should replace the corrupted data. Fixes are graded against the original clean dataset.
86
+
87
+ Both phases can be submitted in the same step or across multiple steps.
88
+
89
+ **Supported Issue Types:**
90
+
91
+ | Type | Description | Example |
92
+ |------|-------------|---------|
93
+ | `missing_value` | Null, empty, or whitespace-only | Empty name field |
94
+ | `wrong_type` | Value doesn't match expected type | Salary as "seventy-five thousand" |
95
+ | `duplicate_row` | Exact duplicate or duplicate key | Two rows with same employee_id |
96
+ | `out_of_range` | Value outside valid range | Salary of 5000 when min is 50000 |
97
+ | `format_violation` | Wrong format or invalid enum | Date as DD/MM/YYYY instead of YYYY-MM-DD |
98
+ | `inconsistent_value` | Computed field mismatch, logical inconsistency | total != qty * price |
99
+ | `statistical_outlier` | Unreasonable value given context | resnet18 using 42.5GB GPU |
100
+ | `referential_integrity` | Foreign key violation | (available for custom tasks) |
101
+
102
+ ## Observation Space
103
+
104
+ | Field | Type | Description |
105
+ |-------|------|-------------|
106
+ | `dataset_csv` | str | The corrupted dataset in CSV format |
107
+ | `schema_description` | str | Column types, ranges, and constraints |
108
+ | `validation_rules` | str | Business rules the data must satisfy |
109
+ | `task_description` | str | Task context and instructions |
110
+ | `feedback` | str | Per-step results: TP/FP/FN, precision/recall, fix scores |
111
+ | `num_issues_hint` | int | Exact count of planted issues |
112
+ | `max_steps` | int | Maximum attempts allowed |
113
+ | `done` | bool | Whether episode has terminated |
114
+ | `reward` | float | Best combined reward so far (0.0-1.0) |
115
+
116
+ **Observation Metadata** (per step):
117
+ - Identify: `identify_f1`, `identify_score`, `precision`, `recall`, `tp`, `fp`, `fn`
118
+ - Fix: `fix_score`, `fixes_correct`, `fixes_partial`, `fixes_wrong`, `fixes_attempted`
119
+ - Combined: `combined_reward`, `difficulty_found`, `difficulty_missed`
120
 
121
  ## Reward Function
122
 
123
+ ### Combined Reward
124
+
125
+ ```
126
+ combined_reward = 0.6 * identify_score + 0.4 * fix_score
127
+ ```
128
+
129
+ If no fixes are submitted, `combined_reward = identify_score` (no penalty — backward compatible).
130
+
131
+ ### Identify Score (Difficulty-Weighted F1)
132
+
133
+ Each planted issue has a **difficulty weight** (1.0-3.0):
134
+
135
+ | Weight | Category | Examples |
136
+ |--------|----------|----------|
137
+ | 1.0 | Easy | Missing values, obvious out-of-range, wrong type |
138
+ | 1.5-2.0 | Medium | Duplicate keys, format violations, cross-column checks |
139
+ | 2.5-3.0 | Hard | Data leakage, statistical outliers, whitespace-only |
140
 
141
+ - **Weighted Recall** = (difficulty of found issues) / (total difficulty)
142
+ - **Weighted Precision** = penalizes false positives proportional to average difficulty
143
+ - **Weighted F1** = harmonic mean
144
 
145
+ ### Fix Score (Difficulty-Weighted Quality)
146
 
147
+ Each proposed fix is compared against the original clean value:
148
 
149
+ | Fix Quality | Score | Description |
150
+ |-------------|-------|-------------|
151
+ | Exact match | 1.0 | Case-insensitive, whitespace-stripped match |
152
+ | Numeric close | 0.8 | Within 1% of correct numeric value |
153
+ | Correct cell | 0.1 | Right location, wrong value |
154
+ | Non-issue cell | 0.0 | Fix targets a cell with no issue |
155
 
156
+ Fix score = (sum of best fix score per issue × difficulty weight) / (total difficulty weight)
157
 
158
+ ### Reward Properties
159
 
160
+ - **Per-step partial progress**: reward increases as more issues are found/fixed
161
+ - **Difficulty-aware**: finding subtle issues earns more than obvious ones
162
+ - **Penalizes bad behavior**: false positives reduce score, fixing non-issues earns nothing
163
+ - **Monotonically non-decreasing**: best score across all steps is the final reward
164
+ - **Always in [0.0, 1.0]**: meets hackathon requirement
165
 
166
+ ### Episode Boundaries
167
+
168
+ - Each task allows up to 3 steps (attempts)
169
+ - Episode ends when F1 >= 0.999 (perfect identification) or max steps reached
170
+ - Agent receives detailed feedback after each step to improve on next attempt
171
+
172
+ ## Baseline Scores
173
+
174
+ Baseline agent uses Qwen2.5-72B-Instruct via HuggingFace Router:
175
+
176
+ | Task | Identify Score | Fix Score | Combined | Notes |
177
+ |------|---------------|-----------|----------|-------|
178
+ | `easy` | 0.7-1.0 | 0.5-0.9 | 0.6-1.0 | Most LLMs find obvious issues reliably |
179
+ | `medium` | 0.5-0.8 | 0.3-0.6 | 0.4-0.7 | Cross-column reasoning challenges models |
180
+ | `hard` | 0.3-0.6 | 0.2-0.4 | 0.3-0.5 | ML domain knowledge and subtle patterns |
181
+
182
+ Scores vary by model. The hard task is designed to challenge frontier models.
183
+
184
+ ## Extensibility
185
+
186
+ ### Custom Contamination Rules
187
+
188
+ ```python
189
+ from dataqa_env import register_contamination_rule
190
+ from dataqa_env.server.tasks import PlantedIssue
191
+
192
+ def swap_digits(rows, header, col_idx, row_idx, rng):
193
+ val = rows[row_idx][col_idx]
194
+ corrupted = val[::-1]
195
+ issue = PlantedIssue(
196
+ row=row_idx + 1, col=header[col_idx],
197
+ issue_type="format_violation",
198
+ description=f"Digits swapped in {header[col_idx]}",
199
+ difficulty=2.0,
200
+ )
201
+ return corrupted, issue
202
+
203
+ register_contamination_rule("swap_digits", swap_digits)
204
+ ```
205
+
206
+ ### Custom Tasks from Config
207
+
208
+ ```python
209
+ from dataqa_env import create_task_from_config, register_task
210
+
211
+ task = create_task_from_config(
212
+ task_id="custom",
213
+ name="Custom Validation",
214
+ description="Find quality issues in this dataset.",
215
+ schema_description="id: int, name: str, score: int (0-100)",
216
+ validation_rules="No missing values. Scores must be 0-100.",
217
+ clean_csv="id,name,score\n1,Alice,95\n2,Bob,87\n3,Carol,92",
218
+ contaminations=[
219
+ {"rule": "missing_value", "row": 0, "col": 1, "difficulty": 1.0},
220
+ {"rule": "negative_value", "row": 2, "col": 2, "difficulty": 1.5},
221
+ ],
222
+ )
223
+ register_task("custom", lambda seed: task)
224
+ ```
225
+
226
+ ### Built-in Contamination Rules
227
+
228
+ | Rule | Effect | Default Difficulty |
229
+ |------|--------|--------------------|
230
+ | `missing_value` | Sets field to empty string | 1.0 |
231
+ | `whitespace_value` | Sets field to single space | 2.5 |
232
+ | `wrong_type_text` | Replaces with random text | 1.0 |
233
+ | `negative_value` | Negates numeric value | 1.0 |
234
+
235
+ ## Setup & Quick Start
236
 
237
  ```bash
238
  # Install
 
241
  # Run server locally
242
  uvicorn dataqa_env.server.app:app --host 0.0.0.0 --port 8000
243
 
244
+ # Run inference (set your API credentials)
245
+ API_BASE_URL=https://router.huggingface.co/v1 \
246
+ MODEL_NAME=Qwen/Qwen2.5-72B-Instruct \
247
+ HF_TOKEN=your-token \
248
  python inference.py
249
  ```
250
 
251
  ## Docker
252
 
253
  ```bash
254
+ docker build -t dataqa-env .
255
  docker run -p 8000:8000 dataqa-env
256
  ```
257
 
258
+ ## Testing
259
+
260
+ ```bash
261
+ pip install -e ".[dev]"
262
+ pytest tests/ -v
263
+ ```
264
+
265
+ 118 tests covering:
266
+ - Task creation, corruption, and difficulty weights
267
+ - Issue key and fix parsing (standard, lenient, edge cases)
268
+ - F1, weighted reward, and fix quality computation
269
+ - Full environment lifecycle (identify-only and identify+fix)
270
+ - Combined reward calculation and weight verification
271
+ - Inference script parsing and prompt building
272
+ - Structured log format ([START], [STEP], [END])
273
+ - Score bounds (0.0-1.0), best-score monotonicity
274
+ - Extensibility API (custom rules, custom tasks)
275
+
276
+ ## Validation
277
+
278
+ ```bash
279
+ # OpenEnv spec validation
280
+ openenv validate .
281
+
282
+ # Pre-submission validation (requires HF Space URL)
283
+ ./prevalidation_script.sh https://your-space.hf.space
284
+ ```
285
+
286
  ## Environment Variables
287
 
288
  | Variable | Description | Default |
289
  |----------|-------------|---------|
290
+ | `API_BASE_URL` | LLM API endpoint | `https://router.huggingface.co/v1` |
291
+ | `MODEL_NAME` | Model identifier | `Qwen/Qwen2.5-72B-Instruct` |
292
+ | `HF_TOKEN` | HuggingFace token / API key | - |
293
  | `ENV_URL` | Environment server URL | `http://localhost:8000` |
 
294
 
295
  ## Architecture
296
 
297
  ```
298
  dataqa_env/
299
+ ├── __init__.py # Public API + extensibility exports
300
+ ├── models.py # Pydantic: DataQAAction (issues + fixes), DataQAObservation, DataQAState
301
  ├── client.py # EnvClient for WebSocket connections
302
  ├── server/
303
+ │ ├── environment.py # Two-phase DataQAEnvironment (identify + fix + combined reward)
304
+ │ ├── tasks.py # Task definitions + contamination rules + extensibility API
305
+ │ ├── app.py # FastAPI server (via openenv-core create_app)
306
  │ └── Dockerfile
307
+ tests/
308
+ ├── test_tasks.py # Task creation, corruption, difficulty weights
309
+ ── test_environment.py # Identify scoring, fix grading, combined reward, lifecycle
310
+ ├── test_inference.py # LLM response parsing, fix parsing, prompt building, log format
311
+ └── test_extensibility.py # Custom rules, custom tasks, registration API
312
+ inference.py # Two-phase baseline agent (identify → fix)
313
+ openenv.yaml # OpenEnv/HF Spaces spec
314
+ pyproject.toml # Package metadata and dependencies
315
+ Dockerfile # Production container
316
  ```
__init__.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from dataqa_env import DataQAEnv, DataQAAction, DataQAObservation, DataQAState
2
 
3
  __all__ = ["DataQAEnv", "DataQAAction", "DataQAObservation", "DataQAState"]
 
1
+ """Root-level package for OpenEnv compatibility."""
2
  from dataqa_env import DataQAEnv, DataQAAction, DataQAObservation, DataQAState
3
 
4
  __all__ = ["DataQAEnv", "DataQAAction", "DataQAObservation", "DataQAState"]
dataqa_env/__init__.py CHANGED
@@ -1,4 +1,19 @@
1
  from .client import DataQAEnv
2
  from .models import DataQAAction, DataQAObservation, DataQAState
 
 
 
 
 
 
3
 
4
- __all__ = ["DataQAEnv", "DataQAAction", "DataQAObservation", "DataQAState"]
 
 
 
 
 
 
 
 
 
 
1
  from .client import DataQAEnv
2
  from .models import DataQAAction, DataQAObservation, DataQAState
3
+ from .server.tasks import (
4
+ create_task_from_config,
5
+ register_task,
6
+ register_contamination_rule,
7
+ CONTAMINATION_RULES,
8
+ )
9
 
10
+ __all__ = [
11
+ "DataQAEnv",
12
+ "DataQAAction",
13
+ "DataQAObservation",
14
+ "DataQAState",
15
+ "create_task_from_config",
16
+ "register_task",
17
+ "register_contamination_rule",
18
+ "CONTAMINATION_RULES",
19
+ ]
dataqa_env/models.py CHANGED
@@ -16,21 +16,23 @@ from openenv.core.env_server.interfaces import Action, Observation, State
16
 
17
  class DataQAAction(Action):
18
  """
19
- Agent submits a list of identified data quality issues.
 
 
 
 
 
 
 
20
 
21
- Each issue is a string in the format: "row:<row_idx>,col:<col_name>,issue:<issue_type>"
22
  Supported issue types:
23
- - missing_value
24
- - wrong_type
25
- - duplicate_row
26
- - out_of_range
27
- - format_violation
28
- - inconsistent_value
29
- - statistical_outlier
30
- - referential_integrity
31
  """
32
 
33
  issues: List[str]
 
34
  # Include task_id so step() can reconstruct context in stateless HTTP mode
35
  task_id: str = "easy"
36
 
 
16
 
17
  class DataQAAction(Action):
18
  """
19
+ Agent submits identified issues AND optional proposed fixes.
20
+
21
+ Two-phase action space:
22
+ Phase 1 (Identify): List issues in format "row:<N>,col:<name>,issue:<type>"
23
+ Phase 2 (Fix): List fixes in format "row:<N>,col:<name>,fix:<proposed_value>"
24
+
25
+ The agent can submit both in the same step or across multiple steps.
26
+ Combined reward = 0.6 * identify_score + 0.4 * fix_score
27
 
 
28
  Supported issue types:
29
+ missing_value, wrong_type, duplicate_row, out_of_range,
30
+ format_violation, inconsistent_value, statistical_outlier,
31
+ referential_integrity
 
 
 
 
 
32
  """
33
 
34
  issues: List[str]
35
+ fixes: List[str] = []
36
  # Include task_id so step() can reconstruct context in stateless HTTP mode
37
  task_id: str = "easy"
38
 
dataqa_env/server/app.py CHANGED
@@ -19,9 +19,20 @@ app = create_app(
19
  )
20
 
21
 
 
 
 
 
 
 
 
 
 
 
 
22
  def main():
23
  import uvicorn
24
- uvicorn.run(app, host="0.0.0.0", port=8000)
25
 
26
 
27
  if __name__ == "__main__":
 
19
  )
20
 
21
 
22
+ @app.get("/")
23
+ def root():
24
+ """Root endpoint — environment info."""
25
+ return {
26
+ "name": "DataQA Environment",
27
+ "description": "Two-phase data quality assurance environment: identify issues + propose fixes",
28
+ "tasks": ["easy", "medium", "hard"],
29
+ "endpoints": ["/health", "/reset", "/step", "/state"],
30
+ }
31
+
32
+
33
  def main():
34
  import uvicorn
35
+ uvicorn.run(app, host="0.0.0.0", port=7860)
36
 
37
 
38
  if __name__ == "__main__":
dataqa_env/server/environment.py CHANGED
@@ -3,8 +3,12 @@ DataQA Environment
3
  ------------------
4
  Server-side environment for data quality assurance tasks.
5
 
6
- The agent receives corrupted datasets and must identify planted quality issues.
7
- Scoring is based on F1 (precision-recall) of correctly matched issues.
 
 
 
 
8
  """
9
 
10
  from __future__ import annotations
@@ -18,6 +22,10 @@ from openenv.core.env_server.interfaces import Action, Environment, Observation
18
  from ..models import DataQAAction, DataQAObservation, DataQAState
19
  from .tasks import PlantedIssue, Task, get_task, list_tasks
20
 
 
 
 
 
21
 
22
  def parse_issue_key(raw: str) -> Optional[str]:
23
  """
@@ -26,7 +34,6 @@ def parse_issue_key(raw: str) -> Optional[str]:
26
  Returns normalized key or None if unparseable.
27
  """
28
  raw = raw.strip().lower()
29
- # Be lenient with formatting
30
  row_match = re.search(r"row\s*[:=]\s*(\d+)", raw)
31
  col_match = re.search(r"col\s*[:=]\s*([\w_]+)", raw)
32
  issue_match = re.search(r"issue\s*[:=]\s*([\w_]+)", raw)
@@ -36,6 +43,22 @@ def parse_issue_key(raw: str) -> Optional[str]:
36
  return None
37
 
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  def compute_f1(reported_keys: Set[str], planted_keys: Set[str]) -> dict:
40
  """Compute precision, recall, and F1 score."""
41
  if not reported_keys and not planted_keys:
@@ -58,12 +81,185 @@ def compute_f1(reported_keys: Set[str], planted_keys: Set[str]) -> dict:
58
  return {"precision": precision, "recall": recall, "f1": f1, "tp": tp, "fp": fp, "fn": fn}
59
 
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  class DataQAEnvironment(Environment):
62
  """
63
- Data Quality Assurance environment.
64
 
65
- The agent inspects corrupted datasets and reports quality issues.
66
- Reward is F1 score of correctly identified issues vs planted ground truth.
 
 
 
67
  """
68
 
69
  SUPPORTS_CONCURRENT_SESSIONS = True
@@ -103,7 +299,11 @@ class DataQAEnvironment(Environment):
103
  schema_description=self._current_task.schema_description,
104
  validation_rules=self._current_task.validation_rules,
105
  task_description=self._current_task.description,
106
- feedback="Environment reset. Inspect the dataset and report all quality issues.",
 
 
 
 
107
  task_id=task_id,
108
  num_issues_hint=len(self._current_task.planted_issues),
109
  max_steps=self._current_task.max_steps,
@@ -120,15 +320,14 @@ class DataQAEnvironment(Environment):
120
  if not isinstance(action, DataQAAction):
121
  raise ValueError(f"Expected DataQAAction, got {type(action)}")
122
 
123
- # In stateless HTTP mode, each request creates a fresh env instance.
124
- # Auto-reset using the task_id from the action so step() works standalone.
125
  if self._current_task is None:
126
  self.reset(task_id=action.task_id)
127
 
128
  self._state.step_count += 1
129
  self._state.current_step += 1
130
 
131
- # Parse reported issues
132
  reported_keys: Set[str] = set()
133
  parse_errors: list[str] = []
134
  for raw_issue in action.issues:
@@ -136,44 +335,148 @@ class DataQAEnvironment(Environment):
136
  if key:
137
  reported_keys.add(key)
138
  else:
139
- parse_errors.append(f"Could not parse: '{raw_issue}'")
140
 
141
- # Compute score
142
  metrics = compute_f1(reported_keys, self._planted_keys)
143
- score = metrics["f1"]
144
- self._best_score = max(self._best_score, score)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  self._state.best_score = self._best_score
146
 
147
- # Check if done
148
  is_done = (
149
- score >= 0.999 # Perfect score
150
  or self._state.current_step >= self._state.max_steps
151
  )
152
 
153
- # Build feedback
 
 
 
 
154
  feedback_lines = [
155
  f"Step {self._state.current_step}/{self._state.max_steps}",
 
 
156
  f"Issues reported: {len(reported_keys)}",
157
  f"True positives: {metrics['tp']}, False positives: {metrics['fp']}, Missed: {metrics['fn']}",
158
- f"Precision: {metrics['precision']:.3f}, Recall: {metrics['recall']:.3f}, F1: {score:.3f}",
 
159
  ]
160
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
  if parse_errors:
162
- feedback_lines.append(f"Parse errors ({len(parse_errors)}): {'; '.join(parse_errors[:3])}")
163
 
164
  if not is_done:
165
- # Give hints about what was missed without revealing exact answers
166
  if metrics["fn"] > 0:
167
  feedback_lines.append(
168
- f"You missed {metrics['fn']} issue(s). Review the dataset carefully."
169
  )
170
  if metrics["fp"] > 0:
171
  feedback_lines.append(
172
- f"{metrics['fp']} of your reported issues were incorrect."
173
  )
174
- feedback_lines.append("You can submit again with an updated list of issues.")
175
  else:
176
- feedback_lines.append(f"Task complete! Final best F1 score: {self._best_score:.3f}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
 
178
  return DataQAObservation(
179
  dataset_csv=self._current_task.corrupted_csv,
@@ -186,6 +489,25 @@ class DataQAEnvironment(Environment):
186
  max_steps=self._state.max_steps,
187
  done=is_done,
188
  reward=self._best_score,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
  )
190
 
191
  @property
 
3
  ------------------
4
  Server-side environment for data quality assurance tasks.
5
 
6
+ Two-phase RL environment:
7
+ Phase 1 (Identify): Agent inspects corrupted datasets and reports quality issues.
8
+ Phase 2 (Fix): Agent proposes corrections for identified issues.
9
+
10
+ Combined reward = 0.6 * identify_score + 0.4 * fix_score
11
+ Both phases scored with difficulty-weighted metrics for rich per-step signal.
12
  """
13
 
14
  from __future__ import annotations
 
22
  from ..models import DataQAAction, DataQAObservation, DataQAState
23
  from .tasks import PlantedIssue, Task, get_task, list_tasks
24
 
25
+ # Reward weights for the two phases
26
+ IDENTIFY_WEIGHT = 0.6
27
+ FIX_WEIGHT = 0.4
28
+
29
 
30
  def parse_issue_key(raw: str) -> Optional[str]:
31
  """
 
34
  Returns normalized key or None if unparseable.
35
  """
36
  raw = raw.strip().lower()
 
37
  row_match = re.search(r"row\s*[:=]\s*(\d+)", raw)
38
  col_match = re.search(r"col\s*[:=]\s*([\w_]+)", raw)
39
  issue_match = re.search(r"issue\s*[:=]\s*([\w_]+)", raw)
 
43
  return None
44
 
45
 
46
+ def parse_fix(raw: str) -> Optional[tuple[int, str, str]]:
47
+ """
48
+ Parse an agent-proposed fix into (row, col, proposed_value).
49
+ Expected format: row:<N>,col:<name>,fix:<value>
50
+ Returns (row, col, value) or None if unparseable.
51
+ """
52
+ raw = raw.strip()
53
+ row_match = re.search(r"row\s*[:=]\s*(\d+)", raw, re.IGNORECASE)
54
+ col_match = re.search(r"col(?:umn)?\s*[:=]\s*([\w_]+)", raw, re.IGNORECASE)
55
+ fix_match = re.search(r"fix\s*[:=]\s*(.+?)$", raw, re.IGNORECASE)
56
+
57
+ if row_match and col_match and fix_match:
58
+ return (int(row_match.group(1)), col_match.group(1).lower(), fix_match.group(1).strip())
59
+ return None
60
+
61
+
62
  def compute_f1(reported_keys: Set[str], planted_keys: Set[str]) -> dict:
63
  """Compute precision, recall, and F1 score."""
64
  if not reported_keys and not planted_keys:
 
81
  return {"precision": precision, "recall": recall, "f1": f1, "tp": tp, "fp": fp, "fn": fn}
82
 
83
 
84
+ def compute_weighted_reward(
85
+ reported_keys: Set[str],
86
+ planted_issues: list,
87
+ ) -> dict:
88
+ """
89
+ Compute difficulty-weighted reward for richer per-step signal.
90
+
91
+ Each planted issue has a difficulty weight (1.0-3.0). Finding harder issues
92
+ earns more reward. False positives incur a penalty scaled by average difficulty.
93
+
94
+ Returns dict with weighted_reward (0.0-1.0), plus per-issue breakdown.
95
+ """
96
+ if not planted_issues and not reported_keys:
97
+ return {"weighted_reward": 1.0, "difficulty_found": 0.0, "difficulty_missed": 0.0}
98
+
99
+ planted_by_key = {issue.to_key(): issue for issue in planted_issues}
100
+ planted_keys = set(planted_by_key.keys())
101
+
102
+ if not reported_keys:
103
+ total_weight = sum(i.difficulty for i in planted_issues)
104
+ return {"weighted_reward": 0.0, "difficulty_found": 0.0, "difficulty_missed": total_weight}
105
+
106
+ if not planted_keys:
107
+ return {"weighted_reward": 0.0, "difficulty_found": 0.0, "difficulty_missed": 0.0}
108
+
109
+ found_keys = reported_keys & planted_keys
110
+ missed_keys = planted_keys - reported_keys
111
+ false_positive_count = len(reported_keys - planted_keys)
112
+
113
+ difficulty_found = sum(planted_by_key[k].difficulty for k in found_keys)
114
+ difficulty_missed = sum(planted_by_key[k].difficulty for k in missed_keys)
115
+ total_weight = sum(i.difficulty for i in planted_issues)
116
+
117
+ weighted_recall = difficulty_found / total_weight if total_weight > 0 else 0.0
118
+
119
+ avg_difficulty = total_weight / len(planted_issues)
120
+ fp_penalty_weight = false_positive_count * avg_difficulty
121
+ weighted_precision = difficulty_found / (difficulty_found + fp_penalty_weight) if (difficulty_found + fp_penalty_weight) > 0 else 0.0
122
+
123
+ if (weighted_precision + weighted_recall) > 0:
124
+ weighted_reward = 2 * weighted_precision * weighted_recall / (weighted_precision + weighted_recall)
125
+ else:
126
+ weighted_reward = 0.0
127
+
128
+ return {
129
+ "weighted_reward": round(weighted_reward, 4),
130
+ "difficulty_found": round(difficulty_found, 2),
131
+ "difficulty_missed": round(difficulty_missed, 2),
132
+ }
133
+
134
+
135
+ def grade_fixes(
136
+ fixes: list[tuple[int, str, str]],
137
+ task: Task,
138
+ ) -> dict:
139
+ """
140
+ Grade proposed fixes against the clean dataset.
141
+
142
+ For each fix (row, col, proposed_value), compare to the original clean value.
143
+ Scoring per fix:
144
+ - Exact match (case-insensitive, whitespace-stripped): 1.0
145
+ - Numeric close match (within 1%): 0.8
146
+ - Correct column but wrong value: 0.1
147
+ - Targets a non-issue cell: 0.0 (penalty)
148
+
149
+ Returns dict with fix_score (0.0-1.0), details per fix, and counts.
150
+ """
151
+ if not fixes and not task.planted_issues:
152
+ return {"fix_score": 1.0, "fixes_correct": 0, "fixes_partial": 0,
153
+ "fixes_wrong": 0, "fixes_attempted": 0, "fix_details": []}
154
+
155
+ if not fixes:
156
+ return {"fix_score": 0.0, "fixes_correct": 0, "fixes_partial": 0,
157
+ "fixes_wrong": 0, "fixes_attempted": 0, "fix_details": []}
158
+
159
+ issue_map = task.get_planted_issue_map()
160
+ # Build set of (row, col) that are actual issues
161
+ issue_cells = {(issue.row, issue.col) for issue in task.planted_issues}
162
+
163
+ total_weight = sum(i.difficulty for i in task.planted_issues) if task.planted_issues else 1.0
164
+ earned_weight = 0.0
165
+ fixes_correct = 0
166
+ fixes_partial = 0
167
+ fixes_wrong = 0
168
+ fix_details = []
169
+
170
+ # Track which issues have been fixed (best fix wins)
171
+ fixed_issues: dict[tuple[int, str], float] = {}
172
+
173
+ for row, col, proposed in fixes:
174
+ clean_value = task.get_clean_value(row, col)
175
+ cell_key = (row, col)
176
+
177
+ if cell_key not in issue_cells:
178
+ # Fix targets a non-issue cell — no credit
179
+ fix_details.append({"row": row, "col": col, "score": 0.0, "reason": "not an issue cell"})
180
+ fixes_wrong += 1
181
+ continue
182
+
183
+ if clean_value is None:
184
+ fix_details.append({"row": row, "col": col, "score": 0.0, "reason": "cell not found"})
185
+ fixes_wrong += 1
186
+ continue
187
+
188
+ # Find the planted issue for this cell to get its difficulty weight
189
+ matching_issue = None
190
+ for issue in task.planted_issues:
191
+ if issue.row == row and issue.col == col:
192
+ matching_issue = issue
193
+ break
194
+
195
+ difficulty = matching_issue.difficulty if matching_issue else 1.0
196
+
197
+ # Score the fix
198
+ score = 0.0
199
+ reason = "wrong value"
200
+
201
+ # Exact match (case-insensitive, whitespace-stripped)
202
+ if proposed.strip().lower() == clean_value.lower():
203
+ score = 1.0
204
+ reason = "exact match"
205
+ fixes_correct += 1
206
+ else:
207
+ # Try numeric close match
208
+ try:
209
+ proposed_num = float(proposed.strip())
210
+ clean_num = float(clean_value)
211
+ if clean_num != 0 and abs(proposed_num - clean_num) / abs(clean_num) <= 0.01:
212
+ score = 0.8
213
+ reason = "numeric close match"
214
+ fixes_partial += 1
215
+ elif proposed_num == clean_num:
216
+ score = 1.0
217
+ reason = "exact numeric match"
218
+ fixes_correct += 1
219
+ else:
220
+ score = 0.1
221
+ reason = "correct cell, wrong value"
222
+ fixes_partial += 1
223
+ except (ValueError, ZeroDivisionError):
224
+ # Not numeric — just a wrong value but at least right cell
225
+ score = 0.1
226
+ reason = "correct cell, wrong value"
227
+ fixes_partial += 1
228
+
229
+ # Keep best fix per cell
230
+ if cell_key not in fixed_issues or score > fixed_issues[cell_key]:
231
+ fixed_issues[cell_key] = score
232
+
233
+ fix_details.append({"row": row, "col": col, "score": score, "reason": reason})
234
+
235
+ # Compute fix score: weighted sum of best fix per issue / total weight
236
+ for issue in task.planted_issues:
237
+ cell_key = (issue.row, issue.col)
238
+ if cell_key in fixed_issues:
239
+ earned_weight += issue.difficulty * fixed_issues[cell_key]
240
+
241
+ fix_score = earned_weight / total_weight if total_weight > 0 else 0.0
242
+ fix_score = min(max(fix_score, 0.0), 1.0)
243
+
244
+ return {
245
+ "fix_score": round(fix_score, 4),
246
+ "fixes_correct": fixes_correct,
247
+ "fixes_partial": fixes_partial,
248
+ "fixes_wrong": fixes_wrong,
249
+ "fixes_attempted": len(fixes),
250
+ "fix_details": fix_details,
251
+ }
252
+
253
+
254
  class DataQAEnvironment(Environment):
255
  """
256
+ Data Quality Assurance environment — two-phase identify + fix.
257
 
258
+ Phase 1 (Identify): Agent inspects corrupted datasets and reports quality issues.
259
+ Phase 2 (Fix): Agent proposes corrections for identified issues.
260
+
261
+ Combined reward = 0.6 * identify_score + 0.4 * fix_score
262
+ Both phases use difficulty-weighted scoring for rich per-step reward signals.
263
  """
264
 
265
  SUPPORTS_CONCURRENT_SESSIONS = True
 
299
  schema_description=self._current_task.schema_description,
300
  validation_rules=self._current_task.validation_rules,
301
  task_description=self._current_task.description,
302
+ feedback=(
303
+ "Environment reset. Inspect the dataset and report all quality issues.\n"
304
+ "You can also propose fixes in format: row:<N>,col:<name>,fix:<corrected_value>\n"
305
+ "Combined reward = 0.6 * identify_score + 0.4 * fix_score"
306
+ ),
307
  task_id=task_id,
308
  num_issues_hint=len(self._current_task.planted_issues),
309
  max_steps=self._current_task.max_steps,
 
320
  if not isinstance(action, DataQAAction):
321
  raise ValueError(f"Expected DataQAAction, got {type(action)}")
322
 
323
+ # Auto-reset in stateless HTTP mode
 
324
  if self._current_task is None:
325
  self.reset(task_id=action.task_id)
326
 
327
  self._state.step_count += 1
328
  self._state.current_step += 1
329
 
330
+ # ── Phase 1: Parse and score issue identification ──
331
  reported_keys: Set[str] = set()
332
  parse_errors: list[str] = []
333
  for raw_issue in action.issues:
 
335
  if key:
336
  reported_keys.add(key)
337
  else:
338
+ parse_errors.append(f"Could not parse issue: '{raw_issue}'")
339
 
 
340
  metrics = compute_f1(reported_keys, self._planted_keys)
341
+ identify_f1 = metrics["f1"]
342
+
343
+ weighted = compute_weighted_reward(reported_keys, self._current_task.planted_issues)
344
+ identify_score = weighted["weighted_reward"]
345
+
346
+ # ── Phase 2: Parse and score proposed fixes ──
347
+ parsed_fixes: list[tuple[int, str, str]] = []
348
+ for raw_fix in action.fixes:
349
+ fix = parse_fix(raw_fix)
350
+ if fix:
351
+ parsed_fixes.append(fix)
352
+ else:
353
+ parse_errors.append(f"Could not parse fix: '{raw_fix}'")
354
+
355
+ fix_result = grade_fixes(parsed_fixes, self._current_task)
356
+ fix_score = fix_result["fix_score"]
357
+
358
+ # ── Combined reward ──
359
+ # If no fixes submitted, score is identify-only (no penalty for not fixing)
360
+ if action.fixes:
361
+ combined_reward = IDENTIFY_WEIGHT * identify_score + FIX_WEIGHT * fix_score
362
+ else:
363
+ combined_reward = identify_score # backward compatible
364
+
365
+ self._best_score = max(self._best_score, combined_reward)
366
  self._state.best_score = self._best_score
367
 
368
+ # ── Check if done ──
369
  is_done = (
370
+ identify_f1 >= 0.999 # Perfect identification
371
  or self._state.current_step >= self._state.max_steps
372
  )
373
 
374
+ # ── Build feedback with actionable diagnostics ──
375
+ # Show the agent exactly which reported issues were correct (TP) and which were wrong (FP)
376
+ tp_keys = reported_keys & self._planted_keys
377
+ fp_keys = reported_keys - self._planted_keys
378
+
379
  feedback_lines = [
380
  f"Step {self._state.current_step}/{self._state.max_steps}",
381
+ "",
382
+ "--- Identification ---",
383
  f"Issues reported: {len(reported_keys)}",
384
  f"True positives: {metrics['tp']}, False positives: {metrics['fp']}, Missed: {metrics['fn']}",
385
+ f"Precision: {metrics['precision']:.3f}, Recall: {metrics['recall']:.3f}, F1: {identify_f1:.3f}",
386
+ f"Identify score (weighted): {identify_score:.3f}",
387
  ]
388
 
389
+ # Show which reported issues were correct vs wrong (helps agent self-correct)
390
+ if tp_keys:
391
+ feedback_lines.append(f"Correct issues: {', '.join(sorted(tp_keys))}")
392
+ if fp_keys:
393
+ feedback_lines.append(f"Incorrect issues (false positives): {', '.join(sorted(fp_keys))}")
394
+
395
+ if action.fixes:
396
+ feedback_lines += [
397
+ "",
398
+ "--- Fix Proposals ---",
399
+ f"Fixes attempted: {fix_result['fixes_attempted']}",
400
+ f"Correct: {fix_result['fixes_correct']}, Partial: {fix_result['fixes_partial']}, Wrong: {fix_result['fixes_wrong']}",
401
+ f"Fix score: {fix_score:.3f}",
402
+ ]
403
+ # Show per-fix feedback so agent knows which fixes worked
404
+ for detail in fix_result["fix_details"]:
405
+ status = "correct" if detail["score"] >= 0.99 else ("partial" if detail["score"] > 0 else "wrong")
406
+ feedback_lines.append(
407
+ f" row:{detail['row']},col:{detail['col']} -> {status} ({detail['reason']})"
408
+ )
409
+ feedback_lines.append(
410
+ f"\n--- Combined Reward: {combined_reward:.3f} (identify={identify_score:.3f} x {IDENTIFY_WEIGHT} + fix={fix_score:.3f} x {FIX_WEIGHT}) ---"
411
+ )
412
+ else:
413
+ feedback_lines += [
414
+ "",
415
+ "Tip: Submit fixes with format row:<N>,col:<name>,fix:<value> for bonus reward.",
416
+ ]
417
+
418
  if parse_errors:
419
+ feedback_lines.append(f"\nParse errors ({len(parse_errors)}): {'; '.join(parse_errors[:5])}")
420
 
421
  if not is_done:
 
422
  if metrics["fn"] > 0:
423
  feedback_lines.append(
424
+ f"\nYou missed {metrics['fn']} issue(s). Review the dataset carefully."
425
  )
426
  if metrics["fp"] > 0:
427
  feedback_lines.append(
428
+ f"Remove the {metrics['fp']} false positive(s) listed above and look for real issues."
429
  )
430
+ feedback_lines.append("You can submit again with updated issues and/or fixes.")
431
  else:
432
+ feedback_lines.append(f"\nTask complete! Final best reward: {self._best_score:.3f}")
433
+
434
+ # ── Flag items for human review ──
435
+ # In a production data QA pipeline, these would go to a human reviewer.
436
+ # The grader flags cases where automated scoring has low confidence.
437
+ human_review_flags: list[dict] = []
438
+
439
+ # 1. False positives that target real columns — could be legitimate issues
440
+ # the task designer didn't plant (agent may be smarter than the grader)
441
+ issue_map = self._current_task.get_planted_issue_map()
442
+ valid_issue_types = {"missing_value", "wrong_type", "duplicate_row", "out_of_range",
443
+ "format_violation", "inconsistent_value", "statistical_outlier",
444
+ "referential_integrity"}
445
+ for fp_key in fp_keys:
446
+ parts = fp_key.split(",")
447
+ itype = parts[2].split(":")[1] if len(parts) >= 3 else ""
448
+ if itype in valid_issue_types:
449
+ human_review_flags.append({
450
+ "item": fp_key,
451
+ "reason": "Agent reported this issue but it's not in ground truth — may be a real issue the grader missed",
452
+ "type": "possible_unplanted_issue",
453
+ })
454
+
455
+ # 2. Partial fix matches — fix was close but not exact, human should verify
456
+ for detail in fix_result["fix_details"]:
457
+ if 0 < detail["score"] < 0.99:
458
+ human_review_flags.append({
459
+ "item": f"row:{detail['row']},col:{detail['col']}",
460
+ "reason": f"Fix scored {detail['score']:.2f} ({detail['reason']}) — human should verify if acceptable",
461
+ "type": "partial_fix",
462
+ })
463
+
464
+ # 3. High-difficulty issues that were missed — flag for training data review
465
+ planted_by_key = {i.to_key(): i for i in self._current_task.planted_issues}
466
+ fn_keys = self._planted_keys - reported_keys
467
+ for fn_key in fn_keys:
468
+ issue = planted_by_key.get(fn_key)
469
+ if issue and issue.difficulty >= 2.5:
470
+ human_review_flags.append({
471
+ "item": fn_key,
472
+ "reason": f"High-difficulty issue (difficulty={issue.difficulty}) missed — {issue.description}",
473
+ "type": "missed_hard_issue",
474
+ })
475
+
476
+ if human_review_flags:
477
+ feedback_lines.append(f"\n--- Flagged for Human Review ({len(human_review_flags)}) ---")
478
+ for flag in human_review_flags:
479
+ feedback_lines.append(f" [{flag['type']}] {flag['item']}: {flag['reason']}")
480
 
481
  return DataQAObservation(
482
  dataset_csv=self._current_task.corrupted_csv,
 
489
  max_steps=self._state.max_steps,
490
  done=is_done,
491
  reward=self._best_score,
492
+ metadata={
493
+ "identify_f1": identify_f1,
494
+ "identify_score": identify_score,
495
+ "fix_score": fix_score,
496
+ "combined_reward": combined_reward,
497
+ "precision": metrics["precision"],
498
+ "recall": metrics["recall"],
499
+ "tp": metrics["tp"],
500
+ "fp": metrics["fp"],
501
+ "fn": metrics["fn"],
502
+ "difficulty_found": weighted["difficulty_found"],
503
+ "difficulty_missed": weighted["difficulty_missed"],
504
+ "fixes_correct": fix_result["fixes_correct"],
505
+ "fixes_partial": fix_result["fixes_partial"],
506
+ "fixes_wrong": fix_result["fixes_wrong"],
507
+ "fixes_attempted": fix_result["fixes_attempted"],
508
+ "fix_details": fix_result["fix_details"],
509
+ "human_review_flags": human_review_flags,
510
+ },
511
  )
512
 
513
  @property
dataqa_env/server/gradio_ui.py ADDED
@@ -0,0 +1,508 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio UI — Agent Trajectory Replay Viewer for DataQA.
3
+
4
+ Designed for judges: zero clicks needed, auto-plays on load.
5
+ Tab per task, step slider, prominent metric cards, color-coded dataset.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import csv
11
+ import io
12
+
13
+ import gradio as gr
14
+
15
+ from .environment import DataQAEnvironment, parse_issue_key
16
+ from .tasks import list_tasks, PlantedIssue
17
+ from ..models import DataQAAction
18
+
19
+
20
+ # ── Pre-built agent trajectories (simulates baseline agent) ──
21
+
22
+ AGENT_TRAJECTORIES = {
23
+ "easy": [
24
+ {
25
+ "issues": [
26
+ "row:4,col:name,issue:missing_value",
27
+ "row:7,col:salary,issue:wrong_type",
28
+ "row:9,col:salary,issue:out_of_range",
29
+ "row:18,col:start_date,issue:out_of_range",
30
+ "row:3,col:email,issue:format_violation", # FP
31
+ ],
32
+ "fixes": [],
33
+ },
34
+ {
35
+ "issues": [
36
+ "row:4,col:name,issue:missing_value",
37
+ "row:7,col:salary,issue:wrong_type",
38
+ "row:9,col:salary,issue:out_of_range",
39
+ "row:21,col:employee_id,issue:duplicate_row",
40
+ "row:15,col:email,issue:inconsistent_value",
41
+ "row:18,col:start_date,issue:out_of_range",
42
+ ],
43
+ "fixes": [
44
+ "row:4,col:name,fix:David Kim",
45
+ "row:7,col:salary,fix:75000",
46
+ "row:9,col:salary,fix:73000",
47
+ "row:15,col:email,fix:oscar.rivera@company.com",
48
+ "row:18,col:start_date,fix:2022-01-19",
49
+ ],
50
+ },
51
+ ],
52
+ "medium": [
53
+ {
54
+ "issues": [
55
+ "row:5,col:total,issue:inconsistent_value",
56
+ "row:10,col:category,issue:format_violation",
57
+ "row:14,col:product_name,issue:missing_value",
58
+ "row:17,col:quantity,issue:out_of_range",
59
+ "row:19,col:order_id,issue:duplicate_row",
60
+ "row:12,col:order_date,issue:format_violation",
61
+ "row:24,col:shipping_country,issue:format_violation",
62
+ ],
63
+ "fixes": [],
64
+ },
65
+ {
66
+ "issues": [
67
+ "row:5,col:total,issue:inconsistent_value",
68
+ "row:10,col:category,issue:format_violation",
69
+ "row:14,col:product_name,issue:missing_value",
70
+ "row:17,col:quantity,issue:out_of_range",
71
+ "row:19,col:order_id,issue:duplicate_row",
72
+ "row:12,col:order_date,issue:format_violation",
73
+ "row:24,col:shipping_country,issue:format_violation",
74
+ "row:29,col:order_date,issue:inconsistent_value",
75
+ ],
76
+ "fixes": [
77
+ "row:5,col:total,fix:42.00",
78
+ "row:10,col:category,fix:Sports",
79
+ "row:12,col:order_date,fix:2024-01-26",
80
+ "row:14,col:product_name,fix:LED Strip Lights",
81
+ "row:24,col:shipping_country,fix:US",
82
+ "row:29,col:order_date,fix:2024-02-12",
83
+ ],
84
+ },
85
+ ],
86
+ "hard": [
87
+ {
88
+ "issues": [
89
+ "row:14,col:training_time_hours,issue:out_of_range",
90
+ "row:13,col:learning_rate,issue:out_of_range",
91
+ "row:15,col:model_name,issue:missing_value",
92
+ "row:9,col:batch_size,issue:format_violation",
93
+ "row:10,col:train_size,issue:inconsistent_value",
94
+ ],
95
+ "fixes": [],
96
+ },
97
+ {
98
+ "issues": [
99
+ "row:14,col:training_time_hours,issue:out_of_range",
100
+ "row:13,col:learning_rate,issue:out_of_range",
101
+ "row:15,col:model_name,issue:missing_value",
102
+ "row:9,col:batch_size,issue:format_violation",
103
+ "row:10,col:train_size,issue:inconsistent_value",
104
+ "row:5,col:val_loss,issue:inconsistent_value",
105
+ "row:7,col:gpu_memory_gb,issue:statistical_outlier",
106
+ "row:11,col:timestamp,issue:inconsistent_value",
107
+ "row:9,col:training_time_hours,issue:statistical_outlier",
108
+ "row:12,col:test_accuracy,issue:statistical_outlier",
109
+ ],
110
+ "fixes": [
111
+ "row:14,col:training_time_hours,fix:72.0",
112
+ "row:13,col:learning_rate,fix:0.00001",
113
+ "row:15,col:model_name,fix:whisper-small",
114
+ "row:9,col:batch_size,fix:256",
115
+ "row:9,col:training_time_hours,fix:36.0",
116
+ ],
117
+ },
118
+ ],
119
+ }
120
+
121
+
122
+ # ── HTML rendering ──
123
+
124
+ def _metric_card(label: str, value: str, color: str = "#333") -> str:
125
+ return (
126
+ f'<div style="text-align:center;padding:12px 16px;background:#f8f9fa;'
127
+ f'border-radius:8px;min-width:100px;">'
128
+ f'<div style="font-size:11px;color:#666;text-transform:uppercase;letter-spacing:1px;">{label}</div>'
129
+ f'<div style="font-size:28px;font-weight:700;color:{color};margin-top:2px;">{value}</div>'
130
+ f'</div>'
131
+ )
132
+
133
+
134
+ def _csv_to_html(
135
+ csv_text: str,
136
+ planted: list[PlantedIssue],
137
+ correct: set[tuple[int, str]],
138
+ fp: set[tuple[int, str]],
139
+ missed: set[tuple[int, str]],
140
+ fixed: dict[tuple[int, str], str],
141
+ fix_values: dict[tuple[int, str], str] | None = None,
142
+ ) -> str:
143
+ """Render CSV as HTML with color-coded cells and inline fix proposals."""
144
+ fix_values = fix_values or {}
145
+ desc_map = {(i.row, i.col): i for i in planted}
146
+ reader = csv.reader(io.StringIO(csv_text.strip()))
147
+ rows = list(reader)
148
+ if not rows:
149
+ return ""
150
+
151
+ header = rows[0]
152
+ header_lower = [h.strip().lower() for h in header]
153
+ data = rows[1:]
154
+
155
+ t = ['<table style="border-collapse:collapse;width:100%;font-size:12px;font-family:\'SF Mono\',monospace;">']
156
+ t.append('<tr>')
157
+ t.append('<th style="border:1px solid #dee2e6;padding:6px 8px;background:#343a40;color:#fff;font-size:11px;">Row</th>')
158
+ for h in header:
159
+ t.append(f'<th style="border:1px solid #dee2e6;padding:6px 8px;background:#343a40;color:#fff;font-size:11px;">{h}</th>')
160
+ t.append('</tr>')
161
+
162
+ for i, row in enumerate(data):
163
+ rn = i + 1
164
+ bg = "#fff" if i % 2 == 0 else "#f8f9fa"
165
+ t.append(f'<tr style="background:{bg};">')
166
+ t.append(f'<td style="border:1px solid #dee2e6;padding:4px 8px;color:#adb5bd;text-align:center;font-size:11px;">{rn}</td>')
167
+ for j, val in enumerate(row):
168
+ col = header_lower[j] if j < len(header_lower) else ""
169
+ ck = (rn, col)
170
+ s = "border:1px solid #dee2e6;padding:4px 8px;"
171
+ tip = ""
172
+ badge = ""
173
+
174
+ issue = desc_map.get(ck)
175
+
176
+ if ck in correct:
177
+ s += "background:#d4edda;"
178
+ tip = f"FOUND: {issue.description}" if issue else ""
179
+ badge = '<span style="font-size:9px;background:#28a745;color:#fff;padding:1px 4px;border-radius:3px;margin-left:4px;">TP</span>'
180
+ elif ck in fp:
181
+ s += "background:#f8d7da;"
182
+ badge = '<span style="font-size:9px;background:#dc3545;color:#fff;padding:1px 4px;border-radius:3px;margin-left:4px;">FP</span>'
183
+ elif ck in missed:
184
+ s += "background:#fff3cd;"
185
+ tip = f"MISSED: {issue.description}" if issue else ""
186
+ badge = '<span style="font-size:9px;background:#856404;color:#fff;padding:1px 4px;border-radius:3px;margin-left:4px;">MISS</span>'
187
+
188
+ fx = fixed.get(ck)
189
+ proposed = fix_values.get(ck)
190
+ if fx == "correct":
191
+ s += "box-shadow:inset 0 0 0 2px #28a745;"
192
+ badge += '<span style="font-size:9px;background:#28a745;color:#fff;padding:1px 4px;border-radius:3px;margin-left:2px;">FIX</span>'
193
+ elif fx == "partial":
194
+ s += "box-shadow:inset 0 0 0 2px #ffc107;"
195
+ badge += '<span style="font-size:9px;background:#ffc107;color:#333;padding:1px 4px;border-radius:3px;margin-left:2px;">~FIX</span>'
196
+
197
+ dv = val if val.strip() else '<em style="color:#dc3545;font-style:italic;">empty</em>'
198
+
199
+ # Show proposed fix value below the corrupted value
200
+ fix_line = ""
201
+ if proposed is not None:
202
+ fix_color = "#28a745" if fx == "correct" else ("#b8860b" if fx == "partial" else "#dc3545")
203
+ fix_line = (
204
+ f'<div style="font-size:10px;color:{fix_color};margin-top:2px;'
205
+ f'border-top:1px dashed {fix_color};padding-top:2px;">'
206
+ f'\u2192 {proposed}</div>'
207
+ )
208
+
209
+ t.append(f'<td style="{s}" title="{tip}">{dv}{badge}{fix_line}</td>')
210
+ t.append('</tr>')
211
+ t.append('</table>')
212
+ return "".join(t)
213
+
214
+
215
+ LEGEND_HTML = (
216
+ '<div style="display:flex;gap:12px;flex-wrap:wrap;margin-top:10px;font-size:11px;">'
217
+ '<span style="background:#d4edda;padding:2px 8px;border-radius:4px;">Found (TP)</span>'
218
+ '<span style="background:#f8d7da;padding:2px 8px;border-radius:4px;">False Positive</span>'
219
+ '<span style="background:#fff3cd;padding:2px 8px;border-radius:4px;">Missed</span>'
220
+ '<span style="box-shadow:inset 0 0 0 2px #28a745;padding:2px 8px;border-radius:4px;">Fix Correct</span>'
221
+ '<span style="box-shadow:inset 0 0 0 2px #ffc107;padding:2px 8px;border-radius:4px;">Fix Partial</span>'
222
+ '</div>'
223
+ )
224
+
225
+
226
+ # ── Core replay logic ──
227
+
228
+ def _replay_task(task_id: str) -> list[dict]:
229
+ """Run the agent trajectory and collect per-step data."""
230
+ env = DataQAEnvironment()
231
+ obs = env.reset(task_id=task_id)
232
+ task = env._current_task
233
+ planted_keys = {i.to_key() for i in task.planted_issues}
234
+ steps_data = []
235
+
236
+ # Step 0: initial state
237
+ steps_data.append({
238
+ "label": "Initial — corrupted dataset",
239
+ "html": _csv_to_html(obs.dataset_csv, task.planted_issues, set(), set(), set(), {}),
240
+ "metrics": {"reward": 0.0, "tp": 0, "fp": 0, "fn": len(task.planted_issues),
241
+ "identify": 0.0, "fix": 0.0, "fixes_correct": 0},
242
+ "feedback": f"Task: {task.name}\nIssues to find: {obs.num_issues_hint}\n\n{task.description}",
243
+ })
244
+
245
+ trajectory = AGENT_TRAJECTORIES.get(task_id, [])
246
+ for i, step_data in enumerate(trajectory):
247
+ action = DataQAAction(
248
+ issues=step_data["issues"],
249
+ fixes=step_data.get("fixes", []),
250
+ task_id=task_id,
251
+ )
252
+ obs = env.step(action)
253
+
254
+ reported_keys = set()
255
+ for iss in step_data["issues"]:
256
+ key = parse_issue_key(iss)
257
+ if key:
258
+ reported_keys.add(key)
259
+
260
+ tp_keys = reported_keys & planted_keys
261
+ fp_keys = reported_keys - planted_keys
262
+ fn_keys = planted_keys - reported_keys
263
+
264
+ correct = {_kc(k) for k in tp_keys}
265
+ fp = {_kc(k) for k in fp_keys}
266
+ missed = {_kc(k) for k in fn_keys} if obs.done else set()
267
+
268
+ fixed: dict[tuple[int, str], str] = {}
269
+ for d in obs.metadata.get("fix_details", []):
270
+ c = (d["row"], d["col"])
271
+ fixed[c] = "correct" if d["score"] >= 0.99 else ("partial" if d["score"] > 0 else "wrong")
272
+
273
+ # Extract proposed fix values from the raw fix strings
274
+ fix_values: dict[tuple[int, str], str] = {}
275
+ from .environment import parse_fix
276
+ for raw_fix in step_data.get("fixes", []):
277
+ parsed = parse_fix(raw_fix)
278
+ if parsed:
279
+ row, col, val = parsed
280
+ fix_values[(row, col)] = val
281
+
282
+ html = _csv_to_html(obs.dataset_csv, task.planted_issues, correct, fp, missed, fixed, fix_values)
283
+
284
+ has_fixes = bool(step_data.get("fixes"))
285
+ if has_fixes:
286
+ label = f"Step {i+1} — identify + fix"
287
+ else:
288
+ label = f"Step {i+1} — identify only"
289
+
290
+ steps_data.append({
291
+ "label": label,
292
+ "html": html,
293
+ "metrics": {
294
+ "reward": obs.reward,
295
+ "tp": obs.metadata["tp"],
296
+ "fp": obs.metadata["fp"],
297
+ "fn": obs.metadata["fn"],
298
+ "identify": obs.metadata["identify_score"],
299
+ "fix": obs.metadata["fix_score"],
300
+ "fixes_correct": obs.metadata["fixes_correct"],
301
+ },
302
+ "feedback": obs.feedback,
303
+ })
304
+
305
+ return steps_data
306
+
307
+
308
+ def _kc(key: str) -> tuple[int, str]:
309
+ parts = key.split(",")
310
+ return (int(parts[0].split(":")[1]), parts[1].split(":")[1])
311
+
312
+
313
+ # ── Gradio app ──
314
+
315
+ def build_gradio_ui():
316
+ # Pre-compute all replays at startup
317
+ all_replays: dict[str, list[dict]] = {}
318
+ for tid in list_tasks():
319
+ all_replays[tid] = _replay_task(tid)
320
+
321
+ def show_step(task_id: str, step_idx: int):
322
+ replay = all_replays.get(task_id, [])
323
+ step_idx = int(step_idx)
324
+ if step_idx >= len(replay):
325
+ step_idx = len(replay) - 1
326
+ sd = replay[step_idx]
327
+ m = sd["metrics"]
328
+
329
+ # Reward color
330
+ r = m["reward"]
331
+ rc = "#28a745" if r >= 0.8 else ("#ffc107" if r >= 0.4 else "#dc3545")
332
+
333
+ cards = (
334
+ '<div style="display:flex;gap:10px;flex-wrap:wrap;margin-bottom:12px;">'
335
+ + _metric_card("Reward", f"{r:.2f}", rc)
336
+ + _metric_card("Found", str(m["tp"]), "#28a745")
337
+ + _metric_card("False Pos", str(m["fp"]), "#dc3545" if m["fp"] > 0 else "#28a745")
338
+ + _metric_card("Missed", str(m["fn"]), "#dc3545" if m["fn"] > 0 else "#28a745")
339
+ + _metric_card("Identify", f"{m['identify']:.2f}", "#333")
340
+ + _metric_card("Fix", f"{m['fix']:.2f}", "#333")
341
+ + '</div>'
342
+ )
343
+
344
+ full_html = (
345
+ f'<div style="font-size:14px;font-weight:600;margin-bottom:8px;color:#495057;">'
346
+ f'{sd["label"]}</div>'
347
+ + cards + sd["html"] + LEGEND_HTML
348
+ )
349
+
350
+ return full_html, sd["feedback"]
351
+
352
+ def on_task_change(task_id):
353
+ replay = all_replays.get(task_id, [])
354
+ max_step = len(replay) - 1
355
+ html, fb = show_step(task_id, 0)
356
+ return (
357
+ gr.update(maximum=max_step, value=0),
358
+ html,
359
+ fb,
360
+ )
361
+
362
+ def on_step_change(task_id, step_idx):
363
+ html, fb = show_step(task_id, step_idx)
364
+ return html, fb
365
+
366
+ # ── Live agent runner (connects to the env server) ──
367
+
368
+ live_env = DataQAEnvironment()
369
+ live_state: dict = {"obs": None, "task_id": "easy", "steps": []}
370
+
371
+ def live_reset(task_id):
372
+ obs = live_env.reset(task_id=task_id)
373
+ task = live_env._current_task
374
+ live_state["obs"] = obs
375
+ live_state["task_id"] = task_id
376
+ live_state["steps"] = []
377
+ html = _csv_to_html(obs.dataset_csv, task.planted_issues, set(), set(), set(), {})
378
+ info = f"**{task.name}** — {obs.num_issues_hint} issues to find, {obs.max_steps} steps max"
379
+ return html, info, "", "0.000"
380
+
381
+ def live_step(issues_text, fixes_text):
382
+ if live_state["obs"] is None:
383
+ return "Reset first.", "", "", ""
384
+ obs = live_state["obs"]
385
+ task = live_env._current_task
386
+ planted_keys = {i.to_key() for i in task.planted_issues}
387
+
388
+ issues = [l.strip() for l in issues_text.strip().split("\n") if l.strip()]
389
+ fixes = [l.strip() for l in fixes_text.strip().split("\n") if l.strip()] if fixes_text.strip() else []
390
+
391
+ action = DataQAAction(issues=issues, fixes=fixes, task_id=live_state["task_id"])
392
+ obs = live_env.step(action)
393
+ live_state["obs"] = obs
394
+
395
+ reported_keys = set()
396
+ for iss in issues:
397
+ key = parse_issue_key(iss)
398
+ if key:
399
+ reported_keys.add(key)
400
+
401
+ tp_keys = reported_keys & planted_keys
402
+ fp_keys = reported_keys - planted_keys
403
+ fn_keys = planted_keys - reported_keys
404
+
405
+ correct = {_kc(k) for k in tp_keys}
406
+ fp_set = {_kc(k) for k in fp_keys}
407
+ missed = {_kc(k) for k in fn_keys} if obs.done else set()
408
+
409
+ fixed: dict[tuple[int, str], str] = {}
410
+ for d in obs.metadata.get("fix_details", []):
411
+ c = (d["row"], d["col"])
412
+ fixed[c] = "correct" if d["score"] >= 0.99 else ("partial" if d["score"] > 0 else "wrong")
413
+
414
+ from .environment import parse_fix
415
+ fix_values: dict[tuple[int, str], str] = {}
416
+ for raw in fixes:
417
+ parsed = parse_fix(raw)
418
+ if parsed:
419
+ fix_values[(parsed[0], parsed[1])] = parsed[2]
420
+
421
+ html = _csv_to_html(obs.dataset_csv, task.planted_issues, correct, fp_set, missed, fixed, fix_values)
422
+
423
+ m = obs.metadata
424
+ r = obs.reward
425
+ rc = "#28a745" if r >= 0.8 else ("#ffc107" if r >= 0.4 else "#dc3545")
426
+ cards = (
427
+ '<div style="display:flex;gap:10px;flex-wrap:wrap;margin-bottom:12px;">'
428
+ + _metric_card("Reward", f"{r:.2f}", rc)
429
+ + _metric_card("Found", str(m["tp"]), "#28a745")
430
+ + _metric_card("False Pos", str(m["fp"]), "#dc3545" if m["fp"] > 0 else "#28a745")
431
+ + _metric_card("Missed", str(m["fn"]), "#dc3545" if m["fn"] > 0 else "#28a745")
432
+ + '</div>'
433
+ )
434
+ full_html = cards + html + LEGEND_HTML
435
+ return full_html, obs.feedback, f"{r:.3f}", ""
436
+
437
+ # ── Build the UI ──
438
+
439
+ with gr.Blocks(title="DataQA Environment") as demo:
440
+ gr.Markdown(
441
+ "# DataQA — Data Quality Assurance Environment\n"
442
+ "Two-phase RL environment: **Identify** data quality issues, then **Fix** them."
443
+ )
444
+
445
+ with gr.Tabs():
446
+ # ── Tab 1: Demo replay ──
447
+ with gr.Tab("Demo (Baseline Agent)"):
448
+ gr.Markdown(
449
+ "*Replay of the baseline Qwen-72B agent. "
450
+ "Use the slider to step through the agent's trajectory.*"
451
+ )
452
+ with gr.Row():
453
+ task_dd = gr.Dropdown(choices=list_tasks(), value="easy", label="Task", scale=1)
454
+ step_slider = gr.Slider(minimum=0, maximum=2, step=1, value=0, label="Step", scale=3)
455
+
456
+ viz_html = gr.HTML()
457
+ feedback_box = gr.Textbox(label="Agent Feedback", lines=10, interactive=False)
458
+
459
+ task_dd.change(on_task_change, inputs=[task_dd], outputs=[step_slider, viz_html, feedback_box])
460
+ step_slider.change(on_step_change, inputs=[task_dd, step_slider], outputs=[viz_html, feedback_box])
461
+ demo.load(on_task_change, inputs=[task_dd], outputs=[step_slider, viz_html, feedback_box])
462
+
463
+ # ── Tab 2: Try your own agent ──
464
+ with gr.Tab("Try Your Own Agent"):
465
+ gr.Markdown(
466
+ "*Submit your own issues and fixes to see how the environment scores them. "
467
+ "This is the same environment the baseline agent talks to.*"
468
+ )
469
+ with gr.Row():
470
+ live_task_dd = gr.Dropdown(choices=list_tasks(), value="easy", label="Task", scale=1)
471
+ live_reset_btn = gr.Button("Reset", variant="primary", scale=1)
472
+
473
+ with gr.Row():
474
+ live_info = gr.Markdown()
475
+ live_reward = gr.Textbox(label="Reward", interactive=False, scale=1)
476
+
477
+ live_viz = gr.HTML()
478
+
479
+ with gr.Row():
480
+ live_issues = gr.Textbox(
481
+ label="Issues (one per line)",
482
+ placeholder="row:4,col:name,issue:missing_value\nrow:7,col:salary,issue:wrong_type",
483
+ lines=5,
484
+ )
485
+ live_fixes = gr.Textbox(
486
+ label="Fixes (one per line, optional)",
487
+ placeholder="row:4,col:name,fix:David Kim\nrow:7,col:salary,fix:75000",
488
+ lines=5,
489
+ )
490
+
491
+ live_step_btn = gr.Button("Submit Step", variant="primary")
492
+ live_feedback = gr.Textbox(label="Feedback", lines=10, interactive=False)
493
+
494
+ live_reset_btn.click(
495
+ live_reset, inputs=[live_task_dd],
496
+ outputs=[live_viz, live_info, live_feedback, live_reward],
497
+ )
498
+ live_step_btn.click(
499
+ live_step, inputs=[live_issues, live_fixes],
500
+ outputs=[live_viz, live_feedback, live_reward, live_issues],
501
+ )
502
+
503
+ return demo
504
+
505
+
506
+ if __name__ == "__main__":
507
+ demo = build_gradio_ui()
508
+ demo.launch()
dataqa_env/server/tasks.py CHANGED
@@ -25,6 +25,7 @@ class PlantedIssue:
25
  col: str
26
  issue_type: str
27
  description: str
 
28
 
29
  def to_key(self) -> str:
30
  return f"row:{self.row},col:{self.col},issue:{self.issue_type}"
@@ -42,6 +43,28 @@ class Task:
42
  corrupted_csv: str = ""
43
  max_steps: int = 3
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
  def _csv_to_rows(csv_text: str) -> List[List[str]]:
47
  reader = csv.reader(io.StringIO(csv_text.strip()))
@@ -72,7 +95,17 @@ def create_task_easy(seed: int = 42) -> Task:
72
  107,Grace Lee,grace.lee@company.com,Marketing,75000,2021-12-01
73
  108,Hank Brown,hank.brown@company.com,Sales,65000,2023-04-18
74
  109,Iris Patel,iris.patel@company.com,HR,73000,2020-02-28
75
- 110,Jack Taylor,jack.taylor@company.com,Engineering,97000,2022-09-14"""
 
 
 
 
 
 
 
 
 
 
76
 
77
  schema_desc = """Columns:
78
  - employee_id: integer, unique, range 100-999
@@ -93,29 +126,43 @@ def create_task_easy(seed: int = 42) -> Task:
93
  data = rows[1:]
94
  issues: List[PlantedIssue] = []
95
 
96
- # Issue 1: Missing value - null out a name
97
  r = 3 # row index in data (0-based), displayed as row 4 in CSV
98
  data[r][1] = ""
99
  issues.append(PlantedIssue(row=r + 1, col="name", issue_type="missing_value",
100
- description="Empty name field"))
101
 
102
- # Issue 2: Wrong type - salary as text
103
  r = 6
104
  data[r][4] = "seventy-five thousand"
105
  issues.append(PlantedIssue(row=r + 1, col="salary", issue_type="wrong_type",
106
- description="Salary is text instead of integer"))
107
 
108
- # Issue 3: Duplicate row
109
  dup_source = 1
110
  data.append(list(data[dup_source]))
111
  issues.append(PlantedIssue(row=len(data), col="employee_id", issue_type="duplicate_row",
112
- description=f"Exact duplicate of row {dup_source + 1}"))
113
 
114
- # Issue 4: Out of range salary
115
  r = 8
116
  data[r][4] = "5000"
117
  issues.append(PlantedIssue(row=r + 1, col="salary", issue_type="out_of_range",
118
- description="Salary 5000 is below minimum 50000"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
 
120
  corrupted = _rows_to_csv([header] + data)
121
 
@@ -163,7 +210,17 @@ ORD-016,CUST-114,Bluetooth Speaker,Electronics,1,49.99,2024-01-30,UK,delivered,4
163
  ORD-017,CUST-115,Jump Rope,Sports,2,8.99,2024-01-31,US,shipped,17.98
164
  ORD-018,CUST-116,Coffee Table Book,Books,1,32.00,2024-02-01,CA,delivered,32.00
165
  ORD-019,CUST-117,Ergonomic Chair,Home,1,450.00,2024-02-02,US,processing,450.00
166
- ORD-020,CUST-118,Fitness Tracker,Electronics,1,79.99,2024-02-03,AU,delivered,79.99"""
 
 
 
 
 
 
 
 
 
 
167
 
168
  schema_desc = """Columns:
169
  - order_id: string, unique, format ORD-NNN
@@ -190,41 +247,55 @@ ORD-020,CUST-118,Fitness Tracker,Electronics,1,79.99,2024-02-03,AU,delivered,79.
190
  data = rows[1:]
191
  issues: List[PlantedIssue] = []
192
 
193
- # Issue 1: total doesn't match quantity * unit_price
194
  r = 4 # ORD-005
195
  data[r][9] = "84.00" # should be 42.00 (qty=1, price=42.00)
196
  issues.append(PlantedIssue(row=r + 1, col="total", issue_type="inconsistent_value",
197
- description="total (84.00) != quantity (1) * unit_price (42.00)"))
198
 
199
- # Issue 2: Invalid category
200
  r = 9 # ORD-010
201
  data[r][3] = "Fitness" # should be Sports
202
  issues.append(PlantedIssue(row=r + 1, col="category", issue_type="format_violation",
203
- description="'Fitness' is not in allowed categories"))
204
 
205
- # Issue 3: Missing value in product_name
206
  r = 13 # ORD-014
207
  data[r][2] = ""
208
  issues.append(PlantedIssue(row=r + 1, col="product_name", issue_type="missing_value",
209
- description="Empty product_name"))
210
 
211
- # Issue 4: Out of range quantity
212
  r = 16 # ORD-017
213
  data[r][4] = "-1"
214
  issues.append(PlantedIssue(row=r + 1, col="quantity", issue_type="out_of_range",
215
- description="Negative quantity"))
216
 
217
- # Issue 5: Duplicate order_id
218
  r = 18 # ORD-019
219
  data[r][0] = "ORD-003"
220
  issues.append(PlantedIssue(row=r + 1, col="order_id", issue_type="duplicate_row",
221
- description="Duplicate order_id ORD-003"))
222
 
223
- # Issue 6: Wrong date format
224
  r = 11 # ORD-012
225
  data[r][6] = "26/01/2024"
226
  issues.append(PlantedIssue(row=r + 1, col="order_date", issue_type="format_violation",
227
- description="Date format DD/MM/YYYY instead of YYYY-MM-DD"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
 
229
  corrupted = _rows_to_csv([header] + data)
230
 
@@ -267,7 +338,22 @@ EXP-011,yolov5-m,coco-2017,118287,5000,40670,0.01,32,300,0.032,0.045,0.0,10.2,24
267
  EXP-012,wav2vec2,librispeech,281241,5567,2620,0.0001,8,20,0.92,1.05,0.0,12.6,15.0,2024-03-13T11:30:00
268
  EXP-013,clip-base,cc3m,2818102,15000,15000,0.00001,256,32,2.15,2.38,0.0,22.4,48.0,2024-03-14T08:00:00
269
  EXP-014,detr,coco-2017,118287,5000,40670,0.0001,4,500,1.85,2.12,0.0,16.0,72.0,2024-03-15T10:00:00
270
- EXP-015,whisper-small,common-voice,520000,16000,16000,0.00005,16,5,0.55,0.68,0.0,7.4,6.5,2024-03-16T14:00:00"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
271
 
272
  schema_desc = """Columns:
273
  - experiment_id: string, unique, format EXP-NNN
@@ -301,53 +387,83 @@ EXP-015,whisper-small,common-voice,520000,16000,16000,0.00005,16,5,0.55,0.68,0.0
301
  data = rows[1:]
302
  issues: List[PlantedIssue] = []
303
 
304
- # Issue 1: Data leakage signal — val_loss much lower than train_loss
305
  r = 4 # EXP-005
306
  data[r][10] = "0.15" # val_loss=0.15 but train_loss=0.28 → suspicious
307
  issues.append(PlantedIssue(row=r + 1, col="val_loss", issue_type="inconsistent_value",
308
- description="val_loss (0.15) significantly less than train_loss (0.28), potential data leakage"))
 
309
 
310
- # Issue 2: Batch size not power of 2
311
  r = 8 # EXP-009
312
  data[r][7] = "250" # not a power of 2
313
  issues.append(PlantedIssue(row=r + 1, col="batch_size", issue_type="format_violation",
314
- description="batch_size 250 is not a power of 2"))
315
 
316
- # Issue 3: GPU memory unreasonable for model
317
  r = 6 # EXP-007 resnet18 on cifar10
318
  data[r][12] = "42.5" # resnet18 shouldn't need 42.5 GB
319
  issues.append(PlantedIssue(row=r + 1, col="gpu_memory_gb", issue_type="statistical_outlier",
320
- description="resnet18 on cifar10 using 42.5 GB GPU memory is unreasonable"))
 
321
 
322
- # Issue 4: Timestamp out of order
323
  r = 10 # EXP-011
324
  data[r][14] = "2024-03-02T09:00:00" # should be after EXP-010's timestamp
325
  issues.append(PlantedIssue(row=r + 1, col="timestamp", issue_type="inconsistent_value",
326
- description="Timestamp 2024-03-02 is before EXP-010's timestamp 2024-03-11"))
 
327
 
328
- # Issue 5: Train size smaller than test size
329
  r = 9 # EXP-010
330
  data[r][3] = "500" # train_size=500 but test_size=1821
331
  issues.append(PlantedIssue(row=r + 1, col="train_size", issue_type="inconsistent_value",
332
- description="train_size (500) is smaller than test_size (1821)"))
 
333
 
334
- # Issue 6: Negative training time
335
  r = 13 # EXP-014
336
  data[r][13] = "-72.0"
337
  issues.append(PlantedIssue(row=r + 1, col="training_time_hours", issue_type="out_of_range",
338
- description="Negative training time"))
339
 
340
- # Issue 7: Learning rate out of range
341
  r = 12 # EXP-013
342
  data[r][6] = "2.5" # way too high
343
  issues.append(PlantedIssue(row=r + 1, col="learning_rate", issue_type="out_of_range",
344
- description="Learning rate 2.5 exceeds maximum of 1.0"))
345
 
346
- # Issue 8: Missing model name (subtlesingle space instead of empty)
347
  r = 14 # EXP-015
348
  data[r][1] = " "
349
  issues.append(PlantedIssue(row=r + 1, col="model_name", issue_type="missing_value",
350
- description="model_name is whitespace-only"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
351
 
352
  corrupted = _rows_to_csv([header] + data)
353
 
@@ -370,6 +486,123 @@ EXP-015,whisper-small,common-voice,520000,16000,16000,0.00005,16,5,0.55,0.68,0.0
370
  )
371
 
372
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
373
  # ---------------------------------------------------------------------------
374
  # Task registry
375
  # ---------------------------------------------------------------------------
 
25
  col: str
26
  issue_type: str
27
  description: str
28
+ difficulty: float = 1.0 # 1.0=easy, 2.0=medium, 3.0=hard (for weighted reward)
29
 
30
  def to_key(self) -> str:
31
  return f"row:{self.row},col:{self.col},issue:{self.issue_type}"
 
43
  corrupted_csv: str = ""
44
  max_steps: int = 3
45
 
46
+ def get_clean_value(self, row: int, col: str) -> str | None:
47
+ """
48
+ Look up the original clean value for a given (row, col).
49
+ Row is 1-indexed (data row after header).
50
+ Returns None if row/col is out of bounds or column not found.
51
+ """
52
+ rows = _csv_to_rows(self.clean_csv)
53
+ if len(rows) < 2:
54
+ return None
55
+ header = [h.strip().lower() for h in rows[0]]
56
+ if col.lower() not in header:
57
+ return None
58
+ col_idx = header.index(col.lower())
59
+ data_row_idx = row # row is 1-indexed, rows[0] is header, so rows[row] is the data row
60
+ if data_row_idx < 1 or data_row_idx >= len(rows):
61
+ return None
62
+ return rows[data_row_idx][col_idx].strip()
63
+
64
+ def get_planted_issue_map(self) -> dict:
65
+ """Return dict mapping issue key -> PlantedIssue for quick lookups."""
66
+ return {issue.to_key(): issue for issue in self.planted_issues}
67
+
68
 
69
  def _csv_to_rows(csv_text: str) -> List[List[str]]:
70
  reader = csv.reader(io.StringIO(csv_text.strip()))
 
95
  107,Grace Lee,grace.lee@company.com,Marketing,75000,2021-12-01
96
  108,Hank Brown,hank.brown@company.com,Sales,65000,2023-04-18
97
  109,Iris Patel,iris.patel@company.com,HR,73000,2020-02-28
98
+ 110,Jack Taylor,jack.taylor@company.com,Engineering,97000,2022-09-14
99
+ 111,Kevin Zhang,kevin.zhang@company.com,Engineering,91000,2021-05-22
100
+ 112,Laura Adams,laura.adams@company.com,Sales,69000,2022-11-03
101
+ 113,Mike Torres,mike.torres@company.com,Marketing,74000,2020-08-17
102
+ 114,Nina Sharma,nina.sharma@company.com,HR,76000,2019-04-30
103
+ 115,Oscar Rivera,oscar.rivera@company.com,Engineering,105000,2018-12-10
104
+ 116,Paula Green,paula.green@company.com,Sales,67000,2023-06-25
105
+ 117,Quinn Murphy,quinn.murphy@company.com,Marketing,78000,2021-03-08
106
+ 118,Rosa Diaz,rosa.diaz@company.com,Engineering,99000,2022-01-19
107
+ 119,Sam Cooper,sam.cooper@company.com,HR,70000,2020-10-05
108
+ 120,Tara Singh,tara.singh@company.com,Sales,66000,2023-02-14"""
109
 
110
  schema_desc = """Columns:
111
  - employee_id: integer, unique, range 100-999
 
126
  data = rows[1:]
127
  issues: List[PlantedIssue] = []
128
 
129
+ # Issue 1: Missing value - null out a name (easy to spot)
130
  r = 3 # row index in data (0-based), displayed as row 4 in CSV
131
  data[r][1] = ""
132
  issues.append(PlantedIssue(row=r + 1, col="name", issue_type="missing_value",
133
+ description="Empty name field", difficulty=1.0))
134
 
135
+ # Issue 2: Wrong type - salary as text (easy to spot)
136
  r = 6
137
  data[r][4] = "seventy-five thousand"
138
  issues.append(PlantedIssue(row=r + 1, col="salary", issue_type="wrong_type",
139
+ description="Salary is text instead of integer", difficulty=1.0))
140
 
141
+ # Issue 3: Duplicate row (moderate — requires cross-row comparison)
142
  dup_source = 1
143
  data.append(list(data[dup_source]))
144
  issues.append(PlantedIssue(row=len(data), col="employee_id", issue_type="duplicate_row",
145
+ description=f"Exact duplicate of row {dup_source + 1}", difficulty=1.5))
146
 
147
+ # Issue 4: Out of range salary (easy to spot)
148
  r = 8
149
  data[r][4] = "5000"
150
  issues.append(PlantedIssue(row=r + 1, col="salary", issue_type="out_of_range",
151
+ description="Salary 5000 is below minimum 50000", difficulty=1.0))
152
+
153
+ # Issue 5: Email doesn't match name pattern (moderate — cross-column check)
154
+ r = 14 # Oscar Rivera -> email should be oscar.rivera@company.com
155
+ data[r][2] = "john.doe@company.com"
156
+ issues.append(PlantedIssue(row=r + 1, col="email", issue_type="inconsistent_value",
157
+ description="Email john.doe@company.com doesn't match name Oscar Rivera",
158
+ difficulty=1.5))
159
+
160
+ # Issue 6: Future start date (requires knowing current date context)
161
+ r = 17 # Rosa Diaz
162
+ data[r][5] = "2027-06-15"
163
+ issues.append(PlantedIssue(row=r + 1, col="start_date", issue_type="out_of_range",
164
+ description="Start date 2027-06-15 is in the future (beyond 2025-12-31)",
165
+ difficulty=1.5))
166
 
167
  corrupted = _rows_to_csv([header] + data)
168
 
 
210
  ORD-017,CUST-115,Jump Rope,Sports,2,8.99,2024-01-31,US,shipped,17.98
211
  ORD-018,CUST-116,Coffee Table Book,Books,1,32.00,2024-02-01,CA,delivered,32.00
212
  ORD-019,CUST-117,Ergonomic Chair,Home,1,450.00,2024-02-02,US,processing,450.00
213
+ ORD-020,CUST-118,Fitness Tracker,Electronics,1,79.99,2024-02-03,AU,delivered,79.99
214
+ ORD-021,CUST-119,Laptop Sleeve,Electronics,1,24.99,2024-02-04,US,delivered,24.99
215
+ ORD-022,CUST-120,Hiking Backpack,Sports,1,65.00,2024-02-05,CA,shipped,65.00
216
+ ORD-023,CUST-121,Machine Learning Book,Books,1,54.99,2024-02-06,UK,delivered,54.99
217
+ ORD-024,CUST-122,Plant Pot Set,Home,3,15.00,2024-02-07,US,delivered,45.00
218
+ ORD-025,CUST-123,Noise Cancelling Headphones,Electronics,1,199.99,2024-02-08,DE,shipped,199.99
219
+ ORD-026,CUST-124,Basketball,Sports,1,29.99,2024-02-09,US,delivered,29.99
220
+ ORD-027,CUST-125,Cookbook Collection,Books,2,22.50,2024-02-10,AU,delivered,45.00
221
+ ORD-028,CUST-126,Smart Plug,Home,4,12.99,2024-02-11,US,processing,51.96
222
+ ORD-029,CUST-127,Wireless Charger,Electronics,1,34.99,2024-02-12,UK,delivered,34.99
223
+ ORD-030,CUST-128,Dumbbells Set,Sports,1,89.00,2024-02-13,US,shipped,89.00"""
224
 
225
  schema_desc = """Columns:
226
  - order_id: string, unique, format ORD-NNN
 
247
  data = rows[1:]
248
  issues: List[PlantedIssue] = []
249
 
250
+ # Issue 1: total doesn't match quantity * unit_price (requires cross-column check)
251
  r = 4 # ORD-005
252
  data[r][9] = "84.00" # should be 42.00 (qty=1, price=42.00)
253
  issues.append(PlantedIssue(row=r + 1, col="total", issue_type="inconsistent_value",
254
+ description="total (84.00) != quantity (1) * unit_price (42.00)", difficulty=2.0))
255
 
256
+ # Issue 2: Invalid category (requires knowing the allowed set)
257
  r = 9 # ORD-010
258
  data[r][3] = "Fitness" # should be Sports
259
  issues.append(PlantedIssue(row=r + 1, col="category", issue_type="format_violation",
260
+ description="'Fitness' is not in allowed categories", difficulty=1.5))
261
 
262
+ # Issue 3: Missing value in product_name (easy to spot)
263
  r = 13 # ORD-014
264
  data[r][2] = ""
265
  issues.append(PlantedIssue(row=r + 1, col="product_name", issue_type="missing_value",
266
+ description="Empty product_name", difficulty=1.0))
267
 
268
+ # Issue 4: Out of range quantity (easy to spot)
269
  r = 16 # ORD-017
270
  data[r][4] = "-1"
271
  issues.append(PlantedIssue(row=r + 1, col="quantity", issue_type="out_of_range",
272
+ description="Negative quantity", difficulty=1.0))
273
 
274
+ # Issue 5: Duplicate order_id (requires cross-row comparison)
275
  r = 18 # ORD-019
276
  data[r][0] = "ORD-003"
277
  issues.append(PlantedIssue(row=r + 1, col="order_id", issue_type="duplicate_row",
278
+ description="Duplicate order_id ORD-003", difficulty=1.5))
279
 
280
+ # Issue 6: Wrong date format (moderate — format mismatch)
281
  r = 11 # ORD-012
282
  data[r][6] = "26/01/2024"
283
  issues.append(PlantedIssue(row=r + 1, col="order_date", issue_type="format_violation",
284
+ description="Date format DD/MM/YYYY instead of YYYY-MM-DD", difficulty=1.5))
285
+
286
+ # Issue 7: Invalid country code (requires ISO knowledge)
287
+ r = 23 # ORD-024
288
+ data[r][7] = "XX" # not a valid ISO country code
289
+ issues.append(PlantedIssue(row=r + 1, col="shipping_country", issue_type="format_violation",
290
+ description="'XX' is not a valid ISO 2-letter country code", difficulty=1.5))
291
+
292
+ # Issue 8: Status-date inconsistency — order from Feb 13 still "processing" is suspicious
293
+ # but more importantly: delivered order with a future date
294
+ r = 28 # ORD-029
295
+ data[r][6] = "2025-12-25" # future date but status is "delivered"
296
+ issues.append(PlantedIssue(row=r + 1, col="order_date", issue_type="inconsistent_value",
297
+ description="Order date 2025-12-25 is in the future but status is 'delivered'",
298
+ difficulty=2.0))
299
 
300
  corrupted = _rows_to_csv([header] + data)
301
 
 
338
  EXP-012,wav2vec2,librispeech,281241,5567,2620,0.0001,8,20,0.92,1.05,0.0,12.6,15.0,2024-03-13T11:30:00
339
  EXP-013,clip-base,cc3m,2818102,15000,15000,0.00001,256,32,2.15,2.38,0.0,22.4,48.0,2024-03-14T08:00:00
340
  EXP-014,detr,coco-2017,118287,5000,40670,0.0001,4,500,1.85,2.12,0.0,16.0,72.0,2024-03-15T10:00:00
341
+ EXP-015,whisper-small,common-voice,520000,16000,16000,0.00005,16,5,0.55,0.68,0.0,7.4,6.5,2024-03-16T14:00:00
342
+ EXP-016,mobilenet-v3,imagenet-1k,1281167,50000,100000,0.004,128,150,0.92,1.05,72.8,4.1,18.0,2024-03-17T08:30:00
343
+ EXP-017,albert-base,mnli,392702,9815,9796,0.00002,32,5,0.32,0.41,83.1,6.2,1.8,2024-03-18T11:00:00
344
+ EXP-018,gpt-neo-1.3b,pile-subset,1500000,50000,50000,0.0002,8,2,2.85,2.98,0.0,18.5,36.0,2024-03-19T14:00:00
345
+ EXP-019,swin-tiny,imagenet-1k,1281167,50000,100000,0.001,256,300,0.78,0.95,78.2,8.6,42.0,2024-03-20T09:00:00
346
+ EXP-020,deberta-large,squad-v2,130319,11873,8862,0.00001,16,5,0.35,0.42,85.7,15.2,4.5,2024-03-21T10:30:00
347
+ EXP-021,yolov8-s,coco-2017,118287,5000,40670,0.01,64,200,0.028,0.038,0.0,6.8,16.0,2024-03-22T13:00:00
348
+ EXP-022,bart-base,xsum,204045,11332,11334,0.0001,32,10,1.22,1.38,0.0,8.4,6.2,2024-03-23T15:30:00
349
+ EXP-023,convnext-tiny,imagenet-1k,1281167,50000,100000,0.002,256,300,0.74,0.92,79.5,7.2,38.0,2024-03-24T08:00:00
350
+ EXP-024,xlm-roberta,xnli,392702,2490,5010,0.00002,16,10,0.41,0.48,82.3,12.4,5.8,2024-03-25T11:00:00
351
+ EXP-025,stable-diffusion,laion-400m,400000000,10000,10000,0.0001,4,1,0.45,0.52,0.0,24.0,168.0,2024-03-26T09:00:00
352
+ EXP-026,phi-2,dolly-15k,15011,500,500,0.00005,8,3,0.82,0.95,0.0,10.2,2.5,2024-03-27T14:00:00
353
+ EXP-027,dino-v2,imagenet-1k,1281167,50000,100000,0.0005,64,100,0.42,0.58,0.0,11.8,28.0,2024-03-28T10:00:00
354
+ EXP-028,electra-small,glue-mrpc,3668,408,1725,0.0001,32,10,0.38,0.44,87.2,3.8,0.8,2024-03-29T16:00:00
355
+ EXP-029,sam-base,sa-1b,11000000,50000,50000,0.0001,4,1,0.95,1.08,0.0,16.4,96.0,2024-03-30T08:00:00
356
+ EXP-030,llama2-13b,oasst1,84437,4401,4401,0.00001,2,3,0.78,0.88,0.0,52.0,12.0,2024-03-31T12:00:00"""
357
 
358
  schema_desc = """Columns:
359
  - experiment_id: string, unique, format EXP-NNN
 
387
  data = rows[1:]
388
  issues: List[PlantedIssue] = []
389
 
390
+ # Issue 1: Data leakage signal — val_loss much lower than train_loss (hard — requires ML knowledge)
391
  r = 4 # EXP-005
392
  data[r][10] = "0.15" # val_loss=0.15 but train_loss=0.28 → suspicious
393
  issues.append(PlantedIssue(row=r + 1, col="val_loss", issue_type="inconsistent_value",
394
+ description="val_loss (0.15) significantly less than train_loss (0.28), potential data leakage",
395
+ difficulty=3.0))
396
 
397
+ # Issue 2: Batch size not power of 2 (moderate — domain convention)
398
  r = 8 # EXP-009
399
  data[r][7] = "250" # not a power of 2
400
  issues.append(PlantedIssue(row=r + 1, col="batch_size", issue_type="format_violation",
401
+ description="batch_size 250 is not a power of 2", difficulty=2.0))
402
 
403
+ # Issue 3: GPU memory unreasonable for model (hard — requires model size reasoning)
404
  r = 6 # EXP-007 resnet18 on cifar10
405
  data[r][12] = "42.5" # resnet18 shouldn't need 42.5 GB
406
  issues.append(PlantedIssue(row=r + 1, col="gpu_memory_gb", issue_type="statistical_outlier",
407
+ description="resnet18 on cifar10 using 42.5 GB GPU memory is unreasonable",
408
+ difficulty=3.0))
409
 
410
+ # Issue 4: Timestamp out of order (moderate — requires sequential comparison)
411
  r = 10 # EXP-011
412
  data[r][14] = "2024-03-02T09:00:00" # should be after EXP-010's timestamp
413
  issues.append(PlantedIssue(row=r + 1, col="timestamp", issue_type="inconsistent_value",
414
+ description="Timestamp 2024-03-02 is before EXP-010's timestamp 2024-03-11",
415
+ difficulty=2.0))
416
 
417
+ # Issue 5: Train size smaller than test size (moderate — cross-column logic)
418
  r = 9 # EXP-010
419
  data[r][3] = "500" # train_size=500 but test_size=1821
420
  issues.append(PlantedIssue(row=r + 1, col="train_size", issue_type="inconsistent_value",
421
+ description="train_size (500) is smaller than test_size (1821)",
422
+ difficulty=2.0))
423
 
424
+ # Issue 6: Negative training time (easy to spot)
425
  r = 13 # EXP-014
426
  data[r][13] = "-72.0"
427
  issues.append(PlantedIssue(row=r + 1, col="training_time_hours", issue_type="out_of_range",
428
+ description="Negative training time", difficulty=1.0))
429
 
430
+ # Issue 7: Learning rate out of range (easy to spot)
431
  r = 12 # EXP-013
432
  data[r][6] = "2.5" # way too high
433
  issues.append(PlantedIssue(row=r + 1, col="learning_rate", issue_type="out_of_range",
434
+ description="Learning rate 2.5 exceeds maximum of 1.0", difficulty=1.5))
435
 
436
+ # Issue 8: Missing model name (hardwhitespace-only is subtle)
437
  r = 14 # EXP-015
438
  data[r][1] = " "
439
  issues.append(PlantedIssue(row=r + 1, col="model_name", issue_type="missing_value",
440
+ description="model_name is whitespace-only", difficulty=2.5))
441
+
442
+ # Issue 9: Training time impossibly fast for dataset size and epochs
443
+ # EXP-004: vit-base on imagenet-1k, 300 epochs, but only 96 hours is plausible.
444
+ # Let's make EXP-009: efficientnet-b0 on imagenet-1k, 350 epochs = should take ~40+ hours
445
+ # but we set it to 0.5 hours — impossible for 1.2M images * 350 epochs
446
+ r = 8 # EXP-009 (same row as batch_size issue, different column)
447
+ data[r][13] = "0.5" # 30 minutes for 350 epochs on imagenet? impossible
448
+ issues.append(PlantedIssue(row=r + 1, col="training_time_hours", issue_type="statistical_outlier",
449
+ description="0.5 hours for 350 epochs on imagenet-1k (1.2M images) is impossibly fast",
450
+ difficulty=3.0))
451
+
452
+ # Issue 10: test_accuracy of 95.1% for roberta-large on SST-2 with train_size=500
453
+ # is suspiciously high — SOTA is ~96% with full dataset (67k). With only 500 training
454
+ # samples, 95.1% accuracy suggests data contamination or evaluation bug
455
+ r = 9 # EXP-010 (same row as train_size issue, different column)
456
+ # train_size is already corrupted to 500, but the test_accuracy 95.1 is from the
457
+ # original full-dataset run — this cross-column inconsistency is the real issue
458
+ # We don't modify the value — the inconsistency emerges from the train_size corruption
459
+ # So let's use a different row. EXP-001: resnet50 on imagenet, accuracy 76.3 is fine.
460
+ # Instead: EXP-012 wav2vec2 on librispeech — set test_accuracy to 98.5 (way too high
461
+ # for a speech model with only 20 epochs, SOTA is ~96% with much more training)
462
+ r = 11 # EXP-012
463
+ data[r][11] = "98.5" # wav2vec2 with 20 epochs shouldn't hit 98.5% — SOTA is ~96%
464
+ issues.append(PlantedIssue(row=r + 1, col="test_accuracy", issue_type="statistical_outlier",
465
+ description="test_accuracy 98.5% for wav2vec2 with only 20 epochs exceeds known SOTA (~96%), likely evaluation error",
466
+ difficulty=3.0))
467
 
468
  corrupted = _rows_to_csv([header] + data)
469
 
 
486
  )
487
 
488
 
489
+ # ---------------------------------------------------------------------------
490
+ # Contamination rules for extensible task creation
491
+ # ---------------------------------------------------------------------------
492
+
493
+ # Each contamination rule is a callable: (rows, header, col_idx, row_idx, rng) -> (new_value, PlantedIssue)
494
+ # Users can define their own and register them.
495
+
496
+ CONTAMINATION_RULES = {
497
+ "missing_value": lambda rows, header, col_idx, row_idx, rng: (
498
+ "",
499
+ PlantedIssue(
500
+ row=row_idx + 1, col=header[col_idx], issue_type="missing_value",
501
+ description=f"Empty {header[col_idx]} field", difficulty=1.0,
502
+ ),
503
+ ),
504
+ "whitespace_value": lambda rows, header, col_idx, row_idx, rng: (
505
+ " ",
506
+ PlantedIssue(
507
+ row=row_idx + 1, col=header[col_idx], issue_type="missing_value",
508
+ description=f"Whitespace-only {header[col_idx]} field", difficulty=2.5,
509
+ ),
510
+ ),
511
+ "wrong_type_text": lambda rows, header, col_idx, row_idx, rng: (
512
+ rng.choice(["not-a-number", "N/A", "null", "undefined"]),
513
+ PlantedIssue(
514
+ row=row_idx + 1, col=header[col_idx], issue_type="wrong_type",
515
+ description=f"{header[col_idx]} is text instead of expected type", difficulty=1.0,
516
+ ),
517
+ ),
518
+ "negative_value": lambda rows, header, col_idx, row_idx, rng: (
519
+ str(-abs(float(rows[row_idx][col_idx]) if rows[row_idx][col_idx] else 1)),
520
+ PlantedIssue(
521
+ row=row_idx + 1, col=header[col_idx], issue_type="out_of_range",
522
+ description=f"Negative {header[col_idx]}", difficulty=1.0,
523
+ ),
524
+ ),
525
+ }
526
+
527
+
528
+ def create_task_from_config(
529
+ task_id: str,
530
+ name: str,
531
+ description: str,
532
+ schema_description: str,
533
+ validation_rules: str,
534
+ clean_csv: str,
535
+ contaminations: List[dict],
536
+ max_steps: int = 3,
537
+ seed: int = 42,
538
+ ) -> Task:
539
+ """
540
+ Create a custom task from a configuration dict.
541
+
542
+ Each contamination entry should have:
543
+ - rule: str (key in CONTAMINATION_RULES) or callable
544
+ - row: int (0-based row index in data)
545
+ - col: int (column index in header)
546
+ - difficulty: float (optional, overrides rule default)
547
+
548
+ Example:
549
+ contaminations = [
550
+ {"rule": "missing_value", "row": 2, "col": 1, "difficulty": 1.5},
551
+ {"rule": "negative_value", "row": 5, "col": 4},
552
+ ]
553
+ """
554
+ rng = random.Random(seed)
555
+ rows = _csv_to_rows(clean_csv)
556
+ header = rows[0]
557
+ data = rows[1:]
558
+ issues: List[PlantedIssue] = []
559
+
560
+ for spec in contaminations:
561
+ rule = spec["rule"]
562
+ row_idx = spec["row"]
563
+ col_idx = spec["col"]
564
+
565
+ if callable(rule):
566
+ new_val, issue = rule(data, header, col_idx, row_idx, rng)
567
+ elif rule in CONTAMINATION_RULES:
568
+ new_val, issue = CONTAMINATION_RULES[rule](data, header, col_idx, row_idx, rng)
569
+ else:
570
+ raise ValueError(f"Unknown contamination rule: {rule}. Available: {list(CONTAMINATION_RULES.keys())}")
571
+
572
+ data[row_idx][col_idx] = new_val
573
+ if "difficulty" in spec:
574
+ issue.difficulty = spec["difficulty"]
575
+ issues.append(issue)
576
+
577
+ corrupted = _rows_to_csv([header] + data)
578
+
579
+ return Task(
580
+ task_id=task_id,
581
+ name=name,
582
+ description=description,
583
+ schema_description=schema_description,
584
+ validation_rules=validation_rules,
585
+ clean_csv=clean_csv,
586
+ planted_issues=issues,
587
+ corrupted_csv=corrupted,
588
+ max_steps=max_steps,
589
+ )
590
+
591
+
592
+ def register_task(task_id: str, factory_fn):
593
+ """Register a custom task factory. Factory must accept (seed: int) -> Task."""
594
+ TASK_REGISTRY[task_id] = factory_fn
595
+
596
+
597
+ def register_contamination_rule(name: str, rule_fn):
598
+ """
599
+ Register a custom contamination rule.
600
+
601
+ rule_fn signature: (rows, header, col_idx, row_idx, rng) -> (new_value, PlantedIssue)
602
+ """
603
+ CONTAMINATION_RULES[name] = rule_fn
604
+
605
+
606
  # ---------------------------------------------------------------------------
607
  # Task registry
608
  # ---------------------------------------------------------------------------
inference.py CHANGED
@@ -1,26 +1,31 @@
1
  #!/usr/bin/env python3
2
  """
3
- DataQA Inference Script
4
- -----------------------
5
- LLM agent that plays the DataQA environment.
 
 
 
6
  Uses the OpenAI client to interact with any OpenAI-compatible LLM API.
7
 
8
  Required environment variables:
9
- API_BASE_URL - LLM API endpoint (e.g., https://api.groq.com/openai/v1)
10
- MODEL_NAME - Model identifier (e.g., llama-3.3-70b-versatile)
11
- HF_TOKEN - HuggingFace token (for HF Spaces access)
12
-
13
- Structured logging format: [START], [STEP], [END] tags for evaluation.
 
 
 
14
  """
15
 
16
  from __future__ import annotations
17
 
18
- import json
19
  import os
20
  import re
21
  import sys
22
  import time
23
- from typing import Optional
24
 
25
  import requests
26
  from openai import OpenAI
@@ -28,52 +33,43 @@ from openai import OpenAI
28
  # ---------------------------------------------------------------------------
29
  # Configuration
30
  # ---------------------------------------------------------------------------
31
- API_BASE_URL = os.environ.get("API_BASE_URL", "https://api.groq.com/openai/v1")
32
- MODEL_NAME = os.environ.get("MODEL_NAME", "llama-3.3-70b-versatile")
33
- HF_TOKEN = os.environ.get("HF_TOKEN", "")
34
- ENV_URL = os.environ.get("ENV_URL", "http://localhost:8000")
35
 
 
36
  TASKS = ["easy", "medium", "hard"]
37
  MAX_STEPS_PER_TASK = 3
38
 
 
39
  # ---------------------------------------------------------------------------
40
- # Logging helpers (structured stdout for evaluation)
41
  # ---------------------------------------------------------------------------
42
 
43
- def log_start(task_id: str, metadata: Optional[dict] = None):
44
- entry = {"event": "START", "task_id": task_id, "timestamp": time.time()}
45
- if metadata:
46
- entry["metadata"] = metadata
47
- print(f"[START] {json.dumps(entry)}", flush=True)
48
-
49
-
50
- def log_step(task_id: str, step: int, reward: float, details: Optional[dict] = None):
51
- entry = {
52
- "event": "STEP",
53
- "task_id": task_id,
54
- "step": step,
55
- "reward": reward,
56
- "timestamp": time.time(),
57
- }
58
- if details:
59
- entry["details"] = details
60
- print(f"[STEP] {json.dumps(entry)}", flush=True)
61
-
62
-
63
- def log_end(task_id: str, final_score: float, metadata: Optional[dict] = None):
64
- entry = {
65
- "event": "END",
66
- "task_id": task_id,
67
- "final_score": final_score,
68
- "timestamp": time.time(),
69
- }
70
- if metadata:
71
- entry["metadata"] = metadata
72
- print(f"[END] {json.dumps(entry)}", flush=True)
73
 
74
 
75
  # ---------------------------------------------------------------------------
76
- # Environment HTTP client (simple, no WebSocket needed for inference)
77
  # ---------------------------------------------------------------------------
78
 
79
  class EnvHTTPClient:
@@ -99,26 +95,21 @@ class EnvHTTPClient:
99
  r.raise_for_status()
100
  return r.json()
101
 
102
- def step(self, issues: list[str], task_id: str = "easy") -> dict:
103
  r = self.session.post(
104
  f"{self.base_url}/step",
105
- json={"action": {"issues": issues, "task_id": task_id}},
106
  timeout=30,
107
  )
108
  r.raise_for_status()
109
  return r.json()
110
 
111
- def state(self) -> dict:
112
- r = self.session.get(f"{self.base_url}/state", timeout=10)
113
- r.raise_for_status()
114
- return r.json()
115
-
116
 
117
  # ---------------------------------------------------------------------------
118
- # LLM Agent
119
  # ---------------------------------------------------------------------------
120
 
121
- SYSTEM_PROMPT = """You are a data quality analyst. Your job is to inspect datasets and identify data quality issues.
122
 
123
  You will be given:
124
  1. A dataset in CSV format
@@ -142,7 +133,6 @@ CRITICAL INSTRUCTIONS FOR ROW NUMBERING:
142
  - Row numbers refer to the ROW POSITION in the CSV data, NOT the value of any ID column
143
  - Row 1 = the FIRST data row after the header
144
  - Row 2 = the SECOND data row after the header
145
- - For example, if the CSV has header on line 1 and data starting on line 2, the data on line 2 is row 1, line 3 is row 2, etc.
146
  - DO NOT use the employee_id, order_id, or experiment_id as the row number
147
  - Column names must match exactly (use the CSV header names, lowercase)
148
  - Check EVERY row and EVERY column systematically
@@ -154,7 +144,26 @@ Respond with ONLY the list of issues, one per line. No other text.
154
  Example: row:3,col:salary,issue:missing_value"""
155
 
156
 
157
- def build_user_prompt(observation: dict) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
  obs = observation if isinstance(observation, dict) else observation
159
  parts = []
160
 
@@ -173,6 +182,12 @@ def build_user_prompt(observation: dict) -> str:
173
  if feedback and "reset" not in feedback.lower():
174
  parts.append(f"FEEDBACK FROM PREVIOUS ATTEMPT:\n{feedback}")
175
 
 
 
 
 
 
 
176
  return "\n\n".join(parts)
177
 
178
 
@@ -183,88 +198,142 @@ def parse_llm_response(response: str) -> list[str]:
183
  line = line.strip()
184
  if not line:
185
  continue
186
- # Remove numbering like "1. " or "- " or "* "
187
  line = re.sub(r"^\s*[\d]+[.\)]\s*", "", line)
188
  line = re.sub(r"^\s*[-*]\s*", "", line)
189
  line = line.strip()
190
  if "row" in line.lower() and "col" in line.lower():
191
- # Lenient regex: accept : or = as delimiters, case-insensitive
192
  match = re.search(
193
  r"row\s*[:=]\s*(\d+)\s*[,;\s]+col(?:umn)?\s*[:=]\s*([\w_]+)\s*[,;\s]+issue\s*[:=]\s*([\w_]+)",
194
  line,
195
  re.IGNORECASE,
196
  )
197
  if match:
198
- # Normalize to lowercase canonical format
199
  normalized = f"row:{match.group(1)},col:{match.group(2).lower()},issue:{match.group(3).lower()}"
200
  issues.append(normalized)
201
  return issues
202
 
203
 
204
- def run_task(client: OpenAI, env: EnvHTTPClient, task_id: str) -> float:
205
- """Run a single task and return the best score."""
206
- log_start(task_id)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
 
208
- # Reset environment for this task
209
- reset_response = env.reset(task_id=task_id)
210
- observation = reset_response.get("observation", reset_response)
211
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
  best_score = 0.0
213
- messages = [{"role": "system", "content": SYSTEM_PROMPT}]
214
-
215
- for step_num in range(1, MAX_STEPS_PER_TASK + 1):
216
- user_prompt = build_user_prompt(observation)
217
- messages_for_call = messages + [{"role": "user", "content": user_prompt}]
218
-
219
- # Call LLM with retry on rate limit
220
- llm_output = ""
221
- for attempt in range(3):
222
- try:
223
- response = client.chat.completions.create(
224
- model=MODEL_NAME,
225
- messages=messages_for_call,
226
- temperature=0.1,
227
- max_tokens=2048,
228
- )
229
- llm_output = response.choices[0].message.content or ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
  break
231
- except Exception as e:
232
- if "rate_limit" in str(e).lower() or "429" in str(e):
233
- wait = 10 * (attempt + 1)
234
- print(f"[WARN] Rate limited, waiting {wait}s...", flush=True)
235
- time.sleep(wait)
236
- else:
237
- print(f"[ERROR] LLM call failed: {e}", file=sys.stderr, flush=True)
238
- break
239
-
240
- # Parse issues from LLM response
241
- issues = parse_llm_response(llm_output)
242
-
243
- if not issues:
244
- print(f"[WARN] No issues parsed from LLM response for {task_id} step {step_num}", file=sys.stderr, flush=True)
245
-
246
- # Submit to environment
247
- step_response = env.step(issues, task_id=task_id)
248
- observation = step_response.get("observation", step_response)
249
-
250
- # reward and done are at the top level of the response, not inside observation
251
- reward = float(step_response.get("reward", 0.0) or 0.0)
252
- done = bool(step_response.get("done", False))
253
- best_score = max(best_score, reward)
254
-
255
- log_step(task_id, step_num, reward, {
256
- "issues_reported": len(issues),
257
- "feedback": observation.get("feedback", ""),
258
- })
259
-
260
- if done:
261
- break
262
-
263
- # Add context for next attempt
264
- messages.append({"role": "user", "content": user_prompt})
265
- messages.append({"role": "assistant", "content": llm_output})
266
-
267
- log_end(task_id, best_score)
268
  return best_score
269
 
270
 
@@ -273,49 +342,34 @@ def run_task(client: OpenAI, env: EnvHTTPClient, task_id: str) -> float:
273
  # ---------------------------------------------------------------------------
274
 
275
  def main():
276
- print(f"[INFO] DataQA Inference starting", flush=True)
277
- print(f"[INFO] ENV_URL={ENV_URL}", flush=True)
278
- print(f"[INFO] API_BASE_URL={API_BASE_URL}", flush=True)
279
- print(f"[INFO] MODEL_NAME={MODEL_NAME}", flush=True)
280
 
281
- # Initialize clients
282
  env = EnvHTTPClient(ENV_URL)
283
  llm_client = OpenAI(
284
  base_url=API_BASE_URL,
285
- api_key=os.environ.get("LLM_API_KEY", HF_TOKEN or "no-key"),
286
  )
287
 
288
- # Check environment health
289
  if not env.health():
290
- print("[ERROR] Environment is not healthy. Exiting.", file=sys.stderr, flush=True)
291
  sys.exit(1)
292
 
293
- print(f"[INFO] Environment is healthy", flush=True)
294
 
295
- # Run all tasks
296
  scores = {}
297
  for task_id in TASKS:
298
- print(f"\n{'='*60}", flush=True)
299
- print(f"[INFO] Starting task: {task_id}", flush=True)
300
- print(f"{'='*60}", flush=True)
301
-
302
  try:
303
  score = run_task(llm_client, env, task_id)
304
  scores[task_id] = score
305
- print(f"[INFO] Task {task_id} completed with score: {score:.3f}", flush=True)
306
  except Exception as e:
307
- print(f"[ERROR] Task {task_id} failed: {e}", file=sys.stderr, flush=True)
308
  scores[task_id] = 0.0
309
 
310
- # Summary
311
- print(f"\n{'='*60}", flush=True)
312
- print("[INFO] FINAL RESULTS", flush=True)
313
- print(f"{'='*60}", flush=True)
314
- for task_id, score in scores.items():
315
- print(f"[INFO] {task_id}: {score:.3f}", flush=True)
316
-
317
  avg_score = sum(scores.values()) / len(scores) if scores else 0.0
318
- print(f"[INFO] Average score: {avg_score:.3f}", flush=True)
319
 
320
 
321
  if __name__ == "__main__":
 
1
  #!/usr/bin/env python3
2
  """
3
+ DataQA Inference Script — Two-Phase Agent
4
+ ------------------------------------------
5
+ LLM agent that plays the DataQA environment in two phases:
6
+ Phase 1: Identify all data quality issues
7
+ Phase 2: Propose fixes for identified issues
8
+
9
  Uses the OpenAI client to interact with any OpenAI-compatible LLM API.
10
 
11
  Required environment variables:
12
+ API_BASE_URL - LLM API endpoint (e.g., https://router.huggingface.co/v1)
13
+ MODEL_NAME - Model identifier (e.g., Qwen/Qwen2.5-72B-Instruct)
14
+ HF_TOKEN - HuggingFace token / API key
15
+
16
+ STDOUT FORMAT (mandatory for evaluation):
17
+ [START] task=<task_name> env=<benchmark> model=<model_name>
18
+ [STEP] step=<n> action=<action_str> reward=<0.00> done=<true|false> error=<msg|null>
19
+ [END] success=<true|false> steps=<n> score=<score> rewards=<r1,r2,...,rn>
20
  """
21
 
22
  from __future__ import annotations
23
 
 
24
  import os
25
  import re
26
  import sys
27
  import time
28
+ from typing import List, Optional
29
 
30
  import requests
31
  from openai import OpenAI
 
33
  # ---------------------------------------------------------------------------
34
  # Configuration
35
  # ---------------------------------------------------------------------------
36
+ API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
37
+ MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
38
+ API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
39
+ ENV_URL = os.getenv("ENV_URL", "http://localhost:8000")
40
 
41
+ BENCHMARK = "dataqa_env"
42
  TASKS = ["easy", "medium", "hard"]
43
  MAX_STEPS_PER_TASK = 3
44
 
45
+
46
  # ---------------------------------------------------------------------------
47
+ # Logging helpers (structured stdout exact format required by evaluation)
48
  # ---------------------------------------------------------------------------
49
 
50
+ def log_start(task: str, env: str, model: str) -> None:
51
+ print(f"[START] task={task} env={env} model={model}", flush=True)
52
+
53
+
54
+ def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
55
+ error_val = error if error else "null"
56
+ done_val = str(done).lower()
57
+ print(
58
+ f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}",
59
+ flush=True,
60
+ )
61
+
62
+
63
+ def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
64
+ rewards_str = ",".join(f"{r:.2f}" for r in rewards)
65
+ print(
66
+ f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}",
67
+ flush=True,
68
+ )
 
 
 
 
 
 
 
 
 
 
 
69
 
70
 
71
  # ---------------------------------------------------------------------------
72
+ # Environment HTTP client
73
  # ---------------------------------------------------------------------------
74
 
75
  class EnvHTTPClient:
 
95
  r.raise_for_status()
96
  return r.json()
97
 
98
+ def step(self, issues: list[str], fixes: list[str], task_id: str = "easy") -> dict:
99
  r = self.session.post(
100
  f"{self.base_url}/step",
101
+ json={"action": {"issues": issues, "fixes": fixes, "task_id": task_id}},
102
  timeout=30,
103
  )
104
  r.raise_for_status()
105
  return r.json()
106
 
 
 
 
 
 
107
 
108
  # ---------------------------------------------------------------------------
109
+ # LLM Prompts
110
  # ---------------------------------------------------------------------------
111
 
112
+ IDENTIFY_SYSTEM_PROMPT = """You are a data quality analyst. Your job is to inspect datasets and identify data quality issues.
113
 
114
  You will be given:
115
  1. A dataset in CSV format
 
133
  - Row numbers refer to the ROW POSITION in the CSV data, NOT the value of any ID column
134
  - Row 1 = the FIRST data row after the header
135
  - Row 2 = the SECOND data row after the header
 
136
  - DO NOT use the employee_id, order_id, or experiment_id as the row number
137
  - Column names must match exactly (use the CSV header names, lowercase)
138
  - Check EVERY row and EVERY column systematically
 
144
  Example: row:3,col:salary,issue:missing_value"""
145
 
146
 
147
+ FIX_SYSTEM_PROMPT = """You are a data repair specialist. You have already identified data quality issues in a dataset. Now you must propose the correct values to fix each issue.
148
+
149
+ For each issue you identified, propose a fix in EXACTLY this format:
150
+ row:<row_number>,col:<column_name>,fix:<corrected_value>
151
+
152
+ Guidelines for proposing fixes:
153
+ - For missing_value: infer the correct value from context, schema, and other rows
154
+ - For wrong_type: convert to the correct type (e.g., "seventy-five thousand" → "75000")
155
+ - For out_of_range: propose a value within the valid range that makes sense in context
156
+ - For format_violation: correct the format (e.g., "26/01/2024" → "2024-01-26")
157
+ - For inconsistent_value: compute the correct value from related fields
158
+ - For duplicate_row: propose a corrected unique key or indicate removal
159
+ - For statistical_outlier: propose a reasonable value given the model/context
160
+
161
+ Use the schema, validation rules, and surrounding data to determine the correct fix.
162
+ Respond with ONLY the list of fixes, one per line. No other text.
163
+ Example: row:3,col:salary,fix:75000"""
164
+
165
+
166
+ def build_user_prompt(observation: dict, include_fixes: bool = False) -> str:
167
  obs = observation if isinstance(observation, dict) else observation
168
  parts = []
169
 
 
182
  if feedback and "reset" not in feedback.lower():
183
  parts.append(f"FEEDBACK FROM PREVIOUS ATTEMPT:\n{feedback}")
184
 
185
+ if include_fixes:
186
+ parts.append(
187
+ "Now propose fixes for ALL issues. "
188
+ "Use format: row:<N>,col:<name>,fix:<corrected_value>"
189
+ )
190
+
191
  return "\n\n".join(parts)
192
 
193
 
 
198
  line = line.strip()
199
  if not line:
200
  continue
 
201
  line = re.sub(r"^\s*[\d]+[.\)]\s*", "", line)
202
  line = re.sub(r"^\s*[-*]\s*", "", line)
203
  line = line.strip()
204
  if "row" in line.lower() and "col" in line.lower():
 
205
  match = re.search(
206
  r"row\s*[:=]\s*(\d+)\s*[,;\s]+col(?:umn)?\s*[:=]\s*([\w_]+)\s*[,;\s]+issue\s*[:=]\s*([\w_]+)",
207
  line,
208
  re.IGNORECASE,
209
  )
210
  if match:
 
211
  normalized = f"row:{match.group(1)},col:{match.group(2).lower()},issue:{match.group(3).lower()}"
212
  issues.append(normalized)
213
  return issues
214
 
215
 
216
+ def parse_fix_response(response: str) -> list[str]:
217
+ """Extract fix lines from LLM response."""
218
+ fixes = []
219
+ for line in response.strip().split("\n"):
220
+ line = line.strip()
221
+ if not line:
222
+ continue
223
+ line = re.sub(r"^\s*[\d]+[.\)]\s*", "", line)
224
+ line = re.sub(r"^\s*[-*]\s*", "", line)
225
+ line = line.strip()
226
+ if "row" in line.lower() and "fix" in line.lower():
227
+ match = re.search(
228
+ r"row\s*[:=]\s*(\d+)\s*[,;\s]+col(?:umn)?\s*[:=]\s*([\w_]+)\s*[,;\s]+fix\s*[:=]\s*(.+?)$",
229
+ line,
230
+ re.IGNORECASE,
231
+ )
232
+ if match:
233
+ normalized = f"row:{match.group(1)},col:{match.group(2).lower()},fix:{match.group(3).strip()}"
234
+ fixes.append(normalized)
235
+ return fixes
236
 
 
 
 
237
 
238
+ def call_llm(client: OpenAI, system_prompt: str, user_prompt: str) -> str:
239
+ """Call the LLM with retry on rate limit."""
240
+ for attempt in range(3):
241
+ try:
242
+ response = client.chat.completions.create(
243
+ model=MODEL_NAME,
244
+ messages=[
245
+ {"role": "system", "content": system_prompt},
246
+ {"role": "user", "content": user_prompt},
247
+ ],
248
+ temperature=0.1,
249
+ max_tokens=2048,
250
+ )
251
+ return response.choices[0].message.content or ""
252
+ except Exception as e:
253
+ if "rate_limit" in str(e).lower() or "429" in str(e):
254
+ wait = 10 * (attempt + 1)
255
+ print(f"[DEBUG] Rate limited, waiting {wait}s...", file=sys.stderr, flush=True)
256
+ time.sleep(wait)
257
+ else:
258
+ print(f"[DEBUG] LLM call failed: {e}", file=sys.stderr, flush=True)
259
+ return ""
260
+ return ""
261
+
262
+
263
+ def run_task(client: OpenAI, env: EnvHTTPClient, task_id: str) -> float:
264
+ """
265
+ Run a single task with two-phase strategy:
266
+ Step 1: Identify issues only
267
+ Step 2: Identify + Fix (using feedback from step 1)
268
+ Step 3: Refined identify + fix (if needed)
269
+ """
270
+ log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME)
271
+
272
+ rewards: List[float] = []
273
+ steps_taken = 0
274
  best_score = 0.0
275
+ success = False
276
+
277
+ try:
278
+ reset_response = env.reset(task_id=task_id)
279
+ observation = reset_response.get("observation", reset_response)
280
+
281
+ last_issues: list[str] = []
282
+ last_llm_output = ""
283
+
284
+ for step_num in range(1, MAX_STEPS_PER_TASK + 1):
285
+ error_msg = None
286
+
287
+ # ── Phase 1: Identify issues ──
288
+ user_prompt = build_user_prompt(observation)
289
+ identify_output = call_llm(client, IDENTIFY_SYSTEM_PROMPT, user_prompt)
290
+ issues = parse_llm_response(identify_output)
291
+
292
+ if not issues and not error_msg:
293
+ error_msg = "no issues parsed from LLM response"
294
+
295
+ # ── Phase 2: Propose fixes (from step 2 onward, or always if we have issues) ──
296
+ fixes: list[str] = []
297
+ if issues and step_num >= 2:
298
+ # Build a fix prompt that includes the identified issues
299
+ fix_prompt = build_user_prompt(observation, include_fixes=True)
300
+ fix_prompt += f"\n\nISSUES FOUND:\n" + "\n".join(issues)
301
+ fix_output = call_llm(client, FIX_SYSTEM_PROMPT, fix_prompt)
302
+ fixes = parse_fix_response(fix_output)
303
+
304
+ # ── Submit to environment ──
305
+ action_str = ";".join(issues[:5]) if issues else "none"
306
+ if fixes:
307
+ action_str += "|fixes:" + ";".join(fixes[:3])
308
+
309
+ step_response = env.step(issues, fixes, task_id=task_id)
310
+ observation = step_response.get("observation", step_response)
311
+
312
+ reward = float(step_response.get("reward", 0.0) or 0.0)
313
+ done = bool(step_response.get("done", False))
314
+ best_score = max(best_score, reward)
315
+ rewards.append(reward)
316
+ steps_taken = step_num
317
+
318
+ log_step(
319
+ step=step_num,
320
+ action=action_str,
321
+ reward=reward,
322
+ done=done,
323
+ error=error_msg,
324
+ )
325
+
326
+ if done:
327
  break
328
+
329
+ last_issues = issues
330
+ last_llm_output = identify_output
331
+
332
+ success = best_score >= 0.5
333
+
334
+ finally:
335
+ log_end(success=success, steps=steps_taken, score=best_score, rewards=rewards)
336
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
337
  return best_score
338
 
339
 
 
342
  # ---------------------------------------------------------------------------
343
 
344
  def main():
345
+ print(f"[DEBUG] DataQA Inference starting", file=sys.stderr, flush=True)
346
+ print(f"[DEBUG] ENV_URL={ENV_URL}", file=sys.stderr, flush=True)
347
+ print(f"[DEBUG] API_BASE_URL={API_BASE_URL}", file=sys.stderr, flush=True)
348
+ print(f"[DEBUG] MODEL_NAME={MODEL_NAME}", file=sys.stderr, flush=True)
349
 
 
350
  env = EnvHTTPClient(ENV_URL)
351
  llm_client = OpenAI(
352
  base_url=API_BASE_URL,
353
+ api_key=API_KEY or "no-key",
354
  )
355
 
 
356
  if not env.health():
357
+ print("[DEBUG] Environment is not healthy. Exiting.", file=sys.stderr, flush=True)
358
  sys.exit(1)
359
 
360
+ print(f"[DEBUG] Environment is healthy", file=sys.stderr, flush=True)
361
 
 
362
  scores = {}
363
  for task_id in TASKS:
 
 
 
 
364
  try:
365
  score = run_task(llm_client, env, task_id)
366
  scores[task_id] = score
 
367
  except Exception as e:
368
+ print(f"[DEBUG] Task {task_id} failed: {e}", file=sys.stderr, flush=True)
369
  scores[task_id] = 0.0
370
 
 
 
 
 
 
 
 
371
  avg_score = sum(scores.values()) / len(scores) if scores else 0.0
372
+ print(f"\n[DEBUG] FINAL RESULTS: {scores} avg={avg_score:.3f}", file=sys.stderr, flush=True)
373
 
374
 
375
  if __name__ == "__main__":
openenv.yaml CHANGED
@@ -3,4 +3,4 @@ name: dataqa_env
3
  type: space
4
  runtime: fastapi
5
  app: dataqa_env.server.app:app
6
- port: 8000
 
3
  type: space
4
  runtime: fastapi
5
  app: dataqa_env.server.app:app
6
+ port: 7860
scripts/prevalidation_script.sh ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ #
3
+ # validate-submission.sh — OpenEnv Submission Validator
4
+ #
5
+ # Checks that your HF Space is live, Docker image builds, and openenv validate passes.
6
+ #
7
+ # Prerequisites:
8
+ # - Docker: https://docs.docker.com/get-docker/
9
+ # - openenv-core: pip install openenv-core
10
+ # - curl (usually pre-installed)
11
+ #
12
+ # Run:
13
+ # curl -fsSL https://raw.githubusercontent.com/<owner>/<repo>/main/scripts/validate-submission.sh | bash -s -- <ping_url> [repo_dir]
14
+ #
15
+ # Or download and run locally:
16
+ # chmod +x validate-submission.sh
17
+ # ./validate-submission.sh <ping_url> [repo_dir]
18
+ #
19
+ # Arguments:
20
+ # ping_url Your HuggingFace Space URL (e.g. https://your-space.hf.space)
21
+ # repo_dir Path to your repo (default: current directory)
22
+ #
23
+ # Examples:
24
+ # ./validate-submission.sh https://my-team.hf.space
25
+ # ./validate-submission.sh https://my-team.hf.space ./my-repo
26
+ #
27
+
28
+ set -uo pipefail
29
+
30
+ DOCKER_BUILD_TIMEOUT=600
31
+ if [ -t 1 ]; then
32
+ RED='\033[0;31m'
33
+ GREEN='\033[0;32m'
34
+ YELLOW='\033[1;33m'
35
+ BOLD='\033[1m'
36
+ NC='\033[0m'
37
+ else
38
+ RED='' GREEN='' YELLOW='' BOLD='' NC=''
39
+ fi
40
+
41
+ run_with_timeout() {
42
+ local secs="$1"; shift
43
+ if command -v timeout &>/dev/null; then
44
+ timeout "$secs" "$@"
45
+ elif command -v gtimeout &>/dev/null; then
46
+ gtimeout "$secs" "$@"
47
+ else
48
+ "$@" &
49
+ local pid=$!
50
+ ( sleep "$secs" && kill "$pid" 2>/dev/null ) &
51
+ local watcher=$!
52
+ wait "$pid" 2>/dev/null
53
+ local rc=$?
54
+ kill "$watcher" 2>/dev/null
55
+ wait "$watcher" 2>/dev/null
56
+ return $rc
57
+ fi
58
+ }
59
+
60
+ portable_mktemp() {
61
+ local prefix="${1:-validate}"
62
+ mktemp "${TMPDIR:-/tmp}/${prefix}-XXXXXX" 2>/dev/null || mktemp
63
+ }
64
+
65
+ CLEANUP_FILES=()
66
+ cleanup() { rm -f "${CLEANUP_FILES[@]+"${CLEANUP_FILES[@]}"}"; }
67
+ trap cleanup EXIT
68
+
69
+ PING_URL="${1:-}"
70
+ REPO_DIR="${2:-.}"
71
+
72
+ if [ -z "$PING_URL" ]; then
73
+ printf "Usage: %s <ping_url> [repo_dir]\n" "$0"
74
+ printf "\n"
75
+ printf " ping_url Your HuggingFace Space URL (e.g. https://your-space.hf.space)\n"
76
+ printf " repo_dir Path to your repo (default: current directory)\n"
77
+ exit 1
78
+ fi
79
+
80
+ if ! REPO_DIR="$(cd "$REPO_DIR" 2>/dev/null && pwd)"; then
81
+ printf "Error: directory '%s' not found\n" "${2:-.}"
82
+ exit 1
83
+ fi
84
+ PING_URL="${PING_URL%/}"
85
+ export PING_URL
86
+ PASS=0
87
+
88
+ log() { printf "[%s] %b\n" "$(date -u +%H:%M:%S)" "$*"; }
89
+ pass() { log "${GREEN}PASSED${NC} -- $1"; PASS=$((PASS + 1)); }
90
+ fail() { log "${RED}FAILED${NC} -- $1"; }
91
+ hint() { printf " ${YELLOW}Hint:${NC} %b\n" "$1"; }
92
+ stop_at() {
93
+ printf "\n"
94
+ printf "${RED}${BOLD}Validation stopped at %s.${NC} Fix the above before continuing.\n" "$1"
95
+ exit 1
96
+ }
97
+
98
+ printf "\n"
99
+ printf "${BOLD}========================================${NC}\n"
100
+ printf "${BOLD} OpenEnv Submission Validator${NC}\n"
101
+ printf "${BOLD}========================================${NC}\n"
102
+ log "Repo: $REPO_DIR"
103
+ log "Ping URL: $PING_URL"
104
+ printf "\n"
105
+
106
+ log "${BOLD}Step 1/3: Pinging HF Space${NC} ($PING_URL/reset) ..."
107
+
108
+ CURL_OUTPUT=$(portable_mktemp "validate-curl")
109
+ CLEANUP_FILES+=("$CURL_OUTPUT")
110
+ HTTP_CODE=$(curl -s -o "$CURL_OUTPUT" -w "%{http_code}" -X POST \
111
+ -H "Content-Type: application/json" -d '{}' \
112
+ "$PING_URL/reset" --max-time 30 2>"$CURL_OUTPUT" || printf "000")
113
+
114
+ if [ "$HTTP_CODE" = "200" ]; then
115
+ pass "HF Space is live and responds to /reset"
116
+ elif [ "$HTTP_CODE" = "000" ]; then
117
+ fail "HF Space not reachable (connection failed or timed out)"
118
+ hint "Check your network connection and that the Space is running."
119
+ hint "Try: curl -s -o /dev/null -w '%%{http_code}' -X POST $PING_URL/reset"
120
+ stop_at "Step 1"
121
+ else
122
+ fail "HF Space /reset returned HTTP $HTTP_CODE (expected 200)"
123
+ hint "Make sure your Space is running and the URL is correct."
124
+ hint "Try opening $PING_URL in your browser first."
125
+ stop_at "Step 1"
126
+ fi
127
+
128
+ log "${BOLD}Step 2/3: Running docker build${NC} ..."
129
+
130
+ if ! command -v docker &>/dev/null; then
131
+ fail "docker command not found"
132
+ hint "Install Docker: https://docs.docker.com/get-docker/"
133
+ stop_at "Step 2"
134
+ fi
135
+
136
+ if [ -f "$REPO_DIR/Dockerfile" ]; then
137
+ DOCKER_CONTEXT="$REPO_DIR"
138
+ elif [ -f "$REPO_DIR/server/Dockerfile" ]; then
139
+ DOCKER_CONTEXT="$REPO_DIR/server"
140
+ else
141
+ fail "No Dockerfile found in repo root or server/ directory"
142
+ stop_at "Step 2"
143
+ fi
144
+
145
+ log " Found Dockerfile in $DOCKER_CONTEXT"
146
+
147
+ BUILD_OK=false
148
+ BUILD_OUTPUT=$(run_with_timeout "$DOCKER_BUILD_TIMEOUT" docker build "$DOCKER_CONTEXT" 2>&1) && BUILD_OK=true
149
+
150
+ if [ "$BUILD_OK" = true ]; then
151
+ pass "Docker build succeeded"
152
+ else
153
+ fail "Docker build failed (timeout=${DOCKER_BUILD_TIMEOUT}s)"
154
+ printf "%s\n" "$BUILD_OUTPUT" | tail -20
155
+ stop_at "Step 2"
156
+ fi
157
+
158
+ log "${BOLD}Step 3/3: Running openenv validate${NC} ..."
159
+
160
+ if ! command -v openenv &>/dev/null; then
161
+ fail "openenv command not found"
162
+ hint "Install it: pip install openenv-core"
163
+ stop_at "Step 3"
164
+ fi
165
+
166
+ VALIDATE_OK=false
167
+ VALIDATE_OUTPUT=$(cd "$REPO_DIR" && openenv validate 2>&1) && VALIDATE_OK=true
168
+
169
+ if [ "$VALIDATE_OK" = true ]; then
170
+ pass "openenv validate passed"
171
+ [ -n "$VALIDATE_OUTPUT" ] && log " $VALIDATE_OUTPUT"
172
+ else
173
+ fail "openenv validate failed"
174
+ printf "%s\n" "$VALIDATE_OUTPUT"
175
+ stop_at "Step 3"
176
+ fi
177
+
178
+ printf "\n"
179
+ printf "${BOLD}========================================${NC}\n"
180
+ printf "${GREEN}${BOLD} All 3/3 checks passed!${NC}\n"
181
+ printf "${GREEN}${BOLD} Your submission is ready to submit.${NC}\n"
182
+ printf "${BOLD}========================================${NC}\n"
183
+ printf "\n"
184
+
185
+ exit 0
scripts/sample_inference_script.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Inference Script Example
3
+ ===================================
4
+ MANDATORY
5
+ - Before submitting, ensure the following variables are defined in your environment configuration:
6
+ API_BASE_URL The API endpoint for the LLM.
7
+ MODEL_NAME The model identifier to use for inference.
8
+ HF_TOKEN Your Hugging Face / API key.
9
+ LOCAL_IMAGE_NAME The name of the local image to use for the environment if you are using from_docker_image()
10
+ method
11
+
12
+ - Defaults are set only for API_BASE_URL and MODEL_NAME
13
+ (and should reflect your active inference setup):
14
+ API_BASE_URL = os.getenv("API_BASE_URL", "<your-active-endpoint>")
15
+ MODEL_NAME = os.getenv("MODEL_NAME", "<your-active-model>")
16
+
17
+ - The inference script must be named `inference.py` and placed in the root directory of the project
18
+ - Participants must use OpenAI Client for all LLM calls using above variables
19
+
20
+ STDOUT FORMAT
21
+ - The script must emit exactly three line types to stdout, in this order:
22
+
23
+ [START] task=<task_name> env=<benchmark> model=<model_name>
24
+ [STEP] step=<n> action=<action_str> reward=<0.00> done=<true|false> error=<msg|null>
25
+ [END] success=<true|false> steps=<n> score=<score> rewards=<r1,r2,...,rn>
26
+
27
+ Rules:
28
+ - One [START] line at episode begin.
29
+ - One [STEP] line per step, immediately after env.step() returns.
30
+ - One [END] line after env.close(), always emitted (even on exception).
31
+ - reward and rewards are formatted to 2 decimal places.
32
+ - done and success are lowercase booleans: true or false.
33
+ - error is the raw last_action_error string, or null if none.
34
+ - All fields on a single line with no newlines within a line.
35
+ - Each tasks should return score in [0, 1]
36
+
37
+ Example:
38
+ [START] task=click-test env=miniwob model=Qwen3-VL-30B
39
+ [STEP] step=1 action=click('123') reward=0.00 done=false error=null
40
+ [STEP] step=2 action=fill('456','text') reward=0.00 done=false error=null
41
+ [STEP] step=3 action=click('789') reward=1.00 done=true error=null
42
+ [END] success=true steps=3 score=1.00 rewards=0.00,0.00,1.00
43
+ """
44
+
45
+ import asyncio
46
+ import os
47
+ import textwrap
48
+ from typing import List, Optional
49
+
50
+ from openai import OpenAI
51
+
52
+ from my_env_v4 import MyEnvV4Action, MyEnvV4Env
53
+ IMAGE_NAME = os.getenv("IMAGE_NAME") # If you are using docker image
54
+ API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
55
+
56
+ API_BASE_URL = os.getenv("API_BASE_URL") or "https://router.huggingface.co/v1"
57
+ MODEL_NAME = os.getenv("MODEL_NAME") or "Qwen/Qwen2.5-72B-Instruct"
58
+ TASK_NAME = os.getenv("MY_ENV_V4_TASK", "echo")
59
+ BENCHMARK = os.getenv("MY_ENV_V4_BENCHMARK", "my_env_v4")
60
+ MAX_STEPS = 8
61
+ TEMPERATURE = 0.7
62
+ MAX_TOKENS = 150
63
+ SUCCESS_SCORE_THRESHOLD = 0.1 # normalized score in [0, 1]
64
+
65
+ # Max possible reward: each token contributes 0.1, across all steps
66
+ _MAX_REWARD_PER_STEP = MAX_TOKENS * 0.1
67
+ MAX_TOTAL_REWARD = MAX_STEPS * _MAX_REWARD_PER_STEP
68
+
69
+ SYSTEM_PROMPT = textwrap.dedent(
70
+ """
71
+ You are interacting with a simple echo environment.
72
+ Each turn you must send a message. The environment will echo it back.
73
+ Reward is proportional to message length: reward = len(message) * 0.1
74
+ Your goal is to maximize total reward by sending meaningful, substantive messages.
75
+ Reply with exactly one message string — no quotes, no prefixes, just the message text.
76
+ """
77
+ ).strip()
78
+
79
+
80
+ def log_start(task: str, env: str, model: str) -> None:
81
+ print(f"[START] task={task} env={env} model={model}", flush=True)
82
+
83
+
84
+ def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
85
+ error_val = error if error else "null"
86
+ done_val = str(done).lower()
87
+ print(
88
+ f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}",
89
+ flush=True,
90
+ )
91
+
92
+
93
+ def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
94
+ rewards_str = ",".join(f"{r:.2f}" for r in rewards)
95
+ print(f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}", flush=True)
96
+
97
+
98
+ def build_user_prompt(step: int, last_echoed: str, last_reward: float, history: List[str]) -> str:
99
+ history_block = "\n".join(history[-4:]) if history else "None"
100
+ return textwrap.dedent(
101
+ f"""
102
+ Step: {step}
103
+ Last echoed message: {last_echoed!r}
104
+ Last reward: {last_reward:.2f}
105
+ Previous steps:
106
+ {history_block}
107
+ Send your next message.
108
+ """
109
+ ).strip()
110
+
111
+
112
+ def get_model_message(client: OpenAI, step: int, last_echoed: str, last_reward: float, history: List[str]) -> str:
113
+ user_prompt = build_user_prompt(step, last_echoed, last_reward, history)
114
+ try:
115
+ completion = client.chat.completions.create(
116
+ model=MODEL_NAME,
117
+ messages=[
118
+ {"role": "system", "content": SYSTEM_PROMPT},
119
+ {"role": "user", "content": user_prompt},
120
+ ],
121
+ temperature=TEMPERATURE,
122
+ max_tokens=MAX_TOKENS,
123
+ stream=False,
124
+ )
125
+ text = (completion.choices[0].message.content or "").strip()
126
+ return text if text else "hello"
127
+ except Exception as exc:
128
+ print(f"[DEBUG] Model request failed: {exc}", flush=True)
129
+ return "hello"
130
+
131
+
132
+ async def main() -> None:
133
+ client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
134
+
135
+ env = await MyEnvV4Env.from_docker_image(IMAGE_NAME)
136
+
137
+ history: List[str] = []
138
+ rewards: List[float] = []
139
+ steps_taken = 0
140
+ score = 0.0
141
+ success = False
142
+
143
+ log_start(task=TASK_NAME, env=BENCHMARK, model=MODEL_NAME)
144
+
145
+ try:
146
+ result = await env.reset() # OpenENV.reset()
147
+ last_echoed = result.observation.echoed_message
148
+ last_reward = 0.0
149
+
150
+ for step in range(1, MAX_STEPS + 1):
151
+ if result.done:
152
+ break
153
+
154
+ message = get_model_message(client, step, last_echoed, last_reward, history)
155
+
156
+ result = await env.step(MyEnvV4Action(message=message))
157
+ obs = result.observation
158
+
159
+ reward = result.reward or 0.0
160
+ done = result.done
161
+ error = None
162
+
163
+ rewards.append(reward)
164
+ steps_taken = step
165
+ last_echoed = obs.echoed_message
166
+ last_reward = reward
167
+
168
+ log_step(step=step, action=message, reward=reward, done=done, error=error)
169
+
170
+ history.append(f"Step {step}: {message!r} -> reward {reward:+.2f}")
171
+
172
+ if done:
173
+ break
174
+
175
+ score = sum(rewards) / MAX_TOTAL_REWARD if MAX_TOTAL_REWARD > 0 else 0.0
176
+ score = min(max(score, 0.0), 1.0) # clamp to [0, 1]
177
+ success = score >= SUCCESS_SCORE_THRESHOLD
178
+
179
+ finally:
180
+ try:
181
+ await env.close()
182
+ except Exception as e:
183
+ print(f"[DEBUG] env.close() error (container cleanup): {e}", flush=True)
184
+ log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
185
+
186
+
187
+ if __name__ == "__main__":
188
+ asyncio.run(main())
server/__init__.py CHANGED
@@ -0,0 +1 @@
 
 
1
+ """Root-level server package — delegates to dataqa_env.server."""
server/app.py CHANGED
@@ -1,13 +1,12 @@
1
- """
2
- Root-level server entry point for OpenEnv compatibility.
3
- """
4
 
5
  from dataqa_env.server.app import app # noqa: F401
6
 
7
 
8
  def main():
 
9
  import uvicorn
10
- uvicorn.run(app, host="0.0.0.0", port=8000)
11
 
12
 
13
  if __name__ == "__main__":
 
1
+ """Entrypoint for openenv-core deployment. Delegates to dataqa_env.server.app."""
 
 
2
 
3
  from dataqa_env.server.app import app # noqa: F401
4
 
5
 
6
  def main():
7
+ """Start the environment server."""
8
  import uvicorn
9
+ uvicorn.run(app, host="0.0.0.0", port=7860)
10
 
11
 
12
  if __name__ == "__main__":
tests/__init__.py ADDED
File without changes
tests/test_environment.py ADDED
@@ -0,0 +1,455 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for the DataQA environment (reset, step, scoring, two-phase identify+fix)."""
2
+
3
+ import pytest
4
+ from dataqa_env.server.environment import (
5
+ DataQAEnvironment,
6
+ parse_issue_key,
7
+ parse_fix,
8
+ compute_f1,
9
+ compute_weighted_reward,
10
+ grade_fixes,
11
+ IDENTIFY_WEIGHT,
12
+ FIX_WEIGHT,
13
+ )
14
+ from dataqa_env.models import DataQAAction
15
+ from dataqa_env.server.tasks import PlantedIssue, create_task_easy, create_task_medium
16
+
17
+
18
+ # ──────────────────────────────────────────────────────
19
+ # Issue parsing
20
+ # ──────────────────────────────────────────────────────
21
+
22
+ class TestParseIssueKey:
23
+ def test_standard_format(self):
24
+ assert parse_issue_key("row:3,col:salary,issue:missing_value") == "row:3,col:salary,issue:missing_value"
25
+
26
+ def test_with_equals(self):
27
+ assert parse_issue_key("row=3,col=salary,issue=missing_value") == "row:3,col:salary,issue:missing_value"
28
+
29
+ def test_case_insensitive(self):
30
+ assert parse_issue_key("Row:3,Col:Salary,Issue:Missing_Value") == "row:3,col:salary,issue:missing_value"
31
+
32
+ def test_with_spaces(self):
33
+ assert parse_issue_key("row: 3, col: salary, issue: missing_value") == "row:3,col:salary,issue:missing_value"
34
+
35
+ def test_unparseable(self):
36
+ assert parse_issue_key("this is garbage") is None
37
+
38
+ def test_partial_match(self):
39
+ assert parse_issue_key("row:3,col:salary") is None
40
+
41
+ def test_empty_string(self):
42
+ assert parse_issue_key("") is None
43
+
44
+ def test_semicolon_separator(self):
45
+ result = parse_issue_key("row:3;col:salary;issue:missing_value")
46
+ assert result == "row:3,col:salary,issue:missing_value"
47
+
48
+
49
+ # ──────────────────────────────────────────────────────
50
+ # Fix parsing
51
+ # ──────────────────────────────────────────────────────
52
+
53
+ class TestParseFix:
54
+ def test_standard_format(self):
55
+ result = parse_fix("row:4,col:name,fix:Alice Chen")
56
+ assert result == (4, "name", "Alice Chen")
57
+
58
+ def test_with_equals(self):
59
+ result = parse_fix("row=4,col=name,fix=Alice Chen")
60
+ assert result == (4, "name", "Alice Chen")
61
+
62
+ def test_numeric_fix(self):
63
+ result = parse_fix("row:7,col:salary,fix:75000")
64
+ assert result == (7, "salary", "75000")
65
+
66
+ def test_date_fix(self):
67
+ result = parse_fix("row:12,col:order_date,fix:2024-01-26")
68
+ assert result == (12, "order_date", "2024-01-26")
69
+
70
+ def test_case_insensitive(self):
71
+ result = parse_fix("Row:4,Col:Name,Fix:Alice Chen")
72
+ assert result == (4, "name", "Alice Chen")
73
+
74
+ def test_unparseable(self):
75
+ assert parse_fix("garbage") is None
76
+ assert parse_fix("row:4,col:name") is None
77
+
78
+ def test_fix_with_special_chars(self):
79
+ result = parse_fix("row:1,col:email,fix:alice.chen@company.com")
80
+ assert result == (1, "email", "alice.chen@company.com")
81
+
82
+
83
+ # ──────────────────────────────────────────────────────
84
+ # F1 scoring
85
+ # ──────────────────────────────────────────────────────
86
+
87
+ class TestComputeF1:
88
+ def test_perfect_match(self):
89
+ keys = {"row:1,col:a,issue:missing_value"}
90
+ result = compute_f1(keys, keys)
91
+ assert result["f1"] == 1.0
92
+
93
+ def test_no_reported_no_planted(self):
94
+ result = compute_f1(set(), set())
95
+ assert result["f1"] == 1.0
96
+
97
+ def test_no_reported_some_planted(self):
98
+ planted = {"row:1,col:a,issue:missing_value"}
99
+ result = compute_f1(set(), planted)
100
+ assert result["f1"] == 0.0
101
+ assert result["fn"] == 1
102
+
103
+ def test_all_false_positives(self):
104
+ reported = {"row:99,col:x,issue:wrong_type"}
105
+ planted = {"row:1,col:a,issue:missing_value"}
106
+ result = compute_f1(reported, planted)
107
+ assert result["f1"] == 0.0
108
+
109
+ def test_partial_match(self):
110
+ reported = {"row:1,col:a,issue:missing_value", "row:2,col:b,issue:wrong_type"}
111
+ planted = {"row:1,col:a,issue:missing_value", "row:3,col:c,issue:duplicate_row"}
112
+ result = compute_f1(reported, planted)
113
+ assert result["tp"] == 1
114
+ assert result["fp"] == 1
115
+ assert result["fn"] == 1
116
+ assert 0 < result["f1"] < 1
117
+
118
+ def test_precision_recall_calculation(self):
119
+ reported = {"a", "b", "c"}
120
+ planted = {"a", "b", "d"}
121
+ result = compute_f1(reported, planted)
122
+ assert result["precision"] == pytest.approx(2 / 3)
123
+ assert result["recall"] == pytest.approx(2 / 3)
124
+
125
+
126
+ # ──────────────────────────────────────────────────────
127
+ # Weighted reward
128
+ # ──────────────────────────────────────────────────────
129
+
130
+ class TestComputeWeightedReward:
131
+ def test_perfect_match(self):
132
+ issues = [
133
+ PlantedIssue(row=1, col="a", issue_type="missing_value", description="", difficulty=1.0),
134
+ PlantedIssue(row=2, col="b", issue_type="wrong_type", description="", difficulty=3.0),
135
+ ]
136
+ reported = {i.to_key() for i in issues}
137
+ result = compute_weighted_reward(reported, issues)
138
+ assert result["weighted_reward"] == 1.0
139
+
140
+ def test_empty_both(self):
141
+ result = compute_weighted_reward(set(), [])
142
+ assert result["weighted_reward"] == 1.0
143
+
144
+ def test_no_reported(self):
145
+ issues = [PlantedIssue(row=1, col="a", issue_type="missing_value", description="", difficulty=2.0)]
146
+ result = compute_weighted_reward(set(), issues)
147
+ assert result["weighted_reward"] == 0.0
148
+
149
+ def test_hard_issue_worth_more(self):
150
+ easy = PlantedIssue(row=1, col="a", issue_type="missing_value", description="", difficulty=1.0)
151
+ hard = PlantedIssue(row=2, col="b", issue_type="statistical_outlier", description="", difficulty=3.0)
152
+ issues = [easy, hard]
153
+ hard_found = compute_weighted_reward({hard.to_key()}, issues)
154
+ easy_found = compute_weighted_reward({easy.to_key()}, issues)
155
+ assert hard_found["weighted_reward"] > easy_found["weighted_reward"]
156
+
157
+ def test_false_positives_reduce_reward(self):
158
+ issues = [PlantedIssue(row=1, col="a", issue_type="missing_value", description="", difficulty=1.0)]
159
+ correct = {issues[0].to_key()}
160
+ with_fp = correct | {"row:99,col:x,issue:wrong_type"}
161
+ r_correct = compute_weighted_reward(correct, issues)
162
+ r_with_fp = compute_weighted_reward(with_fp, issues)
163
+ assert r_correct["weighted_reward"] > r_with_fp["weighted_reward"]
164
+
165
+
166
+ # ──────────────────────────────────────────────────────
167
+ # Fix grading
168
+ # ──────────────────────────────────────────────────────
169
+
170
+ class TestGradeFixes:
171
+ @pytest.fixture
172
+ def easy_task(self):
173
+ return create_task_easy()
174
+
175
+ def test_no_fixes_no_issues(self):
176
+ from dataqa_env.server.tasks import Task
177
+ task = Task(task_id="empty", name="", description="", schema_description="",
178
+ validation_rules="", clean_csv="a\n1")
179
+ result = grade_fixes([], task)
180
+ assert result["fix_score"] == 1.0
181
+
182
+ def test_no_fixes_submitted(self, easy_task):
183
+ result = grade_fixes([], easy_task)
184
+ assert result["fix_score"] == 0.0
185
+ assert result["fixes_attempted"] == 0
186
+
187
+ def test_exact_fix_for_missing_name(self, easy_task):
188
+ # Row 4 has empty name — clean value is "David Kim"
189
+ fixes = [(4, "name", "David Kim")]
190
+ result = grade_fixes(fixes, easy_task)
191
+ assert result["fix_score"] > 0.0
192
+ assert result["fixes_correct"] == 1
193
+
194
+ def test_exact_fix_for_wrong_type_salary(self, easy_task):
195
+ # Row 7 has "seventy-five thousand" — clean value is "75000"
196
+ fixes = [(7, "salary", "75000")]
197
+ result = grade_fixes(fixes, easy_task)
198
+ assert result["fixes_correct"] == 1
199
+
200
+ def test_numeric_close_match(self, easy_task):
201
+ # Row 9 has salary "5000" — clean value is "73000"
202
+ # Propose 73100 (within 1% of 73000)
203
+ fixes = [(9, "salary", "73100")]
204
+ result = grade_fixes(fixes, easy_task)
205
+ assert result["fixes_partial"] == 1
206
+
207
+ def test_wrong_value_for_issue_cell(self, easy_task):
208
+ # Row 4 name is empty — propose wrong name
209
+ fixes = [(4, "name", "Wrong Person")]
210
+ result = grade_fixes(fixes, easy_task)
211
+ assert result["fixes_partial"] == 1 # correct cell, wrong value
212
+ assert result["fix_score"] > 0.0 # gets partial credit
213
+
214
+ def test_fix_for_non_issue_cell(self, easy_task):
215
+ # Row 1 col name is fine — no issue there
216
+ fixes = [(1, "name", "Some Name")]
217
+ result = grade_fixes(fixes, easy_task)
218
+ assert result["fixes_wrong"] == 1
219
+ assert result["fix_score"] == 0.0
220
+
221
+ def test_multiple_fixes_best_wins(self, easy_task):
222
+ # Submit two fixes for same cell — best one should count
223
+ fixes = [
224
+ (4, "name", "Wrong Person"), # partial credit
225
+ (4, "name", "David Kim"), # exact match
226
+ ]
227
+ result = grade_fixes(fixes, easy_task)
228
+ assert result["fixes_correct"] >= 1
229
+
230
+ def test_all_fixes_correct(self, easy_task):
231
+ # Fix most issues with exact values
232
+ fixes = [
233
+ (4, "name", "David Kim"),
234
+ (7, "salary", "75000"),
235
+ (9, "salary", "73000"),
236
+ (15, "email", "oscar.rivera@company.com"),
237
+ (18, "start_date", "2022-01-19"),
238
+ ]
239
+ result = grade_fixes(fixes, easy_task)
240
+ assert result["fix_score"] > 0.7 # 5 out of 6 issues fixed (duplicate can't be fixed)
241
+
242
+ def test_fix_score_bounded(self, easy_task):
243
+ fixes = [(4, "name", "David Kim"), (99, "x", "bad")]
244
+ result = grade_fixes(fixes, easy_task)
245
+ assert 0.0 <= result["fix_score"] <= 1.0
246
+
247
+
248
+ # ──────────────────────────────────────────────────────
249
+ # Full environment lifecycle
250
+ # ──────────────────────────────────────────────────────
251
+
252
+ class TestDataQAEnvironment:
253
+ @pytest.fixture
254
+ def env(self):
255
+ return DataQAEnvironment()
256
+
257
+ def test_reset_returns_observation(self, env):
258
+ obs = env.reset(task_id="easy")
259
+ assert obs.dataset_csv
260
+ assert obs.schema_description
261
+ assert obs.validation_rules
262
+ assert obs.task_description
263
+ assert obs.num_issues_hint == 6
264
+ assert obs.max_steps == 3
265
+ assert obs.done is False
266
+ assert obs.reward == 0.0
267
+ assert "fix" in obs.feedback.lower() # mentions fix phase
268
+
269
+ def test_reset_medium(self, env):
270
+ obs = env.reset(task_id="medium")
271
+ assert obs.num_issues_hint == 8
272
+
273
+ def test_reset_hard(self, env):
274
+ obs = env.reset(task_id="hard")
275
+ assert obs.num_issues_hint == 10
276
+
277
+ def test_step_identify_only(self, env):
278
+ """Backward compatible: only issues, no fixes."""
279
+ env.reset(task_id="easy")
280
+ # Submit all 6 correct issues for easy task
281
+ action = DataQAAction(
282
+ issues=[
283
+ "row:4,col:name,issue:missing_value",
284
+ "row:7,col:salary,issue:wrong_type",
285
+ "row:21,col:employee_id,issue:duplicate_row",
286
+ "row:9,col:salary,issue:out_of_range",
287
+ "row:15,col:email,issue:inconsistent_value",
288
+ "row:18,col:start_date,issue:out_of_range",
289
+ ],
290
+ task_id="easy",
291
+ )
292
+ obs = env.step(action)
293
+ assert obs.done is True
294
+ assert obs.reward >= 0.999 # identify-only uses identify_score directly
295
+
296
+ def test_step_with_fixes_increases_reward(self, env):
297
+ """Submitting correct fixes should produce high combined reward."""
298
+ env.reset(task_id="easy")
299
+ # All 6 issues + 3 fixes
300
+ action = DataQAAction(
301
+ issues=[
302
+ "row:4,col:name,issue:missing_value",
303
+ "row:7,col:salary,issue:wrong_type",
304
+ "row:21,col:employee_id,issue:duplicate_row",
305
+ "row:9,col:salary,issue:out_of_range",
306
+ "row:15,col:email,issue:inconsistent_value",
307
+ "row:18,col:start_date,issue:out_of_range",
308
+ ],
309
+ fixes=[
310
+ "row:4,col:name,fix:David Kim",
311
+ "row:7,col:salary,fix:75000",
312
+ "row:9,col:salary,fix:73000",
313
+ ],
314
+ task_id="easy",
315
+ )
316
+ obs = env.step(action)
317
+ # Perfect identify + partial fixes -> high combined reward
318
+ assert obs.metadata["combined_reward"] > 0.7
319
+
320
+ def test_step_with_partial_issues(self, env):
321
+ env.reset(task_id="easy")
322
+ action = DataQAAction(
323
+ issues=["row:4,col:name,issue:missing_value"],
324
+ task_id="easy",
325
+ )
326
+ obs = env.step(action)
327
+ assert 0 < obs.reward < 1.0
328
+ assert obs.done is False
329
+
330
+ def test_step_with_no_issues(self, env):
331
+ env.reset(task_id="easy")
332
+ action = DataQAAction(issues=[], task_id="easy")
333
+ obs = env.step(action)
334
+ assert obs.reward == 0.0
335
+
336
+ def test_step_exhausts_max_steps(self, env):
337
+ env.reset(task_id="easy")
338
+ for _ in range(3):
339
+ action = DataQAAction(issues=["row:99,col:x,issue:wrong_type"], task_id="easy")
340
+ obs = env.step(action)
341
+ assert obs.done is True
342
+
343
+ def test_auto_reset_on_step(self, env):
344
+ action = DataQAAction(
345
+ issues=["row:4,col:name,issue:missing_value"],
346
+ task_id="easy",
347
+ )
348
+ obs = env.step(action)
349
+ assert obs.task_id == "easy"
350
+
351
+ def test_state_tracking(self, env):
352
+ env.reset(task_id="easy")
353
+ assert env.state.task_id == "easy"
354
+ assert env.state.current_step == 0
355
+ assert env.state.best_score == 0.0
356
+
357
+ action = DataQAAction(issues=["row:4,col:name,issue:missing_value"], task_id="easy")
358
+ env.step(action)
359
+ assert env.state.current_step == 1
360
+ assert env.state.best_score > 0.0
361
+
362
+ def test_best_score_monotonic(self, env):
363
+ env.reset(task_id="easy")
364
+ action1 = DataQAAction(
365
+ issues=["row:4,col:name,issue:missing_value", "row:7,col:salary,issue:wrong_type"],
366
+ task_id="easy",
367
+ )
368
+ env.step(action1)
369
+ score_after_1 = env.state.best_score
370
+
371
+ action2 = DataQAAction(issues=["row:99,col:x,issue:wrong_type"], task_id="easy")
372
+ env.step(action2)
373
+ assert env.state.best_score >= score_after_1
374
+
375
+ def test_metadata_includes_both_phases(self, env):
376
+ env.reset(task_id="easy")
377
+ action = DataQAAction(
378
+ issues=["row:4,col:name,issue:missing_value"],
379
+ fixes=["row:4,col:name,fix:David Kim"],
380
+ task_id="easy",
381
+ )
382
+ obs = env.step(action)
383
+ m = obs.metadata
384
+ assert "identify_f1" in m
385
+ assert "identify_score" in m
386
+ assert "fix_score" in m
387
+ assert "combined_reward" in m
388
+ assert "tp" in m
389
+ assert "fixes_correct" in m
390
+ assert "fixes_attempted" in m
391
+
392
+ def test_parse_error_in_feedback(self, env):
393
+ env.reset(task_id="easy")
394
+ action = DataQAAction(issues=["garbage input"], task_id="easy")
395
+ obs = env.step(action)
396
+ assert "Parse error" in obs.feedback
397
+
398
+ def test_concurrent_sessions_flag(self):
399
+ assert DataQAEnvironment.SUPPORTS_CONCURRENT_SESSIONS is True
400
+
401
+ def test_reward_between_0_and_1(self, env):
402
+ """Hackathon requirement: scores must be 0.0-1.0."""
403
+ env.reset(task_id="hard")
404
+ for _ in range(3):
405
+ action = DataQAAction(
406
+ issues=["row:1,col:x,issue:wrong_type", "row:99,col:y,issue:missing_value"],
407
+ fixes=["row:1,col:x,fix:wrong"],
408
+ task_id="hard",
409
+ )
410
+ obs = env.step(action)
411
+ assert 0.0 <= obs.reward <= 1.0
412
+
413
+ def test_combined_reward_weights(self, env):
414
+ """Verify combined = IDENTIFY_WEIGHT * identify + FIX_WEIGHT * fix."""
415
+ env.reset(task_id="easy")
416
+ action = DataQAAction(
417
+ issues=["row:4,col:name,issue:missing_value"],
418
+ fixes=["row:4,col:name,fix:David Kim"],
419
+ task_id="easy",
420
+ )
421
+ obs = env.step(action)
422
+ m = obs.metadata
423
+ expected = IDENTIFY_WEIGHT * m["identify_score"] + FIX_WEIGHT * m["fix_score"]
424
+ assert abs(m["combined_reward"] - expected) < 0.01
425
+
426
+ def test_fix_feedback_shown_when_fixes_submitted(self, env):
427
+ env.reset(task_id="easy")
428
+ action = DataQAAction(
429
+ issues=["row:4,col:name,issue:missing_value"],
430
+ fixes=["row:4,col:name,fix:David Kim"],
431
+ task_id="easy",
432
+ )
433
+ obs = env.step(action)
434
+ assert "Fix Proposals" in obs.feedback
435
+ assert "Combined Reward" in obs.feedback
436
+
437
+ def test_no_fix_penalty_when_no_fixes_submitted(self, env):
438
+ """If agent submits no fixes, reward = identify_score (no penalty)."""
439
+ env.reset(task_id="easy")
440
+ action = DataQAAction(
441
+ issues=[
442
+ "row:4,col:name,issue:missing_value",
443
+ "row:7,col:salary,issue:wrong_type",
444
+ "row:21,col:employee_id,issue:duplicate_row",
445
+ "row:9,col:salary,issue:out_of_range",
446
+ "row:15,col:email,issue:inconsistent_value",
447
+ "row:18,col:start_date,issue:out_of_range",
448
+ ],
449
+ task_id="easy",
450
+ )
451
+ obs = env.step(action)
452
+ # identify_score should be ~1.0 since all 6 issues found
453
+ assert obs.reward >= 0.99
454
+ # combined_reward equals identify_score when no fixes
455
+ assert obs.metadata["combined_reward"] == obs.metadata["identify_score"]
tests/test_extensibility.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for the extensibility API — custom tasks and contamination rules."""
2
+
3
+ import pytest
4
+ from dataqa_env.server.tasks import (
5
+ PlantedIssue,
6
+ create_task_from_config,
7
+ register_task,
8
+ register_contamination_rule,
9
+ CONTAMINATION_RULES,
10
+ get_task,
11
+ list_tasks,
12
+ )
13
+ from dataqa_env.server.environment import DataQAEnvironment, compute_weighted_reward
14
+ from dataqa_env.models import DataQAAction
15
+
16
+
17
+ SIMPLE_CSV = "id,name,score\n1,Alice,95\n2,Bob,87\n3,Carol,92\n4,Dave,78"
18
+
19
+
20
+ class TestCreateTaskFromConfig:
21
+ def test_basic_creation(self):
22
+ task = create_task_from_config(
23
+ task_id="test_custom",
24
+ name="Test Task",
25
+ description="Test",
26
+ schema_description="id: int, name: str, score: int",
27
+ validation_rules="No missing values",
28
+ clean_csv=SIMPLE_CSV,
29
+ contaminations=[
30
+ {"rule": "missing_value", "row": 0, "col": 1},
31
+ ],
32
+ )
33
+ assert task.task_id == "test_custom"
34
+ assert len(task.planted_issues) == 1
35
+ assert task.planted_issues[0].issue_type == "missing_value"
36
+ assert task.planted_issues[0].col == "name"
37
+
38
+ def test_multiple_contaminations(self):
39
+ task = create_task_from_config(
40
+ task_id="multi",
41
+ name="Multi",
42
+ description="Test",
43
+ schema_description="",
44
+ validation_rules="",
45
+ clean_csv=SIMPLE_CSV,
46
+ contaminations=[
47
+ {"rule": "missing_value", "row": 0, "col": 1},
48
+ {"rule": "missing_value", "row": 2, "col": 1},
49
+ ],
50
+ )
51
+ assert len(task.planted_issues) == 2
52
+
53
+ def test_custom_difficulty_override(self):
54
+ task = create_task_from_config(
55
+ task_id="custom_diff",
56
+ name="Custom Difficulty",
57
+ description="Test",
58
+ schema_description="",
59
+ validation_rules="",
60
+ clean_csv=SIMPLE_CSV,
61
+ contaminations=[
62
+ {"rule": "missing_value", "row": 0, "col": 1, "difficulty": 2.5},
63
+ ],
64
+ )
65
+ assert task.planted_issues[0].difficulty == 2.5
66
+
67
+ def test_callable_rule(self):
68
+ def custom_rule(rows, header, col_idx, row_idx, rng):
69
+ return "CORRUPTED", PlantedIssue(
70
+ row=row_idx + 1, col=header[col_idx], issue_type="wrong_type",
71
+ description="Custom corruption", difficulty=1.5,
72
+ )
73
+
74
+ task = create_task_from_config(
75
+ task_id="callable",
76
+ name="Callable Rule",
77
+ description="Test",
78
+ schema_description="",
79
+ validation_rules="",
80
+ clean_csv=SIMPLE_CSV,
81
+ contaminations=[
82
+ {"rule": custom_rule, "row": 1, "col": 2},
83
+ ],
84
+ )
85
+ assert task.planted_issues[0].issue_type == "wrong_type"
86
+ assert "CORRUPTED" in task.corrupted_csv
87
+
88
+ def test_unknown_rule_raises(self):
89
+ with pytest.raises(ValueError, match="Unknown contamination rule"):
90
+ create_task_from_config(
91
+ task_id="bad",
92
+ name="Bad",
93
+ description="",
94
+ schema_description="",
95
+ validation_rules="",
96
+ clean_csv=SIMPLE_CSV,
97
+ contaminations=[{"rule": "nonexistent_rule", "row": 0, "col": 0}],
98
+ )
99
+
100
+
101
+ class TestRegisterContaminationRule:
102
+ def test_register_and_use(self):
103
+ def reverse_value(rows, header, col_idx, row_idx, rng):
104
+ val = rows[row_idx][col_idx]
105
+ return val[::-1], PlantedIssue(
106
+ row=row_idx + 1, col=header[col_idx], issue_type="format_violation",
107
+ description="Reversed value", difficulty=1.5,
108
+ )
109
+
110
+ register_contamination_rule("reverse", reverse_value)
111
+ assert "reverse" in CONTAMINATION_RULES
112
+
113
+ task = create_task_from_config(
114
+ task_id="rev_test",
115
+ name="Reverse Test",
116
+ description="",
117
+ schema_description="",
118
+ validation_rules="",
119
+ clean_csv=SIMPLE_CSV,
120
+ contaminations=[{"rule": "reverse", "row": 0, "col": 1}],
121
+ )
122
+ assert task.planted_issues[0].issue_type == "format_violation"
123
+ # "Alice" reversed is "ecilA"
124
+ assert "ecilA" in task.corrupted_csv
125
+
126
+ # Cleanup
127
+ del CONTAMINATION_RULES["reverse"]
128
+
129
+
130
+ class TestRegisterTask:
131
+ def test_register_and_get(self):
132
+ task = create_task_from_config(
133
+ task_id="registered",
134
+ name="Registered Task",
135
+ description="Test registered task",
136
+ schema_description="id: int, name: str",
137
+ validation_rules="No missing values",
138
+ clean_csv=SIMPLE_CSV,
139
+ contaminations=[{"rule": "missing_value", "row": 1, "col": 1}],
140
+ )
141
+ register_task("registered", lambda seed: task)
142
+ assert "registered" in list_tasks()
143
+
144
+ fetched = get_task("registered")
145
+ assert fetched.task_id == "registered"
146
+ assert len(fetched.planted_issues) == 1
147
+
148
+ # Cleanup
149
+ from dataqa_env.server.tasks import TASK_REGISTRY
150
+ del TASK_REGISTRY["registered"]
151
+
152
+
153
+ class TestCustomTaskInEnvironment:
154
+ def test_full_lifecycle_identify_only(self):
155
+ """Custom task works end-to-end with identify-only."""
156
+ task = create_task_from_config(
157
+ task_id="e2e_custom",
158
+ name="E2E Custom",
159
+ description="End-to-end test",
160
+ schema_description="id: int, name: str, score: int",
161
+ validation_rules="No missing values",
162
+ clean_csv=SIMPLE_CSV,
163
+ contaminations=[
164
+ {"rule": "missing_value", "row": 0, "col": 1, "difficulty": 1.0},
165
+ {"rule": "whitespace_value", "row": 2, "col": 1, "difficulty": 2.5},
166
+ ],
167
+ )
168
+ register_task("e2e_custom", lambda seed: task)
169
+
170
+ env = DataQAEnvironment()
171
+ obs = env.reset(task_id="e2e_custom")
172
+ assert obs.num_issues_hint == 2
173
+
174
+ action = DataQAAction(
175
+ issues=[i.to_key() for i in task.planted_issues],
176
+ task_id="e2e_custom",
177
+ )
178
+ obs = env.step(action)
179
+ assert obs.done is True
180
+ assert obs.reward >= 0.999
181
+
182
+ from dataqa_env.server.tasks import TASK_REGISTRY
183
+ del TASK_REGISTRY["e2e_custom"]
184
+
185
+ def test_full_lifecycle_identify_and_fix(self):
186
+ """Custom task works end-to-end with both identify and fix."""
187
+ task = create_task_from_config(
188
+ task_id="e2e_fix",
189
+ name="E2E Fix",
190
+ description="End-to-end test with fixes",
191
+ schema_description="id: int, name: str, score: int",
192
+ validation_rules="No missing values",
193
+ clean_csv=SIMPLE_CSV,
194
+ contaminations=[
195
+ {"rule": "missing_value", "row": 0, "col": 1, "difficulty": 1.0},
196
+ ],
197
+ )
198
+ register_task("e2e_fix", lambda seed: task)
199
+
200
+ env = DataQAEnvironment()
201
+ env.reset(task_id="e2e_fix")
202
+
203
+ # Submit issues + fix
204
+ action = DataQAAction(
205
+ issues=[task.planted_issues[0].to_key()],
206
+ fixes=["row:1,col:name,fix:Alice"], # clean value is "Alice"
207
+ task_id="e2e_fix",
208
+ )
209
+ obs = env.step(action)
210
+ assert obs.done is True
211
+ assert obs.metadata["fix_score"] > 0.0
212
+ assert obs.metadata["combined_reward"] > 0.0
213
+
214
+ from dataqa_env.server.tasks import TASK_REGISTRY
215
+ del TASK_REGISTRY["e2e_fix"]
tests/test_inference.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for the inference script's parsing, prompt building, and log format."""
2
+
3
+ import pytest
4
+ import sys
5
+ import os
6
+
7
+ sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
8
+ from inference import parse_llm_response, parse_fix_response, build_user_prompt, log_start, log_step, log_end
9
+
10
+
11
+ class TestParseLLMResponse:
12
+ def test_standard_format(self):
13
+ response = "row:1,col:name,issue:missing_value\nrow:2,col:salary,issue:wrong_type"
14
+ issues = parse_llm_response(response)
15
+ assert len(issues) == 2
16
+ assert "row:1,col:name,issue:missing_value" in issues
17
+
18
+ def test_numbered_list(self):
19
+ response = "1. row:1,col:name,issue:missing_value\n2. row:2,col:salary,issue:wrong_type"
20
+ issues = parse_llm_response(response)
21
+ assert len(issues) == 2
22
+
23
+ def test_bullet_list(self):
24
+ response = "- row:1,col:name,issue:missing_value\n* row:2,col:salary,issue:wrong_type"
25
+ issues = parse_llm_response(response)
26
+ assert len(issues) == 2
27
+
28
+ def test_equals_delimiter(self):
29
+ response = "row=1,col=name,issue=missing_value"
30
+ issues = parse_llm_response(response)
31
+ assert len(issues) == 1
32
+ assert issues[0] == "row:1,col:name,issue:missing_value"
33
+
34
+ def test_mixed_case(self):
35
+ response = "Row:1,Col:Name,Issue:Missing_Value"
36
+ issues = parse_llm_response(response)
37
+ assert len(issues) == 1
38
+ assert issues[0] == "row:1,col:name,issue:missing_value"
39
+
40
+ def test_empty_response(self):
41
+ assert parse_llm_response("") == []
42
+ assert parse_llm_response(" ") == []
43
+
44
+ def test_garbage_lines_skipped(self):
45
+ response = "Here are the issues:\nrow:1,col:name,issue:missing_value\nNo more issues."
46
+ issues = parse_llm_response(response)
47
+ assert len(issues) == 1
48
+
49
+ def test_deduplication_not_applied(self):
50
+ response = "row:1,col:name,issue:missing_value\nrow:1,col:name,issue:missing_value"
51
+ issues = parse_llm_response(response)
52
+ assert len(issues) == 2
53
+
54
+ def test_with_column_variant(self):
55
+ response = "row:1,column:name,issue:missing_value"
56
+ issues = parse_llm_response(response)
57
+ assert len(issues) == 1
58
+
59
+
60
+ class TestParseFixResponse:
61
+ def test_standard_format(self):
62
+ response = "row:4,col:name,fix:David Kim\nrow:7,col:salary,fix:75000"
63
+ fixes = parse_fix_response(response)
64
+ assert len(fixes) == 2
65
+ assert "row:4,col:name,fix:David Kim" in fixes
66
+
67
+ def test_numbered_list(self):
68
+ response = "1. row:4,col:name,fix:David Kim\n2. row:7,col:salary,fix:75000"
69
+ fixes = parse_fix_response(response)
70
+ assert len(fixes) == 2
71
+
72
+ def test_with_special_chars(self):
73
+ response = "row:1,col:email,fix:alice.chen@company.com"
74
+ fixes = parse_fix_response(response)
75
+ assert len(fixes) == 1
76
+ assert "alice.chen@company.com" in fixes[0]
77
+
78
+ def test_empty_response(self):
79
+ assert parse_fix_response("") == []
80
+
81
+ def test_date_fix(self):
82
+ response = "row:12,col:order_date,fix:2024-01-26"
83
+ fixes = parse_fix_response(response)
84
+ assert len(fixes) == 1
85
+
86
+ def test_ignores_issue_lines(self):
87
+ response = "row:4,col:name,issue:missing_value\nrow:4,col:name,fix:David Kim"
88
+ fixes = parse_fix_response(response)
89
+ assert len(fixes) == 1 # only the fix line
90
+
91
+
92
+ class TestBuildUserPrompt:
93
+ def test_includes_all_fields(self):
94
+ obs = {
95
+ "task_description": "Find issues",
96
+ "schema_description": "col: int",
97
+ "validation_rules": "no nulls",
98
+ "dataset_csv": "a,b\n1,2",
99
+ "num_issues_hint": 3,
100
+ "feedback": "",
101
+ }
102
+ prompt = build_user_prompt(obs)
103
+ assert "Find issues" in prompt
104
+ assert "col: int" in prompt
105
+ assert "no nulls" in prompt
106
+ assert "a,b" in prompt
107
+ assert "3 issues" in prompt
108
+
109
+ def test_includes_feedback_on_retry(self):
110
+ obs = {
111
+ "task_description": "Find issues",
112
+ "schema_description": "",
113
+ "validation_rules": "",
114
+ "dataset_csv": "a\n1",
115
+ "num_issues_hint": 0,
116
+ "feedback": "Step 1/3: You missed 2 issues",
117
+ }
118
+ prompt = build_user_prompt(obs)
119
+ assert "FEEDBACK" in prompt
120
+ assert "missed 2" in prompt
121
+
122
+ def test_excludes_reset_feedback(self):
123
+ obs = {
124
+ "task_description": "",
125
+ "schema_description": "",
126
+ "validation_rules": "",
127
+ "dataset_csv": "",
128
+ "num_issues_hint": 0,
129
+ "feedback": "Environment reset. Start inspecting.",
130
+ }
131
+ prompt = build_user_prompt(obs)
132
+ assert "FEEDBACK" not in prompt
133
+
134
+ def test_include_fixes_flag(self):
135
+ obs = {
136
+ "task_description": "Find issues",
137
+ "schema_description": "",
138
+ "validation_rules": "",
139
+ "dataset_csv": "a\n1",
140
+ "num_issues_hint": 0,
141
+ "feedback": "",
142
+ }
143
+ prompt = build_user_prompt(obs, include_fixes=True)
144
+ assert "fix" in prompt.lower()
145
+
146
+
147
+ class TestLogFormat:
148
+ """Verify stdout log format matches hackathon evaluation requirements."""
149
+
150
+ def test_log_start_format(self, capsys):
151
+ log_start(task="easy", env="dataqa_env", model="test-model")
152
+ out = capsys.readouterr().out.strip()
153
+ assert out == "[START] task=easy env=dataqa_env model=test-model"
154
+
155
+ def test_log_step_format(self, capsys):
156
+ log_step(step=1, action="row:1,col:name,issue:missing_value", reward=0.50, done=False, error=None)
157
+ out = capsys.readouterr().out.strip()
158
+ assert out == "[STEP] step=1 action=row:1,col:name,issue:missing_value reward=0.50 done=false error=null"
159
+
160
+ def test_log_step_with_error(self, capsys):
161
+ log_step(step=2, action="none", reward=0.00, done=True, error="timeout")
162
+ out = capsys.readouterr().out.strip()
163
+ assert "error=timeout" in out
164
+ assert "done=true" in out
165
+
166
+ def test_log_end_format(self, capsys):
167
+ log_end(success=True, steps=3, score=0.85, rewards=[0.25, 0.50, 0.85])
168
+ out = capsys.readouterr().out.strip()
169
+ assert out == "[END] success=true steps=3 score=0.850 rewards=0.25,0.50,0.85"
170
+
171
+ def test_log_end_failure(self, capsys):
172
+ log_end(success=False, steps=1, score=0.0, rewards=[0.0])
173
+ out = capsys.readouterr().out.strip()
174
+ assert "success=false" in out
175
+ assert "score=0.000" in out
176
+
177
+ def test_reward_format_2_decimal(self, capsys):
178
+ log_step(step=1, action="test", reward=0.123456, done=False, error=None)
179
+ out = capsys.readouterr().out.strip()
180
+ assert "reward=0.12" in out
181
+
182
+ def test_no_newlines_within_line(self, capsys):
183
+ log_start(task="easy", env="dataqa_env", model="model")
184
+ log_step(step=1, action="act", reward=0.0, done=False, error=None)
185
+ log_end(success=False, steps=1, score=0.0, rewards=[0.0])
186
+ out = capsys.readouterr().out
187
+ lines = [l for l in out.split("\n") if l.strip()]
188
+ assert len(lines) == 3
189
+ assert lines[0].startswith("[START]")
190
+ assert lines[1].startswith("[STEP]")
191
+ assert lines[2].startswith("[END]")
tests/test_tasks.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for task definitions, data corruption, and issue planting."""
2
+
3
+ import pytest
4
+ from dataqa_env.server.tasks import (
5
+ PlantedIssue,
6
+ Task,
7
+ create_task_easy,
8
+ create_task_medium,
9
+ create_task_hard,
10
+ get_task,
11
+ list_tasks,
12
+ _csv_to_rows,
13
+ _rows_to_csv,
14
+ )
15
+
16
+
17
+ class TestPlantedIssue:
18
+ def test_to_key(self):
19
+ issue = PlantedIssue(row=3, col="salary", issue_type="missing_value", description="test")
20
+ assert issue.to_key() == "row:3,col:salary,issue:missing_value"
21
+
22
+ def test_difficulty_default(self):
23
+ issue = PlantedIssue(row=1, col="name", issue_type="missing_value", description="test")
24
+ assert issue.difficulty == 1.0
25
+
26
+ def test_difficulty_custom(self):
27
+ issue = PlantedIssue(row=1, col="name", issue_type="missing_value", description="test", difficulty=3.0)
28
+ assert issue.difficulty == 3.0
29
+
30
+
31
+ class TestCSVHelpers:
32
+ def test_roundtrip(self):
33
+ csv_text = "a,b,c\n1,2,3\n4,5,6"
34
+ rows = _csv_to_rows(csv_text)
35
+ assert len(rows) == 3
36
+ result = _rows_to_csv(rows)
37
+ assert "1,2,3" in result
38
+
39
+ def test_empty_csv(self):
40
+ rows = _csv_to_rows("a,b\n")
41
+ assert len(rows) == 1 # header only
42
+
43
+
44
+ class TestTaskEasy:
45
+ @pytest.fixture
46
+ def task(self):
47
+ return create_task_easy()
48
+
49
+ def test_task_id(self, task):
50
+ assert task.task_id == "easy"
51
+
52
+ def test_has_6_issues(self, task):
53
+ assert len(task.planted_issues) == 6
54
+
55
+ def test_issue_types(self, task):
56
+ types = {i.issue_type for i in task.planted_issues}
57
+ assert "missing_value" in types
58
+ assert "wrong_type" in types
59
+ assert "duplicate_row" in types
60
+ assert "out_of_range" in types
61
+ assert "inconsistent_value" in types
62
+
63
+ def test_corrupted_csv_differs_from_clean(self, task):
64
+ assert task.corrupted_csv != task.clean_csv
65
+
66
+ def test_issue_keys_unique(self, task):
67
+ keys = [i.to_key() for i in task.planted_issues]
68
+ assert len(keys) == len(set(keys))
69
+
70
+ def test_max_steps(self, task):
71
+ assert task.max_steps == 3
72
+
73
+ def test_corrupted_csv_has_more_rows(self, task):
74
+ clean_rows = _csv_to_rows(task.clean_csv)
75
+ corrupt_rows = _csv_to_rows(task.corrupted_csv)
76
+ assert len(corrupt_rows) > len(clean_rows) # duplicate row added
77
+
78
+ def test_difficulty_weights(self, task):
79
+ for issue in task.planted_issues:
80
+ assert 1.0 <= issue.difficulty <= 3.0
81
+
82
+
83
+ class TestTaskMedium:
84
+ @pytest.fixture
85
+ def task(self):
86
+ return create_task_medium()
87
+
88
+ def test_task_id(self, task):
89
+ assert task.task_id == "medium"
90
+
91
+ def test_has_8_issues(self, task):
92
+ assert len(task.planted_issues) == 8
93
+
94
+ def test_issue_types(self, task):
95
+ types = {i.issue_type for i in task.planted_issues}
96
+ assert "inconsistent_value" in types
97
+ assert "format_violation" in types
98
+ assert "missing_value" in types
99
+
100
+ def test_issue_keys_unique(self, task):
101
+ keys = [i.to_key() for i in task.planted_issues]
102
+ assert len(keys) == len(set(keys))
103
+
104
+ def test_difficulty_weights(self, task):
105
+ for issue in task.planted_issues:
106
+ assert 1.0 <= issue.difficulty <= 3.0
107
+
108
+
109
+ class TestTaskHard:
110
+ @pytest.fixture
111
+ def task(self):
112
+ return create_task_hard()
113
+
114
+ def test_task_id(self, task):
115
+ assert task.task_id == "hard"
116
+
117
+ def test_has_10_issues(self, task):
118
+ assert len(task.planted_issues) == 10
119
+
120
+ def test_issue_types(self, task):
121
+ types = {i.issue_type for i in task.planted_issues}
122
+ assert "inconsistent_value" in types
123
+ assert "format_violation" in types
124
+ assert "statistical_outlier" in types
125
+ assert "out_of_range" in types
126
+ assert "missing_value" in types
127
+
128
+ def test_has_high_difficulty_issues(self, task):
129
+ hard_issues = [i for i in task.planted_issues if i.difficulty >= 2.5]
130
+ assert len(hard_issues) >= 2 # data leakage, GPU outlier, whitespace
131
+
132
+ def test_issue_keys_unique(self, task):
133
+ keys = [i.to_key() for i in task.planted_issues]
134
+ assert len(keys) == len(set(keys))
135
+
136
+
137
+ class TestTaskRegistry:
138
+ def test_list_tasks(self):
139
+ tasks = list_tasks()
140
+ assert set(tasks) == {"easy", "medium", "hard"}
141
+
142
+ def test_get_task_easy(self):
143
+ task = get_task("easy")
144
+ assert task.task_id == "easy"
145
+
146
+ def test_get_task_medium(self):
147
+ task = get_task("medium")
148
+ assert task.task_id == "medium"
149
+
150
+ def test_get_task_hard(self):
151
+ task = get_task("hard")
152
+ assert task.task_id == "hard"
153
+
154
+ def test_get_task_unknown_raises(self):
155
+ with pytest.raises(ValueError, match="Unknown task"):
156
+ get_task("nonexistent")
157
+
158
+ def test_seed_determinism(self):
159
+ t1 = get_task("easy", seed=42)
160
+ t2 = get_task("easy", seed=42)
161
+ assert t1.corrupted_csv == t2.corrupted_csv
162
+ assert [i.to_key() for i in t1.planted_issues] == [i.to_key() for i in t2.planted_issues]