Varshith Bathini commited on
Commit
ca01572
·
unverified ·
2 Parent(s): 4c1a85d85257bc

Merge pull request #1 from varshith15/enhancementsv1

Browse files
.gitignore CHANGED
@@ -6,4 +6,8 @@ build/
6
  .venv/
7
  *.egg
8
  .env
 
 
9
  uv.lock
 
 
 
6
  .venv/
7
  *.egg
8
  .env
9
+ .claude/
10
+ .pytest_cache/
11
  uv.lock
12
+ *.mov
13
+ docs/*.png
Dockerfile CHANGED
@@ -26,10 +26,10 @@ RUN uv sync --no-editable 2>/dev/null || pip install -e .
26
  ENV PATH="/app/.venv/bin:$PATH"
27
  ENV PYTHONPATH="/app:$PYTHONPATH"
28
 
29
- # Health check
30
  HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
31
- CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:8000/health')" || exit 1
32
 
33
- EXPOSE 8000
34
 
35
- CMD ["uvicorn", "dataqa_env.server.app:app", "--host", "0.0.0.0", "--port", "8000"]
 
26
  ENV PATH="/app/.venv/bin:$PATH"
27
  ENV PYTHONPATH="/app:$PYTHONPATH"
28
 
29
+ # Health check — HF Spaces uses port 7860
30
  HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
31
+ CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:7860/health')" || exit 1
32
 
33
+ EXPOSE 7860
34
 
35
+ CMD ["uvicorn", "dataqa_env.server.app:app", "--host", "0.0.0.0", "--port", "7860"]
README.md CHANGED
@@ -1,65 +1,237 @@
1
  ---
2
  title: DataQA Environment Server
3
- emoji: 🔍
4
  colorFrom: blue
5
  colorTo: gray
6
  sdk: docker
7
  pinned: false
8
- app_port: 8000
9
- base_path: /web
10
  tags:
11
  - openenv
12
  ---
13
 
14
  # DataQA Environment
15
 
16
- An OpenEnv environment for **Data Quality Assurance** — an LLM agent inspects datasets with planted quality issues and must identify them all.
17
 
18
- ## Overview
19
 
20
- DataQA simulates the real-world task of validating datasets before they enter ML training pipelines or production databases. The agent receives a corrupted dataset along with its schema and validation rules, then must identify all planted data quality issues.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
- ### Why Data QA?
23
 
24
- Every ML engineer and data scientist spends significant time debugging data quality issues — missing values, type mismatches, inconsistencies, and subtle statistical anomalies. This environment turns that task into a structured, gradable challenge.
 
 
 
 
 
 
25
 
26
  ## Environment API
27
 
28
- | Endpoint | Description |
29
- |----------|-------------|
30
- | `reset(task_id)` | Start a new episode with a corrupted dataset |
31
- | `step(issues)` | Submit identified issues, receive F1-scored feedback |
32
- | `state()` | Get current episode state |
 
33
 
34
  ## Tasks
35
 
36
- | Task | Issues | Difficulty | Description |
37
- |------|--------|-----------|-------------|
38
- | `easy` | 4 | Beginner | Employee directory nulls, wrong types, duplicates, out-of-range |
39
- | `medium` | 6 | Intermediate | E-commerce orders format violations, inconsistent totals, duplicate keys |
40
- | `hard` | 8 | Advanced | ML experiment metadata data leakage signals, unreasonable GPU usage, timestamp ordering |
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
  ## Reward Function
43
 
44
- Scoring uses **F1 score** (harmonic mean of precision and recall):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
- - **Precision**: What fraction of reported issues are real?
47
- - **Recall**: What fraction of planted issues did the agent find?
48
- - **F1**: `2 * precision * recall / (precision + recall)`
49
 
50
- Issues are matched by `row:<N>,col:<column>,issue:<type>` keys.
51
 
52
- The agent gets up to 3 attempts per task with feedback on each attempt (true positives, false positives, missed count).
 
 
 
 
 
53
 
54
- ## Action/Observation Space
55
 
56
- **Action**: List of issue strings in format `row:<row_number>,col:<column_name>,issue:<issue_type>`
57
 
58
- **Observation**: Dataset CSV + schema + validation rules + feedback from previous attempt
 
 
 
 
59
 
60
- **Issue Types**: `missing_value`, `wrong_type`, `duplicate_row`, `out_of_range`, `format_violation`, `inconsistent_value`, `statistical_outlier`, `referential_integrity`
61
 
62
- ## Quick Start
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
  ```bash
65
  # Install
@@ -68,42 +240,76 @@ pip install -e .
68
  # Run server locally
69
  uvicorn dataqa_env.server.app:app --host 0.0.0.0 --port 8000
70
 
71
- # Run inference
72
- API_BASE_URL=https://api.groq.com/openai/v1 \
73
- MODEL_NAME=llama-3.3-70b-versatile \
74
- LLM_API_KEY=your-key \
75
  python inference.py
76
  ```
77
 
78
  ## Docker
79
 
80
  ```bash
81
- docker build -t dataqa-env -f dataqa_env/server/Dockerfile .
82
  docker run -p 8000:8000 dataqa-env
83
  ```
84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  ## Environment Variables
86
 
87
  | Variable | Description | Default |
88
  |----------|-------------|---------|
89
- | `API_BASE_URL` | LLM API endpoint | `https://api.groq.com/openai/v1` |
90
- | `MODEL_NAME` | Model identifier | `llama-3.3-70b-versatile` |
91
- | `HF_TOKEN` | HuggingFace token | - |
92
  | `ENV_URL` | Environment server URL | `http://localhost:8000` |
93
- | `LLM_API_KEY` | API key for LLM provider | Falls back to HF_TOKEN |
94
 
95
  ## Architecture
96
 
97
  ```
98
  dataqa_env/
99
- ├── models.py # Pydantic: DataQAAction, DataQAObservation, DataQAState
 
100
  ├── client.py # EnvClient for WebSocket connections
101
  ├── server/
102
- │ ├── environment.py # Core DataQAEnvironment (reset/step/state)
103
- │ ├── tasks.py # Task definitions + data corruption + grading
104
- │ ├── app.py # FastAPI server
105
  │ └── Dockerfile
106
- ├── openenv.yaml
107
- ├── pyproject.toml
108
- ── inference.py # LLM agent using OpenAI client
 
 
 
 
 
 
109
  ```
 
1
  ---
2
  title: DataQA Environment Server
3
+ emoji: "\U0001F50D"
4
  colorFrom: blue
5
  colorTo: gray
6
  sdk: docker
7
  pinned: false
8
+ app_port: 7860
 
9
  tags:
10
  - openenv
11
  ---
12
 
13
  # DataQA Environment
14
 
15
+ A two-phase OpenEnv RL environment for **Data Quality Assurance** — an LLM agent inspects corrupted datasets, identifies all planted quality issues, and proposes data repairs.
16
 
17
+ ### Demo: Agent Trajectory Replay
18
 
19
+ ```
20
+ EASY TASK (Step 2) — All 6 issues found + 5 fixes proposed
21
+ Reward: 0.87 | Identify: 1.00 | Fix: 0.67
22
+ ✓ row:4 name: empty → "David Kim" (fix correct)
23
+ ✓ row:7 salary: "seventy-five thousand" → "75000" (fix correct)
24
+ ✓ row:9 salary: "5000" → "73000" (fix correct)
25
+ ✓ row:15 email: mismatch → "oscar.rivera@company.com" (fix correct)
26
+ ✓ row:18 start_date: "2027-06-15" → "2022-01-19" (fix correct)
27
+ ✓ row:21 duplicate row detected
28
+
29
+ HARD TASK (Step 1 → Step 2)
30
+ Step 1: Found 5/10, missed hard issues → Reward: 0.69
31
+ Step 2: Found 10/10 + 5 fixes proposed → Reward: 0.77
32
+ Issues requiring ML knowledge:
33
+ • val_loss < train_loss (data leakage signal)
34
+ • resnet18 using 42.5GB GPU (impossible)
35
+ • 350 epochs on ImageNet in 30 min (impossible)
36
+ • wav2vec2 at 98.5% accuracy (exceeds SOTA)
37
+ ```
38
+
39
+ > The interactive replay UI with color-coded dataset visualization is available on the HF Space.
40
 
41
+ ## Motivation
42
 
43
+ Every ML engineer and data scientist spends significant time debugging data quality issues — missing values, type mismatches, logical inconsistencies, and subtle statistical anomalies — before data enters ML pipelines or production databases. This is a genuine, high-frequency human task that directly impacts model quality and business outcomes.
44
+
45
+ DataQA turns this into a **two-phase RL challenge**:
46
+ 1. **Identify** — systematically inspect corrupted data and pinpoint every planted issue
47
+ 2. **Fix** — propose corrected values by reasoning about schema, constraints, and context
48
+
49
+ This creates a rich multi-step decision problem where agents must explore datasets strategically, distinguish subtle anomalies from noise, and reason about what the correct data should be.
50
 
51
  ## Environment API
52
 
53
+ | Endpoint | Method | Description |
54
+ |----------|--------|-------------|
55
+ | `/reset` | POST | Start a new episode with a corrupted dataset |
56
+ | `/step` | POST | Submit identified issues + proposed fixes |
57
+ | `/state` | GET | Get current episode state |
58
+ | `/health` | GET | Health check |
59
 
60
  ## Tasks
61
 
62
+ | Task | Issues | Difficulty | Domain | Description |
63
+ |------|--------|-----------|--------|-------------|
64
+ | `easy` | 6 | Beginner | HR/Employee data (21 rows) | Nulls, wrong types, duplicates, out-of-range, email-name mismatch, future dates |
65
+ | `medium` | 8 | Intermediate | E-commerce orders (31 rows) | Inconsistent totals, invalid categories, duplicate keys, wrong date formats, invalid country codes, future-date deliveries |
66
+ | `hard` | 10 | Advanced | ML experiment metadata (31 rows) | Data leakage signals, unreasonable GPU memory, impossibly fast training, SOTA-exceeding accuracy, timestamp ordering, whitespace-only fields |
67
+
68
+ **Difficulty progression**: Easy issues are individually obvious (empty fields, text in numeric columns). Medium issues require cross-column reasoning (total != qty * price) and set membership checks. Hard issues require ML domain knowledge (val_loss < train_loss = data leakage) and multi-row temporal reasoning.
69
+
70
+ ## Two-Phase Action Space
71
+
72
+ ### Phase 1: Identify Issues
73
+
74
+ Submit issues in format: `row:<row_number>,col:<column_name>,issue:<issue_type>`
75
+
76
+ - `row_number`: 1-indexed data row position (after header)
77
+ - `column_name`: Exact column header name, lowercase
78
+ - `issue_type`: One of the supported types below
79
+
80
+ ### Phase 2: Propose Fixes
81
+
82
+ Submit fixes in format: `row:<row_number>,col:<column_name>,fix:<corrected_value>`
83
+
84
+ The agent proposes the **correct value** that should replace the corrupted data. Fixes are graded against the original clean dataset.
85
+
86
+ Both phases can be submitted in the same step or across multiple steps.
87
+
88
+ **Supported Issue Types:**
89
+
90
+ | Type | Description | Example |
91
+ |------|-------------|---------|
92
+ | `missing_value` | Null, empty, or whitespace-only | Empty name field |
93
+ | `wrong_type` | Value doesn't match expected type | Salary as "seventy-five thousand" |
94
+ | `duplicate_row` | Exact duplicate or duplicate key | Two rows with same employee_id |
95
+ | `out_of_range` | Value outside valid range | Salary of 5000 when min is 50000 |
96
+ | `format_violation` | Wrong format or invalid enum | Date as DD/MM/YYYY instead of YYYY-MM-DD |
97
+ | `inconsistent_value` | Computed field mismatch, logical inconsistency | total != qty * price |
98
+ | `statistical_outlier` | Unreasonable value given context | resnet18 using 42.5GB GPU |
99
+ | `referential_integrity` | Foreign key violation | (available for custom tasks) |
100
+
101
+ ## Observation Space
102
+
103
+ | Field | Type | Description |
104
+ |-------|------|-------------|
105
+ | `dataset_csv` | str | The corrupted dataset in CSV format |
106
+ | `schema_description` | str | Column types, ranges, and constraints |
107
+ | `validation_rules` | str | Business rules the data must satisfy |
108
+ | `task_description` | str | Task context and instructions |
109
+ | `feedback` | str | Per-step results: TP/FP/FN, precision/recall, fix scores |
110
+ | `num_issues_hint` | int | Exact count of planted issues |
111
+ | `max_steps` | int | Maximum attempts allowed |
112
+ | `done` | bool | Whether episode has terminated |
113
+ | `reward` | float | Best combined reward so far (0.0-1.0) |
114
+
115
+ **Observation Metadata** (per step):
116
+ - Identify: `identify_f1`, `identify_score`, `precision`, `recall`, `tp`, `fp`, `fn`
117
+ - Fix: `fix_score`, `fixes_correct`, `fixes_partial`, `fixes_wrong`, `fixes_attempted`
118
+ - Combined: `combined_reward`, `difficulty_found`, `difficulty_missed`
119
 
120
  ## Reward Function
121
 
122
+ ### Combined Reward
123
+
124
+ ```
125
+ combined_reward = 0.6 * identify_score + 0.4 * fix_score
126
+ ```
127
+
128
+ If no fixes are submitted, `combined_reward = identify_score` (no penalty — backward compatible).
129
+
130
+ ### Identify Score (Difficulty-Weighted F1)
131
+
132
+ Each planted issue has a **difficulty weight** (1.0-3.0):
133
+
134
+ | Weight | Category | Examples |
135
+ |--------|----------|----------|
136
+ | 1.0 | Easy | Missing values, obvious out-of-range, wrong type |
137
+ | 1.5-2.0 | Medium | Duplicate keys, format violations, cross-column checks |
138
+ | 2.5-3.0 | Hard | Data leakage, statistical outliers, whitespace-only |
139
+
140
+ - **Weighted Recall** = (difficulty of found issues) / (total difficulty)
141
+ - **Weighted Precision** = penalizes false positives proportional to average difficulty
142
+ - **Weighted F1** = harmonic mean
143
 
144
+ ### Fix Score (Difficulty-Weighted Quality)
 
 
145
 
146
+ Each proposed fix is compared against the original clean value:
147
 
148
+ | Fix Quality | Score | Description |
149
+ |-------------|-------|-------------|
150
+ | Exact match | 1.0 | Case-insensitive, whitespace-stripped match |
151
+ | Numeric close | 0.8 | Within 1% of correct numeric value |
152
+ | Correct cell | 0.1 | Right location, wrong value |
153
+ | Non-issue cell | 0.0 | Fix targets a cell with no issue |
154
 
155
+ Fix score = (sum of best fix score per issue × difficulty weight) / (total difficulty weight)
156
 
157
+ ### Reward Properties
158
 
159
+ - **Per-step partial progress**: reward increases as more issues are found/fixed
160
+ - **Difficulty-aware**: finding subtle issues earns more than obvious ones
161
+ - **Penalizes bad behavior**: false positives reduce score, fixing non-issues earns nothing
162
+ - **Monotonically non-decreasing**: best score across all steps is the final reward
163
+ - **Always in [0.0, 1.0]**: meets hackathon requirement
164
 
165
+ ### Episode Boundaries
166
 
167
+ - Each task allows up to 3 steps (attempts)
168
+ - Episode ends when F1 >= 0.999 (perfect identification) or max steps reached
169
+ - Agent receives detailed feedback after each step to improve on next attempt
170
+
171
+ ## Baseline Scores
172
+
173
+ Baseline agent uses Qwen2.5-72B-Instruct via HuggingFace Router:
174
+
175
+ | Task | Identify Score | Fix Score | Combined | Notes |
176
+ |------|---------------|-----------|----------|-------|
177
+ | `easy` | 0.7-1.0 | 0.5-0.9 | 0.6-1.0 | Most LLMs find obvious issues reliably |
178
+ | `medium` | 0.5-0.8 | 0.3-0.6 | 0.4-0.7 | Cross-column reasoning challenges models |
179
+ | `hard` | 0.3-0.6 | 0.2-0.4 | 0.3-0.5 | ML domain knowledge and subtle patterns |
180
+
181
+ Scores vary by model. The hard task is designed to challenge frontier models.
182
+
183
+ ## Extensibility
184
+
185
+ ### Custom Contamination Rules
186
+
187
+ ```python
188
+ from dataqa_env import register_contamination_rule
189
+ from dataqa_env.server.tasks import PlantedIssue
190
+
191
+ def swap_digits(rows, header, col_idx, row_idx, rng):
192
+ val = rows[row_idx][col_idx]
193
+ corrupted = val[::-1]
194
+ issue = PlantedIssue(
195
+ row=row_idx + 1, col=header[col_idx],
196
+ issue_type="format_violation",
197
+ description=f"Digits swapped in {header[col_idx]}",
198
+ difficulty=2.0,
199
+ )
200
+ return corrupted, issue
201
+
202
+ register_contamination_rule("swap_digits", swap_digits)
203
+ ```
204
+
205
+ ### Custom Tasks from Config
206
+
207
+ ```python
208
+ from dataqa_env import create_task_from_config, register_task
209
+
210
+ task = create_task_from_config(
211
+ task_id="custom",
212
+ name="Custom Validation",
213
+ description="Find quality issues in this dataset.",
214
+ schema_description="id: int, name: str, score: int (0-100)",
215
+ validation_rules="No missing values. Scores must be 0-100.",
216
+ clean_csv="id,name,score\n1,Alice,95\n2,Bob,87\n3,Carol,92",
217
+ contaminations=[
218
+ {"rule": "missing_value", "row": 0, "col": 1, "difficulty": 1.0},
219
+ {"rule": "negative_value", "row": 2, "col": 2, "difficulty": 1.5},
220
+ ],
221
+ )
222
+ register_task("custom", lambda seed: task)
223
+ ```
224
+
225
+ ### Built-in Contamination Rules
226
+
227
+ | Rule | Effect | Default Difficulty |
228
+ |------|--------|--------------------|
229
+ | `missing_value` | Sets field to empty string | 1.0 |
230
+ | `whitespace_value` | Sets field to single space | 2.5 |
231
+ | `wrong_type_text` | Replaces with random text | 1.0 |
232
+ | `negative_value` | Negates numeric value | 1.0 |
233
+
234
+ ## Setup & Quick Start
235
 
236
  ```bash
237
  # Install
 
240
  # Run server locally
241
  uvicorn dataqa_env.server.app:app --host 0.0.0.0 --port 8000
242
 
243
+ # Run inference (set your API credentials)
244
+ API_BASE_URL=https://router.huggingface.co/v1 \
245
+ MODEL_NAME=Qwen/Qwen2.5-72B-Instruct \
246
+ HF_TOKEN=your-token \
247
  python inference.py
248
  ```
249
 
250
  ## Docker
251
 
252
  ```bash
253
+ docker build -t dataqa-env .
254
  docker run -p 8000:8000 dataqa-env
255
  ```
256
 
257
+ ## Testing
258
+
259
+ ```bash
260
+ pip install -e ".[dev]"
261
+ pytest tests/ -v
262
+ ```
263
+
264
+ 118 tests covering:
265
+ - Task creation, corruption, and difficulty weights
266
+ - Issue key and fix parsing (standard, lenient, edge cases)
267
+ - F1, weighted reward, and fix quality computation
268
+ - Full environment lifecycle (identify-only and identify+fix)
269
+ - Combined reward calculation and weight verification
270
+ - Inference script parsing and prompt building
271
+ - Structured log format ([START], [STEP], [END])
272
+ - Score bounds (0.0-1.0), best-score monotonicity
273
+ - Extensibility API (custom rules, custom tasks)
274
+
275
+ ## Validation
276
+
277
+ ```bash
278
+ # OpenEnv spec validation
279
+ openenv validate .
280
+
281
+ # Pre-submission validation (requires HF Space URL)
282
+ ./prevalidation_script.sh https://your-space.hf.space
283
+ ```
284
+
285
  ## Environment Variables
286
 
287
  | Variable | Description | Default |
288
  |----------|-------------|---------|
289
+ | `API_BASE_URL` | LLM API endpoint | `https://router.huggingface.co/v1` |
290
+ | `MODEL_NAME` | Model identifier | `Qwen/Qwen2.5-72B-Instruct` |
291
+ | `HF_TOKEN` | HuggingFace token / API key | - |
292
  | `ENV_URL` | Environment server URL | `http://localhost:8000` |
 
293
 
294
  ## Architecture
295
 
296
  ```
297
  dataqa_env/
298
+ ├── __init__.py # Public API + extensibility exports
299
+ ├── models.py # Pydantic: DataQAAction (issues + fixes), DataQAObservation, DataQAState
300
  ├── client.py # EnvClient for WebSocket connections
301
  ├── server/
302
+ │ ├── environment.py # Two-phase DataQAEnvironment (identify + fix + combined reward)
303
+ │ ├── tasks.py # Task definitions + contamination rules + extensibility API
304
+ │ ├── app.py # FastAPI server (via openenv-core create_app)
305
  │ └── Dockerfile
306
+ tests/
307
+ ├── test_tasks.py # Task creation, corruption, difficulty weights
308
+ ── test_environment.py # Identify scoring, fix grading, combined reward, lifecycle
309
+ ├── test_inference.py # LLM response parsing, fix parsing, prompt building, log format
310
+ └── test_extensibility.py # Custom rules, custom tasks, registration API
311
+ inference.py # Two-phase baseline agent (identify → fix)
312
+ openenv.yaml # OpenEnv/HF Spaces spec
313
+ pyproject.toml # Package metadata and dependencies
314
+ Dockerfile # Production container
315
  ```
__init__.py DELETED
@@ -1,3 +0,0 @@
1
- from dataqa_env import DataQAEnv, DataQAAction, DataQAObservation, DataQAState
2
-
3
- __all__ = ["DataQAEnv", "DataQAAction", "DataQAObservation", "DataQAState"]
 
 
 
 
client.py DELETED
@@ -1,5 +0,0 @@
1
- """Root-level client for OpenEnv compatibility."""
2
- from dataqa_env.client import DataQAEnv
3
- from dataqa_env.models import DataQAAction, DataQAObservation, DataQAState
4
-
5
- __all__ = ["DataQAEnv", "DataQAAction", "DataQAObservation", "DataQAState"]
 
 
 
 
 
 
dataqa_env/__init__.py CHANGED
@@ -1,4 +1,19 @@
1
  from .client import DataQAEnv
2
  from .models import DataQAAction, DataQAObservation, DataQAState
 
 
 
 
 
 
3
 
4
- __all__ = ["DataQAEnv", "DataQAAction", "DataQAObservation", "DataQAState"]
 
 
 
 
 
 
 
 
 
 
1
  from .client import DataQAEnv
2
  from .models import DataQAAction, DataQAObservation, DataQAState
3
+ from .server.tasks import (
4
+ create_task_from_config,
5
+ register_task,
6
+ register_contamination_rule,
7
+ CONTAMINATION_RULES,
8
+ )
9
 
10
+ __all__ = [
11
+ "DataQAEnv",
12
+ "DataQAAction",
13
+ "DataQAObservation",
14
+ "DataQAState",
15
+ "create_task_from_config",
16
+ "register_task",
17
+ "register_contamination_rule",
18
+ "CONTAMINATION_RULES",
19
+ ]
dataqa_env/models.py CHANGED
@@ -16,21 +16,23 @@ from openenv.core.env_server.interfaces import Action, Observation, State
16
 
17
  class DataQAAction(Action):
18
  """
19
- Agent submits a list of identified data quality issues.
 
 
 
 
 
 
 
20
 
21
- Each issue is a string in the format: "row:<row_idx>,col:<col_name>,issue:<issue_type>"
22
  Supported issue types:
23
- - missing_value
24
- - wrong_type
25
- - duplicate_row
26
- - out_of_range
27
- - format_violation
28
- - inconsistent_value
29
- - statistical_outlier
30
- - referential_integrity
31
  """
32
 
33
  issues: List[str]
 
34
  # Include task_id so step() can reconstruct context in stateless HTTP mode
35
  task_id: str = "easy"
36
 
 
16
 
17
  class DataQAAction(Action):
18
  """
19
+ Agent submits identified issues AND optional proposed fixes.
20
+
21
+ Two-phase action space:
22
+ Phase 1 (Identify): List issues in format "row:<N>,col:<name>,issue:<type>"
23
+ Phase 2 (Fix): List fixes in format "row:<N>,col:<name>,fix:<proposed_value>"
24
+
25
+ The agent can submit both in the same step or across multiple steps.
26
+ Combined reward = 0.6 * identify_score + 0.4 * fix_score
27
 
 
28
  Supported issue types:
29
+ missing_value, wrong_type, duplicate_row, out_of_range,
30
+ format_violation, inconsistent_value, statistical_outlier,
31
+ referential_integrity
 
 
 
 
 
32
  """
33
 
34
  issues: List[str]
35
+ fixes: List[str] = []
36
  # Include task_id so step() can reconstruct context in stateless HTTP mode
37
  task_id: str = "easy"
38
 
dataqa_env/server/app.py CHANGED
@@ -19,9 +19,20 @@ app = create_app(
19
  )
20
 
21
 
 
 
 
 
 
 
 
 
 
 
 
22
  def main():
23
  import uvicorn
24
- uvicorn.run(app, host="0.0.0.0", port=8000)
25
 
26
 
27
  if __name__ == "__main__":
 
19
  )
20
 
21
 
22
+ @app.get("/")
23
+ def root():
24
+ """Root endpoint — environment info."""
25
+ return {
26
+ "name": "DataQA Environment",
27
+ "description": "Two-phase data quality assurance environment: identify issues + propose fixes",
28
+ "tasks": ["easy", "medium", "hard"],
29
+ "endpoints": ["/health", "/reset", "/step", "/state"],
30
+ }
31
+
32
+
33
  def main():
34
  import uvicorn
35
+ uvicorn.run(app, host="0.0.0.0", port=7860)
36
 
37
 
38
  if __name__ == "__main__":
dataqa_env/server/environment.py CHANGED
@@ -3,8 +3,12 @@ DataQA Environment
3
  ------------------
4
  Server-side environment for data quality assurance tasks.
5
 
6
- The agent receives corrupted datasets and must identify planted quality issues.
7
- Scoring is based on F1 (precision-recall) of correctly matched issues.
 
 
 
 
8
  """
9
 
10
  from __future__ import annotations
@@ -18,6 +22,10 @@ from openenv.core.env_server.interfaces import Action, Environment, Observation
18
  from ..models import DataQAAction, DataQAObservation, DataQAState
19
  from .tasks import PlantedIssue, Task, get_task, list_tasks
20
 
 
 
 
 
21
 
22
  def parse_issue_key(raw: str) -> Optional[str]:
23
  """
@@ -26,7 +34,6 @@ def parse_issue_key(raw: str) -> Optional[str]:
26
  Returns normalized key or None if unparseable.
27
  """
28
  raw = raw.strip().lower()
29
- # Be lenient with formatting
30
  row_match = re.search(r"row\s*[:=]\s*(\d+)", raw)
31
  col_match = re.search(r"col\s*[:=]\s*([\w_]+)", raw)
32
  issue_match = re.search(r"issue\s*[:=]\s*([\w_]+)", raw)
@@ -36,6 +43,22 @@ def parse_issue_key(raw: str) -> Optional[str]:
36
  return None
37
 
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  def compute_f1(reported_keys: Set[str], planted_keys: Set[str]) -> dict:
40
  """Compute precision, recall, and F1 score."""
41
  if not reported_keys and not planted_keys:
@@ -58,12 +81,185 @@ def compute_f1(reported_keys: Set[str], planted_keys: Set[str]) -> dict:
58
  return {"precision": precision, "recall": recall, "f1": f1, "tp": tp, "fp": fp, "fn": fn}
59
 
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  class DataQAEnvironment(Environment):
62
  """
63
- Data Quality Assurance environment.
64
 
65
- The agent inspects corrupted datasets and reports quality issues.
66
- Reward is F1 score of correctly identified issues vs planted ground truth.
 
 
 
67
  """
68
 
69
  SUPPORTS_CONCURRENT_SESSIONS = True
@@ -103,7 +299,11 @@ class DataQAEnvironment(Environment):
103
  schema_description=self._current_task.schema_description,
104
  validation_rules=self._current_task.validation_rules,
105
  task_description=self._current_task.description,
106
- feedback="Environment reset. Inspect the dataset and report all quality issues.",
 
 
 
 
107
  task_id=task_id,
108
  num_issues_hint=len(self._current_task.planted_issues),
109
  max_steps=self._current_task.max_steps,
@@ -120,15 +320,14 @@ class DataQAEnvironment(Environment):
120
  if not isinstance(action, DataQAAction):
121
  raise ValueError(f"Expected DataQAAction, got {type(action)}")
122
 
123
- # In stateless HTTP mode, each request creates a fresh env instance.
124
- # Auto-reset using the task_id from the action so step() works standalone.
125
  if self._current_task is None:
126
  self.reset(task_id=action.task_id)
127
 
128
  self._state.step_count += 1
129
  self._state.current_step += 1
130
 
131
- # Parse reported issues
132
  reported_keys: Set[str] = set()
133
  parse_errors: list[str] = []
134
  for raw_issue in action.issues:
@@ -136,44 +335,148 @@ class DataQAEnvironment(Environment):
136
  if key:
137
  reported_keys.add(key)
138
  else:
139
- parse_errors.append(f"Could not parse: '{raw_issue}'")
140
 
141
- # Compute score
142
  metrics = compute_f1(reported_keys, self._planted_keys)
143
- score = metrics["f1"]
144
- self._best_score = max(self._best_score, score)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  self._state.best_score = self._best_score
146
 
147
- # Check if done
148
  is_done = (
149
- score >= 0.999 # Perfect score
150
  or self._state.current_step >= self._state.max_steps
151
  )
152
 
153
- # Build feedback
 
 
 
 
154
  feedback_lines = [
155
  f"Step {self._state.current_step}/{self._state.max_steps}",
 
 
156
  f"Issues reported: {len(reported_keys)}",
157
  f"True positives: {metrics['tp']}, False positives: {metrics['fp']}, Missed: {metrics['fn']}",
158
- f"Precision: {metrics['precision']:.3f}, Recall: {metrics['recall']:.3f}, F1: {score:.3f}",
 
159
  ]
160
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
  if parse_errors:
162
- feedback_lines.append(f"Parse errors ({len(parse_errors)}): {'; '.join(parse_errors[:3])}")
163
 
164
  if not is_done:
165
- # Give hints about what was missed without revealing exact answers
166
  if metrics["fn"] > 0:
167
  feedback_lines.append(
168
- f"You missed {metrics['fn']} issue(s). Review the dataset carefully."
169
  )
170
  if metrics["fp"] > 0:
171
  feedback_lines.append(
172
- f"{metrics['fp']} of your reported issues were incorrect."
173
  )
174
- feedback_lines.append("You can submit again with an updated list of issues.")
175
  else:
176
- feedback_lines.append(f"Task complete! Final best F1 score: {self._best_score:.3f}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
 
178
  return DataQAObservation(
179
  dataset_csv=self._current_task.corrupted_csv,
@@ -186,6 +489,25 @@ class DataQAEnvironment(Environment):
186
  max_steps=self._state.max_steps,
187
  done=is_done,
188
  reward=self._best_score,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
  )
190
 
191
  @property
 
3
  ------------------
4
  Server-side environment for data quality assurance tasks.
5
 
6
+ Two-phase RL environment:
7
+ Phase 1 (Identify): Agent inspects corrupted datasets and reports quality issues.
8
+ Phase 2 (Fix): Agent proposes corrections for identified issues.
9
+
10
+ Combined reward = 0.6 * identify_score + 0.4 * fix_score
11
+ Both phases scored with difficulty-weighted metrics for rich per-step signal.
12
  """
13
 
14
  from __future__ import annotations
 
22
  from ..models import DataQAAction, DataQAObservation, DataQAState
23
  from .tasks import PlantedIssue, Task, get_task, list_tasks
24
 
25
+ # Reward weights for the two phases
26
+ IDENTIFY_WEIGHT = 0.6
27
+ FIX_WEIGHT = 0.4
28
+
29
 
30
  def parse_issue_key(raw: str) -> Optional[str]:
31
  """
 
34
  Returns normalized key or None if unparseable.
35
  """
36
  raw = raw.strip().lower()
 
37
  row_match = re.search(r"row\s*[:=]\s*(\d+)", raw)
38
  col_match = re.search(r"col\s*[:=]\s*([\w_]+)", raw)
39
  issue_match = re.search(r"issue\s*[:=]\s*([\w_]+)", raw)
 
43
  return None
44
 
45
 
46
+ def parse_fix(raw: str) -> Optional[tuple[int, str, str]]:
47
+ """
48
+ Parse an agent-proposed fix into (row, col, proposed_value).
49
+ Expected format: row:<N>,col:<name>,fix:<value>
50
+ Returns (row, col, value) or None if unparseable.
51
+ """
52
+ raw = raw.strip()
53
+ row_match = re.search(r"row\s*[:=]\s*(\d+)", raw, re.IGNORECASE)
54
+ col_match = re.search(r"col(?:umn)?\s*[:=]\s*([\w_]+)", raw, re.IGNORECASE)
55
+ fix_match = re.search(r"fix\s*[:=]\s*(.+?)$", raw, re.IGNORECASE)
56
+
57
+ if row_match and col_match and fix_match:
58
+ return (int(row_match.group(1)), col_match.group(1).lower(), fix_match.group(1).strip())
59
+ return None
60
+
61
+
62
  def compute_f1(reported_keys: Set[str], planted_keys: Set[str]) -> dict:
63
  """Compute precision, recall, and F1 score."""
64
  if not reported_keys and not planted_keys:
 
81
  return {"precision": precision, "recall": recall, "f1": f1, "tp": tp, "fp": fp, "fn": fn}
82
 
83
 
84
+ def compute_weighted_reward(
85
+ reported_keys: Set[str],
86
+ planted_issues: list,
87
+ ) -> dict:
88
+ """
89
+ Compute difficulty-weighted reward for richer per-step signal.
90
+
91
+ Each planted issue has a difficulty weight (1.0-3.0). Finding harder issues
92
+ earns more reward. False positives incur a penalty scaled by average difficulty.
93
+
94
+ Returns dict with weighted_reward (0.0-1.0), plus per-issue breakdown.
95
+ """
96
+ if not planted_issues and not reported_keys:
97
+ return {"weighted_reward": 1.0, "difficulty_found": 0.0, "difficulty_missed": 0.0}
98
+
99
+ planted_by_key = {issue.to_key(): issue for issue in planted_issues}
100
+ planted_keys = set(planted_by_key.keys())
101
+
102
+ if not reported_keys:
103
+ total_weight = sum(i.difficulty for i in planted_issues)
104
+ return {"weighted_reward": 0.0, "difficulty_found": 0.0, "difficulty_missed": total_weight}
105
+
106
+ if not planted_keys:
107
+ return {"weighted_reward": 0.0, "difficulty_found": 0.0, "difficulty_missed": 0.0}
108
+
109
+ found_keys = reported_keys & planted_keys
110
+ missed_keys = planted_keys - reported_keys
111
+ false_positive_count = len(reported_keys - planted_keys)
112
+
113
+ difficulty_found = sum(planted_by_key[k].difficulty for k in found_keys)
114
+ difficulty_missed = sum(planted_by_key[k].difficulty for k in missed_keys)
115
+ total_weight = sum(i.difficulty for i in planted_issues)
116
+
117
+ weighted_recall = difficulty_found / total_weight if total_weight > 0 else 0.0
118
+
119
+ avg_difficulty = total_weight / len(planted_issues)
120
+ fp_penalty_weight = false_positive_count * avg_difficulty
121
+ weighted_precision = difficulty_found / (difficulty_found + fp_penalty_weight) if (difficulty_found + fp_penalty_weight) > 0 else 0.0
122
+
123
+ if (weighted_precision + weighted_recall) > 0:
124
+ weighted_reward = 2 * weighted_precision * weighted_recall / (weighted_precision + weighted_recall)
125
+ else:
126
+ weighted_reward = 0.0
127
+
128
+ return {
129
+ "weighted_reward": round(weighted_reward, 4),
130
+ "difficulty_found": round(difficulty_found, 2),
131
+ "difficulty_missed": round(difficulty_missed, 2),
132
+ }
133
+
134
+
135
+ def grade_fixes(
136
+ fixes: list[tuple[int, str, str]],
137
+ task: Task,
138
+ ) -> dict:
139
+ """
140
+ Grade proposed fixes against the clean dataset.
141
+
142
+ For each fix (row, col, proposed_value), compare to the original clean value.
143
+ Scoring per fix:
144
+ - Exact match (case-insensitive, whitespace-stripped): 1.0
145
+ - Numeric close match (within 1%): 0.8
146
+ - Correct column but wrong value: 0.1
147
+ - Targets a non-issue cell: 0.0 (penalty)
148
+
149
+ Returns dict with fix_score (0.0-1.0), details per fix, and counts.
150
+ """
151
+ if not fixes and not task.planted_issues:
152
+ return {"fix_score": 1.0, "fixes_correct": 0, "fixes_partial": 0,
153
+ "fixes_wrong": 0, "fixes_attempted": 0, "fix_details": []}
154
+
155
+ if not fixes:
156
+ return {"fix_score": 0.0, "fixes_correct": 0, "fixes_partial": 0,
157
+ "fixes_wrong": 0, "fixes_attempted": 0, "fix_details": []}
158
+
159
+ issue_map = task.get_planted_issue_map()
160
+ # Build set of (row, col) that are actual issues
161
+ issue_cells = {(issue.row, issue.col) for issue in task.planted_issues}
162
+
163
+ total_weight = sum(i.difficulty for i in task.planted_issues) if task.planted_issues else 1.0
164
+ earned_weight = 0.0
165
+ fixes_correct = 0
166
+ fixes_partial = 0
167
+ fixes_wrong = 0
168
+ fix_details = []
169
+
170
+ # Track which issues have been fixed (best fix wins)
171
+ fixed_issues: dict[tuple[int, str], float] = {}
172
+
173
+ for row, col, proposed in fixes:
174
+ clean_value = task.get_clean_value(row, col)
175
+ cell_key = (row, col)
176
+
177
+ if cell_key not in issue_cells:
178
+ # Fix targets a non-issue cell — no credit
179
+ fix_details.append({"row": row, "col": col, "score": 0.0, "reason": "not an issue cell"})
180
+ fixes_wrong += 1
181
+ continue
182
+
183
+ if clean_value is None:
184
+ fix_details.append({"row": row, "col": col, "score": 0.0, "reason": "cell not found"})
185
+ fixes_wrong += 1
186
+ continue
187
+
188
+ # Find the planted issue for this cell to get its difficulty weight
189
+ matching_issue = None
190
+ for issue in task.planted_issues:
191
+ if issue.row == row and issue.col == col:
192
+ matching_issue = issue
193
+ break
194
+
195
+ difficulty = matching_issue.difficulty if matching_issue else 1.0
196
+
197
+ # Score the fix
198
+ score = 0.0
199
+ reason = "wrong value"
200
+
201
+ # Exact match (case-insensitive, whitespace-stripped)
202
+ if proposed.strip().lower() == clean_value.lower():
203
+ score = 1.0
204
+ reason = "exact match"
205
+ fixes_correct += 1
206
+ else:
207
+ # Try numeric close match
208
+ try:
209
+ proposed_num = float(proposed.strip())
210
+ clean_num = float(clean_value)
211
+ if clean_num != 0 and abs(proposed_num - clean_num) / abs(clean_num) <= 0.01:
212
+ score = 0.8
213
+ reason = "numeric close match"
214
+ fixes_partial += 1
215
+ elif proposed_num == clean_num:
216
+ score = 1.0
217
+ reason = "exact numeric match"
218
+ fixes_correct += 1
219
+ else:
220
+ score = 0.1
221
+ reason = "correct cell, wrong value"
222
+ fixes_partial += 1
223
+ except (ValueError, ZeroDivisionError):
224
+ # Not numeric — just a wrong value but at least right cell
225
+ score = 0.1
226
+ reason = "correct cell, wrong value"
227
+ fixes_partial += 1
228
+
229
+ # Keep best fix per cell
230
+ if cell_key not in fixed_issues or score > fixed_issues[cell_key]:
231
+ fixed_issues[cell_key] = score
232
+
233
+ fix_details.append({"row": row, "col": col, "score": score, "reason": reason})
234
+
235
+ # Compute fix score: weighted sum of best fix per issue / total weight
236
+ for issue in task.planted_issues:
237
+ cell_key = (issue.row, issue.col)
238
+ if cell_key in fixed_issues:
239
+ earned_weight += issue.difficulty * fixed_issues[cell_key]
240
+
241
+ fix_score = earned_weight / total_weight if total_weight > 0 else 0.0
242
+ fix_score = min(max(fix_score, 0.0), 1.0)
243
+
244
+ return {
245
+ "fix_score": round(fix_score, 4),
246
+ "fixes_correct": fixes_correct,
247
+ "fixes_partial": fixes_partial,
248
+ "fixes_wrong": fixes_wrong,
249
+ "fixes_attempted": len(fixes),
250
+ "fix_details": fix_details,
251
+ }
252
+
253
+
254
  class DataQAEnvironment(Environment):
255
  """
256
+ Data Quality Assurance environment — two-phase identify + fix.
257
 
258
+ Phase 1 (Identify): Agent inspects corrupted datasets and reports quality issues.
259
+ Phase 2 (Fix): Agent proposes corrections for identified issues.
260
+
261
+ Combined reward = 0.6 * identify_score + 0.4 * fix_score
262
+ Both phases use difficulty-weighted scoring for rich per-step reward signals.
263
  """
264
 
265
  SUPPORTS_CONCURRENT_SESSIONS = True
 
299
  schema_description=self._current_task.schema_description,
300
  validation_rules=self._current_task.validation_rules,
301
  task_description=self._current_task.description,
302
+ feedback=(
303
+ "Environment reset. Inspect the dataset and report all quality issues.\n"
304
+ "You can also propose fixes in format: row:<N>,col:<name>,fix:<corrected_value>\n"
305
+ "Combined reward = 0.6 * identify_score + 0.4 * fix_score"
306
+ ),
307
  task_id=task_id,
308
  num_issues_hint=len(self._current_task.planted_issues),
309
  max_steps=self._current_task.max_steps,
 
320
  if not isinstance(action, DataQAAction):
321
  raise ValueError(f"Expected DataQAAction, got {type(action)}")
322
 
323
+ # Auto-reset in stateless HTTP mode
 
324
  if self._current_task is None:
325
  self.reset(task_id=action.task_id)
326
 
327
  self._state.step_count += 1
328
  self._state.current_step += 1
329
 
330
+ # ── Phase 1: Parse and score issue identification ──
331
  reported_keys: Set[str] = set()
332
  parse_errors: list[str] = []
333
  for raw_issue in action.issues:
 
335
  if key:
336
  reported_keys.add(key)
337
  else:
338
+ parse_errors.append(f"Could not parse issue: '{raw_issue}'")
339
 
 
340
  metrics = compute_f1(reported_keys, self._planted_keys)
341
+ identify_f1 = metrics["f1"]
342
+
343
+ weighted = compute_weighted_reward(reported_keys, self._current_task.planted_issues)
344
+ identify_score = weighted["weighted_reward"]
345
+
346
+ # ── Phase 2: Parse and score proposed fixes ──
347
+ parsed_fixes: list[tuple[int, str, str]] = []
348
+ for raw_fix in action.fixes:
349
+ fix = parse_fix(raw_fix)
350
+ if fix:
351
+ parsed_fixes.append(fix)
352
+ else:
353
+ parse_errors.append(f"Could not parse fix: '{raw_fix}'")
354
+
355
+ fix_result = grade_fixes(parsed_fixes, self._current_task)
356
+ fix_score = fix_result["fix_score"]
357
+
358
+ # ── Combined reward ──
359
+ # If no fixes submitted, score is identify-only (no penalty for not fixing)
360
+ if action.fixes:
361
+ combined_reward = IDENTIFY_WEIGHT * identify_score + FIX_WEIGHT * fix_score
362
+ else:
363
+ combined_reward = identify_score # backward compatible
364
+
365
+ self._best_score = max(self._best_score, combined_reward)
366
  self._state.best_score = self._best_score
367
 
368
+ # ── Check if done ──
369
  is_done = (
370
+ identify_f1 >= 0.999 # Perfect identification
371
  or self._state.current_step >= self._state.max_steps
372
  )
373
 
374
+ # ── Build feedback with actionable diagnostics ──
375
+ # Show the agent exactly which reported issues were correct (TP) and which were wrong (FP)
376
+ tp_keys = reported_keys & self._planted_keys
377
+ fp_keys = reported_keys - self._planted_keys
378
+
379
  feedback_lines = [
380
  f"Step {self._state.current_step}/{self._state.max_steps}",
381
+ "",
382
+ "--- Identification ---",
383
  f"Issues reported: {len(reported_keys)}",
384
  f"True positives: {metrics['tp']}, False positives: {metrics['fp']}, Missed: {metrics['fn']}",
385
+ f"Precision: {metrics['precision']:.3f}, Recall: {metrics['recall']:.3f}, F1: {identify_f1:.3f}",
386
+ f"Identify score (weighted): {identify_score:.3f}",
387
  ]
388
 
389
+ # Show which reported issues were correct vs wrong (helps agent self-correct)
390
+ if tp_keys:
391
+ feedback_lines.append(f"Correct issues: {', '.join(sorted(tp_keys))}")
392
+ if fp_keys:
393
+ feedback_lines.append(f"Incorrect issues (false positives): {', '.join(sorted(fp_keys))}")
394
+
395
+ if action.fixes:
396
+ feedback_lines += [
397
+ "",
398
+ "--- Fix Proposals ---",
399
+ f"Fixes attempted: {fix_result['fixes_attempted']}",
400
+ f"Correct: {fix_result['fixes_correct']}, Partial: {fix_result['fixes_partial']}, Wrong: {fix_result['fixes_wrong']}",
401
+ f"Fix score: {fix_score:.3f}",
402
+ ]
403
+ # Show per-fix feedback so agent knows which fixes worked
404
+ for detail in fix_result["fix_details"]:
405
+ status = "correct" if detail["score"] >= 0.99 else ("partial" if detail["score"] > 0 else "wrong")
406
+ feedback_lines.append(
407
+ f" row:{detail['row']},col:{detail['col']} -> {status} ({detail['reason']})"
408
+ )
409
+ feedback_lines.append(
410
+ f"\n--- Combined Reward: {combined_reward:.3f} (identify={identify_score:.3f} x {IDENTIFY_WEIGHT} + fix={fix_score:.3f} x {FIX_WEIGHT}) ---"
411
+ )
412
+ else:
413
+ feedback_lines += [
414
+ "",
415
+ "Tip: Submit fixes with format row:<N>,col:<name>,fix:<value> for bonus reward.",
416
+ ]
417
+
418
  if parse_errors:
419
+ feedback_lines.append(f"\nParse errors ({len(parse_errors)}): {'; '.join(parse_errors[:5])}")
420
 
421
  if not is_done:
 
422
  if metrics["fn"] > 0:
423
  feedback_lines.append(
424
+ f"\nYou missed {metrics['fn']} issue(s). Review the dataset carefully."
425
  )
426
  if metrics["fp"] > 0:
427
  feedback_lines.append(
428
+ f"Remove the {metrics['fp']} false positive(s) listed above and look for real issues."
429
  )
430
+ feedback_lines.append("You can submit again with updated issues and/or fixes.")
431
  else:
432
+ feedback_lines.append(f"\nTask complete! Final best reward: {self._best_score:.3f}")
433
+
434
+ # ── Flag items for human review ──
435
+ # In a production data QA pipeline, these would go to a human reviewer.
436
+ # The grader flags cases where automated scoring has low confidence.
437
+ human_review_flags: list[dict] = []
438
+
439
+ # 1. False positives that target real columns — could be legitimate issues
440
+ # the task designer didn't plant (agent may be smarter than the grader)
441
+ issue_map = self._current_task.get_planted_issue_map()
442
+ valid_issue_types = {"missing_value", "wrong_type", "duplicate_row", "out_of_range",
443
+ "format_violation", "inconsistent_value", "statistical_outlier",
444
+ "referential_integrity"}
445
+ for fp_key in fp_keys:
446
+ parts = fp_key.split(",")
447
+ itype = parts[2].split(":")[1] if len(parts) >= 3 else ""
448
+ if itype in valid_issue_types:
449
+ human_review_flags.append({
450
+ "item": fp_key,
451
+ "reason": "Agent reported this issue but it's not in ground truth — may be a real issue the grader missed",
452
+ "type": "possible_unplanted_issue",
453
+ })
454
+
455
+ # 2. Partial fix matches — fix was close but not exact, human should verify
456
+ for detail in fix_result["fix_details"]:
457
+ if 0 < detail["score"] < 0.99:
458
+ human_review_flags.append({
459
+ "item": f"row:{detail['row']},col:{detail['col']}",
460
+ "reason": f"Fix scored {detail['score']:.2f} ({detail['reason']}) — human should verify if acceptable",
461
+ "type": "partial_fix",
462
+ })
463
+
464
+ # 3. High-difficulty issues that were missed — flag for training data review
465
+ planted_by_key = {i.to_key(): i for i in self._current_task.planted_issues}
466
+ fn_keys = self._planted_keys - reported_keys
467
+ for fn_key in fn_keys:
468
+ issue = planted_by_key.get(fn_key)
469
+ if issue and issue.difficulty >= 2.5:
470
+ human_review_flags.append({
471
+ "item": fn_key,
472
+ "reason": f"High-difficulty issue (difficulty={issue.difficulty}) missed — {issue.description}",
473
+ "type": "missed_hard_issue",
474
+ })
475
+
476
+ if human_review_flags:
477
+ feedback_lines.append(f"\n--- Flagged for Human Review ({len(human_review_flags)}) ---")
478
+ for flag in human_review_flags:
479
+ feedback_lines.append(f" [{flag['type']}] {flag['item']}: {flag['reason']}")
480
 
481
  return DataQAObservation(
482
  dataset_csv=self._current_task.corrupted_csv,
 
489
  max_steps=self._state.max_steps,
490
  done=is_done,
491
  reward=self._best_score,
492
+ metadata={
493
+ "identify_f1": identify_f1,
494
+ "identify_score": identify_score,
495
+ "fix_score": fix_score,
496
+ "combined_reward": combined_reward,
497
+ "precision": metrics["precision"],
498
+ "recall": metrics["recall"],
499
+ "tp": metrics["tp"],
500
+ "fp": metrics["fp"],
501
+ "fn": metrics["fn"],
502
+ "difficulty_found": weighted["difficulty_found"],
503
+ "difficulty_missed": weighted["difficulty_missed"],
504
+ "fixes_correct": fix_result["fixes_correct"],
505
+ "fixes_partial": fix_result["fixes_partial"],
506
+ "fixes_wrong": fix_result["fixes_wrong"],
507
+ "fixes_attempted": fix_result["fixes_attempted"],
508
+ "fix_details": fix_result["fix_details"],
509
+ "human_review_flags": human_review_flags,
510
+ },
511
  )
512
 
513
  @property
dataqa_env/server/gradio_ui.py ADDED
@@ -0,0 +1,508 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio UI — Agent Trajectory Replay Viewer for DataQA.
3
+
4
+ Designed for judges: zero clicks needed, auto-plays on load.
5
+ Tab per task, step slider, prominent metric cards, color-coded dataset.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import csv
11
+ import io
12
+
13
+ import gradio as gr
14
+
15
+ from .environment import DataQAEnvironment, parse_issue_key
16
+ from .tasks import list_tasks, PlantedIssue
17
+ from ..models import DataQAAction
18
+
19
+
20
+ # ── Pre-built agent trajectories (simulates baseline agent) ──
21
+
22
+ AGENT_TRAJECTORIES = {
23
+ "easy": [
24
+ {
25
+ "issues": [
26
+ "row:4,col:name,issue:missing_value",
27
+ "row:7,col:salary,issue:wrong_type",
28
+ "row:9,col:salary,issue:out_of_range",
29
+ "row:18,col:start_date,issue:out_of_range",
30
+ "row:3,col:email,issue:format_violation", # FP
31
+ ],
32
+ "fixes": [],
33
+ },
34
+ {
35
+ "issues": [
36
+ "row:4,col:name,issue:missing_value",
37
+ "row:7,col:salary,issue:wrong_type",
38
+ "row:9,col:salary,issue:out_of_range",
39
+ "row:21,col:employee_id,issue:duplicate_row",
40
+ "row:15,col:email,issue:inconsistent_value",
41
+ "row:18,col:start_date,issue:out_of_range",
42
+ ],
43
+ "fixes": [
44
+ "row:4,col:name,fix:David Kim",
45
+ "row:7,col:salary,fix:75000",
46
+ "row:9,col:salary,fix:73000",
47
+ "row:15,col:email,fix:oscar.rivera@company.com",
48
+ "row:18,col:start_date,fix:2022-01-19",
49
+ ],
50
+ },
51
+ ],
52
+ "medium": [
53
+ {
54
+ "issues": [
55
+ "row:5,col:total,issue:inconsistent_value",
56
+ "row:10,col:category,issue:format_violation",
57
+ "row:14,col:product_name,issue:missing_value",
58
+ "row:17,col:quantity,issue:out_of_range",
59
+ "row:19,col:order_id,issue:duplicate_row",
60
+ "row:12,col:order_date,issue:format_violation",
61
+ "row:24,col:shipping_country,issue:format_violation",
62
+ ],
63
+ "fixes": [],
64
+ },
65
+ {
66
+ "issues": [
67
+ "row:5,col:total,issue:inconsistent_value",
68
+ "row:10,col:category,issue:format_violation",
69
+ "row:14,col:product_name,issue:missing_value",
70
+ "row:17,col:quantity,issue:out_of_range",
71
+ "row:19,col:order_id,issue:duplicate_row",
72
+ "row:12,col:order_date,issue:format_violation",
73
+ "row:24,col:shipping_country,issue:format_violation",
74
+ "row:29,col:order_date,issue:inconsistent_value",
75
+ ],
76
+ "fixes": [
77
+ "row:5,col:total,fix:42.00",
78
+ "row:10,col:category,fix:Sports",
79
+ "row:12,col:order_date,fix:2024-01-26",
80
+ "row:14,col:product_name,fix:LED Strip Lights",
81
+ "row:24,col:shipping_country,fix:US",
82
+ "row:29,col:order_date,fix:2024-02-12",
83
+ ],
84
+ },
85
+ ],
86
+ "hard": [
87
+ {
88
+ "issues": [
89
+ "row:14,col:training_time_hours,issue:out_of_range",
90
+ "row:13,col:learning_rate,issue:out_of_range",
91
+ "row:15,col:model_name,issue:missing_value",
92
+ "row:9,col:batch_size,issue:format_violation",
93
+ "row:10,col:train_size,issue:inconsistent_value",
94
+ ],
95
+ "fixes": [],
96
+ },
97
+ {
98
+ "issues": [
99
+ "row:14,col:training_time_hours,issue:out_of_range",
100
+ "row:13,col:learning_rate,issue:out_of_range",
101
+ "row:15,col:model_name,issue:missing_value",
102
+ "row:9,col:batch_size,issue:format_violation",
103
+ "row:10,col:train_size,issue:inconsistent_value",
104
+ "row:5,col:val_loss,issue:inconsistent_value",
105
+ "row:7,col:gpu_memory_gb,issue:statistical_outlier",
106
+ "row:11,col:timestamp,issue:inconsistent_value",
107
+ "row:9,col:training_time_hours,issue:statistical_outlier",
108
+ "row:12,col:test_accuracy,issue:statistical_outlier",
109
+ ],
110
+ "fixes": [
111
+ "row:14,col:training_time_hours,fix:72.0",
112
+ "row:13,col:learning_rate,fix:0.00001",
113
+ "row:15,col:model_name,fix:whisper-small",
114
+ "row:9,col:batch_size,fix:256",
115
+ "row:9,col:training_time_hours,fix:36.0",
116
+ ],
117
+ },
118
+ ],
119
+ }
120
+
121
+
122
+ # ── HTML rendering ──
123
+
124
+ def _metric_card(label: str, value: str, color: str = "#333") -> str:
125
+ return (
126
+ f'<div style="text-align:center;padding:12px 16px;background:#f8f9fa;'
127
+ f'border-radius:8px;min-width:100px;">'
128
+ f'<div style="font-size:11px;color:#666;text-transform:uppercase;letter-spacing:1px;">{label}</div>'
129
+ f'<div style="font-size:28px;font-weight:700;color:{color};margin-top:2px;">{value}</div>'
130
+ f'</div>'
131
+ )
132
+
133
+
134
+ def _csv_to_html(
135
+ csv_text: str,
136
+ planted: list[PlantedIssue],
137
+ correct: set[tuple[int, str]],
138
+ fp: set[tuple[int, str]],
139
+ missed: set[tuple[int, str]],
140
+ fixed: dict[tuple[int, str], str],
141
+ fix_values: dict[tuple[int, str], str] | None = None,
142
+ ) -> str:
143
+ """Render CSV as HTML with color-coded cells and inline fix proposals."""
144
+ fix_values = fix_values or {}
145
+ desc_map = {(i.row, i.col): i for i in planted}
146
+ reader = csv.reader(io.StringIO(csv_text.strip()))
147
+ rows = list(reader)
148
+ if not rows:
149
+ return ""
150
+
151
+ header = rows[0]
152
+ header_lower = [h.strip().lower() for h in header]
153
+ data = rows[1:]
154
+
155
+ t = ['<table style="border-collapse:collapse;width:100%;font-size:12px;font-family:\'SF Mono\',monospace;">']
156
+ t.append('<tr>')
157
+ t.append('<th style="border:1px solid #dee2e6;padding:6px 8px;background:#343a40;color:#fff;font-size:11px;">Row</th>')
158
+ for h in header:
159
+ t.append(f'<th style="border:1px solid #dee2e6;padding:6px 8px;background:#343a40;color:#fff;font-size:11px;">{h}</th>')
160
+ t.append('</tr>')
161
+
162
+ for i, row in enumerate(data):
163
+ rn = i + 1
164
+ bg = "#fff" if i % 2 == 0 else "#f8f9fa"
165
+ t.append(f'<tr style="background:{bg};">')
166
+ t.append(f'<td style="border:1px solid #dee2e6;padding:4px 8px;color:#adb5bd;text-align:center;font-size:11px;">{rn}</td>')
167
+ for j, val in enumerate(row):
168
+ col = header_lower[j] if j < len(header_lower) else ""
169
+ ck = (rn, col)
170
+ s = "border:1px solid #dee2e6;padding:4px 8px;"
171
+ tip = ""
172
+ badge = ""
173
+
174
+ issue = desc_map.get(ck)
175
+
176
+ if ck in correct:
177
+ s += "background:#d4edda;"
178
+ tip = f"FOUND: {issue.description}" if issue else ""
179
+ badge = '<span style="font-size:9px;background:#28a745;color:#fff;padding:1px 4px;border-radius:3px;margin-left:4px;">TP</span>'
180
+ elif ck in fp:
181
+ s += "background:#f8d7da;"
182
+ badge = '<span style="font-size:9px;background:#dc3545;color:#fff;padding:1px 4px;border-radius:3px;margin-left:4px;">FP</span>'
183
+ elif ck in missed:
184
+ s += "background:#fff3cd;"
185
+ tip = f"MISSED: {issue.description}" if issue else ""
186
+ badge = '<span style="font-size:9px;background:#856404;color:#fff;padding:1px 4px;border-radius:3px;margin-left:4px;">MISS</span>'
187
+
188
+ fx = fixed.get(ck)
189
+ proposed = fix_values.get(ck)
190
+ if fx == "correct":
191
+ s += "box-shadow:inset 0 0 0 2px #28a745;"
192
+ badge += '<span style="font-size:9px;background:#28a745;color:#fff;padding:1px 4px;border-radius:3px;margin-left:2px;">FIX</span>'
193
+ elif fx == "partial":
194
+ s += "box-shadow:inset 0 0 0 2px #ffc107;"
195
+ badge += '<span style="font-size:9px;background:#ffc107;color:#333;padding:1px 4px;border-radius:3px;margin-left:2px;">~FIX</span>'
196
+
197
+ dv = val if val.strip() else '<em style="color:#dc3545;font-style:italic;">empty</em>'
198
+
199
+ # Show proposed fix value below the corrupted value
200
+ fix_line = ""
201
+ if proposed is not None:
202
+ fix_color = "#28a745" if fx == "correct" else ("#b8860b" if fx == "partial" else "#dc3545")
203
+ fix_line = (
204
+ f'<div style="font-size:10px;color:{fix_color};margin-top:2px;'
205
+ f'border-top:1px dashed {fix_color};padding-top:2px;">'
206
+ f'\u2192 {proposed}</div>'
207
+ )
208
+
209
+ t.append(f'<td style="{s}" title="{tip}">{dv}{badge}{fix_line}</td>')
210
+ t.append('</tr>')
211
+ t.append('</table>')
212
+ return "".join(t)
213
+
214
+
215
+ LEGEND_HTML = (
216
+ '<div style="display:flex;gap:12px;flex-wrap:wrap;margin-top:10px;font-size:11px;">'
217
+ '<span style="background:#d4edda;padding:2px 8px;border-radius:4px;">Found (TP)</span>'
218
+ '<span style="background:#f8d7da;padding:2px 8px;border-radius:4px;">False Positive</span>'
219
+ '<span style="background:#fff3cd;padding:2px 8px;border-radius:4px;">Missed</span>'
220
+ '<span style="box-shadow:inset 0 0 0 2px #28a745;padding:2px 8px;border-radius:4px;">Fix Correct</span>'
221
+ '<span style="box-shadow:inset 0 0 0 2px #ffc107;padding:2px 8px;border-radius:4px;">Fix Partial</span>'
222
+ '</div>'
223
+ )
224
+
225
+
226
+ # ── Core replay logic ──
227
+
228
+ def _replay_task(task_id: str) -> list[dict]:
229
+ """Run the agent trajectory and collect per-step data."""
230
+ env = DataQAEnvironment()
231
+ obs = env.reset(task_id=task_id)
232
+ task = env._current_task
233
+ planted_keys = {i.to_key() for i in task.planted_issues}
234
+ steps_data = []
235
+
236
+ # Step 0: initial state
237
+ steps_data.append({
238
+ "label": "Initial — corrupted dataset",
239
+ "html": _csv_to_html(obs.dataset_csv, task.planted_issues, set(), set(), set(), {}),
240
+ "metrics": {"reward": 0.0, "tp": 0, "fp": 0, "fn": len(task.planted_issues),
241
+ "identify": 0.0, "fix": 0.0, "fixes_correct": 0},
242
+ "feedback": f"Task: {task.name}\nIssues to find: {obs.num_issues_hint}\n\n{task.description}",
243
+ })
244
+
245
+ trajectory = AGENT_TRAJECTORIES.get(task_id, [])
246
+ for i, step_data in enumerate(trajectory):
247
+ action = DataQAAction(
248
+ issues=step_data["issues"],
249
+ fixes=step_data.get("fixes", []),
250
+ task_id=task_id,
251
+ )
252
+ obs = env.step(action)
253
+
254
+ reported_keys = set()
255
+ for iss in step_data["issues"]:
256
+ key = parse_issue_key(iss)
257
+ if key:
258
+ reported_keys.add(key)
259
+
260
+ tp_keys = reported_keys & planted_keys
261
+ fp_keys = reported_keys - planted_keys
262
+ fn_keys = planted_keys - reported_keys
263
+
264
+ correct = {_kc(k) for k in tp_keys}
265
+ fp = {_kc(k) for k in fp_keys}
266
+ missed = {_kc(k) for k in fn_keys} if obs.done else set()
267
+
268
+ fixed: dict[tuple[int, str], str] = {}
269
+ for d in obs.metadata.get("fix_details", []):
270
+ c = (d["row"], d["col"])
271
+ fixed[c] = "correct" if d["score"] >= 0.99 else ("partial" if d["score"] > 0 else "wrong")
272
+
273
+ # Extract proposed fix values from the raw fix strings
274
+ fix_values: dict[tuple[int, str], str] = {}
275
+ from .environment import parse_fix
276
+ for raw_fix in step_data.get("fixes", []):
277
+ parsed = parse_fix(raw_fix)
278
+ if parsed:
279
+ row, col, val = parsed
280
+ fix_values[(row, col)] = val
281
+
282
+ html = _csv_to_html(obs.dataset_csv, task.planted_issues, correct, fp, missed, fixed, fix_values)
283
+
284
+ has_fixes = bool(step_data.get("fixes"))
285
+ if has_fixes:
286
+ label = f"Step {i+1} — identify + fix"
287
+ else:
288
+ label = f"Step {i+1} — identify only"
289
+
290
+ steps_data.append({
291
+ "label": label,
292
+ "html": html,
293
+ "metrics": {
294
+ "reward": obs.reward,
295
+ "tp": obs.metadata["tp"],
296
+ "fp": obs.metadata["fp"],
297
+ "fn": obs.metadata["fn"],
298
+ "identify": obs.metadata["identify_score"],
299
+ "fix": obs.metadata["fix_score"],
300
+ "fixes_correct": obs.metadata["fixes_correct"],
301
+ },
302
+ "feedback": obs.feedback,
303
+ })
304
+
305
+ return steps_data
306
+
307
+
308
+ def _kc(key: str) -> tuple[int, str]:
309
+ parts = key.split(",")
310
+ return (int(parts[0].split(":")[1]), parts[1].split(":")[1])
311
+
312
+
313
+ # ── Gradio app ──
314
+
315
+ def build_gradio_ui():
316
+ # Pre-compute all replays at startup
317
+ all_replays: dict[str, list[dict]] = {}
318
+ for tid in list_tasks():
319
+ all_replays[tid] = _replay_task(tid)
320
+
321
+ def show_step(task_id: str, step_idx: int):
322
+ replay = all_replays.get(task_id, [])
323
+ step_idx = int(step_idx)
324
+ if step_idx >= len(replay):
325
+ step_idx = len(replay) - 1
326
+ sd = replay[step_idx]
327
+ m = sd["metrics"]
328
+
329
+ # Reward color
330
+ r = m["reward"]
331
+ rc = "#28a745" if r >= 0.8 else ("#ffc107" if r >= 0.4 else "#dc3545")
332
+
333
+ cards = (
334
+ '<div style="display:flex;gap:10px;flex-wrap:wrap;margin-bottom:12px;">'
335
+ + _metric_card("Reward", f"{r:.2f}", rc)
336
+ + _metric_card("Found", str(m["tp"]), "#28a745")
337
+ + _metric_card("False Pos", str(m["fp"]), "#dc3545" if m["fp"] > 0 else "#28a745")
338
+ + _metric_card("Missed", str(m["fn"]), "#dc3545" if m["fn"] > 0 else "#28a745")
339
+ + _metric_card("Identify", f"{m['identify']:.2f}", "#333")
340
+ + _metric_card("Fix", f"{m['fix']:.2f}", "#333")
341
+ + '</div>'
342
+ )
343
+
344
+ full_html = (
345
+ f'<div style="font-size:14px;font-weight:600;margin-bottom:8px;color:#495057;">'
346
+ f'{sd["label"]}</div>'
347
+ + cards + sd["html"] + LEGEND_HTML
348
+ )
349
+
350
+ return full_html, sd["feedback"]
351
+
352
+ def on_task_change(task_id):
353
+ replay = all_replays.get(task_id, [])
354
+ max_step = len(replay) - 1
355
+ html, fb = show_step(task_id, 0)
356
+ return (
357
+ gr.update(maximum=max_step, value=0),
358
+ html,
359
+ fb,
360
+ )
361
+
362
+ def on_step_change(task_id, step_idx):
363
+ html, fb = show_step(task_id, step_idx)
364
+ return html, fb
365
+
366
+ # ── Live agent runner (connects to the env server) ──
367
+
368
+ live_env = DataQAEnvironment()
369
+ live_state: dict = {"obs": None, "task_id": "easy", "steps": []}
370
+
371
+ def live_reset(task_id):
372
+ obs = live_env.reset(task_id=task_id)
373
+ task = live_env._current_task
374
+ live_state["obs"] = obs
375
+ live_state["task_id"] = task_id
376
+ live_state["steps"] = []
377
+ html = _csv_to_html(obs.dataset_csv, task.planted_issues, set(), set(), set(), {})
378
+ info = f"**{task.name}** — {obs.num_issues_hint} issues to find, {obs.max_steps} steps max"
379
+ return html, info, "", "0.000"
380
+
381
+ def live_step(issues_text, fixes_text):
382
+ if live_state["obs"] is None:
383
+ return "Reset first.", "", "", ""
384
+ obs = live_state["obs"]
385
+ task = live_env._current_task
386
+ planted_keys = {i.to_key() for i in task.planted_issues}
387
+
388
+ issues = [l.strip() for l in issues_text.strip().split("\n") if l.strip()]
389
+ fixes = [l.strip() for l in fixes_text.strip().split("\n") if l.strip()] if fixes_text.strip() else []
390
+
391
+ action = DataQAAction(issues=issues, fixes=fixes, task_id=live_state["task_id"])
392
+ obs = live_env.step(action)
393
+ live_state["obs"] = obs
394
+
395
+ reported_keys = set()
396
+ for iss in issues:
397
+ key = parse_issue_key(iss)
398
+ if key:
399
+ reported_keys.add(key)
400
+
401
+ tp_keys = reported_keys & planted_keys
402
+ fp_keys = reported_keys - planted_keys
403
+ fn_keys = planted_keys - reported_keys
404
+
405
+ correct = {_kc(k) for k in tp_keys}
406
+ fp_set = {_kc(k) for k in fp_keys}
407
+ missed = {_kc(k) for k in fn_keys} if obs.done else set()
408
+
409
+ fixed: dict[tuple[int, str], str] = {}
410
+ for d in obs.metadata.get("fix_details", []):
411
+ c = (d["row"], d["col"])
412
+ fixed[c] = "correct" if d["score"] >= 0.99 else ("partial" if d["score"] > 0 else "wrong")
413
+
414
+ from .environment import parse_fix
415
+ fix_values: dict[tuple[int, str], str] = {}
416
+ for raw in fixes:
417
+ parsed = parse_fix(raw)
418
+ if parsed:
419
+ fix_values[(parsed[0], parsed[1])] = parsed[2]
420
+
421
+ html = _csv_to_html(obs.dataset_csv, task.planted_issues, correct, fp_set, missed, fixed, fix_values)
422
+
423
+ m = obs.metadata
424
+ r = obs.reward
425
+ rc = "#28a745" if r >= 0.8 else ("#ffc107" if r >= 0.4 else "#dc3545")
426
+ cards = (
427
+ '<div style="display:flex;gap:10px;flex-wrap:wrap;margin-bottom:12px;">'
428
+ + _metric_card("Reward", f"{r:.2f}", rc)
429
+ + _metric_card("Found", str(m["tp"]), "#28a745")
430
+ + _metric_card("False Pos", str(m["fp"]), "#dc3545" if m["fp"] > 0 else "#28a745")
431
+ + _metric_card("Missed", str(m["fn"]), "#dc3545" if m["fn"] > 0 else "#28a745")
432
+ + '</div>'
433
+ )
434
+ full_html = cards + html + LEGEND_HTML
435
+ return full_html, obs.feedback, f"{r:.3f}", ""
436
+
437
+ # ── Build the UI ──
438
+
439
+ with gr.Blocks(title="DataQA Environment") as demo:
440
+ gr.Markdown(
441
+ "# DataQA — Data Quality Assurance Environment\n"
442
+ "Two-phase RL environment: **Identify** data quality issues, then **Fix** them."
443
+ )
444
+
445
+ with gr.Tabs():
446
+ # ── Tab 1: Demo replay ──
447
+ with gr.Tab("Demo (Baseline Agent)"):
448
+ gr.Markdown(
449
+ "*Replay of the baseline Qwen-72B agent. "
450
+ "Use the slider to step through the agent's trajectory.*"
451
+ )
452
+ with gr.Row():
453
+ task_dd = gr.Dropdown(choices=list_tasks(), value="easy", label="Task", scale=1)
454
+ step_slider = gr.Slider(minimum=0, maximum=2, step=1, value=0, label="Step", scale=3)
455
+
456
+ viz_html = gr.HTML()
457
+ feedback_box = gr.Textbox(label="Agent Feedback", lines=10, interactive=False)
458
+
459
+ task_dd.change(on_task_change, inputs=[task_dd], outputs=[step_slider, viz_html, feedback_box])
460
+ step_slider.change(on_step_change, inputs=[task_dd, step_slider], outputs=[viz_html, feedback_box])
461
+ demo.load(on_task_change, inputs=[task_dd], outputs=[step_slider, viz_html, feedback_box])
462
+
463
+ # ── Tab 2: Try your own agent ──
464
+ with gr.Tab("Try Your Own Agent"):
465
+ gr.Markdown(
466
+ "*Submit your own issues and fixes to see how the environment scores them. "
467
+ "This is the same environment the baseline agent talks to.*"
468
+ )
469
+ with gr.Row():
470
+ live_task_dd = gr.Dropdown(choices=list_tasks(), value="easy", label="Task", scale=1)
471
+ live_reset_btn = gr.Button("Reset", variant="primary", scale=1)
472
+
473
+ with gr.Row():
474
+ live_info = gr.Markdown()
475
+ live_reward = gr.Textbox(label="Reward", interactive=False, scale=1)
476
+
477
+ live_viz = gr.HTML()
478
+
479
+ with gr.Row():
480
+ live_issues = gr.Textbox(
481
+ label="Issues (one per line)",
482
+ placeholder="row:4,col:name,issue:missing_value\nrow:7,col:salary,issue:wrong_type",
483
+ lines=5,
484
+ )
485
+ live_fixes = gr.Textbox(
486
+ label="Fixes (one per line, optional)",
487
+ placeholder="row:4,col:name,fix:David Kim\nrow:7,col:salary,fix:75000",
488
+ lines=5,
489
+ )
490
+
491
+ live_step_btn = gr.Button("Submit Step", variant="primary")
492
+ live_feedback = gr.Textbox(label="Feedback", lines=10, interactive=False)
493
+
494
+ live_reset_btn.click(
495
+ live_reset, inputs=[live_task_dd],
496
+ outputs=[live_viz, live_info, live_feedback, live_reward],
497
+ )
498
+ live_step_btn.click(
499
+ live_step, inputs=[live_issues, live_fixes],
500
+ outputs=[live_viz, live_feedback, live_reward, live_issues],
501
+ )
502
+
503
+ return demo
504
+
505
+
506
+ if __name__ == "__main__":
507
+ demo = build_gradio_ui()
508
+ demo.launch()
dataqa_env/server/tasks.py CHANGED
@@ -25,6 +25,7 @@ class PlantedIssue:
25
  col: str
26
  issue_type: str
27
  description: str
 
28
 
29
  def to_key(self) -> str:
30
  return f"row:{self.row},col:{self.col},issue:{self.issue_type}"
@@ -42,6 +43,28 @@ class Task:
42
  corrupted_csv: str = ""
43
  max_steps: int = 3
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
  def _csv_to_rows(csv_text: str) -> List[List[str]]:
47
  reader = csv.reader(io.StringIO(csv_text.strip()))
@@ -72,7 +95,17 @@ def create_task_easy(seed: int = 42) -> Task:
72
  107,Grace Lee,grace.lee@company.com,Marketing,75000,2021-12-01
73
  108,Hank Brown,hank.brown@company.com,Sales,65000,2023-04-18
74
  109,Iris Patel,iris.patel@company.com,HR,73000,2020-02-28
75
- 110,Jack Taylor,jack.taylor@company.com,Engineering,97000,2022-09-14"""
 
 
 
 
 
 
 
 
 
 
76
 
77
  schema_desc = """Columns:
78
  - employee_id: integer, unique, range 100-999
@@ -93,29 +126,43 @@ def create_task_easy(seed: int = 42) -> Task:
93
  data = rows[1:]
94
  issues: List[PlantedIssue] = []
95
 
96
- # Issue 1: Missing value - null out a name
97
  r = 3 # row index in data (0-based), displayed as row 4 in CSV
98
  data[r][1] = ""
99
  issues.append(PlantedIssue(row=r + 1, col="name", issue_type="missing_value",
100
- description="Empty name field"))
101
 
102
- # Issue 2: Wrong type - salary as text
103
  r = 6
104
  data[r][4] = "seventy-five thousand"
105
  issues.append(PlantedIssue(row=r + 1, col="salary", issue_type="wrong_type",
106
- description="Salary is text instead of integer"))
107
 
108
- # Issue 3: Duplicate row
109
  dup_source = 1
110
  data.append(list(data[dup_source]))
111
  issues.append(PlantedIssue(row=len(data), col="employee_id", issue_type="duplicate_row",
112
- description=f"Exact duplicate of row {dup_source + 1}"))
113
 
114
- # Issue 4: Out of range salary
115
  r = 8
116
  data[r][4] = "5000"
117
  issues.append(PlantedIssue(row=r + 1, col="salary", issue_type="out_of_range",
118
- description="Salary 5000 is below minimum 50000"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
 
120
  corrupted = _rows_to_csv([header] + data)
121
 
@@ -163,7 +210,17 @@ ORD-016,CUST-114,Bluetooth Speaker,Electronics,1,49.99,2024-01-30,UK,delivered,4
163
  ORD-017,CUST-115,Jump Rope,Sports,2,8.99,2024-01-31,US,shipped,17.98
164
  ORD-018,CUST-116,Coffee Table Book,Books,1,32.00,2024-02-01,CA,delivered,32.00
165
  ORD-019,CUST-117,Ergonomic Chair,Home,1,450.00,2024-02-02,US,processing,450.00
166
- ORD-020,CUST-118,Fitness Tracker,Electronics,1,79.99,2024-02-03,AU,delivered,79.99"""
 
 
 
 
 
 
 
 
 
 
167
 
168
  schema_desc = """Columns:
169
  - order_id: string, unique, format ORD-NNN
@@ -190,41 +247,55 @@ ORD-020,CUST-118,Fitness Tracker,Electronics,1,79.99,2024-02-03,AU,delivered,79.
190
  data = rows[1:]
191
  issues: List[PlantedIssue] = []
192
 
193
- # Issue 1: total doesn't match quantity * unit_price
194
  r = 4 # ORD-005
195
  data[r][9] = "84.00" # should be 42.00 (qty=1, price=42.00)
196
  issues.append(PlantedIssue(row=r + 1, col="total", issue_type="inconsistent_value",
197
- description="total (84.00) != quantity (1) * unit_price (42.00)"))
198
 
199
- # Issue 2: Invalid category
200
  r = 9 # ORD-010
201
  data[r][3] = "Fitness" # should be Sports
202
  issues.append(PlantedIssue(row=r + 1, col="category", issue_type="format_violation",
203
- description="'Fitness' is not in allowed categories"))
204
 
205
- # Issue 3: Missing value in product_name
206
  r = 13 # ORD-014
207
  data[r][2] = ""
208
  issues.append(PlantedIssue(row=r + 1, col="product_name", issue_type="missing_value",
209
- description="Empty product_name"))
210
 
211
- # Issue 4: Out of range quantity
212
  r = 16 # ORD-017
213
  data[r][4] = "-1"
214
  issues.append(PlantedIssue(row=r + 1, col="quantity", issue_type="out_of_range",
215
- description="Negative quantity"))
216
 
217
- # Issue 5: Duplicate order_id
218
  r = 18 # ORD-019
219
  data[r][0] = "ORD-003"
220
  issues.append(PlantedIssue(row=r + 1, col="order_id", issue_type="duplicate_row",
221
- description="Duplicate order_id ORD-003"))
222
 
223
- # Issue 6: Wrong date format
224
  r = 11 # ORD-012
225
  data[r][6] = "26/01/2024"
226
  issues.append(PlantedIssue(row=r + 1, col="order_date", issue_type="format_violation",
227
- description="Date format DD/MM/YYYY instead of YYYY-MM-DD"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
 
229
  corrupted = _rows_to_csv([header] + data)
230
 
@@ -267,7 +338,22 @@ EXP-011,yolov5-m,coco-2017,118287,5000,40670,0.01,32,300,0.032,0.045,0.0,10.2,24
267
  EXP-012,wav2vec2,librispeech,281241,5567,2620,0.0001,8,20,0.92,1.05,0.0,12.6,15.0,2024-03-13T11:30:00
268
  EXP-013,clip-base,cc3m,2818102,15000,15000,0.00001,256,32,2.15,2.38,0.0,22.4,48.0,2024-03-14T08:00:00
269
  EXP-014,detr,coco-2017,118287,5000,40670,0.0001,4,500,1.85,2.12,0.0,16.0,72.0,2024-03-15T10:00:00
270
- EXP-015,whisper-small,common-voice,520000,16000,16000,0.00005,16,5,0.55,0.68,0.0,7.4,6.5,2024-03-16T14:00:00"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
271
 
272
  schema_desc = """Columns:
273
  - experiment_id: string, unique, format EXP-NNN
@@ -301,53 +387,83 @@ EXP-015,whisper-small,common-voice,520000,16000,16000,0.00005,16,5,0.55,0.68,0.0
301
  data = rows[1:]
302
  issues: List[PlantedIssue] = []
303
 
304
- # Issue 1: Data leakage signal — val_loss much lower than train_loss
305
  r = 4 # EXP-005
306
  data[r][10] = "0.15" # val_loss=0.15 but train_loss=0.28 → suspicious
307
  issues.append(PlantedIssue(row=r + 1, col="val_loss", issue_type="inconsistent_value",
308
- description="val_loss (0.15) significantly less than train_loss (0.28), potential data leakage"))
 
309
 
310
- # Issue 2: Batch size not power of 2
311
  r = 8 # EXP-009
312
  data[r][7] = "250" # not a power of 2
313
  issues.append(PlantedIssue(row=r + 1, col="batch_size", issue_type="format_violation",
314
- description="batch_size 250 is not a power of 2"))
315
 
316
- # Issue 3: GPU memory unreasonable for model
317
  r = 6 # EXP-007 resnet18 on cifar10
318
  data[r][12] = "42.5" # resnet18 shouldn't need 42.5 GB
319
  issues.append(PlantedIssue(row=r + 1, col="gpu_memory_gb", issue_type="statistical_outlier",
320
- description="resnet18 on cifar10 using 42.5 GB GPU memory is unreasonable"))
 
321
 
322
- # Issue 4: Timestamp out of order
323
  r = 10 # EXP-011
324
  data[r][14] = "2024-03-02T09:00:00" # should be after EXP-010's timestamp
325
  issues.append(PlantedIssue(row=r + 1, col="timestamp", issue_type="inconsistent_value",
326
- description="Timestamp 2024-03-02 is before EXP-010's timestamp 2024-03-11"))
 
327
 
328
- # Issue 5: Train size smaller than test size
329
  r = 9 # EXP-010
330
  data[r][3] = "500" # train_size=500 but test_size=1821
331
  issues.append(PlantedIssue(row=r + 1, col="train_size", issue_type="inconsistent_value",
332
- description="train_size (500) is smaller than test_size (1821)"))
 
333
 
334
- # Issue 6: Negative training time
335
  r = 13 # EXP-014
336
  data[r][13] = "-72.0"
337
  issues.append(PlantedIssue(row=r + 1, col="training_time_hours", issue_type="out_of_range",
338
- description="Negative training time"))
339
 
340
- # Issue 7: Learning rate out of range
341
  r = 12 # EXP-013
342
  data[r][6] = "2.5" # way too high
343
  issues.append(PlantedIssue(row=r + 1, col="learning_rate", issue_type="out_of_range",
344
- description="Learning rate 2.5 exceeds maximum of 1.0"))
345
 
346
- # Issue 8: Missing model name (subtlesingle space instead of empty)
347
  r = 14 # EXP-015
348
  data[r][1] = " "
349
  issues.append(PlantedIssue(row=r + 1, col="model_name", issue_type="missing_value",
350
- description="model_name is whitespace-only"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
351
 
352
  corrupted = _rows_to_csv([header] + data)
353
 
@@ -370,6 +486,123 @@ EXP-015,whisper-small,common-voice,520000,16000,16000,0.00005,16,5,0.55,0.68,0.0
370
  )
371
 
372
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
373
  # ---------------------------------------------------------------------------
374
  # Task registry
375
  # ---------------------------------------------------------------------------
 
25
  col: str
26
  issue_type: str
27
  description: str
28
+ difficulty: float = 1.0 # 1.0=easy, 2.0=medium, 3.0=hard (for weighted reward)
29
 
30
  def to_key(self) -> str:
31
  return f"row:{self.row},col:{self.col},issue:{self.issue_type}"
 
43
  corrupted_csv: str = ""
44
  max_steps: int = 3
45
 
46
+ def get_clean_value(self, row: int, col: str) -> str | None:
47
+ """
48
+ Look up the original clean value for a given (row, col).
49
+ Row is 1-indexed (data row after header).
50
+ Returns None if row/col is out of bounds or column not found.
51
+ """
52
+ rows = _csv_to_rows(self.clean_csv)
53
+ if len(rows) < 2:
54
+ return None
55
+ header = [h.strip().lower() for h in rows[0]]
56
+ if col.lower() not in header:
57
+ return None
58
+ col_idx = header.index(col.lower())
59
+ data_row_idx = row # row is 1-indexed, rows[0] is header, so rows[row] is the data row
60
+ if data_row_idx < 1 or data_row_idx >= len(rows):
61
+ return None
62
+ return rows[data_row_idx][col_idx].strip()
63
+
64
+ def get_planted_issue_map(self) -> dict:
65
+ """Return dict mapping issue key -> PlantedIssue for quick lookups."""
66
+ return {issue.to_key(): issue for issue in self.planted_issues}
67
+
68
 
69
  def _csv_to_rows(csv_text: str) -> List[List[str]]:
70
  reader = csv.reader(io.StringIO(csv_text.strip()))
 
95
  107,Grace Lee,grace.lee@company.com,Marketing,75000,2021-12-01
96
  108,Hank Brown,hank.brown@company.com,Sales,65000,2023-04-18
97
  109,Iris Patel,iris.patel@company.com,HR,73000,2020-02-28
98
+ 110,Jack Taylor,jack.taylor@company.com,Engineering,97000,2022-09-14
99
+ 111,Kevin Zhang,kevin.zhang@company.com,Engineering,91000,2021-05-22
100
+ 112,Laura Adams,laura.adams@company.com,Sales,69000,2022-11-03
101
+ 113,Mike Torres,mike.torres@company.com,Marketing,74000,2020-08-17
102
+ 114,Nina Sharma,nina.sharma@company.com,HR,76000,2019-04-30
103
+ 115,Oscar Rivera,oscar.rivera@company.com,Engineering,105000,2018-12-10
104
+ 116,Paula Green,paula.green@company.com,Sales,67000,2023-06-25
105
+ 117,Quinn Murphy,quinn.murphy@company.com,Marketing,78000,2021-03-08
106
+ 118,Rosa Diaz,rosa.diaz@company.com,Engineering,99000,2022-01-19
107
+ 119,Sam Cooper,sam.cooper@company.com,HR,70000,2020-10-05
108
+ 120,Tara Singh,tara.singh@company.com,Sales,66000,2023-02-14"""
109
 
110
  schema_desc = """Columns:
111
  - employee_id: integer, unique, range 100-999
 
126
  data = rows[1:]
127
  issues: List[PlantedIssue] = []
128
 
129
+ # Issue 1: Missing value - null out a name (easy to spot)
130
  r = 3 # row index in data (0-based), displayed as row 4 in CSV
131
  data[r][1] = ""
132
  issues.append(PlantedIssue(row=r + 1, col="name", issue_type="missing_value",
133
+ description="Empty name field", difficulty=1.0))
134
 
135
+ # Issue 2: Wrong type - salary as text (easy to spot)
136
  r = 6
137
  data[r][4] = "seventy-five thousand"
138
  issues.append(PlantedIssue(row=r + 1, col="salary", issue_type="wrong_type",
139
+ description="Salary is text instead of integer", difficulty=1.0))
140
 
141
+ # Issue 3: Duplicate row (moderate — requires cross-row comparison)
142
  dup_source = 1
143
  data.append(list(data[dup_source]))
144
  issues.append(PlantedIssue(row=len(data), col="employee_id", issue_type="duplicate_row",
145
+ description=f"Exact duplicate of row {dup_source + 1}", difficulty=1.5))
146
 
147
+ # Issue 4: Out of range salary (easy to spot)
148
  r = 8
149
  data[r][4] = "5000"
150
  issues.append(PlantedIssue(row=r + 1, col="salary", issue_type="out_of_range",
151
+ description="Salary 5000 is below minimum 50000", difficulty=1.0))
152
+
153
+ # Issue 5: Email doesn't match name pattern (moderate — cross-column check)
154
+ r = 14 # Oscar Rivera -> email should be oscar.rivera@company.com
155
+ data[r][2] = "john.doe@company.com"
156
+ issues.append(PlantedIssue(row=r + 1, col="email", issue_type="inconsistent_value",
157
+ description="Email john.doe@company.com doesn't match name Oscar Rivera",
158
+ difficulty=1.5))
159
+
160
+ # Issue 6: Future start date (requires knowing current date context)
161
+ r = 17 # Rosa Diaz
162
+ data[r][5] = "2027-06-15"
163
+ issues.append(PlantedIssue(row=r + 1, col="start_date", issue_type="out_of_range",
164
+ description="Start date 2027-06-15 is in the future (beyond 2025-12-31)",
165
+ difficulty=1.5))
166
 
167
  corrupted = _rows_to_csv([header] + data)
168
 
 
210
  ORD-017,CUST-115,Jump Rope,Sports,2,8.99,2024-01-31,US,shipped,17.98
211
  ORD-018,CUST-116,Coffee Table Book,Books,1,32.00,2024-02-01,CA,delivered,32.00
212
  ORD-019,CUST-117,Ergonomic Chair,Home,1,450.00,2024-02-02,US,processing,450.00
213
+ ORD-020,CUST-118,Fitness Tracker,Electronics,1,79.99,2024-02-03,AU,delivered,79.99
214
+ ORD-021,CUST-119,Laptop Sleeve,Electronics,1,24.99,2024-02-04,US,delivered,24.99
215
+ ORD-022,CUST-120,Hiking Backpack,Sports,1,65.00,2024-02-05,CA,shipped,65.00
216
+ ORD-023,CUST-121,Machine Learning Book,Books,1,54.99,2024-02-06,UK,delivered,54.99
217
+ ORD-024,CUST-122,Plant Pot Set,Home,3,15.00,2024-02-07,US,delivered,45.00
218
+ ORD-025,CUST-123,Noise Cancelling Headphones,Electronics,1,199.99,2024-02-08,DE,shipped,199.99
219
+ ORD-026,CUST-124,Basketball,Sports,1,29.99,2024-02-09,US,delivered,29.99
220
+ ORD-027,CUST-125,Cookbook Collection,Books,2,22.50,2024-02-10,AU,delivered,45.00
221
+ ORD-028,CUST-126,Smart Plug,Home,4,12.99,2024-02-11,US,processing,51.96
222
+ ORD-029,CUST-127,Wireless Charger,Electronics,1,34.99,2024-02-12,UK,delivered,34.99
223
+ ORD-030,CUST-128,Dumbbells Set,Sports,1,89.00,2024-02-13,US,shipped,89.00"""
224
 
225
  schema_desc = """Columns:
226
  - order_id: string, unique, format ORD-NNN
 
247
  data = rows[1:]
248
  issues: List[PlantedIssue] = []
249
 
250
+ # Issue 1: total doesn't match quantity * unit_price (requires cross-column check)
251
  r = 4 # ORD-005
252
  data[r][9] = "84.00" # should be 42.00 (qty=1, price=42.00)
253
  issues.append(PlantedIssue(row=r + 1, col="total", issue_type="inconsistent_value",
254
+ description="total (84.00) != quantity (1) * unit_price (42.00)", difficulty=2.0))
255
 
256
+ # Issue 2: Invalid category (requires knowing the allowed set)
257
  r = 9 # ORD-010
258
  data[r][3] = "Fitness" # should be Sports
259
  issues.append(PlantedIssue(row=r + 1, col="category", issue_type="format_violation",
260
+ description="'Fitness' is not in allowed categories", difficulty=1.5))
261
 
262
+ # Issue 3: Missing value in product_name (easy to spot)
263
  r = 13 # ORD-014
264
  data[r][2] = ""
265
  issues.append(PlantedIssue(row=r + 1, col="product_name", issue_type="missing_value",
266
+ description="Empty product_name", difficulty=1.0))
267
 
268
+ # Issue 4: Out of range quantity (easy to spot)
269
  r = 16 # ORD-017
270
  data[r][4] = "-1"
271
  issues.append(PlantedIssue(row=r + 1, col="quantity", issue_type="out_of_range",
272
+ description="Negative quantity", difficulty=1.0))
273
 
274
+ # Issue 5: Duplicate order_id (requires cross-row comparison)
275
  r = 18 # ORD-019
276
  data[r][0] = "ORD-003"
277
  issues.append(PlantedIssue(row=r + 1, col="order_id", issue_type="duplicate_row",
278
+ description="Duplicate order_id ORD-003", difficulty=1.5))
279
 
280
+ # Issue 6: Wrong date format (moderate — format mismatch)
281
  r = 11 # ORD-012
282
  data[r][6] = "26/01/2024"
283
  issues.append(PlantedIssue(row=r + 1, col="order_date", issue_type="format_violation",
284
+ description="Date format DD/MM/YYYY instead of YYYY-MM-DD", difficulty=1.5))
285
+
286
+ # Issue 7: Invalid country code (requires ISO knowledge)
287
+ r = 23 # ORD-024
288
+ data[r][7] = "XX" # not a valid ISO country code
289
+ issues.append(PlantedIssue(row=r + 1, col="shipping_country", issue_type="format_violation",
290
+ description="'XX' is not a valid ISO 2-letter country code", difficulty=1.5))
291
+
292
+ # Issue 8: Status-date inconsistency — order from Feb 13 still "processing" is suspicious
293
+ # but more importantly: delivered order with a future date
294
+ r = 28 # ORD-029
295
+ data[r][6] = "2025-12-25" # future date but status is "delivered"
296
+ issues.append(PlantedIssue(row=r + 1, col="order_date", issue_type="inconsistent_value",
297
+ description="Order date 2025-12-25 is in the future but status is 'delivered'",
298
+ difficulty=2.0))
299
 
300
  corrupted = _rows_to_csv([header] + data)
301
 
 
338
  EXP-012,wav2vec2,librispeech,281241,5567,2620,0.0001,8,20,0.92,1.05,0.0,12.6,15.0,2024-03-13T11:30:00
339
  EXP-013,clip-base,cc3m,2818102,15000,15000,0.00001,256,32,2.15,2.38,0.0,22.4,48.0,2024-03-14T08:00:00
340
  EXP-014,detr,coco-2017,118287,5000,40670,0.0001,4,500,1.85,2.12,0.0,16.0,72.0,2024-03-15T10:00:00
341
+ EXP-015,whisper-small,common-voice,520000,16000,16000,0.00005,16,5,0.55,0.68,0.0,7.4,6.5,2024-03-16T14:00:00
342
+ EXP-016,mobilenet-v3,imagenet-1k,1281167,50000,100000,0.004,128,150,0.92,1.05,72.8,4.1,18.0,2024-03-17T08:30:00
343
+ EXP-017,albert-base,mnli,392702,9815,9796,0.00002,32,5,0.32,0.41,83.1,6.2,1.8,2024-03-18T11:00:00
344
+ EXP-018,gpt-neo-1.3b,pile-subset,1500000,50000,50000,0.0002,8,2,2.85,2.98,0.0,18.5,36.0,2024-03-19T14:00:00
345
+ EXP-019,swin-tiny,imagenet-1k,1281167,50000,100000,0.001,256,300,0.78,0.95,78.2,8.6,42.0,2024-03-20T09:00:00
346
+ EXP-020,deberta-large,squad-v2,130319,11873,8862,0.00001,16,5,0.35,0.42,85.7,15.2,4.5,2024-03-21T10:30:00
347
+ EXP-021,yolov8-s,coco-2017,118287,5000,40670,0.01,64,200,0.028,0.038,0.0,6.8,16.0,2024-03-22T13:00:00
348
+ EXP-022,bart-base,xsum,204045,11332,11334,0.0001,32,10,1.22,1.38,0.0,8.4,6.2,2024-03-23T15:30:00
349
+ EXP-023,convnext-tiny,imagenet-1k,1281167,50000,100000,0.002,256,300,0.74,0.92,79.5,7.2,38.0,2024-03-24T08:00:00
350
+ EXP-024,xlm-roberta,xnli,392702,2490,5010,0.00002,16,10,0.41,0.48,82.3,12.4,5.8,2024-03-25T11:00:00
351
+ EXP-025,stable-diffusion,laion-400m,400000000,10000,10000,0.0001,4,1,0.45,0.52,0.0,24.0,168.0,2024-03-26T09:00:00
352
+ EXP-026,phi-2,dolly-15k,15011,500,500,0.00005,8,3,0.82,0.95,0.0,10.2,2.5,2024-03-27T14:00:00
353
+ EXP-027,dino-v2,imagenet-1k,1281167,50000,100000,0.0005,64,100,0.42,0.58,0.0,11.8,28.0,2024-03-28T10:00:00
354
+ EXP-028,electra-small,glue-mrpc,3668,408,1725,0.0001,32,10,0.38,0.44,87.2,3.8,0.8,2024-03-29T16:00:00
355
+ EXP-029,sam-base,sa-1b,11000000,50000,50000,0.0001,4,1,0.95,1.08,0.0,16.4,96.0,2024-03-30T08:00:00
356
+ EXP-030,llama2-13b,oasst1,84437,4401,4401,0.00001,2,3,0.78,0.88,0.0,52.0,12.0,2024-03-31T12:00:00"""
357
 
358
  schema_desc = """Columns:
359
  - experiment_id: string, unique, format EXP-NNN
 
387
  data = rows[1:]
388
  issues: List[PlantedIssue] = []
389
 
390
+ # Issue 1: Data leakage signal — val_loss much lower than train_loss (hard — requires ML knowledge)
391
  r = 4 # EXP-005
392
  data[r][10] = "0.15" # val_loss=0.15 but train_loss=0.28 → suspicious
393
  issues.append(PlantedIssue(row=r + 1, col="val_loss", issue_type="inconsistent_value",
394
+ description="val_loss (0.15) significantly less than train_loss (0.28), potential data leakage",
395
+ difficulty=3.0))
396
 
397
+ # Issue 2: Batch size not power of 2 (moderate — domain convention)
398
  r = 8 # EXP-009
399
  data[r][7] = "250" # not a power of 2
400
  issues.append(PlantedIssue(row=r + 1, col="batch_size", issue_type="format_violation",
401
+ description="batch_size 250 is not a power of 2", difficulty=2.0))
402
 
403
+ # Issue 3: GPU memory unreasonable for model (hard — requires model size reasoning)
404
  r = 6 # EXP-007 resnet18 on cifar10
405
  data[r][12] = "42.5" # resnet18 shouldn't need 42.5 GB
406
  issues.append(PlantedIssue(row=r + 1, col="gpu_memory_gb", issue_type="statistical_outlier",
407
+ description="resnet18 on cifar10 using 42.5 GB GPU memory is unreasonable",
408
+ difficulty=3.0))
409
 
410
+ # Issue 4: Timestamp out of order (moderate — requires sequential comparison)
411
  r = 10 # EXP-011
412
  data[r][14] = "2024-03-02T09:00:00" # should be after EXP-010's timestamp
413
  issues.append(PlantedIssue(row=r + 1, col="timestamp", issue_type="inconsistent_value",
414
+ description="Timestamp 2024-03-02 is before EXP-010's timestamp 2024-03-11",
415
+ difficulty=2.0))
416
 
417
+ # Issue 5: Train size smaller than test size (moderate — cross-column logic)
418
  r = 9 # EXP-010
419
  data[r][3] = "500" # train_size=500 but test_size=1821
420
  issues.append(PlantedIssue(row=r + 1, col="train_size", issue_type="inconsistent_value",
421
+ description="train_size (500) is smaller than test_size (1821)",
422
+ difficulty=2.0))
423
 
424
+ # Issue 6: Negative training time (easy to spot)
425
  r = 13 # EXP-014
426
  data[r][13] = "-72.0"
427
  issues.append(PlantedIssue(row=r + 1, col="training_time_hours", issue_type="out_of_range",
428
+ description="Negative training time", difficulty=1.0))
429
 
430
+ # Issue 7: Learning rate out of range (easy to spot)
431
  r = 12 # EXP-013
432
  data[r][6] = "2.5" # way too high
433
  issues.append(PlantedIssue(row=r + 1, col="learning_rate", issue_type="out_of_range",
434
+ description="Learning rate 2.5 exceeds maximum of 1.0", difficulty=1.5))
435
 
436
+ # Issue 8: Missing model name (hardwhitespace-only is subtle)
437
  r = 14 # EXP-015
438
  data[r][1] = " "
439
  issues.append(PlantedIssue(row=r + 1, col="model_name", issue_type="missing_value",
440
+ description="model_name is whitespace-only", difficulty=2.5))
441
+
442
+ # Issue 9: Training time impossibly fast for dataset size and epochs
443
+ # EXP-004: vit-base on imagenet-1k, 300 epochs, but only 96 hours is plausible.
444
+ # Let's make EXP-009: efficientnet-b0 on imagenet-1k, 350 epochs = should take ~40+ hours
445
+ # but we set it to 0.5 hours — impossible for 1.2M images * 350 epochs
446
+ r = 8 # EXP-009 (same row as batch_size issue, different column)
447
+ data[r][13] = "0.5" # 30 minutes for 350 epochs on imagenet? impossible
448
+ issues.append(PlantedIssue(row=r + 1, col="training_time_hours", issue_type="statistical_outlier",
449
+ description="0.5 hours for 350 epochs on imagenet-1k (1.2M images) is impossibly fast",
450
+ difficulty=3.0))
451
+
452
+ # Issue 10: test_accuracy of 95.1% for roberta-large on SST-2 with train_size=500
453
+ # is suspiciously high — SOTA is ~96% with full dataset (67k). With only 500 training
454
+ # samples, 95.1% accuracy suggests data contamination or evaluation bug
455
+ r = 9 # EXP-010 (same row as train_size issue, different column)
456
+ # train_size is already corrupted to 500, but the test_accuracy 95.1 is from the
457
+ # original full-dataset run — this cross-column inconsistency is the real issue
458
+ # We don't modify the value — the inconsistency emerges from the train_size corruption
459
+ # So let's use a different row. EXP-001: resnet50 on imagenet, accuracy 76.3 is fine.
460
+ # Instead: EXP-012 wav2vec2 on librispeech — set test_accuracy to 98.5 (way too high
461
+ # for a speech model with only 20 epochs, SOTA is ~96% with much more training)
462
+ r = 11 # EXP-012
463
+ data[r][11] = "98.5" # wav2vec2 with 20 epochs shouldn't hit 98.5% — SOTA is ~96%
464
+ issues.append(PlantedIssue(row=r + 1, col="test_accuracy", issue_type="statistical_outlier",
465
+ description="test_accuracy 98.5% for wav2vec2 with only 20 epochs exceeds known SOTA (~96%), likely evaluation error",
466
+ difficulty=3.0))
467
 
468
  corrupted = _rows_to_csv([header] + data)
469
 
 
486
  )
487
 
488
 
489
+ # ---------------------------------------------------------------------------
490
+ # Contamination rules for extensible task creation
491
+ # ---------------------------------------------------------------------------
492
+
493
+ # Each contamination rule is a callable: (rows, header, col_idx, row_idx, rng) -> (new_value, PlantedIssue)
494
+ # Users can define their own and register them.
495
+
496
+ CONTAMINATION_RULES = {
497
+ "missing_value": lambda rows, header, col_idx, row_idx, rng: (
498
+ "",
499
+ PlantedIssue(
500
+ row=row_idx + 1, col=header[col_idx], issue_type="missing_value",
501
+ description=f"Empty {header[col_idx]} field", difficulty=1.0,
502
+ ),
503
+ ),
504
+ "whitespace_value": lambda rows, header, col_idx, row_idx, rng: (
505
+ " ",
506
+ PlantedIssue(
507
+ row=row_idx + 1, col=header[col_idx], issue_type="missing_value",
508
+ description=f"Whitespace-only {header[col_idx]} field", difficulty=2.5,
509
+ ),
510
+ ),
511
+ "wrong_type_text": lambda rows, header, col_idx, row_idx, rng: (
512
+ rng.choice(["not-a-number", "N/A", "null", "undefined"]),
513
+ PlantedIssue(
514
+ row=row_idx + 1, col=header[col_idx], issue_type="wrong_type",
515
+ description=f"{header[col_idx]} is text instead of expected type", difficulty=1.0,
516
+ ),
517
+ ),
518
+ "negative_value": lambda rows, header, col_idx, row_idx, rng: (
519
+ str(-abs(float(rows[row_idx][col_idx]) if rows[row_idx][col_idx] else 1)),
520
+ PlantedIssue(
521
+ row=row_idx + 1, col=header[col_idx], issue_type="out_of_range",
522
+ description=f"Negative {header[col_idx]}", difficulty=1.0,
523
+ ),
524
+ ),
525
+ }
526
+
527
+
528
+ def create_task_from_config(
529
+ task_id: str,
530
+ name: str,
531
+ description: str,
532
+ schema_description: str,
533
+ validation_rules: str,
534
+ clean_csv: str,
535
+ contaminations: List[dict],
536
+ max_steps: int = 3,
537
+ seed: int = 42,
538
+ ) -> Task:
539
+ """
540
+ Create a custom task from a configuration dict.
541
+
542
+ Each contamination entry should have:
543
+ - rule: str (key in CONTAMINATION_RULES) or callable
544
+ - row: int (0-based row index in data)
545
+ - col: int (column index in header)
546
+ - difficulty: float (optional, overrides rule default)
547
+
548
+ Example:
549
+ contaminations = [
550
+ {"rule": "missing_value", "row": 2, "col": 1, "difficulty": 1.5},
551
+ {"rule": "negative_value", "row": 5, "col": 4},
552
+ ]
553
+ """
554
+ rng = random.Random(seed)
555
+ rows = _csv_to_rows(clean_csv)
556
+ header = rows[0]
557
+ data = rows[1:]
558
+ issues: List[PlantedIssue] = []
559
+
560
+ for spec in contaminations:
561
+ rule = spec["rule"]
562
+ row_idx = spec["row"]
563
+ col_idx = spec["col"]
564
+
565
+ if callable(rule):
566
+ new_val, issue = rule(data, header, col_idx, row_idx, rng)
567
+ elif rule in CONTAMINATION_RULES:
568
+ new_val, issue = CONTAMINATION_RULES[rule](data, header, col_idx, row_idx, rng)
569
+ else:
570
+ raise ValueError(f"Unknown contamination rule: {rule}. Available: {list(CONTAMINATION_RULES.keys())}")
571
+
572
+ data[row_idx][col_idx] = new_val
573
+ if "difficulty" in spec:
574
+ issue.difficulty = spec["difficulty"]
575
+ issues.append(issue)
576
+
577
+ corrupted = _rows_to_csv([header] + data)
578
+
579
+ return Task(
580
+ task_id=task_id,
581
+ name=name,
582
+ description=description,
583
+ schema_description=schema_description,
584
+ validation_rules=validation_rules,
585
+ clean_csv=clean_csv,
586
+ planted_issues=issues,
587
+ corrupted_csv=corrupted,
588
+ max_steps=max_steps,
589
+ )
590
+
591
+
592
+ def register_task(task_id: str, factory_fn):
593
+ """Register a custom task factory. Factory must accept (seed: int) -> Task."""
594
+ TASK_REGISTRY[task_id] = factory_fn
595
+
596
+
597
+ def register_contamination_rule(name: str, rule_fn):
598
+ """
599
+ Register a custom contamination rule.
600
+
601
+ rule_fn signature: (rows, header, col_idx, row_idx, rng) -> (new_value, PlantedIssue)
602
+ """
603
+ CONTAMINATION_RULES[name] = rule_fn
604
+
605
+
606
  # ---------------------------------------------------------------------------
607
  # Task registry
608
  # ---------------------------------------------------------------------------
inference.py CHANGED
@@ -1,26 +1,31 @@
1
  #!/usr/bin/env python3
2
  """
3
- DataQA Inference Script
4
- -----------------------
5
- LLM agent that plays the DataQA environment.
 
 
 
6
  Uses the OpenAI client to interact with any OpenAI-compatible LLM API.
7
 
8
  Required environment variables:
9
- API_BASE_URL - LLM API endpoint (e.g., https://api.groq.com/openai/v1)
10
- MODEL_NAME - Model identifier (e.g., llama-3.3-70b-versatile)
11
- HF_TOKEN - HuggingFace token (for HF Spaces access)
12
-
13
- Structured logging format: [START], [STEP], [END] tags for evaluation.
 
 
 
14
  """
15
 
16
  from __future__ import annotations
17
 
18
- import json
19
  import os
20
  import re
21
  import sys
22
  import time
23
- from typing import Optional
24
 
25
  import requests
26
  from openai import OpenAI
@@ -28,52 +33,43 @@ from openai import OpenAI
28
  # ---------------------------------------------------------------------------
29
  # Configuration
30
  # ---------------------------------------------------------------------------
31
- API_BASE_URL = os.environ.get("API_BASE_URL", "https://api.groq.com/openai/v1")
32
- MODEL_NAME = os.environ.get("MODEL_NAME", "llama-3.3-70b-versatile")
33
- HF_TOKEN = os.environ.get("HF_TOKEN", "")
34
- ENV_URL = os.environ.get("ENV_URL", "http://localhost:8000")
35
 
 
36
  TASKS = ["easy", "medium", "hard"]
37
  MAX_STEPS_PER_TASK = 3
38
 
 
39
  # ---------------------------------------------------------------------------
40
- # Logging helpers (structured stdout for evaluation)
41
  # ---------------------------------------------------------------------------
42
 
43
- def log_start(task_id: str, metadata: Optional[dict] = None):
44
- entry = {"event": "START", "task_id": task_id, "timestamp": time.time()}
45
- if metadata:
46
- entry["metadata"] = metadata
47
- print(f"[START] {json.dumps(entry)}", flush=True)
48
-
49
-
50
- def log_step(task_id: str, step: int, reward: float, details: Optional[dict] = None):
51
- entry = {
52
- "event": "STEP",
53
- "task_id": task_id,
54
- "step": step,
55
- "reward": reward,
56
- "timestamp": time.time(),
57
- }
58
- if details:
59
- entry["details"] = details
60
- print(f"[STEP] {json.dumps(entry)}", flush=True)
61
-
62
-
63
- def log_end(task_id: str, final_score: float, metadata: Optional[dict] = None):
64
- entry = {
65
- "event": "END",
66
- "task_id": task_id,
67
- "final_score": final_score,
68
- "timestamp": time.time(),
69
- }
70
- if metadata:
71
- entry["metadata"] = metadata
72
- print(f"[END] {json.dumps(entry)}", flush=True)
73
 
74
 
75
  # ---------------------------------------------------------------------------
76
- # Environment HTTP client (simple, no WebSocket needed for inference)
77
  # ---------------------------------------------------------------------------
78
 
79
  class EnvHTTPClient:
@@ -99,26 +95,21 @@ class EnvHTTPClient:
99
  r.raise_for_status()
100
  return r.json()
101
 
102
- def step(self, issues: list[str], task_id: str = "easy") -> dict:
103
  r = self.session.post(
104
  f"{self.base_url}/step",
105
- json={"action": {"issues": issues, "task_id": task_id}},
106
  timeout=30,
107
  )
108
  r.raise_for_status()
109
  return r.json()
110
 
111
- def state(self) -> dict:
112
- r = self.session.get(f"{self.base_url}/state", timeout=10)
113
- r.raise_for_status()
114
- return r.json()
115
-
116
 
117
  # ---------------------------------------------------------------------------
118
- # LLM Agent
119
  # ---------------------------------------------------------------------------
120
 
121
- SYSTEM_PROMPT = """You are a data quality analyst. Your job is to inspect datasets and identify data quality issues.
122
 
123
  You will be given:
124
  1. A dataset in CSV format
@@ -142,7 +133,6 @@ CRITICAL INSTRUCTIONS FOR ROW NUMBERING:
142
  - Row numbers refer to the ROW POSITION in the CSV data, NOT the value of any ID column
143
  - Row 1 = the FIRST data row after the header
144
  - Row 2 = the SECOND data row after the header
145
- - For example, if the CSV has header on line 1 and data starting on line 2, the data on line 2 is row 1, line 3 is row 2, etc.
146
  - DO NOT use the employee_id, order_id, or experiment_id as the row number
147
  - Column names must match exactly (use the CSV header names, lowercase)
148
  - Check EVERY row and EVERY column systematically
@@ -154,7 +144,26 @@ Respond with ONLY the list of issues, one per line. No other text.
154
  Example: row:3,col:salary,issue:missing_value"""
155
 
156
 
157
- def build_user_prompt(observation: dict) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
  obs = observation if isinstance(observation, dict) else observation
159
  parts = []
160
 
@@ -173,6 +182,12 @@ def build_user_prompt(observation: dict) -> str:
173
  if feedback and "reset" not in feedback.lower():
174
  parts.append(f"FEEDBACK FROM PREVIOUS ATTEMPT:\n{feedback}")
175
 
 
 
 
 
 
 
176
  return "\n\n".join(parts)
177
 
178
 
@@ -183,88 +198,142 @@ def parse_llm_response(response: str) -> list[str]:
183
  line = line.strip()
184
  if not line:
185
  continue
186
- # Remove numbering like "1. " or "- " or "* "
187
  line = re.sub(r"^\s*[\d]+[.\)]\s*", "", line)
188
  line = re.sub(r"^\s*[-*]\s*", "", line)
189
  line = line.strip()
190
  if "row" in line.lower() and "col" in line.lower():
191
- # Lenient regex: accept : or = as delimiters, case-insensitive
192
  match = re.search(
193
  r"row\s*[:=]\s*(\d+)\s*[,;\s]+col(?:umn)?\s*[:=]\s*([\w_]+)\s*[,;\s]+issue\s*[:=]\s*([\w_]+)",
194
  line,
195
  re.IGNORECASE,
196
  )
197
  if match:
198
- # Normalize to lowercase canonical format
199
  normalized = f"row:{match.group(1)},col:{match.group(2).lower()},issue:{match.group(3).lower()}"
200
  issues.append(normalized)
201
  return issues
202
 
203
 
204
- def run_task(client: OpenAI, env: EnvHTTPClient, task_id: str) -> float:
205
- """Run a single task and return the best score."""
206
- log_start(task_id)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
 
208
- # Reset environment for this task
209
- reset_response = env.reset(task_id=task_id)
210
- observation = reset_response.get("observation", reset_response)
211
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
  best_score = 0.0
213
- messages = [{"role": "system", "content": SYSTEM_PROMPT}]
214
-
215
- for step_num in range(1, MAX_STEPS_PER_TASK + 1):
216
- user_prompt = build_user_prompt(observation)
217
- messages_for_call = messages + [{"role": "user", "content": user_prompt}]
218
-
219
- # Call LLM with retry on rate limit
220
- llm_output = ""
221
- for attempt in range(3):
222
- try:
223
- response = client.chat.completions.create(
224
- model=MODEL_NAME,
225
- messages=messages_for_call,
226
- temperature=0.1,
227
- max_tokens=2048,
228
- )
229
- llm_output = response.choices[0].message.content or ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
  break
231
- except Exception as e:
232
- if "rate_limit" in str(e).lower() or "429" in str(e):
233
- wait = 10 * (attempt + 1)
234
- print(f"[WARN] Rate limited, waiting {wait}s...", flush=True)
235
- time.sleep(wait)
236
- else:
237
- print(f"[ERROR] LLM call failed: {e}", file=sys.stderr, flush=True)
238
- break
239
-
240
- # Parse issues from LLM response
241
- issues = parse_llm_response(llm_output)
242
-
243
- if not issues:
244
- print(f"[WARN] No issues parsed from LLM response for {task_id} step {step_num}", file=sys.stderr, flush=True)
245
-
246
- # Submit to environment
247
- step_response = env.step(issues, task_id=task_id)
248
- observation = step_response.get("observation", step_response)
249
-
250
- # reward and done are at the top level of the response, not inside observation
251
- reward = float(step_response.get("reward", 0.0) or 0.0)
252
- done = bool(step_response.get("done", False))
253
- best_score = max(best_score, reward)
254
-
255
- log_step(task_id, step_num, reward, {
256
- "issues_reported": len(issues),
257
- "feedback": observation.get("feedback", ""),
258
- })
259
-
260
- if done:
261
- break
262
-
263
- # Add context for next attempt
264
- messages.append({"role": "user", "content": user_prompt})
265
- messages.append({"role": "assistant", "content": llm_output})
266
-
267
- log_end(task_id, best_score)
268
  return best_score
269
 
270
 
@@ -273,49 +342,34 @@ def run_task(client: OpenAI, env: EnvHTTPClient, task_id: str) -> float:
273
  # ---------------------------------------------------------------------------
274
 
275
  def main():
276
- print(f"[INFO] DataQA Inference starting", flush=True)
277
- print(f"[INFO] ENV_URL={ENV_URL}", flush=True)
278
- print(f"[INFO] API_BASE_URL={API_BASE_URL}", flush=True)
279
- print(f"[INFO] MODEL_NAME={MODEL_NAME}", flush=True)
280
 
281
- # Initialize clients
282
  env = EnvHTTPClient(ENV_URL)
283
  llm_client = OpenAI(
284
  base_url=API_BASE_URL,
285
- api_key=os.environ.get("LLM_API_KEY", HF_TOKEN or "no-key"),
286
  )
287
 
288
- # Check environment health
289
  if not env.health():
290
- print("[ERROR] Environment is not healthy. Exiting.", file=sys.stderr, flush=True)
291
  sys.exit(1)
292
 
293
- print(f"[INFO] Environment is healthy", flush=True)
294
 
295
- # Run all tasks
296
  scores = {}
297
  for task_id in TASKS:
298
- print(f"\n{'='*60}", flush=True)
299
- print(f"[INFO] Starting task: {task_id}", flush=True)
300
- print(f"{'='*60}", flush=True)
301
-
302
  try:
303
  score = run_task(llm_client, env, task_id)
304
  scores[task_id] = score
305
- print(f"[INFO] Task {task_id} completed with score: {score:.3f}", flush=True)
306
  except Exception as e:
307
- print(f"[ERROR] Task {task_id} failed: {e}", file=sys.stderr, flush=True)
308
  scores[task_id] = 0.0
309
 
310
- # Summary
311
- print(f"\n{'='*60}", flush=True)
312
- print("[INFO] FINAL RESULTS", flush=True)
313
- print(f"{'='*60}", flush=True)
314
- for task_id, score in scores.items():
315
- print(f"[INFO] {task_id}: {score:.3f}", flush=True)
316
-
317
  avg_score = sum(scores.values()) / len(scores) if scores else 0.0
318
- print(f"[INFO] Average score: {avg_score:.3f}", flush=True)
319
 
320
 
321
  if __name__ == "__main__":
 
1
  #!/usr/bin/env python3
2
  """
3
+ DataQA Inference Script — Two-Phase Agent
4
+ ------------------------------------------
5
+ LLM agent that plays the DataQA environment in two phases:
6
+ Phase 1: Identify all data quality issues
7
+ Phase 2: Propose fixes for identified issues
8
+
9
  Uses the OpenAI client to interact with any OpenAI-compatible LLM API.
10
 
11
  Required environment variables:
12
+ API_BASE_URL - LLM API endpoint (e.g., https://router.huggingface.co/v1)
13
+ MODEL_NAME - Model identifier (e.g., Qwen/Qwen2.5-72B-Instruct)
14
+ HF_TOKEN - HuggingFace token / API key
15
+
16
+ STDOUT FORMAT (mandatory for evaluation):
17
+ [START] task=<task_name> env=<benchmark> model=<model_name>
18
+ [STEP] step=<n> action=<action_str> reward=<0.00> done=<true|false> error=<msg|null>
19
+ [END] success=<true|false> steps=<n> score=<score> rewards=<r1,r2,...,rn>
20
  """
21
 
22
  from __future__ import annotations
23
 
 
24
  import os
25
  import re
26
  import sys
27
  import time
28
+ from typing import List, Optional
29
 
30
  import requests
31
  from openai import OpenAI
 
33
  # ---------------------------------------------------------------------------
34
  # Configuration
35
  # ---------------------------------------------------------------------------
36
+ API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
37
+ MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
38
+ API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
39
+ ENV_URL = os.getenv("ENV_URL", "http://localhost:8000")
40
 
41
+ BENCHMARK = "dataqa_env"
42
  TASKS = ["easy", "medium", "hard"]
43
  MAX_STEPS_PER_TASK = 3
44
 
45
+
46
  # ---------------------------------------------------------------------------
47
+ # Logging helpers (structured stdout exact format required by evaluation)
48
  # ---------------------------------------------------------------------------
49
 
50
+ def log_start(task: str, env: str, model: str) -> None:
51
+ print(f"[START] task={task} env={env} model={model}", flush=True)
52
+
53
+
54
+ def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
55
+ error_val = error if error else "null"
56
+ done_val = str(done).lower()
57
+ print(
58
+ f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}",
59
+ flush=True,
60
+ )
61
+
62
+
63
+ def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
64
+ rewards_str = ",".join(f"{r:.2f}" for r in rewards)
65
+ print(
66
+ f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}",
67
+ flush=True,
68
+ )
 
 
 
 
 
 
 
 
 
 
 
69
 
70
 
71
  # ---------------------------------------------------------------------------
72
+ # Environment HTTP client
73
  # ---------------------------------------------------------------------------
74
 
75
  class EnvHTTPClient:
 
95
  r.raise_for_status()
96
  return r.json()
97
 
98
+ def step(self, issues: list[str], fixes: list[str], task_id: str = "easy") -> dict:
99
  r = self.session.post(
100
  f"{self.base_url}/step",
101
+ json={"action": {"issues": issues, "fixes": fixes, "task_id": task_id}},
102
  timeout=30,
103
  )
104
  r.raise_for_status()
105
  return r.json()
106
 
 
 
 
 
 
107
 
108
  # ---------------------------------------------------------------------------
109
+ # LLM Prompts
110
  # ---------------------------------------------------------------------------
111
 
112
+ IDENTIFY_SYSTEM_PROMPT = """You are a data quality analyst. Your job is to inspect datasets and identify data quality issues.
113
 
114
  You will be given:
115
  1. A dataset in CSV format
 
133
  - Row numbers refer to the ROW POSITION in the CSV data, NOT the value of any ID column
134
  - Row 1 = the FIRST data row after the header
135
  - Row 2 = the SECOND data row after the header
 
136
  - DO NOT use the employee_id, order_id, or experiment_id as the row number
137
  - Column names must match exactly (use the CSV header names, lowercase)
138
  - Check EVERY row and EVERY column systematically
 
144
  Example: row:3,col:salary,issue:missing_value"""
145
 
146
 
147
+ FIX_SYSTEM_PROMPT = """You are a data repair specialist. You have already identified data quality issues in a dataset. Now you must propose the correct values to fix each issue.
148
+
149
+ For each issue you identified, propose a fix in EXACTLY this format:
150
+ row:<row_number>,col:<column_name>,fix:<corrected_value>
151
+
152
+ Guidelines for proposing fixes:
153
+ - For missing_value: infer the correct value from context, schema, and other rows
154
+ - For wrong_type: convert to the correct type (e.g., "seventy-five thousand" → "75000")
155
+ - For out_of_range: propose a value within the valid range that makes sense in context
156
+ - For format_violation: correct the format (e.g., "26/01/2024" → "2024-01-26")
157
+ - For inconsistent_value: compute the correct value from related fields
158
+ - For duplicate_row: propose a corrected unique key or indicate removal
159
+ - For statistical_outlier: propose a reasonable value given the model/context
160
+
161
+ Use the schema, validation rules, and surrounding data to determine the correct fix.
162
+ Respond with ONLY the list of fixes, one per line. No other text.
163
+ Example: row:3,col:salary,fix:75000"""
164
+
165
+
166
+ def build_user_prompt(observation: dict, include_fixes: bool = False) -> str:
167
  obs = observation if isinstance(observation, dict) else observation
168
  parts = []
169
 
 
182
  if feedback and "reset" not in feedback.lower():
183
  parts.append(f"FEEDBACK FROM PREVIOUS ATTEMPT:\n{feedback}")
184
 
185
+ if include_fixes:
186
+ parts.append(
187
+ "Now propose fixes for ALL issues. "
188
+ "Use format: row:<N>,col:<name>,fix:<corrected_value>"
189
+ )
190
+
191
  return "\n\n".join(parts)
192
 
193
 
 
198
  line = line.strip()
199
  if not line:
200
  continue
 
201
  line = re.sub(r"^\s*[\d]+[.\)]\s*", "", line)
202
  line = re.sub(r"^\s*[-*]\s*", "", line)
203
  line = line.strip()
204
  if "row" in line.lower() and "col" in line.lower():
 
205
  match = re.search(
206
  r"row\s*[:=]\s*(\d+)\s*[,;\s]+col(?:umn)?\s*[:=]\s*([\w_]+)\s*[,;\s]+issue\s*[:=]\s*([\w_]+)",
207
  line,
208
  re.IGNORECASE,
209
  )
210
  if match:
 
211
  normalized = f"row:{match.group(1)},col:{match.group(2).lower()},issue:{match.group(3).lower()}"
212
  issues.append(normalized)
213
  return issues
214
 
215
 
216
+ def parse_fix_response(response: str) -> list[str]:
217
+ """Extract fix lines from LLM response."""
218
+ fixes = []
219
+ for line in response.strip().split("\n"):
220
+ line = line.strip()
221
+ if not line:
222
+ continue
223
+ line = re.sub(r"^\s*[\d]+[.\)]\s*", "", line)
224
+ line = re.sub(r"^\s*[-*]\s*", "", line)
225
+ line = line.strip()
226
+ if "row" in line.lower() and "fix" in line.lower():
227
+ match = re.search(
228
+ r"row\s*[:=]\s*(\d+)\s*[,;\s]+col(?:umn)?\s*[:=]\s*([\w_]+)\s*[,;\s]+fix\s*[:=]\s*(.+?)$",
229
+ line,
230
+ re.IGNORECASE,
231
+ )
232
+ if match:
233
+ normalized = f"row:{match.group(1)},col:{match.group(2).lower()},fix:{match.group(3).strip()}"
234
+ fixes.append(normalized)
235
+ return fixes
236
 
 
 
 
237
 
238
+ def call_llm(client: OpenAI, system_prompt: str, user_prompt: str) -> str:
239
+ """Call the LLM with retry on rate limit."""
240
+ for attempt in range(3):
241
+ try:
242
+ response = client.chat.completions.create(
243
+ model=MODEL_NAME,
244
+ messages=[
245
+ {"role": "system", "content": system_prompt},
246
+ {"role": "user", "content": user_prompt},
247
+ ],
248
+ temperature=0.1,
249
+ max_tokens=2048,
250
+ )
251
+ return response.choices[0].message.content or ""
252
+ except Exception as e:
253
+ if "rate_limit" in str(e).lower() or "429" in str(e):
254
+ wait = 10 * (attempt + 1)
255
+ print(f"[DEBUG] Rate limited, waiting {wait}s...", file=sys.stderr, flush=True)
256
+ time.sleep(wait)
257
+ else:
258
+ print(f"[DEBUG] LLM call failed: {e}", file=sys.stderr, flush=True)
259
+ return ""
260
+ return ""
261
+
262
+
263
+ def run_task(client: OpenAI, env: EnvHTTPClient, task_id: str) -> float:
264
+ """
265
+ Run a single task with two-phase strategy:
266
+ Step 1: Identify issues only
267
+ Step 2: Identify + Fix (using feedback from step 1)
268
+ Step 3: Refined identify + fix (if needed)
269
+ """
270
+ log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME)
271
+
272
+ rewards: List[float] = []
273
+ steps_taken = 0
274
  best_score = 0.0
275
+ success = False
276
+
277
+ try:
278
+ reset_response = env.reset(task_id=task_id)
279
+ observation = reset_response.get("observation", reset_response)
280
+
281
+ last_issues: list[str] = []
282
+ last_llm_output = ""
283
+
284
+ for step_num in range(1, MAX_STEPS_PER_TASK + 1):
285
+ error_msg = None
286
+
287
+ # ── Phase 1: Identify issues ──
288
+ user_prompt = build_user_prompt(observation)
289
+ identify_output = call_llm(client, IDENTIFY_SYSTEM_PROMPT, user_prompt)
290
+ issues = parse_llm_response(identify_output)
291
+
292
+ if not issues and not error_msg:
293
+ error_msg = "no issues parsed from LLM response"
294
+
295
+ # ── Phase 2: Propose fixes (from step 2 onward, or always if we have issues) ──
296
+ fixes: list[str] = []
297
+ if issues and step_num >= 2:
298
+ # Build a fix prompt that includes the identified issues
299
+ fix_prompt = build_user_prompt(observation, include_fixes=True)
300
+ fix_prompt += f"\n\nISSUES FOUND:\n" + "\n".join(issues)
301
+ fix_output = call_llm(client, FIX_SYSTEM_PROMPT, fix_prompt)
302
+ fixes = parse_fix_response(fix_output)
303
+
304
+ # ── Submit to environment ──
305
+ action_str = ";".join(issues[:5]) if issues else "none"
306
+ if fixes:
307
+ action_str += "|fixes:" + ";".join(fixes[:3])
308
+
309
+ step_response = env.step(issues, fixes, task_id=task_id)
310
+ observation = step_response.get("observation", step_response)
311
+
312
+ reward = float(step_response.get("reward", 0.0) or 0.0)
313
+ done = bool(step_response.get("done", False))
314
+ best_score = max(best_score, reward)
315
+ rewards.append(reward)
316
+ steps_taken = step_num
317
+
318
+ log_step(
319
+ step=step_num,
320
+ action=action_str,
321
+ reward=reward,
322
+ done=done,
323
+ error=error_msg,
324
+ )
325
+
326
+ if done:
327
  break
328
+
329
+ last_issues = issues
330
+ last_llm_output = identify_output
331
+
332
+ success = best_score >= 0.5
333
+
334
+ finally:
335
+ log_end(success=success, steps=steps_taken, score=best_score, rewards=rewards)
336
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
337
  return best_score
338
 
339
 
 
342
  # ---------------------------------------------------------------------------
343
 
344
  def main():
345
+ print(f"[DEBUG] DataQA Inference starting", file=sys.stderr, flush=True)
346
+ print(f"[DEBUG] ENV_URL={ENV_URL}", file=sys.stderr, flush=True)
347
+ print(f"[DEBUG] API_BASE_URL={API_BASE_URL}", file=sys.stderr, flush=True)
348
+ print(f"[DEBUG] MODEL_NAME={MODEL_NAME}", file=sys.stderr, flush=True)
349
 
 
350
  env = EnvHTTPClient(ENV_URL)
351
  llm_client = OpenAI(
352
  base_url=API_BASE_URL,
353
+ api_key=API_KEY or "no-key",
354
  )
355
 
 
356
  if not env.health():
357
+ print("[DEBUG] Environment is not healthy. Exiting.", file=sys.stderr, flush=True)
358
  sys.exit(1)
359
 
360
+ print(f"[DEBUG] Environment is healthy", file=sys.stderr, flush=True)
361
 
 
362
  scores = {}
363
  for task_id in TASKS:
 
 
 
 
364
  try:
365
  score = run_task(llm_client, env, task_id)
366
  scores[task_id] = score
 
367
  except Exception as e:
368
+ print(f"[DEBUG] Task {task_id} failed: {e}", file=sys.stderr, flush=True)
369
  scores[task_id] = 0.0
370
 
 
 
 
 
 
 
 
371
  avg_score = sum(scores.values()) / len(scores) if scores else 0.0
372
+ print(f"\n[DEBUG] FINAL RESULTS: {scores} avg={avg_score:.3f}", file=sys.stderr, flush=True)
373
 
374
 
375
  if __name__ == "__main__":
models.py DELETED
@@ -1,4 +0,0 @@
1
- """Root-level models for OpenEnv compatibility."""
2
- from dataqa_env.models import DataQAAction, DataQAObservation, DataQAState
3
-
4
- __all__ = ["DataQAAction", "DataQAObservation", "DataQAState"]
 
 
 
 
 
openenv.yaml CHANGED
@@ -3,4 +3,4 @@ name: dataqa_env
3
  type: space
4
  runtime: fastapi
5
  app: dataqa_env.server.app:app
6
- port: 8000
 
3
  type: space
4
  runtime: fastapi
5
  app: dataqa_env.server.app:app
6
+ port: 7860
scripts/prevalidation_script.sh ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ #
3
+ # validate-submission.sh — OpenEnv Submission Validator
4
+ #
5
+ # Checks that your HF Space is live, Docker image builds, and openenv validate passes.
6
+ #
7
+ # Prerequisites:
8
+ # - Docker: https://docs.docker.com/get-docker/
9
+ # - openenv-core: pip install openenv-core
10
+ # - curl (usually pre-installed)
11
+ #
12
+ # Run:
13
+ # curl -fsSL https://raw.githubusercontent.com/<owner>/<repo>/main/scripts/validate-submission.sh | bash -s -- <ping_url> [repo_dir]
14
+ #
15
+ # Or download and run locally:
16
+ # chmod +x validate-submission.sh
17
+ # ./validate-submission.sh <ping_url> [repo_dir]
18
+ #
19
+ # Arguments:
20
+ # ping_url Your HuggingFace Space URL (e.g. https://your-space.hf.space)
21
+ # repo_dir Path to your repo (default: current directory)
22
+ #
23
+ # Examples:
24
+ # ./validate-submission.sh https://my-team.hf.space
25
+ # ./validate-submission.sh https://my-team.hf.space ./my-repo
26
+ #
27
+
28
+ set -uo pipefail
29
+
30
+ DOCKER_BUILD_TIMEOUT=600
31
+ if [ -t 1 ]; then
32
+ RED='\033[0;31m'
33
+ GREEN='\033[0;32m'
34
+ YELLOW='\033[1;33m'
35
+ BOLD='\033[1m'
36
+ NC='\033[0m'
37
+ else
38
+ RED='' GREEN='' YELLOW='' BOLD='' NC=''
39
+ fi
40
+
41
+ run_with_timeout() {
42
+ local secs="$1"; shift
43
+ if command -v timeout &>/dev/null; then
44
+ timeout "$secs" "$@"
45
+ elif command -v gtimeout &>/dev/null; then
46
+ gtimeout "$secs" "$@"
47
+ else
48
+ "$@" &
49
+ local pid=$!
50
+ ( sleep "$secs" && kill "$pid" 2>/dev/null ) &
51
+ local watcher=$!
52
+ wait "$pid" 2>/dev/null
53
+ local rc=$?
54
+ kill "$watcher" 2>/dev/null
55
+ wait "$watcher" 2>/dev/null
56
+ return $rc
57
+ fi
58
+ }
59
+
60
+ portable_mktemp() {
61
+ local prefix="${1:-validate}"
62
+ mktemp "${TMPDIR:-/tmp}/${prefix}-XXXXXX" 2>/dev/null || mktemp
63
+ }
64
+
65
+ CLEANUP_FILES=()
66
+ cleanup() { rm -f "${CLEANUP_FILES[@]+"${CLEANUP_FILES[@]}"}"; }
67
+ trap cleanup EXIT
68
+
69
+ PING_URL="${1:-}"
70
+ REPO_DIR="${2:-.}"
71
+
72
+ if [ -z "$PING_URL" ]; then
73
+ printf "Usage: %s <ping_url> [repo_dir]\n" "$0"
74
+ printf "\n"
75
+ printf " ping_url Your HuggingFace Space URL (e.g. https://your-space.hf.space)\n"
76
+ printf " repo_dir Path to your repo (default: current directory)\n"
77
+ exit 1
78
+ fi
79
+
80
+ if ! REPO_DIR="$(cd "$REPO_DIR" 2>/dev/null && pwd)"; then
81
+ printf "Error: directory '%s' not found\n" "${2:-.}"
82
+ exit 1
83
+ fi
84
+ PING_URL="${PING_URL%/}"
85
+ export PING_URL
86
+ PASS=0
87
+
88
+ log() { printf "[%s] %b\n" "$(date -u +%H:%M:%S)" "$*"; }
89
+ pass() { log "${GREEN}PASSED${NC} -- $1"; PASS=$((PASS + 1)); }
90
+ fail() { log "${RED}FAILED${NC} -- $1"; }
91
+ hint() { printf " ${YELLOW}Hint:${NC} %b\n" "$1"; }
92
+ stop_at() {
93
+ printf "\n"
94
+ printf "${RED}${BOLD}Validation stopped at %s.${NC} Fix the above before continuing.\n" "$1"
95
+ exit 1
96
+ }
97
+
98
+ printf "\n"
99
+ printf "${BOLD}========================================${NC}\n"
100
+ printf "${BOLD} OpenEnv Submission Validator${NC}\n"
101
+ printf "${BOLD}========================================${NC}\n"
102
+ log "Repo: $REPO_DIR"
103
+ log "Ping URL: $PING_URL"
104
+ printf "\n"
105
+
106
+ log "${BOLD}Step 1/3: Pinging HF Space${NC} ($PING_URL/reset) ..."
107
+
108
+ CURL_OUTPUT=$(portable_mktemp "validate-curl")
109
+ CLEANUP_FILES+=("$CURL_OUTPUT")
110
+ HTTP_CODE=$(curl -s -o "$CURL_OUTPUT" -w "%{http_code}" -X POST \
111
+ -H "Content-Type: application/json" -d '{}' \
112
+ "$PING_URL/reset" --max-time 30 2>"$CURL_OUTPUT" || printf "000")
113
+
114
+ if [ "$HTTP_CODE" = "200" ]; then
115
+ pass "HF Space is live and responds to /reset"
116
+ elif [ "$HTTP_CODE" = "000" ]; then
117
+ fail "HF Space not reachable (connection failed or timed out)"
118
+ hint "Check your network connection and that the Space is running."
119
+ hint "Try: curl -s -o /dev/null -w '%%{http_code}' -X POST $PING_URL/reset"
120
+ stop_at "Step 1"
121
+ else
122
+ fail "HF Space /reset returned HTTP $HTTP_CODE (expected 200)"
123
+ hint "Make sure your Space is running and the URL is correct."
124
+ hint "Try opening $PING_URL in your browser first."
125
+ stop_at "Step 1"
126
+ fi
127
+
128
+ log "${BOLD}Step 2/3: Running docker build${NC} ..."
129
+
130
+ if ! command -v docker &>/dev/null; then
131
+ fail "docker command not found"
132
+ hint "Install Docker: https://docs.docker.com/get-docker/"
133
+ stop_at "Step 2"
134
+ fi
135
+
136
+ if [ -f "$REPO_DIR/Dockerfile" ]; then
137
+ DOCKER_CONTEXT="$REPO_DIR"
138
+ elif [ -f "$REPO_DIR/server/Dockerfile" ]; then
139
+ DOCKER_CONTEXT="$REPO_DIR/server"
140
+ else
141
+ fail "No Dockerfile found in repo root or server/ directory"
142
+ stop_at "Step 2"
143
+ fi
144
+
145
+ log " Found Dockerfile in $DOCKER_CONTEXT"
146
+
147
+ BUILD_OK=false
148
+ BUILD_OUTPUT=$(run_with_timeout "$DOCKER_BUILD_TIMEOUT" docker build "$DOCKER_CONTEXT" 2>&1) && BUILD_OK=true
149
+
150
+ if [ "$BUILD_OK" = true ]; then
151
+ pass "Docker build succeeded"
152
+ else
153
+ fail "Docker build failed (timeout=${DOCKER_BUILD_TIMEOUT}s)"
154
+ printf "%s\n" "$BUILD_OUTPUT" | tail -20
155
+ stop_at "Step 2"
156
+ fi
157
+
158
+ log "${BOLD}Step 3/3: Running openenv validate${NC} ..."
159
+
160
+ if ! command -v openenv &>/dev/null; then
161
+ fail "openenv command not found"
162
+ hint "Install it: pip install openenv-core"
163
+ stop_at "Step 3"
164
+ fi
165
+
166
+ VALIDATE_OK=false
167
+ VALIDATE_OUTPUT=$(cd "$REPO_DIR" && openenv validate 2>&1) && VALIDATE_OK=true
168
+
169
+ if [ "$VALIDATE_OK" = true ]; then
170
+ pass "openenv validate passed"
171
+ [ -n "$VALIDATE_OUTPUT" ] && log " $VALIDATE_OUTPUT"
172
+ else
173
+ fail "openenv validate failed"
174
+ printf "%s\n" "$VALIDATE_OUTPUT"
175
+ stop_at "Step 3"
176
+ fi
177
+
178
+ printf "\n"
179
+ printf "${BOLD}========================================${NC}\n"
180
+ printf "${GREEN}${BOLD} All 3/3 checks passed!${NC}\n"
181
+ printf "${GREEN}${BOLD} Your submission is ready to submit.${NC}\n"
182
+ printf "${BOLD}========================================${NC}\n"
183
+ printf "\n"
184
+
185
+ exit 0
scripts/sample_inference_script.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Inference Script Example
3
+ ===================================
4
+ MANDATORY
5
+ - Before submitting, ensure the following variables are defined in your environment configuration:
6
+ API_BASE_URL The API endpoint for the LLM.
7
+ MODEL_NAME The model identifier to use for inference.
8
+ HF_TOKEN Your Hugging Face / API key.
9
+ LOCAL_IMAGE_NAME The name of the local image to use for the environment if you are using from_docker_image()
10
+ method
11
+
12
+ - Defaults are set only for API_BASE_URL and MODEL_NAME
13
+ (and should reflect your active inference setup):
14
+ API_BASE_URL = os.getenv("API_BASE_URL", "<your-active-endpoint>")
15
+ MODEL_NAME = os.getenv("MODEL_NAME", "<your-active-model>")
16
+
17
+ - The inference script must be named `inference.py` and placed in the root directory of the project
18
+ - Participants must use OpenAI Client for all LLM calls using above variables
19
+
20
+ STDOUT FORMAT
21
+ - The script must emit exactly three line types to stdout, in this order:
22
+
23
+ [START] task=<task_name> env=<benchmark> model=<model_name>
24
+ [STEP] step=<n> action=<action_str> reward=<0.00> done=<true|false> error=<msg|null>
25
+ [END] success=<true|false> steps=<n> score=<score> rewards=<r1,r2,...,rn>
26
+
27
+ Rules:
28
+ - One [START] line at episode begin.
29
+ - One [STEP] line per step, immediately after env.step() returns.
30
+ - One [END] line after env.close(), always emitted (even on exception).
31
+ - reward and rewards are formatted to 2 decimal places.
32
+ - done and success are lowercase booleans: true or false.
33
+ - error is the raw last_action_error string, or null if none.
34
+ - All fields on a single line with no newlines within a line.
35
+ - Each tasks should return score in [0, 1]
36
+
37
+ Example:
38
+ [START] task=click-test env=miniwob model=Qwen3-VL-30B
39
+ [STEP] step=1 action=click('123') reward=0.00 done=false error=null
40
+ [STEP] step=2 action=fill('456','text') reward=0.00 done=false error=null
41
+ [STEP] step=3 action=click('789') reward=1.00 done=true error=null
42
+ [END] success=true steps=3 score=1.00 rewards=0.00,0.00,1.00
43
+ """
44
+
45
+ import asyncio
46
+ import os
47
+ import textwrap
48
+ from typing import List, Optional
49
+
50
+ from openai import OpenAI
51
+
52
+ from my_env_v4 import MyEnvV4Action, MyEnvV4Env
53
+ IMAGE_NAME = os.getenv("IMAGE_NAME") # If you are using docker image
54
+ API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
55
+
56
+ API_BASE_URL = os.getenv("API_BASE_URL") or "https://router.huggingface.co/v1"
57
+ MODEL_NAME = os.getenv("MODEL_NAME") or "Qwen/Qwen2.5-72B-Instruct"
58
+ TASK_NAME = os.getenv("MY_ENV_V4_TASK", "echo")
59
+ BENCHMARK = os.getenv("MY_ENV_V4_BENCHMARK", "my_env_v4")
60
+ MAX_STEPS = 8
61
+ TEMPERATURE = 0.7
62
+ MAX_TOKENS = 150
63
+ SUCCESS_SCORE_THRESHOLD = 0.1 # normalized score in [0, 1]
64
+
65
+ # Max possible reward: each token contributes 0.1, across all steps
66
+ _MAX_REWARD_PER_STEP = MAX_TOKENS * 0.1
67
+ MAX_TOTAL_REWARD = MAX_STEPS * _MAX_REWARD_PER_STEP
68
+
69
+ SYSTEM_PROMPT = textwrap.dedent(
70
+ """
71
+ You are interacting with a simple echo environment.
72
+ Each turn you must send a message. The environment will echo it back.
73
+ Reward is proportional to message length: reward = len(message) * 0.1
74
+ Your goal is to maximize total reward by sending meaningful, substantive messages.
75
+ Reply with exactly one message string — no quotes, no prefixes, just the message text.
76
+ """
77
+ ).strip()
78
+
79
+
80
+ def log_start(task: str, env: str, model: str) -> None:
81
+ print(f"[START] task={task} env={env} model={model}", flush=True)
82
+
83
+
84
+ def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
85
+ error_val = error if error else "null"
86
+ done_val = str(done).lower()
87
+ print(
88
+ f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}",
89
+ flush=True,
90
+ )
91
+
92
+
93
+ def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
94
+ rewards_str = ",".join(f"{r:.2f}" for r in rewards)
95
+ print(f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}", flush=True)
96
+
97
+
98
+ def build_user_prompt(step: int, last_echoed: str, last_reward: float, history: List[str]) -> str:
99
+ history_block = "\n".join(history[-4:]) if history else "None"
100
+ return textwrap.dedent(
101
+ f"""
102
+ Step: {step}
103
+ Last echoed message: {last_echoed!r}
104
+ Last reward: {last_reward:.2f}
105
+ Previous steps:
106
+ {history_block}
107
+ Send your next message.
108
+ """
109
+ ).strip()
110
+
111
+
112
+ def get_model_message(client: OpenAI, step: int, last_echoed: str, last_reward: float, history: List[str]) -> str:
113
+ user_prompt = build_user_prompt(step, last_echoed, last_reward, history)
114
+ try:
115
+ completion = client.chat.completions.create(
116
+ model=MODEL_NAME,
117
+ messages=[
118
+ {"role": "system", "content": SYSTEM_PROMPT},
119
+ {"role": "user", "content": user_prompt},
120
+ ],
121
+ temperature=TEMPERATURE,
122
+ max_tokens=MAX_TOKENS,
123
+ stream=False,
124
+ )
125
+ text = (completion.choices[0].message.content or "").strip()
126
+ return text if text else "hello"
127
+ except Exception as exc:
128
+ print(f"[DEBUG] Model request failed: {exc}", flush=True)
129
+ return "hello"
130
+
131
+
132
+ async def main() -> None:
133
+ client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
134
+
135
+ env = await MyEnvV4Env.from_docker_image(IMAGE_NAME)
136
+
137
+ history: List[str] = []
138
+ rewards: List[float] = []
139
+ steps_taken = 0
140
+ score = 0.0
141
+ success = False
142
+
143
+ log_start(task=TASK_NAME, env=BENCHMARK, model=MODEL_NAME)
144
+
145
+ try:
146
+ result = await env.reset() # OpenENV.reset()
147
+ last_echoed = result.observation.echoed_message
148
+ last_reward = 0.0
149
+
150
+ for step in range(1, MAX_STEPS + 1):
151
+ if result.done:
152
+ break
153
+
154
+ message = get_model_message(client, step, last_echoed, last_reward, history)
155
+
156
+ result = await env.step(MyEnvV4Action(message=message))
157
+ obs = result.observation
158
+
159
+ reward = result.reward or 0.0
160
+ done = result.done
161
+ error = None
162
+
163
+ rewards.append(reward)
164
+ steps_taken = step
165
+ last_echoed = obs.echoed_message
166
+ last_reward = reward
167
+
168
+ log_step(step=step, action=message, reward=reward, done=done, error=error)
169
+
170
+ history.append(f"Step {step}: {message!r} -> reward {reward:+.2f}")
171
+
172
+ if done:
173
+ break
174
+
175
+ score = sum(rewards) / MAX_TOTAL_REWARD if MAX_TOTAL_REWARD > 0 else 0.0
176
+ score = min(max(score, 0.0), 1.0) # clamp to [0, 1]
177
+ success = score >= SUCCESS_SCORE_THRESHOLD
178
+
179
+ finally:
180
+ try:
181
+ await env.close()
182
+ except Exception as e:
183
+ print(f"[DEBUG] env.close() error (container cleanup): {e}", flush=True)
184
+ log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
185
+
186
+
187
+ if __name__ == "__main__":
188
+ asyncio.run(main())
server/__init__.py CHANGED
@@ -0,0 +1 @@
 
 
1
+ """Root-level server package — delegates to dataqa_env.server."""
server/app.py CHANGED
@@ -1,13 +1,12 @@
1
- """
2
- Root-level server entry point for OpenEnv compatibility.
3
- """
4
 
5
  from dataqa_env.server.app import app # noqa: F401
6
 
7
 
8
  def main():
 
9
  import uvicorn
10
- uvicorn.run(app, host="0.0.0.0", port=8000)
11
 
12
 
13
  if __name__ == "__main__":
 
1
+ """Entrypoint for openenv-core deployment. Delegates to dataqa_env.server.app."""
 
 
2
 
3
  from dataqa_env.server.app import app # noqa: F401
4
 
5
 
6
  def main():
7
+ """Start the environment server."""
8
  import uvicorn
9
+ uvicorn.run(app, host="0.0.0.0", port=7860)
10
 
11
 
12
  if __name__ == "__main__":
tests/__init__.py ADDED
File without changes
tests/test_environment.py ADDED
@@ -0,0 +1,455 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for the DataQA environment (reset, step, scoring, two-phase identify+fix)."""
2
+
3
+ import pytest
4
+ from dataqa_env.server.environment import (
5
+ DataQAEnvironment,
6
+ parse_issue_key,
7
+ parse_fix,
8
+ compute_f1,
9
+ compute_weighted_reward,
10
+ grade_fixes,
11
+ IDENTIFY_WEIGHT,
12
+ FIX_WEIGHT,
13
+ )
14
+ from dataqa_env.models import DataQAAction
15
+ from dataqa_env.server.tasks import PlantedIssue, create_task_easy, create_task_medium
16
+
17
+
18
+ # ──────────────────────────────────────────────────────
19
+ # Issue parsing
20
+ # ──────────────────────────────────────────────────────
21
+
22
+ class TestParseIssueKey:
23
+ def test_standard_format(self):
24
+ assert parse_issue_key("row:3,col:salary,issue:missing_value") == "row:3,col:salary,issue:missing_value"
25
+
26
+ def test_with_equals(self):
27
+ assert parse_issue_key("row=3,col=salary,issue=missing_value") == "row:3,col:salary,issue:missing_value"
28
+
29
+ def test_case_insensitive(self):
30
+ assert parse_issue_key("Row:3,Col:Salary,Issue:Missing_Value") == "row:3,col:salary,issue:missing_value"
31
+
32
+ def test_with_spaces(self):
33
+ assert parse_issue_key("row: 3, col: salary, issue: missing_value") == "row:3,col:salary,issue:missing_value"
34
+
35
+ def test_unparseable(self):
36
+ assert parse_issue_key("this is garbage") is None
37
+
38
+ def test_partial_match(self):
39
+ assert parse_issue_key("row:3,col:salary") is None
40
+
41
+ def test_empty_string(self):
42
+ assert parse_issue_key("") is None
43
+
44
+ def test_semicolon_separator(self):
45
+ result = parse_issue_key("row:3;col:salary;issue:missing_value")
46
+ assert result == "row:3,col:salary,issue:missing_value"
47
+
48
+
49
+ # ──────────────────────────────────────────────────────
50
+ # Fix parsing
51
+ # ──────────────────────────────────────────────────────
52
+
53
+ class TestParseFix:
54
+ def test_standard_format(self):
55
+ result = parse_fix("row:4,col:name,fix:Alice Chen")
56
+ assert result == (4, "name", "Alice Chen")
57
+
58
+ def test_with_equals(self):
59
+ result = parse_fix("row=4,col=name,fix=Alice Chen")
60
+ assert result == (4, "name", "Alice Chen")
61
+
62
+ def test_numeric_fix(self):
63
+ result = parse_fix("row:7,col:salary,fix:75000")
64
+ assert result == (7, "salary", "75000")
65
+
66
+ def test_date_fix(self):
67
+ result = parse_fix("row:12,col:order_date,fix:2024-01-26")
68
+ assert result == (12, "order_date", "2024-01-26")
69
+
70
+ def test_case_insensitive(self):
71
+ result = parse_fix("Row:4,Col:Name,Fix:Alice Chen")
72
+ assert result == (4, "name", "Alice Chen")
73
+
74
+ def test_unparseable(self):
75
+ assert parse_fix("garbage") is None
76
+ assert parse_fix("row:4,col:name") is None
77
+
78
+ def test_fix_with_special_chars(self):
79
+ result = parse_fix("row:1,col:email,fix:alice.chen@company.com")
80
+ assert result == (1, "email", "alice.chen@company.com")
81
+
82
+
83
+ # ──────────────────────────────────────────────────────
84
+ # F1 scoring
85
+ # ──────────────────────────────────────────────────────
86
+
87
+ class TestComputeF1:
88
+ def test_perfect_match(self):
89
+ keys = {"row:1,col:a,issue:missing_value"}
90
+ result = compute_f1(keys, keys)
91
+ assert result["f1"] == 1.0
92
+
93
+ def test_no_reported_no_planted(self):
94
+ result = compute_f1(set(), set())
95
+ assert result["f1"] == 1.0
96
+
97
+ def test_no_reported_some_planted(self):
98
+ planted = {"row:1,col:a,issue:missing_value"}
99
+ result = compute_f1(set(), planted)
100
+ assert result["f1"] == 0.0
101
+ assert result["fn"] == 1
102
+
103
+ def test_all_false_positives(self):
104
+ reported = {"row:99,col:x,issue:wrong_type"}
105
+ planted = {"row:1,col:a,issue:missing_value"}
106
+ result = compute_f1(reported, planted)
107
+ assert result["f1"] == 0.0
108
+
109
+ def test_partial_match(self):
110
+ reported = {"row:1,col:a,issue:missing_value", "row:2,col:b,issue:wrong_type"}
111
+ planted = {"row:1,col:a,issue:missing_value", "row:3,col:c,issue:duplicate_row"}
112
+ result = compute_f1(reported, planted)
113
+ assert result["tp"] == 1
114
+ assert result["fp"] == 1
115
+ assert result["fn"] == 1
116
+ assert 0 < result["f1"] < 1
117
+
118
+ def test_precision_recall_calculation(self):
119
+ reported = {"a", "b", "c"}
120
+ planted = {"a", "b", "d"}
121
+ result = compute_f1(reported, planted)
122
+ assert result["precision"] == pytest.approx(2 / 3)
123
+ assert result["recall"] == pytest.approx(2 / 3)
124
+
125
+
126
+ # ──────────────────────────────────────────────────────
127
+ # Weighted reward
128
+ # ──────────────────────────────────────────────────────
129
+
130
+ class TestComputeWeightedReward:
131
+ def test_perfect_match(self):
132
+ issues = [
133
+ PlantedIssue(row=1, col="a", issue_type="missing_value", description="", difficulty=1.0),
134
+ PlantedIssue(row=2, col="b", issue_type="wrong_type", description="", difficulty=3.0),
135
+ ]
136
+ reported = {i.to_key() for i in issues}
137
+ result = compute_weighted_reward(reported, issues)
138
+ assert result["weighted_reward"] == 1.0
139
+
140
+ def test_empty_both(self):
141
+ result = compute_weighted_reward(set(), [])
142
+ assert result["weighted_reward"] == 1.0
143
+
144
+ def test_no_reported(self):
145
+ issues = [PlantedIssue(row=1, col="a", issue_type="missing_value", description="", difficulty=2.0)]
146
+ result = compute_weighted_reward(set(), issues)
147
+ assert result["weighted_reward"] == 0.0
148
+
149
+ def test_hard_issue_worth_more(self):
150
+ easy = PlantedIssue(row=1, col="a", issue_type="missing_value", description="", difficulty=1.0)
151
+ hard = PlantedIssue(row=2, col="b", issue_type="statistical_outlier", description="", difficulty=3.0)
152
+ issues = [easy, hard]
153
+ hard_found = compute_weighted_reward({hard.to_key()}, issues)
154
+ easy_found = compute_weighted_reward({easy.to_key()}, issues)
155
+ assert hard_found["weighted_reward"] > easy_found["weighted_reward"]
156
+
157
+ def test_false_positives_reduce_reward(self):
158
+ issues = [PlantedIssue(row=1, col="a", issue_type="missing_value", description="", difficulty=1.0)]
159
+ correct = {issues[0].to_key()}
160
+ with_fp = correct | {"row:99,col:x,issue:wrong_type"}
161
+ r_correct = compute_weighted_reward(correct, issues)
162
+ r_with_fp = compute_weighted_reward(with_fp, issues)
163
+ assert r_correct["weighted_reward"] > r_with_fp["weighted_reward"]
164
+
165
+
166
+ # ──────────────────────────────────────────────────────
167
+ # Fix grading
168
+ # ──────────────────────────────────────────────────────
169
+
170
+ class TestGradeFixes:
171
+ @pytest.fixture
172
+ def easy_task(self):
173
+ return create_task_easy()
174
+
175
+ def test_no_fixes_no_issues(self):
176
+ from dataqa_env.server.tasks import Task
177
+ task = Task(task_id="empty", name="", description="", schema_description="",
178
+ validation_rules="", clean_csv="a\n1")
179
+ result = grade_fixes([], task)
180
+ assert result["fix_score"] == 1.0
181
+
182
+ def test_no_fixes_submitted(self, easy_task):
183
+ result = grade_fixes([], easy_task)
184
+ assert result["fix_score"] == 0.0
185
+ assert result["fixes_attempted"] == 0
186
+
187
+ def test_exact_fix_for_missing_name(self, easy_task):
188
+ # Row 4 has empty name — clean value is "David Kim"
189
+ fixes = [(4, "name", "David Kim")]
190
+ result = grade_fixes(fixes, easy_task)
191
+ assert result["fix_score"] > 0.0
192
+ assert result["fixes_correct"] == 1
193
+
194
+ def test_exact_fix_for_wrong_type_salary(self, easy_task):
195
+ # Row 7 has "seventy-five thousand" — clean value is "75000"
196
+ fixes = [(7, "salary", "75000")]
197
+ result = grade_fixes(fixes, easy_task)
198
+ assert result["fixes_correct"] == 1
199
+
200
+ def test_numeric_close_match(self, easy_task):
201
+ # Row 9 has salary "5000" — clean value is "73000"
202
+ # Propose 73100 (within 1% of 73000)
203
+ fixes = [(9, "salary", "73100")]
204
+ result = grade_fixes(fixes, easy_task)
205
+ assert result["fixes_partial"] == 1
206
+
207
+ def test_wrong_value_for_issue_cell(self, easy_task):
208
+ # Row 4 name is empty — propose wrong name
209
+ fixes = [(4, "name", "Wrong Person")]
210
+ result = grade_fixes(fixes, easy_task)
211
+ assert result["fixes_partial"] == 1 # correct cell, wrong value
212
+ assert result["fix_score"] > 0.0 # gets partial credit
213
+
214
+ def test_fix_for_non_issue_cell(self, easy_task):
215
+ # Row 1 col name is fine — no issue there
216
+ fixes = [(1, "name", "Some Name")]
217
+ result = grade_fixes(fixes, easy_task)
218
+ assert result["fixes_wrong"] == 1
219
+ assert result["fix_score"] == 0.0
220
+
221
+ def test_multiple_fixes_best_wins(self, easy_task):
222
+ # Submit two fixes for same cell — best one should count
223
+ fixes = [
224
+ (4, "name", "Wrong Person"), # partial credit
225
+ (4, "name", "David Kim"), # exact match
226
+ ]
227
+ result = grade_fixes(fixes, easy_task)
228
+ assert result["fixes_correct"] >= 1
229
+
230
+ def test_all_fixes_correct(self, easy_task):
231
+ # Fix most issues with exact values
232
+ fixes = [
233
+ (4, "name", "David Kim"),
234
+ (7, "salary", "75000"),
235
+ (9, "salary", "73000"),
236
+ (15, "email", "oscar.rivera@company.com"),
237
+ (18, "start_date", "2022-01-19"),
238
+ ]
239
+ result = grade_fixes(fixes, easy_task)
240
+ assert result["fix_score"] > 0.7 # 5 out of 6 issues fixed (duplicate can't be fixed)
241
+
242
+ def test_fix_score_bounded(self, easy_task):
243
+ fixes = [(4, "name", "David Kim"), (99, "x", "bad")]
244
+ result = grade_fixes(fixes, easy_task)
245
+ assert 0.0 <= result["fix_score"] <= 1.0
246
+
247
+
248
+ # ──────────────────────────────────────────────────────
249
+ # Full environment lifecycle
250
+ # ──────────────────────────────────────────────────────
251
+
252
+ class TestDataQAEnvironment:
253
+ @pytest.fixture
254
+ def env(self):
255
+ return DataQAEnvironment()
256
+
257
+ def test_reset_returns_observation(self, env):
258
+ obs = env.reset(task_id="easy")
259
+ assert obs.dataset_csv
260
+ assert obs.schema_description
261
+ assert obs.validation_rules
262
+ assert obs.task_description
263
+ assert obs.num_issues_hint == 6
264
+ assert obs.max_steps == 3
265
+ assert obs.done is False
266
+ assert obs.reward == 0.0
267
+ assert "fix" in obs.feedback.lower() # mentions fix phase
268
+
269
+ def test_reset_medium(self, env):
270
+ obs = env.reset(task_id="medium")
271
+ assert obs.num_issues_hint == 8
272
+
273
+ def test_reset_hard(self, env):
274
+ obs = env.reset(task_id="hard")
275
+ assert obs.num_issues_hint == 10
276
+
277
+ def test_step_identify_only(self, env):
278
+ """Backward compatible: only issues, no fixes."""
279
+ env.reset(task_id="easy")
280
+ # Submit all 6 correct issues for easy task
281
+ action = DataQAAction(
282
+ issues=[
283
+ "row:4,col:name,issue:missing_value",
284
+ "row:7,col:salary,issue:wrong_type",
285
+ "row:21,col:employee_id,issue:duplicate_row",
286
+ "row:9,col:salary,issue:out_of_range",
287
+ "row:15,col:email,issue:inconsistent_value",
288
+ "row:18,col:start_date,issue:out_of_range",
289
+ ],
290
+ task_id="easy",
291
+ )
292
+ obs = env.step(action)
293
+ assert obs.done is True
294
+ assert obs.reward >= 0.999 # identify-only uses identify_score directly
295
+
296
+ def test_step_with_fixes_increases_reward(self, env):
297
+ """Submitting correct fixes should produce high combined reward."""
298
+ env.reset(task_id="easy")
299
+ # All 6 issues + 3 fixes
300
+ action = DataQAAction(
301
+ issues=[
302
+ "row:4,col:name,issue:missing_value",
303
+ "row:7,col:salary,issue:wrong_type",
304
+ "row:21,col:employee_id,issue:duplicate_row",
305
+ "row:9,col:salary,issue:out_of_range",
306
+ "row:15,col:email,issue:inconsistent_value",
307
+ "row:18,col:start_date,issue:out_of_range",
308
+ ],
309
+ fixes=[
310
+ "row:4,col:name,fix:David Kim",
311
+ "row:7,col:salary,fix:75000",
312
+ "row:9,col:salary,fix:73000",
313
+ ],
314
+ task_id="easy",
315
+ )
316
+ obs = env.step(action)
317
+ # Perfect identify + partial fixes -> high combined reward
318
+ assert obs.metadata["combined_reward"] > 0.7
319
+
320
+ def test_step_with_partial_issues(self, env):
321
+ env.reset(task_id="easy")
322
+ action = DataQAAction(
323
+ issues=["row:4,col:name,issue:missing_value"],
324
+ task_id="easy",
325
+ )
326
+ obs = env.step(action)
327
+ assert 0 < obs.reward < 1.0
328
+ assert obs.done is False
329
+
330
+ def test_step_with_no_issues(self, env):
331
+ env.reset(task_id="easy")
332
+ action = DataQAAction(issues=[], task_id="easy")
333
+ obs = env.step(action)
334
+ assert obs.reward == 0.0
335
+
336
+ def test_step_exhausts_max_steps(self, env):
337
+ env.reset(task_id="easy")
338
+ for _ in range(3):
339
+ action = DataQAAction(issues=["row:99,col:x,issue:wrong_type"], task_id="easy")
340
+ obs = env.step(action)
341
+ assert obs.done is True
342
+
343
+ def test_auto_reset_on_step(self, env):
344
+ action = DataQAAction(
345
+ issues=["row:4,col:name,issue:missing_value"],
346
+ task_id="easy",
347
+ )
348
+ obs = env.step(action)
349
+ assert obs.task_id == "easy"
350
+
351
+ def test_state_tracking(self, env):
352
+ env.reset(task_id="easy")
353
+ assert env.state.task_id == "easy"
354
+ assert env.state.current_step == 0
355
+ assert env.state.best_score == 0.0
356
+
357
+ action = DataQAAction(issues=["row:4,col:name,issue:missing_value"], task_id="easy")
358
+ env.step(action)
359
+ assert env.state.current_step == 1
360
+ assert env.state.best_score > 0.0
361
+
362
+ def test_best_score_monotonic(self, env):
363
+ env.reset(task_id="easy")
364
+ action1 = DataQAAction(
365
+ issues=["row:4,col:name,issue:missing_value", "row:7,col:salary,issue:wrong_type"],
366
+ task_id="easy",
367
+ )
368
+ env.step(action1)
369
+ score_after_1 = env.state.best_score
370
+
371
+ action2 = DataQAAction(issues=["row:99,col:x,issue:wrong_type"], task_id="easy")
372
+ env.step(action2)
373
+ assert env.state.best_score >= score_after_1
374
+
375
+ def test_metadata_includes_both_phases(self, env):
376
+ env.reset(task_id="easy")
377
+ action = DataQAAction(
378
+ issues=["row:4,col:name,issue:missing_value"],
379
+ fixes=["row:4,col:name,fix:David Kim"],
380
+ task_id="easy",
381
+ )
382
+ obs = env.step(action)
383
+ m = obs.metadata
384
+ assert "identify_f1" in m
385
+ assert "identify_score" in m
386
+ assert "fix_score" in m
387
+ assert "combined_reward" in m
388
+ assert "tp" in m
389
+ assert "fixes_correct" in m
390
+ assert "fixes_attempted" in m
391
+
392
+ def test_parse_error_in_feedback(self, env):
393
+ env.reset(task_id="easy")
394
+ action = DataQAAction(issues=["garbage input"], task_id="easy")
395
+ obs = env.step(action)
396
+ assert "Parse error" in obs.feedback
397
+
398
+ def test_concurrent_sessions_flag(self):
399
+ assert DataQAEnvironment.SUPPORTS_CONCURRENT_SESSIONS is True
400
+
401
+ def test_reward_between_0_and_1(self, env):
402
+ """Hackathon requirement: scores must be 0.0-1.0."""
403
+ env.reset(task_id="hard")
404
+ for _ in range(3):
405
+ action = DataQAAction(
406
+ issues=["row:1,col:x,issue:wrong_type", "row:99,col:y,issue:missing_value"],
407
+ fixes=["row:1,col:x,fix:wrong"],
408
+ task_id="hard",
409
+ )
410
+ obs = env.step(action)
411
+ assert 0.0 <= obs.reward <= 1.0
412
+
413
+ def test_combined_reward_weights(self, env):
414
+ """Verify combined = IDENTIFY_WEIGHT * identify + FIX_WEIGHT * fix."""
415
+ env.reset(task_id="easy")
416
+ action = DataQAAction(
417
+ issues=["row:4,col:name,issue:missing_value"],
418
+ fixes=["row:4,col:name,fix:David Kim"],
419
+ task_id="easy",
420
+ )
421
+ obs = env.step(action)
422
+ m = obs.metadata
423
+ expected = IDENTIFY_WEIGHT * m["identify_score"] + FIX_WEIGHT * m["fix_score"]
424
+ assert abs(m["combined_reward"] - expected) < 0.01
425
+
426
+ def test_fix_feedback_shown_when_fixes_submitted(self, env):
427
+ env.reset(task_id="easy")
428
+ action = DataQAAction(
429
+ issues=["row:4,col:name,issue:missing_value"],
430
+ fixes=["row:4,col:name,fix:David Kim"],
431
+ task_id="easy",
432
+ )
433
+ obs = env.step(action)
434
+ assert "Fix Proposals" in obs.feedback
435
+ assert "Combined Reward" in obs.feedback
436
+
437
+ def test_no_fix_penalty_when_no_fixes_submitted(self, env):
438
+ """If agent submits no fixes, reward = identify_score (no penalty)."""
439
+ env.reset(task_id="easy")
440
+ action = DataQAAction(
441
+ issues=[
442
+ "row:4,col:name,issue:missing_value",
443
+ "row:7,col:salary,issue:wrong_type",
444
+ "row:21,col:employee_id,issue:duplicate_row",
445
+ "row:9,col:salary,issue:out_of_range",
446
+ "row:15,col:email,issue:inconsistent_value",
447
+ "row:18,col:start_date,issue:out_of_range",
448
+ ],
449
+ task_id="easy",
450
+ )
451
+ obs = env.step(action)
452
+ # identify_score should be ~1.0 since all 6 issues found
453
+ assert obs.reward >= 0.99
454
+ # combined_reward equals identify_score when no fixes
455
+ assert obs.metadata["combined_reward"] == obs.metadata["identify_score"]
tests/test_extensibility.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for the extensibility API — custom tasks and contamination rules."""
2
+
3
+ import pytest
4
+ from dataqa_env.server.tasks import (
5
+ PlantedIssue,
6
+ create_task_from_config,
7
+ register_task,
8
+ register_contamination_rule,
9
+ CONTAMINATION_RULES,
10
+ get_task,
11
+ list_tasks,
12
+ )
13
+ from dataqa_env.server.environment import DataQAEnvironment, compute_weighted_reward
14
+ from dataqa_env.models import DataQAAction
15
+
16
+
17
+ SIMPLE_CSV = "id,name,score\n1,Alice,95\n2,Bob,87\n3,Carol,92\n4,Dave,78"
18
+
19
+
20
+ class TestCreateTaskFromConfig:
21
+ def test_basic_creation(self):
22
+ task = create_task_from_config(
23
+ task_id="test_custom",
24
+ name="Test Task",
25
+ description="Test",
26
+ schema_description="id: int, name: str, score: int",
27
+ validation_rules="No missing values",
28
+ clean_csv=SIMPLE_CSV,
29
+ contaminations=[
30
+ {"rule": "missing_value", "row": 0, "col": 1},
31
+ ],
32
+ )
33
+ assert task.task_id == "test_custom"
34
+ assert len(task.planted_issues) == 1
35
+ assert task.planted_issues[0].issue_type == "missing_value"
36
+ assert task.planted_issues[0].col == "name"
37
+
38
+ def test_multiple_contaminations(self):
39
+ task = create_task_from_config(
40
+ task_id="multi",
41
+ name="Multi",
42
+ description="Test",
43
+ schema_description="",
44
+ validation_rules="",
45
+ clean_csv=SIMPLE_CSV,
46
+ contaminations=[
47
+ {"rule": "missing_value", "row": 0, "col": 1},
48
+ {"rule": "missing_value", "row": 2, "col": 1},
49
+ ],
50
+ )
51
+ assert len(task.planted_issues) == 2
52
+
53
+ def test_custom_difficulty_override(self):
54
+ task = create_task_from_config(
55
+ task_id="custom_diff",
56
+ name="Custom Difficulty",
57
+ description="Test",
58
+ schema_description="",
59
+ validation_rules="",
60
+ clean_csv=SIMPLE_CSV,
61
+ contaminations=[
62
+ {"rule": "missing_value", "row": 0, "col": 1, "difficulty": 2.5},
63
+ ],
64
+ )
65
+ assert task.planted_issues[0].difficulty == 2.5
66
+
67
+ def test_callable_rule(self):
68
+ def custom_rule(rows, header, col_idx, row_idx, rng):
69
+ return "CORRUPTED", PlantedIssue(
70
+ row=row_idx + 1, col=header[col_idx], issue_type="wrong_type",
71
+ description="Custom corruption", difficulty=1.5,
72
+ )
73
+
74
+ task = create_task_from_config(
75
+ task_id="callable",
76
+ name="Callable Rule",
77
+ description="Test",
78
+ schema_description="",
79
+ validation_rules="",
80
+ clean_csv=SIMPLE_CSV,
81
+ contaminations=[
82
+ {"rule": custom_rule, "row": 1, "col": 2},
83
+ ],
84
+ )
85
+ assert task.planted_issues[0].issue_type == "wrong_type"
86
+ assert "CORRUPTED" in task.corrupted_csv
87
+
88
+ def test_unknown_rule_raises(self):
89
+ with pytest.raises(ValueError, match="Unknown contamination rule"):
90
+ create_task_from_config(
91
+ task_id="bad",
92
+ name="Bad",
93
+ description="",
94
+ schema_description="",
95
+ validation_rules="",
96
+ clean_csv=SIMPLE_CSV,
97
+ contaminations=[{"rule": "nonexistent_rule", "row": 0, "col": 0}],
98
+ )
99
+
100
+
101
+ class TestRegisterContaminationRule:
102
+ def test_register_and_use(self):
103
+ def reverse_value(rows, header, col_idx, row_idx, rng):
104
+ val = rows[row_idx][col_idx]
105
+ return val[::-1], PlantedIssue(
106
+ row=row_idx + 1, col=header[col_idx], issue_type="format_violation",
107
+ description="Reversed value", difficulty=1.5,
108
+ )
109
+
110
+ register_contamination_rule("reverse", reverse_value)
111
+ assert "reverse" in CONTAMINATION_RULES
112
+
113
+ task = create_task_from_config(
114
+ task_id="rev_test",
115
+ name="Reverse Test",
116
+ description="",
117
+ schema_description="",
118
+ validation_rules="",
119
+ clean_csv=SIMPLE_CSV,
120
+ contaminations=[{"rule": "reverse", "row": 0, "col": 1}],
121
+ )
122
+ assert task.planted_issues[0].issue_type == "format_violation"
123
+ # "Alice" reversed is "ecilA"
124
+ assert "ecilA" in task.corrupted_csv
125
+
126
+ # Cleanup
127
+ del CONTAMINATION_RULES["reverse"]
128
+
129
+
130
+ class TestRegisterTask:
131
+ def test_register_and_get(self):
132
+ task = create_task_from_config(
133
+ task_id="registered",
134
+ name="Registered Task",
135
+ description="Test registered task",
136
+ schema_description="id: int, name: str",
137
+ validation_rules="No missing values",
138
+ clean_csv=SIMPLE_CSV,
139
+ contaminations=[{"rule": "missing_value", "row": 1, "col": 1}],
140
+ )
141
+ register_task("registered", lambda seed: task)
142
+ assert "registered" in list_tasks()
143
+
144
+ fetched = get_task("registered")
145
+ assert fetched.task_id == "registered"
146
+ assert len(fetched.planted_issues) == 1
147
+
148
+ # Cleanup
149
+ from dataqa_env.server.tasks import TASK_REGISTRY
150
+ del TASK_REGISTRY["registered"]
151
+
152
+
153
+ class TestCustomTaskInEnvironment:
154
+ def test_full_lifecycle_identify_only(self):
155
+ """Custom task works end-to-end with identify-only."""
156
+ task = create_task_from_config(
157
+ task_id="e2e_custom",
158
+ name="E2E Custom",
159
+ description="End-to-end test",
160
+ schema_description="id: int, name: str, score: int",
161
+ validation_rules="No missing values",
162
+ clean_csv=SIMPLE_CSV,
163
+ contaminations=[
164
+ {"rule": "missing_value", "row": 0, "col": 1, "difficulty": 1.0},
165
+ {"rule": "whitespace_value", "row": 2, "col": 1, "difficulty": 2.5},
166
+ ],
167
+ )
168
+ register_task("e2e_custom", lambda seed: task)
169
+
170
+ env = DataQAEnvironment()
171
+ obs = env.reset(task_id="e2e_custom")
172
+ assert obs.num_issues_hint == 2
173
+
174
+ action = DataQAAction(
175
+ issues=[i.to_key() for i in task.planted_issues],
176
+ task_id="e2e_custom",
177
+ )
178
+ obs = env.step(action)
179
+ assert obs.done is True
180
+ assert obs.reward >= 0.999
181
+
182
+ from dataqa_env.server.tasks import TASK_REGISTRY
183
+ del TASK_REGISTRY["e2e_custom"]
184
+
185
+ def test_full_lifecycle_identify_and_fix(self):
186
+ """Custom task works end-to-end with both identify and fix."""
187
+ task = create_task_from_config(
188
+ task_id="e2e_fix",
189
+ name="E2E Fix",
190
+ description="End-to-end test with fixes",
191
+ schema_description="id: int, name: str, score: int",
192
+ validation_rules="No missing values",
193
+ clean_csv=SIMPLE_CSV,
194
+ contaminations=[
195
+ {"rule": "missing_value", "row": 0, "col": 1, "difficulty": 1.0},
196
+ ],
197
+ )
198
+ register_task("e2e_fix", lambda seed: task)
199
+
200
+ env = DataQAEnvironment()
201
+ env.reset(task_id="e2e_fix")
202
+
203
+ # Submit issues + fix
204
+ action = DataQAAction(
205
+ issues=[task.planted_issues[0].to_key()],
206
+ fixes=["row:1,col:name,fix:Alice"], # clean value is "Alice"
207
+ task_id="e2e_fix",
208
+ )
209
+ obs = env.step(action)
210
+ assert obs.done is True
211
+ assert obs.metadata["fix_score"] > 0.0
212
+ assert obs.metadata["combined_reward"] > 0.0
213
+
214
+ from dataqa_env.server.tasks import TASK_REGISTRY
215
+ del TASK_REGISTRY["e2e_fix"]
tests/test_inference.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for the inference script's parsing, prompt building, and log format."""
2
+
3
+ import pytest
4
+ import sys
5
+ import os
6
+
7
+ sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
8
+ from inference import parse_llm_response, parse_fix_response, build_user_prompt, log_start, log_step, log_end
9
+
10
+
11
+ class TestParseLLMResponse:
12
+ def test_standard_format(self):
13
+ response = "row:1,col:name,issue:missing_value\nrow:2,col:salary,issue:wrong_type"
14
+ issues = parse_llm_response(response)
15
+ assert len(issues) == 2
16
+ assert "row:1,col:name,issue:missing_value" in issues
17
+
18
+ def test_numbered_list(self):
19
+ response = "1. row:1,col:name,issue:missing_value\n2. row:2,col:salary,issue:wrong_type"
20
+ issues = parse_llm_response(response)
21
+ assert len(issues) == 2
22
+
23
+ def test_bullet_list(self):
24
+ response = "- row:1,col:name,issue:missing_value\n* row:2,col:salary,issue:wrong_type"
25
+ issues = parse_llm_response(response)
26
+ assert len(issues) == 2
27
+
28
+ def test_equals_delimiter(self):
29
+ response = "row=1,col=name,issue=missing_value"
30
+ issues = parse_llm_response(response)
31
+ assert len(issues) == 1
32
+ assert issues[0] == "row:1,col:name,issue:missing_value"
33
+
34
+ def test_mixed_case(self):
35
+ response = "Row:1,Col:Name,Issue:Missing_Value"
36
+ issues = parse_llm_response(response)
37
+ assert len(issues) == 1
38
+ assert issues[0] == "row:1,col:name,issue:missing_value"
39
+
40
+ def test_empty_response(self):
41
+ assert parse_llm_response("") == []
42
+ assert parse_llm_response(" ") == []
43
+
44
+ def test_garbage_lines_skipped(self):
45
+ response = "Here are the issues:\nrow:1,col:name,issue:missing_value\nNo more issues."
46
+ issues = parse_llm_response(response)
47
+ assert len(issues) == 1
48
+
49
+ def test_deduplication_not_applied(self):
50
+ response = "row:1,col:name,issue:missing_value\nrow:1,col:name,issue:missing_value"
51
+ issues = parse_llm_response(response)
52
+ assert len(issues) == 2
53
+
54
+ def test_with_column_variant(self):
55
+ response = "row:1,column:name,issue:missing_value"
56
+ issues = parse_llm_response(response)
57
+ assert len(issues) == 1
58
+
59
+
60
+ class TestParseFixResponse:
61
+ def test_standard_format(self):
62
+ response = "row:4,col:name,fix:David Kim\nrow:7,col:salary,fix:75000"
63
+ fixes = parse_fix_response(response)
64
+ assert len(fixes) == 2
65
+ assert "row:4,col:name,fix:David Kim" in fixes
66
+
67
+ def test_numbered_list(self):
68
+ response = "1. row:4,col:name,fix:David Kim\n2. row:7,col:salary,fix:75000"
69
+ fixes = parse_fix_response(response)
70
+ assert len(fixes) == 2
71
+
72
+ def test_with_special_chars(self):
73
+ response = "row:1,col:email,fix:alice.chen@company.com"
74
+ fixes = parse_fix_response(response)
75
+ assert len(fixes) == 1
76
+ assert "alice.chen@company.com" in fixes[0]
77
+
78
+ def test_empty_response(self):
79
+ assert parse_fix_response("") == []
80
+
81
+ def test_date_fix(self):
82
+ response = "row:12,col:order_date,fix:2024-01-26"
83
+ fixes = parse_fix_response(response)
84
+ assert len(fixes) == 1
85
+
86
+ def test_ignores_issue_lines(self):
87
+ response = "row:4,col:name,issue:missing_value\nrow:4,col:name,fix:David Kim"
88
+ fixes = parse_fix_response(response)
89
+ assert len(fixes) == 1 # only the fix line
90
+
91
+
92
+ class TestBuildUserPrompt:
93
+ def test_includes_all_fields(self):
94
+ obs = {
95
+ "task_description": "Find issues",
96
+ "schema_description": "col: int",
97
+ "validation_rules": "no nulls",
98
+ "dataset_csv": "a,b\n1,2",
99
+ "num_issues_hint": 3,
100
+ "feedback": "",
101
+ }
102
+ prompt = build_user_prompt(obs)
103
+ assert "Find issues" in prompt
104
+ assert "col: int" in prompt
105
+ assert "no nulls" in prompt
106
+ assert "a,b" in prompt
107
+ assert "3 issues" in prompt
108
+
109
+ def test_includes_feedback_on_retry(self):
110
+ obs = {
111
+ "task_description": "Find issues",
112
+ "schema_description": "",
113
+ "validation_rules": "",
114
+ "dataset_csv": "a\n1",
115
+ "num_issues_hint": 0,
116
+ "feedback": "Step 1/3: You missed 2 issues",
117
+ }
118
+ prompt = build_user_prompt(obs)
119
+ assert "FEEDBACK" in prompt
120
+ assert "missed 2" in prompt
121
+
122
+ def test_excludes_reset_feedback(self):
123
+ obs = {
124
+ "task_description": "",
125
+ "schema_description": "",
126
+ "validation_rules": "",
127
+ "dataset_csv": "",
128
+ "num_issues_hint": 0,
129
+ "feedback": "Environment reset. Start inspecting.",
130
+ }
131
+ prompt = build_user_prompt(obs)
132
+ assert "FEEDBACK" not in prompt
133
+
134
+ def test_include_fixes_flag(self):
135
+ obs = {
136
+ "task_description": "Find issues",
137
+ "schema_description": "",
138
+ "validation_rules": "",
139
+ "dataset_csv": "a\n1",
140
+ "num_issues_hint": 0,
141
+ "feedback": "",
142
+ }
143
+ prompt = build_user_prompt(obs, include_fixes=True)
144
+ assert "fix" in prompt.lower()
145
+
146
+
147
+ class TestLogFormat:
148
+ """Verify stdout log format matches hackathon evaluation requirements."""
149
+
150
+ def test_log_start_format(self, capsys):
151
+ log_start(task="easy", env="dataqa_env", model="test-model")
152
+ out = capsys.readouterr().out.strip()
153
+ assert out == "[START] task=easy env=dataqa_env model=test-model"
154
+
155
+ def test_log_step_format(self, capsys):
156
+ log_step(step=1, action="row:1,col:name,issue:missing_value", reward=0.50, done=False, error=None)
157
+ out = capsys.readouterr().out.strip()
158
+ assert out == "[STEP] step=1 action=row:1,col:name,issue:missing_value reward=0.50 done=false error=null"
159
+
160
+ def test_log_step_with_error(self, capsys):
161
+ log_step(step=2, action="none", reward=0.00, done=True, error="timeout")
162
+ out = capsys.readouterr().out.strip()
163
+ assert "error=timeout" in out
164
+ assert "done=true" in out
165
+
166
+ def test_log_end_format(self, capsys):
167
+ log_end(success=True, steps=3, score=0.85, rewards=[0.25, 0.50, 0.85])
168
+ out = capsys.readouterr().out.strip()
169
+ assert out == "[END] success=true steps=3 score=0.850 rewards=0.25,0.50,0.85"
170
+
171
+ def test_log_end_failure(self, capsys):
172
+ log_end(success=False, steps=1, score=0.0, rewards=[0.0])
173
+ out = capsys.readouterr().out.strip()
174
+ assert "success=false" in out
175
+ assert "score=0.000" in out
176
+
177
+ def test_reward_format_2_decimal(self, capsys):
178
+ log_step(step=1, action="test", reward=0.123456, done=False, error=None)
179
+ out = capsys.readouterr().out.strip()
180
+ assert "reward=0.12" in out
181
+
182
+ def test_no_newlines_within_line(self, capsys):
183
+ log_start(task="easy", env="dataqa_env", model="model")
184
+ log_step(step=1, action="act", reward=0.0, done=False, error=None)
185
+ log_end(success=False, steps=1, score=0.0, rewards=[0.0])
186
+ out = capsys.readouterr().out
187
+ lines = [l for l in out.split("\n") if l.strip()]
188
+ assert len(lines) == 3
189
+ assert lines[0].startswith("[START]")
190
+ assert lines[1].startswith("[STEP]")
191
+ assert lines[2].startswith("[END]")
tests/test_tasks.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for task definitions, data corruption, and issue planting."""
2
+
3
+ import pytest
4
+ from dataqa_env.server.tasks import (
5
+ PlantedIssue,
6
+ Task,
7
+ create_task_easy,
8
+ create_task_medium,
9
+ create_task_hard,
10
+ get_task,
11
+ list_tasks,
12
+ _csv_to_rows,
13
+ _rows_to_csv,
14
+ )
15
+
16
+
17
+ class TestPlantedIssue:
18
+ def test_to_key(self):
19
+ issue = PlantedIssue(row=3, col="salary", issue_type="missing_value", description="test")
20
+ assert issue.to_key() == "row:3,col:salary,issue:missing_value"
21
+
22
+ def test_difficulty_default(self):
23
+ issue = PlantedIssue(row=1, col="name", issue_type="missing_value", description="test")
24
+ assert issue.difficulty == 1.0
25
+
26
+ def test_difficulty_custom(self):
27
+ issue = PlantedIssue(row=1, col="name", issue_type="missing_value", description="test", difficulty=3.0)
28
+ assert issue.difficulty == 3.0
29
+
30
+
31
+ class TestCSVHelpers:
32
+ def test_roundtrip(self):
33
+ csv_text = "a,b,c\n1,2,3\n4,5,6"
34
+ rows = _csv_to_rows(csv_text)
35
+ assert len(rows) == 3
36
+ result = _rows_to_csv(rows)
37
+ assert "1,2,3" in result
38
+
39
+ def test_empty_csv(self):
40
+ rows = _csv_to_rows("a,b\n")
41
+ assert len(rows) == 1 # header only
42
+
43
+
44
+ class TestTaskEasy:
45
+ @pytest.fixture
46
+ def task(self):
47
+ return create_task_easy()
48
+
49
+ def test_task_id(self, task):
50
+ assert task.task_id == "easy"
51
+
52
+ def test_has_6_issues(self, task):
53
+ assert len(task.planted_issues) == 6
54
+
55
+ def test_issue_types(self, task):
56
+ types = {i.issue_type for i in task.planted_issues}
57
+ assert "missing_value" in types
58
+ assert "wrong_type" in types
59
+ assert "duplicate_row" in types
60
+ assert "out_of_range" in types
61
+ assert "inconsistent_value" in types
62
+
63
+ def test_corrupted_csv_differs_from_clean(self, task):
64
+ assert task.corrupted_csv != task.clean_csv
65
+
66
+ def test_issue_keys_unique(self, task):
67
+ keys = [i.to_key() for i in task.planted_issues]
68
+ assert len(keys) == len(set(keys))
69
+
70
+ def test_max_steps(self, task):
71
+ assert task.max_steps == 3
72
+
73
+ def test_corrupted_csv_has_more_rows(self, task):
74
+ clean_rows = _csv_to_rows(task.clean_csv)
75
+ corrupt_rows = _csv_to_rows(task.corrupted_csv)
76
+ assert len(corrupt_rows) > len(clean_rows) # duplicate row added
77
+
78
+ def test_difficulty_weights(self, task):
79
+ for issue in task.planted_issues:
80
+ assert 1.0 <= issue.difficulty <= 3.0
81
+
82
+
83
+ class TestTaskMedium:
84
+ @pytest.fixture
85
+ def task(self):
86
+ return create_task_medium()
87
+
88
+ def test_task_id(self, task):
89
+ assert task.task_id == "medium"
90
+
91
+ def test_has_8_issues(self, task):
92
+ assert len(task.planted_issues) == 8
93
+
94
+ def test_issue_types(self, task):
95
+ types = {i.issue_type for i in task.planted_issues}
96
+ assert "inconsistent_value" in types
97
+ assert "format_violation" in types
98
+ assert "missing_value" in types
99
+
100
+ def test_issue_keys_unique(self, task):
101
+ keys = [i.to_key() for i in task.planted_issues]
102
+ assert len(keys) == len(set(keys))
103
+
104
+ def test_difficulty_weights(self, task):
105
+ for issue in task.planted_issues:
106
+ assert 1.0 <= issue.difficulty <= 3.0
107
+
108
+
109
+ class TestTaskHard:
110
+ @pytest.fixture
111
+ def task(self):
112
+ return create_task_hard()
113
+
114
+ def test_task_id(self, task):
115
+ assert task.task_id == "hard"
116
+
117
+ def test_has_10_issues(self, task):
118
+ assert len(task.planted_issues) == 10
119
+
120
+ def test_issue_types(self, task):
121
+ types = {i.issue_type for i in task.planted_issues}
122
+ assert "inconsistent_value" in types
123
+ assert "format_violation" in types
124
+ assert "statistical_outlier" in types
125
+ assert "out_of_range" in types
126
+ assert "missing_value" in types
127
+
128
+ def test_has_high_difficulty_issues(self, task):
129
+ hard_issues = [i for i in task.planted_issues if i.difficulty >= 2.5]
130
+ assert len(hard_issues) >= 2 # data leakage, GPU outlier, whitespace
131
+
132
+ def test_issue_keys_unique(self, task):
133
+ keys = [i.to_key() for i in task.planted_issues]
134
+ assert len(keys) == len(set(keys))
135
+
136
+
137
+ class TestTaskRegistry:
138
+ def test_list_tasks(self):
139
+ tasks = list_tasks()
140
+ assert set(tasks) == {"easy", "medium", "hard"}
141
+
142
+ def test_get_task_easy(self):
143
+ task = get_task("easy")
144
+ assert task.task_id == "easy"
145
+
146
+ def test_get_task_medium(self):
147
+ task = get_task("medium")
148
+ assert task.task_id == "medium"
149
+
150
+ def test_get_task_hard(self):
151
+ task = get_task("hard")
152
+ assert task.task_id == "hard"
153
+
154
+ def test_get_task_unknown_raises(self):
155
+ with pytest.raises(ValueError, match="Unknown task"):
156
+ get_task("nonexistent")
157
+
158
+ def test_seed_determinism(self):
159
+ t1 = get_task("easy", seed=42)
160
+ t2 = get_task("easy", seed=42)
161
+ assert t1.corrupted_csv == t2.corrupted_csv
162
+ assert [i.to_key() for i in t1.planted_issues] == [i.to_key() for i in t2.planted_issues]