Spaces:
Sleeping
Sleeping
Commit ·
ce59113
0
Parent(s):
Initial commit
Browse files- .claude/settings.local.json +13 -0
- .dockerignore +8 -0
- .gitignore +33 -0
- Dockerfile +24 -0
- README.md +569 -0
- inference.py +394 -0
- openenv.yaml +61 -0
- sqlsherlock_env/__init__.py +19 -0
- sqlsherlock_env/client.py +186 -0
- sqlsherlock_env/models.py +171 -0
- sqlsherlock_env/pyproject.toml +46 -0
- sqlsherlock_env/server/__init__.py +11 -0
- sqlsherlock_env/server/app.py +199 -0
- sqlsherlock_env/server/database.py +563 -0
- sqlsherlock_env/server/dataset_loader.py +467 -0
- sqlsherlock_env/server/environment.py +408 -0
- sqlsherlock_env/server/exporter.py +160 -0
- sqlsherlock_env/server/graders/__init__.py +73 -0
- sqlsherlock_env/server/graders/task1.py +75 -0
- sqlsherlock_env/server/graders/task2.py +93 -0
- sqlsherlock_env/server/graders/task3.py +94 -0
- sqlsherlock_env/server/graders/universal.py +442 -0
- sqlsherlock_env/server/issue_detector.py +920 -0
- sqlsherlock_env/server/requirements.txt +9 -0
- sqlsherlock_env/server/reward.py +411 -0
- sqlsherlock_env/server/schema_profiler.py +255 -0
- sqlsherlock_env/server/sqlsherlock_env_environment.py +155 -0
- sqlsherlock_env/server/validator.py +545 -0
- tests/__init__.py +0 -0
- tests/conftest.py +198 -0
- tests/test_environment.py +447 -0
- tests/test_graders.py +354 -0
- tests/test_issue_detector.py +341 -0
- train.py +334 -0
.claude/settings.local.json
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"permissions": {
|
| 3 |
+
"allow": [
|
| 4 |
+
"Bash(PYTHONPATH=sqlsherlock_env .venv/Scripts/python -c \"from openenv.core.env_server import Environment; import inspect; print\\([m for m in dir\\(Environment\\) if not m.startswith\\('__'\\)]\\); print\\(inspect.getmembers\\(Environment, predicate=inspect.isfunction\\)\\)\")",
|
| 5 |
+
"Bash(PYTHONPATH=sqlsherlock_env .venv/Scripts/python -c \"from openenv.core.env_server import Environment; import inspect; src = inspect.getsource\\(Environment.state\\); print\\(src\\)\")",
|
| 6 |
+
"Bash(PYTHONPATH=sqlsherlock_env .venv/Scripts/python -m pytest tests/ -v)",
|
| 7 |
+
"Bash(PYTHONPATH=sqlsherlock_env .venv/Scripts/python -m pytest tests/test_issue_detector.py::TestDuplicateDetection tests/test_issue_detector.py::TestDetectTrap tests/test_graders.py::TestTask1Grader -v)",
|
| 8 |
+
"Bash(PYTHONPATH=\"c:/Users/HP/OneDrive/Desktop/SQLSherlock-env/sqlsherlock_env\" \"c:/Users/HP/OneDrive/Desktop/SQLSherlock-env/.venv/Scripts/uvicorn\" server.app:app --host 0.0.0.0 --port 7860)",
|
| 9 |
+
"Bash(.venv/Scripts/python -c ':*)",
|
| 10 |
+
"Bash(PYTHONPATH=sqlsherlock_env .venv/Scripts/python -m pytest tests/ -v --tb=short)"
|
| 11 |
+
]
|
| 12 |
+
}
|
| 13 |
+
}
|
.dockerignore
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.venv/
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.pyc
|
| 4 |
+
.pytest_cache/
|
| 5 |
+
.git/
|
| 6 |
+
*.egg-info/
|
| 7 |
+
grpo_output/
|
| 8 |
+
.env
|
.gitignore
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Virtual environments
|
| 2 |
+
.venv/
|
| 3 |
+
venv/
|
| 4 |
+
env/
|
| 5 |
+
|
| 6 |
+
# Python
|
| 7 |
+
__pycache__/
|
| 8 |
+
*.pyc
|
| 9 |
+
*.pyo
|
| 10 |
+
*.egg-info/
|
| 11 |
+
dist/
|
| 12 |
+
build/
|
| 13 |
+
|
| 14 |
+
# Training outputs
|
| 15 |
+
grpo_output/
|
| 16 |
+
|
| 17 |
+
# IDE
|
| 18 |
+
.vscode/
|
| 19 |
+
.idea/
|
| 20 |
+
|
| 21 |
+
# OS
|
| 22 |
+
.DS_Store
|
| 23 |
+
Thumbs.db
|
| 24 |
+
|
| 25 |
+
# Secrets
|
| 26 |
+
.env
|
| 27 |
+
*.key
|
| 28 |
+
|
| 29 |
+
# Pytest
|
| 30 |
+
.pytest_cache/
|
| 31 |
+
|
| 32 |
+
# UV lock (package-level, not needed at repo root)
|
| 33 |
+
uv.lock
|
Dockerfile
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM ghcr.io/meta-pytorch/openenv-base:latest
|
| 2 |
+
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
|
| 5 |
+
# Install Python dependencies first so this layer is cached
|
| 6 |
+
COPY sqlsherlock_env/server/requirements.txt ./requirements.txt
|
| 7 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 8 |
+
|
| 9 |
+
# Copy entire repo
|
| 10 |
+
COPY . .
|
| 11 |
+
|
| 12 |
+
EXPOSE 7860
|
| 13 |
+
|
| 14 |
+
# PYTHONPATH so "from models import ..." and "from server.xxx import ..." resolve correctly
|
| 15 |
+
ENV PYTHONPATH=/app/sqlsherlock_env
|
| 16 |
+
|
| 17 |
+
# Health check — must pass before HF Spaces routes traffic
|
| 18 |
+
HEALTHCHECK --interval=30s --timeout=10s --start-period=15s \
|
| 19 |
+
--retries=3 CMD curl -f http://localhost:7860/health || exit 1
|
| 20 |
+
|
| 21 |
+
# Run from sqlsherlock_env/ so relative module paths match the import structure
|
| 22 |
+
WORKDIR /app/sqlsherlock_env
|
| 23 |
+
CMD ["uvicorn", "server.app:app", "--host", "0.0.0.0", \
|
| 24 |
+
"--port", "7860", "--workers", "2"]
|
README.md
ADDED
|
@@ -0,0 +1,569 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: SQLSherlock Env
|
| 3 |
+
emoji: 🔍
|
| 4 |
+
colorFrom: indigo
|
| 5 |
+
colorTo: cyan
|
| 6 |
+
sdk: docker
|
| 7 |
+
app_port: 7860
|
| 8 |
+
tags:
|
| 9 |
+
- openenv
|
| 10 |
+
- reinforcement-learning
|
| 11 |
+
- data-quality
|
| 12 |
+
pinned: false
|
| 13 |
+
---
|
| 14 |
+
|
| 15 |
+
# SQLSherlock-Env
|
| 16 |
+
|
| 17 |
+
An RL environment where an AI agent acts as a data scientist investigating a dirty dataset.
|
| 18 |
+
|
| 19 |
+
The agent discovers real data quality issues through statistical investigation — exactly like a human data scientist — fixes them with documented reasoning, validates fixes against the raw baseline, and exports the cleaned output in the same format as the input.
|
| 20 |
+
|
| 21 |
+
**The environment does NOT plant or inject issues.** Real datasets already have data quality problems. The issue detector scans the dataset at `reset()` time and builds a ground-truth catalogue from what it finds. The agent never sees this catalogue — it must discover everything through investigation.
|
| 22 |
+
|
| 23 |
+
---
|
| 24 |
+
|
| 25 |
+
## Architecture
|
| 26 |
+
|
| 27 |
+
### Episode Flow
|
| 28 |
+
|
| 29 |
+
```
|
| 30 |
+
reset(dataset, task_id)
|
| 31 |
+
│
|
| 32 |
+
▼
|
| 33 |
+
┌───────────────────────────────────────────────────────────────────┐
|
| 34 |
+
│ DatabaseEngine.__init__ │
|
| 35 |
+
│ │
|
| 36 |
+
│ 1. load(source) ← CSV / JSON / JSONL / Parquet / HF │
|
| 37 |
+
│ 2. records_to_sqlite() ← In-memory SQLite, isolated per episode│
|
| 38 |
+
│ 3. deep_copy(originals) ← Immutable snapshot before any edits │
|
| 39 |
+
│ 4. profile_table() ← mean/std/z-scores per column │
|
| 40 |
+
│ 5. detect_issues() ← null / type / constraint / outlier │
|
| 41 |
+
│ duplicate / fk_violation │
|
| 42 |
+
│ 6. Validator(baseline) ← 6-check baseline captured │
|
| 43 |
+
│ 7. detect_trap() ← Task 3 only: plant 2x value in DB │
|
| 44 |
+
└───────────────────────────────────────────────────────────────────┘
|
| 45 |
+
│
|
| 46 |
+
▼
|
| 47 |
+
SQLSherlockObservation returned to agent
|
| 48 |
+
│
|
| 49 |
+
▼
|
| 50 |
+
┌─────────────────────────────────────────────────────┐
|
| 51 |
+
│ Agent Step Loop │
|
| 52 |
+
│ │
|
| 53 |
+
│ ┌──────────────────────────────────────────────┐ │
|
| 54 |
+
│ │ Agent decides action (LLM call) │ │
|
| 55 |
+
│ │ │ │
|
| 56 |
+
│ │ investigate: inspect / profile / run_sql │ │
|
| 57 |
+
│ │ fix: fix_cell / delete_row │ │
|
| 58 |
+
│ │ check: validate │ │
|
| 59 |
+
│ │ end: submit / export │ │
|
| 60 |
+
│ └───────────────────┬──────────────────────────┘ │
|
| 61 |
+
│ │ │
|
| 62 |
+
│ ▼ │
|
| 63 |
+
│ ┌──────────────────────────────────────────────┐ │
|
| 64 |
+
│ │ Environment.step(action) │ │
|
| 65 |
+
│ │ │ │
|
| 66 |
+
│ │ 1. dispatch action → DatabaseEngine │ │
|
| 67 |
+
│ │ 2. reward.calc() → RB breakdown │ │
|
| 68 |
+
│ │ 3. build observation (feedback + results) │ │
|
| 69 |
+
│ │ 4. return (obs, reward, done, info) │ │
|
| 70 |
+
│ └──────────────────────────────────────────────┘ │
|
| 71 |
+
│ │
|
| 72 |
+
│ Repeat until submit/export or budget exhausted │
|
| 73 |
+
└─────────────────────────────────────────────────────┘
|
| 74 |
+
│
|
| 75 |
+
▼
|
| 76 |
+
Grader.score() → final score [0.0 – 1.0]
|
| 77 |
+
```
|
| 78 |
+
|
| 79 |
+
### Component Diagram
|
| 80 |
+
|
| 81 |
+
```
|
| 82 |
+
inference.py / train.py / custom agent
|
| 83 |
+
│ HTTP + WebSocket
|
| 84 |
+
▼
|
| 85 |
+
┌─────────────────────────────────────────────────────────────┐
|
| 86 |
+
│ FastAPI App (server/app.py) │
|
| 87 |
+
│ POST /reset POST /step GET /state GET /health │
|
| 88 |
+
│ WS /ws │
|
| 89 |
+
└──────────────────────┬──────────────────────────────────────┘
|
| 90 |
+
│
|
| 91 |
+
▼
|
| 92 |
+
┌─────────────────────────────────────────────────────────────┐
|
| 93 |
+
│ SQLSherlockEnvironment (server/environment.py) │
|
| 94 |
+
│ │
|
| 95 |
+
│ reset() ─────────────────────────────────────────────► │
|
| 96 |
+
│ DatabaseEngine │
|
| 97 |
+
│ step(action) ─────► dispatch ──────────────────────► │
|
| 98 |
+
│ │ │
|
| 99 |
+
│ │ │
|
| 100 |
+
│ ┌────▼────┐ │
|
| 101 |
+
│ │ reward │ │
|
| 102 |
+
│ │ .calc()│ │
|
| 103 |
+
│ └─────────┘ │
|
| 104 |
+
│ │
|
| 105 |
+
│ on submit/export ─────► Grader.score() │
|
| 106 |
+
└─────────────────────────────────────────────────────────────┘
|
| 107 |
+
│
|
| 108 |
+
┌──────────────┼──────────────────────┐
|
| 109 |
+
▼ ▼ ▼
|
| 110 |
+
┌─────────────┐ ┌─────────────────┐ ┌──────────────────┐
|
| 111 |
+
│ Database │ │ IssueDetector │ │ Validator │
|
| 112 |
+
│ Engine │ │ │ │ │
|
| 113 |
+
│ │ │ detect_issues()│ │ 6-check before/ │
|
| 114 |
+
│ SQLite │ │ detect_trap() │ │ after comparison │
|
| 115 |
+
│ in-memory │ │ │ │ │
|
| 116 |
+
│ per episode│ │ null │ │ null_check │
|
| 117 |
+
│ │ │ type_error │ │ type_check │
|
| 118 |
+
│ profile_ │ │ constraint │ │ range_check │
|
| 119 |
+
│ table() │ │ outlier │ │ distribution_ │
|
| 120 |
+
│ │ │ duplicate │ │ check │
|
| 121 |
+
│ z_scores │ │ fk_violation │ │ duplicate_check │
|
| 122 |
+
│ per row │ │ │ │ outlier_check │
|
| 123 |
+
└─────────────┘ └─────────────────┘ └──────────────────┘
|
| 124 |
+
```
|
| 125 |
+
|
| 126 |
+
### Grading Pipeline (7 steps)
|
| 127 |
+
|
| 128 |
+
```
|
| 129 |
+
submit / export triggered
|
| 130 |
+
│
|
| 131 |
+
▼
|
| 132 |
+
┌─────────────────────────────────────────────────────────────┐
|
| 133 |
+
│ universal.py — 7-step grader │
|
| 134 |
+
│ │
|
| 135 |
+
│ Step 1: Zero-change guard │
|
| 136 |
+
│ └── if nothing changed → score = 0.0 │
|
| 137 |
+
│ │
|
| 138 |
+
│ Step 2: Resolution score (0.0 – 1.0) │
|
| 139 |
+
│ └── per issue: confidence-weighted correct/total │
|
| 140 |
+
│ null: confidence 0.20 – 1.0 (structural=0.20) │
|
| 141 |
+
│ type_error: always 1.0 │
|
| 142 |
+
│ constraint / outlier: 0.80 │
|
| 143 |
+
│ duplicate: 0.70 │
|
| 144 |
+
│ │
|
| 145 |
+
│ Step 3: False-positive penalty │
|
| 146 |
+
│ └── −0.15 per clean cell touched │
|
| 147 |
+
│ │
|
| 148 |
+
│ Step 4: Trap penalty (Task 3 only) │
|
| 149 |
+
│ └── −0.40 if trap cell was modified │
|
| 150 |
+
│ │
|
| 151 |
+
│ Step 5: Validation score (0.0 – 0.30) │
|
| 152 |
+
│ └── checks_passed / total_checks × 0.30 │
|
| 153 |
+
│ │
|
| 154 |
+
│ Step 6: Reasoning bonus (0.0 – 0.10) │
|
| 155 |
+
│ └── +0.02 per fix_cell/delete_row with reason str │
|
| 156 |
+
│ │
|
| 157 |
+
│ Step 7: Final clamp │
|
| 158 |
+
│ raw = res×0.60 + val×0.30 + bonus×0.10 − fp − trap│
|
| 159 |
+
│ score = clamp(raw, 0.0, 1.0) │
|
| 160 |
+
└─────────────────────────────────────────────────────────────┘
|
| 161 |
+
```
|
| 162 |
+
|
| 163 |
+
---
|
| 164 |
+
|
| 165 |
+
## Quick Start
|
| 166 |
+
|
| 167 |
+
### 1. Docker (recommended)
|
| 168 |
+
|
| 169 |
+
```bash
|
| 170 |
+
# Build from repo root
|
| 171 |
+
docker build -t sqlsherlock-env:latest .
|
| 172 |
+
|
| 173 |
+
# Run
|
| 174 |
+
docker run -p 7860:7860 sqlsherlock-env:latest
|
| 175 |
+
|
| 176 |
+
# Verify
|
| 177 |
+
curl http://localhost:7860/health
|
| 178 |
+
```
|
| 179 |
+
|
| 180 |
+
### 2. Local (without Docker)
|
| 181 |
+
|
| 182 |
+
```bash
|
| 183 |
+
cd sqlsherlock_env
|
| 184 |
+
pip install -r server/requirements.txt
|
| 185 |
+
uvicorn server.app:app --host 0.0.0.0 --port 7860
|
| 186 |
+
```
|
| 187 |
+
|
| 188 |
+
### 3. Run baseline inference
|
| 189 |
+
|
| 190 |
+
```bash
|
| 191 |
+
export API_BASE_URL="https://router.huggingface.co/v1"
|
| 192 |
+
export MODEL_NAME="Qwen/Qwen2.5-72B-Instruct"
|
| 193 |
+
export HF_TOKEN="hf_..."
|
| 194 |
+
export SPACE_URL="http://localhost:7860"
|
| 195 |
+
|
| 196 |
+
python inference.py
|
| 197 |
+
```
|
| 198 |
+
|
| 199 |
+
Expected stdout (judges parse this exactly):
|
| 200 |
+
|
| 201 |
+
```
|
| 202 |
+
[START] task=task1_null_and_types env=sqlsherlock_env model=Qwen/Qwen2.5-72B-Instruct
|
| 203 |
+
[STEP] step=1 action=inspect reward=0.02 done=false error=null
|
| 204 |
+
[STEP] step=2 action=profile_column(age) reward=0.03 done=false error=null
|
| 205 |
+
...
|
| 206 |
+
[END] success=true steps=8 score=0.820 rewards=0.02,0.03,0.15,0.15,0.05,0.15,0.10
|
| 207 |
+
```
|
| 208 |
+
|
| 209 |
+
---
|
| 210 |
+
|
| 211 |
+
## Using Your Own Dataset
|
| 212 |
+
|
| 213 |
+
`inference.py` uses `phihung/titanic` for hackathon validation. To use your own dataset, connect the client directly:
|
| 214 |
+
|
| 215 |
+
### HuggingFace dataset
|
| 216 |
+
|
| 217 |
+
```python
|
| 218 |
+
from sqlsherlock_env.client import SQLSherlockEnv
|
| 219 |
+
|
| 220 |
+
env = SQLSherlockEnv(base_url="http://localhost:7860")
|
| 221 |
+
obs = env.reset(
|
| 222 |
+
dataset="your_org/your_dataset", # any public HF dataset
|
| 223 |
+
task_id="task1_null_and_types",
|
| 224 |
+
max_rows=500,
|
| 225 |
+
)
|
| 226 |
+
```
|
| 227 |
+
|
| 228 |
+
### Local file (CSV / JSON / JSONL / Parquet)
|
| 229 |
+
|
| 230 |
+
```python
|
| 231 |
+
obs = env.reset(
|
| 232 |
+
dataset="/absolute/path/to/data.csv",
|
| 233 |
+
task_id="task2_constraints_and_fk",
|
| 234 |
+
)
|
| 235 |
+
```
|
| 236 |
+
|
| 237 |
+
### Raw CSV string
|
| 238 |
+
|
| 239 |
+
```python
|
| 240 |
+
csv_text = "id,name,age,fare\n1,Alice,,25.0\n2,Bob,FORTY,50.0\n..."
|
| 241 |
+
obs = env.reset(
|
| 242 |
+
dataset=csv_text,
|
| 243 |
+
task_id="task1_null_and_types",
|
| 244 |
+
)
|
| 245 |
+
```
|
| 246 |
+
|
| 247 |
+
### Upload via API
|
| 248 |
+
|
| 249 |
+
```bash
|
| 250 |
+
curl -X POST http://localhost:7860/upload_dataset \
|
| 251 |
+
-F "file=@data.csv" \
|
| 252 |
+
-F "task_id=task1_null_and_types"
|
| 253 |
+
```
|
| 254 |
+
|
| 255 |
+
**What the environment does with your dataset:**
|
| 256 |
+
1. Loads the data (any format above)
|
| 257 |
+
2. Auto-detects column types (int / float / str / bool)
|
| 258 |
+
3. Scans for real data quality issues — no injection
|
| 259 |
+
4. Builds a ground-truth issue catalogue the agent never sees
|
| 260 |
+
5. Plants a trap value in Task 3
|
| 261 |
+
|
| 262 |
+
The agent then investigates, fixes, validates, and exports. The exported file matches the input format (CSV in → CSV out, Parquet in → Parquet out).
|
| 263 |
+
|
| 264 |
+
---
|
| 265 |
+
|
| 266 |
+
## Action Space
|
| 267 |
+
|
| 268 |
+
| `action_type` | Required fields | Description |
|
| 269 |
+
|---|---|---|
|
| 270 |
+
| `inspect` | `table` | View all rows |
|
| 271 |
+
| `profile_column` | `table`, `column` | Stats: mean/std/min/max/nulls/z-scores |
|
| 272 |
+
| `run_sql` | `sql` | SELECT query (read-only, max 50 rows) |
|
| 273 |
+
| `fix_cell` | `table`, `row_id`, `column`, `value`, `reason` | Fix one cell with justification |
|
| 274 |
+
| `fix_column` | `table`, `column`, `value`, `reason` | Fix ALL nulls in a column at once (bulk) |
|
| 275 |
+
| `delete_row` | `table`, `row_id`, `reason` | Remove duplicate or FK row |
|
| 276 |
+
| `validate` | — | Run all 6 before/after checks |
|
| 277 |
+
| `submit` | — | Score and end episode |
|
| 278 |
+
| `export` | — | Write cleaned file, score and end episode |
|
| 279 |
+
|
| 280 |
+
---
|
| 281 |
+
|
| 282 |
+
## Reward System
|
| 283 |
+
|
| 284 |
+
| Action | Reward | Cap |
|
| 285 |
+
|---|---|---|
|
| 286 |
+
| `inspect` | +0.02 | 3 rewarded |
|
| 287 |
+
| `profile_column` | +0.03 | 3 rewarded |
|
| 288 |
+
| `run_sql` | +0.03 | 3 rewarded |
|
| 289 |
+
| `validate` | +0.05 × (checks_passed / 6) | 2 rewarded |
|
| 290 |
+
| `fix_cell` — correct | **+0.15** | — |
|
| 291 |
+
| `fix_cell` — false positive | **−0.20** | — |
|
| 292 |
+
| `fix_cell` — trap cell | **−0.40** | — |
|
| 293 |
+
| `fix_cell` — wrong value | **−0.10** | — |
|
| 294 |
+
| `delete_row` — valid | **+0.15** | — |
|
| 295 |
+
| `delete_row` — false positive | **−0.20** | — |
|
| 296 |
+
| `submit` — all resolved | +0.10 | — |
|
| 297 |
+
| `submit` — issues remain | −0.10 | — |
|
| 298 |
+
|
| 299 |
+
---
|
| 300 |
+
|
| 301 |
+
## Three Tasks
|
| 302 |
+
|
| 303 |
+
### Task 1 — `task1_null_and_types` (Easy, max 20 steps)
|
| 304 |
+
|
| 305 |
+
Find and fix **null values** and **type errors**.
|
| 306 |
+
|
| 307 |
+
- Null: `None` or empty string in any non-PK column
|
| 308 |
+
- Type error: text in a numeric column (e.g. `"FORTY"` in age)
|
| 309 |
+
- Score: `resolution × 0.70 + validation × 0.30`
|
| 310 |
+
|
| 311 |
+
### Task 2 — `task2_constraints_and_fk` (Medium, max 25 steps)
|
| 312 |
+
|
| 313 |
+
Everything in Task 1 plus:
|
| 314 |
+
|
| 315 |
+
- **Constraint violations**: negative values in must-be-positive columns (age, fare, price)
|
| 316 |
+
- **FK violations**: orphan references in related tables
|
| 317 |
+
|
| 318 |
+
### Task 3 — `task3_full_audit_with_trap` (Hard, max 30 steps)
|
| 319 |
+
|
| 320 |
+
Full audit including:
|
| 321 |
+
|
| 322 |
+
- **Statistical outliers**: z-score > 5 in any numeric column
|
| 323 |
+
- **Duplicates**: natural key appearing more than once
|
| 324 |
+
|
| 325 |
+
**THE TRAP**: One numeric value is set to 2x original — looks suspicious but has `z < 3`. Touching it costs **−0.40**.
|
| 326 |
+
|
| 327 |
+
> Rule: Always `profile_column` before fixing any numeric value.
|
| 328 |
+
> `z > 5` → real outlier → fix it. `z < 3` → legitimate → leave it.
|
| 329 |
+
|
| 330 |
+
---
|
| 331 |
+
|
| 332 |
+
## Validation (6 Checks)
|
| 333 |
+
|
| 334 |
+
Run with `validate` action. Compares current state against the baseline from `reset()`:
|
| 335 |
+
|
| 336 |
+
| Check | Passes when |
|
| 337 |
+
|---|---|
|
| 338 |
+
| `null_check` | High-confidence nulls resolved |
|
| 339 |
+
| `type_check` | All type errors castable to float |
|
| 340 |
+
| `range_check` | No negatives in must-be-positive columns |
|
| 341 |
+
| `distribution_check` | Column mean drift < 20% |
|
| 342 |
+
| `duplicate_check` | Duplicate count reduced |
|
| 343 |
+
| `outlier_check` | No previously-flagged rows still exceed z > 5 |
|
| 344 |
+
|
| 345 |
+
Returns `PASS` / `PARTIAL` / `FAIL` with per-check detail and drift warnings.
|
| 346 |
+
|
| 347 |
+
---
|
| 348 |
+
|
| 349 |
+
## API Reference
|
| 350 |
+
|
| 351 |
+
| Method | Path | Description |
|
| 352 |
+
|---|---|---|
|
| 353 |
+
| `WS` | `/ws` | Persistent WebSocket session |
|
| 354 |
+
| `POST` | `/reset` | Reset environment, load dataset |
|
| 355 |
+
| `POST` | `/step` | Execute one action |
|
| 356 |
+
| `GET` | `/state` | Current episode state |
|
| 357 |
+
| `GET` | `/health` | Health check (`{"status":"ok"}`) |
|
| 358 |
+
| `GET` | `/tasks` | List all 3 tasks |
|
| 359 |
+
| `POST` | `/upload_dataset` | Upload file, get session |
|
| 360 |
+
| `GET` | `/download/{file_id}` | Download cleaned output |
|
| 361 |
+
| `GET` | `/docs` | OpenAPI docs (Swagger UI) |
|
| 362 |
+
|
| 363 |
+
---
|
| 364 |
+
|
| 365 |
+
## Testing
|
| 366 |
+
|
| 367 |
+
### Run all tests
|
| 368 |
+
|
| 369 |
+
```bash
|
| 370 |
+
cd SQLSherlock-env
|
| 371 |
+
pip install pytest
|
| 372 |
+
pytest tests/ -v
|
| 373 |
+
```
|
| 374 |
+
|
| 375 |
+
### Test checklist
|
| 376 |
+
|
| 377 |
+
```
|
| 378 |
+
tests/test_issue_detector.py ← null / type_error / constraint / outlier / duplicate
|
| 379 |
+
tests/test_graders.py ← task1 / task2 / task3 scoring, trap penalty, FP penalty
|
| 380 |
+
tests/test_environment.py ← reset → step → submit full episode
|
| 381 |
+
```
|
| 382 |
+
|
| 383 |
+
Expected: all tests pass. If any fail, check [tests/conftest.py](tests/conftest.py) — the `DIRTY_RECORDS` fixture must cover all issue types.
|
| 384 |
+
|
| 385 |
+
### Manual smoke test
|
| 386 |
+
|
| 387 |
+
```bash
|
| 388 |
+
# 1. Start server
|
| 389 |
+
docker run -p 7860:7860 sqlsherlock-env:latest
|
| 390 |
+
|
| 391 |
+
# 2. Health check
|
| 392 |
+
curl http://localhost:7860/health
|
| 393 |
+
# → {"status":"ok"}
|
| 394 |
+
|
| 395 |
+
# 3. List tasks
|
| 396 |
+
curl http://localhost:7860/tasks
|
| 397 |
+
# → [{id: task1_null_and_types, ...}, ...]
|
| 398 |
+
|
| 399 |
+
# 4. Run inference (requires HF_TOKEN for model access)
|
| 400 |
+
export HF_TOKEN="hf_..."
|
| 401 |
+
python inference.py 2>results.txt
|
| 402 |
+
# → check stdout for [START]/[STEP]/[END] lines
|
| 403 |
+
# → check stderr (results.txt) for score summary
|
| 404 |
+
```
|
| 405 |
+
|
| 406 |
+
---
|
| 407 |
+
|
| 408 |
+
## Submission Checklist
|
| 409 |
+
|
| 410 |
+
```
|
| 411 |
+
[ ] docker build -t sqlsherlock-env:latest . ← must succeed from repo root
|
| 412 |
+
[ ] docker run -p 7860:7860 sqlsherlock-env:latest ← must start, port 7860
|
| 413 |
+
[ ] curl http://localhost:7860/health ← must return {"status":"ok"}
|
| 414 |
+
[ ] python inference.py ← must emit [START]/[STEP]/[END]
|
| 415 |
+
[ ] openenv validate ← must pass (openenv.yaml at root)
|
| 416 |
+
[ ] Dockerfile is at repo root (not inside subdir) ← validate-submission.sh checks this
|
| 417 |
+
[ ] openenv.yaml is at repo root ← openenv validate checks this
|
| 418 |
+
[ ] No hardcoded secrets in any file ← use env vars only
|
| 419 |
+
[ ] All env vars documented (API_BASE_URL, MODEL_NAME, HF_TOKEN, SPACE_URL)
|
| 420 |
+
[ ] pytest tests/ -v ← all tests pass
|
| 421 |
+
```
|
| 422 |
+
|
| 423 |
+
---
|
| 424 |
+
|
| 425 |
+
## Setup on a New Device
|
| 426 |
+
|
| 427 |
+
### Option A: Docker (recommended for deployment)
|
| 428 |
+
|
| 429 |
+
```bash
|
| 430 |
+
# 1. Clone
|
| 431 |
+
git clone <your-repo-url>
|
| 432 |
+
cd SQLSherlock-env
|
| 433 |
+
|
| 434 |
+
# 2. Build and run
|
| 435 |
+
docker build -t sqlsherlock-env:latest .
|
| 436 |
+
docker run -p 7860:7860 sqlsherlock-env:latest
|
| 437 |
+
|
| 438 |
+
# 3. Verify (in another terminal)
|
| 439 |
+
curl http://localhost:7860/health
|
| 440 |
+
# → {"status":"healthy"}
|
| 441 |
+
|
| 442 |
+
# 4. Run inference
|
| 443 |
+
export HF_TOKEN="hf_your_token_here"
|
| 444 |
+
export SPACE_URL="http://localhost:7860"
|
| 445 |
+
python inference.py
|
| 446 |
+
```
|
| 447 |
+
|
| 448 |
+
### Option B: Local Python (for development)
|
| 449 |
+
|
| 450 |
+
```bash
|
| 451 |
+
# 1. Clone
|
| 452 |
+
git clone <your-repo-url>
|
| 453 |
+
cd SQLSherlock-env
|
| 454 |
+
|
| 455 |
+
# 2. Create virtual environment (Python 3.11+ required)
|
| 456 |
+
python -m venv .venv
|
| 457 |
+
|
| 458 |
+
# 3. Activate venv
|
| 459 |
+
# Linux/Mac:
|
| 460 |
+
source .venv/bin/activate
|
| 461 |
+
# Windows PowerShell:
|
| 462 |
+
.venv\Scripts\Activate.ps1
|
| 463 |
+
# Windows CMD:
|
| 464 |
+
.venv\Scripts\activate.bat
|
| 465 |
+
|
| 466 |
+
# 4. Install dependencies
|
| 467 |
+
pip install -r sqlsherlock_env/server/requirements.txt
|
| 468 |
+
pip install pytest # for tests
|
| 469 |
+
|
| 470 |
+
# 5. Start the server (Terminal 1)
|
| 471 |
+
cd sqlsherlock_env
|
| 472 |
+
# Linux/Mac:
|
| 473 |
+
PYTHONPATH=. uvicorn server.app:app --host 0.0.0.0 --port 7860
|
| 474 |
+
# Windows PowerShell:
|
| 475 |
+
$env:PYTHONPATH = (Get-Location).Path
|
| 476 |
+
python -m uvicorn server.app:app --host 0.0.0.0 --port 7860
|
| 477 |
+
|
| 478 |
+
# 6. Run inference (Terminal 2)
|
| 479 |
+
cd SQLSherlock-env
|
| 480 |
+
# Linux/Mac:
|
| 481 |
+
export HF_TOKEN="hf_your_token_here"
|
| 482 |
+
export SPACE_URL="http://localhost:7860"
|
| 483 |
+
python inference.py
|
| 484 |
+
# Windows PowerShell:
|
| 485 |
+
$env:HF_TOKEN = "hf_your_token_here"
|
| 486 |
+
$env:SPACE_URL = "http://localhost:7860"
|
| 487 |
+
python inference.py
|
| 488 |
+
|
| 489 |
+
# 7. Run tests (server not needed for tests)
|
| 490 |
+
cd SQLSherlock-env
|
| 491 |
+
# Linux/Mac:
|
| 492 |
+
PYTHONPATH=sqlsherlock_env pytest tests/ -v
|
| 493 |
+
# Windows PowerShell:
|
| 494 |
+
$env:PYTHONPATH = "sqlsherlock_env"
|
| 495 |
+
python -m pytest tests/ -v
|
| 496 |
+
```
|
| 497 |
+
|
| 498 |
+
**Python version**: 3.11+ required. Dependencies: `fastapi`, `uvicorn`, `openai`, `datasets`, `pandas`, `pyarrow`.
|
| 499 |
+
|
| 500 |
+
---
|
| 501 |
+
|
| 502 |
+
## GRPO Training
|
| 503 |
+
|
| 504 |
+
```bash
|
| 505 |
+
pip install trl transformers torch
|
| 506 |
+
|
| 507 |
+
export SPACE_URL="http://localhost:7860"
|
| 508 |
+
export MODEL_ID="Qwen/Qwen2.5-1.5B-Instruct"
|
| 509 |
+
python train.py
|
| 510 |
+
```
|
| 511 |
+
|
| 512 |
+
---
|
| 513 |
+
|
| 514 |
+
## Environment Variables
|
| 515 |
+
|
| 516 |
+
| Variable | Default | Description |
|
| 517 |
+
|---|---|---|
|
| 518 |
+
| `API_BASE_URL` | `https://router.huggingface.co/v1` | LLM endpoint |
|
| 519 |
+
| `MODEL_NAME` | `Qwen/Qwen2.5-72B-Instruct` | Model ID |
|
| 520 |
+
| `HF_TOKEN` | — | HuggingFace token (dataset access + LLM) |
|
| 521 |
+
| `SPACE_URL` | `http://localhost:7860` | Environment server URL |
|
| 522 |
+
|
| 523 |
+
---
|
| 524 |
+
|
| 525 |
+
## Baseline Scores (phihung/titanic, 150 rows)
|
| 526 |
+
|
| 527 |
+
| Task | Difficulty | Expected Score |
|
| 528 |
+
|---|---|---|
|
| 529 |
+
| `task1_null_and_types` | Easy | 0.70 – 0.88 |
|
| 530 |
+
| `task2_constraints_and_fk` | Medium | 0.55 – 0.76 |
|
| 531 |
+
| `task3_full_audit_with_trap` | Hard | 0.40 – 0.65 |
|
| 532 |
+
|
| 533 |
+
---
|
| 534 |
+
|
| 535 |
+
## Project Structure
|
| 536 |
+
|
| 537 |
+
```
|
| 538 |
+
SQLSherlock-env/
|
| 539 |
+
├── Dockerfile ← repo root (required for HF Spaces)
|
| 540 |
+
├── README.md ← this file
|
| 541 |
+
├── openenv.yaml ← OpenEnv + HF Spaces manifest (repo root)
|
| 542 |
+
├── inference.py ← baseline agent ([START]/[STEP]/[END] format)
|
| 543 |
+
├── train.py ← TRL GRPO training loop
|
| 544 |
+
├── sqlsherlock_env/
|
| 545 |
+
│ ├── __init__.py
|
| 546 |
+
│ ├── client.py ← SQLSherlockEnv WebSocket/HTTP client
|
| 547 |
+
│ ├── models.py ← Action / Observation / State (Pydantic)
|
| 548 |
+
│ └── server/
|
| 549 |
+
│ ├── app.py ← FastAPI application + WebSocket handler
|
| 550 |
+
│ ├── environment.py ← RL core: reset() / step() / get_state()
|
| 551 |
+
│ ├── database.py ← In-memory SQLite engine, per-episode
|
| 552 |
+
│ ├── dataset_loader.py ← CSV / JSON / JSONL / Parquet / HF loader
|
| 553 |
+
│ ├── schema_profiler.py ← Column statistics + z-scores
|
| 554 |
+
│ ├── issue_detector.py ← Real issue detection + trap planting
|
| 555 |
+
│ ├── validator.py ← 6-check before/after validator
|
| 556 |
+
│ ├── reward.py ← Dense per-step reward with InvestCounter
|
| 557 |
+
│ ├── exporter.py ← Format-fidelity output (CSV→CSV, etc.)
|
| 558 |
+
│ ├── requirements.txt
|
| 559 |
+
│ └── graders/
|
| 560 |
+
│ ├── universal.py ← 7-step scoring pipeline
|
| 561 |
+
│ ├── task1.py ← Task 1 grader
|
| 562 |
+
│ ├── task2.py ← Task 2 grader
|
| 563 |
+
│ └── task3.py ← Task 3 grader (trap-aware)
|
| 564 |
+
└── tests/
|
| 565 |
+
├── conftest.py ← DIRTY_RECORDS fixture (all issue types)
|
| 566 |
+
├── test_issue_detector.py
|
| 567 |
+
├── test_graders.py
|
| 568 |
+
└── test_environment.py
|
| 569 |
+
```
|
inference.py
ADDED
|
@@ -0,0 +1,394 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
SQLSherlock-Env — Baseline Inference Script.
|
| 9 |
+
|
| 10 |
+
STDOUT FORMAT (mandatory — judges parse this exactly):
|
| 11 |
+
|
| 12 |
+
[START] task=<task_name> env=sqlsherlock_env model=<model_name>
|
| 13 |
+
[STEP] step=<n> action=<action_str> reward=<0.00> done=<true|false> error=<msg|null>
|
| 14 |
+
[END] success=<true|false> steps=<n> score=<0.000> rewards=<r1,r2,...>
|
| 15 |
+
|
| 16 |
+
Environment variables:
|
| 17 |
+
API_BASE_URL LLM endpoint (default: https://router.huggingface.co/v1)
|
| 18 |
+
MODEL_NAME Model id (default: Qwen/Qwen2.5-72B-Instruct)
|
| 19 |
+
HF_TOKEN HuggingFace / API key
|
| 20 |
+
SPACE_URL Server URL (default: http://localhost:7860)
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
import json
|
| 24 |
+
import os
|
| 25 |
+
import re
|
| 26 |
+
import sys
|
| 27 |
+
import time
|
| 28 |
+
from typing import Any, Optional
|
| 29 |
+
|
| 30 |
+
from openai import OpenAI
|
| 31 |
+
|
| 32 |
+
# ---------------------------------------------------------------------------
|
| 33 |
+
# Configuration
|
| 34 |
+
# ---------------------------------------------------------------------------
|
| 35 |
+
|
| 36 |
+
DEMO_DATASET = "phihung/titanic"
|
| 37 |
+
INFERENCE_MAX_ROWS = 500
|
| 38 |
+
ENV_NAME = "sqlsherlock_env"
|
| 39 |
+
|
| 40 |
+
API_BASE_URL = os.getenv("API_BASE_URL") or "https://router.huggingface.co/v1"
|
| 41 |
+
MODEL_NAME = os.getenv("MODEL_NAME") or "Qwen/Qwen2.5-72B-Instruct"
|
| 42 |
+
API_KEY = os.getenv("HF_TOKEN") or os.getenv("OPENAI_API_KEY") or "none"
|
| 43 |
+
SPACE_URL = os.getenv("SPACE_URL", "http://localhost:7860")
|
| 44 |
+
|
| 45 |
+
STEP_BUDGETS: dict[str, int] = {
|
| 46 |
+
"task1_null_and_types": 20,
|
| 47 |
+
"task2_constraints_and_fk": 25,
|
| 48 |
+
"task3_full_audit_with_trap": 30,
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
TASKS = [
|
| 52 |
+
("task1_null_and_types", "easy"),
|
| 53 |
+
("task2_constraints_and_fk", "medium"),
|
| 54 |
+
("task3_full_audit_with_trap", "hard"),
|
| 55 |
+
]
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
# ---------------------------------------------------------------------------
|
| 59 |
+
# Mandatory log helpers
|
| 60 |
+
# ---------------------------------------------------------------------------
|
| 61 |
+
|
| 62 |
+
def log_start(task: str, model: str) -> None:
|
| 63 |
+
print(f"[START] task={task} env={ENV_NAME} model={model}", flush=True)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def log_step(step: int, action: str, reward: float, done: bool,
|
| 67 |
+
error: Optional[str] = None) -> None:
|
| 68 |
+
action_str = action.replace("\n", " ").replace("\r", " ").strip()[:120]
|
| 69 |
+
print(
|
| 70 |
+
f"[STEP] step={step} action={action_str} "
|
| 71 |
+
f"reward={reward:.2f} done={str(done).lower()} "
|
| 72 |
+
f"error={error if error else 'null'}",
|
| 73 |
+
flush=True,
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def log_end(success: bool, steps: int, score: float, rewards: list[float]) -> None:
|
| 78 |
+
rewards_str = ",".join(f"{r:.2f}" for r in rewards)
|
| 79 |
+
print(
|
| 80 |
+
f"[END] success={str(success).lower()} steps={steps} "
|
| 81 |
+
f"score={score:.3f} rewards={rewards_str}",
|
| 82 |
+
flush=True,
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def _parse_score(feedback: str) -> Optional[float]:
|
| 87 |
+
m = re.search(r"[Gg]rader\s+score\s*=?\s*(\d+\.\d+)", feedback)
|
| 88 |
+
if m:
|
| 89 |
+
try:
|
| 90 |
+
return float(m.group(1))
|
| 91 |
+
except (ValueError, TypeError):
|
| 92 |
+
pass
|
| 93 |
+
return None
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def _label(d: dict) -> str:
|
| 97 |
+
a = d.get("action_type", "?")
|
| 98 |
+
if a == "fix_cell":
|
| 99 |
+
return f"fix_cell(row={d.get('row_id')},col={d.get('column')},val={d.get('value')})"
|
| 100 |
+
if a == "profile_column":
|
| 101 |
+
return f"profile_column({d.get('column')})"
|
| 102 |
+
if a == "run_sql":
|
| 103 |
+
return f"run_sql({(d.get('sql') or '')[:40]})"
|
| 104 |
+
if a == "delete_row":
|
| 105 |
+
return f"delete_row(row={d.get('row_id')})"
|
| 106 |
+
return a
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
# ---------------------------------------------------------------------------
|
| 110 |
+
# LLM-assisted action selection
|
| 111 |
+
# ---------------------------------------------------------------------------
|
| 112 |
+
|
| 113 |
+
SYSTEM_PROMPT = """You are an expert data scientist performing a complete data quality audit.
|
| 114 |
+
You have inspected and profiled the data. Now decide the NEXT action.
|
| 115 |
+
|
| 116 |
+
Available actions (respond with ONE JSON object):
|
| 117 |
+
- {"action_type":"profile_column","table":"T","column":"C"}
|
| 118 |
+
- {"action_type":"run_sql","sql":"SELECT id FROM \\"T\\" WHERE \\"C\\" IS NULL LIMIT 10"}
|
| 119 |
+
- {"action_type":"fix_column","table":"T","column":"C","value":29.5,"reason":"bulk fix nulls, median=29.5"}
|
| 120 |
+
- {"action_type":"fix_cell","table":"T","row_id":6,"column":"C","value":29.5,"reason":"null, median=29.5"}
|
| 121 |
+
- {"action_type":"delete_row","table":"T","row_id":8,"reason":"duplicate Name"}
|
| 122 |
+
- {"action_type":"validate"}
|
| 123 |
+
- {"action_type":"submit"}
|
| 124 |
+
|
| 125 |
+
PREFER fix_column over fix_cell when fixing nulls — it fixes ALL nulls in one step.
|
| 126 |
+
|
| 127 |
+
RULES:
|
| 128 |
+
- Use EXACT row ids from SQL results. NEVER guess ids.
|
| 129 |
+
- Numeric nulls → column MEDIAN. String nulls → column MODE.
|
| 130 |
+
- Skip columns with null_rate >= 0.70 (structural nulls).
|
| 131 |
+
- TRAP: z > 5 = outlier fix. z < 3 = normal, DO NOT TOUCH.
|
| 132 |
+
|
| 133 |
+
Respond with ONLY one JSON object. No markdown, no text."""
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def _call_llm(client: OpenAI, messages: list[dict]) -> Optional[dict]:
|
| 137 |
+
"""Call LLM and parse JSON. Returns None on failure."""
|
| 138 |
+
try:
|
| 139 |
+
resp = client.chat.completions.create(
|
| 140 |
+
model=MODEL_NAME, messages=messages,
|
| 141 |
+
max_tokens=300, temperature=0.0,
|
| 142 |
+
)
|
| 143 |
+
raw = (resp.choices[0].message.content or "").strip()
|
| 144 |
+
raw = re.sub(r"^```[a-z]*\n?", "", raw)
|
| 145 |
+
raw = re.sub(r"\n?```\s*$", "", raw)
|
| 146 |
+
raw = raw.strip()
|
| 147 |
+
if not raw.startswith("{"):
|
| 148 |
+
start = raw.find("{")
|
| 149 |
+
end = raw.rfind("}")
|
| 150 |
+
if start >= 0 and end > start:
|
| 151 |
+
raw = raw[start:end + 1]
|
| 152 |
+
return json.loads(raw)
|
| 153 |
+
except Exception:
|
| 154 |
+
return None
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
# ---------------------------------------------------------------------------
|
| 158 |
+
# Smart data scientist workflow (programmatic + LLM hybrid)
|
| 159 |
+
# ---------------------------------------------------------------------------
|
| 160 |
+
|
| 161 |
+
def _build_action_plan(
|
| 162 |
+
env, table: str, columns: list[str], task_id: str, llm: OpenAI,
|
| 163 |
+
) -> list[dict]:
|
| 164 |
+
"""Build a complete action plan by profiling all columns, then fixing issues.
|
| 165 |
+
|
| 166 |
+
This is the core data scientist workflow:
|
| 167 |
+
1. Inspect the table
|
| 168 |
+
2. Profile each column to understand statistics
|
| 169 |
+
3. For each column with issues, query and fix
|
| 170 |
+
4. Validate and submit
|
| 171 |
+
"""
|
| 172 |
+
from models import SQLSherlockAction
|
| 173 |
+
|
| 174 |
+
plan: list[dict] = []
|
| 175 |
+
col_stats: dict[str, dict] = {}
|
| 176 |
+
visible_cols = [c for c in columns if c not in ("id", "_source_format")]
|
| 177 |
+
|
| 178 |
+
# Step 1: Inspect
|
| 179 |
+
plan.append({"action_type": "inspect", "table": table})
|
| 180 |
+
|
| 181 |
+
# Step 2: Profile key columns (max 3 rewarded, but profile more for info)
|
| 182 |
+
for col in visible_cols[:6]:
|
| 183 |
+
plan.append({"action_type": "profile_column", "table": table, "column": col})
|
| 184 |
+
|
| 185 |
+
# We'll execute the plan up to here, collect profiles, then build fix actions
|
| 186 |
+
return plan
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def run_task(task_id: str) -> float:
|
| 190 |
+
pkg_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "sqlsherlock_env")
|
| 191 |
+
if pkg_dir not in sys.path:
|
| 192 |
+
sys.path.insert(0, pkg_dir)
|
| 193 |
+
|
| 194 |
+
from client import SQLSherlockEnv
|
| 195 |
+
from models import SQLSherlockAction
|
| 196 |
+
|
| 197 |
+
budget = STEP_BUDGETS[task_id]
|
| 198 |
+
rewards: list[float] = []
|
| 199 |
+
steps_taken = 0
|
| 200 |
+
score = 0.0
|
| 201 |
+
success = False
|
| 202 |
+
|
| 203 |
+
log_start(task=task_id, model=MODEL_NAME)
|
| 204 |
+
|
| 205 |
+
try:
|
| 206 |
+
llm = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
|
| 207 |
+
except Exception as exc:
|
| 208 |
+
log_step(1, "init_llm", 0.0, True, str(exc)[:80])
|
| 209 |
+
log_end(False, 0, 0.0, [])
|
| 210 |
+
return 0.0
|
| 211 |
+
|
| 212 |
+
env = SQLSherlockEnv(base_url=SPACE_URL)
|
| 213 |
+
|
| 214 |
+
try:
|
| 215 |
+
# --- Reset ---
|
| 216 |
+
try:
|
| 217 |
+
obs = env.reset(dataset=DEMO_DATASET, task_id=task_id,
|
| 218 |
+
max_rows=INFERENCE_MAX_ROWS)
|
| 219 |
+
except Exception as exc:
|
| 220 |
+
log_step(1, "reset", 0.0, True, str(exc)[:80])
|
| 221 |
+
log_end(False, 0, 0.0, [])
|
| 222 |
+
return 0.0
|
| 223 |
+
|
| 224 |
+
table = list(obs.tables_summary.keys())[0] if obs.tables_summary else "dataset"
|
| 225 |
+
columns = obs.tables_summary.get(table, {}).get("columns", [])
|
| 226 |
+
visible_cols = [c for c in columns if c not in ("id", "_source_format")]
|
| 227 |
+
|
| 228 |
+
done = False
|
| 229 |
+
step_num = 0
|
| 230 |
+
col_profiles: dict[str, dict] = {} # column → profile stats
|
| 231 |
+
llm_messages = [
|
| 232 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 233 |
+
]
|
| 234 |
+
|
| 235 |
+
def _do_step(action_dict: dict) -> tuple:
|
| 236 |
+
nonlocal step_num, done, obs
|
| 237 |
+
step_num += 1
|
| 238 |
+
if step_num > budget or done:
|
| 239 |
+
return 0.0, True
|
| 240 |
+
action = SQLSherlockAction(**{k: v for k, v in action_dict.items() if v is not None})
|
| 241 |
+
try:
|
| 242 |
+
obs, reward, done, _ = env.step(action)
|
| 243 |
+
reward = float(reward or 0.0)
|
| 244 |
+
except Exception as exc:
|
| 245 |
+
reward = 0.0
|
| 246 |
+
rewards.append(reward)
|
| 247 |
+
log_step(step_num, _label(action_dict), reward, done, None)
|
| 248 |
+
return reward, done
|
| 249 |
+
|
| 250 |
+
# ===== PHASE 1: Inspect =====
|
| 251 |
+
_do_step({"action_type": "inspect", "table": table})
|
| 252 |
+
|
| 253 |
+
# ===== PHASE 2: Profile + Bulk Fix interleaved =====
|
| 254 |
+
# Profile each column. If it has fixable nulls, use fix_column to
|
| 255 |
+
# fix ALL nulls in ONE step. This handles the complete dataset.
|
| 256 |
+
for col in visible_cols:
|
| 257 |
+
if done or step_num >= budget - 2:
|
| 258 |
+
break
|
| 259 |
+
|
| 260 |
+
# Profile this column
|
| 261 |
+
_do_step({"action_type": "profile_column", "table": table, "column": col})
|
| 262 |
+
if not obs.query_result or len(obs.query_result) == 0:
|
| 263 |
+
continue
|
| 264 |
+
profile = obs.query_result[0]
|
| 265 |
+
col_profiles[col] = profile
|
| 266 |
+
|
| 267 |
+
null_count = profile.get("null_count", 0)
|
| 268 |
+
null_rate = profile.get("null_rate", 0.0)
|
| 269 |
+
dtype = profile.get("dtype", "unknown")
|
| 270 |
+
median_val = profile.get("median")
|
| 271 |
+
mode_val = profile.get("mode")
|
| 272 |
+
mean_val = profile.get("mean")
|
| 273 |
+
|
| 274 |
+
# Skip if no nulls at all
|
| 275 |
+
if null_count == 0:
|
| 276 |
+
continue
|
| 277 |
+
|
| 278 |
+
# For high-null columns (structural), still fix but with "Unknown"
|
| 279 |
+
# These have low confidence in the grader but still count toward score
|
| 280 |
+
|
| 281 |
+
# Determine fill value based on column type and null_rate
|
| 282 |
+
if dtype in ("int", "float"):
|
| 283 |
+
fill_value = median_val or mean_val or 0
|
| 284 |
+
elif null_rate >= 0.70:
|
| 285 |
+
fill_value = "Unknown" # structural nulls — safe generic fill
|
| 286 |
+
else:
|
| 287 |
+
fill_value = mode_val or "Unknown"
|
| 288 |
+
|
| 289 |
+
# Bulk fix: fix ALL nulls in this column in one step
|
| 290 |
+
strategy = "median" if dtype in ("int", "float") else "mode"
|
| 291 |
+
reason = f"bulk fix {null_count} nulls in {col}, {strategy}={fill_value}"
|
| 292 |
+
_do_step({
|
| 293 |
+
"action_type": "fix_column",
|
| 294 |
+
"table": table,
|
| 295 |
+
"column": col,
|
| 296 |
+
"value": fill_value,
|
| 297 |
+
"reason": reason,
|
| 298 |
+
})
|
| 299 |
+
|
| 300 |
+
# ===== PHASE 4: LLM-assisted advanced cleaning =====
|
| 301 |
+
# Give the LLM a chance to find issues we missed (type errors, constraints, etc.)
|
| 302 |
+
if not done and step_num < budget - 3:
|
| 303 |
+
# Build context for LLM
|
| 304 |
+
fixed_summary = f"Profiled {len(col_profiles)} columns. Fixed nulls in columns with issues."
|
| 305 |
+
remaining_budget = budget - step_num - 2 # reserve 2 for validate+submit
|
| 306 |
+
|
| 307 |
+
llm_messages.append({"role": "user", "content": (
|
| 308 |
+
f"Table: \"{table}\", Columns: {visible_cols}\n"
|
| 309 |
+
f"I've already: {fixed_summary}\n"
|
| 310 |
+
f"Remaining budget: {remaining_budget} actions before validate+submit.\n"
|
| 311 |
+
f"What other data quality issues should I check? "
|
| 312 |
+
f"Consider: type errors, negative values, duplicates, whitespace. "
|
| 313 |
+
f"Respond with one JSON action, or {{\"action_type\":\"validate\"}} if done."
|
| 314 |
+
)})
|
| 315 |
+
|
| 316 |
+
for _ in range(min(remaining_budget, 5)):
|
| 317 |
+
if done or step_num >= budget - 2:
|
| 318 |
+
break
|
| 319 |
+
|
| 320 |
+
action_dict = _call_llm(llm, llm_messages)
|
| 321 |
+
if action_dict is None or action_dict.get("action_type") in ("validate", "submit"):
|
| 322 |
+
break
|
| 323 |
+
|
| 324 |
+
r, d = _do_step(action_dict)
|
| 325 |
+
if d:
|
| 326 |
+
break
|
| 327 |
+
|
| 328 |
+
# Feed result back to LLM
|
| 329 |
+
feedback = (obs.last_feedback or "")[:300]
|
| 330 |
+
if obs.query_result:
|
| 331 |
+
ids = [r2.get("id") for r2 in obs.query_result if r2.get("id") is not None]
|
| 332 |
+
if ids:
|
| 333 |
+
feedback += f"\nRow IDs: {ids[:15]}"
|
| 334 |
+
llm_messages.append({"role": "assistant", "content": json.dumps(action_dict)})
|
| 335 |
+
llm_messages.append({"role": "user", "content": feedback + "\nNext action?"})
|
| 336 |
+
|
| 337 |
+
# ===== PHASE 5: Validate =====
|
| 338 |
+
if not done and step_num < budget:
|
| 339 |
+
_do_step({"action_type": "validate"})
|
| 340 |
+
|
| 341 |
+
# ===== PHASE 6: Submit =====
|
| 342 |
+
if not done:
|
| 343 |
+
_do_step({"action_type": "submit"})
|
| 344 |
+
if obs.last_feedback:
|
| 345 |
+
parsed = _parse_score(obs.last_feedback)
|
| 346 |
+
if parsed is not None:
|
| 347 |
+
score = max(0.0, min(1.0, parsed))
|
| 348 |
+
|
| 349 |
+
# Fallback score from rewards
|
| 350 |
+
if score == 0.0 and rewards:
|
| 351 |
+
positive = sum(r for r in rewards if r > 0)
|
| 352 |
+
score = max(0.0, min(1.0, positive / max(budget * 0.15, 0.01)))
|
| 353 |
+
|
| 354 |
+
success = score >= 0.50
|
| 355 |
+
steps_taken = step_num
|
| 356 |
+
|
| 357 |
+
finally:
|
| 358 |
+
try:
|
| 359 |
+
env.close()
|
| 360 |
+
except Exception:
|
| 361 |
+
pass
|
| 362 |
+
log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
|
| 363 |
+
|
| 364 |
+
return score
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
# ---------------------------------------------------------------------------
|
| 368 |
+
# Main
|
| 369 |
+
# ---------------------------------------------------------------------------
|
| 370 |
+
|
| 371 |
+
def main() -> None:
|
| 372 |
+
wall_start = time.time()
|
| 373 |
+
all_scores: list[float] = []
|
| 374 |
+
|
| 375 |
+
for task_id, _ in TASKS:
|
| 376 |
+
score = run_task(task_id)
|
| 377 |
+
all_scores.append(score)
|
| 378 |
+
time.sleep(1)
|
| 379 |
+
|
| 380 |
+
avg = sum(all_scores) / len(all_scores) if all_scores else 0.0
|
| 381 |
+
total = time.time() - wall_start
|
| 382 |
+
|
| 383 |
+
print(
|
| 384 |
+
f"\n=== SQLSherlock-Env Results avg={avg:.3f} "
|
| 385 |
+
f"runtime={total:.1f}s ===",
|
| 386 |
+
file=sys.stderr,
|
| 387 |
+
)
|
| 388 |
+
for (tid, _), sc in zip(TASKS, all_scores):
|
| 389 |
+
bar = "\u2588" * int(sc * 20) + "\u2591" * (20 - int(sc * 20))
|
| 390 |
+
print(f" {tid:<38} [{bar}] {sc:.3f}", file=sys.stderr)
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
if __name__ == "__main__":
|
| 394 |
+
main()
|
openenv.yaml
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: SQLSherlock Env
|
| 3 |
+
emoji: 🔍
|
| 4 |
+
colorFrom: indigo
|
| 5 |
+
colorTo: cyan
|
| 6 |
+
sdk: docker
|
| 7 |
+
app_port: 7860
|
| 8 |
+
tags:
|
| 9 |
+
- openenv
|
| 10 |
+
pinned: false
|
| 11 |
+
---
|
| 12 |
+
|
| 13 |
+
name: sqlsherlock_env
|
| 14 |
+
version: "1.0.0"
|
| 15 |
+
description: >
|
| 16 |
+
RL environment where an AI agent acts as a data scientist.
|
| 17 |
+
Investigates real dirty datasets, discovers issues through
|
| 18 |
+
statistical profiling and SQL queries, fixes with reasoning,
|
| 19 |
+
validates fixes against raw baseline, exports in original format.
|
| 20 |
+
No issues are planted — the agent discovers them exactly like
|
| 21 |
+
a human data scientist would.
|
| 22 |
+
|
| 23 |
+
tasks:
|
| 24 |
+
- id: task1_null_and_types
|
| 25 |
+
name: "Null and type error repair"
|
| 26 |
+
difficulty: easy
|
| 27 |
+
max_steps: 20
|
| 28 |
+
description: >
|
| 29 |
+
Find and fix null values and type errors in the primary table.
|
| 30 |
+
Profile columns, identify anomalies, fix with reasoning,
|
| 31 |
+
validate your work, and export the cleaned dataset.
|
| 32 |
+
|
| 33 |
+
- id: task2_constraints_and_fk
|
| 34 |
+
name: "Constraint and FK integrity"
|
| 35 |
+
difficulty: medium
|
| 36 |
+
max_steps: 25
|
| 37 |
+
description: >
|
| 38 |
+
Everything in Task 1 plus constraint violations
|
| 39 |
+
(negative values in must-be-positive columns) and FK
|
| 40 |
+
violations (orphan references in related tables).
|
| 41 |
+
|
| 42 |
+
- id: task3_full_audit_with_trap
|
| 43 |
+
name: "Full statistical audit with trap"
|
| 44 |
+
difficulty: hard
|
| 45 |
+
max_steps: 30
|
| 46 |
+
description: >
|
| 47 |
+
Full audit including statistical outliers. TRAP WARNING:
|
| 48 |
+
one numeric value looks suspicious but is legitimate.
|
| 49 |
+
You MUST check z-scores before fixing any numeric value.
|
| 50 |
+
z > 5 = real outlier. z < 3 = leave alone.
|
| 51 |
+
|
| 52 |
+
env_vars:
|
| 53 |
+
API_BASE_URL:
|
| 54 |
+
description: "LLM API endpoint"
|
| 55 |
+
default: "https://router.huggingface.co/v1"
|
| 56 |
+
MODEL_NAME:
|
| 57 |
+
description: "Model identifier for inference"
|
| 58 |
+
default: "Qwen/Qwen2.5-72B-Instruct"
|
| 59 |
+
HF_TOKEN:
|
| 60 |
+
description: "HuggingFace API token (set as Space secret)"
|
| 61 |
+
required: true
|
sqlsherlock_env/__init__.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""SQLSherlock-Env — RL environment for AI data scientist agents."""
|
| 8 |
+
|
| 9 |
+
from client import SQLSherlockEnv
|
| 10 |
+
from models import SQLSherlockAction, SQLSherlockObservation, SQLSherlockState
|
| 11 |
+
|
| 12 |
+
__version__ = "1.0.0"
|
| 13 |
+
|
| 14 |
+
__all__ = [
|
| 15 |
+
"SQLSherlockEnv",
|
| 16 |
+
"SQLSherlockAction",
|
| 17 |
+
"SQLSherlockObservation",
|
| 18 |
+
"SQLSherlockState",
|
| 19 |
+
]
|
sqlsherlock_env/client.py
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
SQLSherlock-Env client.
|
| 9 |
+
|
| 10 |
+
Wraps the OpenEnv EnvClient to provide a typed, synchronous interface for
|
| 11 |
+
SQLSherlockAction / SQLSherlockObservation / SQLSherlockState.
|
| 12 |
+
|
| 13 |
+
Usage::
|
| 14 |
+
|
| 15 |
+
with SQLSherlockEnv(base_url="http://localhost:7860") as env:
|
| 16 |
+
obs = env.reset(dataset="mstz/titanic", task_id="task1_null_and_types")
|
| 17 |
+
obs, reward, done, info = env.step(
|
| 18 |
+
SQLSherlockAction(action_type="inspect", table="titanic")
|
| 19 |
+
)
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
from typing import Any, Dict, Optional, Tuple
|
| 23 |
+
|
| 24 |
+
from openenv.core import EnvClient
|
| 25 |
+
from openenv.core.client_types import StepResult
|
| 26 |
+
from openenv.core.env_server.types import State
|
| 27 |
+
|
| 28 |
+
from models import SQLSherlockAction, SQLSherlockObservation, SQLSherlockState
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class _AsyncSQLSherlockClient(
|
| 32 |
+
EnvClient[SQLSherlockAction, SQLSherlockObservation, SQLSherlockState]
|
| 33 |
+
):
|
| 34 |
+
"""Async EnvClient subclass with custom payload/parsing logic."""
|
| 35 |
+
|
| 36 |
+
def _step_payload(self, action: SQLSherlockAction) -> Dict[str, Any]:
|
| 37 |
+
payload: Dict[str, Any] = {"action_type": action.action_type}
|
| 38 |
+
|
| 39 |
+
if action.table is not None:
|
| 40 |
+
payload["table"] = action.table
|
| 41 |
+
if action.row_id is not None:
|
| 42 |
+
payload["row_id"] = action.row_id
|
| 43 |
+
if action.column is not None:
|
| 44 |
+
payload["column"] = action.column
|
| 45 |
+
if action.value is not None:
|
| 46 |
+
payload["value"] = action.value
|
| 47 |
+
if action.sql is not None:
|
| 48 |
+
payload["sql"] = action.sql
|
| 49 |
+
if action.cleaned_rows is not None:
|
| 50 |
+
payload["cleaned_rows"] = action.cleaned_rows
|
| 51 |
+
if action.removed_ids is not None:
|
| 52 |
+
payload["removed_ids"] = action.removed_ids
|
| 53 |
+
if action.reason is not None:
|
| 54 |
+
payload["reason"] = action.reason
|
| 55 |
+
|
| 56 |
+
return payload
|
| 57 |
+
|
| 58 |
+
def _parse_result(
|
| 59 |
+
self, payload: Dict[str, Any]
|
| 60 |
+
) -> StepResult[SQLSherlockObservation]:
|
| 61 |
+
obs_data = payload.get("observation", {})
|
| 62 |
+
|
| 63 |
+
observation = SQLSherlockObservation(
|
| 64 |
+
task_id=obs_data.get("task_id", ""),
|
| 65 |
+
task_description=obs_data.get("task_description", ""),
|
| 66 |
+
step=obs_data.get("step", 0),
|
| 67 |
+
max_steps=obs_data.get("max_steps", 20),
|
| 68 |
+
tables_summary=obs_data.get("tables_summary", {}),
|
| 69 |
+
query_result=obs_data.get("query_result"),
|
| 70 |
+
validation_result=obs_data.get("validation_result"),
|
| 71 |
+
last_feedback=obs_data.get("last_feedback", ""),
|
| 72 |
+
reward_trace=obs_data.get("reward_trace", []),
|
| 73 |
+
done=payload.get("done", False),
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
return StepResult(
|
| 77 |
+
observation=observation,
|
| 78 |
+
reward=payload.get("reward"),
|
| 79 |
+
done=payload.get("done", False),
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
def _parse_state(self, payload: Dict[str, Any]) -> SQLSherlockState:
|
| 83 |
+
return SQLSherlockState(
|
| 84 |
+
episode_id=payload.get("episode_id", ""),
|
| 85 |
+
task_id=payload.get("task_id", ""),
|
| 86 |
+
step_count=payload.get("step_count", 0),
|
| 87 |
+
grader_score=payload.get("grader_score", 0.0),
|
| 88 |
+
done=payload.get("done", False),
|
| 89 |
+
dataset_name=payload.get("dataset_name", ""),
|
| 90 |
+
source_format=payload.get("source_format", ""),
|
| 91 |
+
investigation_count=payload.get("investigation_count", 0),
|
| 92 |
+
validation_called=payload.get("validation_called", False),
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class SQLSherlockEnv:
|
| 97 |
+
"""Synchronous client for the SQLSherlock-Env RL environment.
|
| 98 |
+
|
| 99 |
+
Provides the standard RL interface:
|
| 100 |
+
obs = env.reset(dataset=..., task_id=...)
|
| 101 |
+
obs, reward, done, info = env.step(action)
|
| 102 |
+
|
| 103 |
+
Example::
|
| 104 |
+
|
| 105 |
+
with SQLSherlockEnv(base_url="http://localhost:7860") as env:
|
| 106 |
+
obs = env.reset(
|
| 107 |
+
dataset="mstz/titanic",
|
| 108 |
+
task_id="task1_null_and_types",
|
| 109 |
+
)
|
| 110 |
+
print(obs.tables_summary)
|
| 111 |
+
|
| 112 |
+
obs, reward, done, info = env.step(
|
| 113 |
+
SQLSherlockAction(action_type="inspect", table="titanic")
|
| 114 |
+
)
|
| 115 |
+
print(obs.last_feedback, reward)
|
| 116 |
+
"""
|
| 117 |
+
|
| 118 |
+
def __init__(self, base_url: str = "http://localhost:7860") -> None:
|
| 119 |
+
self._async_client = _AsyncSQLSherlockClient(base_url=base_url)
|
| 120 |
+
self._sync = self._async_client.sync()
|
| 121 |
+
|
| 122 |
+
def __enter__(self):
|
| 123 |
+
self._sync.connect()
|
| 124 |
+
return self
|
| 125 |
+
|
| 126 |
+
def __exit__(self, *args):
|
| 127 |
+
self.close()
|
| 128 |
+
|
| 129 |
+
def reset(self, **kwargs) -> SQLSherlockObservation:
|
| 130 |
+
"""Reset the environment and return initial observation.
|
| 131 |
+
|
| 132 |
+
Keyword Args:
|
| 133 |
+
dataset (str): Dataset source — required.
|
| 134 |
+
task_id (str): Task identifier — required.
|
| 135 |
+
seed (int): RNG seed (default 42).
|
| 136 |
+
max_rows(int): Row limit (default 500).
|
| 137 |
+
"""
|
| 138 |
+
result: StepResult = self._sync.reset(**kwargs)
|
| 139 |
+
return result.observation
|
| 140 |
+
|
| 141 |
+
def step(
|
| 142 |
+
self, action: SQLSherlockAction
|
| 143 |
+
) -> Tuple[SQLSherlockObservation, float, bool, dict]:
|
| 144 |
+
"""Execute one action. Returns (obs, reward, done, info)."""
|
| 145 |
+
result: StepResult = self._sync.step(action)
|
| 146 |
+
return (
|
| 147 |
+
result.observation,
|
| 148 |
+
float(result.reward or 0.0),
|
| 149 |
+
result.done,
|
| 150 |
+
{},
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
def get_state(self) -> SQLSherlockState:
|
| 154 |
+
"""Return current episode state."""
|
| 155 |
+
return self._sync.state()
|
| 156 |
+
|
| 157 |
+
def close(self) -> None:
|
| 158 |
+
"""Close the connection."""
|
| 159 |
+
try:
|
| 160 |
+
self._sync.disconnect()
|
| 161 |
+
except Exception:
|
| 162 |
+
pass
|
| 163 |
+
|
| 164 |
+
@classmethod
|
| 165 |
+
def from_docker_image(cls, image: str, port: int = 7860) -> "SQLSherlockEnv":
|
| 166 |
+
"""Create client connected to a freshly launched Docker container."""
|
| 167 |
+
import subprocess
|
| 168 |
+
import time
|
| 169 |
+
|
| 170 |
+
container_id = subprocess.check_output(
|
| 171 |
+
["docker", "run", "-d", "-p", f"{port}:{port}", image],
|
| 172 |
+
text=True,
|
| 173 |
+
).strip()
|
| 174 |
+
|
| 175 |
+
# Wait for server to be ready
|
| 176 |
+
import urllib.request
|
| 177 |
+
for _ in range(30):
|
| 178 |
+
try:
|
| 179 |
+
urllib.request.urlopen(f"http://localhost:{port}/health", timeout=2)
|
| 180 |
+
break
|
| 181 |
+
except Exception:
|
| 182 |
+
time.sleep(1)
|
| 183 |
+
|
| 184 |
+
client = cls(base_url=f"http://localhost:{port}")
|
| 185 |
+
client._container_id = container_id
|
| 186 |
+
return client
|
sqlsherlock_env/models.py
ADDED
|
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
Data models for the SQLSherlock-Env RL environment.
|
| 9 |
+
|
| 10 |
+
An AI agent acts as a data scientist investigating a dirty dataset,
|
| 11 |
+
discovering real data quality issues through statistical investigation,
|
| 12 |
+
fixing them with reasoning, validating fixes, and exporting cleaned output.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
from typing import Any, Literal, Optional
|
| 16 |
+
|
| 17 |
+
from openenv.core.env_server.types import Action, Observation, State
|
| 18 |
+
from pydantic import Field
|
| 19 |
+
|
| 20 |
+
ActionType = Literal[
|
| 21 |
+
"inspect", # view all rows in a table
|
| 22 |
+
"profile_column", # stats: mean/std/min/max/nulls/z_scores per col
|
| 23 |
+
"run_sql", # SELECT query only
|
| 24 |
+
"fix_cell", # correct one cell value with reason
|
| 25 |
+
"fix_column", # fix ALL nulls in a column with one value (bulk operation)
|
| 26 |
+
"delete_row", # remove a row with reason
|
| 27 |
+
"validate", # run all 6 checks: before vs after
|
| 28 |
+
"submit", # end episode and score
|
| 29 |
+
"export", # terminal: write cleaned file, return URL
|
| 30 |
+
]
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class SQLSherlockAction(Action):
|
| 34 |
+
"""Action for the SQLSherlock-Env environment.
|
| 35 |
+
|
| 36 |
+
The agent issues one of 8 action types per step.
|
| 37 |
+
Every fix action MUST include a reason field with statistical justification.
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
action_type: ActionType = Field(
|
| 41 |
+
...,
|
| 42 |
+
description="Type of action to perform.",
|
| 43 |
+
)
|
| 44 |
+
table: Optional[str] = Field(
|
| 45 |
+
default=None,
|
| 46 |
+
description="Target table name (required for inspect, profile_column, fix_cell, delete_row).",
|
| 47 |
+
)
|
| 48 |
+
row_id: Optional[int] = Field(
|
| 49 |
+
default=None,
|
| 50 |
+
description="Row primary key (required for fix_cell, delete_row).",
|
| 51 |
+
)
|
| 52 |
+
column: Optional[str] = Field(
|
| 53 |
+
default=None,
|
| 54 |
+
description="Column name (required for profile_column, fix_cell).",
|
| 55 |
+
)
|
| 56 |
+
value: Optional[Any] = Field(
|
| 57 |
+
default=None,
|
| 58 |
+
description="Corrected value to write (required for fix_cell).",
|
| 59 |
+
)
|
| 60 |
+
sql: Optional[str] = Field(
|
| 61 |
+
default=None,
|
| 62 |
+
description="SELECT SQL query string (required for run_sql).",
|
| 63 |
+
)
|
| 64 |
+
cleaned_rows: Optional[list[dict]] = Field(
|
| 65 |
+
default=None,
|
| 66 |
+
description="Full list of cleaned rows for export action.",
|
| 67 |
+
)
|
| 68 |
+
removed_ids: Optional[list[int]] = Field(
|
| 69 |
+
default=None,
|
| 70 |
+
description="List of deleted row primary keys for export action.",
|
| 71 |
+
)
|
| 72 |
+
reason: Optional[str] = Field(
|
| 73 |
+
default=None,
|
| 74 |
+
description="Statistical justification for this action (required for fix_cell, delete_row).",
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class SQLSherlockObservation(Observation):
|
| 79 |
+
"""Observation returned to the agent after each step.
|
| 80 |
+
|
| 81 |
+
Contains the current environment state the agent can see.
|
| 82 |
+
The issue_registry is NEVER included here — the agent must discover issues.
|
| 83 |
+
"""
|
| 84 |
+
|
| 85 |
+
task_id: str = Field(
|
| 86 |
+
default="",
|
| 87 |
+
description="Current task identifier.",
|
| 88 |
+
)
|
| 89 |
+
task_description: str = Field(
|
| 90 |
+
default="",
|
| 91 |
+
description="Human-readable task description for the agent.",
|
| 92 |
+
)
|
| 93 |
+
step: int = Field(
|
| 94 |
+
default=0,
|
| 95 |
+
description="Current step number (1-indexed).",
|
| 96 |
+
)
|
| 97 |
+
max_steps: int = Field(
|
| 98 |
+
default=20,
|
| 99 |
+
description="Maximum steps allowed for this task.",
|
| 100 |
+
)
|
| 101 |
+
tables_summary: dict[str, Any] = Field(
|
| 102 |
+
default_factory=dict,
|
| 103 |
+
description=(
|
| 104 |
+
"Summary of all loaded tables: "
|
| 105 |
+
"{table_name: {row_count: int, columns: list[str], dtypes: dict}}"
|
| 106 |
+
),
|
| 107 |
+
)
|
| 108 |
+
query_result: Optional[list[dict]] = Field(
|
| 109 |
+
default=None,
|
| 110 |
+
description="Result rows from inspect or run_sql actions.",
|
| 111 |
+
)
|
| 112 |
+
validation_result: Optional[dict] = Field(
|
| 113 |
+
default=None,
|
| 114 |
+
description="Detailed validation results after a validate action.",
|
| 115 |
+
)
|
| 116 |
+
last_feedback: str = Field(
|
| 117 |
+
default="",
|
| 118 |
+
description="Human-readable feedback about the last action taken.",
|
| 119 |
+
)
|
| 120 |
+
reward_trace: list[dict] = Field(
|
| 121 |
+
default_factory=list,
|
| 122 |
+
description="Cumulative reward log — grows every step; judges review this.",
|
| 123 |
+
)
|
| 124 |
+
done: bool = Field(
|
| 125 |
+
default=False,
|
| 126 |
+
description="True when the episode has ended.",
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
class SQLSherlockState(State):
|
| 131 |
+
"""Internal server-side state for one SQLSherlock episode.
|
| 132 |
+
|
| 133 |
+
Not exposed to the agent. Used by the environment and graders.
|
| 134 |
+
"""
|
| 135 |
+
|
| 136 |
+
episode_id: str = Field(
|
| 137 |
+
default="",
|
| 138 |
+
description="Unique identifier for this episode.",
|
| 139 |
+
)
|
| 140 |
+
task_id: str = Field(
|
| 141 |
+
default="",
|
| 142 |
+
description="Task identifier for this episode.",
|
| 143 |
+
)
|
| 144 |
+
step_count: int = Field(
|
| 145 |
+
default=0,
|
| 146 |
+
description="Number of steps taken so far.",
|
| 147 |
+
)
|
| 148 |
+
grader_score: float = Field(
|
| 149 |
+
default=0.0,
|
| 150 |
+
description="Most recent grader score (0.0–1.0).",
|
| 151 |
+
)
|
| 152 |
+
done: bool = Field(
|
| 153 |
+
default=False,
|
| 154 |
+
description="Whether the episode has ended.",
|
| 155 |
+
)
|
| 156 |
+
dataset_name: str = Field(
|
| 157 |
+
default="",
|
| 158 |
+
description="Name or path of the loaded dataset.",
|
| 159 |
+
)
|
| 160 |
+
source_format: str = Field(
|
| 161 |
+
default="",
|
| 162 |
+
description="Detected source format: csv|json|jsonl|parquet|hf_dataset.",
|
| 163 |
+
)
|
| 164 |
+
investigation_count: int = Field(
|
| 165 |
+
default=0,
|
| 166 |
+
description="Number of investigation actions taken (inspect + profile + sql).",
|
| 167 |
+
)
|
| 168 |
+
validation_called: bool = Field(
|
| 169 |
+
default=False,
|
| 170 |
+
description="Whether the agent called validate() at least once.",
|
| 171 |
+
)
|
sqlsherlock_env/pyproject.toml
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
[build-system]
|
| 8 |
+
requires = ["hatchling"]
|
| 9 |
+
build-backend = "hatchling.build"
|
| 10 |
+
|
| 11 |
+
[project]
|
| 12 |
+
name = "sqlsherlock-env"
|
| 13 |
+
version = "1.0.0"
|
| 14 |
+
description = "RL environment where an AI agent acts as a data scientist investigating dirty datasets"
|
| 15 |
+
requires-python = ">=3.11"
|
| 16 |
+
dependencies = [
|
| 17 |
+
"openenv-core>=0.2.1",
|
| 18 |
+
"fastapi>=0.115.0",
|
| 19 |
+
"uvicorn[standard]>=0.30.0",
|
| 20 |
+
"pydantic>=2.8.2",
|
| 21 |
+
"openai>=1.40.0",
|
| 22 |
+
"python-multipart>=0.0.9",
|
| 23 |
+
"datasets>=2.20.0",
|
| 24 |
+
"pandas>=2.0.0",
|
| 25 |
+
"pyarrow>=14.0.0",
|
| 26 |
+
]
|
| 27 |
+
|
| 28 |
+
[project.optional-dependencies]
|
| 29 |
+
train = [
|
| 30 |
+
"trl>=0.15.0",
|
| 31 |
+
"transformers>=4.47.0",
|
| 32 |
+
"torch>=2.5.0",
|
| 33 |
+
]
|
| 34 |
+
dev = [
|
| 35 |
+
"pytest>=8.0",
|
| 36 |
+
"httpx>=0.27",
|
| 37 |
+
]
|
| 38 |
+
|
| 39 |
+
[project.scripts]
|
| 40 |
+
server = "server.app:main"
|
| 41 |
+
|
| 42 |
+
[tool.hatch.build.targets.wheel]
|
| 43 |
+
packages = ["."]
|
| 44 |
+
|
| 45 |
+
[tool.pytest.ini_options]
|
| 46 |
+
testpaths = ["tests"]
|
sqlsherlock_env/server/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""SQLSherlock-Env server components."""
|
| 8 |
+
|
| 9 |
+
from server.environment import SQLSherlockEnvironment, TASKS
|
| 10 |
+
|
| 11 |
+
__all__ = ["SQLSherlockEnvironment", "TASKS"]
|
sqlsherlock_env/server/app.py
ADDED
|
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
FastAPI application for SQLSherlock-Env.
|
| 9 |
+
|
| 10 |
+
Mounts the OpenEnv core WebSocket/HTTP app and adds extra endpoints:
|
| 11 |
+
GET /health
|
| 12 |
+
GET /tasks
|
| 13 |
+
POST /upload_dataset
|
| 14 |
+
GET /download/{file_id}
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
import os
|
| 18 |
+
import tempfile
|
| 19 |
+
import time
|
| 20 |
+
from pathlib import Path
|
| 21 |
+
|
| 22 |
+
from fastapi import FastAPI, File, HTTPException, UploadFile
|
| 23 |
+
from fastapi.responses import FileResponse
|
| 24 |
+
|
| 25 |
+
from openenv.core.env_server import create_app
|
| 26 |
+
|
| 27 |
+
from models import SQLSherlockAction, SQLSherlockObservation
|
| 28 |
+
from server.environment import SQLSherlockEnvironment, TASKS
|
| 29 |
+
|
| 30 |
+
# ---------------------------------------------------------------------------
|
| 31 |
+
# Core OpenEnv app
|
| 32 |
+
# ---------------------------------------------------------------------------
|
| 33 |
+
|
| 34 |
+
app: FastAPI = create_app(
|
| 35 |
+
SQLSherlockEnvironment, # class (factory), not instance
|
| 36 |
+
SQLSherlockAction,
|
| 37 |
+
SQLSherlockObservation,
|
| 38 |
+
env_name="sqlsherlock_env",
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
# ---------------------------------------------------------------------------
|
| 42 |
+
# /health
|
| 43 |
+
# ---------------------------------------------------------------------------
|
| 44 |
+
|
| 45 |
+
@app.get("/health")
|
| 46 |
+
async def health() -> dict:
|
| 47 |
+
return {
|
| 48 |
+
"status": "healthy",
|
| 49 |
+
"version": "1.0.0",
|
| 50 |
+
"timestamp": time.time(),
|
| 51 |
+
"tasks": [t["id"] for t in TASKS],
|
| 52 |
+
"supported_formats": ["csv", "json", "jsonl", "parquet", "hf"],
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
# ---------------------------------------------------------------------------
|
| 57 |
+
# /tasks
|
| 58 |
+
# ---------------------------------------------------------------------------
|
| 59 |
+
|
| 60 |
+
@app.get("/tasks")
|
| 61 |
+
async def list_tasks() -> list[dict]:
|
| 62 |
+
return [
|
| 63 |
+
{
|
| 64 |
+
"id": t["id"],
|
| 65 |
+
"name": t["name"],
|
| 66 |
+
"difficulty": t["difficulty"],
|
| 67 |
+
"max_steps": t["max_steps"],
|
| 68 |
+
"description": t["description"],
|
| 69 |
+
}
|
| 70 |
+
for t in TASKS
|
| 71 |
+
]
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
# ---------------------------------------------------------------------------
|
| 75 |
+
# /upload_dataset
|
| 76 |
+
# ---------------------------------------------------------------------------
|
| 77 |
+
|
| 78 |
+
@app.post("/upload_dataset")
|
| 79 |
+
async def upload_dataset(file: UploadFile = File(...)) -> dict:
|
| 80 |
+
"""Accept a dataset file, validate it is loadable, return a preview.
|
| 81 |
+
|
| 82 |
+
Supported file types: .csv, .json, .jsonl, .parquet
|
| 83 |
+
"""
|
| 84 |
+
from server.dataset_loader import load
|
| 85 |
+
|
| 86 |
+
filename = file.filename or "upload"
|
| 87 |
+
suffix = Path(filename).suffix.lower()
|
| 88 |
+
|
| 89 |
+
if suffix not in (".csv", ".json", ".jsonl", ".parquet"):
|
| 90 |
+
raise HTTPException(
|
| 91 |
+
status_code=400,
|
| 92 |
+
detail=(
|
| 93 |
+
f"Unsupported file type '{suffix}'. "
|
| 94 |
+
"Upload a .csv, .json, .jsonl, or .parquet file."
|
| 95 |
+
),
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
# Save to temp file
|
| 99 |
+
tmp_path = os.path.join(tempfile.gettempdir(), f"sqlsherlock_upload_{filename}")
|
| 100 |
+
try:
|
| 101 |
+
contents = await file.read()
|
| 102 |
+
with open(tmp_path, "wb") as f:
|
| 103 |
+
f.write(contents)
|
| 104 |
+
except Exception as exc:
|
| 105 |
+
raise HTTPException(status_code=500, detail=f"File save failed: {exc}")
|
| 106 |
+
|
| 107 |
+
# Attempt load
|
| 108 |
+
try:
|
| 109 |
+
table_records = load(tmp_path, max_rows=500)
|
| 110 |
+
except ValueError as exc:
|
| 111 |
+
raise HTTPException(status_code=422, detail=str(exc))
|
| 112 |
+
finally:
|
| 113 |
+
try:
|
| 114 |
+
os.remove(tmp_path)
|
| 115 |
+
except OSError:
|
| 116 |
+
pass
|
| 117 |
+
|
| 118 |
+
table_name = list(table_records.keys())[0]
|
| 119 |
+
records = table_records[table_name]
|
| 120 |
+
columns = list(records[0].keys()) if records else []
|
| 121 |
+
|
| 122 |
+
issue_preview = _quick_issue_preview(records, columns)
|
| 123 |
+
|
| 124 |
+
return {
|
| 125 |
+
"dataset_key": filename,
|
| 126 |
+
"table_name": table_name,
|
| 127 |
+
"columns": columns,
|
| 128 |
+
"row_count": len(records),
|
| 129 |
+
"detected_issues_preview": issue_preview,
|
| 130 |
+
"usage_example": (
|
| 131 |
+
f'{{"dataset": "{filename}", '
|
| 132 |
+
f'"task_id": "task1_null_and_types"}}'
|
| 133 |
+
),
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
# ---------------------------------------------------------------------------
|
| 138 |
+
# /download/{file_id}
|
| 139 |
+
# ---------------------------------------------------------------------------
|
| 140 |
+
|
| 141 |
+
@app.get("/download/{file_id}")
|
| 142 |
+
async def download_file(file_id: str) -> FileResponse:
|
| 143 |
+
"""Serve a previously exported cleaned dataset file."""
|
| 144 |
+
tmp_dir = tempfile.gettempdir()
|
| 145 |
+
matches = [
|
| 146 |
+
f for f in os.listdir(tmp_dir)
|
| 147 |
+
if f.startswith(file_id)
|
| 148 |
+
]
|
| 149 |
+
if not matches:
|
| 150 |
+
raise HTTPException(
|
| 151 |
+
status_code=404,
|
| 152 |
+
detail=f"No exported file found for file_id='{file_id}'.",
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
filepath = os.path.join(tmp_dir, matches[0])
|
| 156 |
+
filename = matches[0][len(file_id) + 1:] # strip "{uuid}_" prefix
|
| 157 |
+
|
| 158 |
+
return FileResponse(
|
| 159 |
+
path=filepath,
|
| 160 |
+
filename=filename,
|
| 161 |
+
media_type="application/octet-stream",
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
# ---------------------------------------------------------------------------
|
| 166 |
+
# Dev entry point
|
| 167 |
+
# ---------------------------------------------------------------------------
|
| 168 |
+
|
| 169 |
+
def main(host: str = "0.0.0.0", port: int = 7860):
|
| 170 |
+
import uvicorn
|
| 171 |
+
uvicorn.run(app, host=host, port=port)
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
if __name__ == "__main__":
|
| 175 |
+
import argparse
|
| 176 |
+
parser = argparse.ArgumentParser()
|
| 177 |
+
parser.add_argument("--port", type=int, default=7860)
|
| 178 |
+
args = parser.parse_args()
|
| 179 |
+
main(port=args.port)
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
# ---------------------------------------------------------------------------
|
| 183 |
+
# Helpers
|
| 184 |
+
# ---------------------------------------------------------------------------
|
| 185 |
+
|
| 186 |
+
def _quick_issue_preview(records: list[dict], columns: list[str]) -> int:
|
| 187 |
+
"""Count obvious null cells for the upload preview."""
|
| 188 |
+
import math
|
| 189 |
+
count = 0
|
| 190 |
+
for row in records:
|
| 191 |
+
for col in columns:
|
| 192 |
+
val = row.get(col)
|
| 193 |
+
if val is None:
|
| 194 |
+
count += 1
|
| 195 |
+
elif isinstance(val, float) and math.isnan(val):
|
| 196 |
+
count += 1
|
| 197 |
+
elif isinstance(val, str) and val.strip() == "":
|
| 198 |
+
count += 1
|
| 199 |
+
return count
|
sqlsherlock_env/server/database.py
ADDED
|
@@ -0,0 +1,563 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
DatabaseEngine for SQLSherlock-Env.
|
| 9 |
+
|
| 10 |
+
Manages one in-memory SQLite database per episode.
|
| 11 |
+
Owns: dataset loading, profiling, issue detection, trap planting,
|
| 12 |
+
baseline validation, and all agent-facing read/write operations.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import copy
|
| 16 |
+
import math
|
| 17 |
+
import re
|
| 18 |
+
import sqlite3
|
| 19 |
+
from typing import Any, Optional
|
| 20 |
+
|
| 21 |
+
from server.dataset_loader import load, records_to_sqlite, coerce
|
| 22 |
+
from server.schema_profiler import profile_table, find_primary_key
|
| 23 |
+
from server.issue_detector import detect_issues, detect_trap, Issue, Trap
|
| 24 |
+
from server.validator import Validator, ValidationResult
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
# ---------------------------------------------------------------------------
|
| 28 |
+
# SQL injection block-list
|
| 29 |
+
# ---------------------------------------------------------------------------
|
| 30 |
+
|
| 31 |
+
_BLOCKED = frozenset({
|
| 32 |
+
"DROP", "DELETE", "UPDATE", "INSERT", "ALTER",
|
| 33 |
+
"CREATE", "ATTACH", "DETACH", "LOAD_EXTENSION", "PRAGMA", "VACUUM",
|
| 34 |
+
"REINDEX", "SAVEPOINT", "RELEASE", "BEGIN", "COMMIT", "ROLLBACK",
|
| 35 |
+
})
|
| 36 |
+
_WORD_RE = re.compile(r"\b(\w+)\b")
|
| 37 |
+
_MAX_QUERY_ROWS = 50
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
# ---------------------------------------------------------------------------
|
| 41 |
+
# DatabaseEngine
|
| 42 |
+
# ---------------------------------------------------------------------------
|
| 43 |
+
|
| 44 |
+
class DatabaseEngine:
|
| 45 |
+
"""In-memory SQLite environment, isolated per episode.
|
| 46 |
+
|
| 47 |
+
Initialisation sequence
|
| 48 |
+
-----------------------
|
| 49 |
+
1. Load dataset from source.
|
| 50 |
+
2. Write records to SQLite.
|
| 51 |
+
3. Deep-copy originals (before any mutation).
|
| 52 |
+
4. Profile all columns.
|
| 53 |
+
5. Capture validator baseline.
|
| 54 |
+
6. Detect real issues (+ synthetic top-up).
|
| 55 |
+
7. Plant trap (task3 only).
|
| 56 |
+
8. Initialise action log.
|
| 57 |
+
"""
|
| 58 |
+
|
| 59 |
+
def __init__(
|
| 60 |
+
self,
|
| 61 |
+
task_id: str,
|
| 62 |
+
seed: int,
|
| 63 |
+
dataset_source: str,
|
| 64 |
+
max_rows: int = 500,
|
| 65 |
+
) -> None:
|
| 66 |
+
if not dataset_source or not dataset_source.strip():
|
| 67 |
+
raise ValueError("dataset_source must not be empty.")
|
| 68 |
+
|
| 69 |
+
self.task_id = task_id
|
| 70 |
+
self.seed = seed
|
| 71 |
+
|
| 72 |
+
# --- 1. Load ---
|
| 73 |
+
table_records = load(dataset_source, max_rows=max_rows)
|
| 74 |
+
|
| 75 |
+
# --- 2. SQLite ---
|
| 76 |
+
self._conn = sqlite3.connect(":memory:", check_same_thread=False)
|
| 77 |
+
self._conn.row_factory = sqlite3.Row
|
| 78 |
+
|
| 79 |
+
self._table_names: list[str] = []
|
| 80 |
+
self._records: dict[str, list[dict]] = {}
|
| 81 |
+
|
| 82 |
+
for tname, recs in table_records.items():
|
| 83 |
+
records_to_sqlite(self._conn, tname, recs)
|
| 84 |
+
self._table_names.append(tname)
|
| 85 |
+
self._records[tname] = recs
|
| 86 |
+
|
| 87 |
+
# Primary table is always the first one
|
| 88 |
+
self._primary_table: str = self._table_names[0]
|
| 89 |
+
|
| 90 |
+
# --- 3. Deep-copy originals (clean snapshot) ---
|
| 91 |
+
self._originals: dict[str, list[dict]] = {
|
| 92 |
+
t: copy.deepcopy(recs) for t, recs in self._records.items()
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
# --- 4. Profile ---
|
| 96 |
+
self._profiles: dict[str, dict[str, dict]] = {}
|
| 97 |
+
for tname, recs in self._records.items():
|
| 98 |
+
self._profiles[tname] = profile_table(tname, recs, self._conn)
|
| 99 |
+
|
| 100 |
+
# Determine PK column for primary table
|
| 101 |
+
primary_recs = self._records[self._primary_table]
|
| 102 |
+
self._pk_col: str = (
|
| 103 |
+
find_primary_key(primary_recs) or list(primary_recs[0].keys())[0]
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
# Source format (from injected _source_format key)
|
| 107 |
+
self.source_format: str = (
|
| 108 |
+
primary_recs[0].get("_source_format", "csv") if primary_recs else "csv"
|
| 109 |
+
)
|
| 110 |
+
self.dataset_name: str = dataset_source
|
| 111 |
+
|
| 112 |
+
# --- 5. Validator baseline ---
|
| 113 |
+
# Issue registry not yet built — pass empty list for baseline;
|
| 114 |
+
# we rebuild after detection.
|
| 115 |
+
self._validator: Optional[Validator] = None # initialised after step 6
|
| 116 |
+
|
| 117 |
+
# --- 6. Issue detection ---
|
| 118 |
+
primary_profile = self._profiles[self._primary_table]
|
| 119 |
+
self._issues: list[Issue] = detect_issues(
|
| 120 |
+
conn=self._conn,
|
| 121 |
+
profile=primary_profile,
|
| 122 |
+
records=primary_recs,
|
| 123 |
+
task_id=task_id,
|
| 124 |
+
seed=seed,
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
# NOW build validator with the real issue registry
|
| 128 |
+
self._validator = Validator(
|
| 129 |
+
conn=self._conn,
|
| 130 |
+
profile=primary_profile,
|
| 131 |
+
issue_registry=self._issues,
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
# --- 7. Trap (task3 only) ---
|
| 135 |
+
self._trap: Optional[Trap] = None
|
| 136 |
+
if task_id == "task3_full_audit_with_trap":
|
| 137 |
+
self._trap = detect_trap(
|
| 138 |
+
conn=self._conn,
|
| 139 |
+
profile=primary_profile,
|
| 140 |
+
records=primary_recs,
|
| 141 |
+
issue_registry=self._issues,
|
| 142 |
+
seed=seed,
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
# --- 8. Action log ---
|
| 146 |
+
self._action_log: list[Any] = []
|
| 147 |
+
|
| 148 |
+
# Track which columns the agent has touched (for distribution warnings)
|
| 149 |
+
self._touched_columns: set[str] = set()
|
| 150 |
+
|
| 151 |
+
# ------------------------------------------------------------------
|
| 152 |
+
# Read operations
|
| 153 |
+
# ------------------------------------------------------------------
|
| 154 |
+
|
| 155 |
+
def rows(self, table: str) -> list[dict]:
|
| 156 |
+
"""Return current rows for *table* as plain dicts."""
|
| 157 |
+
self._require_table(table)
|
| 158 |
+
cur = self._conn.execute(f'SELECT * FROM "{table}"')
|
| 159 |
+
return [dict(row) for row in cur.fetchall()]
|
| 160 |
+
|
| 161 |
+
def columns(self, table: str) -> list[str]:
|
| 162 |
+
"""Return column names for *table*."""
|
| 163 |
+
self._require_table(table)
|
| 164 |
+
cur = self._conn.execute(f'PRAGMA table_info("{table}")')
|
| 165 |
+
return [row[1] for row in cur.fetchall()]
|
| 166 |
+
|
| 167 |
+
def table_names(self) -> list[str]:
|
| 168 |
+
"""Return all table names in this episode's database."""
|
| 169 |
+
return list(self._table_names)
|
| 170 |
+
|
| 171 |
+
def tables_summary(self) -> dict[str, Any]:
|
| 172 |
+
"""Return a compact summary of every table (for observations)."""
|
| 173 |
+
summary = {}
|
| 174 |
+
for tname in self._table_names:
|
| 175 |
+
cols = self.columns(tname)
|
| 176 |
+
profile = self._profiles.get(tname, {})
|
| 177 |
+
dtypes = {col: profile[col]["dtype"] for col in cols if col in profile}
|
| 178 |
+
current_rows = self.rows(tname)
|
| 179 |
+
summary[tname] = {
|
| 180 |
+
"row_count": len(current_rows),
|
| 181 |
+
"columns": cols,
|
| 182 |
+
"dtypes": dtypes,
|
| 183 |
+
}
|
| 184 |
+
return summary
|
| 185 |
+
|
| 186 |
+
def query(self, sql: str) -> list[dict]:
|
| 187 |
+
"""Execute a read-only SELECT query and return up to 50 rows.
|
| 188 |
+
|
| 189 |
+
Raises:
|
| 190 |
+
ValueError: If the query is not a SELECT or contains blocked keywords.
|
| 191 |
+
"""
|
| 192 |
+
if not sql or not sql.strip():
|
| 193 |
+
raise ValueError("SQL query must not be empty.")
|
| 194 |
+
|
| 195 |
+
stripped = sql.strip()
|
| 196 |
+
if not stripped.upper().startswith("SELECT"):
|
| 197 |
+
raise ValueError("Only SELECT queries are permitted.")
|
| 198 |
+
|
| 199 |
+
if ";" in stripped:
|
| 200 |
+
raise ValueError("Semicolons are not permitted in queries.")
|
| 201 |
+
|
| 202 |
+
# Word-boundary check for blocked keywords
|
| 203 |
+
words = {m.group(1).upper() for m in _WORD_RE.finditer(stripped)}
|
| 204 |
+
blocked_found = words & _BLOCKED
|
| 205 |
+
if blocked_found:
|
| 206 |
+
raise ValueError(
|
| 207 |
+
f"Query contains blocked keyword(s): {sorted(blocked_found)}. "
|
| 208 |
+
"Only SELECT is permitted."
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
try:
|
| 212 |
+
cur = self._conn.execute(stripped)
|
| 213 |
+
rows = cur.fetchmany(_MAX_QUERY_ROWS)
|
| 214 |
+
return [dict(row) for row in rows]
|
| 215 |
+
except sqlite3.Error as exc:
|
| 216 |
+
raise ValueError(f"SQL error: {exc}") from exc
|
| 217 |
+
|
| 218 |
+
def profile_col(self, table: str, column: str) -> dict:
|
| 219 |
+
"""Return statistical profile for one column.
|
| 220 |
+
|
| 221 |
+
Returns dict with: mean, std, min, max, null_count,
|
| 222 |
+
z_scores {row_id: z}, must_be_positive.
|
| 223 |
+
"""
|
| 224 |
+
self._require_table(table)
|
| 225 |
+
profile = self._profiles.get(table, {})
|
| 226 |
+
if column not in profile:
|
| 227 |
+
# Re-profile on demand (column may have been modified)
|
| 228 |
+
current = self.rows(table)
|
| 229 |
+
updated_profile = profile_table(table, current, self._conn)
|
| 230 |
+
self._profiles[table] = updated_profile
|
| 231 |
+
profile = updated_profile
|
| 232 |
+
|
| 233 |
+
if column not in profile:
|
| 234 |
+
raise ValueError(f"Column '{column}' not found in table '{table}'.")
|
| 235 |
+
|
| 236 |
+
p = profile[column]
|
| 237 |
+
|
| 238 |
+
# Compute median and mode for smarter imputation hints
|
| 239 |
+
current_rows = self.rows(table)
|
| 240 |
+
non_null_vals = [r.get(column) for r in current_rows if not _is_null(r.get(column))]
|
| 241 |
+
|
| 242 |
+
median_val = None
|
| 243 |
+
mode_val = None
|
| 244 |
+
if non_null_vals:
|
| 245 |
+
if p.get("dtype") in ("int", "float"):
|
| 246 |
+
nums = sorted(float(v) for v in non_null_vals if _can_cast_float(v))
|
| 247 |
+
if nums:
|
| 248 |
+
mid = len(nums) // 2
|
| 249 |
+
median_val = round(nums[mid] if len(nums) % 2 else (nums[mid-1]+nums[mid])/2, 4)
|
| 250 |
+
# Mode: most common value (works for both string and numeric)
|
| 251 |
+
from collections import Counter
|
| 252 |
+
counts = Counter(str(v) for v in non_null_vals)
|
| 253 |
+
if counts:
|
| 254 |
+
mode_val = counts.most_common(1)[0][0]
|
| 255 |
+
|
| 256 |
+
return {
|
| 257 |
+
"mean": p.get("mean"),
|
| 258 |
+
"median": median_val,
|
| 259 |
+
"mode": mode_val,
|
| 260 |
+
"std": p.get("std"),
|
| 261 |
+
"min": p.get("min"),
|
| 262 |
+
"max": p.get("max"),
|
| 263 |
+
"null_count": p.get("null_count", 0),
|
| 264 |
+
"null_rate": p.get("null_rate", 0.0),
|
| 265 |
+
"z_scores": p.get("z_scores", {}),
|
| 266 |
+
"must_be_positive": p.get("must_be_positive", False),
|
| 267 |
+
"dtype": p.get("dtype", "unknown"),
|
| 268 |
+
}
|
| 269 |
+
|
| 270 |
+
# ------------------------------------------------------------------
|
| 271 |
+
# Write operations
|
| 272 |
+
# ------------------------------------------------------------------
|
| 273 |
+
|
| 274 |
+
def fix_cell(self, table: str, row_id: int, column: str, value: Any) -> None:
|
| 275 |
+
"""Update one cell in the database.
|
| 276 |
+
|
| 277 |
+
Raises:
|
| 278 |
+
ValueError: If table/column not found or row_id does not exist.
|
| 279 |
+
"""
|
| 280 |
+
self._require_table(table)
|
| 281 |
+
cols = self.columns(table)
|
| 282 |
+
if column not in cols:
|
| 283 |
+
raise ValueError(f"Column '{column}' not found in table '{table}'.")
|
| 284 |
+
|
| 285 |
+
pk = self._pk_col
|
| 286 |
+
existing = self._conn.execute(
|
| 287 |
+
f'SELECT "{pk}" FROM "{table}" WHERE "{pk}" = ?', (row_id,)
|
| 288 |
+
).fetchone()
|
| 289 |
+
if existing is None:
|
| 290 |
+
raise ValueError(f"Row id={row_id} not found in table '{table}'.")
|
| 291 |
+
|
| 292 |
+
# Coerce value to the column's detected dtype so SQLite stores correctly.
|
| 293 |
+
# Without this, an agent sending value="25.5" for a REAL column would
|
| 294 |
+
# store TEXT instead of REAL, causing false type_error flags in validation.
|
| 295 |
+
profile = self._profiles.get(table, {})
|
| 296 |
+
col_dtype = profile.get(column, {}).get("dtype", "str")
|
| 297 |
+
if col_dtype in ("int", "float") and value is not None:
|
| 298 |
+
try:
|
| 299 |
+
fval = float(str(value))
|
| 300 |
+
safe_val = int(fval) if col_dtype == "int" and fval == int(fval) else fval
|
| 301 |
+
except (ValueError, TypeError):
|
| 302 |
+
safe_val = _to_sqlite(value)
|
| 303 |
+
else:
|
| 304 |
+
safe_val = _to_sqlite(value)
|
| 305 |
+
|
| 306 |
+
self._conn.execute(
|
| 307 |
+
f'UPDATE "{table}" SET "{column}" = ? WHERE "{pk}" = ?',
|
| 308 |
+
(safe_val, row_id),
|
| 309 |
+
)
|
| 310 |
+
self._conn.commit()
|
| 311 |
+
self._touched_columns.add(column)
|
| 312 |
+
|
| 313 |
+
# Invalidate cached profile for this column
|
| 314 |
+
if table in self._profiles and column in self._profiles[table]:
|
| 315 |
+
del self._profiles[table][column]
|
| 316 |
+
|
| 317 |
+
def fix_column(self, table: str, column: str, value: Any) -> dict:
|
| 318 |
+
"""Fix ALL data quality issues in a column in one bulk operation.
|
| 319 |
+
|
| 320 |
+
Fixes: nulls, empty strings, type errors (non-castable values in
|
| 321 |
+
numeric columns), and negative values in must-be-positive columns.
|
| 322 |
+
|
| 323 |
+
Returns dict with counts: {nulls_fixed, type_errors_fixed,
|
| 324 |
+
negatives_fixed, total_fixed}.
|
| 325 |
+
"""
|
| 326 |
+
self._require_table(table)
|
| 327 |
+
cols = self.columns(table)
|
| 328 |
+
if column not in cols:
|
| 329 |
+
raise ValueError(f"Column '{column}' not found in table '{table}'.")
|
| 330 |
+
|
| 331 |
+
profile = self._profiles.get(table, {})
|
| 332 |
+
col_profile = profile.get(column, {})
|
| 333 |
+
col_dtype = col_profile.get("dtype", "str")
|
| 334 |
+
must_be_positive = col_profile.get("must_be_positive", False)
|
| 335 |
+
|
| 336 |
+
# Coerce fill value to column dtype
|
| 337 |
+
if col_dtype in ("int", "float") and value is not None:
|
| 338 |
+
try:
|
| 339 |
+
fval = float(str(value))
|
| 340 |
+
safe_val = int(fval) if col_dtype == "int" and fval == int(fval) else fval
|
| 341 |
+
except (ValueError, TypeError):
|
| 342 |
+
safe_val = _to_sqlite(value)
|
| 343 |
+
else:
|
| 344 |
+
safe_val = _to_sqlite(value)
|
| 345 |
+
|
| 346 |
+
total = 0
|
| 347 |
+
|
| 348 |
+
# 1. Fix NULLs and empty strings
|
| 349 |
+
cur = self._conn.execute(
|
| 350 |
+
f'UPDATE "{table}" SET "{column}" = ? '
|
| 351 |
+
f'WHERE "{column}" IS NULL OR TRIM("{column}") = ?',
|
| 352 |
+
(safe_val, ""),
|
| 353 |
+
)
|
| 354 |
+
nulls_fixed = cur.rowcount
|
| 355 |
+
total += nulls_fixed
|
| 356 |
+
|
| 357 |
+
# 2. Fix type errors: non-castable strings in numeric columns
|
| 358 |
+
type_errors_fixed = 0
|
| 359 |
+
if col_dtype in ("int", "float"):
|
| 360 |
+
# Find rows where the value can't be cast to a number
|
| 361 |
+
pk = self._pk_col
|
| 362 |
+
rows = self._conn.execute(
|
| 363 |
+
f'SELECT "{pk}", "{column}" FROM "{table}" '
|
| 364 |
+
f'WHERE "{column}" IS NOT NULL AND TRIM("{column}") != ?',
|
| 365 |
+
("",),
|
| 366 |
+
).fetchall()
|
| 367 |
+
for row in rows:
|
| 368 |
+
rid = row[0]
|
| 369 |
+
val = row[1]
|
| 370 |
+
try:
|
| 371 |
+
float(str(val))
|
| 372 |
+
except (ValueError, TypeError):
|
| 373 |
+
# This value is not castable to float — it's a type error
|
| 374 |
+
self._conn.execute(
|
| 375 |
+
f'UPDATE "{table}" SET "{column}" = ? WHERE "{pk}" = ?',
|
| 376 |
+
(safe_val, rid),
|
| 377 |
+
)
|
| 378 |
+
type_errors_fixed += 1
|
| 379 |
+
total += type_errors_fixed
|
| 380 |
+
|
| 381 |
+
# 3. Fix negative values in must-be-positive columns
|
| 382 |
+
negatives_fixed = 0
|
| 383 |
+
if must_be_positive and col_dtype in ("int", "float"):
|
| 384 |
+
cur = self._conn.execute(
|
| 385 |
+
f'UPDATE "{table}" SET "{column}" = ABS(CAST("{column}" AS REAL)) '
|
| 386 |
+
f'WHERE CAST("{column}" AS REAL) < 0',
|
| 387 |
+
)
|
| 388 |
+
negatives_fixed = cur.rowcount
|
| 389 |
+
total += negatives_fixed
|
| 390 |
+
|
| 391 |
+
self._conn.commit()
|
| 392 |
+
self._touched_columns.add(column)
|
| 393 |
+
|
| 394 |
+
# Invalidate profile cache
|
| 395 |
+
if table in self._profiles and column in self._profiles[table]:
|
| 396 |
+
del self._profiles[table][column]
|
| 397 |
+
|
| 398 |
+
return {
|
| 399 |
+
"nulls_fixed": nulls_fixed,
|
| 400 |
+
"type_errors_fixed": type_errors_fixed,
|
| 401 |
+
"negatives_fixed": negatives_fixed,
|
| 402 |
+
"total_fixed": total,
|
| 403 |
+
}
|
| 404 |
+
|
| 405 |
+
def delete_row(self, table: str, row_id: int) -> None:
|
| 406 |
+
"""Delete a row from the database.
|
| 407 |
+
|
| 408 |
+
Raises:
|
| 409 |
+
ValueError: If table not found or row does not exist.
|
| 410 |
+
"""
|
| 411 |
+
self._require_table(table)
|
| 412 |
+
pk = self._pk_col
|
| 413 |
+
existing = self._conn.execute(
|
| 414 |
+
f'SELECT "{pk}" FROM "{table}" WHERE "{pk}" = ?', (row_id,)
|
| 415 |
+
).fetchone()
|
| 416 |
+
if existing is None:
|
| 417 |
+
raise ValueError(f"Row id={row_id} not found in table '{table}'.")
|
| 418 |
+
|
| 419 |
+
self._conn.execute(
|
| 420 |
+
f'DELETE FROM "{table}" WHERE "{pk}" = ?', (row_id,)
|
| 421 |
+
)
|
| 422 |
+
self._conn.commit()
|
| 423 |
+
|
| 424 |
+
# ------------------------------------------------------------------
|
| 425 |
+
# Validation
|
| 426 |
+
# ------------------------------------------------------------------
|
| 427 |
+
|
| 428 |
+
def validate(self) -> ValidationResult:
|
| 429 |
+
"""Run all 6 validator checks against current state."""
|
| 430 |
+
current = self.rows(self._primary_table)
|
| 431 |
+
return self._validator.validate(
|
| 432 |
+
conn=self._conn,
|
| 433 |
+
current_records=current,
|
| 434 |
+
touched_columns=self._touched_columns,
|
| 435 |
+
)
|
| 436 |
+
|
| 437 |
+
# ------------------------------------------------------------------
|
| 438 |
+
# State / scoring helpers
|
| 439 |
+
# ------------------------------------------------------------------
|
| 440 |
+
|
| 441 |
+
def current_state(self) -> list[dict]:
|
| 442 |
+
"""Return current rows of the primary table."""
|
| 443 |
+
return self.rows(self._primary_table)
|
| 444 |
+
|
| 445 |
+
def original_state(self) -> list[dict]:
|
| 446 |
+
"""Return the deep-copied original rows (before any fixes)."""
|
| 447 |
+
return copy.deepcopy(self._originals[self._primary_table])
|
| 448 |
+
|
| 449 |
+
@property
|
| 450 |
+
def primary_table(self) -> str:
|
| 451 |
+
return self._primary_table
|
| 452 |
+
|
| 453 |
+
@property
|
| 454 |
+
def pk_col(self) -> str:
|
| 455 |
+
return self._pk_col
|
| 456 |
+
|
| 457 |
+
@property
|
| 458 |
+
def trap(self) -> Optional[Trap]:
|
| 459 |
+
return self._trap
|
| 460 |
+
|
| 461 |
+
@property
|
| 462 |
+
def issue_registry(self) -> list[Issue]:
|
| 463 |
+
"""The ground-truth issue list. NEVER sent to the agent."""
|
| 464 |
+
return self._issues
|
| 465 |
+
|
| 466 |
+
@property
|
| 467 |
+
def total_issues(self) -> int:
|
| 468 |
+
return len(self._issues)
|
| 469 |
+
|
| 470 |
+
def issues_remaining(self) -> int:
|
| 471 |
+
"""Count issues not yet resolved by the current DB state."""
|
| 472 |
+
current = self.rows(self._primary_table)
|
| 473 |
+
pk_col = self._pk_col
|
| 474 |
+
row_map = {row[pk_col]: row for row in current}
|
| 475 |
+
current_ids = set(row_map.keys())
|
| 476 |
+
|
| 477 |
+
remaining = 0
|
| 478 |
+
for iss in self._issues:
|
| 479 |
+
if iss.issue_type in ("duplicate", "fk_violation"):
|
| 480 |
+
if iss.row_id in current_ids:
|
| 481 |
+
remaining += 1
|
| 482 |
+
elif iss.issue_type == "null":
|
| 483 |
+
row = row_map.get(iss.row_id)
|
| 484 |
+
if row is not None and _is_null(row.get(iss.column)):
|
| 485 |
+
remaining += 1
|
| 486 |
+
elif iss.issue_type == "type_error":
|
| 487 |
+
row = row_map.get(iss.row_id)
|
| 488 |
+
if row is not None:
|
| 489 |
+
val = row.get(iss.column)
|
| 490 |
+
# Only count as remaining if non-null AND still non-castable
|
| 491 |
+
# (prevents null cells being double-counted as type errors)
|
| 492 |
+
if not _is_null(val) and not _can_cast_float(val):
|
| 493 |
+
remaining += 1
|
| 494 |
+
elif iss.issue_type == "constraint":
|
| 495 |
+
row = row_map.get(iss.row_id)
|
| 496 |
+
if row is not None:
|
| 497 |
+
val = row.get(iss.column)
|
| 498 |
+
if val is not None and _can_cast_float(val) and float(val) < 0:
|
| 499 |
+
remaining += 1
|
| 500 |
+
elif iss.issue_type == "outlier":
|
| 501 |
+
row = row_map.get(iss.row_id)
|
| 502 |
+
if row is not None:
|
| 503 |
+
val = row.get(iss.column)
|
| 504 |
+
if val is not None and _can_cast_float(val):
|
| 505 |
+
profile = self._profiles.get(self._primary_table, {})
|
| 506 |
+
p = profile.get(iss.column, {})
|
| 507 |
+
mean = p.get("mean")
|
| 508 |
+
std = p.get("std")
|
| 509 |
+
if mean is not None and std and std > 0:
|
| 510 |
+
z = abs(float(val) - mean) / std
|
| 511 |
+
if z > 5.0:
|
| 512 |
+
remaining += 1
|
| 513 |
+
return remaining
|
| 514 |
+
|
| 515 |
+
def log_action(self, action: Any) -> None:
|
| 516 |
+
"""Append an action to the episode log."""
|
| 517 |
+
self._action_log.append(action)
|
| 518 |
+
|
| 519 |
+
# ------------------------------------------------------------------
|
| 520 |
+
# Private helpers
|
| 521 |
+
# ------------------------------------------------------------------
|
| 522 |
+
|
| 523 |
+
def _require_table(self, table: str) -> None:
|
| 524 |
+
if table not in self._table_names:
|
| 525 |
+
raise ValueError(
|
| 526 |
+
f"Table '{table}' not found. "
|
| 527 |
+
f"Available tables: {self._table_names}"
|
| 528 |
+
)
|
| 529 |
+
|
| 530 |
+
|
| 531 |
+
# ---------------------------------------------------------------------------
|
| 532 |
+
# Module-level helpers
|
| 533 |
+
# ---------------------------------------------------------------------------
|
| 534 |
+
|
| 535 |
+
def _to_sqlite(value: Any) -> Any:
|
| 536 |
+
"""Convert a Python value to a SQLite-safe scalar."""
|
| 537 |
+
if value is None:
|
| 538 |
+
return None
|
| 539 |
+
if isinstance(value, bool):
|
| 540 |
+
return int(value)
|
| 541 |
+
if isinstance(value, (int, float, str, bytes)):
|
| 542 |
+
return value
|
| 543 |
+
if isinstance(value, float) and math.isnan(value):
|
| 544 |
+
return None
|
| 545 |
+
return str(value)
|
| 546 |
+
|
| 547 |
+
|
| 548 |
+
def _is_null(value: Any) -> bool:
|
| 549 |
+
if value is None:
|
| 550 |
+
return True
|
| 551 |
+
if isinstance(value, float) and math.isnan(value):
|
| 552 |
+
return True
|
| 553 |
+
if isinstance(value, str) and value.strip() == "":
|
| 554 |
+
return True
|
| 555 |
+
return False
|
| 556 |
+
|
| 557 |
+
|
| 558 |
+
def _can_cast_float(value: Any) -> bool:
|
| 559 |
+
try:
|
| 560 |
+
float(str(value))
|
| 561 |
+
return True
|
| 562 |
+
except (ValueError, TypeError):
|
| 563 |
+
return False
|
sqlsherlock_env/server/dataset_loader.py
ADDED
|
@@ -0,0 +1,467 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
Dataset loader for SQLSherlock-Env.
|
| 9 |
+
|
| 10 |
+
Supports: local CSV/JSON/JSONL/Parquet, HuggingFace dataset names, raw CSV text.
|
| 11 |
+
ZERO defaults — raises ValueError if source is empty or unrecognisable.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import csv
|
| 15 |
+
import io
|
| 16 |
+
import json
|
| 17 |
+
import math
|
| 18 |
+
import os
|
| 19 |
+
import sqlite3
|
| 20 |
+
from pathlib import Path
|
| 21 |
+
from typing import Any, Optional
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
# ---------------------------------------------------------------------------
|
| 25 |
+
# Public API
|
| 26 |
+
# ---------------------------------------------------------------------------
|
| 27 |
+
|
| 28 |
+
def load(source: str, max_rows: int = 500) -> dict[str, list[dict]]:
|
| 29 |
+
"""Load a dataset from *source* and return a table-name → records mapping.
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
source: One of:
|
| 33 |
+
- Absolute/relative path ending in .csv/.json/.jsonl/.parquet
|
| 34 |
+
- HuggingFace dataset name "owner/name" or "owner/name:split"
|
| 35 |
+
- Raw CSV text (multi-line string with comma-separated header)
|
| 36 |
+
max_rows: Maximum rows to keep per table.
|
| 37 |
+
|
| 38 |
+
Returns:
|
| 39 |
+
Dict mapping table name (str) to list of row dicts.
|
| 40 |
+
Each dict has an "id" key added if not already present.
|
| 41 |
+
A ``_source_format`` key is injected into each record for the
|
| 42 |
+
exporter to reconstruct the original format.
|
| 43 |
+
|
| 44 |
+
Raises:
|
| 45 |
+
ValueError: On empty source, auth failure, not found, too few rows,
|
| 46 |
+
no columns, or unrecognised format.
|
| 47 |
+
"""
|
| 48 |
+
if not source or not source.strip():
|
| 49 |
+
raise ValueError("Dataset source must not be empty.")
|
| 50 |
+
|
| 51 |
+
source = source.strip()
|
| 52 |
+
|
| 53 |
+
# Dispatch to loader
|
| 54 |
+
if _is_local_file(source):
|
| 55 |
+
records, fmt = _load_local(source, max_rows)
|
| 56 |
+
elif _is_hf_dataset(source):
|
| 57 |
+
records, fmt = _load_hf(source, max_rows)
|
| 58 |
+
elif _looks_like_csv_text(source):
|
| 59 |
+
records, fmt = _load_raw_csv(source, max_rows)
|
| 60 |
+
else:
|
| 61 |
+
raise ValueError(
|
| 62 |
+
f"Unrecognised source '{source}'. "
|
| 63 |
+
"Provide a file path (.csv/.json/.jsonl/.parquet), "
|
| 64 |
+
"a HuggingFace dataset name (owner/name), "
|
| 65 |
+
"or raw CSV text."
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
_validate_records(records)
|
| 69 |
+
records = _ensure_id_column(records)
|
| 70 |
+
records = coerce(records)
|
| 71 |
+
|
| 72 |
+
# Inject source format so exporter can match output format
|
| 73 |
+
for row in records:
|
| 74 |
+
row["_source_format"] = fmt
|
| 75 |
+
|
| 76 |
+
table_name = _table_name_from_source(source)
|
| 77 |
+
return {table_name: records}
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def coerce(records: list[dict]) -> list[dict]:
|
| 81 |
+
"""Auto-detect and coerce int/float values per column.
|
| 82 |
+
|
| 83 |
+
For each column, if ALL non-null values can be cast to int → cast to int.
|
| 84 |
+
Else if ALL non-null values can be cast to float → cast to float.
|
| 85 |
+
Otherwise leave as string.
|
| 86 |
+
|
| 87 |
+
The ``_source_format`` and ``id`` columns are never coerced.
|
| 88 |
+
"""
|
| 89 |
+
if not records:
|
| 90 |
+
return records
|
| 91 |
+
|
| 92 |
+
columns = [c for c in records[0].keys() if c not in ("_source_format",)]
|
| 93 |
+
|
| 94 |
+
for col in columns:
|
| 95 |
+
values = [r.get(col) for r in records]
|
| 96 |
+
non_null = [v for v in values if not _is_null(v)]
|
| 97 |
+
if not non_null:
|
| 98 |
+
continue
|
| 99 |
+
|
| 100 |
+
target_type = _detect_target_type(non_null)
|
| 101 |
+
if target_type is None:
|
| 102 |
+
continue
|
| 103 |
+
|
| 104 |
+
for row in records:
|
| 105 |
+
v = row.get(col)
|
| 106 |
+
if _is_null(v):
|
| 107 |
+
row[col] = None
|
| 108 |
+
continue
|
| 109 |
+
try:
|
| 110 |
+
fval = float(str(v))
|
| 111 |
+
if target_type == "int":
|
| 112 |
+
# Only cast to int if value is genuinely whole-number
|
| 113 |
+
# (avoids silently truncating 3.7 → 3)
|
| 114 |
+
row[col] = int(fval) if fval == int(fval) else fval
|
| 115 |
+
else:
|
| 116 |
+
row[col] = fval
|
| 117 |
+
except (ValueError, TypeError):
|
| 118 |
+
pass # leave as-is if cast fails (type_error issue will detect it)
|
| 119 |
+
|
| 120 |
+
return records
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def records_to_sqlite(
|
| 124 |
+
conn: sqlite3.Connection,
|
| 125 |
+
table: str,
|
| 126 |
+
records: list[dict],
|
| 127 |
+
) -> None:
|
| 128 |
+
"""Write *records* into an in-memory SQLite table.
|
| 129 |
+
|
| 130 |
+
Creates the table fresh (DROP IF EXISTS then CREATE).
|
| 131 |
+
Column types are inferred from the records.
|
| 132 |
+
|
| 133 |
+
The ``_source_format`` column is NOT written to SQLite
|
| 134 |
+
(it is preserved in the Python records only).
|
| 135 |
+
"""
|
| 136 |
+
if not records:
|
| 137 |
+
raise ValueError(f"Cannot create table '{table}' from empty records.")
|
| 138 |
+
|
| 139 |
+
# Filter out the internal metadata column
|
| 140 |
+
columns = [c for c in records[0].keys() if c != "_source_format"]
|
| 141 |
+
|
| 142 |
+
# Infer SQLite column types
|
| 143 |
+
col_types = {}
|
| 144 |
+
for col in columns:
|
| 145 |
+
vals = [r.get(col) for r in records if not _is_null(r.get(col))]
|
| 146 |
+
col_types[col] = _sqlite_type(vals)
|
| 147 |
+
|
| 148 |
+
col_defs = ", ".join(
|
| 149 |
+
f'"{col}" {col_types[col]}' for col in columns
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
conn.execute(f'DROP TABLE IF EXISTS "{table}"')
|
| 153 |
+
conn.execute(f'CREATE TABLE "{table}" ({col_defs})')
|
| 154 |
+
|
| 155 |
+
placeholders = ", ".join("?" for _ in columns)
|
| 156 |
+
rows_to_insert = [
|
| 157 |
+
tuple(_sqlite_val(r.get(col)) for col in columns)
|
| 158 |
+
for r in records
|
| 159 |
+
]
|
| 160 |
+
conn.executemany(
|
| 161 |
+
f'INSERT INTO "{table}" VALUES ({placeholders})',
|
| 162 |
+
rows_to_insert,
|
| 163 |
+
)
|
| 164 |
+
conn.commit()
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
# ---------------------------------------------------------------------------
|
| 168 |
+
# Local file loaders
|
| 169 |
+
# ---------------------------------------------------------------------------
|
| 170 |
+
|
| 171 |
+
def _load_local(path: str, max_rows: int) -> tuple[list[dict], str]:
|
| 172 |
+
p = Path(path)
|
| 173 |
+
if not p.exists():
|
| 174 |
+
raise ValueError(f"File not found: {path}")
|
| 175 |
+
|
| 176 |
+
suffix = p.suffix.lower()
|
| 177 |
+
if suffix == ".csv":
|
| 178 |
+
return _load_csv_file(p, max_rows), "csv"
|
| 179 |
+
elif suffix == ".json":
|
| 180 |
+
return _load_json_file(p, max_rows), "json"
|
| 181 |
+
elif suffix == ".jsonl":
|
| 182 |
+
return _load_jsonl_file(p, max_rows), "jsonl"
|
| 183 |
+
elif suffix == ".parquet":
|
| 184 |
+
return _load_parquet_file(p, max_rows), "parquet"
|
| 185 |
+
else:
|
| 186 |
+
raise ValueError(
|
| 187 |
+
f"Unsupported file extension '{suffix}'. "
|
| 188 |
+
"Use .csv, .json, .jsonl, or .parquet."
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def _load_csv_file(path: Path, max_rows: int) -> list[dict]:
|
| 193 |
+
with open(path, newline="", encoding="utf-8-sig") as f:
|
| 194 |
+
reader = csv.DictReader(f)
|
| 195 |
+
rows = []
|
| 196 |
+
for i, row in enumerate(reader):
|
| 197 |
+
if i >= max_rows:
|
| 198 |
+
break
|
| 199 |
+
rows.append(dict(row))
|
| 200 |
+
return rows
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def _load_json_file(path: Path, max_rows: int) -> list[dict]:
|
| 204 |
+
with open(path, encoding="utf-8") as f:
|
| 205 |
+
data = json.load(f)
|
| 206 |
+
if isinstance(data, dict):
|
| 207 |
+
# Might be {records: [...]} or similar
|
| 208 |
+
for key in ("records", "data", "rows", "items"):
|
| 209 |
+
if key in data and isinstance(data[key], list):
|
| 210 |
+
data = data[key]
|
| 211 |
+
break
|
| 212 |
+
else:
|
| 213 |
+
raise ValueError("JSON file must contain a list of records.")
|
| 214 |
+
if not isinstance(data, list):
|
| 215 |
+
raise ValueError("JSON file must contain a list of records.")
|
| 216 |
+
return [dict(r) for r in data[:max_rows]]
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
def _load_jsonl_file(path: Path, max_rows: int) -> list[dict]:
|
| 220 |
+
rows = []
|
| 221 |
+
with open(path, encoding="utf-8") as f:
|
| 222 |
+
for i, line in enumerate(f):
|
| 223 |
+
if i >= max_rows:
|
| 224 |
+
break
|
| 225 |
+
line = line.strip()
|
| 226 |
+
if line:
|
| 227 |
+
rows.append(json.loads(line))
|
| 228 |
+
return rows
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
def _load_parquet_file(path: Path, max_rows: int) -> list[dict]:
|
| 232 |
+
try:
|
| 233 |
+
import pandas as pd
|
| 234 |
+
except ImportError:
|
| 235 |
+
raise ValueError("pandas is required to load Parquet files. pip install pandas pyarrow")
|
| 236 |
+
df = pd.read_parquet(path)
|
| 237 |
+
df = df.head(max_rows)
|
| 238 |
+
return _df_to_records(df)
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
# ---------------------------------------------------------------------------
|
| 242 |
+
# HuggingFace dataset loader
|
| 243 |
+
# ---------------------------------------------------------------------------
|
| 244 |
+
|
| 245 |
+
def _load_hf(source: str, max_rows: int) -> tuple[list[dict], str]:
|
| 246 |
+
"""Load a dataset from HuggingFace Hub.
|
| 247 |
+
|
| 248 |
+
source format: "owner/name" or "owner/name:split"
|
| 249 |
+
"""
|
| 250 |
+
try:
|
| 251 |
+
from datasets import load_dataset
|
| 252 |
+
except ImportError:
|
| 253 |
+
raise ValueError(
|
| 254 |
+
"The 'datasets' package is required for HuggingFace datasets. "
|
| 255 |
+
"pip install datasets"
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
# Parse split
|
| 259 |
+
split = "train"
|
| 260 |
+
name = source
|
| 261 |
+
if ":" in source:
|
| 262 |
+
name, split = source.rsplit(":", 1)
|
| 263 |
+
|
| 264 |
+
hf_token = os.environ.get("HF_TOKEN")
|
| 265 |
+
|
| 266 |
+
try:
|
| 267 |
+
ds = load_dataset(name, split=split, token=hf_token)
|
| 268 |
+
except Exception as exc:
|
| 269 |
+
msg = str(exc).lower()
|
| 270 |
+
if "401" in msg or "unauthorized" in msg or "authentication" in msg:
|
| 271 |
+
raise ValueError(
|
| 272 |
+
f"Dataset '{name}' requires authentication. "
|
| 273 |
+
"Use a public dataset or set the HF_TOKEN environment variable."
|
| 274 |
+
) from exc
|
| 275 |
+
if "404" in msg or "not found" in msg or "doesn't exist" in msg:
|
| 276 |
+
raise ValueError(
|
| 277 |
+
f"Dataset '{name}' not found. "
|
| 278 |
+
"Check the owner/name format (e.g. 'mstz/titanic')."
|
| 279 |
+
) from exc
|
| 280 |
+
raise ValueError(f"Failed to load HuggingFace dataset '{source}': {exc}") from exc
|
| 281 |
+
|
| 282 |
+
# Convert to list of dicts
|
| 283 |
+
try:
|
| 284 |
+
import pandas as pd
|
| 285 |
+
df = ds.to_pandas().head(max_rows)
|
| 286 |
+
records = _df_to_records(df)
|
| 287 |
+
except Exception:
|
| 288 |
+
records = [dict(row) for row in ds.select(range(min(max_rows, len(ds))))]
|
| 289 |
+
|
| 290 |
+
return records, "hf_dataset"
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
# ---------------------------------------------------------------------------
|
| 294 |
+
# Raw CSV text loader
|
| 295 |
+
# ---------------------------------------------------------------------------
|
| 296 |
+
|
| 297 |
+
def _load_raw_csv(source: str, max_rows: int) -> tuple[list[dict], str]:
|
| 298 |
+
reader = csv.DictReader(io.StringIO(source))
|
| 299 |
+
rows = []
|
| 300 |
+
for i, row in enumerate(reader):
|
| 301 |
+
if i >= max_rows:
|
| 302 |
+
break
|
| 303 |
+
rows.append(dict(row))
|
| 304 |
+
return rows, "csv"
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
# ---------------------------------------------------------------------------
|
| 308 |
+
# Validation & helpers
|
| 309 |
+
# ---------------------------------------------------------------------------
|
| 310 |
+
|
| 311 |
+
def _validate_records(records: list[dict]) -> None:
|
| 312 |
+
if not records:
|
| 313 |
+
raise ValueError("Dataset loaded 0 rows. Need at least 5.")
|
| 314 |
+
if len(records) < 5:
|
| 315 |
+
raise ValueError(
|
| 316 |
+
f"Dataset has only {len(records)} rows. Need at least 5."
|
| 317 |
+
)
|
| 318 |
+
if not records[0]:
|
| 319 |
+
raise ValueError("Dataset has no columns.")
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
def _ensure_id_column(records: list[dict]) -> list[dict]:
|
| 323 |
+
"""Guarantee every record has an integer 'id' column as the FIRST field."""
|
| 324 |
+
if not records:
|
| 325 |
+
return records
|
| 326 |
+
|
| 327 |
+
# Check all columns for a PK-like column (not just the first)
|
| 328 |
+
all_cols = list(records[0].keys())
|
| 329 |
+
pk_col = None
|
| 330 |
+
for col in all_cols:
|
| 331 |
+
if col.lower() in ("id", "passengerid", "index", "passengerId"):
|
| 332 |
+
pk_col = col
|
| 333 |
+
break
|
| 334 |
+
|
| 335 |
+
if pk_col is not None:
|
| 336 |
+
# Rename to 'id' and reorder to put it first
|
| 337 |
+
for i, row in enumerate(records):
|
| 338 |
+
pk_val = row.pop(pk_col) if pk_col != "id" else row.pop("id")
|
| 339 |
+
try:
|
| 340 |
+
pk_val = int(pk_val)
|
| 341 |
+
except (ValueError, TypeError):
|
| 342 |
+
pk_val = i + 1
|
| 343 |
+
# Rebuild dict with 'id' first
|
| 344 |
+
records[i] = {"id": pk_val, **row}
|
| 345 |
+
return records
|
| 346 |
+
|
| 347 |
+
# No obvious PK — inject sequential id as first field
|
| 348 |
+
for i, row in enumerate(records):
|
| 349 |
+
records[i] = {"id": i + 1, **row}
|
| 350 |
+
|
| 351 |
+
return records
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
def _table_name_from_source(source: str) -> str:
|
| 355 |
+
"""Derive a clean table name from the source string."""
|
| 356 |
+
if _is_local_file(source):
|
| 357 |
+
stem = Path(source).stem
|
| 358 |
+
return _sanitise_name(stem)
|
| 359 |
+
if _is_hf_dataset(source):
|
| 360 |
+
base = source.split(":")[0] # strip split
|
| 361 |
+
parts = base.split("/")
|
| 362 |
+
return _sanitise_name(parts[-1]) # e.g. "titanic"
|
| 363 |
+
return "dataset"
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
def _sanitise_name(name: str) -> str:
|
| 367 |
+
"""Return a SQLite-safe lowercase identifier."""
|
| 368 |
+
safe = "".join(c if c.isalnum() or c == "_" else "_" for c in name.lower())
|
| 369 |
+
if safe and safe[0].isdigit():
|
| 370 |
+
safe = "t_" + safe
|
| 371 |
+
return safe or "dataset"
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
def _is_local_file(source: str) -> bool:
|
| 375 |
+
return any(source.lower().endswith(ext) for ext in (".csv", ".json", ".jsonl", ".parquet"))
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
def _is_hf_dataset(source: str) -> bool:
|
| 379 |
+
"""Heuristic: 'owner/name' with no spaces and not a file path."""
|
| 380 |
+
if "/" not in source:
|
| 381 |
+
return False
|
| 382 |
+
if any(source.lower().endswith(ext) for ext in (".csv", ".json", ".jsonl", ".parquet")):
|
| 383 |
+
return False
|
| 384 |
+
if "\n" in source or "," not in source.split("\n")[0]:
|
| 385 |
+
# Might still be HF if no comma in first line
|
| 386 |
+
parts = source.split("/")
|
| 387 |
+
return len(parts) == 2 or (len(parts) == 2 and ":" in parts[-1])
|
| 388 |
+
return "/" in source and "\n" not in source and len(source.split("/")) == 2
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
def _looks_like_csv_text(source: str) -> bool:
|
| 392 |
+
"""Return True if source looks like raw CSV text (has newlines and commas)."""
|
| 393 |
+
lines = source.strip().splitlines()
|
| 394 |
+
return len(lines) >= 2 and "," in lines[0]
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
def _detect_target_type(non_null: list[Any]) -> Optional[str]:
|
| 398 |
+
"""Return 'int' or 'float' if all values are numeric, else None."""
|
| 399 |
+
# Try int
|
| 400 |
+
try:
|
| 401 |
+
for v in non_null:
|
| 402 |
+
f = float(str(v))
|
| 403 |
+
if f != int(f):
|
| 404 |
+
raise ValueError
|
| 405 |
+
return "int"
|
| 406 |
+
except (ValueError, TypeError):
|
| 407 |
+
pass
|
| 408 |
+
# Try float
|
| 409 |
+
try:
|
| 410 |
+
for v in non_null:
|
| 411 |
+
float(str(v))
|
| 412 |
+
return "float"
|
| 413 |
+
except (ValueError, TypeError):
|
| 414 |
+
pass
|
| 415 |
+
return None
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
def _is_null(value: Any) -> bool:
|
| 419 |
+
if value is None:
|
| 420 |
+
return True
|
| 421 |
+
if isinstance(value, float) and math.isnan(value):
|
| 422 |
+
return True
|
| 423 |
+
if isinstance(value, str) and value.strip() == "":
|
| 424 |
+
return True
|
| 425 |
+
return False
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
def _sqlite_type(non_null_vals: list[Any]) -> str:
|
| 429 |
+
if not non_null_vals:
|
| 430 |
+
return "TEXT"
|
| 431 |
+
target = _detect_target_type(non_null_vals)
|
| 432 |
+
if target == "int":
|
| 433 |
+
return "INTEGER"
|
| 434 |
+
if target == "float":
|
| 435 |
+
return "REAL"
|
| 436 |
+
return "TEXT"
|
| 437 |
+
|
| 438 |
+
|
| 439 |
+
def _sqlite_val(value: Any) -> Any:
|
| 440 |
+
"""Convert a Python value to a SQLite-compatible scalar."""
|
| 441 |
+
if value is None:
|
| 442 |
+
return None
|
| 443 |
+
if isinstance(value, float) and math.isnan(value):
|
| 444 |
+
return None
|
| 445 |
+
if isinstance(value, (int, float, str, bytes)):
|
| 446 |
+
return value
|
| 447 |
+
return str(value)
|
| 448 |
+
|
| 449 |
+
|
| 450 |
+
def _df_to_records(df) -> list[dict]:
|
| 451 |
+
"""Convert a pandas DataFrame to a list of plain Python dicts."""
|
| 452 |
+
import math as _math
|
| 453 |
+
records = []
|
| 454 |
+
for _, row in df.iterrows():
|
| 455 |
+
d = {}
|
| 456 |
+
for col, val in row.items():
|
| 457 |
+
# Convert numpy/pandas scalars to Python natives
|
| 458 |
+
if hasattr(val, "item"):
|
| 459 |
+
try:
|
| 460 |
+
val = val.item()
|
| 461 |
+
except Exception:
|
| 462 |
+
val = str(val)
|
| 463 |
+
if isinstance(val, float) and _math.isnan(val):
|
| 464 |
+
val = None
|
| 465 |
+
d[str(col)] = val
|
| 466 |
+
records.append(d)
|
| 467 |
+
return records
|
sqlsherlock_env/server/environment.py
ADDED
|
@@ -0,0 +1,408 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
SQLSherlock RL environment — server-side implementation.
|
| 9 |
+
|
| 10 |
+
Implements the OpenEnv Environment interface. One instance per
|
| 11 |
+
WebSocket session; each reset() creates a fresh DatabaseEngine.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import uuid
|
| 15 |
+
from typing import Any, Optional
|
| 16 |
+
|
| 17 |
+
from openenv.core.env_server import Environment
|
| 18 |
+
|
| 19 |
+
from models import SQLSherlockAction, SQLSherlockObservation, SQLSherlockState
|
| 20 |
+
from server.database import DatabaseEngine
|
| 21 |
+
from server.reward import calc, RB, InvestCounter
|
| 22 |
+
from server import graders
|
| 23 |
+
from server.exporter import export_cleaned
|
| 24 |
+
from server.validator import Validator
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
# ---------------------------------------------------------------------------
|
| 28 |
+
# Task catalogue
|
| 29 |
+
# ---------------------------------------------------------------------------
|
| 30 |
+
|
| 31 |
+
TASKS: list[dict] = [
|
| 32 |
+
{
|
| 33 |
+
"id": "task1_null_and_types",
|
| 34 |
+
"name": "Null and type error repair",
|
| 35 |
+
"difficulty": "easy",
|
| 36 |
+
"max_steps": 20,
|
| 37 |
+
"description": (
|
| 38 |
+
"Find and fix null values and type errors in the primary table. "
|
| 39 |
+
"Profile columns, identify anomalies, fix with reasoning, "
|
| 40 |
+
"validate your work, and export the cleaned dataset."
|
| 41 |
+
),
|
| 42 |
+
},
|
| 43 |
+
{
|
| 44 |
+
"id": "task2_constraints_and_fk",
|
| 45 |
+
"name": "Constraint and FK integrity",
|
| 46 |
+
"difficulty": "medium",
|
| 47 |
+
"max_steps": 25,
|
| 48 |
+
"description": (
|
| 49 |
+
"Everything in Task 1 plus constraint violations "
|
| 50 |
+
"(negative values in must-be-positive columns) and FK "
|
| 51 |
+
"violations (orphan references in related tables)."
|
| 52 |
+
),
|
| 53 |
+
},
|
| 54 |
+
{
|
| 55 |
+
"id": "task3_full_audit_with_trap",
|
| 56 |
+
"name": "Full statistical audit with trap",
|
| 57 |
+
"difficulty": "hard",
|
| 58 |
+
"max_steps": 30,
|
| 59 |
+
"description": (
|
| 60 |
+
"Full audit including statistical outliers. TRAP WARNING: "
|
| 61 |
+
"one numeric value looks suspicious but is legitimate. "
|
| 62 |
+
"You MUST check z-scores before fixing any numeric value. "
|
| 63 |
+
"z > 5 = real outlier. z < 3 = leave alone."
|
| 64 |
+
),
|
| 65 |
+
},
|
| 66 |
+
]
|
| 67 |
+
|
| 68 |
+
_TASK_MAP: dict[str, dict] = {t["id"]: t for t in TASKS}
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
# ---------------------------------------------------------------------------
|
| 72 |
+
# Environment
|
| 73 |
+
# ---------------------------------------------------------------------------
|
| 74 |
+
|
| 75 |
+
class SQLSherlockEnvironment(Environment):
|
| 76 |
+
"""One episode of the SQLSherlock RL environment."""
|
| 77 |
+
|
| 78 |
+
# Called by create_app() as a factory — __init__ must be zero-arg.
|
| 79 |
+
def __init__(self) -> None:
|
| 80 |
+
self._db: Optional[DatabaseEngine] = None
|
| 81 |
+
self._state: Optional[SQLSherlockState] = None
|
| 82 |
+
self._counter: Optional[InvestCounter] = None
|
| 83 |
+
self._reward_trace: list[dict] = []
|
| 84 |
+
self._validation_called: bool = False
|
| 85 |
+
self._export_result: Optional[dict] = None
|
| 86 |
+
|
| 87 |
+
# ------------------------------------------------------------------
|
| 88 |
+
# reset()
|
| 89 |
+
# ------------------------------------------------------------------
|
| 90 |
+
|
| 91 |
+
def reset(self, **kwargs) -> SQLSherlockObservation:
|
| 92 |
+
"""Start a new episode.
|
| 93 |
+
|
| 94 |
+
Keyword Args:
|
| 95 |
+
dataset (str): Dataset source — required, no default.
|
| 96 |
+
task_id (str): Task identifier — required, no default.
|
| 97 |
+
seed (int): RNG seed (default 42).
|
| 98 |
+
max_rows(int): Row limit (default 500).
|
| 99 |
+
|
| 100 |
+
Raises:
|
| 101 |
+
ValueError: If dataset or task_id is missing/invalid.
|
| 102 |
+
"""
|
| 103 |
+
dataset = kwargs.get("dataset", "")
|
| 104 |
+
task_id = kwargs.get("task_id", "")
|
| 105 |
+
seed = int(kwargs.get("seed", 42))
|
| 106 |
+
max_rows = int(kwargs.get("max_rows", 500))
|
| 107 |
+
|
| 108 |
+
if not dataset or not dataset.strip():
|
| 109 |
+
raise ValueError(
|
| 110 |
+
"reset() requires 'dataset' keyword argument. "
|
| 111 |
+
"Provide a file path, HuggingFace dataset name, or raw CSV text."
|
| 112 |
+
)
|
| 113 |
+
if not task_id or not task_id.strip():
|
| 114 |
+
raise ValueError(
|
| 115 |
+
"reset() requires 'task_id' keyword argument. "
|
| 116 |
+
f"Valid tasks: {sorted(_TASK_MAP.keys())}"
|
| 117 |
+
)
|
| 118 |
+
if task_id not in _TASK_MAP:
|
| 119 |
+
raise ValueError(
|
| 120 |
+
f"Unknown task_id '{task_id}'. "
|
| 121 |
+
f"Valid tasks: {sorted(_TASK_MAP.keys())}"
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
task_cfg = _TASK_MAP[task_id]
|
| 125 |
+
|
| 126 |
+
# Fresh database for this episode
|
| 127 |
+
self._db = DatabaseEngine(
|
| 128 |
+
task_id=task_id,
|
| 129 |
+
seed=seed,
|
| 130 |
+
dataset_source=dataset,
|
| 131 |
+
max_rows=max_rows,
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
self._state = SQLSherlockState(
|
| 135 |
+
episode_id=str(uuid.uuid4()),
|
| 136 |
+
task_id=task_id,
|
| 137 |
+
step_count=0,
|
| 138 |
+
grader_score=0.0,
|
| 139 |
+
done=False,
|
| 140 |
+
dataset_name=dataset,
|
| 141 |
+
source_format=self._db.source_format,
|
| 142 |
+
investigation_count=0,
|
| 143 |
+
validation_called=False,
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
self._counter = InvestCounter()
|
| 147 |
+
self._reward_trace = []
|
| 148 |
+
self._validation_called = False
|
| 149 |
+
self._export_result = None
|
| 150 |
+
self._deleted_row_ids: list[int] = [] # track deletes for grader
|
| 151 |
+
|
| 152 |
+
return self._make_obs(
|
| 153 |
+
last_feedback=(
|
| 154 |
+
f"Episode started. Dataset loaded: {self._db.primary_table} "
|
| 155 |
+
f"({len(self._db.rows(self._db.primary_table))} rows). "
|
| 156 |
+
f"Task: {task_cfg['name']}. Max steps: {task_cfg['max_steps']}. "
|
| 157 |
+
"Begin by inspecting the table or profiling columns."
|
| 158 |
+
),
|
| 159 |
+
query_result=None,
|
| 160 |
+
validation_result=None,
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
# ------------------------------------------------------------------
|
| 164 |
+
# step()
|
| 165 |
+
# ------------------------------------------------------------------
|
| 166 |
+
|
| 167 |
+
def step(
|
| 168 |
+
self, action: SQLSherlockAction, **kwargs
|
| 169 |
+
) -> SQLSherlockObservation:
|
| 170 |
+
"""Execute one agent action.
|
| 171 |
+
|
| 172 |
+
Returns the observation with reward and done set on it.
|
| 173 |
+
The openenv framework extracts reward/done from the observation.
|
| 174 |
+
"""
|
| 175 |
+
if self._db is None or self._state is None:
|
| 176 |
+
raise RuntimeError("Call reset() before step().")
|
| 177 |
+
|
| 178 |
+
task_cfg = _TASK_MAP[self._state.task_id]
|
| 179 |
+
max_steps = task_cfg["max_steps"]
|
| 180 |
+
|
| 181 |
+
self._state.step_count += 1
|
| 182 |
+
step = self._state.step_count
|
| 183 |
+
|
| 184 |
+
# Log action for reasoning bonus check
|
| 185 |
+
self._db.log_action(action)
|
| 186 |
+
|
| 187 |
+
query_result = None
|
| 188 |
+
validation_result = None
|
| 189 |
+
feedback = ""
|
| 190 |
+
done = False
|
| 191 |
+
|
| 192 |
+
atype = action.action_type
|
| 193 |
+
|
| 194 |
+
# ------------------------------------------------------------------
|
| 195 |
+
# Dispatch
|
| 196 |
+
# ------------------------------------------------------------------
|
| 197 |
+
try:
|
| 198 |
+
if atype == "inspect":
|
| 199 |
+
table = action.table or self._db.primary_table
|
| 200 |
+
rows = self._db.rows(table)
|
| 201 |
+
query_result = rows
|
| 202 |
+
feedback = f"inspect: returned {len(rows)} rows from '{table}'."
|
| 203 |
+
|
| 204 |
+
elif atype == "profile_column":
|
| 205 |
+
table = action.table or self._db.primary_table
|
| 206 |
+
column = action.column
|
| 207 |
+
if not column:
|
| 208 |
+
raise ValueError("profile_column requires 'column' field.")
|
| 209 |
+
profile = self._db.profile_col(table, column)
|
| 210 |
+
query_result = [profile]
|
| 211 |
+
feedback = (
|
| 212 |
+
f"profile_column '{column}': "
|
| 213 |
+
f"mean={profile.get('mean')}, std={profile.get('std')}, "
|
| 214 |
+
f"null_count={profile.get('null_count')}, "
|
| 215 |
+
f"must_be_positive={profile.get('must_be_positive')}."
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
elif atype == "run_sql":
|
| 219 |
+
sql = action.sql
|
| 220 |
+
if not sql:
|
| 221 |
+
raise ValueError("run_sql requires 'sql' field.")
|
| 222 |
+
rows = self._db.query(sql)
|
| 223 |
+
query_result = rows
|
| 224 |
+
feedback = f"run_sql: returned {len(rows)} rows."
|
| 225 |
+
|
| 226 |
+
elif atype == "fix_cell":
|
| 227 |
+
table = action.table or self._db.primary_table
|
| 228 |
+
row_id = action.row_id
|
| 229 |
+
column = action.column
|
| 230 |
+
value = action.value
|
| 231 |
+
if row_id is None or column is None:
|
| 232 |
+
raise ValueError("fix_cell requires 'row_id' and 'column'.")
|
| 233 |
+
self._db.fix_cell(table, row_id, column, value)
|
| 234 |
+
feedback = (
|
| 235 |
+
f"fix_cell: set [{table}].{column}[id={row_id}] = {value!r}. "
|
| 236 |
+
f"Reason: {action.reason or '(none provided)'}."
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
elif atype == "fix_column":
|
| 240 |
+
table = action.table or self._db.primary_table
|
| 241 |
+
column = action.column
|
| 242 |
+
value = action.value
|
| 243 |
+
if column is None:
|
| 244 |
+
raise ValueError("fix_column requires 'column'.")
|
| 245 |
+
result = self._db.fix_column(table, column, value)
|
| 246 |
+
parts = []
|
| 247 |
+
if result["nulls_fixed"]:
|
| 248 |
+
parts.append(f"{result['nulls_fixed']} nulls")
|
| 249 |
+
if result["type_errors_fixed"]:
|
| 250 |
+
parts.append(f"{result['type_errors_fixed']} type errors")
|
| 251 |
+
if result["negatives_fixed"]:
|
| 252 |
+
parts.append(f"{result['negatives_fixed']} negatives")
|
| 253 |
+
detail = ", ".join(parts) if parts else "0 issues"
|
| 254 |
+
feedback = (
|
| 255 |
+
f"fix_column '{column}': fixed {detail} "
|
| 256 |
+
f"(total {result['total_fixed']} rows) with value={value!r}. "
|
| 257 |
+
f"Reason: {action.reason or '(none provided)'}."
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
elif atype == "delete_row":
|
| 261 |
+
table = action.table or self._db.primary_table
|
| 262 |
+
row_id = action.row_id
|
| 263 |
+
if row_id is None:
|
| 264 |
+
raise ValueError("delete_row requires 'row_id'.")
|
| 265 |
+
self._db.delete_row(table, row_id)
|
| 266 |
+
if row_id not in self._deleted_row_ids:
|
| 267 |
+
self._deleted_row_ids.append(row_id)
|
| 268 |
+
feedback = (
|
| 269 |
+
f"delete_row: removed row id={row_id} from '{table}'. "
|
| 270 |
+
f"Reason: {action.reason or '(none provided)'}."
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
elif atype == "validate":
|
| 274 |
+
vr = self._db.validate()
|
| 275 |
+
validation_result = vr.to_dict()
|
| 276 |
+
self._validation_called = True
|
| 277 |
+
self._state.validation_called = True
|
| 278 |
+
self._last_vr = vr # cache — avoid second validate() call
|
| 279 |
+
feedback = (
|
| 280 |
+
f"validate: {vr.overall} — "
|
| 281 |
+
f"{vr.checks_passed}/{vr.total_checks} checks passed. "
|
| 282 |
+
+ (f"Warnings: {vr.warnings}" if vr.warnings else "")
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
elif atype == "submit":
|
| 286 |
+
current = self._db.current_state()
|
| 287 |
+
score = graders.grade(
|
| 288 |
+
db=self._db,
|
| 289 |
+
cleaned_rows=current,
|
| 290 |
+
removed_ids=list(self._deleted_row_ids),
|
| 291 |
+
task_id=self._state.task_id,
|
| 292 |
+
validation_was_called=self._validation_called,
|
| 293 |
+
)
|
| 294 |
+
self._state.grader_score = score
|
| 295 |
+
done = True
|
| 296 |
+
feedback = (
|
| 297 |
+
f"submit: episode complete. "
|
| 298 |
+
f"Grader score = {score:.4f}. "
|
| 299 |
+
f"Issues remaining: {self._db.issues_remaining()}."
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
elif atype == "export":
|
| 303 |
+
cleaned_rows = action.cleaned_rows or self._db.current_state()
|
| 304 |
+
removed_ids = action.removed_ids or []
|
| 305 |
+
score = graders.grade(
|
| 306 |
+
db=self._db,
|
| 307 |
+
cleaned_rows=cleaned_rows,
|
| 308 |
+
removed_ids=removed_ids,
|
| 309 |
+
task_id=self._state.task_id,
|
| 310 |
+
validation_was_called=self._validation_called,
|
| 311 |
+
)
|
| 312 |
+
self._state.grader_score = score
|
| 313 |
+
export_info = export_cleaned(
|
| 314 |
+
cleaned_rows=cleaned_rows,
|
| 315 |
+
source_format=self._db.source_format,
|
| 316 |
+
dataset_name=self._db.dataset_name,
|
| 317 |
+
)
|
| 318 |
+
self._export_result = export_info
|
| 319 |
+
done = True
|
| 320 |
+
feedback = (
|
| 321 |
+
f"export: {export_info['row_count']} rows written to "
|
| 322 |
+
f"{export_info['download_url']} ({export_info['format']}). "
|
| 323 |
+
f"Grader score = {score:.4f}."
|
| 324 |
+
)
|
| 325 |
+
|
| 326 |
+
else:
|
| 327 |
+
feedback = f"Unknown action_type '{atype}'. No-op."
|
| 328 |
+
|
| 329 |
+
except ValueError as exc:
|
| 330 |
+
feedback = f"Action error: {exc}"
|
| 331 |
+
|
| 332 |
+
# ------------------------------------------------------------------
|
| 333 |
+
# Reward
|
| 334 |
+
# ------------------------------------------------------------------
|
| 335 |
+
rb: RB = calc(
|
| 336 |
+
action_type=atype,
|
| 337 |
+
db=self._db,
|
| 338 |
+
counter=self._counter,
|
| 339 |
+
action=action,
|
| 340 |
+
validation_result=(
|
| 341 |
+
getattr(self, "_last_vr", None) if atype == "validate" else None
|
| 342 |
+
),
|
| 343 |
+
)
|
| 344 |
+
|
| 345 |
+
step_reward = rb.total
|
| 346 |
+
rb_dict = rb.to_dict()
|
| 347 |
+
rb_dict["step"] = step
|
| 348 |
+
rb_dict["action_type"] = atype
|
| 349 |
+
self._reward_trace.append(rb_dict)
|
| 350 |
+
|
| 351 |
+
# Update investigation count
|
| 352 |
+
if atype in ("inspect", "profile_column", "run_sql"):
|
| 353 |
+
self._state.investigation_count += 1
|
| 354 |
+
|
| 355 |
+
# Max-steps termination
|
| 356 |
+
if step >= max_steps and not done:
|
| 357 |
+
done = True
|
| 358 |
+
feedback += f" [max_steps={max_steps} reached]"
|
| 359 |
+
|
| 360 |
+
self._state.done = done
|
| 361 |
+
|
| 362 |
+
obs = self._make_obs(
|
| 363 |
+
last_feedback=feedback,
|
| 364 |
+
query_result=query_result,
|
| 365 |
+
validation_result=validation_result,
|
| 366 |
+
)
|
| 367 |
+
obs.done = done
|
| 368 |
+
obs.reward = step_reward
|
| 369 |
+
|
| 370 |
+
return obs
|
| 371 |
+
|
| 372 |
+
# ------------------------------------------------------------------
|
| 373 |
+
# get_state()
|
| 374 |
+
# ------------------------------------------------------------------
|
| 375 |
+
|
| 376 |
+
@property
|
| 377 |
+
def state(self) -> SQLSherlockState:
|
| 378 |
+
"""Required by openenv-core Environment base class."""
|
| 379 |
+
return self.get_state()
|
| 380 |
+
|
| 381 |
+
def get_state(self) -> SQLSherlockState:
|
| 382 |
+
if self._state is None:
|
| 383 |
+
return SQLSherlockState()
|
| 384 |
+
return self._state
|
| 385 |
+
|
| 386 |
+
# ------------------------------------------------------------------
|
| 387 |
+
# Private helpers
|
| 388 |
+
# ------------------------------------------------------------------
|
| 389 |
+
|
| 390 |
+
def _make_obs(
|
| 391 |
+
self,
|
| 392 |
+
last_feedback: str,
|
| 393 |
+
query_result: Optional[list],
|
| 394 |
+
validation_result: Optional[dict],
|
| 395 |
+
) -> SQLSherlockObservation:
|
| 396 |
+
task_cfg = _TASK_MAP.get(self._state.task_id, TASKS[0]) if self._state else TASKS[0]
|
| 397 |
+
return SQLSherlockObservation(
|
| 398 |
+
task_id=self._state.task_id if self._state else "",
|
| 399 |
+
task_description=task_cfg["description"],
|
| 400 |
+
step=self._state.step_count if self._state else 0,
|
| 401 |
+
max_steps=task_cfg["max_steps"],
|
| 402 |
+
tables_summary=self._db.tables_summary() if self._db else {},
|
| 403 |
+
query_result=query_result,
|
| 404 |
+
validation_result=validation_result,
|
| 405 |
+
last_feedback=last_feedback,
|
| 406 |
+
reward_trace=list(self._reward_trace),
|
| 407 |
+
done=self._state.done if self._state else False,
|
| 408 |
+
)
|
sqlsherlock_env/server/exporter.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
Exporter for SQLSherlock-Env.
|
| 9 |
+
|
| 10 |
+
Writes the cleaned dataset in the SAME FORMAT as the original input.
|
| 11 |
+
Supported output formats: csv, json, jsonl, parquet, hf_dataset (→ csv).
|
| 12 |
+
|
| 13 |
+
Returns a file descriptor dict that the environment embeds in the
|
| 14 |
+
observation and that the /download/{file_id} endpoint serves.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
import csv
|
| 18 |
+
import io
|
| 19 |
+
import json
|
| 20 |
+
import os
|
| 21 |
+
import tempfile
|
| 22 |
+
import uuid
|
| 23 |
+
from typing import Any
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
# ---------------------------------------------------------------------------
|
| 27 |
+
# Public API
|
| 28 |
+
# ---------------------------------------------------------------------------
|
| 29 |
+
|
| 30 |
+
def export_cleaned(
|
| 31 |
+
cleaned_rows: list[dict],
|
| 32 |
+
source_format: str,
|
| 33 |
+
dataset_name: str,
|
| 34 |
+
) -> dict:
|
| 35 |
+
"""Write cleaned rows to a temp file matching the original format.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
cleaned_rows: List of cleaned row dicts (no _source_format key).
|
| 39 |
+
source_format: One of csv | json | jsonl | parquet | hf_dataset.
|
| 40 |
+
dataset_name: Original dataset name/path (used to derive filename).
|
| 41 |
+
|
| 42 |
+
Returns:
|
| 43 |
+
Dict with keys:
|
| 44 |
+
file_id — UUID string (used in /download/{file_id})
|
| 45 |
+
filename — human-readable filename
|
| 46 |
+
format — detected output format
|
| 47 |
+
download_url — relative URL path
|
| 48 |
+
row_count — number of rows written
|
| 49 |
+
"""
|
| 50 |
+
if not cleaned_rows:
|
| 51 |
+
raise ValueError("Cannot export empty cleaned_rows list.")
|
| 52 |
+
|
| 53 |
+
# Strip internal metadata column before writing
|
| 54 |
+
rows = _strip_meta(cleaned_rows)
|
| 55 |
+
|
| 56 |
+
file_id = str(uuid.uuid4())
|
| 57 |
+
stem = _stem_from_name(dataset_name)
|
| 58 |
+
fmt = source_format if source_format in _WRITERS else "csv"
|
| 59 |
+
|
| 60 |
+
filename, filepath = _make_temp_path(file_id, stem, fmt)
|
| 61 |
+
|
| 62 |
+
_WRITERS[fmt](rows, filepath)
|
| 63 |
+
|
| 64 |
+
return {
|
| 65 |
+
"file_id": file_id,
|
| 66 |
+
"filename": filename,
|
| 67 |
+
"format": fmt,
|
| 68 |
+
"download_url": f"/download/{file_id}",
|
| 69 |
+
"row_count": len(rows),
|
| 70 |
+
"filepath": filepath, # kept server-side for FileResponse
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
# ---------------------------------------------------------------------------
|
| 75 |
+
# Format writers
|
| 76 |
+
# ---------------------------------------------------------------------------
|
| 77 |
+
|
| 78 |
+
def _write_csv(rows: list[dict], path: str) -> None:
|
| 79 |
+
if not rows:
|
| 80 |
+
return
|
| 81 |
+
with open(path, "w", newline="", encoding="utf-8") as f:
|
| 82 |
+
writer = csv.DictWriter(f, fieldnames=list(rows[0].keys()))
|
| 83 |
+
writer.writeheader()
|
| 84 |
+
writer.writerows(rows)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def _write_json(rows: list[dict], path: str) -> None:
|
| 88 |
+
with open(path, "w", encoding="utf-8") as f:
|
| 89 |
+
json.dump(rows, f, indent=2, default=str)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def _write_jsonl(rows: list[dict], path: str) -> None:
|
| 93 |
+
with open(path, "w", encoding="utf-8") as f:
|
| 94 |
+
for row in rows:
|
| 95 |
+
f.write(json.dumps(row, default=str) + "\n")
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def _write_parquet(rows: list[dict], path: str) -> None:
|
| 99 |
+
try:
|
| 100 |
+
import pandas as pd
|
| 101 |
+
except ImportError:
|
| 102 |
+
raise ValueError(
|
| 103 |
+
"pandas is required to export Parquet files. "
|
| 104 |
+
"pip install pandas pyarrow"
|
| 105 |
+
)
|
| 106 |
+
df = pd.DataFrame(rows)
|
| 107 |
+
df.to_parquet(path, index=False)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
_WRITERS = {
|
| 111 |
+
"csv": _write_csv,
|
| 112 |
+
"json": _write_json,
|
| 113 |
+
"jsonl": _write_jsonl,
|
| 114 |
+
"parquet": _write_parquet,
|
| 115 |
+
"hf_dataset": _write_csv, # HF datasets exported as CSV
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
# ---------------------------------------------------------------------------
|
| 120 |
+
# Helpers
|
| 121 |
+
# ---------------------------------------------------------------------------
|
| 122 |
+
|
| 123 |
+
def _strip_meta(rows: list[dict]) -> list[dict]:
|
| 124 |
+
"""Remove _source_format from every row."""
|
| 125 |
+
return [
|
| 126 |
+
{k: v for k, v in row.items() if k != "_source_format"}
|
| 127 |
+
for row in rows
|
| 128 |
+
]
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def _stem_from_name(dataset_name: str) -> str:
|
| 132 |
+
"""Derive a clean file stem from the dataset name."""
|
| 133 |
+
if not dataset_name:
|
| 134 |
+
return "cleaned"
|
| 135 |
+
# HF dataset: "owner/name" or "owner/name:split"
|
| 136 |
+
# For raw CSV text, take only the first line (header) to avoid huge filenames.
|
| 137 |
+
first_line = dataset_name.strip().split("\n")[0]
|
| 138 |
+
base = first_line.split(":")[0].split("/")[-1]
|
| 139 |
+
safe = "".join(c if c.isalnum() or c == "_" else "_" for c in base.lower())
|
| 140 |
+
# Truncate to 40 chars to stay well under filesystem path length limits.
|
| 141 |
+
safe = (safe or "cleaned")[:40].rstrip("_")
|
| 142 |
+
return (safe or "cleaned") + "_cleaned"
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def _ext_for_format(fmt: str) -> str:
|
| 146 |
+
return {
|
| 147 |
+
"csv": ".csv",
|
| 148 |
+
"json": ".json",
|
| 149 |
+
"jsonl": ".jsonl",
|
| 150 |
+
"parquet": ".parquet",
|
| 151 |
+
"hf_dataset": ".csv",
|
| 152 |
+
}.get(fmt, ".csv")
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def _make_temp_path(file_id: str, stem: str, fmt: str) -> tuple[str, str]:
|
| 156 |
+
"""Return (filename, full_filepath) in the system temp directory."""
|
| 157 |
+
ext = _ext_for_format(fmt)
|
| 158 |
+
filename = f"{stem}{ext}"
|
| 159 |
+
filepath = os.path.join(tempfile.gettempdir(), f"{file_id}_{filename}")
|
| 160 |
+
return filename, filepath
|
sqlsherlock_env/server/graders/__init__.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
Graders package for SQLSherlock-Env.
|
| 9 |
+
|
| 10 |
+
Each task has a dedicated grader that delegates to universal.grade()
|
| 11 |
+
with task-appropriate filters.
|
| 12 |
+
|
| 13 |
+
Usage (from environment.py)::
|
| 14 |
+
|
| 15 |
+
from server import graders
|
| 16 |
+
|
| 17 |
+
score = graders.grade(
|
| 18 |
+
db=db,
|
| 19 |
+
cleaned_rows=cleaned_rows,
|
| 20 |
+
removed_ids=removed_ids,
|
| 21 |
+
task_id=task_id,
|
| 22 |
+
validation_was_called=validation_was_called,
|
| 23 |
+
)
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
from server.graders.task1 import grade as _grade_task1
|
| 27 |
+
from server.graders.task2 import grade as _grade_task2
|
| 28 |
+
from server.graders.task3 import grade as _grade_task3
|
| 29 |
+
|
| 30 |
+
_GRADERS = {
|
| 31 |
+
"task1_null_and_types": _grade_task1,
|
| 32 |
+
"task2_constraints_and_fk": _grade_task2,
|
| 33 |
+
"task3_full_audit_with_trap": _grade_task3,
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def grade(
|
| 38 |
+
db,
|
| 39 |
+
cleaned_rows: list[dict],
|
| 40 |
+
removed_ids: list[int],
|
| 41 |
+
task_id: str,
|
| 42 |
+
validation_was_called: bool,
|
| 43 |
+
) -> float:
|
| 44 |
+
"""Dispatch to the correct task grader and return a score in [0.0, 1.0].
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
db: DatabaseEngine instance for this episode.
|
| 48 |
+
cleaned_rows: Agent-provided cleaned row list.
|
| 49 |
+
removed_ids: Agent-provided list of deleted row PKs.
|
| 50 |
+
task_id: Task identifier string.
|
| 51 |
+
validation_was_called: Whether the agent called validate() at least once.
|
| 52 |
+
|
| 53 |
+
Returns:
|
| 54 |
+
Float score in [0.0, 1.0].
|
| 55 |
+
|
| 56 |
+
Raises:
|
| 57 |
+
ValueError: If task_id is not recognised.
|
| 58 |
+
"""
|
| 59 |
+
grader_fn = _GRADERS.get(task_id)
|
| 60 |
+
if grader_fn is None:
|
| 61 |
+
raise ValueError(
|
| 62 |
+
f"Unknown task_id '{task_id}'. "
|
| 63 |
+
f"Valid tasks: {sorted(_GRADERS.keys())}"
|
| 64 |
+
)
|
| 65 |
+
return grader_fn(
|
| 66 |
+
db=db,
|
| 67 |
+
cleaned_rows=cleaned_rows,
|
| 68 |
+
removed_ids=removed_ids,
|
| 69 |
+
validation_was_called=validation_was_called,
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
__all__ = ["grade"]
|
sqlsherlock_env/server/graders/task1.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
Task 1 grader — Null and type error repair.
|
| 9 |
+
|
| 10 |
+
Scoring formula:
|
| 11 |
+
task1_score = resolution_score × 0.70 + validation_score × 0.30
|
| 12 |
+
|
| 13 |
+
Only null and type_error issues contribute to resolution_score.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
from server.database import DatabaseEngine
|
| 17 |
+
from server.graders.universal import grade as universal_grade
|
| 18 |
+
|
| 19 |
+
_ISSUE_FILTER = {"null", "type_error"}
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def grade(
|
| 23 |
+
db: DatabaseEngine,
|
| 24 |
+
cleaned_rows: list[dict],
|
| 25 |
+
removed_ids: list[int],
|
| 26 |
+
validation_was_called: bool,
|
| 27 |
+
) -> float:
|
| 28 |
+
"""Score a task1 submission.
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
db: DatabaseEngine for this episode.
|
| 32 |
+
cleaned_rows: Agent-provided cleaned rows.
|
| 33 |
+
removed_ids: Agent-provided deleted row PKs.
|
| 34 |
+
validation_was_called: Whether validate() was called.
|
| 35 |
+
|
| 36 |
+
Returns:
|
| 37 |
+
Float score in [0.0, 1.0].
|
| 38 |
+
"""
|
| 39 |
+
# universal.grade uses its own 0.60/0.30/0.10 weights internally.
|
| 40 |
+
# We get the raw universal score, then re-weight to task1 formula:
|
| 41 |
+
# resolution_score × 0.70 + validation_score × 0.30
|
| 42 |
+
#
|
| 43 |
+
# To do that cleanly we compute both sub-scores independently and
|
| 44 |
+
# combine them here.
|
| 45 |
+
|
| 46 |
+
from server.graders.universal import (
|
| 47 |
+
_resolution_score,
|
| 48 |
+
_false_positive_penalty,
|
| 49 |
+
_trap_penalty,
|
| 50 |
+
_validation_score,
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
issue_registry = db.issue_registry
|
| 54 |
+
scored_issues = [i for i in issue_registry if i.issue_type in _ISSUE_FILTER]
|
| 55 |
+
pk_col = db.pk_col
|
| 56 |
+
|
| 57 |
+
# Zero-change guard — compare against ORIGINAL dirty state, not current state
|
| 58 |
+
dirty_rows = db.original_state()
|
| 59 |
+
from server.graders.universal import _rows_identical
|
| 60 |
+
if not removed_ids and _rows_identical(cleaned_rows, dirty_rows, pk_col):
|
| 61 |
+
if db.total_issues > 0:
|
| 62 |
+
return 0.0
|
| 63 |
+
|
| 64 |
+
res_score, _ = _resolution_score(
|
| 65 |
+
scored_issues, cleaned_rows, removed_ids, pk_col, db
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
fp_penalty = _false_positive_penalty(
|
| 69 |
+
db, cleaned_rows, removed_ids, pk_col, db.primary_table
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
val_score = _validation_score(db, cleaned_rows, validation_was_called)
|
| 73 |
+
|
| 74 |
+
raw = res_score * 0.70 + val_score * 0.30 - fp_penalty
|
| 75 |
+
return max(0.0, min(1.0, round(raw, 4)))
|
sqlsherlock_env/server/graders/task2.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
Task 2 grader — Constraint and FK integrity.
|
| 9 |
+
|
| 10 |
+
Scoring formula:
|
| 11 |
+
task2_score = task1_score × 0.40
|
| 12 |
+
+ (constraint_resolved + fk_resolved) / 2 × 0.60
|
| 13 |
+
|
| 14 |
+
task1_score is computed by the task1 grader (null + type only).
|
| 15 |
+
constraint_resolved and fk_resolved are weighted resolution scores
|
| 16 |
+
for their respective issue types (each in [0.0, 1.0]).
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
from server.database import DatabaseEngine
|
| 20 |
+
from server.graders.task1 import grade as task1_grade
|
| 21 |
+
from server.graders.universal import (
|
| 22 |
+
_resolution_score,
|
| 23 |
+
_false_positive_penalty,
|
| 24 |
+
_rows_identical,
|
| 25 |
+
_validation_score,
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
_CONSTRAINT_FILTER = {"constraint"}
|
| 29 |
+
_FK_FILTER = {"fk_violation"}
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def grade(
|
| 33 |
+
db: DatabaseEngine,
|
| 34 |
+
cleaned_rows: list[dict],
|
| 35 |
+
removed_ids: list[int],
|
| 36 |
+
validation_was_called: bool,
|
| 37 |
+
) -> float:
|
| 38 |
+
"""Score a task2 submission.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
db: DatabaseEngine for this episode.
|
| 42 |
+
cleaned_rows: Agent-provided cleaned rows.
|
| 43 |
+
removed_ids: Agent-provided deleted row PKs.
|
| 44 |
+
validation_was_called: Whether validate() was called.
|
| 45 |
+
|
| 46 |
+
Returns:
|
| 47 |
+
Float score in [0.0, 1.0].
|
| 48 |
+
"""
|
| 49 |
+
pk_col = db.pk_col
|
| 50 |
+
|
| 51 |
+
# Zero-change guard — compare against ORIGINAL dirty state, not current state
|
| 52 |
+
dirty_rows = db.original_state()
|
| 53 |
+
if not removed_ids and _rows_identical(cleaned_rows, dirty_rows, pk_col):
|
| 54 |
+
if db.total_issues > 0:
|
| 55 |
+
return 0.0
|
| 56 |
+
|
| 57 |
+
# task1 component (null + type errors)
|
| 58 |
+
t1 = task1_grade(
|
| 59 |
+
db=db,
|
| 60 |
+
cleaned_rows=cleaned_rows,
|
| 61 |
+
removed_ids=removed_ids,
|
| 62 |
+
validation_was_called=validation_was_called,
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
# Constraint resolution score
|
| 66 |
+
constraint_issues = [
|
| 67 |
+
i for i in db.issue_registry if i.issue_type in _CONSTRAINT_FILTER
|
| 68 |
+
]
|
| 69 |
+
if constraint_issues:
|
| 70 |
+
c_score, _ = _resolution_score(
|
| 71 |
+
constraint_issues, cleaned_rows, removed_ids, pk_col, db
|
| 72 |
+
)
|
| 73 |
+
else:
|
| 74 |
+
c_score = 1.0 # No constraint issues → full credit
|
| 75 |
+
|
| 76 |
+
# FK resolution score
|
| 77 |
+
fk_issues = [
|
| 78 |
+
i for i in db.issue_registry if i.issue_type in _FK_FILTER
|
| 79 |
+
]
|
| 80 |
+
if fk_issues:
|
| 81 |
+
fk_score, _ = _resolution_score(
|
| 82 |
+
fk_issues, cleaned_rows, removed_ids, pk_col, db
|
| 83 |
+
)
|
| 84 |
+
else:
|
| 85 |
+
fk_score = 1.0 # No FK issues → full credit
|
| 86 |
+
|
| 87 |
+
fp_penalty = _false_positive_penalty(
|
| 88 |
+
db, cleaned_rows, removed_ids, pk_col, db.primary_table
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
combined = (c_score + fk_score) / 2.0
|
| 92 |
+
raw = t1 * 0.40 + combined * 0.60 - fp_penalty
|
| 93 |
+
return max(0.0, min(1.0, round(raw, 4)))
|
sqlsherlock_env/server/graders/task3.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
Task 3 grader — Full statistical audit with trap.
|
| 9 |
+
|
| 10 |
+
Scoring formula:
|
| 11 |
+
task3_score = task2_score × 0.50
|
| 12 |
+
+ audit_issues_resolved × 0.50
|
| 13 |
+
+ reasoning_bonus (0.05)
|
| 14 |
+
- trap_penalty (0.40 if trap hit)
|
| 15 |
+
|
| 16 |
+
audit_issues_resolved = weighted resolution score for
|
| 17 |
+
outlier + duplicate issue types.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
from server.database import DatabaseEngine
|
| 21 |
+
from server.graders.task2 import grade as task2_grade
|
| 22 |
+
from server.graders.universal import (
|
| 23 |
+
_resolution_score,
|
| 24 |
+
_trap_penalty,
|
| 25 |
+
_rows_identical,
|
| 26 |
+
_reasoning_bonus,
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
_AUDIT_FILTER = {"outlier", "duplicate"}
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def grade(
|
| 33 |
+
db: DatabaseEngine,
|
| 34 |
+
cleaned_rows: list[dict],
|
| 35 |
+
removed_ids: list[int],
|
| 36 |
+
validation_was_called: bool,
|
| 37 |
+
) -> float:
|
| 38 |
+
"""Score a task3 submission.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
db: DatabaseEngine for this episode.
|
| 42 |
+
cleaned_rows: Agent-provided cleaned rows.
|
| 43 |
+
removed_ids: Agent-provided deleted row PKs.
|
| 44 |
+
validation_was_called: Whether validate() was called.
|
| 45 |
+
|
| 46 |
+
Returns:
|
| 47 |
+
Float score in [0.0, 1.0].
|
| 48 |
+
"""
|
| 49 |
+
pk_col = db.pk_col
|
| 50 |
+
|
| 51 |
+
# Zero-change guard — compare against ORIGINAL dirty state, not current state
|
| 52 |
+
dirty_rows = db.original_state()
|
| 53 |
+
if not removed_ids and _rows_identical(cleaned_rows, dirty_rows, pk_col):
|
| 54 |
+
if db.total_issues > 0:
|
| 55 |
+
return 0.0
|
| 56 |
+
|
| 57 |
+
# task2 component (null + type + constraint + fk)
|
| 58 |
+
t2 = task2_grade(
|
| 59 |
+
db=db,
|
| 60 |
+
cleaned_rows=cleaned_rows,
|
| 61 |
+
removed_ids=removed_ids,
|
| 62 |
+
validation_was_called=validation_was_called,
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
# Audit issues: outlier + duplicate
|
| 66 |
+
audit_issues = [
|
| 67 |
+
i for i in db.issue_registry if i.issue_type in _AUDIT_FILTER
|
| 68 |
+
]
|
| 69 |
+
if audit_issues:
|
| 70 |
+
audit_score, _ = _resolution_score(
|
| 71 |
+
audit_issues, cleaned_rows, removed_ids, pk_col, db
|
| 72 |
+
)
|
| 73 |
+
else:
|
| 74 |
+
audit_score = 1.0 # No audit issues → full credit
|
| 75 |
+
|
| 76 |
+
# Trap penalty
|
| 77 |
+
trap_pen = _trap_penalty(
|
| 78 |
+
db, cleaned_rows, removed_ids, pk_col,
|
| 79 |
+
task_id="task3_full_audit_with_trap",
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
# Reasoning bonus
|
| 83 |
+
r_bonus = _reasoning_bonus(db, "task3_full_audit_with_trap", validation_was_called)
|
| 84 |
+
|
| 85 |
+
# NOTE: FP penalty is already applied inside t2 (task2_grade) — not applied
|
| 86 |
+
# again here to avoid double-counting.
|
| 87 |
+
|
| 88 |
+
raw = (
|
| 89 |
+
t2 * 0.50
|
| 90 |
+
+ audit_score * 0.50
|
| 91 |
+
+ r_bonus
|
| 92 |
+
- trap_pen
|
| 93 |
+
)
|
| 94 |
+
return max(0.0, min(1.0, round(raw, 4)))
|
sqlsherlock_env/server/graders/universal.py
ADDED
|
@@ -0,0 +1,442 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
Universal grader for SQLSherlock-Env.
|
| 9 |
+
|
| 10 |
+
Implements the 7-step scoring pipeline shared by all task graders.
|
| 11 |
+
Task graders (task1/task2/task3) call grade() with an issue_filter
|
| 12 |
+
to restrict which issue types count toward the score.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import math
|
| 16 |
+
from typing import Any, Optional
|
| 17 |
+
|
| 18 |
+
from server.issue_detector import SENTINEL_UNKNOWN
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
# ---------------------------------------------------------------------------
|
| 22 |
+
# Public API
|
| 23 |
+
# ---------------------------------------------------------------------------
|
| 24 |
+
|
| 25 |
+
def grade(
|
| 26 |
+
db: Any,
|
| 27 |
+
cleaned_rows: list[dict],
|
| 28 |
+
removed_ids: list[int],
|
| 29 |
+
task_id: str,
|
| 30 |
+
validation_was_called: bool,
|
| 31 |
+
issue_filter: Optional[set[str]] = None,
|
| 32 |
+
) -> float:
|
| 33 |
+
"""Score an agent's submitted solution in [0.0, 1.0].
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
db: DatabaseEngine for this episode.
|
| 37 |
+
cleaned_rows: Rows the agent claims are clean.
|
| 38 |
+
removed_ids: Row PKs the agent deleted.
|
| 39 |
+
task_id: Task identifier (used for trap / reasoning checks).
|
| 40 |
+
validation_was_called: Whether validate() was called during the episode.
|
| 41 |
+
issue_filter: If set, only issues whose type is in this set
|
| 42 |
+
contribute to resolution_score. None = all types.
|
| 43 |
+
|
| 44 |
+
Returns:
|
| 45 |
+
Float in [0.0, 1.0].
|
| 46 |
+
"""
|
| 47 |
+
issue_registry = db.issue_registry
|
| 48 |
+
pk_col = db.pk_col
|
| 49 |
+
primary_table = db.primary_table
|
| 50 |
+
|
| 51 |
+
# Filter issues by type if requested
|
| 52 |
+
if issue_filter:
|
| 53 |
+
scored_issues = [i for i in issue_registry if i.issue_type in issue_filter]
|
| 54 |
+
else:
|
| 55 |
+
scored_issues = list(issue_registry)
|
| 56 |
+
|
| 57 |
+
# --- STEP 1: Zero-change check ---
|
| 58 |
+
# Compare against the ORIGINAL dirty state (before any fixes), not the current state.
|
| 59 |
+
# db.rows() returns the current (post-fix) state, so it would always match cleaned_rows.
|
| 60 |
+
dirty_rows = db.original_state()
|
| 61 |
+
|
| 62 |
+
if not removed_ids and _rows_identical(cleaned_rows, dirty_rows, pk_col):
|
| 63 |
+
if db.total_issues > 0:
|
| 64 |
+
return 0.0
|
| 65 |
+
|
| 66 |
+
# --- STEP 2: Resolution score ---
|
| 67 |
+
resolution_score, total_weight = _resolution_score(
|
| 68 |
+
scored_issues, cleaned_rows, removed_ids, pk_col, db
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
# --- STEP 3: False positive penalty ---
|
| 72 |
+
fp_penalty = _false_positive_penalty(
|
| 73 |
+
db, cleaned_rows, removed_ids, pk_col, primary_table
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
# --- STEP 4: Trap penalty (task3 only) ---
|
| 77 |
+
trap_penalty = _trap_penalty(db, cleaned_rows, removed_ids, pk_col, task_id)
|
| 78 |
+
|
| 79 |
+
# --- STEP 5: Validation score ---
|
| 80 |
+
validation_score = _validation_score(
|
| 81 |
+
db, cleaned_rows, validation_was_called
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
# --- STEP 6: Reasoning bonus (task3 only) ---
|
| 85 |
+
reasoning_bonus = _reasoning_bonus(db, task_id, validation_was_called)
|
| 86 |
+
|
| 87 |
+
# --- STEP 7: Final score ---
|
| 88 |
+
raw = (
|
| 89 |
+
resolution_score * 0.60
|
| 90 |
+
+ validation_score * 0.30
|
| 91 |
+
+ reasoning_bonus * 0.10
|
| 92 |
+
- fp_penalty
|
| 93 |
+
- trap_penalty
|
| 94 |
+
)
|
| 95 |
+
return max(0.0, min(1.0, round(raw, 4)))
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
# ---------------------------------------------------------------------------
|
| 99 |
+
# Step implementations
|
| 100 |
+
# ---------------------------------------------------------------------------
|
| 101 |
+
|
| 102 |
+
def _resolution_score(
|
| 103 |
+
issues: list,
|
| 104 |
+
cleaned_rows: list[dict],
|
| 105 |
+
removed_ids: list[int],
|
| 106 |
+
pk_col: str,
|
| 107 |
+
db: Any,
|
| 108 |
+
) -> tuple[float, float]:
|
| 109 |
+
"""Return (weighted_resolution_score, total_weight)."""
|
| 110 |
+
if not issues:
|
| 111 |
+
return 1.0, 1.0 # No issues to resolve → full resolution score
|
| 112 |
+
|
| 113 |
+
cleaned_map = {row[pk_col]: row for row in cleaned_rows}
|
| 114 |
+
removed_set = set(removed_ids)
|
| 115 |
+
total_weight = sum(i.confidence for i in issues)
|
| 116 |
+
|
| 117 |
+
if total_weight == 0:
|
| 118 |
+
return 0.0, 0.0
|
| 119 |
+
|
| 120 |
+
# Per-column stats for outlier z-score recheck
|
| 121 |
+
col_stats: dict[str, dict] = {}
|
| 122 |
+
profile = db._profiles.get(db.primary_table, {})
|
| 123 |
+
|
| 124 |
+
weighted_sum = 0.0
|
| 125 |
+
|
| 126 |
+
for iss in issues:
|
| 127 |
+
C = iss.confidence
|
| 128 |
+
col = iss.column
|
| 129 |
+
rid = iss.row_id
|
| 130 |
+
|
| 131 |
+
p = profile.get(col, {}) if col else {}
|
| 132 |
+
col_mean = p.get("mean")
|
| 133 |
+
col_std = p.get("std")
|
| 134 |
+
|
| 135 |
+
resolved = _resolve_issue(
|
| 136 |
+
iss, cleaned_map, removed_set, col_mean, col_std
|
| 137 |
+
)
|
| 138 |
+
weighted_sum += resolved * C
|
| 139 |
+
|
| 140 |
+
return weighted_sum / total_weight, total_weight
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def _resolve_issue(
|
| 144 |
+
iss: Any,
|
| 145 |
+
cleaned_map: dict,
|
| 146 |
+
removed_set: set,
|
| 147 |
+
col_mean: Optional[float],
|
| 148 |
+
col_std: Optional[float],
|
| 149 |
+
) -> float:
|
| 150 |
+
"""Return a resolution score in [0.0, 1.0] for one issue."""
|
| 151 |
+
C = iss.confidence
|
| 152 |
+
col = iss.column
|
| 153 |
+
rid = iss.row_id
|
| 154 |
+
|
| 155 |
+
itype = iss.issue_type
|
| 156 |
+
|
| 157 |
+
# --- duplicate / fk_violation ---
|
| 158 |
+
if itype in ("duplicate", "fk_violation"):
|
| 159 |
+
if rid in removed_set:
|
| 160 |
+
return 1.0
|
| 161 |
+
if rid not in cleaned_map:
|
| 162 |
+
return 1.0 # row absent from cleaned output = deleted
|
| 163 |
+
return 0.0
|
| 164 |
+
|
| 165 |
+
# --- null ---
|
| 166 |
+
if itype == "null":
|
| 167 |
+
row = cleaned_map.get(rid)
|
| 168 |
+
if row is None:
|
| 169 |
+
return 0.5 * C # deleted instead of fixed
|
| 170 |
+
val = row.get(col)
|
| 171 |
+
if _is_null(val):
|
| 172 |
+
return 0.0
|
| 173 |
+
if iss.correct == SENTINEL_UNKNOWN:
|
| 174 |
+
# Any non-null value of correct type accepted
|
| 175 |
+
col_dtype = _guess_dtype(val)
|
| 176 |
+
return C if col_dtype != "unknown" else C * 0.5
|
| 177 |
+
return C if _values_match(val, iss.correct) else 0.0
|
| 178 |
+
|
| 179 |
+
# --- type_error ---
|
| 180 |
+
if itype == "type_error":
|
| 181 |
+
row = cleaned_map.get(rid)
|
| 182 |
+
if row is None:
|
| 183 |
+
return 0.5
|
| 184 |
+
val = row.get(col)
|
| 185 |
+
if _is_null(val):
|
| 186 |
+
return 0.0
|
| 187 |
+
try:
|
| 188 |
+
float(str(val))
|
| 189 |
+
return 1.0
|
| 190 |
+
except (ValueError, TypeError):
|
| 191 |
+
return 0.0
|
| 192 |
+
|
| 193 |
+
# --- constraint ---
|
| 194 |
+
if itype == "constraint":
|
| 195 |
+
row = cleaned_map.get(rid)
|
| 196 |
+
if row is None:
|
| 197 |
+
return 0.5 * C
|
| 198 |
+
val = row.get(col)
|
| 199 |
+
if _is_null(val):
|
| 200 |
+
return 0.0
|
| 201 |
+
try:
|
| 202 |
+
fval = float(str(val))
|
| 203 |
+
except (ValueError, TypeError):
|
| 204 |
+
return 0.0
|
| 205 |
+
if fval >= 0:
|
| 206 |
+
correct = iss.correct
|
| 207 |
+
if correct is not None and correct != SENTINEL_UNKNOWN:
|
| 208 |
+
if fval <= abs(float(correct)) * 5:
|
| 209 |
+
return C # positive and close to original
|
| 210 |
+
return C * 0.7 # positive but far from original
|
| 211 |
+
return C # unknown correct — any non-negative OK
|
| 212 |
+
return 0.0 # still negative
|
| 213 |
+
|
| 214 |
+
# --- outlier ---
|
| 215 |
+
if itype == "outlier":
|
| 216 |
+
row = cleaned_map.get(rid)
|
| 217 |
+
if row is None:
|
| 218 |
+
return 0.5 * C
|
| 219 |
+
val = row.get(col)
|
| 220 |
+
if _is_null(val):
|
| 221 |
+
return 0.0
|
| 222 |
+
if col_mean is None or col_std is None or col_std == 0:
|
| 223 |
+
return C # can't verify — assume resolved
|
| 224 |
+
try:
|
| 225 |
+
z = abs(float(str(val)) - col_mean) / col_std
|
| 226 |
+
except (ValueError, TypeError):
|
| 227 |
+
return 0.0
|
| 228 |
+
if z <= 3.0:
|
| 229 |
+
return C
|
| 230 |
+
if z <= 5.0:
|
| 231 |
+
return C * 0.5
|
| 232 |
+
return 0.0
|
| 233 |
+
|
| 234 |
+
# --- whitespace ---
|
| 235 |
+
if itype == "whitespace":
|
| 236 |
+
row = cleaned_map.get(rid)
|
| 237 |
+
if row is None:
|
| 238 |
+
return 0.0
|
| 239 |
+
val = row.get(col)
|
| 240 |
+
if _is_null(val):
|
| 241 |
+
return 0.0
|
| 242 |
+
s = str(val)
|
| 243 |
+
if s == " ".join(s.split()):
|
| 244 |
+
return C # whitespace cleaned
|
| 245 |
+
return 0.0
|
| 246 |
+
|
| 247 |
+
# --- inconsistent_category ---
|
| 248 |
+
if itype == "inconsistent_category":
|
| 249 |
+
row = cleaned_map.get(rid)
|
| 250 |
+
if row is None:
|
| 251 |
+
return 0.0
|
| 252 |
+
val = row.get(col)
|
| 253 |
+
if _is_null(val):
|
| 254 |
+
return 0.0
|
| 255 |
+
if _values_match(val, iss.correct):
|
| 256 |
+
return C # normalized to dominant form
|
| 257 |
+
# Accept if same lowercase (partially resolved)
|
| 258 |
+
if str(val).strip().lower() == str(iss.correct).strip().lower():
|
| 259 |
+
return C * 0.8
|
| 260 |
+
return 0.0
|
| 261 |
+
|
| 262 |
+
return 0.0
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
def _false_positive_penalty(
|
| 266 |
+
db: Any,
|
| 267 |
+
cleaned_rows: list[dict],
|
| 268 |
+
removed_ids: list[int],
|
| 269 |
+
pk_col: str,
|
| 270 |
+
primary_table: str,
|
| 271 |
+
) -> float:
|
| 272 |
+
"""Penalise changes to cells that were not in the issue registry."""
|
| 273 |
+
originals = db._originals.get(primary_table, [])
|
| 274 |
+
orig_map = {row[pk_col]: row for row in originals}
|
| 275 |
+
issue_cells = {
|
| 276 |
+
(i.row_id, i.column)
|
| 277 |
+
for i in db.issue_registry
|
| 278 |
+
if i.column
|
| 279 |
+
}
|
| 280 |
+
issue_rows = {i.row_id for i in db.issue_registry}
|
| 281 |
+
removed_set = set(removed_ids)
|
| 282 |
+
|
| 283 |
+
fp_count = 0
|
| 284 |
+
|
| 285 |
+
# Changed cells that are not in issue_registry
|
| 286 |
+
for row in cleaned_rows:
|
| 287 |
+
rid = row.get(pk_col)
|
| 288 |
+
orig = orig_map.get(rid)
|
| 289 |
+
if orig is None:
|
| 290 |
+
continue
|
| 291 |
+
for col, val in row.items():
|
| 292 |
+
if col == pk_col or col == "_source_format":
|
| 293 |
+
continue
|
| 294 |
+
if (rid, col) in issue_cells:
|
| 295 |
+
continue
|
| 296 |
+
if not _values_match(val, orig.get(col)):
|
| 297 |
+
fp_count += 1
|
| 298 |
+
|
| 299 |
+
# Rows that were deleted but should NOT have been
|
| 300 |
+
orig_ids = set(orig_map.keys())
|
| 301 |
+
cleaned_ids = {row[pk_col] for row in cleaned_rows} | removed_set
|
| 302 |
+
wrongly_removed = orig_ids - cleaned_ids - issue_rows
|
| 303 |
+
fp_count += len(wrongly_removed)
|
| 304 |
+
|
| 305 |
+
return min(fp_count * 0.05, 0.20)
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
def _trap_penalty(
|
| 309 |
+
db: Any,
|
| 310 |
+
cleaned_rows: list[dict],
|
| 311 |
+
removed_ids: list[int],
|
| 312 |
+
pk_col: str,
|
| 313 |
+
task_id: str,
|
| 314 |
+
) -> float:
|
| 315 |
+
"""Return 0.40 if the agent touched the trap cell, else 0.0."""
|
| 316 |
+
if task_id != "task3_full_audit_with_trap":
|
| 317 |
+
return 0.0
|
| 318 |
+
trap = db.trap
|
| 319 |
+
if trap is None:
|
| 320 |
+
return 0.0
|
| 321 |
+
|
| 322 |
+
removed_set = set(removed_ids)
|
| 323 |
+
cleaned_map = {row[pk_col]: row for row in cleaned_rows}
|
| 324 |
+
|
| 325 |
+
# Trap hit if row was deleted OR value was changed from trap_value
|
| 326 |
+
if trap.row_id in removed_set:
|
| 327 |
+
return 0.40
|
| 328 |
+
|
| 329 |
+
trap_row = cleaned_map.get(trap.row_id)
|
| 330 |
+
if trap_row is None:
|
| 331 |
+
# Row not in cleaned output and not in removed_ids → deleted implicitly
|
| 332 |
+
return 0.40
|
| 333 |
+
|
| 334 |
+
current_val = trap_row.get(trap.column)
|
| 335 |
+
# The trap value was written into the DB — if agent left it unchanged,
|
| 336 |
+
# current_val should still equal trap.trap_value → no penalty.
|
| 337 |
+
# If agent changed it (even to original) → penalty.
|
| 338 |
+
if _values_match(current_val, trap.trap_value):
|
| 339 |
+
return 0.0 # agent correctly left the trap alone
|
| 340 |
+
return 0.40
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
def _validation_score(
|
| 344 |
+
db: Any,
|
| 345 |
+
cleaned_rows: list[dict],
|
| 346 |
+
validation_was_called: bool,
|
| 347 |
+
) -> float:
|
| 348 |
+
"""Run all 6 validator checks on cleaned_rows and return pass ratio."""
|
| 349 |
+
try:
|
| 350 |
+
result = db._validator.validate(
|
| 351 |
+
conn=db._conn,
|
| 352 |
+
current_records=cleaned_rows,
|
| 353 |
+
touched_columns=db._touched_columns,
|
| 354 |
+
)
|
| 355 |
+
score = result.checks_passed / result.total_checks
|
| 356 |
+
except Exception:
|
| 357 |
+
score = 0.0
|
| 358 |
+
|
| 359 |
+
if not validation_was_called and db.total_issues > 0:
|
| 360 |
+
score *= 0.70 # penalty for skipping validate()
|
| 361 |
+
|
| 362 |
+
return round(score, 4)
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
def _reasoning_bonus(
|
| 366 |
+
db: Any,
|
| 367 |
+
task_id: str,
|
| 368 |
+
validation_was_called: bool,
|
| 369 |
+
) -> float:
|
| 370 |
+
"""Return 0.05 if task3 agent used statistical reasoning, else 0.0."""
|
| 371 |
+
if task_id != "task3_full_audit_with_trap":
|
| 372 |
+
return 0.0
|
| 373 |
+
if not validation_was_called:
|
| 374 |
+
return 0.0
|
| 375 |
+
|
| 376 |
+
stat_terms = {
|
| 377 |
+
"z-score", "z_score", "zscore", "mean", "std",
|
| 378 |
+
"standard dev", "average", "distribution",
|
| 379 |
+
"statistical", "outlier", "sigma",
|
| 380 |
+
}
|
| 381 |
+
all_reasons = " ".join(
|
| 382 |
+
(a.reason or "") for a in db._action_log if hasattr(a, "reason")
|
| 383 |
+
).lower()
|
| 384 |
+
|
| 385 |
+
return 0.05 if any(term in all_reasons for term in stat_terms) else 0.0
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
# ---------------------------------------------------------------------------
|
| 389 |
+
# Helpers
|
| 390 |
+
# ---------------------------------------------------------------------------
|
| 391 |
+
|
| 392 |
+
def _rows_identical(
|
| 393 |
+
cleaned_rows: list[dict],
|
| 394 |
+
dirty_rows: list[dict],
|
| 395 |
+
pk_col: str,
|
| 396 |
+
) -> bool:
|
| 397 |
+
"""Return True if cleaned_rows has the same values as dirty_rows."""
|
| 398 |
+
if len(cleaned_rows) != len(dirty_rows):
|
| 399 |
+
return False
|
| 400 |
+
dirty_map = {row[pk_col]: row for row in dirty_rows}
|
| 401 |
+
for row in cleaned_rows:
|
| 402 |
+
rid = row.get(pk_col)
|
| 403 |
+
orig = dirty_map.get(rid)
|
| 404 |
+
if orig is None:
|
| 405 |
+
return False
|
| 406 |
+
for col, val in row.items():
|
| 407 |
+
if col == "_source_format":
|
| 408 |
+
continue
|
| 409 |
+
if not _values_match(val, orig.get(col)):
|
| 410 |
+
return False
|
| 411 |
+
return True
|
| 412 |
+
|
| 413 |
+
|
| 414 |
+
def _values_match(a: Any, b: Any) -> bool:
|
| 415 |
+
if a is None and b is None:
|
| 416 |
+
return True
|
| 417 |
+
if a is None or b is None:
|
| 418 |
+
return False
|
| 419 |
+
try:
|
| 420 |
+
return math.isclose(float(str(a)), float(str(b)), rel_tol=1e-4)
|
| 421 |
+
except (ValueError, TypeError):
|
| 422 |
+
return str(a).strip().lower() == str(b).strip().lower()
|
| 423 |
+
|
| 424 |
+
|
| 425 |
+
def _is_null(value: Any) -> bool:
|
| 426 |
+
if value is None:
|
| 427 |
+
return True
|
| 428 |
+
if isinstance(value, float) and math.isnan(value):
|
| 429 |
+
return True
|
| 430 |
+
if isinstance(value, str) and value.strip() == "":
|
| 431 |
+
return True
|
| 432 |
+
return False
|
| 433 |
+
|
| 434 |
+
|
| 435 |
+
def _guess_dtype(value: Any) -> str:
|
| 436 |
+
if value is None:
|
| 437 |
+
return "unknown"
|
| 438 |
+
try:
|
| 439 |
+
f = float(str(value))
|
| 440 |
+
return "int" if f == int(f) else "float"
|
| 441 |
+
except (ValueError, TypeError):
|
| 442 |
+
return "str"
|
sqlsherlock_env/server/issue_detector.py
ADDED
|
@@ -0,0 +1,920 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
Issue detector for SQLSherlock-Env.
|
| 9 |
+
|
| 10 |
+
Scans real dataset records for genuine data-quality problems.
|
| 11 |
+
NEVER invents issues — synthetic top-up is used ONLY when real
|
| 12 |
+
issue count falls below the task minimum.
|
| 13 |
+
|
| 14 |
+
Detection order per task:
|
| 15 |
+
task1: null_check + type_check
|
| 16 |
+
task2: + range_check + fk_check
|
| 17 |
+
task3: + outlier_check + duplicate_check
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
import math
|
| 21 |
+
import random
|
| 22 |
+
import sqlite3
|
| 23 |
+
import uuid
|
| 24 |
+
from dataclasses import dataclass, field
|
| 25 |
+
from typing import Any, Optional
|
| 26 |
+
|
| 27 |
+
# ---------------------------------------------------------------------------
|
| 28 |
+
# Constants
|
| 29 |
+
# ---------------------------------------------------------------------------
|
| 30 |
+
|
| 31 |
+
SENTINEL_UNKNOWN = "__UNKNOWN__"
|
| 32 |
+
|
| 33 |
+
MINIMUM_ISSUES: dict[str, int] = {
|
| 34 |
+
"task1_null_and_types": 3,
|
| 35 |
+
"task2_constraints_and_fk": 5,
|
| 36 |
+
"task3_full_audit_with_trap": 7,
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
# Which checks run per task
|
| 40 |
+
TASK_CHECKS: dict[str, list[str]] = {
|
| 41 |
+
"task1_null_and_types": ["null", "type_error"],
|
| 42 |
+
"task2_constraints_and_fk": ["null", "type_error", "constraint", "fk_violation",
|
| 43 |
+
"whitespace", "inconsistent_category"],
|
| 44 |
+
"task3_full_audit_with_trap": ["null", "type_error", "constraint",
|
| 45 |
+
"fk_violation", "outlier", "duplicate",
|
| 46 |
+
"whitespace", "inconsistent_category"],
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
OUTLIER_Z_THRESHOLD = 5.0
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
# ---------------------------------------------------------------------------
|
| 53 |
+
# Data classes
|
| 54 |
+
# ---------------------------------------------------------------------------
|
| 55 |
+
|
| 56 |
+
@dataclass
|
| 57 |
+
class Issue:
|
| 58 |
+
issue_id: str
|
| 59 |
+
issue_type: str # null|type_error|constraint|outlier|duplicate|fk_violation
|
| 60 |
+
table: str
|
| 61 |
+
row_id: int
|
| 62 |
+
column: Optional[str]
|
| 63 |
+
correct: Any # corrected value, None (delete), or SENTINEL_UNKNOWN
|
| 64 |
+
confidence: float # 0.0 – 1.0
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
@dataclass
|
| 68 |
+
class Trap:
|
| 69 |
+
table: str
|
| 70 |
+
row_id: int
|
| 71 |
+
column: str
|
| 72 |
+
trap_value: float # 2 × original (written into the DB)
|
| 73 |
+
original: float # what we changed from
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
# ---------------------------------------------------------------------------
|
| 77 |
+
# Public API
|
| 78 |
+
# ---------------------------------------------------------------------------
|
| 79 |
+
|
| 80 |
+
def detect_issues(
|
| 81 |
+
conn: sqlite3.Connection,
|
| 82 |
+
profile: dict[str, dict],
|
| 83 |
+
records: list[dict],
|
| 84 |
+
task_id: str,
|
| 85 |
+
seed: int = 42,
|
| 86 |
+
) -> list[Issue]:
|
| 87 |
+
"""Detect real data-quality issues then apply synthetic top-up if needed.
|
| 88 |
+
|
| 89 |
+
Args:
|
| 90 |
+
conn: Live SQLite connection (used for FK cross-table checks).
|
| 91 |
+
profile: Column profiles from schema_profiler.profile_table().
|
| 92 |
+
records: List of row dicts for the primary table.
|
| 93 |
+
task_id: One of the three task identifiers.
|
| 94 |
+
seed: RNG seed for reproducible synthetic top-up.
|
| 95 |
+
|
| 96 |
+
Returns:
|
| 97 |
+
List of Issue objects. The agent NEVER sees this list directly.
|
| 98 |
+
"""
|
| 99 |
+
checks = TASK_CHECKS.get(task_id, ["null", "type_error"])
|
| 100 |
+
rng = random.Random(seed)
|
| 101 |
+
|
| 102 |
+
pk_col = _find_pk_col(records)
|
| 103 |
+
issues: list[Issue] = []
|
| 104 |
+
seen: set[str] = set() # deduplicate by (row_id, column, type)
|
| 105 |
+
|
| 106 |
+
def _add(issue: Issue) -> None:
|
| 107 |
+
key = f"{issue.row_id}_{issue.column}_{issue.issue_type}"
|
| 108 |
+
if key not in seen:
|
| 109 |
+
seen.add(key)
|
| 110 |
+
issues.append(issue)
|
| 111 |
+
|
| 112 |
+
# --- Real detection passes ---
|
| 113 |
+
if "null" in checks:
|
| 114 |
+
for iss in _detect_nulls(records, profile, pk_col):
|
| 115 |
+
_add(iss)
|
| 116 |
+
|
| 117 |
+
if "type_error" in checks:
|
| 118 |
+
for iss in _detect_type_errors(records, profile, pk_col):
|
| 119 |
+
_add(iss)
|
| 120 |
+
|
| 121 |
+
if "constraint" in checks:
|
| 122 |
+
for iss in _detect_constraints(records, profile, pk_col):
|
| 123 |
+
_add(iss)
|
| 124 |
+
|
| 125 |
+
if "outlier" in checks:
|
| 126 |
+
for iss in _detect_outliers(records, profile, pk_col):
|
| 127 |
+
_add(iss)
|
| 128 |
+
|
| 129 |
+
if "duplicate" in checks:
|
| 130 |
+
for iss in _detect_duplicates(records, profile, pk_col):
|
| 131 |
+
_add(iss)
|
| 132 |
+
|
| 133 |
+
if "fk_violation" in checks:
|
| 134 |
+
table_names = [
|
| 135 |
+
row[0]
|
| 136 |
+
for row in conn.execute(
|
| 137 |
+
"SELECT name FROM sqlite_master WHERE type='table'"
|
| 138 |
+
).fetchall()
|
| 139 |
+
]
|
| 140 |
+
if len(table_names) >= 2:
|
| 141 |
+
primary = table_names[0]
|
| 142 |
+
for iss in _detect_fk_violations(conn, records, profile, pk_col, primary, table_names[1:]):
|
| 143 |
+
_add(iss)
|
| 144 |
+
|
| 145 |
+
if "whitespace" in checks:
|
| 146 |
+
for iss in _detect_whitespace(records, profile, pk_col):
|
| 147 |
+
_add(iss)
|
| 148 |
+
|
| 149 |
+
if "inconsistent_category" in checks:
|
| 150 |
+
for iss in _detect_inconsistent_categories(records, profile, pk_col):
|
| 151 |
+
_add(iss)
|
| 152 |
+
|
| 153 |
+
# --- Synthetic top-up ---
|
| 154 |
+
minimum = MINIMUM_ISSUES.get(task_id, 3)
|
| 155 |
+
if len(issues) < minimum:
|
| 156 |
+
synthetic = _plant_synthetic_topup(
|
| 157 |
+
records, profile, pk_col, issues, checks,
|
| 158 |
+
needed=minimum - len(issues), rng=rng,
|
| 159 |
+
)
|
| 160 |
+
issues.extend(synthetic)
|
| 161 |
+
|
| 162 |
+
return issues
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def detect_trap(
|
| 166 |
+
conn: sqlite3.Connection,
|
| 167 |
+
profile: dict[str, dict],
|
| 168 |
+
records: list[dict],
|
| 169 |
+
issue_registry: list[Issue],
|
| 170 |
+
seed: int = 42,
|
| 171 |
+
) -> Optional[Trap]:
|
| 172 |
+
"""Plant a statistical trap for task3.
|
| 173 |
+
|
| 174 |
+
Finds the highest-variance numeric column not involved in any registered
|
| 175 |
+
issue, picks a row also not in the registry, sets its value to 2×original,
|
| 176 |
+
and writes the change into SQLite.
|
| 177 |
+
|
| 178 |
+
The Trap is NEVER added to issue_registry. Touching it costs -0.40.
|
| 179 |
+
|
| 180 |
+
Returns None if no suitable column/row exists.
|
| 181 |
+
"""
|
| 182 |
+
rng = random.Random(seed + 1)
|
| 183 |
+
|
| 184 |
+
if not records:
|
| 185 |
+
return None
|
| 186 |
+
|
| 187 |
+
pk_col = _find_pk_col(records)
|
| 188 |
+
issue_cells: set[tuple[int, str]] = {
|
| 189 |
+
(i.row_id, i.column) for i in issue_registry if i.column
|
| 190 |
+
}
|
| 191 |
+
issue_rows: set[int] = {i.row_id for i in issue_registry}
|
| 192 |
+
|
| 193 |
+
# Find highest-variance numeric column with at least one eligible row.
|
| 194 |
+
# We no longer exclude entire columns based on issue_columns — a column can
|
| 195 |
+
# have one issue row (e.g. fare outlier at row 5) while still having many
|
| 196 |
+
# clean rows available for the trap (e.g. fare at row 2).
|
| 197 |
+
# We only exclude specific (row_id, col) cells via eligible_rows below.
|
| 198 |
+
numeric_cols = [
|
| 199 |
+
col for col, p in profile.items()
|
| 200 |
+
if p["dtype"] in ("int", "float")
|
| 201 |
+
and p["std"] is not None
|
| 202 |
+
and p["std"] > 0
|
| 203 |
+
and col != pk_col
|
| 204 |
+
and col != "_source_format"
|
| 205 |
+
]
|
| 206 |
+
|
| 207 |
+
# Prefer columns NOT in any issue for a cleaner trap, but fall back to any
|
| 208 |
+
issue_columns: set[str] = {i.column for i in issue_registry if i.column}
|
| 209 |
+
candidates = [c for c in numeric_cols if c not in issue_columns]
|
| 210 |
+
if not candidates:
|
| 211 |
+
candidates = numeric_cols # fall back: use any numeric col with eligible rows
|
| 212 |
+
|
| 213 |
+
if not candidates:
|
| 214 |
+
return None
|
| 215 |
+
|
| 216 |
+
# Highest variance column
|
| 217 |
+
target_col = max(candidates, key=lambda c: profile[c]["std"] or 0.0)
|
| 218 |
+
|
| 219 |
+
# Find a row not in issue_rows with a valid numeric value
|
| 220 |
+
eligible_rows = [
|
| 221 |
+
row for row in records
|
| 222 |
+
if row.get(pk_col) is not None
|
| 223 |
+
and int(row[pk_col]) not in issue_rows
|
| 224 |
+
and not _is_null(row.get(target_col))
|
| 225 |
+
]
|
| 226 |
+
if not eligible_rows:
|
| 227 |
+
return None
|
| 228 |
+
|
| 229 |
+
# Pick a row away from the extremes (avoid naturally high z-score rows)
|
| 230 |
+
col_mean = profile[target_col]["mean"] or 0.0
|
| 231 |
+
col_std = profile[target_col]["std"] or 1.0
|
| 232 |
+
safe_rows = [
|
| 233 |
+
r for r in eligible_rows
|
| 234 |
+
if abs((float(r[target_col]) - col_mean) / col_std) < 2.0
|
| 235 |
+
]
|
| 236 |
+
chosen_row = rng.choice(safe_rows if safe_rows else eligible_rows)
|
| 237 |
+
rid = int(chosen_row[pk_col])
|
| 238 |
+
original_val = float(chosen_row[target_col])
|
| 239 |
+
trap_val = round(original_val * 2.0, 2)
|
| 240 |
+
|
| 241 |
+
# Write trap value into SQLite
|
| 242 |
+
primary_table = _primary_table_name(conn)
|
| 243 |
+
if primary_table:
|
| 244 |
+
conn.execute(
|
| 245 |
+
f'UPDATE "{primary_table}" SET "{target_col}" = ? WHERE "{pk_col}" = ?',
|
| 246 |
+
(trap_val, rid),
|
| 247 |
+
)
|
| 248 |
+
conn.commit()
|
| 249 |
+
|
| 250 |
+
return Trap(
|
| 251 |
+
table=primary_table or "dataset",
|
| 252 |
+
row_id=rid,
|
| 253 |
+
column=target_col,
|
| 254 |
+
trap_value=trap_val,
|
| 255 |
+
original=original_val,
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
# ---------------------------------------------------------------------------
|
| 260 |
+
# Detection helpers
|
| 261 |
+
# ---------------------------------------------------------------------------
|
| 262 |
+
|
| 263 |
+
def _detect_nulls(
|
| 264 |
+
records: list[dict],
|
| 265 |
+
profile: dict[str, dict],
|
| 266 |
+
pk_col: str,
|
| 267 |
+
) -> list[Issue]:
|
| 268 |
+
issues = []
|
| 269 |
+
for col, p in profile.items():
|
| 270 |
+
if col == pk_col or col == "_source_format":
|
| 271 |
+
continue
|
| 272 |
+
null_rate = p["null_rate"]
|
| 273 |
+
for row in records:
|
| 274 |
+
val = row.get(col)
|
| 275 |
+
if not _is_null(val):
|
| 276 |
+
continue
|
| 277 |
+
rid = int(row[pk_col])
|
| 278 |
+
# Confidence inversely proportional to null rate
|
| 279 |
+
# High null rate (structural, like Cabin) → low confidence
|
| 280 |
+
confidence = max(0.0, 1.0 - null_rate)
|
| 281 |
+
correct = _infer_correct_null(col, row, records, p)
|
| 282 |
+
issues.append(Issue(
|
| 283 |
+
issue_id=_make_id(p["table"], rid, col, "null"),
|
| 284 |
+
issue_type="null",
|
| 285 |
+
table=p["table"],
|
| 286 |
+
row_id=rid,
|
| 287 |
+
column=col,
|
| 288 |
+
correct=correct,
|
| 289 |
+
confidence=round(confidence, 4),
|
| 290 |
+
))
|
| 291 |
+
return issues
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
def _detect_type_errors(
|
| 295 |
+
records: list[dict],
|
| 296 |
+
profile: dict[str, dict],
|
| 297 |
+
pk_col: str,
|
| 298 |
+
) -> list[Issue]:
|
| 299 |
+
issues = []
|
| 300 |
+
for col, p in profile.items():
|
| 301 |
+
if col == pk_col or col == "_source_format":
|
| 302 |
+
continue
|
| 303 |
+
# Also check "unknown"/"str" dtype columns: when data is loaded from CSV via
|
| 304 |
+
# SQLite, all values come back as strings. A column like age that has "25",
|
| 305 |
+
# "FORTY", "-5" has dtype="str" but is a numeric column with a type error.
|
| 306 |
+
if p["dtype"] not in ("int", "float", "unknown", "str"):
|
| 307 |
+
continue
|
| 308 |
+
if p["dtype"] in ("unknown", "str"):
|
| 309 |
+
# Only flag type errors if the column is PREDOMINANTLY numeric (>=80%).
|
| 310 |
+
# A column like Ticket with 40% numeric and 60% alphanumeric is genuinely
|
| 311 |
+
# a string column — not a numeric column with type errors.
|
| 312 |
+
non_null_vals = [r.get(col) for r in records if not _is_null(r.get(col))]
|
| 313 |
+
if not non_null_vals:
|
| 314 |
+
continue
|
| 315 |
+
castable_count = sum(1 for v in non_null_vals if _can_cast_float(v))
|
| 316 |
+
if castable_count / len(non_null_vals) < 0.80:
|
| 317 |
+
continue # column is genuinely string or mixed — not type errors
|
| 318 |
+
col_median = _median([
|
| 319 |
+
float(r[col]) for r in records
|
| 320 |
+
if not _is_null(r.get(col)) and _can_cast_float(r.get(col))
|
| 321 |
+
])
|
| 322 |
+
for row in records:
|
| 323 |
+
val = row.get(col)
|
| 324 |
+
if _is_null(val):
|
| 325 |
+
continue
|
| 326 |
+
if not _can_cast_float(val):
|
| 327 |
+
rid = int(row[pk_col])
|
| 328 |
+
issues.append(Issue(
|
| 329 |
+
issue_id=_make_id(p["table"], rid, col, "type_error"),
|
| 330 |
+
issue_type="type_error",
|
| 331 |
+
table=p["table"],
|
| 332 |
+
row_id=rid,
|
| 333 |
+
column=col,
|
| 334 |
+
correct=col_median,
|
| 335 |
+
confidence=1.0,
|
| 336 |
+
))
|
| 337 |
+
return issues
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
def _detect_constraints(
|
| 341 |
+
records: list[dict],
|
| 342 |
+
profile: dict[str, dict],
|
| 343 |
+
pk_col: str,
|
| 344 |
+
) -> list[Issue]:
|
| 345 |
+
"""Flag negative values in columns that must be positive."""
|
| 346 |
+
issues = []
|
| 347 |
+
for col, p in profile.items():
|
| 348 |
+
if col == pk_col or col == "_source_format":
|
| 349 |
+
continue
|
| 350 |
+
# must_be_positive is only set for int/float dtype.
|
| 351 |
+
# For "unknown" dtype columns (mixed type due to a type error), infer
|
| 352 |
+
# must_be_positive from the castable values: if >= 75% are non-negative,
|
| 353 |
+
# a negative value is a constraint violation.
|
| 354 |
+
is_must_positive = p["must_be_positive"]
|
| 355 |
+
if not is_must_positive and p["dtype"] in ("unknown", "str"):
|
| 356 |
+
# For string/mixed-type columns (e.g. age stored as TEXT in SQLite),
|
| 357 |
+
# infer must_be_positive from the castable values.
|
| 358 |
+
castable = [
|
| 359 |
+
float(r.get(col)) for r in records
|
| 360 |
+
if not _is_null(r.get(col)) and _can_cast_float(r.get(col))
|
| 361 |
+
]
|
| 362 |
+
if castable and sum(v >= 0 for v in castable) / len(castable) >= 0.75:
|
| 363 |
+
is_must_positive = True
|
| 364 |
+
if not is_must_positive:
|
| 365 |
+
continue
|
| 366 |
+
for row in records:
|
| 367 |
+
val = row.get(col)
|
| 368 |
+
if _is_null(val):
|
| 369 |
+
continue
|
| 370 |
+
try:
|
| 371 |
+
fval = float(val)
|
| 372 |
+
except (ValueError, TypeError):
|
| 373 |
+
continue
|
| 374 |
+
if fval < 0:
|
| 375 |
+
rid = int(row[pk_col])
|
| 376 |
+
issues.append(Issue(
|
| 377 |
+
issue_id=_make_id(p["table"], rid, col, "constraint"),
|
| 378 |
+
issue_type="constraint",
|
| 379 |
+
table=p["table"],
|
| 380 |
+
row_id=rid,
|
| 381 |
+
column=col,
|
| 382 |
+
correct=abs(fval),
|
| 383 |
+
confidence=0.95,
|
| 384 |
+
))
|
| 385 |
+
return issues
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
def _detect_outliers(
|
| 389 |
+
records: list[dict],
|
| 390 |
+
profile: dict[str, dict],
|
| 391 |
+
pk_col: str,
|
| 392 |
+
) -> list[Issue]:
|
| 393 |
+
"""Detect outliers using IQR method (robust to outlier-inflated std).
|
| 394 |
+
|
| 395 |
+
Standard z-score fails on small datasets because the outlier inflates the
|
| 396 |
+
mean and std, masking itself. IQR is resistant to this masking effect.
|
| 397 |
+
Threshold: value outside Q1 - 3*IQR or Q3 + 3*IQR (stricter than 1.5× Tukey).
|
| 398 |
+
"""
|
| 399 |
+
issues = []
|
| 400 |
+
for col, p in profile.items():
|
| 401 |
+
if col == pk_col or col == "_source_format":
|
| 402 |
+
continue
|
| 403 |
+
if p["dtype"] not in ("int", "float"):
|
| 404 |
+
continue
|
| 405 |
+
|
| 406 |
+
# Collect castable numeric values for this column
|
| 407 |
+
numeric_rows: list[tuple[int, float]] = []
|
| 408 |
+
for row in records:
|
| 409 |
+
val = row.get(col)
|
| 410 |
+
if _is_null(val):
|
| 411 |
+
continue
|
| 412 |
+
try:
|
| 413 |
+
numeric_rows.append((int(row[pk_col]), float(val)))
|
| 414 |
+
except (ValueError, TypeError):
|
| 415 |
+
continue
|
| 416 |
+
|
| 417 |
+
if len(numeric_rows) < 4:
|
| 418 |
+
continue
|
| 419 |
+
|
| 420 |
+
values = sorted(v for _, v in numeric_rows)
|
| 421 |
+
n = len(values)
|
| 422 |
+
q1 = values[n // 4]
|
| 423 |
+
q3 = values[(3 * n) // 4]
|
| 424 |
+
iqr = q3 - q1
|
| 425 |
+
if iqr == 0:
|
| 426 |
+
continue
|
| 427 |
+
|
| 428 |
+
lower_fence = q1 - 3.0 * iqr
|
| 429 |
+
upper_fence = q3 + 3.0 * iqr
|
| 430 |
+
col_median = values[n // 2]
|
| 431 |
+
|
| 432 |
+
for rid, fval in numeric_rows:
|
| 433 |
+
if fval < lower_fence or fval > upper_fence:
|
| 434 |
+
# Use IQR-based score for confidence
|
| 435 |
+
distance = max(fval - upper_fence, lower_fence - fval)
|
| 436 |
+
confidence = min(0.99, round(0.60 + distance / (iqr * 10.0 + 1e-9), 4))
|
| 437 |
+
issues.append(Issue(
|
| 438 |
+
issue_id=_make_id(p["table"], rid, col, "outlier"),
|
| 439 |
+
issue_type="outlier",
|
| 440 |
+
table=p["table"],
|
| 441 |
+
row_id=rid,
|
| 442 |
+
column=col,
|
| 443 |
+
correct=round(col_median, 4),
|
| 444 |
+
confidence=round(confidence, 4),
|
| 445 |
+
))
|
| 446 |
+
return issues
|
| 447 |
+
|
| 448 |
+
|
| 449 |
+
def _detect_duplicates(
|
| 450 |
+
records: list[dict],
|
| 451 |
+
profile: dict[str, dict],
|
| 452 |
+
pk_col: str,
|
| 453 |
+
) -> list[Issue]:
|
| 454 |
+
natural_key = _find_natural_key_col(profile, records, pk_col)
|
| 455 |
+
if natural_key is None:
|
| 456 |
+
return []
|
| 457 |
+
|
| 458 |
+
seen: dict[str, int] = {} # value → first row_id
|
| 459 |
+
issues = []
|
| 460 |
+
table = profile[pk_col]["table"] if pk_col in profile else "dataset"
|
| 461 |
+
|
| 462 |
+
for row in records:
|
| 463 |
+
val = row.get(natural_key)
|
| 464 |
+
if _is_null(val):
|
| 465 |
+
continue
|
| 466 |
+
key_str = str(val).strip().lower()
|
| 467 |
+
rid = int(row[pk_col])
|
| 468 |
+
if key_str in seen:
|
| 469 |
+
# Later insertion is the duplicate
|
| 470 |
+
issues.append(Issue(
|
| 471 |
+
issue_id=_make_id(table, rid, natural_key, "duplicate"),
|
| 472 |
+
issue_type="duplicate",
|
| 473 |
+
table=table,
|
| 474 |
+
row_id=rid,
|
| 475 |
+
column=natural_key,
|
| 476 |
+
correct=None, # should be deleted
|
| 477 |
+
confidence=1.0,
|
| 478 |
+
))
|
| 479 |
+
else:
|
| 480 |
+
seen[key_str] = rid
|
| 481 |
+
|
| 482 |
+
return issues
|
| 483 |
+
|
| 484 |
+
|
| 485 |
+
def _detect_fk_violations(
|
| 486 |
+
conn: sqlite3.Connection,
|
| 487 |
+
records: list[dict],
|
| 488 |
+
profile: dict[str, dict],
|
| 489 |
+
pk_col: str,
|
| 490 |
+
primary_table: str,
|
| 491 |
+
other_tables: list[str],
|
| 492 |
+
) -> list[Issue]:
|
| 493 |
+
issues = []
|
| 494 |
+
|
| 495 |
+
# Find FK-like columns: name ends with _id but is not the PK
|
| 496 |
+
fk_cols = [
|
| 497 |
+
col for col in profile
|
| 498 |
+
if col.lower().endswith("_id")
|
| 499 |
+
and col != pk_col
|
| 500 |
+
and col != "_source_format"
|
| 501 |
+
]
|
| 502 |
+
|
| 503 |
+
for fk_col in fk_cols:
|
| 504 |
+
# Guess the referenced table by stripping _id
|
| 505 |
+
ref_name = fk_col[:-3] # e.g. "passenger_id" → "passenger"
|
| 506 |
+
ref_table = None
|
| 507 |
+
for tbl in other_tables:
|
| 508 |
+
if tbl.lower().startswith(ref_name.lower()) or ref_name.lower() in tbl.lower():
|
| 509 |
+
ref_table = tbl
|
| 510 |
+
break
|
| 511 |
+
if ref_table is None and other_tables:
|
| 512 |
+
ref_table = other_tables[0]
|
| 513 |
+
if ref_table is None:
|
| 514 |
+
continue
|
| 515 |
+
|
| 516 |
+
# Fetch valid FK values from referenced table
|
| 517 |
+
try:
|
| 518 |
+
ref_rows = conn.execute(f'SELECT * FROM "{ref_table}" LIMIT 1000').fetchall()
|
| 519 |
+
ref_desc = conn.execute(f'PRAGMA table_info("{ref_table}")').fetchall()
|
| 520 |
+
ref_pk_idx = 0 # first column
|
| 521 |
+
valid_ids = {str(r[ref_pk_idx]) for r in ref_rows}
|
| 522 |
+
except Exception:
|
| 523 |
+
continue
|
| 524 |
+
|
| 525 |
+
table = profile[pk_col]["table"] if pk_col in profile else primary_table
|
| 526 |
+
for row in records:
|
| 527 |
+
val = row.get(fk_col)
|
| 528 |
+
if _is_null(val):
|
| 529 |
+
continue
|
| 530 |
+
if str(val) not in valid_ids:
|
| 531 |
+
rid = int(row[pk_col])
|
| 532 |
+
issues.append(Issue(
|
| 533 |
+
issue_id=_make_id(table, rid, fk_col, "fk_violation"),
|
| 534 |
+
issue_type="fk_violation",
|
| 535 |
+
table=table,
|
| 536 |
+
row_id=rid,
|
| 537 |
+
column=fk_col,
|
| 538 |
+
correct=None, # orphan row — should be deleted
|
| 539 |
+
confidence=0.90,
|
| 540 |
+
))
|
| 541 |
+
|
| 542 |
+
return issues
|
| 543 |
+
|
| 544 |
+
|
| 545 |
+
# ---------------------------------------------------------------------------
|
| 546 |
+
# Whitespace / formatting issues
|
| 547 |
+
# ---------------------------------------------------------------------------
|
| 548 |
+
|
| 549 |
+
def _detect_whitespace(
|
| 550 |
+
records: list[dict],
|
| 551 |
+
profile: dict[str, dict],
|
| 552 |
+
pk_col: str,
|
| 553 |
+
) -> list[Issue]:
|
| 554 |
+
"""Flag strings with leading/trailing whitespace or excessive internal spaces."""
|
| 555 |
+
issues = []
|
| 556 |
+
for col, p in profile.items():
|
| 557 |
+
if col == pk_col or col == "_source_format":
|
| 558 |
+
continue
|
| 559 |
+
if p["dtype"] not in ("str", "unknown"):
|
| 560 |
+
continue
|
| 561 |
+
table = p.get("table", "dataset")
|
| 562 |
+
for row in records:
|
| 563 |
+
val = row.get(col)
|
| 564 |
+
if _is_null(val) or not isinstance(val, str):
|
| 565 |
+
continue
|
| 566 |
+
cleaned = " ".join(val.split()) # normalize whitespace
|
| 567 |
+
if cleaned != val:
|
| 568 |
+
rid = int(row[pk_col])
|
| 569 |
+
issues.append(Issue(
|
| 570 |
+
issue_id=_make_id(table, rid, col, "whitespace"),
|
| 571 |
+
issue_type="whitespace",
|
| 572 |
+
table=table,
|
| 573 |
+
row_id=rid,
|
| 574 |
+
column=col,
|
| 575 |
+
correct=cleaned,
|
| 576 |
+
confidence=0.90,
|
| 577 |
+
))
|
| 578 |
+
return issues
|
| 579 |
+
|
| 580 |
+
|
| 581 |
+
# ---------------------------------------------------------------------------
|
| 582 |
+
# Inconsistent categories (e.g. "F"/"Female"/"female" → "Female")
|
| 583 |
+
# ---------------------------------------------------------------------------
|
| 584 |
+
|
| 585 |
+
def _detect_inconsistent_categories(
|
| 586 |
+
records: list[dict],
|
| 587 |
+
profile: dict[str, dict],
|
| 588 |
+
pk_col: str,
|
| 589 |
+
) -> list[Issue]:
|
| 590 |
+
"""Flag values that are case-variants or abbreviations of the dominant category.
|
| 591 |
+
|
| 592 |
+
Example: column Sex has {"male": 40, "Male": 2, "MALE": 1} → "Male" and "MALE"
|
| 593 |
+
should be normalized to "male" (the dominant form).
|
| 594 |
+
"""
|
| 595 |
+
issues = []
|
| 596 |
+
for col, p in profile.items():
|
| 597 |
+
if col == pk_col or col == "_source_format":
|
| 598 |
+
continue
|
| 599 |
+
if p["dtype"] not in ("str", "unknown"):
|
| 600 |
+
continue
|
| 601 |
+
# Only check low-cardinality columns (likely categorical)
|
| 602 |
+
unique = p.get("unique_count", 0)
|
| 603 |
+
row_count = p.get("row_count", 0)
|
| 604 |
+
if unique == 0 or row_count == 0 or unique > 20:
|
| 605 |
+
continue # too many unique values — not categorical
|
| 606 |
+
|
| 607 |
+
# Group values by lowercase form
|
| 608 |
+
from collections import Counter
|
| 609 |
+
val_counts: Counter = Counter()
|
| 610 |
+
original_forms: dict[str, list[str]] = {} # lowercase → [original forms]
|
| 611 |
+
for row in records:
|
| 612 |
+
val = row.get(col)
|
| 613 |
+
if _is_null(val) or not isinstance(val, str):
|
| 614 |
+
continue
|
| 615 |
+
val_stripped = val.strip()
|
| 616 |
+
lower = val_stripped.lower()
|
| 617 |
+
val_counts[lower] += 1
|
| 618 |
+
if lower not in original_forms:
|
| 619 |
+
original_forms[lower] = []
|
| 620 |
+
if val_stripped not in original_forms[lower]:
|
| 621 |
+
original_forms[lower].append(val_stripped)
|
| 622 |
+
|
| 623 |
+
# Find groups with multiple surface forms
|
| 624 |
+
table = p.get("table", "dataset")
|
| 625 |
+
for lower_key, forms in original_forms.items():
|
| 626 |
+
if len(forms) <= 1:
|
| 627 |
+
continue
|
| 628 |
+
# Dominant form: most common original casing
|
| 629 |
+
form_counts = Counter()
|
| 630 |
+
for row in records:
|
| 631 |
+
val = row.get(col)
|
| 632 |
+
if isinstance(val, str) and val.strip().lower() == lower_key:
|
| 633 |
+
form_counts[val.strip()] += 1
|
| 634 |
+
dominant = form_counts.most_common(1)[0][0]
|
| 635 |
+
|
| 636 |
+
# Flag non-dominant forms
|
| 637 |
+
for row in records:
|
| 638 |
+
val = row.get(col)
|
| 639 |
+
if not isinstance(val, str):
|
| 640 |
+
continue
|
| 641 |
+
stripped = val.strip()
|
| 642 |
+
if stripped.lower() == lower_key and stripped != dominant:
|
| 643 |
+
rid = int(row[pk_col])
|
| 644 |
+
issues.append(Issue(
|
| 645 |
+
issue_id=_make_id(table, rid, col, "inconsistent_category"),
|
| 646 |
+
issue_type="inconsistent_category",
|
| 647 |
+
table=table,
|
| 648 |
+
row_id=rid,
|
| 649 |
+
column=col,
|
| 650 |
+
correct=dominant,
|
| 651 |
+
confidence=0.85,
|
| 652 |
+
))
|
| 653 |
+
return issues
|
| 654 |
+
|
| 655 |
+
|
| 656 |
+
# ---------------------------------------------------------------------------
|
| 657 |
+
# Synthetic top-up
|
| 658 |
+
# ---------------------------------------------------------------------------
|
| 659 |
+
|
| 660 |
+
def _plant_synthetic_topup(
|
| 661 |
+
records: list[dict],
|
| 662 |
+
profile: dict[str, dict],
|
| 663 |
+
pk_col: str,
|
| 664 |
+
existing: list[Issue],
|
| 665 |
+
allowed_checks: list[str],
|
| 666 |
+
needed: int,
|
| 667 |
+
rng: random.Random,
|
| 668 |
+
) -> list[Issue]:
|
| 669 |
+
"""Plant statistically valid synthetic issues when real count < minimum.
|
| 670 |
+
|
| 671 |
+
Never touches: PK column, natural-key column, columns already in existing.
|
| 672 |
+
"""
|
| 673 |
+
synthetic: list[Issue] = []
|
| 674 |
+
touched_cells: set[tuple[int, str]] = {(i.row_id, i.column) for i in existing if i.column}
|
| 675 |
+
natural_key = _find_natural_key_col(profile, records, pk_col)
|
| 676 |
+
|
| 677 |
+
# Columns available for synthetic planting
|
| 678 |
+
def available_cols(dtype_filter=None) -> list[str]:
|
| 679 |
+
cols = []
|
| 680 |
+
for col, p in profile.items():
|
| 681 |
+
if col == pk_col or col == "_source_format":
|
| 682 |
+
continue
|
| 683 |
+
if col == natural_key:
|
| 684 |
+
continue
|
| 685 |
+
if dtype_filter and p["dtype"] not in dtype_filter:
|
| 686 |
+
continue
|
| 687 |
+
cols.append(col)
|
| 688 |
+
return cols
|
| 689 |
+
|
| 690 |
+
table = profile[pk_col]["table"] if pk_col in profile else "dataset"
|
| 691 |
+
|
| 692 |
+
# Candidate issue types to synthesise (ordered by preference)
|
| 693 |
+
type_order = []
|
| 694 |
+
if "null" in allowed_checks:
|
| 695 |
+
type_order.append("null")
|
| 696 |
+
if "type_error" in allowed_checks:
|
| 697 |
+
type_order.append("type_error")
|
| 698 |
+
if "constraint" in allowed_checks:
|
| 699 |
+
type_order.append("constraint")
|
| 700 |
+
|
| 701 |
+
planted = 0
|
| 702 |
+
attempt = 0
|
| 703 |
+
max_attempts = needed * 20
|
| 704 |
+
|
| 705 |
+
while planted < needed and attempt < max_attempts:
|
| 706 |
+
attempt += 1
|
| 707 |
+
issue_type = type_order[planted % len(type_order)]
|
| 708 |
+
|
| 709 |
+
if issue_type == "null":
|
| 710 |
+
cols = available_cols()
|
| 711 |
+
if not cols:
|
| 712 |
+
continue
|
| 713 |
+
col = rng.choice(cols)
|
| 714 |
+
eligible = [
|
| 715 |
+
r for r in records
|
| 716 |
+
if not _is_null(r.get(col))
|
| 717 |
+
and (int(r[pk_col]), col) not in touched_cells
|
| 718 |
+
]
|
| 719 |
+
if not eligible:
|
| 720 |
+
continue
|
| 721 |
+
row = rng.choice(eligible)
|
| 722 |
+
rid = int(row[pk_col])
|
| 723 |
+
original = row[col]
|
| 724 |
+
# Plant NULL in the live records
|
| 725 |
+
row[col] = None
|
| 726 |
+
touched_cells.add((rid, col))
|
| 727 |
+
synthetic.append(Issue(
|
| 728 |
+
issue_id=_make_id(table, rid, col, "null"),
|
| 729 |
+
issue_type="null",
|
| 730 |
+
table=table,
|
| 731 |
+
row_id=rid,
|
| 732 |
+
column=col,
|
| 733 |
+
correct=original,
|
| 734 |
+
confidence=0.95,
|
| 735 |
+
))
|
| 736 |
+
planted += 1
|
| 737 |
+
|
| 738 |
+
elif issue_type == "type_error":
|
| 739 |
+
cols = available_cols(dtype_filter=("int", "float"))
|
| 740 |
+
if not cols:
|
| 741 |
+
continue
|
| 742 |
+
col = rng.choice(cols)
|
| 743 |
+
eligible = [
|
| 744 |
+
r for r in records
|
| 745 |
+
if not _is_null(r.get(col))
|
| 746 |
+
and _can_cast_float(r.get(col))
|
| 747 |
+
and (int(r[pk_col]), col) not in touched_cells
|
| 748 |
+
]
|
| 749 |
+
if not eligible:
|
| 750 |
+
continue
|
| 751 |
+
row = rng.choice(eligible)
|
| 752 |
+
rid = int(row[pk_col])
|
| 753 |
+
# Plant "INVALID_TEXT" in the live records
|
| 754 |
+
row[col] = "INVALID_TEXT"
|
| 755 |
+
col_median = _median([
|
| 756 |
+
float(r[col]) for r in records
|
| 757 |
+
if not _is_null(r.get(col)) and _can_cast_float(r.get(col))
|
| 758 |
+
])
|
| 759 |
+
touched_cells.add((rid, col))
|
| 760 |
+
synthetic.append(Issue(
|
| 761 |
+
issue_id=_make_id(table, rid, col, "type_error"),
|
| 762 |
+
issue_type="type_error",
|
| 763 |
+
table=table,
|
| 764 |
+
row_id=rid,
|
| 765 |
+
column=col,
|
| 766 |
+
correct=col_median,
|
| 767 |
+
confidence=1.0,
|
| 768 |
+
))
|
| 769 |
+
planted += 1
|
| 770 |
+
|
| 771 |
+
elif issue_type == "constraint":
|
| 772 |
+
cols = [
|
| 773 |
+
col for col in available_cols(dtype_filter=("int", "float"))
|
| 774 |
+
if profile[col].get("must_be_positive", False)
|
| 775 |
+
]
|
| 776 |
+
if not cols:
|
| 777 |
+
# Fall back to any positive-valued numeric col
|
| 778 |
+
cols = [
|
| 779 |
+
col for col in available_cols(dtype_filter=("int", "float"))
|
| 780 |
+
if profile[col].get("min", 0) is not None
|
| 781 |
+
and (profile[col].get("min") or 0) > 0
|
| 782 |
+
]
|
| 783 |
+
if not cols:
|
| 784 |
+
continue
|
| 785 |
+
col = rng.choice(cols)
|
| 786 |
+
eligible = [
|
| 787 |
+
r for r in records
|
| 788 |
+
if not _is_null(r.get(col))
|
| 789 |
+
and _can_cast_float(r.get(col))
|
| 790 |
+
and float(r.get(col, 0)) > 0
|
| 791 |
+
and (int(r[pk_col]), col) not in touched_cells
|
| 792 |
+
]
|
| 793 |
+
if not eligible:
|
| 794 |
+
continue
|
| 795 |
+
row = rng.choice(eligible)
|
| 796 |
+
rid = int(row[pk_col])
|
| 797 |
+
original = float(row[col])
|
| 798 |
+
row[col] = -abs(original)
|
| 799 |
+
touched_cells.add((rid, col))
|
| 800 |
+
synthetic.append(Issue(
|
| 801 |
+
issue_id=_make_id(table, rid, col, "constraint"),
|
| 802 |
+
issue_type="constraint",
|
| 803 |
+
table=table,
|
| 804 |
+
row_id=rid,
|
| 805 |
+
column=col,
|
| 806 |
+
correct=original,
|
| 807 |
+
confidence=0.95,
|
| 808 |
+
))
|
| 809 |
+
planted += 1
|
| 810 |
+
|
| 811 |
+
return synthetic
|
| 812 |
+
|
| 813 |
+
|
| 814 |
+
# ---------------------------------------------------------------------------
|
| 815 |
+
# Utility helpers
|
| 816 |
+
# ---------------------------------------------------------------------------
|
| 817 |
+
|
| 818 |
+
def _find_pk_col(records: list[dict]) -> str:
|
| 819 |
+
"""Return the primary key column name from records.
|
| 820 |
+
|
| 821 |
+
Looks for 'id' column first, then falls back to first column.
|
| 822 |
+
"""
|
| 823 |
+
if not records:
|
| 824 |
+
return "id"
|
| 825 |
+
keys = list(records[0].keys())
|
| 826 |
+
# Prefer explicit 'id' column
|
| 827 |
+
for k in keys:
|
| 828 |
+
if k.lower() == "id":
|
| 829 |
+
return k
|
| 830 |
+
# Fall back to first column
|
| 831 |
+
return keys[0]
|
| 832 |
+
|
| 833 |
+
|
| 834 |
+
def _find_natural_key_col(
|
| 835 |
+
profile: dict[str, dict],
|
| 836 |
+
records: list[dict],
|
| 837 |
+
pk_col: str,
|
| 838 |
+
) -> Optional[str]:
|
| 839 |
+
"""Return the natural key column if one exists, else None.
|
| 840 |
+
|
| 841 |
+
Natural key: high uniqueness (>= 70%), not float dtype, not PK,
|
| 842 |
+
name contains: name, email, code, ref, id_, key, title.
|
| 843 |
+
|
| 844 |
+
Uses 70% threshold (not strict all_unique) so that dirty datasets with
|
| 845 |
+
a small number of duplicates still have their natural key identified.
|
| 846 |
+
"""
|
| 847 |
+
KEY_HINTS = ("name", "email", "code", "ref", "id_", "key", "title")
|
| 848 |
+
for col, p in profile.items():
|
| 849 |
+
if col == pk_col or col == "_source_format":
|
| 850 |
+
continue
|
| 851 |
+
if p["dtype"] == "float":
|
| 852 |
+
continue
|
| 853 |
+
row_count = p.get("row_count", 0)
|
| 854 |
+
unique_count = p.get("unique_count", 0)
|
| 855 |
+
if row_count == 0:
|
| 856 |
+
continue
|
| 857 |
+
uniqueness_ratio = unique_count / row_count
|
| 858 |
+
if uniqueness_ratio < 0.70:
|
| 859 |
+
continue
|
| 860 |
+
col_lower = col.lower()
|
| 861 |
+
if any(hint in col_lower for hint in KEY_HINTS):
|
| 862 |
+
return col
|
| 863 |
+
return None
|
| 864 |
+
|
| 865 |
+
|
| 866 |
+
def _infer_correct_null(
|
| 867 |
+
col: str,
|
| 868 |
+
row: dict,
|
| 869 |
+
records: list[dict],
|
| 870 |
+
p: dict,
|
| 871 |
+
) -> Any:
|
| 872 |
+
"""Best-guess correct value for a null cell."""
|
| 873 |
+
if p["dtype"] in ("int", "float"):
|
| 874 |
+
non_null = [
|
| 875 |
+
float(r[col]) for r in records
|
| 876 |
+
if not _is_null(r.get(col)) and _can_cast_float(r.get(col))
|
| 877 |
+
]
|
| 878 |
+
if non_null:
|
| 879 |
+
return round(_median(non_null), 4)
|
| 880 |
+
return SENTINEL_UNKNOWN
|
| 881 |
+
|
| 882 |
+
|
| 883 |
+
def _median(values: list[float]) -> Optional[float]:
|
| 884 |
+
if not values:
|
| 885 |
+
return None
|
| 886 |
+
s = sorted(values)
|
| 887 |
+
n = len(s)
|
| 888 |
+
mid = n // 2
|
| 889 |
+
if n % 2 == 0:
|
| 890 |
+
return (s[mid - 1] + s[mid]) / 2.0
|
| 891 |
+
return s[mid]
|
| 892 |
+
|
| 893 |
+
|
| 894 |
+
def _can_cast_float(value: Any) -> bool:
|
| 895 |
+
try:
|
| 896 |
+
float(str(value))
|
| 897 |
+
return True
|
| 898 |
+
except (ValueError, TypeError):
|
| 899 |
+
return False
|
| 900 |
+
|
| 901 |
+
|
| 902 |
+
def _is_null(value: Any) -> bool:
|
| 903 |
+
if value is None:
|
| 904 |
+
return True
|
| 905 |
+
if isinstance(value, float) and math.isnan(value):
|
| 906 |
+
return True
|
| 907 |
+
if isinstance(value, str) and value.strip() == "":
|
| 908 |
+
return True
|
| 909 |
+
return False
|
| 910 |
+
|
| 911 |
+
|
| 912 |
+
def _make_id(table: str, row_id: int, col: Optional[str], issue_type: str) -> str:
|
| 913 |
+
return f"{table}_{row_id}_{col or 'row'}_{issue_type}"
|
| 914 |
+
|
| 915 |
+
|
| 916 |
+
def _primary_table_name(conn: sqlite3.Connection) -> Optional[str]:
|
| 917 |
+
rows = conn.execute(
|
| 918 |
+
"SELECT name FROM sqlite_master WHERE type='table' ORDER BY rowid"
|
| 919 |
+
).fetchall()
|
| 920 |
+
return rows[0][0] if rows else None
|
sqlsherlock_env/server/requirements.txt
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi>=0.115.0
|
| 2 |
+
uvicorn[standard]>=0.30.0
|
| 3 |
+
pydantic>=2.8.2
|
| 4 |
+
openenv-core>=0.2.1
|
| 5 |
+
openai>=1.40.0
|
| 6 |
+
python-multipart>=0.0.9
|
| 7 |
+
datasets>=2.20.0
|
| 8 |
+
pandas>=2.0.0
|
| 9 |
+
pyarrow>=14.0.0
|
sqlsherlock_env/server/reward.py
ADDED
|
@@ -0,0 +1,411 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
Reward calculator for SQLSherlock-Env.
|
| 9 |
+
|
| 10 |
+
Dense per-step rewards with hard caps on investigation bonuses.
|
| 11 |
+
Every action produces a reward signal so the RL agent gets
|
| 12 |
+
continuous feedback throughout the episode.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import math
|
| 16 |
+
from dataclasses import dataclass, field
|
| 17 |
+
from typing import Any, Optional
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# ---------------------------------------------------------------------------
|
| 21 |
+
# Per-action reward magnitudes
|
| 22 |
+
# ---------------------------------------------------------------------------
|
| 23 |
+
|
| 24 |
+
INVEST_REWARDS: dict[str, float] = {
|
| 25 |
+
"inspect": 0.02,
|
| 26 |
+
"profile_column": 0.03,
|
| 27 |
+
"run_sql": 0.03,
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
INVEST_CAPS: dict[str, int] = {
|
| 31 |
+
"inspect": 3,
|
| 32 |
+
"profile_column": 3,
|
| 33 |
+
"run_sql": 3,
|
| 34 |
+
"validate": 2,
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
FIX_CORRECT: float = 0.15
|
| 38 |
+
FIX_FALSE_POSITIVE: float = -0.20
|
| 39 |
+
FIX_TRAP: float = -0.40
|
| 40 |
+
FIX_WRONG_VALUE: float = -0.10
|
| 41 |
+
|
| 42 |
+
DELETE_CORRECT: float = 0.15
|
| 43 |
+
DELETE_FALSE_POSITIVE: float = -0.20
|
| 44 |
+
|
| 45 |
+
SUBMIT_ALL_RESOLVED: float = 0.10
|
| 46 |
+
SUBMIT_ISSUES_OPEN: float = -0.10
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
# ---------------------------------------------------------------------------
|
| 50 |
+
# InvestCounter — tracks capped investigation calls
|
| 51 |
+
# ---------------------------------------------------------------------------
|
| 52 |
+
|
| 53 |
+
class InvestCounter:
|
| 54 |
+
"""Tracks how many times each investigation action has been called.
|
| 55 |
+
|
| 56 |
+
Once an action type hits its cap, further calls still execute
|
| 57 |
+
but return 0 reward (no error raised).
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
def __init__(self) -> None:
|
| 61 |
+
self._counts: dict[str, int] = {k: 0 for k in INVEST_CAPS}
|
| 62 |
+
|
| 63 |
+
def record(self, action_type: str) -> float:
|
| 64 |
+
"""Record one call of *action_type* and return the reward earned.
|
| 65 |
+
|
| 66 |
+
Returns 0.0 if the cap has already been reached.
|
| 67 |
+
Always increments the counter so validate_reward() can detect over-cap.
|
| 68 |
+
"""
|
| 69 |
+
if action_type not in INVEST_CAPS:
|
| 70 |
+
return 0.0
|
| 71 |
+
|
| 72 |
+
cap = INVEST_CAPS[action_type]
|
| 73 |
+
current = self._counts.get(action_type, 0)
|
| 74 |
+
|
| 75 |
+
# Always increment so validate_reward() can detect over-cap correctly.
|
| 76 |
+
self._counts[action_type] = current + 1
|
| 77 |
+
|
| 78 |
+
if current >= cap:
|
| 79 |
+
return 0.0 # cap already hit before this call
|
| 80 |
+
|
| 81 |
+
if action_type == "validate":
|
| 82 |
+
# Reward computed externally (depends on checks_passed)
|
| 83 |
+
return 0.0 # caller computes and adds the validate reward
|
| 84 |
+
|
| 85 |
+
return INVEST_REWARDS.get(action_type, 0.0)
|
| 86 |
+
|
| 87 |
+
def validate_reward(self, checks_passed: int, total_checks: int) -> float:
|
| 88 |
+
"""Return the validate reward if under cap, else 0.0.
|
| 89 |
+
|
| 90 |
+
Must be called AFTER record("validate") so the count is incremented.
|
| 91 |
+
"""
|
| 92 |
+
count = self._counts.get("validate", 0)
|
| 93 |
+
if count > INVEST_CAPS["validate"]: # count already incremented by record()
|
| 94 |
+
return 0.0
|
| 95 |
+
# count == cap means this IS the last rewarded call (e.g. cap=2, count=2 → reward)
|
| 96 |
+
# count > cap means over the limit → 0 (checked above)
|
| 97 |
+
if total_checks == 0:
|
| 98 |
+
return 0.0
|
| 99 |
+
return round(0.05 * (checks_passed / total_checks), 4)
|
| 100 |
+
|
| 101 |
+
def count(self, action_type: str) -> int:
|
| 102 |
+
return self._counts.get(action_type, 0)
|
| 103 |
+
|
| 104 |
+
def to_dict(self) -> dict:
|
| 105 |
+
return dict(self._counts)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
# ---------------------------------------------------------------------------
|
| 109 |
+
# RB — per-step reward breakdown
|
| 110 |
+
# ---------------------------------------------------------------------------
|
| 111 |
+
|
| 112 |
+
@dataclass
|
| 113 |
+
class RB:
|
| 114 |
+
"""Reward breakdown for one step.
|
| 115 |
+
|
| 116 |
+
Stored in reward_trace every step so judges (and the agent) can
|
| 117 |
+
see exactly how reward was composed.
|
| 118 |
+
"""
|
| 119 |
+
invest: float = 0.0 # investigation bonus
|
| 120 |
+
fix_delta: float = 0.0 # fix / delete reward (positive or negative)
|
| 121 |
+
validate_b: float = 0.0 # validate bonus
|
| 122 |
+
penalty: float = 0.0 # trap / fp / submit penalties (stored negative)
|
| 123 |
+
|
| 124 |
+
@property
|
| 125 |
+
def total(self) -> float:
|
| 126 |
+
raw = self.invest + self.fix_delta + self.validate_b + self.penalty
|
| 127 |
+
return max(-1.0, min(1.0, round(raw, 4)))
|
| 128 |
+
|
| 129 |
+
def to_dict(self) -> dict:
|
| 130 |
+
return {
|
| 131 |
+
"invest": round(self.invest, 4),
|
| 132 |
+
"fix_delta": round(self.fix_delta, 4),
|
| 133 |
+
"validate_b": round(self.validate_b, 4),
|
| 134 |
+
"penalty": round(self.penalty, 4),
|
| 135 |
+
"total": self.total,
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
# ---------------------------------------------------------------------------
|
| 140 |
+
# calc — main reward function called from environment.py
|
| 141 |
+
# ---------------------------------------------------------------------------
|
| 142 |
+
|
| 143 |
+
def calc(
|
| 144 |
+
action_type: str,
|
| 145 |
+
db: Any, # DatabaseEngine (typed loosely to avoid circular)
|
| 146 |
+
counter: InvestCounter,
|
| 147 |
+
action: Any, # SQLSherlockAction
|
| 148 |
+
validation_result: Optional[Any] = None, # ValidationResult | None
|
| 149 |
+
) -> RB:
|
| 150 |
+
"""Compute per-step reward for one action.
|
| 151 |
+
|
| 152 |
+
Args:
|
| 153 |
+
action_type: The action type string.
|
| 154 |
+
db: Live DatabaseEngine instance.
|
| 155 |
+
counter: Shared InvestCounter for this episode.
|
| 156 |
+
action: The SQLSherlockAction taken.
|
| 157 |
+
validation_result: Result from Validator.validate() if action_type=="validate".
|
| 158 |
+
|
| 159 |
+
Returns:
|
| 160 |
+
RB breakdown. Caller adds rb.to_dict() to reward_trace.
|
| 161 |
+
"""
|
| 162 |
+
rb = RB()
|
| 163 |
+
|
| 164 |
+
# ------------------------------------------------------------------
|
| 165 |
+
# Investigation actions
|
| 166 |
+
# ------------------------------------------------------------------
|
| 167 |
+
if action_type in ("inspect", "profile_column", "run_sql"):
|
| 168 |
+
rb.invest = counter.record(action_type)
|
| 169 |
+
return rb
|
| 170 |
+
|
| 171 |
+
# ------------------------------------------------------------------
|
| 172 |
+
# Validate
|
| 173 |
+
# ------------------------------------------------------------------
|
| 174 |
+
if action_type == "validate":
|
| 175 |
+
counter.record("validate") # increment count (may be over cap)
|
| 176 |
+
if validation_result is not None:
|
| 177 |
+
rb.validate_b = counter.validate_reward(
|
| 178 |
+
validation_result.checks_passed,
|
| 179 |
+
validation_result.total_checks,
|
| 180 |
+
)
|
| 181 |
+
return rb
|
| 182 |
+
|
| 183 |
+
# ------------------------------------------------------------------
|
| 184 |
+
# fix_cell
|
| 185 |
+
# ------------------------------------------------------------------
|
| 186 |
+
if action_type == "fix_cell":
|
| 187 |
+
table = action.table or db.primary_table
|
| 188 |
+
row_id = action.row_id
|
| 189 |
+
column = action.column
|
| 190 |
+
|
| 191 |
+
if row_id is None or column is None:
|
| 192 |
+
rb.penalty = FIX_FALSE_POSITIVE
|
| 193 |
+
return rb
|
| 194 |
+
|
| 195 |
+
# Trap check (task3 only — highest priority)
|
| 196 |
+
trap = db.trap
|
| 197 |
+
if trap and trap.row_id == row_id and trap.column == column:
|
| 198 |
+
rb.penalty = FIX_TRAP
|
| 199 |
+
return rb
|
| 200 |
+
|
| 201 |
+
# Is this cell in the issue registry?
|
| 202 |
+
issue_match = _find_issue(db, row_id, column)
|
| 203 |
+
|
| 204 |
+
if issue_match is None:
|
| 205 |
+
# Not a known issue — check if we changed a clean original cell
|
| 206 |
+
orig = _original_val(db, table, row_id, column)
|
| 207 |
+
current_val = action.value
|
| 208 |
+
if orig is not None and not _values_match(current_val, orig):
|
| 209 |
+
rb.penalty = FIX_FALSE_POSITIVE
|
| 210 |
+
# If we can't find original (row may not exist), small FP penalty
|
| 211 |
+
elif orig is None:
|
| 212 |
+
rb.penalty = FIX_FALSE_POSITIVE
|
| 213 |
+
return rb
|
| 214 |
+
|
| 215 |
+
# Issue exists — check if the fix actually resolves it
|
| 216 |
+
if _fix_resolves(issue_match, action.value, db):
|
| 217 |
+
rb.fix_delta = FIX_CORRECT
|
| 218 |
+
else:
|
| 219 |
+
rb.fix_delta = FIX_WRONG_VALUE
|
| 220 |
+
|
| 221 |
+
return rb
|
| 222 |
+
|
| 223 |
+
# ------------------------------------------------------------------
|
| 224 |
+
# delete_row
|
| 225 |
+
# ------------------------------------------------------------------
|
| 226 |
+
if action_type == "delete_row":
|
| 227 |
+
table = action.table or db.primary_table
|
| 228 |
+
row_id = action.row_id
|
| 229 |
+
|
| 230 |
+
if row_id is None:
|
| 231 |
+
rb.penalty = DELETE_FALSE_POSITIVE
|
| 232 |
+
return rb
|
| 233 |
+
|
| 234 |
+
# Valid delete: row must be a duplicate or fk_violation issue
|
| 235 |
+
valid_issue = any(
|
| 236 |
+
iss.row_id == row_id and iss.issue_type in ("duplicate", "fk_violation")
|
| 237 |
+
for iss in db.issue_registry
|
| 238 |
+
)
|
| 239 |
+
if valid_issue:
|
| 240 |
+
rb.fix_delta = DELETE_CORRECT
|
| 241 |
+
else:
|
| 242 |
+
rb.penalty = DELETE_FALSE_POSITIVE
|
| 243 |
+
|
| 244 |
+
return rb
|
| 245 |
+
|
| 246 |
+
# ------------------------------------------------------------------
|
| 247 |
+
# fix_column (bulk fix)
|
| 248 |
+
# ------------------------------------------------------------------
|
| 249 |
+
if action_type == "fix_column":
|
| 250 |
+
column = action.column
|
| 251 |
+
if column is None:
|
| 252 |
+
rb.penalty = FIX_FALSE_POSITIVE
|
| 253 |
+
return rb
|
| 254 |
+
|
| 255 |
+
# Count how many registered issues in this column were null-type
|
| 256 |
+
column_issues = [
|
| 257 |
+
iss for iss in db.issue_registry
|
| 258 |
+
if iss.column == column and iss.issue_type in ("null", "type_error", "whitespace")
|
| 259 |
+
]
|
| 260 |
+
if column_issues:
|
| 261 |
+
# Reward proportional to issues resolved (capped at +0.15)
|
| 262 |
+
resolved_fraction = min(len(column_issues) / max(db.total_issues, 1), 1.0)
|
| 263 |
+
rb.fix_delta = round(FIX_CORRECT * (1.0 + resolved_fraction), 4) # +0.15 to +0.30
|
| 264 |
+
else:
|
| 265 |
+
# No registered issues in this column — possible false positive
|
| 266 |
+
rb.penalty = FIX_FALSE_POSITIVE * 0.5 # lighter penalty for bulk ops
|
| 267 |
+
return rb
|
| 268 |
+
|
| 269 |
+
# ------------------------------------------------------------------
|
| 270 |
+
# submit
|
| 271 |
+
# ------------------------------------------------------------------
|
| 272 |
+
if action_type == "submit":
|
| 273 |
+
if db.issues_remaining() == 0:
|
| 274 |
+
rb.fix_delta = SUBMIT_ALL_RESOLVED
|
| 275 |
+
else:
|
| 276 |
+
rb.penalty = SUBMIT_ISSUES_OPEN
|
| 277 |
+
return rb
|
| 278 |
+
|
| 279 |
+
# ------------------------------------------------------------------
|
| 280 |
+
# export (no direct step reward; grader scores the file)
|
| 281 |
+
# ------------------------------------------------------------------
|
| 282 |
+
if action_type == "export":
|
| 283 |
+
return rb
|
| 284 |
+
|
| 285 |
+
return rb
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
# ---------------------------------------------------------------------------
|
| 289 |
+
# Helpers
|
| 290 |
+
# ---------------------------------------------------------------------------
|
| 291 |
+
|
| 292 |
+
def _find_issue(db: Any, row_id: int, column: str):
|
| 293 |
+
"""Return the matching Issue from the registry using O(1) dict lookup.
|
| 294 |
+
|
| 295 |
+
The issue index is lazily built and cached on the db object.
|
| 296 |
+
"""
|
| 297 |
+
if not hasattr(db, "_issue_index"):
|
| 298 |
+
db._issue_index = {
|
| 299 |
+
(iss.row_id, iss.column): iss
|
| 300 |
+
for iss in db.issue_registry
|
| 301 |
+
if iss.column is not None
|
| 302 |
+
}
|
| 303 |
+
return db._issue_index.get((row_id, column))
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
def _original_val(db: Any, table: str, row_id: int, column: str) -> Any:
|
| 307 |
+
"""Return the original (pre-episode) value for a cell using O(1) dict lookup.
|
| 308 |
+
|
| 309 |
+
The originals index is lazily built and cached on the db object.
|
| 310 |
+
"""
|
| 311 |
+
cache_key = f"_orig_index_{table}"
|
| 312 |
+
if not hasattr(db, cache_key):
|
| 313 |
+
originals = db._originals.get(table, [])
|
| 314 |
+
pk = db.pk_col
|
| 315 |
+
setattr(db, cache_key, {row.get(pk): row for row in originals})
|
| 316 |
+
orig_map = getattr(db, cache_key)
|
| 317 |
+
row = orig_map.get(row_id)
|
| 318 |
+
return row.get(column) if row is not None else None
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
def _fix_resolves(issue: Any, new_value: Any, db: Any) -> bool:
|
| 322 |
+
"""Return True if *new_value* resolves *issue*."""
|
| 323 |
+
from server.issue_detector import SENTINEL_UNKNOWN
|
| 324 |
+
|
| 325 |
+
itype = issue.issue_type
|
| 326 |
+
|
| 327 |
+
if itype == "null":
|
| 328 |
+
if _is_null(new_value):
|
| 329 |
+
return False
|
| 330 |
+
if issue.correct == SENTINEL_UNKNOWN:
|
| 331 |
+
return True # any non-null value accepted
|
| 332 |
+
# Accept the fix if the value matches OR is the same type.
|
| 333 |
+
# For numeric nulls: any valid numeric value is a reasonable fix
|
| 334 |
+
# (the agent imputes from column statistics, not from our stored correct).
|
| 335 |
+
if _values_match(new_value, issue.correct):
|
| 336 |
+
return True
|
| 337 |
+
# Type-compatible acceptance: if correct is numeric, accept any numeric
|
| 338 |
+
if _can_cast_float(issue.correct) and _can_cast_float(new_value):
|
| 339 |
+
return True
|
| 340 |
+
# If correct is string, accept any non-null string
|
| 341 |
+
if isinstance(issue.correct, str) and isinstance(new_value, str):
|
| 342 |
+
return True
|
| 343 |
+
return False
|
| 344 |
+
|
| 345 |
+
if itype == "type_error":
|
| 346 |
+
return _can_cast_float(new_value)
|
| 347 |
+
|
| 348 |
+
if itype == "constraint":
|
| 349 |
+
try:
|
| 350 |
+
return float(str(new_value)) >= 0
|
| 351 |
+
except (ValueError, TypeError):
|
| 352 |
+
return False
|
| 353 |
+
|
| 354 |
+
if itype == "outlier":
|
| 355 |
+
# Resolves if new z-score <= 3
|
| 356 |
+
profile = db._profiles.get(db.primary_table, {})
|
| 357 |
+
p = profile.get(issue.column, {})
|
| 358 |
+
mean = p.get("mean")
|
| 359 |
+
std = p.get("std")
|
| 360 |
+
if mean is None or not std or std == 0:
|
| 361 |
+
return True # can't compute z — assume resolved
|
| 362 |
+
try:
|
| 363 |
+
z = abs(float(str(new_value)) - mean) / std
|
| 364 |
+
return z <= 3.0
|
| 365 |
+
except (ValueError, TypeError):
|
| 366 |
+
return False
|
| 367 |
+
|
| 368 |
+
if itype == "whitespace":
|
| 369 |
+
# Resolved if the new value has no leading/trailing/excessive whitespace
|
| 370 |
+
if _is_null(new_value):
|
| 371 |
+
return False
|
| 372 |
+
s = str(new_value)
|
| 373 |
+
return s == " ".join(s.split())
|
| 374 |
+
|
| 375 |
+
if itype == "inconsistent_category":
|
| 376 |
+
# Resolved if new value matches the correct (dominant) form
|
| 377 |
+
if _is_null(new_value):
|
| 378 |
+
return False
|
| 379 |
+
return _values_match(new_value, issue.correct)
|
| 380 |
+
|
| 381 |
+
return False
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
def _values_match(a: Any, b: Any) -> bool:
|
| 385 |
+
"""Loose equality: handles numeric vs string comparisons."""
|
| 386 |
+
if a is None and b is None:
|
| 387 |
+
return True
|
| 388 |
+
if a is None or b is None:
|
| 389 |
+
return False
|
| 390 |
+
try:
|
| 391 |
+
return math.isclose(float(str(a)), float(str(b)), rel_tol=1e-4)
|
| 392 |
+
except (ValueError, TypeError):
|
| 393 |
+
return str(a).strip().lower() == str(b).strip().lower()
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
def _is_null(value: Any) -> bool:
|
| 397 |
+
if value is None:
|
| 398 |
+
return True
|
| 399 |
+
if isinstance(value, float) and math.isnan(value):
|
| 400 |
+
return True
|
| 401 |
+
if isinstance(value, str) and value.strip() == "":
|
| 402 |
+
return True
|
| 403 |
+
return False
|
| 404 |
+
|
| 405 |
+
|
| 406 |
+
def _can_cast_float(value: Any) -> bool:
|
| 407 |
+
try:
|
| 408 |
+
float(str(value))
|
| 409 |
+
return True
|
| 410 |
+
except (ValueError, TypeError):
|
| 411 |
+
return False
|
sqlsherlock_env/server/schema_profiler.py
ADDED
|
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
Schema profiler for SQLSherlock-Env.
|
| 9 |
+
|
| 10 |
+
Computes per-column statistical profiles from raw records.
|
| 11 |
+
Used by DatabaseEngine at load time and by issue_detector / validator.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import math
|
| 15 |
+
import sqlite3
|
| 16 |
+
from typing import Any, Optional
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
# ---------------------------------------------------------------------------
|
| 20 |
+
# Public API
|
| 21 |
+
# ---------------------------------------------------------------------------
|
| 22 |
+
|
| 23 |
+
def profile_table(
|
| 24 |
+
table: str,
|
| 25 |
+
records: list[dict],
|
| 26 |
+
conn: Optional[sqlite3.Connection] = None,
|
| 27 |
+
) -> dict[str, dict]:
|
| 28 |
+
"""Return a statistical profile for every column in *records*.
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
table: Table name (stored in the profile for reference).
|
| 32 |
+
records: List of row dicts (already coerced to Python types).
|
| 33 |
+
conn: Optional SQLite connection (unused currently; reserved for
|
| 34 |
+
future SQL-based profiling).
|
| 35 |
+
|
| 36 |
+
Returns:
|
| 37 |
+
Dict keyed by column name. Each value is a column-profile dict::
|
| 38 |
+
|
| 39 |
+
{
|
| 40 |
+
"table": str,
|
| 41 |
+
"column": str,
|
| 42 |
+
"dtype": "int" | "float" | "str" | "bool" | "unknown",
|
| 43 |
+
"row_count": int,
|
| 44 |
+
"null_count": int,
|
| 45 |
+
"null_rate": float, # 0.0 – 1.0
|
| 46 |
+
"unique_count": int,
|
| 47 |
+
"all_unique": bool,
|
| 48 |
+
"mean": float | None, # numeric only
|
| 49 |
+
"std": float | None, # numeric only
|
| 50 |
+
"min": float | None, # numeric only
|
| 51 |
+
"max": float | None, # numeric only
|
| 52 |
+
"must_be_positive": bool, # numeric only
|
| 53 |
+
"z_scores": dict[int, float], # row_id → z
|
| 54 |
+
"sample_values": list[Any], # up to 5 non-null values
|
| 55 |
+
}
|
| 56 |
+
"""
|
| 57 |
+
if not records:
|
| 58 |
+
return {}
|
| 59 |
+
|
| 60 |
+
columns = list(records[0].keys())
|
| 61 |
+
profile: dict[str, dict] = {}
|
| 62 |
+
|
| 63 |
+
for col in columns:
|
| 64 |
+
values = [row.get(col) for row in records]
|
| 65 |
+
col_profile = _profile_column(table, col, values, records)
|
| 66 |
+
profile[col] = col_profile
|
| 67 |
+
|
| 68 |
+
return profile
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def _profile_column(
|
| 72 |
+
table: str,
|
| 73 |
+
col: str,
|
| 74 |
+
values: list[Any],
|
| 75 |
+
records: list[dict],
|
| 76 |
+
) -> dict:
|
| 77 |
+
"""Compute statistics for a single column."""
|
| 78 |
+
row_count = len(values)
|
| 79 |
+
null_count = sum(1 for v in values if _is_null(v))
|
| 80 |
+
null_rate = null_count / row_count if row_count > 0 else 0.0
|
| 81 |
+
|
| 82 |
+
non_null = [v for v in values if not _is_null(v)]
|
| 83 |
+
unique_count = len(set(str(v) for v in non_null))
|
| 84 |
+
# all_unique: every non-null value is distinct AND covers all rows
|
| 85 |
+
# Compare against row_count so that a column with 1 null among unique values
|
| 86 |
+
# is NOT considered all-unique (the null breaks the uniqueness guarantee)
|
| 87 |
+
all_unique = (unique_count == row_count) and row_count > 0 and null_count == 0
|
| 88 |
+
|
| 89 |
+
dtype = _infer_dtype(non_null)
|
| 90 |
+
|
| 91 |
+
# Numeric statistics
|
| 92 |
+
mean = std = mn = mx = None
|
| 93 |
+
must_be_positive = False
|
| 94 |
+
z_scores: dict[int, float] = {}
|
| 95 |
+
|
| 96 |
+
if dtype in ("int", "float") and non_null:
|
| 97 |
+
numeric_vals = []
|
| 98 |
+
for v in non_null:
|
| 99 |
+
try:
|
| 100 |
+
numeric_vals.append(float(v))
|
| 101 |
+
except (ValueError, TypeError):
|
| 102 |
+
pass
|
| 103 |
+
|
| 104 |
+
if numeric_vals:
|
| 105 |
+
mean = sum(numeric_vals) / len(numeric_vals)
|
| 106 |
+
variance = sum((x - mean) ** 2 for x in numeric_vals) / len(numeric_vals)
|
| 107 |
+
std = math.sqrt(variance)
|
| 108 |
+
mn = min(numeric_vals)
|
| 109 |
+
mx = max(numeric_vals)
|
| 110 |
+
|
| 111 |
+
# must_be_positive: all non-null values are >= 0 and at least one > 0
|
| 112 |
+
# Handles columns like age/fare that should never be negative
|
| 113 |
+
must_be_positive = len(numeric_vals) > 0 and all(v >= 0 for v in numeric_vals) and any(v > 0 for v in numeric_vals)
|
| 114 |
+
|
| 115 |
+
# z-scores per row keyed by primary key value
|
| 116 |
+
# Use find_primary_key() for accuracy; fall back to first column
|
| 117 |
+
pk_col = find_primary_key(records) if records else None
|
| 118 |
+
if pk_col is None and records:
|
| 119 |
+
pk_col = list(records[0].keys())[0]
|
| 120 |
+
for row in records:
|
| 121 |
+
raw = row.get(col)
|
| 122 |
+
if _is_null(raw):
|
| 123 |
+
continue
|
| 124 |
+
try:
|
| 125 |
+
fval = float(raw)
|
| 126 |
+
except (ValueError, TypeError):
|
| 127 |
+
continue
|
| 128 |
+
rid = row.get(pk_col) if pk_col else None
|
| 129 |
+
if rid is not None and std > 0:
|
| 130 |
+
z = (fval - mean) / std
|
| 131 |
+
z_scores[int(rid)] = round(z, 4)
|
| 132 |
+
elif rid is not None:
|
| 133 |
+
z_scores[int(rid)] = 0.0
|
| 134 |
+
|
| 135 |
+
# Sample values: up to 5 non-null
|
| 136 |
+
sample_values = non_null[:5]
|
| 137 |
+
|
| 138 |
+
return {
|
| 139 |
+
"table": table,
|
| 140 |
+
"column": col,
|
| 141 |
+
"dtype": dtype,
|
| 142 |
+
"row_count": row_count,
|
| 143 |
+
"null_count": null_count,
|
| 144 |
+
"null_rate": round(null_rate, 4),
|
| 145 |
+
"unique_count": unique_count,
|
| 146 |
+
"all_unique": all_unique,
|
| 147 |
+
"mean": round(mean, 6) if mean is not None else None,
|
| 148 |
+
"std": round(std, 6) if std is not None else None,
|
| 149 |
+
"min": mn,
|
| 150 |
+
"max": mx,
|
| 151 |
+
"must_be_positive": must_be_positive,
|
| 152 |
+
"z_scores": z_scores,
|
| 153 |
+
"sample_values": sample_values,
|
| 154 |
+
}
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
# ---------------------------------------------------------------------------
|
| 158 |
+
# Helpers
|
| 159 |
+
# ---------------------------------------------------------------------------
|
| 160 |
+
|
| 161 |
+
def _is_null(value: Any) -> bool:
|
| 162 |
+
"""Return True if *value* represents a missing / null entry."""
|
| 163 |
+
if value is None:
|
| 164 |
+
return True
|
| 165 |
+
if isinstance(value, float) and math.isnan(value):
|
| 166 |
+
return True
|
| 167 |
+
if isinstance(value, str) and value.strip() == "":
|
| 168 |
+
return True
|
| 169 |
+
return False
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def _infer_dtype(non_null_values: list[Any]) -> str:
|
| 173 |
+
"""Infer column dtype from a list of non-null values.
|
| 174 |
+
|
| 175 |
+
Priority: bool > int > float > str > unknown.
|
| 176 |
+
"""
|
| 177 |
+
if not non_null_values:
|
| 178 |
+
return "unknown"
|
| 179 |
+
|
| 180 |
+
# Bool check first (Python bool is subclass of int)
|
| 181 |
+
if all(isinstance(v, bool) for v in non_null_values):
|
| 182 |
+
return "bool"
|
| 183 |
+
|
| 184 |
+
# Try int
|
| 185 |
+
int_ok = True
|
| 186 |
+
for v in non_null_values:
|
| 187 |
+
if isinstance(v, bool):
|
| 188 |
+
int_ok = False
|
| 189 |
+
break
|
| 190 |
+
if isinstance(v, int):
|
| 191 |
+
continue
|
| 192 |
+
try:
|
| 193 |
+
f = float(v)
|
| 194 |
+
if f != int(f):
|
| 195 |
+
int_ok = False
|
| 196 |
+
break
|
| 197 |
+
except (ValueError, TypeError):
|
| 198 |
+
int_ok = False
|
| 199 |
+
break
|
| 200 |
+
if int_ok:
|
| 201 |
+
return "int"
|
| 202 |
+
|
| 203 |
+
# Try float
|
| 204 |
+
float_ok = True
|
| 205 |
+
for v in non_null_values:
|
| 206 |
+
if isinstance(v, (int, float)) and not isinstance(v, bool):
|
| 207 |
+
continue
|
| 208 |
+
try:
|
| 209 |
+
float(v)
|
| 210 |
+
except (ValueError, TypeError):
|
| 211 |
+
float_ok = False
|
| 212 |
+
break
|
| 213 |
+
if float_ok:
|
| 214 |
+
return "float"
|
| 215 |
+
|
| 216 |
+
# Default to str
|
| 217 |
+
if all(isinstance(v, str) for v in non_null_values):
|
| 218 |
+
return "str"
|
| 219 |
+
|
| 220 |
+
return "unknown"
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
def find_primary_key(records: list[dict]) -> Optional[str]:
|
| 224 |
+
"""Return the name of the primary-key column.
|
| 225 |
+
|
| 226 |
+
Convention: the first column whose name is 'id' or ends with '_id',
|
| 227 |
+
OR simply the first column if all values are unique integers.
|
| 228 |
+
Falls back to the first column name.
|
| 229 |
+
"""
|
| 230 |
+
if not records:
|
| 231 |
+
return None
|
| 232 |
+
|
| 233 |
+
columns = list(records[0].keys())
|
| 234 |
+
if not columns:
|
| 235 |
+
return None
|
| 236 |
+
|
| 237 |
+
# Explicit id column
|
| 238 |
+
for col in columns:
|
| 239 |
+
if col.lower() == "id" or col.lower().endswith("_id"):
|
| 240 |
+
vals = [row.get(col) for row in records]
|
| 241 |
+
if len(set(str(v) for v in vals)) == len(vals):
|
| 242 |
+
return col
|
| 243 |
+
|
| 244 |
+
# First column with all-unique integer-like values
|
| 245 |
+
first = columns[0]
|
| 246 |
+
vals = [row.get(first) for row in records]
|
| 247 |
+
try:
|
| 248 |
+
int_vals = [int(v) for v in vals if v is not None]
|
| 249 |
+
if len(int_vals) == len(records) and len(set(int_vals)) == len(int_vals):
|
| 250 |
+
return first
|
| 251 |
+
except (ValueError, TypeError):
|
| 252 |
+
pass
|
| 253 |
+
|
| 254 |
+
# Last resort: first column
|
| 255 |
+
return first
|
sqlsherlock_env/server/sqlsherlock_env_environment.py
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
MCP-enabled SQLSherlock environment.
|
| 9 |
+
|
| 10 |
+
Exposes all agent actions as MCP tools that any MCP-compatible LLM
|
| 11 |
+
(Claude, GPT, etc.) can discover and invoke dynamically via
|
| 12 |
+
ListToolsAction / CallToolAction.
|
| 13 |
+
|
| 14 |
+
This adds MCP tool discoverability on top of the existing WebSocket/HTTP API.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
from typing import Any, Optional
|
| 18 |
+
|
| 19 |
+
from fastmcp import FastMCP
|
| 20 |
+
|
| 21 |
+
from openenv.core.env_server.mcp_environment import MCPEnvironment
|
| 22 |
+
from openenv.core.env_server.types import Action
|
| 23 |
+
|
| 24 |
+
from models import SQLSherlockAction, SQLSherlockObservation, SQLSherlockState
|
| 25 |
+
from server.environment import SQLSherlockEnvironment
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# ---------------------------------------------------------------------------
|
| 29 |
+
# FastMCP server — data-quality investigation tools
|
| 30 |
+
# ---------------------------------------------------------------------------
|
| 31 |
+
|
| 32 |
+
mcp = FastMCP("sqlsherlock")
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@mcp.tool()
|
| 36 |
+
def inspect_table(table: str) -> str:
|
| 37 |
+
"""View all rows in a database table.
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
table: Name of the table to inspect (e.g. 'titanic').
|
| 41 |
+
"""
|
| 42 |
+
return f"inspect:{table}"
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
@mcp.tool()
|
| 46 |
+
def profile_column(table: str, column: str) -> str:
|
| 47 |
+
"""Get statistical profile: mean, std, min, max, null_count, z-scores.
|
| 48 |
+
|
| 49 |
+
IMPORTANT: Always call this BEFORE fixing any numeric value.
|
| 50 |
+
z > 5 = real outlier (fix it). z < 3 = normal (DO NOT touch).
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
table: Table name.
|
| 54 |
+
column: Column to profile.
|
| 55 |
+
"""
|
| 56 |
+
return f"profile:{table}:{column}"
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
@mcp.tool()
|
| 60 |
+
def run_sql(sql: str) -> str:
|
| 61 |
+
"""Execute a read-only SELECT SQL query to investigate data quality.
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
sql: A SELECT query string. No write operations allowed.
|
| 65 |
+
"""
|
| 66 |
+
return f"sql:{sql}"
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
@mcp.tool()
|
| 70 |
+
def fix_cell(table: str, row_id: int, column: str, value: str, reason: str) -> str:
|
| 71 |
+
"""Fix a data quality issue in one cell.
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
table: Table name.
|
| 75 |
+
row_id: Primary key of the row.
|
| 76 |
+
column: Column to fix.
|
| 77 |
+
value: Corrected value to write.
|
| 78 |
+
reason: Statistical justification (e.g. 'median=29.0, z-score=N/A').
|
| 79 |
+
"""
|
| 80 |
+
return f"fix:{table}:{row_id}:{column}:{value}"
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
@mcp.tool()
|
| 84 |
+
def delete_row(table: str, row_id: int, reason: str) -> str:
|
| 85 |
+
"""Delete a duplicate or FK-violation row.
|
| 86 |
+
|
| 87 |
+
Args:
|
| 88 |
+
table: Table name.
|
| 89 |
+
row_id: Primary key to delete.
|
| 90 |
+
reason: Why this row should be removed.
|
| 91 |
+
"""
|
| 92 |
+
return f"delete:{table}:{row_id}"
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
@mcp.tool()
|
| 96 |
+
def validate_data() -> str:
|
| 97 |
+
"""Run all 6 validation checks comparing current vs raw baseline.
|
| 98 |
+
|
| 99 |
+
Returns pass/partial/fail for: null_check, type_check, range_check,
|
| 100 |
+
distribution_check, duplicate_check, outlier_check.
|
| 101 |
+
"""
|
| 102 |
+
return "validate"
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
@mcp.tool()
|
| 106 |
+
def submit_investigation() -> str:
|
| 107 |
+
"""Submit the investigation for final scoring. Call after all fixes."""
|
| 108 |
+
return "submit"
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
# ---------------------------------------------------------------------------
|
| 112 |
+
# MCP Environment class
|
| 113 |
+
# ---------------------------------------------------------------------------
|
| 114 |
+
|
| 115 |
+
class SQLSherlockMCPEnvironment(MCPEnvironment):
|
| 116 |
+
"""SQLSherlock environment with MCP tool discoverability.
|
| 117 |
+
|
| 118 |
+
Wraps SQLSherlockEnvironment and exposes all actions as MCP tools.
|
| 119 |
+
MCP agents call ListToolsAction to discover tools, then CallToolAction
|
| 120 |
+
to invoke them.
|
| 121 |
+
"""
|
| 122 |
+
|
| 123 |
+
SUPPORTS_CONCURRENT_SESSIONS: bool = True
|
| 124 |
+
|
| 125 |
+
def __init__(self) -> None:
|
| 126 |
+
super().__init__(mcp_server=mcp)
|
| 127 |
+
self._env = SQLSherlockEnvironment()
|
| 128 |
+
|
| 129 |
+
@property
|
| 130 |
+
def state(self) -> SQLSherlockState:
|
| 131 |
+
return self._env.state
|
| 132 |
+
|
| 133 |
+
def reset(self, **kwargs) -> SQLSherlockObservation:
|
| 134 |
+
return self._env.reset(**kwargs)
|
| 135 |
+
|
| 136 |
+
def _step_impl(
|
| 137 |
+
self,
|
| 138 |
+
action: Action,
|
| 139 |
+
timeout_s: Optional[float] = None,
|
| 140 |
+
**kwargs: Any,
|
| 141 |
+
) -> SQLSherlockObservation:
|
| 142 |
+
"""Handle standard SQLSherlock actions (non-MCP)."""
|
| 143 |
+
if isinstance(action, SQLSherlockAction):
|
| 144 |
+
return self._env.step(action, **kwargs)
|
| 145 |
+
|
| 146 |
+
# Fallback: construct from dict
|
| 147 |
+
if hasattr(action, "model_dump"):
|
| 148 |
+
d = action.model_dump()
|
| 149 |
+
elif isinstance(action, dict):
|
| 150 |
+
d = action
|
| 151 |
+
else:
|
| 152 |
+
d = {"action_type": "inspect"}
|
| 153 |
+
|
| 154 |
+
sa = SQLSherlockAction(**{k: v for k, v in d.items() if v is not None})
|
| 155 |
+
return self._env.step(sa, **kwargs)
|
sqlsherlock_env/server/validator.py
ADDED
|
@@ -0,0 +1,545 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
Validator for SQLSherlock-Env.
|
| 9 |
+
|
| 10 |
+
Runs 6 checks comparing the current dataset state against the baseline
|
| 11 |
+
captured at reset() time. Called by:
|
| 12 |
+
- DatabaseEngine.__init__() → stores baseline_metrics
|
| 13 |
+
- environment.py step() → on "validate" action
|
| 14 |
+
- graders/universal.py → final scoring pass
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
import math
|
| 18 |
+
import sqlite3
|
| 19 |
+
from dataclasses import dataclass, field
|
| 20 |
+
from typing import Any, Optional
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# ---------------------------------------------------------------------------
|
| 24 |
+
# Result types
|
| 25 |
+
# ---------------------------------------------------------------------------
|
| 26 |
+
|
| 27 |
+
@dataclass
|
| 28 |
+
class CheckResult:
|
| 29 |
+
name: str
|
| 30 |
+
passed: bool
|
| 31 |
+
before: Any
|
| 32 |
+
after: Any
|
| 33 |
+
detail: str = ""
|
| 34 |
+
warnings: list[str] = field(default_factory=list)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
@dataclass
|
| 38 |
+
class ValidationResult:
|
| 39 |
+
checks: dict[str, CheckResult]
|
| 40 |
+
checks_passed: int
|
| 41 |
+
total_checks: int
|
| 42 |
+
overall: str # "PASS" | "PARTIAL" | "FAIL"
|
| 43 |
+
warnings: list[str] # distribution drift warnings
|
| 44 |
+
|
| 45 |
+
def to_dict(self) -> dict:
|
| 46 |
+
return {
|
| 47 |
+
"checks": {
|
| 48 |
+
name: {
|
| 49 |
+
"passed": cr.passed,
|
| 50 |
+
"before": cr.before,
|
| 51 |
+
"after": cr.after,
|
| 52 |
+
"detail": cr.detail,
|
| 53 |
+
"warnings": cr.warnings,
|
| 54 |
+
}
|
| 55 |
+
for name, cr in self.checks.items()
|
| 56 |
+
},
|
| 57 |
+
"checks_passed": self.checks_passed,
|
| 58 |
+
"total_checks": self.total_checks,
|
| 59 |
+
"overall": self.overall,
|
| 60 |
+
"warnings": self.warnings,
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
# ---------------------------------------------------------------------------
|
| 65 |
+
# Validator class
|
| 66 |
+
# ---------------------------------------------------------------------------
|
| 67 |
+
|
| 68 |
+
class Validator:
|
| 69 |
+
"""Stateful validator that stores baseline metrics at construction time.
|
| 70 |
+
|
| 71 |
+
Usage::
|
| 72 |
+
|
| 73 |
+
v = Validator(conn, profile, issue_registry)
|
| 74 |
+
# ... agent makes fixes ...
|
| 75 |
+
result = v.validate(conn, current_records)
|
| 76 |
+
"""
|
| 77 |
+
|
| 78 |
+
def __init__(
|
| 79 |
+
self,
|
| 80 |
+
conn: sqlite3.Connection,
|
| 81 |
+
profile: dict[str, dict],
|
| 82 |
+
issue_registry: list, # list[Issue] — typed loosely to avoid circular import
|
| 83 |
+
) -> None:
|
| 84 |
+
self._profile = profile
|
| 85 |
+
self._issue_registry = issue_registry
|
| 86 |
+
self._baseline = self._scan_baseline(conn, profile, issue_registry)
|
| 87 |
+
|
| 88 |
+
# ------------------------------------------------------------------
|
| 89 |
+
# Public
|
| 90 |
+
# ------------------------------------------------------------------
|
| 91 |
+
|
| 92 |
+
def validate(
|
| 93 |
+
self,
|
| 94 |
+
conn: sqlite3.Connection,
|
| 95 |
+
current_records: list[dict],
|
| 96 |
+
touched_columns: Optional[set[str]] = None,
|
| 97 |
+
) -> ValidationResult:
|
| 98 |
+
"""Run all 6 checks against the current state.
|
| 99 |
+
|
| 100 |
+
Args:
|
| 101 |
+
conn: Live SQLite connection (current state).
|
| 102 |
+
current_records: Current rows as list of dicts.
|
| 103 |
+
touched_columns: Set of column names the agent modified.
|
| 104 |
+
Used to distinguish false-positive drift warnings.
|
| 105 |
+
|
| 106 |
+
Returns:
|
| 107 |
+
ValidationResult with per-check details.
|
| 108 |
+
"""
|
| 109 |
+
profile = self._profile
|
| 110 |
+
baseline = self._baseline
|
| 111 |
+
touched = touched_columns or set()
|
| 112 |
+
|
| 113 |
+
checks: dict[str, CheckResult] = {}
|
| 114 |
+
warnings: list[str] = []
|
| 115 |
+
|
| 116 |
+
# 1. Null check
|
| 117 |
+
checks["null_check"] = self._null_check(current_records, baseline, profile)
|
| 118 |
+
|
| 119 |
+
# 2. Type check
|
| 120 |
+
checks["type_check"] = self._type_check(current_records, baseline, profile)
|
| 121 |
+
|
| 122 |
+
# 3. Range check
|
| 123 |
+
checks["range_check"] = self._range_check(current_records, baseline, profile)
|
| 124 |
+
|
| 125 |
+
# 4. Distribution check
|
| 126 |
+
dist_cr = self._distribution_check(current_records, baseline, profile, touched)
|
| 127 |
+
checks["distribution_check"] = dist_cr
|
| 128 |
+
warnings.extend(dist_cr.warnings)
|
| 129 |
+
|
| 130 |
+
# 5. Duplicate check
|
| 131 |
+
checks["duplicate_check"] = self._duplicate_check(current_records, baseline, profile)
|
| 132 |
+
|
| 133 |
+
# 6. Outlier check
|
| 134 |
+
checks["outlier_check"] = self._outlier_check(current_records, baseline, profile)
|
| 135 |
+
|
| 136 |
+
passed = sum(1 for cr in checks.values() if cr.passed)
|
| 137 |
+
total = len(checks)
|
| 138 |
+
|
| 139 |
+
if passed == total:
|
| 140 |
+
overall = "PASS"
|
| 141 |
+
elif passed == 0:
|
| 142 |
+
overall = "FAIL"
|
| 143 |
+
else:
|
| 144 |
+
overall = "PARTIAL"
|
| 145 |
+
|
| 146 |
+
return ValidationResult(
|
| 147 |
+
checks=checks,
|
| 148 |
+
checks_passed=passed,
|
| 149 |
+
total_checks=total,
|
| 150 |
+
overall=overall,
|
| 151 |
+
warnings=warnings,
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
# ------------------------------------------------------------------
|
| 155 |
+
# Baseline scan
|
| 156 |
+
# ------------------------------------------------------------------
|
| 157 |
+
|
| 158 |
+
def _scan_baseline(
|
| 159 |
+
self,
|
| 160 |
+
conn: sqlite3.Connection,
|
| 161 |
+
profile: dict[str, dict],
|
| 162 |
+
issue_registry: list,
|
| 163 |
+
) -> dict:
|
| 164 |
+
"""Compute baseline metrics from the initial (dirty) state."""
|
| 165 |
+
# We use the profile (computed at load time) as our baseline source
|
| 166 |
+
# plus we do a quick live scan for null/type counts
|
| 167 |
+
|
| 168 |
+
baseline: dict = {}
|
| 169 |
+
|
| 170 |
+
# Null counts per column (high-confidence issues only)
|
| 171 |
+
high_conf_null_cols: set[str] = set()
|
| 172 |
+
for iss in issue_registry:
|
| 173 |
+
if iss.issue_type == "null" and iss.confidence > 0.50 and iss.column:
|
| 174 |
+
high_conf_null_cols.add(iss.column)
|
| 175 |
+
|
| 176 |
+
baseline["null_cols"] = high_conf_null_cols
|
| 177 |
+
baseline["null_counts"] = {
|
| 178 |
+
col: profile[col]["null_count"]
|
| 179 |
+
for col in high_conf_null_cols
|
| 180 |
+
if col in profile
|
| 181 |
+
}
|
| 182 |
+
|
| 183 |
+
# Type error columns
|
| 184 |
+
type_error_cols = {
|
| 185 |
+
iss.column
|
| 186 |
+
for iss in issue_registry
|
| 187 |
+
if iss.issue_type == "type_error" and iss.column
|
| 188 |
+
}
|
| 189 |
+
baseline["type_error_cols"] = type_error_cols
|
| 190 |
+
baseline["type_error_counts"] = {col: 0 for col in type_error_cols}
|
| 191 |
+
for iss in issue_registry:
|
| 192 |
+
if iss.issue_type == "type_error" and iss.column:
|
| 193 |
+
baseline["type_error_counts"][iss.column] = (
|
| 194 |
+
baseline["type_error_counts"].get(iss.column, 0) + 1
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
# Must-be-positive columns with negatives
|
| 198 |
+
constraint_cols = {
|
| 199 |
+
iss.column
|
| 200 |
+
for iss in issue_registry
|
| 201 |
+
if iss.issue_type == "constraint" and iss.column
|
| 202 |
+
}
|
| 203 |
+
baseline["constraint_cols"] = constraint_cols
|
| 204 |
+
baseline["constraint_counts"] = {}
|
| 205 |
+
for iss in issue_registry:
|
| 206 |
+
if iss.issue_type == "constraint" and iss.column:
|
| 207 |
+
baseline["constraint_counts"][iss.column] = (
|
| 208 |
+
baseline["constraint_counts"].get(iss.column, 0) + 1
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
# Distribution baseline (mean/std per numeric column)
|
| 212 |
+
baseline["distribution"] = {
|
| 213 |
+
col: {"mean": p["mean"], "std": p["std"]}
|
| 214 |
+
for col, p in profile.items()
|
| 215 |
+
if p["dtype"] in ("int", "float")
|
| 216 |
+
and p["mean"] is not None
|
| 217 |
+
}
|
| 218 |
+
|
| 219 |
+
# Duplicate baseline: count of rows with repeated natural-key values
|
| 220 |
+
baseline["duplicate_count"] = sum(
|
| 221 |
+
1 for iss in issue_registry if iss.issue_type == "duplicate"
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
# Outlier baseline: set of (row_id, col) pairs with z > 5
|
| 225 |
+
baseline["outlier_cells"] = {
|
| 226 |
+
(iss.row_id, iss.column)
|
| 227 |
+
for iss in issue_registry
|
| 228 |
+
if iss.issue_type == "outlier" and iss.column
|
| 229 |
+
}
|
| 230 |
+
|
| 231 |
+
return baseline
|
| 232 |
+
|
| 233 |
+
# ------------------------------------------------------------------
|
| 234 |
+
# Individual checks
|
| 235 |
+
# ------------------------------------------------------------------
|
| 236 |
+
|
| 237 |
+
def _null_check(
|
| 238 |
+
self,
|
| 239 |
+
records: list[dict],
|
| 240 |
+
baseline: dict,
|
| 241 |
+
profile: dict[str, dict],
|
| 242 |
+
) -> CheckResult:
|
| 243 |
+
null_cols = baseline.get("null_cols", set())
|
| 244 |
+
before_counts = baseline.get("null_counts", {})
|
| 245 |
+
|
| 246 |
+
if not null_cols:
|
| 247 |
+
return CheckResult(
|
| 248 |
+
name="null_check",
|
| 249 |
+
passed=True,
|
| 250 |
+
before=before_counts,
|
| 251 |
+
after={},
|
| 252 |
+
detail="No high-confidence null issues in registry.",
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
after_counts: dict[str, int] = {}
|
| 256 |
+
for col in null_cols:
|
| 257 |
+
after_counts[col] = sum(
|
| 258 |
+
1 for row in records if _is_null(row.get(col))
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
all_fixed = all(after_counts.get(col, 0) == 0 for col in null_cols)
|
| 262 |
+
return CheckResult(
|
| 263 |
+
name="null_check",
|
| 264 |
+
passed=all_fixed,
|
| 265 |
+
before=before_counts,
|
| 266 |
+
after=after_counts,
|
| 267 |
+
detail=(
|
| 268 |
+
"All high-confidence nulls resolved."
|
| 269 |
+
if all_fixed
|
| 270 |
+
else f"Remaining nulls: { {c:v for c,v in after_counts.items() if v>0} }"
|
| 271 |
+
),
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
def _type_check(
|
| 275 |
+
self,
|
| 276 |
+
records: list[dict],
|
| 277 |
+
baseline: dict,
|
| 278 |
+
profile: dict[str, dict],
|
| 279 |
+
) -> CheckResult:
|
| 280 |
+
type_cols = baseline.get("type_error_cols", set())
|
| 281 |
+
before_counts = baseline.get("type_error_counts", {})
|
| 282 |
+
|
| 283 |
+
if not type_cols:
|
| 284 |
+
return CheckResult(
|
| 285 |
+
name="type_check",
|
| 286 |
+
passed=True,
|
| 287 |
+
before=before_counts,
|
| 288 |
+
after={},
|
| 289 |
+
detail="No type errors in registry.",
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
after_counts: dict[str, int] = {}
|
| 293 |
+
for col in type_cols:
|
| 294 |
+
if col not in profile:
|
| 295 |
+
after_counts[col] = 0
|
| 296 |
+
continue
|
| 297 |
+
after_counts[col] = sum(
|
| 298 |
+
1 for row in records
|
| 299 |
+
if not _is_null(row.get(col))
|
| 300 |
+
and not _can_cast_float(row.get(col))
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
all_fixed = all(v == 0 for v in after_counts.values())
|
| 304 |
+
return CheckResult(
|
| 305 |
+
name="type_check",
|
| 306 |
+
passed=all_fixed,
|
| 307 |
+
before=before_counts,
|
| 308 |
+
after=after_counts,
|
| 309 |
+
detail=(
|
| 310 |
+
"All type errors resolved."
|
| 311 |
+
if all_fixed
|
| 312 |
+
else f"Remaining type errors: { {c:v for c,v in after_counts.items() if v>0} }"
|
| 313 |
+
),
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
def _range_check(
|
| 317 |
+
self,
|
| 318 |
+
records: list[dict],
|
| 319 |
+
baseline: dict,
|
| 320 |
+
profile: dict[str, dict],
|
| 321 |
+
) -> CheckResult:
|
| 322 |
+
constraint_cols = baseline.get("constraint_cols", set())
|
| 323 |
+
before_counts = baseline.get("constraint_counts", {})
|
| 324 |
+
|
| 325 |
+
if not constraint_cols:
|
| 326 |
+
return CheckResult(
|
| 327 |
+
name="range_check",
|
| 328 |
+
passed=True,
|
| 329 |
+
before=before_counts,
|
| 330 |
+
after={},
|
| 331 |
+
detail="No constraint violations in registry.",
|
| 332 |
+
)
|
| 333 |
+
|
| 334 |
+
after_counts: dict[str, int] = {}
|
| 335 |
+
for col in constraint_cols:
|
| 336 |
+
after_counts[col] = sum(
|
| 337 |
+
1 for row in records
|
| 338 |
+
if not _is_null(row.get(col))
|
| 339 |
+
and _can_cast_float(row.get(col))
|
| 340 |
+
and float(row[col]) < 0
|
| 341 |
+
)
|
| 342 |
+
|
| 343 |
+
all_fixed = all(v == 0 for v in after_counts.values())
|
| 344 |
+
return CheckResult(
|
| 345 |
+
name="range_check",
|
| 346 |
+
passed=all_fixed,
|
| 347 |
+
before=before_counts,
|
| 348 |
+
after=after_counts,
|
| 349 |
+
detail=(
|
| 350 |
+
"All constraint violations resolved."
|
| 351 |
+
if all_fixed
|
| 352 |
+
else f"Remaining negatives: { {c:v for c,v in after_counts.items() if v>0} }"
|
| 353 |
+
),
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
def _distribution_check(
|
| 357 |
+
self,
|
| 358 |
+
records: list[dict],
|
| 359 |
+
baseline: dict,
|
| 360 |
+
profile: dict[str, dict],
|
| 361 |
+
touched: set[str],
|
| 362 |
+
) -> CheckResult:
|
| 363 |
+
dist_baseline = baseline.get("distribution", {})
|
| 364 |
+
if not dist_baseline:
|
| 365 |
+
return CheckResult(
|
| 366 |
+
name="distribution_check",
|
| 367 |
+
passed=True,
|
| 368 |
+
before={},
|
| 369 |
+
after={},
|
| 370 |
+
detail="No numeric columns to check.",
|
| 371 |
+
)
|
| 372 |
+
|
| 373 |
+
after_dist: dict[str, dict] = {}
|
| 374 |
+
warnings: list[str] = []
|
| 375 |
+
drift_cols: list[str] = []
|
| 376 |
+
|
| 377 |
+
for col, bstats in dist_baseline.items():
|
| 378 |
+
b_mean = bstats.get("mean")
|
| 379 |
+
if b_mean is None or b_mean == 0:
|
| 380 |
+
continue
|
| 381 |
+
vals = [
|
| 382 |
+
float(row[col])
|
| 383 |
+
for row in records
|
| 384 |
+
if not _is_null(row.get(col)) and _can_cast_float(row.get(col))
|
| 385 |
+
]
|
| 386 |
+
if not vals:
|
| 387 |
+
continue
|
| 388 |
+
a_mean = sum(vals) / len(vals)
|
| 389 |
+
drift_pct = abs(a_mean - b_mean) / abs(b_mean) * 100.0
|
| 390 |
+
after_dist[col] = {"mean": round(a_mean, 4), "drift_pct": round(drift_pct, 2)}
|
| 391 |
+
|
| 392 |
+
if drift_pct >= 20.0:
|
| 393 |
+
drift_cols.append(col)
|
| 394 |
+
if drift_pct > 5.0 and col not in touched:
|
| 395 |
+
warnings.append(
|
| 396 |
+
f"Column '{col}' mean drifted {drift_pct:.1f}% but agent did not modify it — "
|
| 397 |
+
"possible false positive fix in a related column."
|
| 398 |
+
)
|
| 399 |
+
|
| 400 |
+
passed = len(drift_cols) == 0
|
| 401 |
+
return CheckResult(
|
| 402 |
+
name="distribution_check",
|
| 403 |
+
passed=passed,
|
| 404 |
+
before={c: {"mean": v["mean"]} for c, v in dist_baseline.items() if "mean" in v},
|
| 405 |
+
after=after_dist,
|
| 406 |
+
detail=(
|
| 407 |
+
"Distribution stable across all numeric columns."
|
| 408 |
+
if passed
|
| 409 |
+
else f"Mean drift ≥20% in: {drift_cols}"
|
| 410 |
+
),
|
| 411 |
+
warnings=warnings,
|
| 412 |
+
)
|
| 413 |
+
|
| 414 |
+
def _duplicate_check(
|
| 415 |
+
self,
|
| 416 |
+
records: list[dict],
|
| 417 |
+
baseline: dict,
|
| 418 |
+
profile: dict[str, dict],
|
| 419 |
+
) -> CheckResult:
|
| 420 |
+
before_count = baseline.get("duplicate_count", 0)
|
| 421 |
+
if before_count == 0:
|
| 422 |
+
return CheckResult(
|
| 423 |
+
name="duplicate_check",
|
| 424 |
+
passed=True,
|
| 425 |
+
before=0,
|
| 426 |
+
after=0,
|
| 427 |
+
detail="No duplicates in baseline.",
|
| 428 |
+
)
|
| 429 |
+
|
| 430 |
+
# Find natural key column from profile
|
| 431 |
+
natural_key = None
|
| 432 |
+
for col, p in profile.items():
|
| 433 |
+
if p.get("all_unique") and p["dtype"] != "float":
|
| 434 |
+
col_lower = col.lower()
|
| 435 |
+
if any(h in col_lower for h in ("name", "email", "code", "ref", "id_", "key", "title")):
|
| 436 |
+
natural_key = col
|
| 437 |
+
break
|
| 438 |
+
|
| 439 |
+
if natural_key is None:
|
| 440 |
+
return CheckResult(
|
| 441 |
+
name="duplicate_check",
|
| 442 |
+
passed=True,
|
| 443 |
+
before=before_count,
|
| 444 |
+
after=0,
|
| 445 |
+
detail="Natural key column not found; cannot recheck duplicates.",
|
| 446 |
+
)
|
| 447 |
+
|
| 448 |
+
seen: set[str] = set()
|
| 449 |
+
after_count = 0
|
| 450 |
+
for row in records:
|
| 451 |
+
val = row.get(natural_key)
|
| 452 |
+
if _is_null(val):
|
| 453 |
+
continue
|
| 454 |
+
key_str = str(val).strip().lower()
|
| 455 |
+
if key_str in seen:
|
| 456 |
+
after_count += 1
|
| 457 |
+
else:
|
| 458 |
+
seen.add(key_str)
|
| 459 |
+
|
| 460 |
+
passed = after_count < before_count or after_count == 0
|
| 461 |
+
return CheckResult(
|
| 462 |
+
name="duplicate_check",
|
| 463 |
+
passed=passed,
|
| 464 |
+
before=before_count,
|
| 465 |
+
after=after_count,
|
| 466 |
+
detail=(
|
| 467 |
+
f"Duplicates reduced from {before_count} to {after_count}."
|
| 468 |
+
if passed
|
| 469 |
+
else f"Duplicate count unchanged at {after_count}."
|
| 470 |
+
),
|
| 471 |
+
)
|
| 472 |
+
|
| 473 |
+
def _outlier_check(
|
| 474 |
+
self,
|
| 475 |
+
records: list[dict],
|
| 476 |
+
baseline: dict,
|
| 477 |
+
profile: dict[str, dict],
|
| 478 |
+
) -> CheckResult:
|
| 479 |
+
outlier_cells = baseline.get("outlier_cells", set())
|
| 480 |
+
if not outlier_cells:
|
| 481 |
+
return CheckResult(
|
| 482 |
+
name="outlier_check",
|
| 483 |
+
passed=True,
|
| 484 |
+
before=set(),
|
| 485 |
+
after=set(),
|
| 486 |
+
detail="No outliers in baseline.",
|
| 487 |
+
)
|
| 488 |
+
|
| 489 |
+
pk_col = list(records[0].keys())[0] if records else "id"
|
| 490 |
+
row_map = {int(r[pk_col]): r for r in records if not _is_null(r.get(pk_col))}
|
| 491 |
+
|
| 492 |
+
still_outliers: set[tuple] = set()
|
| 493 |
+
for (rid, col) in outlier_cells:
|
| 494 |
+
if col not in profile:
|
| 495 |
+
continue
|
| 496 |
+
p = profile[col]
|
| 497 |
+
mean = p.get("mean")
|
| 498 |
+
std = p.get("std")
|
| 499 |
+
if mean is None or std is None or std == 0:
|
| 500 |
+
continue
|
| 501 |
+
row = row_map.get(rid)
|
| 502 |
+
if row is None:
|
| 503 |
+
# Row was deleted — outlier resolved
|
| 504 |
+
continue
|
| 505 |
+
val = row.get(col)
|
| 506 |
+
if _is_null(val) or not _can_cast_float(val):
|
| 507 |
+
continue
|
| 508 |
+
z = abs(float(val) - mean) / std
|
| 509 |
+
if z > 5.0:
|
| 510 |
+
still_outliers.add((rid, col))
|
| 511 |
+
|
| 512 |
+
passed = len(still_outliers) == 0
|
| 513 |
+
return CheckResult(
|
| 514 |
+
name="outlier_check",
|
| 515 |
+
passed=passed,
|
| 516 |
+
before=len(outlier_cells),
|
| 517 |
+
after=len(still_outliers),
|
| 518 |
+
detail=(
|
| 519 |
+
"All outliers resolved."
|
| 520 |
+
if passed
|
| 521 |
+
else f"{len(still_outliers)} outlier(s) remain: {list(still_outliers)[:5]}"
|
| 522 |
+
),
|
| 523 |
+
)
|
| 524 |
+
|
| 525 |
+
|
| 526 |
+
# ---------------------------------------------------------------------------
|
| 527 |
+
# Helpers
|
| 528 |
+
# ---------------------------------------------------------------------------
|
| 529 |
+
|
| 530 |
+
def _is_null(value: Any) -> bool:
|
| 531 |
+
if value is None:
|
| 532 |
+
return True
|
| 533 |
+
if isinstance(value, float) and math.isnan(value):
|
| 534 |
+
return True
|
| 535 |
+
if isinstance(value, str) and value.strip() == "":
|
| 536 |
+
return True
|
| 537 |
+
return False
|
| 538 |
+
|
| 539 |
+
|
| 540 |
+
def _can_cast_float(value: Any) -> bool:
|
| 541 |
+
try:
|
| 542 |
+
float(str(value))
|
| 543 |
+
return True
|
| 544 |
+
except (ValueError, TypeError):
|
| 545 |
+
return False
|
tests/__init__.py
ADDED
|
File without changes
|
tests/conftest.py
ADDED
|
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
Shared pytest fixtures for SQLSherlock-Env tests.
|
| 9 |
+
|
| 10 |
+
All fixtures use in-memory SQLite and synthetic data — no network calls,
|
| 11 |
+
no HuggingFace token required.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import sqlite3
|
| 15 |
+
import sys
|
| 16 |
+
import os
|
| 17 |
+
import pytest
|
| 18 |
+
|
| 19 |
+
# Ensure sqlsherlock_env/ is on the path so absolute imports resolve
|
| 20 |
+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "sqlsherlock_env"))
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# ---------------------------------------------------------------------------
|
| 24 |
+
# Minimal synthetic dataset helpers
|
| 25 |
+
# ---------------------------------------------------------------------------
|
| 26 |
+
|
| 27 |
+
CLEAN_RECORDS = [
|
| 28 |
+
{"id": 1, "name": "Alice", "age": 30, "fare": 10.50, "survived": 1},
|
| 29 |
+
{"id": 2, "name": "Bob", "age": 25, "fare": 7.25, "survived": 0},
|
| 30 |
+
{"id": 3, "name": "Carol", "age": 40, "fare": 15.00, "survived": 1},
|
| 31 |
+
{"id": 4, "name": "Dave", "age": 35, "fare": 8.00, "survived": 0},
|
| 32 |
+
{"id": 5, "name": "Eve", "age": 28, "fare": 12.00, "survived": 1},
|
| 33 |
+
{"id": 6, "name": "Frank", "age": 45, "fare": 9.75, "survived": 0},
|
| 34 |
+
{"id": 7, "name": "Grace", "age": 33, "fare": 11.50, "survived": 1},
|
| 35 |
+
{"id": 8, "name": "Heidi", "age": 29, "fare": 6.50, "survived": 0},
|
| 36 |
+
{"id": 9, "name": "Ivan", "age": 38, "fare": 13.25, "survived": 1},
|
| 37 |
+
{"id": 10, "name": "Judy", "age": 22, "fare": 5.00, "survived": 0},
|
| 38 |
+
]
|
| 39 |
+
|
| 40 |
+
DIRTY_RECORDS = [
|
| 41 |
+
{"id": 1, "name": "Alice", "age": None, "fare": 10.50, "survived": 1}, # null age
|
| 42 |
+
{"id": 2, "name": "Bob", "age": 25, "fare": 7.25, "survived": 0},
|
| 43 |
+
{"id": 3, "name": "Carol", "age": "FORTY", "fare": 15.00, "survived": 1}, # type error
|
| 44 |
+
{"id": 4, "name": "Dave", "age": -5, "fare": 8.00, "survived": 0}, # constraint
|
| 45 |
+
{"id": 5, "name": "Eve", "age": 28, "fare": 512.33, "survived": 1}, # outlier (z>5)
|
| 46 |
+
{"id": 6, "name": "Frank", "age": 45, "fare": 9.75, "survived": 0},
|
| 47 |
+
{"id": 7, "name": "Grace", "age": 33, "fare": 11.50, "survived": 1},
|
| 48 |
+
{"id": 8, "name": "Alice", "age": 29, "fare": 6.50, "survived": 0}, # duplicate name
|
| 49 |
+
{"id": 9, "name": "Ivan", "age": 38, "fare": 13.25, "survived": 1},
|
| 50 |
+
{"id": 10, "name": "Judy", "age": 22, "fare": 5.00, "survived": 0},
|
| 51 |
+
]
|
| 52 |
+
|
| 53 |
+
RAW_CSV_TEXT = (
|
| 54 |
+
"id,name,age,fare,survived\n"
|
| 55 |
+
"1,Alice,,10.50,1\n"
|
| 56 |
+
"2,Bob,25,7.25,0\n"
|
| 57 |
+
"3,Carol,FORTY,15.00,1\n"
|
| 58 |
+
"4,Dave,-5,8.00,0\n"
|
| 59 |
+
"5,Eve,28,512.33,1\n"
|
| 60 |
+
"6,Frank,45,9.75,0\n"
|
| 61 |
+
"7,Grace,33,11.50,1\n"
|
| 62 |
+
"8,Alice,29,6.50,0\n"
|
| 63 |
+
"9,Ivan,38,13.25,1\n"
|
| 64 |
+
"10,Judy,22,5.00,0\n"
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
# ---------------------------------------------------------------------------
|
| 69 |
+
# SQLite connection fixtures
|
| 70 |
+
# ---------------------------------------------------------------------------
|
| 71 |
+
|
| 72 |
+
@pytest.fixture
|
| 73 |
+
def clean_conn():
|
| 74 |
+
"""In-memory SQLite with clean records."""
|
| 75 |
+
conn = sqlite3.connect(":memory:")
|
| 76 |
+
conn.row_factory = sqlite3.Row
|
| 77 |
+
_create_table(conn, "passengers", CLEAN_RECORDS)
|
| 78 |
+
yield conn
|
| 79 |
+
conn.close()
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
@pytest.fixture
|
| 83 |
+
def dirty_conn():
|
| 84 |
+
"""In-memory SQLite with dirty records (nulls, type errors, constraint, outlier, duplicate)."""
|
| 85 |
+
conn = sqlite3.connect(":memory:")
|
| 86 |
+
conn.row_factory = sqlite3.Row
|
| 87 |
+
_create_table(conn, "passengers", DIRTY_RECORDS)
|
| 88 |
+
yield conn
|
| 89 |
+
conn.close()
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def _create_table(conn: sqlite3.Connection, table: str, records: list[dict]) -> None:
|
| 93 |
+
conn.execute(f'DROP TABLE IF EXISTS "{table}"')
|
| 94 |
+
conn.execute(
|
| 95 |
+
f'CREATE TABLE "{table}" '
|
| 96 |
+
f'(id INTEGER, name TEXT, age TEXT, fare REAL, survived INTEGER)'
|
| 97 |
+
)
|
| 98 |
+
for r in records:
|
| 99 |
+
conn.execute(
|
| 100 |
+
f'INSERT INTO "{table}" VALUES (?, ?, ?, ?, ?)',
|
| 101 |
+
(r["id"], r["name"], r.get("age"), r.get("fare"), r.get("survived")),
|
| 102 |
+
)
|
| 103 |
+
conn.commit()
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
# ---------------------------------------------------------------------------
|
| 107 |
+
# Profile fixture
|
| 108 |
+
# ---------------------------------------------------------------------------
|
| 109 |
+
|
| 110 |
+
@pytest.fixture
|
| 111 |
+
def dirty_profile():
|
| 112 |
+
"""Column profile computed from DIRTY_RECORDS."""
|
| 113 |
+
from server.schema_profiler import profile_table
|
| 114 |
+
return profile_table("passengers", DIRTY_RECORDS)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
@pytest.fixture
|
| 118 |
+
def clean_profile():
|
| 119 |
+
"""Column profile computed from CLEAN_RECORDS."""
|
| 120 |
+
from server.schema_profiler import profile_table
|
| 121 |
+
return profile_table("passengers", CLEAN_RECORDS)
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
# ---------------------------------------------------------------------------
|
| 125 |
+
# DatabaseEngine fixtures
|
| 126 |
+
# ---------------------------------------------------------------------------
|
| 127 |
+
|
| 128 |
+
@pytest.fixture
|
| 129 |
+
def db_task1():
|
| 130 |
+
"""DatabaseEngine for task1 loaded from raw CSV text."""
|
| 131 |
+
from server.database import DatabaseEngine
|
| 132 |
+
db = DatabaseEngine(
|
| 133 |
+
task_id="task1_null_and_types",
|
| 134 |
+
seed=42,
|
| 135 |
+
dataset_source=RAW_CSV_TEXT,
|
| 136 |
+
max_rows=50,
|
| 137 |
+
)
|
| 138 |
+
return db
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
@pytest.fixture
|
| 142 |
+
def db_task2():
|
| 143 |
+
"""DatabaseEngine for task2 loaded from raw CSV text."""
|
| 144 |
+
from server.database import DatabaseEngine
|
| 145 |
+
db = DatabaseEngine(
|
| 146 |
+
task_id="task2_constraints_and_fk",
|
| 147 |
+
seed=42,
|
| 148 |
+
dataset_source=RAW_CSV_TEXT,
|
| 149 |
+
max_rows=50,
|
| 150 |
+
)
|
| 151 |
+
return db
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
@pytest.fixture
|
| 155 |
+
def db_task3():
|
| 156 |
+
"""DatabaseEngine for task3 loaded from raw CSV text."""
|
| 157 |
+
from server.database import DatabaseEngine
|
| 158 |
+
db = DatabaseEngine(
|
| 159 |
+
task_id="task3_full_audit_with_trap",
|
| 160 |
+
seed=42,
|
| 161 |
+
dataset_source=RAW_CSV_TEXT,
|
| 162 |
+
max_rows=50,
|
| 163 |
+
)
|
| 164 |
+
return db
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
# ---------------------------------------------------------------------------
|
| 168 |
+
# Issue registry fixture
|
| 169 |
+
# ---------------------------------------------------------------------------
|
| 170 |
+
|
| 171 |
+
@pytest.fixture
|
| 172 |
+
def task1_issues(dirty_conn, dirty_profile):
|
| 173 |
+
"""Issues detected for task1 on the dirty dataset."""
|
| 174 |
+
from server.issue_detector import detect_issues
|
| 175 |
+
import copy
|
| 176 |
+
records = copy.deepcopy(DIRTY_RECORDS)
|
| 177 |
+
return detect_issues(
|
| 178 |
+
conn=dirty_conn,
|
| 179 |
+
profile=dirty_profile,
|
| 180 |
+
records=records,
|
| 181 |
+
task_id="task1_null_and_types",
|
| 182 |
+
seed=42,
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
@pytest.fixture
|
| 187 |
+
def task3_issues(dirty_conn, dirty_profile):
|
| 188 |
+
"""Issues detected for task3 on the dirty dataset."""
|
| 189 |
+
from server.issue_detector import detect_issues
|
| 190 |
+
import copy
|
| 191 |
+
records = copy.deepcopy(DIRTY_RECORDS)
|
| 192 |
+
return detect_issues(
|
| 193 |
+
conn=dirty_conn,
|
| 194 |
+
profile=dirty_profile,
|
| 195 |
+
records=records,
|
| 196 |
+
task_id="task3_full_audit_with_trap",
|
| 197 |
+
seed=42,
|
| 198 |
+
)
|
tests/test_environment.py
ADDED
|
@@ -0,0 +1,447 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
Tests for server/environment.py
|
| 9 |
+
|
| 10 |
+
Covers: reset validation, step dispatch for all 8 action types,
|
| 11 |
+
reward accumulation, done flag, max-steps termination,
|
| 12 |
+
and WebSocket minimal-action compatibility (Nemotron Phase 2).
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import pytest
|
| 16 |
+
|
| 17 |
+
from server.environment import SQLSherlockEnvironment, TASKS
|
| 18 |
+
from models import SQLSherlockAction, SQLSherlockObservation, SQLSherlockState
|
| 19 |
+
from tests.conftest import RAW_CSV_TEXT
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def _step(env, action):
|
| 23 |
+
"""Call env.step() and unpack the observation into (obs, reward, done, info).
|
| 24 |
+
|
| 25 |
+
The openenv-core Environment.step() returns an Observation with reward/done
|
| 26 |
+
set on it. This helper provides the classic RL tuple interface for tests.
|
| 27 |
+
"""
|
| 28 |
+
obs = env.step(action)
|
| 29 |
+
return obs, float(obs.reward or 0.0), obs.done, {}
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# ---------------------------------------------------------------------------
|
| 33 |
+
# Fixtures
|
| 34 |
+
# ---------------------------------------------------------------------------
|
| 35 |
+
|
| 36 |
+
@pytest.fixture
|
| 37 |
+
def env():
|
| 38 |
+
return SQLSherlockEnvironment()
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
@pytest.fixture
|
| 42 |
+
def env_task1(env):
|
| 43 |
+
env.reset(dataset=RAW_CSV_TEXT, task_id="task1_null_and_types")
|
| 44 |
+
return env
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
@pytest.fixture
|
| 48 |
+
def env_task3(env):
|
| 49 |
+
env.reset(dataset=RAW_CSV_TEXT, task_id="task3_full_audit_with_trap")
|
| 50 |
+
return env
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
# ---------------------------------------------------------------------------
|
| 54 |
+
# TASKS catalogue
|
| 55 |
+
# ---------------------------------------------------------------------------
|
| 56 |
+
|
| 57 |
+
class TestTasksCatalogue:
|
| 58 |
+
def test_three_tasks_defined(self):
|
| 59 |
+
assert len(TASKS) == 3
|
| 60 |
+
|
| 61 |
+
def test_task_ids_correct(self):
|
| 62 |
+
ids = {t["id"] for t in TASKS}
|
| 63 |
+
assert ids == {
|
| 64 |
+
"task1_null_and_types",
|
| 65 |
+
"task2_constraints_and_fk",
|
| 66 |
+
"task3_full_audit_with_trap",
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
def test_tasks_have_required_fields(self):
|
| 70 |
+
for t in TASKS:
|
| 71 |
+
for field in ("id", "name", "difficulty", "max_steps", "description"):
|
| 72 |
+
assert field in t, f"Task missing field '{field}': {t}"
|
| 73 |
+
|
| 74 |
+
def test_max_steps_values(self):
|
| 75 |
+
step_map = {t["id"]: t["max_steps"] for t in TASKS}
|
| 76 |
+
assert step_map["task1_null_and_types"] == 20
|
| 77 |
+
assert step_map["task2_constraints_and_fk"] == 25
|
| 78 |
+
assert step_map["task3_full_audit_with_trap"] == 30
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
# ---------------------------------------------------------------------------
|
| 82 |
+
# reset() validation
|
| 83 |
+
# ---------------------------------------------------------------------------
|
| 84 |
+
|
| 85 |
+
class TestReset:
|
| 86 |
+
def test_reset_returns_observation(self, env):
|
| 87 |
+
obs = env.reset(dataset=RAW_CSV_TEXT, task_id="task1_null_and_types")
|
| 88 |
+
assert isinstance(obs, SQLSherlockObservation)
|
| 89 |
+
|
| 90 |
+
def test_reset_populates_tables_summary(self, env):
|
| 91 |
+
obs = env.reset(dataset=RAW_CSV_TEXT, task_id="task1_null_and_types")
|
| 92 |
+
assert len(obs.tables_summary) > 0
|
| 93 |
+
|
| 94 |
+
def test_reset_task_description_set(self, env):
|
| 95 |
+
obs = env.reset(dataset=RAW_CSV_TEXT, task_id="task2_constraints_and_fk")
|
| 96 |
+
assert "Task" in obs.task_description or len(obs.task_description) > 0
|
| 97 |
+
|
| 98 |
+
def test_reset_step_zero(self, env):
|
| 99 |
+
obs = env.reset(dataset=RAW_CSV_TEXT, task_id="task1_null_and_types")
|
| 100 |
+
assert obs.step == 0
|
| 101 |
+
|
| 102 |
+
def test_reset_no_dataset_raises(self, env):
|
| 103 |
+
with pytest.raises(ValueError, match="dataset"):
|
| 104 |
+
env.reset(dataset="", task_id="task1_null_and_types")
|
| 105 |
+
|
| 106 |
+
def test_reset_no_task_raises(self, env):
|
| 107 |
+
with pytest.raises(ValueError, match="task_id"):
|
| 108 |
+
env.reset(dataset=RAW_CSV_TEXT, task_id="")
|
| 109 |
+
|
| 110 |
+
def test_reset_invalid_task_raises(self, env):
|
| 111 |
+
with pytest.raises(ValueError, match="Unknown task_id"):
|
| 112 |
+
env.reset(dataset=RAW_CSV_TEXT, task_id="task99_bad")
|
| 113 |
+
|
| 114 |
+
def test_reset_clears_reward_trace(self, env):
|
| 115 |
+
env.reset(dataset=RAW_CSV_TEXT, task_id="task1_null_and_types")
|
| 116 |
+
env.step(SQLSherlockAction(action_type="inspect",
|
| 117 |
+
table=list(env._db.table_names())[0]))
|
| 118 |
+
# Second reset should clear trace
|
| 119 |
+
obs = env.reset(dataset=RAW_CSV_TEXT, task_id="task1_null_and_types")
|
| 120 |
+
assert obs.reward_trace == []
|
| 121 |
+
|
| 122 |
+
def test_reset_before_step_raises(self, env):
|
| 123 |
+
with pytest.raises(RuntimeError):
|
| 124 |
+
env.step(SQLSherlockAction(action_type="inspect"))
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
# ---------------------------------------------------------------------------
|
| 128 |
+
# step() — inspect
|
| 129 |
+
# ---------------------------------------------------------------------------
|
| 130 |
+
|
| 131 |
+
class TestStepInspect:
|
| 132 |
+
def test_inspect_returns_rows(self, env_task1):
|
| 133 |
+
table = list(env_task1._db.table_names())[0]
|
| 134 |
+
obs, reward, done, info = _step(env_task1,
|
| 135 |
+
SQLSherlockAction(action_type="inspect", table=table)
|
| 136 |
+
)
|
| 137 |
+
assert obs.query_result is not None
|
| 138 |
+
assert len(obs.query_result) > 0
|
| 139 |
+
|
| 140 |
+
def test_inspect_positive_reward(self, env_task1):
|
| 141 |
+
table = list(env_task1._db.table_names())[0]
|
| 142 |
+
_, reward, _, _ = _step(env_task1,
|
| 143 |
+
SQLSherlockAction(action_type="inspect", table=table)
|
| 144 |
+
)
|
| 145 |
+
assert reward > 0
|
| 146 |
+
|
| 147 |
+
def test_inspect_capped_at_3(self, env_task1):
|
| 148 |
+
table = list(env_task1._db.table_names())[0]
|
| 149 |
+
rewards = []
|
| 150 |
+
for _ in range(5):
|
| 151 |
+
_, r, _, _ = _step(env_task1,
|
| 152 |
+
SQLSherlockAction(action_type="inspect", table=table)
|
| 153 |
+
)
|
| 154 |
+
rewards.append(r)
|
| 155 |
+
# First 3 positive, after that 0
|
| 156 |
+
assert rewards[0] > 0
|
| 157 |
+
assert rewards[1] > 0
|
| 158 |
+
assert rewards[2] > 0
|
| 159 |
+
assert rewards[3] == 0.0
|
| 160 |
+
assert rewards[4] == 0.0
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
# ---------------------------------------------------------------------------
|
| 164 |
+
# step() — profile_column
|
| 165 |
+
# ---------------------------------------------------------------------------
|
| 166 |
+
|
| 167 |
+
class TestStepProfileColumn:
|
| 168 |
+
def test_profile_returns_stats(self, env_task1):
|
| 169 |
+
table = list(env_task1._db.table_names())[0]
|
| 170 |
+
obs, reward, done, _ = _step(env_task1,
|
| 171 |
+
SQLSherlockAction(action_type="profile_column",
|
| 172 |
+
table=table, column="fare")
|
| 173 |
+
)
|
| 174 |
+
assert obs.query_result is not None
|
| 175 |
+
profile = obs.query_result[0]
|
| 176 |
+
assert "mean" in profile
|
| 177 |
+
assert "std" in profile
|
| 178 |
+
assert "z_scores" in profile
|
| 179 |
+
|
| 180 |
+
def test_profile_missing_column_gives_feedback(self, env_task1):
|
| 181 |
+
table = list(env_task1._db.table_names())[0]
|
| 182 |
+
obs, _, _, _ = _step(env_task1,
|
| 183 |
+
SQLSherlockAction(action_type="profile_column",
|
| 184 |
+
table=table, column="nonexistent_col")
|
| 185 |
+
)
|
| 186 |
+
assert "error" in obs.last_feedback.lower() or "not found" in obs.last_feedback.lower()
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
# ---------------------------------------------------------------------------
|
| 190 |
+
# step() — run_sql
|
| 191 |
+
# ---------------------------------------------------------------------------
|
| 192 |
+
|
| 193 |
+
class TestStepRunSQL:
|
| 194 |
+
def test_select_query_works(self, env_task1):
|
| 195 |
+
table = list(env_task1._db.table_names())[0]
|
| 196 |
+
obs, reward, done, _ = _step(env_task1,
|
| 197 |
+
SQLSherlockAction(
|
| 198 |
+
action_type="run_sql",
|
| 199 |
+
sql=f'SELECT * FROM "{table}" LIMIT 3',
|
| 200 |
+
)
|
| 201 |
+
)
|
| 202 |
+
assert obs.query_result is not None
|
| 203 |
+
assert len(obs.query_result) <= 3
|
| 204 |
+
|
| 205 |
+
def test_blocked_keyword_gives_error_feedback(self, env_task1):
|
| 206 |
+
obs, _, _, _ = _step(env_task1,
|
| 207 |
+
SQLSherlockAction(
|
| 208 |
+
action_type="run_sql",
|
| 209 |
+
sql="DROP TABLE passengers",
|
| 210 |
+
)
|
| 211 |
+
)
|
| 212 |
+
assert "error" in obs.last_feedback.lower() or "blocked" in obs.last_feedback.lower()
|
| 213 |
+
|
| 214 |
+
def test_non_select_gives_error_feedback(self, env_task1):
|
| 215 |
+
obs, _, _, _ = _step(env_task1,
|
| 216 |
+
SQLSherlockAction(
|
| 217 |
+
action_type="run_sql",
|
| 218 |
+
sql="UPDATE passengers SET age=0",
|
| 219 |
+
)
|
| 220 |
+
)
|
| 221 |
+
assert "error" in obs.last_feedback.lower() or "select" in obs.last_feedback.lower()
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
# ---------------------------------------------------------------------------
|
| 225 |
+
# step() — fix_cell
|
| 226 |
+
# ---------------------------------------------------------------------------
|
| 227 |
+
|
| 228 |
+
class TestStepFixCell:
|
| 229 |
+
def test_fix_real_issue_positive_reward(self, env_task1):
|
| 230 |
+
# Find a null issue
|
| 231 |
+
null_issue = next(
|
| 232 |
+
(i for i in env_task1._db.issue_registry if i.issue_type == "null"),
|
| 233 |
+
None,
|
| 234 |
+
)
|
| 235 |
+
if null_issue is None:
|
| 236 |
+
pytest.skip("No null issues in registry")
|
| 237 |
+
_, reward, _, _ = _step(env_task1,
|
| 238 |
+
SQLSherlockAction(
|
| 239 |
+
action_type="fix_cell",
|
| 240 |
+
table=null_issue.table,
|
| 241 |
+
row_id=null_issue.row_id,
|
| 242 |
+
column=null_issue.column,
|
| 243 |
+
value=30,
|
| 244 |
+
reason="median imputation",
|
| 245 |
+
)
|
| 246 |
+
)
|
| 247 |
+
assert reward > 0
|
| 248 |
+
|
| 249 |
+
def test_fix_clean_cell_negative_reward(self, env_task1):
|
| 250 |
+
# Fix a cell not in the issue registry
|
| 251 |
+
table = env_task1._db.primary_table
|
| 252 |
+
pk = env_task1._db.pk_col
|
| 253 |
+
issue_cells = {(i.row_id, i.column) for i in env_task1._db.issue_registry}
|
| 254 |
+
rows = env_task1._db.rows(table)
|
| 255 |
+
target = None
|
| 256 |
+
for row in rows:
|
| 257 |
+
rid = row[pk]
|
| 258 |
+
for col in row:
|
| 259 |
+
if col not in (pk, "_source_format") and (rid, col) not in issue_cells:
|
| 260 |
+
target = (rid, col)
|
| 261 |
+
break
|
| 262 |
+
if target:
|
| 263 |
+
break
|
| 264 |
+
if target is None:
|
| 265 |
+
pytest.skip("No clean cell available to test FP")
|
| 266 |
+
_, reward, _, _ = _step(env_task1,
|
| 267 |
+
SQLSherlockAction(
|
| 268 |
+
action_type="fix_cell",
|
| 269 |
+
table=table,
|
| 270 |
+
row_id=target[0],
|
| 271 |
+
column=target[1],
|
| 272 |
+
value="TAMPERED",
|
| 273 |
+
reason="test",
|
| 274 |
+
)
|
| 275 |
+
)
|
| 276 |
+
assert reward < 0
|
| 277 |
+
|
| 278 |
+
def test_fix_trap_negative_reward(self, env_task3):
|
| 279 |
+
trap = env_task3._db.trap
|
| 280 |
+
if trap is None:
|
| 281 |
+
pytest.skip("No trap in this episode")
|
| 282 |
+
_, reward, _, _ = _step(env_task3,
|
| 283 |
+
SQLSherlockAction(
|
| 284 |
+
action_type="fix_cell",
|
| 285 |
+
table=trap.table,
|
| 286 |
+
row_id=trap.row_id,
|
| 287 |
+
column=trap.column,
|
| 288 |
+
value=trap.original,
|
| 289 |
+
reason="looks like outlier",
|
| 290 |
+
)
|
| 291 |
+
)
|
| 292 |
+
assert reward <= -0.39
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
# ---------------------------------------------------------------------------
|
| 296 |
+
# step() — validate
|
| 297 |
+
# ---------------------------------------------------------------------------
|
| 298 |
+
|
| 299 |
+
class TestStepValidate:
|
| 300 |
+
def test_validate_returns_result(self, env_task1):
|
| 301 |
+
obs, _, _, _ = _step(env_task1,
|
| 302 |
+
SQLSherlockAction(action_type="validate")
|
| 303 |
+
)
|
| 304 |
+
assert obs.validation_result is not None
|
| 305 |
+
assert "checks_passed" in obs.validation_result
|
| 306 |
+
assert "overall" in obs.validation_result
|
| 307 |
+
|
| 308 |
+
def test_validate_reward_capped_at_2(self, env_task1):
|
| 309 |
+
rewards = []
|
| 310 |
+
for _ in range(4):
|
| 311 |
+
_, r, _, _ = _step(env_task1,
|
| 312 |
+
SQLSherlockAction(action_type="validate")
|
| 313 |
+
)
|
| 314 |
+
rewards.append(r)
|
| 315 |
+
# Reward only for first 2 calls
|
| 316 |
+
assert rewards[2] == 0.0
|
| 317 |
+
assert rewards[3] == 0.0
|
| 318 |
+
|
| 319 |
+
def test_validate_sets_validation_called(self, env_task1):
|
| 320 |
+
assert env_task1._validation_called is False
|
| 321 |
+
env_task1.step(SQLSherlockAction(action_type="validate"))
|
| 322 |
+
assert env_task1._validation_called is True
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
# ---------------------------------------------------------------------------
|
| 326 |
+
# step() — submit
|
| 327 |
+
# ---------------------------------------------------------------------------
|
| 328 |
+
|
| 329 |
+
class TestStepSubmit:
|
| 330 |
+
def test_submit_ends_episode(self, env_task1):
|
| 331 |
+
_, _, done, _ = _step(env_task1,
|
| 332 |
+
SQLSherlockAction(action_type="submit")
|
| 333 |
+
)
|
| 334 |
+
assert done is True
|
| 335 |
+
|
| 336 |
+
def test_submit_with_open_issues_negative_reward(self, env_task1):
|
| 337 |
+
_, reward, _, _ = _step(env_task1,
|
| 338 |
+
SQLSherlockAction(action_type="submit")
|
| 339 |
+
)
|
| 340 |
+
# Issues still open -> negative reward
|
| 341 |
+
assert reward < 0
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
# ---------------------------------------------------------------------------
|
| 345 |
+
# step() — export
|
| 346 |
+
# ---------------------------------------------------------------------------
|
| 347 |
+
|
| 348 |
+
class TestStepExport:
|
| 349 |
+
def test_export_ends_episode(self, env_task1):
|
| 350 |
+
_, _, done, _ = _step(env_task1,
|
| 351 |
+
SQLSherlockAction(action_type="export")
|
| 352 |
+
)
|
| 353 |
+
assert done is True
|
| 354 |
+
|
| 355 |
+
def test_export_feedback_contains_download(self, env_task1):
|
| 356 |
+
obs, _, _, _ = _step(env_task1,
|
| 357 |
+
SQLSherlockAction(action_type="export")
|
| 358 |
+
)
|
| 359 |
+
assert "download" in obs.last_feedback.lower() or "export" in obs.last_feedback.lower()
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
# ---------------------------------------------------------------------------
|
| 363 |
+
# Reward trace
|
| 364 |
+
# ---------------------------------------------------------------------------
|
| 365 |
+
|
| 366 |
+
class TestRewardTrace:
|
| 367 |
+
def test_reward_trace_grows_each_step(self, env_task1):
|
| 368 |
+
table = list(env_task1._db.table_names())[0]
|
| 369 |
+
for i in range(3):
|
| 370 |
+
obs, _, _, _ = _step(env_task1,
|
| 371 |
+
SQLSherlockAction(action_type="inspect", table=table)
|
| 372 |
+
)
|
| 373 |
+
assert len(obs.reward_trace) == 3
|
| 374 |
+
|
| 375 |
+
def test_reward_trace_has_required_keys(self, env_task1):
|
| 376 |
+
table = list(env_task1._db.table_names())[0]
|
| 377 |
+
obs, _, _, _ = _step(env_task1,
|
| 378 |
+
SQLSherlockAction(action_type="inspect", table=table)
|
| 379 |
+
)
|
| 380 |
+
entry = obs.reward_trace[-1]
|
| 381 |
+
for key in ("invest", "fix_delta", "validate_b", "penalty", "total", "step", "action_type"):
|
| 382 |
+
assert key in entry, f"reward_trace entry missing key '{key}'"
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
# ---------------------------------------------------------------------------
|
| 386 |
+
# Max-steps termination
|
| 387 |
+
# ---------------------------------------------------------------------------
|
| 388 |
+
|
| 389 |
+
class TestMaxSteps:
|
| 390 |
+
def test_done_at_max_steps(self, env):
|
| 391 |
+
env.reset(dataset=RAW_CSV_TEXT, task_id="task1_null_and_types")
|
| 392 |
+
table = list(env._db.table_names())[0]
|
| 393 |
+
done = False
|
| 394 |
+
for _ in range(25): # more than max_steps=20
|
| 395 |
+
_, _, done, _ = _step(env,
|
| 396 |
+
SQLSherlockAction(action_type="inspect", table=table)
|
| 397 |
+
)
|
| 398 |
+
if done:
|
| 399 |
+
break
|
| 400 |
+
assert done is True
|
| 401 |
+
|
| 402 |
+
|
| 403 |
+
# ---------------------------------------------------------------------------
|
| 404 |
+
# get_state()
|
| 405 |
+
# ---------------------------------------------------------------------------
|
| 406 |
+
|
| 407 |
+
class TestGetState:
|
| 408 |
+
def test_get_state_returns_state(self, env_task1):
|
| 409 |
+
state = env_task1.get_state()
|
| 410 |
+
assert isinstance(state, SQLSherlockState)
|
| 411 |
+
|
| 412 |
+
def test_get_state_task_id(self, env_task1):
|
| 413 |
+
state = env_task1.get_state()
|
| 414 |
+
assert state.task_id == "task1_null_and_types"
|
| 415 |
+
|
| 416 |
+
def test_get_state_step_count_increments(self, env_task1):
|
| 417 |
+
table = list(env_task1._db.table_names())[0]
|
| 418 |
+
env_task1.step(SQLSherlockAction(action_type="inspect", table=table))
|
| 419 |
+
env_task1.step(SQLSherlockAction(action_type="inspect", table=table))
|
| 420 |
+
state = env_task1.get_state()
|
| 421 |
+
assert state.step_count == 2
|
| 422 |
+
|
| 423 |
+
|
| 424 |
+
# ---------------------------------------------------------------------------
|
| 425 |
+
# Nemotron Phase 2 — minimal action compatibility
|
| 426 |
+
# ---------------------------------------------------------------------------
|
| 427 |
+
|
| 428 |
+
class TestWebSocketActionMinimal:
|
| 429 |
+
def test_action_with_only_action_type_accepted(self, env_task1):
|
| 430 |
+
"""A SQLSherlockAction with only action_type set must not crash the server."""
|
| 431 |
+
action = SQLSherlockAction(action_type="validate")
|
| 432 |
+
obs, reward, done, info = _step(env_task1, action)
|
| 433 |
+
assert isinstance(obs, SQLSherlockObservation)
|
| 434 |
+
assert isinstance(reward, float)
|
| 435 |
+
assert isinstance(done, bool)
|
| 436 |
+
|
| 437 |
+
def test_inspect_without_table_uses_primary(self, env_task1):
|
| 438 |
+
"""inspect with no table field defaults to the primary table."""
|
| 439 |
+
action = SQLSherlockAction(action_type="inspect")
|
| 440 |
+
obs, reward, done, _ = _step(env_task1, action)
|
| 441 |
+
assert obs.query_result is not None
|
| 442 |
+
|
| 443 |
+
def test_submit_without_extra_fields(self, env_task1):
|
| 444 |
+
"""submit with only action_type must terminate the episode."""
|
| 445 |
+
action = SQLSherlockAction(action_type="submit")
|
| 446 |
+
obs, reward, done, _ = _step(env_task1, action)
|
| 447 |
+
assert done is True
|
tests/test_graders.py
ADDED
|
@@ -0,0 +1,354 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
Tests for server/graders/ — universal.py, task1.py, task2.py, task3.py.
|
| 9 |
+
|
| 10 |
+
All tests use DatabaseEngine fixtures from conftest.py.
|
| 11 |
+
No network calls, no HuggingFace token required.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import copy
|
| 15 |
+
import pytest
|
| 16 |
+
|
| 17 |
+
from server import graders
|
| 18 |
+
from server.graders.universal import (
|
| 19 |
+
grade as universal_grade,
|
| 20 |
+
_rows_identical,
|
| 21 |
+
_values_match,
|
| 22 |
+
_false_positive_penalty,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
# ---------------------------------------------------------------------------
|
| 27 |
+
# Helpers
|
| 28 |
+
# ---------------------------------------------------------------------------
|
| 29 |
+
|
| 30 |
+
def _current(db) -> list[dict]:
|
| 31 |
+
"""Return current rows as plain dicts."""
|
| 32 |
+
return db.rows(db.primary_table)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def _apply_all_fixes(db) -> list[dict]:
|
| 36 |
+
"""Fix every issue in the registry and return the updated rows."""
|
| 37 |
+
from server.issue_detector import SENTINEL_UNKNOWN
|
| 38 |
+
for iss in db.issue_registry:
|
| 39 |
+
if iss.issue_type in ("duplicate", "fk_violation"):
|
| 40 |
+
try:
|
| 41 |
+
db.delete_row(db.primary_table, iss.row_id)
|
| 42 |
+
except Exception:
|
| 43 |
+
pass
|
| 44 |
+
elif iss.correct is not None and iss.correct != SENTINEL_UNKNOWN:
|
| 45 |
+
try:
|
| 46 |
+
db.fix_cell(db.primary_table, iss.row_id, iss.column, iss.correct)
|
| 47 |
+
except Exception:
|
| 48 |
+
pass
|
| 49 |
+
elif iss.correct == SENTINEL_UNKNOWN and iss.issue_type == "null":
|
| 50 |
+
# Supply a plausible non-null value
|
| 51 |
+
try:
|
| 52 |
+
db.fix_cell(db.primary_table, iss.row_id, iss.column, 0)
|
| 53 |
+
except Exception:
|
| 54 |
+
pass
|
| 55 |
+
return _current(db)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
# ---------------------------------------------------------------------------
|
| 59 |
+
# _rows_identical
|
| 60 |
+
# ---------------------------------------------------------------------------
|
| 61 |
+
|
| 62 |
+
class TestRowsIdentical:
|
| 63 |
+
def test_identical_rows(self, db_task1):
|
| 64 |
+
rows = _current(db_task1)
|
| 65 |
+
assert _rows_identical(rows, rows, db_task1.pk_col) is True
|
| 66 |
+
|
| 67 |
+
def test_different_value(self, db_task1):
|
| 68 |
+
rows = _current(db_task1)
|
| 69 |
+
modified = copy.deepcopy(rows)
|
| 70 |
+
if modified:
|
| 71 |
+
modified[0]["fare"] = 9999.0
|
| 72 |
+
assert _rows_identical(modified, rows, db_task1.pk_col) is False
|
| 73 |
+
|
| 74 |
+
def test_different_length(self, db_task1):
|
| 75 |
+
rows = _current(db_task1)
|
| 76 |
+
assert _rows_identical(rows[:-1], rows, db_task1.pk_col) is False
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
# ---------------------------------------------------------------------------
|
| 80 |
+
# _values_match
|
| 81 |
+
# ---------------------------------------------------------------------------
|
| 82 |
+
|
| 83 |
+
class TestValuesMatch:
|
| 84 |
+
def test_numeric_close(self):
|
| 85 |
+
assert _values_match(28.0, 28.000001) is True
|
| 86 |
+
|
| 87 |
+
def test_string_case_insensitive(self):
|
| 88 |
+
assert _values_match("Alice", "alice") is True
|
| 89 |
+
|
| 90 |
+
def test_none_both(self):
|
| 91 |
+
assert _values_match(None, None) is True
|
| 92 |
+
|
| 93 |
+
def test_none_one_side(self):
|
| 94 |
+
assert _values_match(None, 5) is False
|
| 95 |
+
|
| 96 |
+
def test_int_vs_float(self):
|
| 97 |
+
assert _values_match(28, 28.0) is True
|
| 98 |
+
|
| 99 |
+
def test_clearly_different(self):
|
| 100 |
+
assert _values_match(10, 999) is False
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
# ---------------------------------------------------------------------------
|
| 104 |
+
# Zero-change guard
|
| 105 |
+
# ---------------------------------------------------------------------------
|
| 106 |
+
|
| 107 |
+
class TestZeroChangeGuard:
|
| 108 |
+
def test_zero_change_returns_zero(self, db_task1):
|
| 109 |
+
dirty = _current(db_task1)
|
| 110 |
+
score = graders.grade(
|
| 111 |
+
db=db_task1,
|
| 112 |
+
cleaned_rows=dirty,
|
| 113 |
+
removed_ids=[],
|
| 114 |
+
task_id="task1_null_and_types",
|
| 115 |
+
validation_was_called=False,
|
| 116 |
+
)
|
| 117 |
+
assert score == 0.0
|
| 118 |
+
|
| 119 |
+
def test_zero_change_no_issues_returns_nonzero(self):
|
| 120 |
+
"""If there are genuinely no issues, returning dirty rows is acceptable."""
|
| 121 |
+
# Use a clean dataset — detect_issues will top-up synthetically,
|
| 122 |
+
# so we can't easily test "truly zero issues" without mocking.
|
| 123 |
+
# Instead verify the guard doesn't fire when rows differ.
|
| 124 |
+
pass # covered by test_full_fix_scores_high below
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
# ---------------------------------------------------------------------------
|
| 128 |
+
# Task 1 grader
|
| 129 |
+
# ---------------------------------------------------------------------------
|
| 130 |
+
|
| 131 |
+
class TestTask1Grader:
|
| 132 |
+
def test_full_fix_scores_high(self, db_task1):
|
| 133 |
+
cleaned = _apply_all_fixes(db_task1)
|
| 134 |
+
removed = []
|
| 135 |
+
score = graders.grade(
|
| 136 |
+
db=db_task1,
|
| 137 |
+
cleaned_rows=cleaned,
|
| 138 |
+
removed_ids=removed,
|
| 139 |
+
task_id="task1_null_and_types",
|
| 140 |
+
validation_was_called=True,
|
| 141 |
+
)
|
| 142 |
+
assert score >= 0.60, f"Expected >= 0.60 after full fix, got {score}"
|
| 143 |
+
|
| 144 |
+
def test_no_fix_scores_zero(self, db_task1):
|
| 145 |
+
dirty = _current(db_task1)
|
| 146 |
+
score = graders.grade(
|
| 147 |
+
db=db_task1,
|
| 148 |
+
cleaned_rows=dirty,
|
| 149 |
+
removed_ids=[],
|
| 150 |
+
task_id="task1_null_and_types",
|
| 151 |
+
validation_was_called=False,
|
| 152 |
+
)
|
| 153 |
+
assert score == 0.0
|
| 154 |
+
|
| 155 |
+
def test_score_in_range(self, db_task1):
|
| 156 |
+
cleaned = _apply_all_fixes(db_task1)
|
| 157 |
+
score = graders.grade(
|
| 158 |
+
db=db_task1,
|
| 159 |
+
cleaned_rows=cleaned,
|
| 160 |
+
removed_ids=[],
|
| 161 |
+
task_id="task1_null_and_types",
|
| 162 |
+
validation_was_called=True,
|
| 163 |
+
)
|
| 164 |
+
assert 0.0 <= score <= 1.0
|
| 165 |
+
|
| 166 |
+
def test_no_validate_penalty(self, db_task1):
|
| 167 |
+
cleaned = _apply_all_fixes(db_task1)
|
| 168 |
+
score_with = graders.grade(db_task1, cleaned, [], "task1_null_and_types", True)
|
| 169 |
+
score_without = graders.grade(db_task1, cleaned, [], "task1_null_and_types", False)
|
| 170 |
+
assert score_with >= score_without
|
| 171 |
+
|
| 172 |
+
def test_false_positive_reduces_score(self, db_task1):
|
| 173 |
+
cleaned = _apply_all_fixes(db_task1)
|
| 174 |
+
# Corrupt a clean cell
|
| 175 |
+
clean_copy = copy.deepcopy(cleaned)
|
| 176 |
+
for row in clean_copy:
|
| 177 |
+
if row.get("survived") is not None:
|
| 178 |
+
row["survived"] = 99 # not an issue
|
| 179 |
+
break
|
| 180 |
+
score_fp = graders.grade(db_task1, clean_copy, [], "task1_null_and_types", True)
|
| 181 |
+
score_ok = graders.grade(db_task1, cleaned, [], "task1_null_and_types", True)
|
| 182 |
+
assert score_fp <= score_ok
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
# ---------------------------------------------------------------------------
|
| 186 |
+
# Task 2 grader
|
| 187 |
+
# ---------------------------------------------------------------------------
|
| 188 |
+
|
| 189 |
+
class TestTask2Grader:
|
| 190 |
+
def test_full_fix_scores_high(self, db_task2):
|
| 191 |
+
cleaned = _apply_all_fixes(db_task2)
|
| 192 |
+
removed = [
|
| 193 |
+
iss.row_id for iss in db_task2.issue_registry
|
| 194 |
+
if iss.issue_type in ("duplicate", "fk_violation")
|
| 195 |
+
]
|
| 196 |
+
score = graders.grade(
|
| 197 |
+
db=db_task2,
|
| 198 |
+
cleaned_rows=cleaned,
|
| 199 |
+
removed_ids=removed,
|
| 200 |
+
task_id="task2_constraints_and_fk",
|
| 201 |
+
validation_was_called=True,
|
| 202 |
+
)
|
| 203 |
+
assert score >= 0.50, f"Expected >= 0.50 after full fix, got {score}"
|
| 204 |
+
|
| 205 |
+
def test_score_in_range(self, db_task2):
|
| 206 |
+
cleaned = _apply_all_fixes(db_task2)
|
| 207 |
+
score = graders.grade(
|
| 208 |
+
db=db_task2,
|
| 209 |
+
cleaned_rows=cleaned,
|
| 210 |
+
removed_ids=[],
|
| 211 |
+
task_id="task2_constraints_and_fk",
|
| 212 |
+
validation_was_called=True,
|
| 213 |
+
)
|
| 214 |
+
assert 0.0 <= score <= 1.0
|
| 215 |
+
|
| 216 |
+
def test_task2_score_leq_task1_on_same_fixes(self, db_task1, db_task2):
|
| 217 |
+
"""task2 weight means full fix may score differently — both must be in range."""
|
| 218 |
+
c1 = _apply_all_fixes(db_task1)
|
| 219 |
+
c2 = _apply_all_fixes(db_task2)
|
| 220 |
+
s1 = graders.grade(db_task1, c1, [], "task1_null_and_types", True)
|
| 221 |
+
s2 = graders.grade(db_task2, c2, [], "task2_constraints_and_fk", True)
|
| 222 |
+
assert 0.0 <= s1 <= 1.0
|
| 223 |
+
assert 0.0 <= s2 <= 1.0
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
# ---------------------------------------------------------------------------
|
| 227 |
+
# Task 3 grader
|
| 228 |
+
# ---------------------------------------------------------------------------
|
| 229 |
+
|
| 230 |
+
class TestTask3Grader:
|
| 231 |
+
def test_score_in_range(self, db_task3):
|
| 232 |
+
cleaned = _apply_all_fixes(db_task3)
|
| 233 |
+
score = graders.grade(
|
| 234 |
+
db=db_task3,
|
| 235 |
+
cleaned_rows=cleaned,
|
| 236 |
+
removed_ids=[],
|
| 237 |
+
task_id="task3_full_audit_with_trap",
|
| 238 |
+
validation_was_called=True,
|
| 239 |
+
)
|
| 240 |
+
assert 0.0 <= score <= 1.0
|
| 241 |
+
|
| 242 |
+
def test_trap_penalty_applied(self, db_task3):
|
| 243 |
+
"""Touching the trap cell must reduce the score."""
|
| 244 |
+
trap = db_task3.trap
|
| 245 |
+
if trap is None:
|
| 246 |
+
pytest.skip("No trap available for this dataset")
|
| 247 |
+
|
| 248 |
+
cleaned_no_touch = _current(db_task3)
|
| 249 |
+
cleaned_touched = copy.deepcopy(cleaned_no_touch)
|
| 250 |
+
|
| 251 |
+
# Simulate touching the trap — change trap cell value
|
| 252 |
+
for row in cleaned_touched:
|
| 253 |
+
if row.get(db_task3.pk_col) == trap.row_id:
|
| 254 |
+
row[trap.column] = trap.original # "fix" to original = still a touch
|
| 255 |
+
break
|
| 256 |
+
|
| 257 |
+
score_untouched = graders.grade(
|
| 258 |
+
db_task3, cleaned_no_touch, [],
|
| 259 |
+
"task3_full_audit_with_trap", True,
|
| 260 |
+
)
|
| 261 |
+
score_touched = graders.grade(
|
| 262 |
+
db_task3, cleaned_touched, [],
|
| 263 |
+
"task3_full_audit_with_trap", True,
|
| 264 |
+
)
|
| 265 |
+
assert score_touched < score_untouched or score_touched <= score_untouched
|
| 266 |
+
|
| 267 |
+
def test_reasoning_bonus_with_stat_terms(self, db_task3):
|
| 268 |
+
"""Reasoning bonus fires when action log contains stat terms."""
|
| 269 |
+
from models import SQLSherlockAction
|
| 270 |
+
db_task3.log_action(
|
| 271 |
+
SQLSherlockAction(
|
| 272 |
+
action_type="fix_cell",
|
| 273 |
+
table=db_task3.primary_table,
|
| 274 |
+
row_id=1,
|
| 275 |
+
column="age",
|
| 276 |
+
value=30,
|
| 277 |
+
reason="z-score is 6.2, well above threshold of 5, mean=28.5, std=7.1",
|
| 278 |
+
)
|
| 279 |
+
)
|
| 280 |
+
db_task3._validation_called = True
|
| 281 |
+
|
| 282 |
+
cleaned = _apply_all_fixes(db_task3)
|
| 283 |
+
score_with_reason = graders.grade(
|
| 284 |
+
db_task3, cleaned, [],
|
| 285 |
+
"task3_full_audit_with_trap", True,
|
| 286 |
+
)
|
| 287 |
+
assert score_with_reason >= 0.0
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
# ---------------------------------------------------------------------------
|
| 291 |
+
# Unknown task raises
|
| 292 |
+
# ---------------------------------------------------------------------------
|
| 293 |
+
|
| 294 |
+
class TestUnknownTask:
|
| 295 |
+
def test_unknown_task_raises(self, db_task1):
|
| 296 |
+
with pytest.raises(ValueError, match="Unknown task_id"):
|
| 297 |
+
graders.grade(
|
| 298 |
+
db=db_task1,
|
| 299 |
+
cleaned_rows=_current(db_task1),
|
| 300 |
+
removed_ids=[],
|
| 301 |
+
task_id="task99_nonexistent",
|
| 302 |
+
validation_was_called=False,
|
| 303 |
+
)
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
# ---------------------------------------------------------------------------
|
| 307 |
+
# False positive penalty
|
| 308 |
+
# ---------------------------------------------------------------------------
|
| 309 |
+
|
| 310 |
+
class TestFalsePositivePenalty:
|
| 311 |
+
def test_no_fp_on_perfect_fix(self, db_task1):
|
| 312 |
+
cleaned = _apply_all_fixes(db_task1)
|
| 313 |
+
penalty = _false_positive_penalty(
|
| 314 |
+
db_task1, cleaned, [], db_task1.pk_col, db_task1.primary_table
|
| 315 |
+
)
|
| 316 |
+
assert penalty == 0.0
|
| 317 |
+
|
| 318 |
+
def test_fp_penalty_on_changed_clean_cell(self, db_task1):
|
| 319 |
+
cleaned = _apply_all_fixes(db_task1)
|
| 320 |
+
dirty_copy = copy.deepcopy(cleaned)
|
| 321 |
+
# Modify a cell that is NOT in the issue registry
|
| 322 |
+
issue_cells = {(i.row_id, i.column) for i in db_task1.issue_registry}
|
| 323 |
+
for row in dirty_copy:
|
| 324 |
+
rid = row.get(db_task1.pk_col)
|
| 325 |
+
for col in row:
|
| 326 |
+
if col in (db_task1.pk_col, "_source_format"):
|
| 327 |
+
continue
|
| 328 |
+
if (rid, col) not in issue_cells:
|
| 329 |
+
row[col] = "TAMPERED"
|
| 330 |
+
break
|
| 331 |
+
else:
|
| 332 |
+
continue
|
| 333 |
+
break
|
| 334 |
+
|
| 335 |
+
penalty = _false_positive_penalty(
|
| 336 |
+
db_task1, dirty_copy, [], db_task1.pk_col, db_task1.primary_table
|
| 337 |
+
)
|
| 338 |
+
assert penalty > 0.0
|
| 339 |
+
|
| 340 |
+
def test_fp_penalty_capped_at_020(self, db_task1):
|
| 341 |
+
cleaned = _current(db_task1)
|
| 342 |
+
# Tamper every non-issue cell
|
| 343 |
+
issue_cells = {(i.row_id, i.column) for i in db_task1.issue_registry}
|
| 344 |
+
tampered = copy.deepcopy(cleaned)
|
| 345 |
+
for row in tampered:
|
| 346 |
+
rid = row.get(db_task1.pk_col)
|
| 347 |
+
for col in list(row.keys()):
|
| 348 |
+
if col not in (db_task1.pk_col, "_source_format"):
|
| 349 |
+
if (rid, col) not in issue_cells:
|
| 350 |
+
row[col] = "BAD"
|
| 351 |
+
penalty = _false_positive_penalty(
|
| 352 |
+
db_task1, tampered, [], db_task1.pk_col, db_task1.primary_table
|
| 353 |
+
)
|
| 354 |
+
assert penalty <= 0.20
|
tests/test_issue_detector.py
ADDED
|
@@ -0,0 +1,341 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
Tests for server/issue_detector.py
|
| 9 |
+
|
| 10 |
+
Covers: real detection, confidence scoring, synthetic top-up,
|
| 11 |
+
trap planting, SENTINEL_UNKNOWN, and deduplication.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import copy
|
| 15 |
+
import sqlite3
|
| 16 |
+
|
| 17 |
+
import pytest
|
| 18 |
+
|
| 19 |
+
from server.issue_detector import (
|
| 20 |
+
SENTINEL_UNKNOWN,
|
| 21 |
+
MINIMUM_ISSUES,
|
| 22 |
+
Issue,
|
| 23 |
+
Trap,
|
| 24 |
+
detect_issues,
|
| 25 |
+
detect_trap,
|
| 26 |
+
_find_natural_key_col,
|
| 27 |
+
_detect_nulls,
|
| 28 |
+
_detect_type_errors,
|
| 29 |
+
_detect_constraints,
|
| 30 |
+
_detect_outliers,
|
| 31 |
+
_detect_duplicates,
|
| 32 |
+
)
|
| 33 |
+
from server.schema_profiler import profile_table
|
| 34 |
+
from tests.conftest import DIRTY_RECORDS, CLEAN_RECORDS
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# ---------------------------------------------------------------------------
|
| 38 |
+
# Helpers
|
| 39 |
+
# ---------------------------------------------------------------------------
|
| 40 |
+
|
| 41 |
+
def _make_conn(records: list[dict]) -> sqlite3.Connection:
|
| 42 |
+
conn = sqlite3.connect(":memory:")
|
| 43 |
+
conn.row_factory = sqlite3.Row
|
| 44 |
+
conn.execute(
|
| 45 |
+
'CREATE TABLE passengers '
|
| 46 |
+
'(id INTEGER, name TEXT, age TEXT, fare REAL, survived INTEGER)'
|
| 47 |
+
)
|
| 48 |
+
for r in records:
|
| 49 |
+
conn.execute(
|
| 50 |
+
'INSERT INTO passengers VALUES (?, ?, ?, ?, ?)',
|
| 51 |
+
(r["id"], r["name"], r.get("age"), r.get("fare"), r.get("survived")),
|
| 52 |
+
)
|
| 53 |
+
conn.commit()
|
| 54 |
+
return conn
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
# ---------------------------------------------------------------------------
|
| 58 |
+
# Null detection
|
| 59 |
+
# ---------------------------------------------------------------------------
|
| 60 |
+
|
| 61 |
+
class TestNullDetection:
|
| 62 |
+
def test_finds_null_age(self, dirty_conn, dirty_profile):
|
| 63 |
+
records = copy.deepcopy(DIRTY_RECORDS)
|
| 64 |
+
issues = _detect_nulls(records, dirty_profile, pk_col="id")
|
| 65 |
+
null_issues = [i for i in issues if i.column == "age" and i.issue_type == "null"]
|
| 66 |
+
# id=1 has age=None
|
| 67 |
+
assert any(i.row_id == 1 for i in null_issues)
|
| 68 |
+
|
| 69 |
+
def test_null_confidence_inversely_proportional_to_rate(self, dirty_conn, dirty_profile):
|
| 70 |
+
records = copy.deepcopy(DIRTY_RECORDS)
|
| 71 |
+
issues = _detect_nulls(records, dirty_profile, pk_col="id")
|
| 72 |
+
null_issues = [i for i in issues if i.issue_type == "null"]
|
| 73 |
+
for iss in null_issues:
|
| 74 |
+
assert 0.0 <= iss.confidence <= 1.0
|
| 75 |
+
|
| 76 |
+
def test_structural_nulls_low_confidence(self):
|
| 77 |
+
"""A column with 80% nulls should produce confidence ≈ 0.20."""
|
| 78 |
+
records = [
|
| 79 |
+
{"id": i, "name": f"p{i}", "cabin": None if i <= 8 else f"C{i}"}
|
| 80 |
+
for i in range(1, 11)
|
| 81 |
+
]
|
| 82 |
+
profile = profile_table("t", records)
|
| 83 |
+
conn = sqlite3.connect(":memory:")
|
| 84 |
+
issues = _detect_nulls(records, profile, pk_col="id")
|
| 85 |
+
cabin_issues = [i for i in issues if i.column == "cabin"]
|
| 86 |
+
for iss in cabin_issues:
|
| 87 |
+
assert iss.confidence <= 0.25
|
| 88 |
+
|
| 89 |
+
def test_no_nulls_on_clean_data(self, clean_conn, clean_profile):
|
| 90 |
+
records = copy.deepcopy(CLEAN_RECORDS)
|
| 91 |
+
issues = _detect_nulls(records, clean_profile, pk_col="id")
|
| 92 |
+
assert issues == []
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
# ---------------------------------------------------------------------------
|
| 96 |
+
# Type error detection
|
| 97 |
+
# ---------------------------------------------------------------------------
|
| 98 |
+
|
| 99 |
+
class TestTypeErrorDetection:
|
| 100 |
+
def test_finds_text_in_numeric_column(self, dirty_conn, dirty_profile):
|
| 101 |
+
records = copy.deepcopy(DIRTY_RECORDS)
|
| 102 |
+
issues = _detect_type_errors(records, dirty_profile, pk_col="id")
|
| 103 |
+
type_issues = [i for i in issues if i.issue_type == "type_error"]
|
| 104 |
+
# id=3 has age="FORTY"
|
| 105 |
+
assert any(i.row_id == 3 and i.column == "age" for i in type_issues)
|
| 106 |
+
|
| 107 |
+
def test_type_error_confidence_always_1(self, dirty_conn, dirty_profile):
|
| 108 |
+
records = copy.deepcopy(DIRTY_RECORDS)
|
| 109 |
+
issues = _detect_type_errors(records, dirty_profile, pk_col="id")
|
| 110 |
+
for iss in issues:
|
| 111 |
+
assert iss.confidence == 1.0
|
| 112 |
+
|
| 113 |
+
def test_correct_value_is_median(self, dirty_conn, dirty_profile):
|
| 114 |
+
records = copy.deepcopy(DIRTY_RECORDS)
|
| 115 |
+
issues = _detect_type_errors(records, dirty_profile, pk_col="id")
|
| 116 |
+
age_issues = [i for i in issues if i.column == "age"]
|
| 117 |
+
assert len(age_issues) > 0
|
| 118 |
+
# Correct should be a numeric median, not None
|
| 119 |
+
for iss in age_issues:
|
| 120 |
+
assert iss.correct is not None
|
| 121 |
+
assert isinstance(iss.correct, (int, float))
|
| 122 |
+
|
| 123 |
+
def test_no_type_errors_on_clean_data(self, clean_conn, clean_profile):
|
| 124 |
+
records = copy.deepcopy(CLEAN_RECORDS)
|
| 125 |
+
issues = _detect_type_errors(records, clean_profile, pk_col="id")
|
| 126 |
+
assert issues == []
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
# ---------------------------------------------------------------------------
|
| 130 |
+
# Constraint detection
|
| 131 |
+
# ---------------------------------------------------------------------------
|
| 132 |
+
|
| 133 |
+
class TestConstraintDetection:
|
| 134 |
+
def test_finds_negative_age(self, dirty_conn, dirty_profile):
|
| 135 |
+
records = copy.deepcopy(DIRTY_RECORDS)
|
| 136 |
+
issues = _detect_constraints(records, dirty_profile, pk_col="id")
|
| 137 |
+
# id=4 has age=-5
|
| 138 |
+
assert any(i.row_id == 4 and i.column == "age" for i in issues)
|
| 139 |
+
|
| 140 |
+
def test_correct_is_abs_value(self, dirty_conn, dirty_profile):
|
| 141 |
+
records = copy.deepcopy(DIRTY_RECORDS)
|
| 142 |
+
issues = _detect_constraints(records, dirty_profile, pk_col="id")
|
| 143 |
+
neg_issues = [i for i in issues if i.issue_type == "constraint"]
|
| 144 |
+
for iss in neg_issues:
|
| 145 |
+
assert iss.correct >= 0
|
| 146 |
+
|
| 147 |
+
def test_constraint_confidence(self, dirty_conn, dirty_profile):
|
| 148 |
+
records = copy.deepcopy(DIRTY_RECORDS)
|
| 149 |
+
issues = _detect_constraints(records, dirty_profile, pk_col="id")
|
| 150 |
+
for iss in issues:
|
| 151 |
+
assert iss.confidence == 0.95
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
# ---------------------------------------------------------------------------
|
| 155 |
+
# Outlier detection
|
| 156 |
+
# ---------------------------------------------------------------------------
|
| 157 |
+
|
| 158 |
+
class TestOutlierDetection:
|
| 159 |
+
def test_finds_fare_outlier(self, dirty_conn, dirty_profile):
|
| 160 |
+
records = copy.deepcopy(DIRTY_RECORDS)
|
| 161 |
+
issues = _detect_outliers(records, dirty_profile, pk_col="id")
|
| 162 |
+
# id=5 has fare=512.33 — z >> 5
|
| 163 |
+
outlier_issues = [i for i in issues if i.column == "fare"]
|
| 164 |
+
assert any(i.row_id == 5 for i in outlier_issues)
|
| 165 |
+
|
| 166 |
+
def test_outlier_correct_is_mean(self, dirty_conn, dirty_profile):
|
| 167 |
+
records = copy.deepcopy(DIRTY_RECORDS)
|
| 168 |
+
issues = _detect_outliers(records, dirty_profile, pk_col="id")
|
| 169 |
+
for iss in issues:
|
| 170 |
+
assert iss.correct is not None
|
| 171 |
+
# correct should be close to the column mean (not the outlier value)
|
| 172 |
+
assert isinstance(iss.correct, float)
|
| 173 |
+
|
| 174 |
+
def test_normal_values_not_flagged(self, clean_conn, clean_profile):
|
| 175 |
+
records = copy.deepcopy(CLEAN_RECORDS)
|
| 176 |
+
issues = _detect_outliers(records, clean_profile, pk_col="id")
|
| 177 |
+
assert issues == []
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
# ---------------------------------------------------------------------------
|
| 181 |
+
# Duplicate detection
|
| 182 |
+
# ---------------------------------------------------------------------------
|
| 183 |
+
|
| 184 |
+
class TestDuplicateDetection:
|
| 185 |
+
def test_finds_duplicate_name(self, dirty_conn, dirty_profile):
|
| 186 |
+
records = copy.deepcopy(DIRTY_RECORDS)
|
| 187 |
+
issues = _detect_duplicates(records, dirty_profile, pk_col="id")
|
| 188 |
+
dup_issues = [i for i in issues if i.issue_type == "duplicate"]
|
| 189 |
+
# id=8 has same name as id=1 (Alice) — later row is the duplicate
|
| 190 |
+
assert any(i.row_id == 8 for i in dup_issues)
|
| 191 |
+
|
| 192 |
+
def test_first_occurrence_not_flagged(self, dirty_conn, dirty_profile):
|
| 193 |
+
records = copy.deepcopy(DIRTY_RECORDS)
|
| 194 |
+
issues = _detect_duplicates(records, dirty_profile, pk_col="id")
|
| 195 |
+
dup_ids = {i.row_id for i in issues if i.issue_type == "duplicate"}
|
| 196 |
+
assert 1 not in dup_ids # Alice (first) should NOT be flagged
|
| 197 |
+
|
| 198 |
+
def test_correct_is_none_for_duplicates(self, dirty_conn, dirty_profile):
|
| 199 |
+
records = copy.deepcopy(DIRTY_RECORDS)
|
| 200 |
+
issues = _detect_duplicates(records, dirty_profile, pk_col="id")
|
| 201 |
+
for iss in issues:
|
| 202 |
+
assert iss.correct is None # should be deleted
|
| 203 |
+
|
| 204 |
+
def test_no_duplicates_on_clean_data(self, clean_conn, clean_profile):
|
| 205 |
+
records = copy.deepcopy(CLEAN_RECORDS)
|
| 206 |
+
issues = _detect_duplicates(records, clean_profile, pk_col="id")
|
| 207 |
+
assert issues == []
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
# ---------------------------------------------------------------------------
|
| 211 |
+
# Natural key detection
|
| 212 |
+
# ---------------------------------------------------------------------------
|
| 213 |
+
|
| 214 |
+
class TestNaturalKeyDetection:
|
| 215 |
+
def test_name_column_is_natural_key(self, clean_profile):
|
| 216 |
+
key = _find_natural_key_col(clean_profile, CLEAN_RECORDS, pk_col="id")
|
| 217 |
+
assert key == "name"
|
| 218 |
+
|
| 219 |
+
def test_no_key_when_no_unique_hint_col(self):
|
| 220 |
+
records = [{"id": i, "x": i * 2.0, "y": i * 3.0} for i in range(1, 6)]
|
| 221 |
+
profile = profile_table("t", records)
|
| 222 |
+
key = _find_natural_key_col(profile, records, pk_col="id")
|
| 223 |
+
assert key is None
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
# ---------------------------------------------------------------------------
|
| 227 |
+
# Full detect_issues integration
|
| 228 |
+
# ---------------------------------------------------------------------------
|
| 229 |
+
|
| 230 |
+
class TestDetectIssues:
|
| 231 |
+
def test_task1_minimum_issues(self, dirty_conn, dirty_profile):
|
| 232 |
+
records = copy.deepcopy(DIRTY_RECORDS)
|
| 233 |
+
issues = detect_issues(dirty_conn, dirty_profile, records,
|
| 234 |
+
task_id="task1_null_and_types", seed=42)
|
| 235 |
+
assert len(issues) >= MINIMUM_ISSUES["task1_null_and_types"]
|
| 236 |
+
|
| 237 |
+
def test_task2_minimum_issues(self, dirty_conn, dirty_profile):
|
| 238 |
+
records = copy.deepcopy(DIRTY_RECORDS)
|
| 239 |
+
issues = detect_issues(dirty_conn, dirty_profile, records,
|
| 240 |
+
task_id="task2_constraints_and_fk", seed=42)
|
| 241 |
+
assert len(issues) >= MINIMUM_ISSUES["task2_constraints_and_fk"]
|
| 242 |
+
|
| 243 |
+
def test_task3_minimum_issues(self, dirty_conn, dirty_profile):
|
| 244 |
+
records = copy.deepcopy(DIRTY_RECORDS)
|
| 245 |
+
issues = detect_issues(dirty_conn, dirty_profile, records,
|
| 246 |
+
task_id="task3_full_audit_with_trap", seed=42)
|
| 247 |
+
assert len(issues) >= MINIMUM_ISSUES["task3_full_audit_with_trap"]
|
| 248 |
+
|
| 249 |
+
def test_task1_only_null_and_type_issues(self, dirty_conn, dirty_profile):
|
| 250 |
+
records = copy.deepcopy(DIRTY_RECORDS)
|
| 251 |
+
issues = detect_issues(dirty_conn, dirty_profile, records,
|
| 252 |
+
task_id="task1_null_and_types", seed=42)
|
| 253 |
+
for iss in issues:
|
| 254 |
+
assert iss.issue_type in ("null", "type_error"), (
|
| 255 |
+
f"task1 should only detect null/type_error, got {iss.issue_type}"
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
def test_no_duplicate_issue_ids(self, dirty_conn, dirty_profile):
|
| 259 |
+
records = copy.deepcopy(DIRTY_RECORDS)
|
| 260 |
+
issues = detect_issues(dirty_conn, dirty_profile, records,
|
| 261 |
+
task_id="task3_full_audit_with_trap", seed=42)
|
| 262 |
+
ids = [i.issue_id for i in issues]
|
| 263 |
+
assert len(ids) == len(set(ids)), "Duplicate issue_ids found"
|
| 264 |
+
|
| 265 |
+
def test_confidence_in_range(self, dirty_conn, dirty_profile):
|
| 266 |
+
records = copy.deepcopy(DIRTY_RECORDS)
|
| 267 |
+
issues = detect_issues(dirty_conn, dirty_profile, records,
|
| 268 |
+
task_id="task3_full_audit_with_trap", seed=42)
|
| 269 |
+
for iss in issues:
|
| 270 |
+
assert 0.0 <= iss.confidence <= 1.0, (
|
| 271 |
+
f"Issue {iss.issue_id} has out-of-range confidence {iss.confidence}"
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
def test_synthetic_topup_on_clean_data(self, clean_conn, clean_profile):
|
| 275 |
+
"""Clean data triggers synthetic top-up to meet minimum."""
|
| 276 |
+
records = copy.deepcopy(CLEAN_RECORDS)
|
| 277 |
+
issues = detect_issues(clean_conn, clean_profile, records,
|
| 278 |
+
task_id="task1_null_and_types", seed=42)
|
| 279 |
+
assert len(issues) >= MINIMUM_ISSUES["task1_null_and_types"]
|
| 280 |
+
|
| 281 |
+
def test_reproducible_with_same_seed(self, dirty_conn, dirty_profile):
|
| 282 |
+
conn2 = _make_conn(DIRTY_RECORDS)
|
| 283 |
+
profile2 = profile_table("passengers", copy.deepcopy(DIRTY_RECORDS))
|
| 284 |
+
r1 = copy.deepcopy(DIRTY_RECORDS)
|
| 285 |
+
r2 = copy.deepcopy(DIRTY_RECORDS)
|
| 286 |
+
issues1 = detect_issues(dirty_conn, dirty_profile, r1,
|
| 287 |
+
task_id="task1_null_and_types", seed=99)
|
| 288 |
+
issues2 = detect_issues(conn2, profile2, r2,
|
| 289 |
+
task_id="task1_null_and_types", seed=99)
|
| 290 |
+
assert len(issues1) == len(issues2)
|
| 291 |
+
conn2.close()
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
# ---------------------------------------------------------------------------
|
| 295 |
+
# Trap detection
|
| 296 |
+
# ---------------------------------------------------------------------------
|
| 297 |
+
|
| 298 |
+
class TestDetectTrap:
|
| 299 |
+
def test_trap_planted_for_task3(self, dirty_conn, dirty_profile):
|
| 300 |
+
records = copy.deepcopy(DIRTY_RECORDS)
|
| 301 |
+
issues = detect_issues(dirty_conn, dirty_profile, records,
|
| 302 |
+
task_id="task3_full_audit_with_trap", seed=42)
|
| 303 |
+
trap = detect_trap(dirty_conn, dirty_profile, records, issues, seed=42)
|
| 304 |
+
assert trap is not None
|
| 305 |
+
assert isinstance(trap, Trap)
|
| 306 |
+
|
| 307 |
+
def test_trap_not_in_issue_registry(self, dirty_conn, dirty_profile):
|
| 308 |
+
records = copy.deepcopy(DIRTY_RECORDS)
|
| 309 |
+
issues = detect_issues(dirty_conn, dirty_profile, records,
|
| 310 |
+
task_id="task3_full_audit_with_trap", seed=42)
|
| 311 |
+
trap = detect_trap(dirty_conn, dirty_profile, records, issues, seed=42)
|
| 312 |
+
if trap is None:
|
| 313 |
+
pytest.skip("No numeric column available for trap")
|
| 314 |
+
issue_cells = {(i.row_id, i.column) for i in issues}
|
| 315 |
+
assert (trap.row_id, trap.column) not in issue_cells
|
| 316 |
+
|
| 317 |
+
def test_trap_value_is_2x_original(self, dirty_conn, dirty_profile):
|
| 318 |
+
records = copy.deepcopy(DIRTY_RECORDS)
|
| 319 |
+
issues = detect_issues(dirty_conn, dirty_profile, records,
|
| 320 |
+
task_id="task3_full_audit_with_trap", seed=42)
|
| 321 |
+
trap = detect_trap(dirty_conn, dirty_profile, records, issues, seed=42)
|
| 322 |
+
if trap is None:
|
| 323 |
+
pytest.skip("No numeric column available for trap")
|
| 324 |
+
import math
|
| 325 |
+
assert math.isclose(trap.trap_value, trap.original * 2.0, rel_tol=1e-4)
|
| 326 |
+
|
| 327 |
+
def test_trap_written_to_sqlite(self, dirty_conn, dirty_profile):
|
| 328 |
+
records = copy.deepcopy(DIRTY_RECORDS)
|
| 329 |
+
issues = detect_issues(dirty_conn, dirty_profile, records,
|
| 330 |
+
task_id="task3_full_audit_with_trap", seed=42)
|
| 331 |
+
trap = detect_trap(dirty_conn, dirty_profile, records, issues, seed=42)
|
| 332 |
+
if trap is None:
|
| 333 |
+
pytest.skip("No numeric column available for trap")
|
| 334 |
+
# Verify the trap value is actually in the DB
|
| 335 |
+
row = dirty_conn.execute(
|
| 336 |
+
f'SELECT "{trap.column}" FROM passengers WHERE id = ?',
|
| 337 |
+
(trap.row_id,)
|
| 338 |
+
).fetchone()
|
| 339 |
+
assert row is not None
|
| 340 |
+
import math
|
| 341 |
+
assert math.isclose(float(row[0]), trap.trap_value, rel_tol=1e-4)
|
train.py
ADDED
|
@@ -0,0 +1,334 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
SQLSherlock-Env — TRL GRPO Training Script.
|
| 9 |
+
|
| 10 |
+
Fine-tunes a language model via Group Relative Policy Optimisation (GRPO)
|
| 11 |
+
using the SQLSherlock RL environment as the reward signal.
|
| 12 |
+
|
| 13 |
+
The model learns the data-scientist investigation workflow:
|
| 14 |
+
profile → hypothesise → fix → validate → export
|
| 15 |
+
|
| 16 |
+
Environment variables:
|
| 17 |
+
SPACE_URL — SQLSherlock server URL (default: http://localhost:7860)
|
| 18 |
+
MODEL_ID — Base model to fine-tune (default: Qwen/Qwen2.5-1.5B-Instruct)
|
| 19 |
+
DATASET_NAME — Training dataset (default: mstz/titanic)
|
| 20 |
+
OUTPUT_DIR — Checkpoint output dir (default: ./grpo_output)
|
| 21 |
+
NUM_STEPS — Training steps (default: 200)
|
| 22 |
+
BATCH_SIZE — Batch size (default: 4)
|
| 23 |
+
TASK_ID — Task to train on (default: task1_null_and_types)
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
import os
|
| 27 |
+
import sys
|
| 28 |
+
|
| 29 |
+
# ---------------------------------------------------------------------------
|
| 30 |
+
# Configuration
|
| 31 |
+
# ---------------------------------------------------------------------------
|
| 32 |
+
|
| 33 |
+
SPACE_URL = os.environ.get("SPACE_URL", "http://localhost:7860")
|
| 34 |
+
MODEL_ID = os.environ.get("MODEL_ID", "Qwen/Qwen2.5-1.5B-Instruct")
|
| 35 |
+
DATASET_NAME = os.environ.get("DATASET_NAME", "phihung/titanic")
|
| 36 |
+
OUTPUT_DIR = os.environ.get("OUTPUT_DIR", "./grpo_output")
|
| 37 |
+
NUM_STEPS = int(os.environ.get("NUM_STEPS", "200"))
|
| 38 |
+
BATCH_SIZE = int(os.environ.get("BATCH_SIZE", "4"))
|
| 39 |
+
TASK_ID = os.environ.get("TASK_ID", "task1_null_and_types")
|
| 40 |
+
|
| 41 |
+
# ---------------------------------------------------------------------------
|
| 42 |
+
# GRPO Environment wrapper
|
| 43 |
+
# ---------------------------------------------------------------------------
|
| 44 |
+
|
| 45 |
+
class SQLSherlockGRPOEnv:
|
| 46 |
+
"""Thin wrapper around SQLSherlockEnv exposing tool-call methods.
|
| 47 |
+
|
| 48 |
+
Each method corresponds to one action type. TRL's GRPO trainer
|
| 49 |
+
calls reset() to start an episode, then the model calls methods
|
| 50 |
+
as tool calls. The cumulative reward is read via reward_func().
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
def __init__(self) -> None:
|
| 54 |
+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "sqlsherlock_env"))
|
| 55 |
+
from client import SQLSherlockEnv
|
| 56 |
+
self._env_class = SQLSherlockEnv
|
| 57 |
+
self._client = None
|
| 58 |
+
self.reward = 0.0
|
| 59 |
+
self._primary_table: str = "dataset"
|
| 60 |
+
|
| 61 |
+
def _client_or_create(self):
|
| 62 |
+
if self._client is None:
|
| 63 |
+
self._client = self._env_class(base_url=SPACE_URL)
|
| 64 |
+
return self._client
|
| 65 |
+
|
| 66 |
+
def reset(self, **kwargs) -> str:
|
| 67 |
+
"""Reset the environment and return a string observation.
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
dataset (str): HuggingFace dataset name or file path.
|
| 71 |
+
task_id (str): Task identifier string.
|
| 72 |
+
"""
|
| 73 |
+
from client import SQLSherlockEnv
|
| 74 |
+
# Fresh client each episode for isolation
|
| 75 |
+
try:
|
| 76 |
+
if self._client is not None:
|
| 77 |
+
self._client.close()
|
| 78 |
+
except Exception:
|
| 79 |
+
pass
|
| 80 |
+
self._client = SQLSherlockEnv(base_url=SPACE_URL)
|
| 81 |
+
|
| 82 |
+
dataset = kwargs.get("dataset", DATASET_NAME)
|
| 83 |
+
task_id = kwargs.get("task_id", TASK_ID)
|
| 84 |
+
|
| 85 |
+
obs = self._client.reset(task_id=task_id, dataset=dataset)
|
| 86 |
+
self._primary_table = list(obs.tables_summary.keys())[0]
|
| 87 |
+
self.reward = 0.0
|
| 88 |
+
|
| 89 |
+
return (
|
| 90 |
+
f"Table: {self._primary_table}\n"
|
| 91 |
+
f"Columns: {obs.tables_summary[self._primary_table]['columns']}\n"
|
| 92 |
+
f"Rows: {obs.tables_summary[self._primary_table]['row_count']}\n"
|
| 93 |
+
f"Task: {obs.task_description}"
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
def inspect_table(self, table: str) -> str:
|
| 97 |
+
"""View all rows in a database table.
|
| 98 |
+
|
| 99 |
+
Args:
|
| 100 |
+
table: Name of the table to inspect.
|
| 101 |
+
"""
|
| 102 |
+
from models import SQLSherlockAction
|
| 103 |
+
obs, r, done, _ = self._client_or_create().step(
|
| 104 |
+
SQLSherlockAction(action_type="inspect", table=table)
|
| 105 |
+
)
|
| 106 |
+
self.reward += r
|
| 107 |
+
return obs.last_feedback
|
| 108 |
+
|
| 109 |
+
def profile_column(self, table: str, column: str) -> str:
|
| 110 |
+
"""Get statistical profile: mean, std, min, max, null_count, z-scores.
|
| 111 |
+
|
| 112 |
+
Args:
|
| 113 |
+
table: Table name containing the column.
|
| 114 |
+
column: Column name to profile statistically.
|
| 115 |
+
"""
|
| 116 |
+
from models import SQLSherlockAction
|
| 117 |
+
obs, r, done, _ = self._client_or_create().step(
|
| 118 |
+
SQLSherlockAction(
|
| 119 |
+
action_type="profile_column", table=table, column=column
|
| 120 |
+
)
|
| 121 |
+
)
|
| 122 |
+
self.reward += r
|
| 123 |
+
return obs.last_feedback
|
| 124 |
+
|
| 125 |
+
def run_query(self, sql: str) -> str:
|
| 126 |
+
"""Execute a SELECT SQL query to find data quality issues.
|
| 127 |
+
|
| 128 |
+
Args:
|
| 129 |
+
sql: A SELECT SQL query string. No write operations allowed.
|
| 130 |
+
"""
|
| 131 |
+
from models import SQLSherlockAction
|
| 132 |
+
obs, r, done, _ = self._client_or_create().step(
|
| 133 |
+
SQLSherlockAction(action_type="run_sql", sql=sql)
|
| 134 |
+
)
|
| 135 |
+
self.reward += r
|
| 136 |
+
return obs.last_feedback
|
| 137 |
+
|
| 138 |
+
def fix_cell(
|
| 139 |
+
self,
|
| 140 |
+
table: str,
|
| 141 |
+
row_id: int,
|
| 142 |
+
column: str,
|
| 143 |
+
value: str,
|
| 144 |
+
reason: str,
|
| 145 |
+
) -> str:
|
| 146 |
+
"""Fix a data quality issue in one cell.
|
| 147 |
+
|
| 148 |
+
Args:
|
| 149 |
+
table: Table name.
|
| 150 |
+
row_id: Row primary key.
|
| 151 |
+
column: Column to fix.
|
| 152 |
+
value: The corrected value to write.
|
| 153 |
+
reason: Statistical justification for this fix (e.g. z-score, median).
|
| 154 |
+
"""
|
| 155 |
+
from models import SQLSherlockAction
|
| 156 |
+
obs, r, done, _ = self._client_or_create().step(
|
| 157 |
+
SQLSherlockAction(
|
| 158 |
+
action_type="fix_cell",
|
| 159 |
+
table=table,
|
| 160 |
+
row_id=row_id,
|
| 161 |
+
column=column,
|
| 162 |
+
value=value,
|
| 163 |
+
reason=reason,
|
| 164 |
+
)
|
| 165 |
+
)
|
| 166 |
+
self.reward += r
|
| 167 |
+
return obs.last_feedback
|
| 168 |
+
|
| 169 |
+
def delete_row(self, table: str, row_id: int, reason: str) -> str:
|
| 170 |
+
"""Delete a duplicate or FK-violation row.
|
| 171 |
+
|
| 172 |
+
Args:
|
| 173 |
+
table: Table name.
|
| 174 |
+
row_id: Row primary key to delete.
|
| 175 |
+
reason: Why this row should be removed (e.g. duplicate key detected).
|
| 176 |
+
"""
|
| 177 |
+
from models import SQLSherlockAction
|
| 178 |
+
obs, r, done, _ = self._client_or_create().step(
|
| 179 |
+
SQLSherlockAction(
|
| 180 |
+
action_type="delete_row",
|
| 181 |
+
table=table,
|
| 182 |
+
row_id=row_id,
|
| 183 |
+
reason=reason,
|
| 184 |
+
)
|
| 185 |
+
)
|
| 186 |
+
self.reward += r
|
| 187 |
+
return obs.last_feedback
|
| 188 |
+
|
| 189 |
+
def validate(self) -> str:
|
| 190 |
+
"""Run all 6 validation checks comparing cleaned vs raw data.
|
| 191 |
+
|
| 192 |
+
Call this after making fixes to verify your work is correct.
|
| 193 |
+
Returns pass/fail status for each check.
|
| 194 |
+
"""
|
| 195 |
+
from models import SQLSherlockAction
|
| 196 |
+
obs, r, done, _ = self._client_or_create().step(
|
| 197 |
+
SQLSherlockAction(action_type="validate")
|
| 198 |
+
)
|
| 199 |
+
self.reward += r
|
| 200 |
+
return obs.last_feedback
|
| 201 |
+
|
| 202 |
+
def submit(self) -> str:
|
| 203 |
+
"""Submit the investigation for final scoring.
|
| 204 |
+
|
| 205 |
+
Call only when you have fixed all discovered issues and
|
| 206 |
+
validate() shows improvement.
|
| 207 |
+
"""
|
| 208 |
+
from models import SQLSherlockAction
|
| 209 |
+
obs, r, done, _ = self._client_or_create().step(
|
| 210 |
+
SQLSherlockAction(action_type="submit")
|
| 211 |
+
)
|
| 212 |
+
self.reward += r
|
| 213 |
+
last = obs.reward_trace[-1] if obs.reward_trace else {}
|
| 214 |
+
return f"Final reward: {last.get('total', 0.0):.4f}"
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
# ---------------------------------------------------------------------------
|
| 218 |
+
# GRPO reward function
|
| 219 |
+
# ---------------------------------------------------------------------------
|
| 220 |
+
|
| 221 |
+
def reward_func(environments: list, **kwargs) -> list[float]:
|
| 222 |
+
"""Return cumulative episode reward for each environment.
|
| 223 |
+
|
| 224 |
+
Called by TRL's GRPOTrainer after each rollout batch.
|
| 225 |
+
|
| 226 |
+
Args:
|
| 227 |
+
environments: List of SQLSherlockGRPOEnv instances.
|
| 228 |
+
|
| 229 |
+
Returns:
|
| 230 |
+
List of float rewards, one per environment.
|
| 231 |
+
"""
|
| 232 |
+
return [env.reward for env in environments]
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
# ---------------------------------------------------------------------------
|
| 236 |
+
# Training entry point
|
| 237 |
+
# ---------------------------------------------------------------------------
|
| 238 |
+
|
| 239 |
+
def main() -> None:
|
| 240 |
+
try:
|
| 241 |
+
from trl import GRPOConfig, GRPOTrainer
|
| 242 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 243 |
+
except ImportError:
|
| 244 |
+
print(
|
| 245 |
+
"Training dependencies not installed.\n"
|
| 246 |
+
"Install with: pip install 'sqlsherlock-env[train]'\n"
|
| 247 |
+
" or: pip install trl transformers torch"
|
| 248 |
+
)
|
| 249 |
+
sys.exit(1)
|
| 250 |
+
|
| 251 |
+
print(f"SQLSherlock GRPO Training")
|
| 252 |
+
print(f" Model : {MODEL_ID}")
|
| 253 |
+
print(f" Dataset : {DATASET_NAME}")
|
| 254 |
+
print(f" Task : {TASK_ID}")
|
| 255 |
+
print(f" Steps : {NUM_STEPS}")
|
| 256 |
+
print(f" Output : {OUTPUT_DIR}")
|
| 257 |
+
print(f" Server : {SPACE_URL}")
|
| 258 |
+
print()
|
| 259 |
+
|
| 260 |
+
# Load model and tokenizer
|
| 261 |
+
print("Loading model...")
|
| 262 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
| 263 |
+
model = AutoModelForCausalLM.from_pretrained(MODEL_ID)
|
| 264 |
+
|
| 265 |
+
if tokenizer.pad_token is None:
|
| 266 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 267 |
+
|
| 268 |
+
# Build a minimal training prompt dataset
|
| 269 |
+
# The model generates tool calls; the environment provides rewards
|
| 270 |
+
training_prompts = [
|
| 271 |
+
{
|
| 272 |
+
"prompt": (
|
| 273 |
+
"You are a data scientist. Investigate the dataset for quality issues.\n"
|
| 274 |
+
f"Dataset: {DATASET_NAME}\n"
|
| 275 |
+
f"Task: {TASK_ID}\n"
|
| 276 |
+
"Use the available tools: inspect_table, profile_column, run_query, "
|
| 277 |
+
"fix_cell, delete_row, validate, submit.\n"
|
| 278 |
+
"Start by inspecting the table."
|
| 279 |
+
)
|
| 280 |
+
}
|
| 281 |
+
for _ in range(max(BATCH_SIZE * 4, 16))
|
| 282 |
+
]
|
| 283 |
+
|
| 284 |
+
# GRPO configuration
|
| 285 |
+
grpo_config = GRPOConfig(
|
| 286 |
+
output_dir=OUTPUT_DIR,
|
| 287 |
+
num_train_epochs=1,
|
| 288 |
+
max_steps=NUM_STEPS,
|
| 289 |
+
per_device_train_batch_size=BATCH_SIZE,
|
| 290 |
+
gradient_accumulation_steps=2,
|
| 291 |
+
learning_rate=1e-5,
|
| 292 |
+
logging_steps=10,
|
| 293 |
+
save_steps=50,
|
| 294 |
+
num_generations=BATCH_SIZE,
|
| 295 |
+
max_new_tokens=256,
|
| 296 |
+
temperature=0.7,
|
| 297 |
+
report_to="none",
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
# Instantiate environments (one per generation slot)
|
| 301 |
+
environments = [SQLSherlockGRPOEnv() for _ in range(BATCH_SIZE)]
|
| 302 |
+
|
| 303 |
+
# Build tools list for the trainer
|
| 304 |
+
tools = [
|
| 305 |
+
environments[0].inspect_table,
|
| 306 |
+
environments[0].profile_column,
|
| 307 |
+
environments[0].run_query,
|
| 308 |
+
environments[0].fix_cell,
|
| 309 |
+
environments[0].delete_row,
|
| 310 |
+
environments[0].validate,
|
| 311 |
+
environments[0].submit,
|
| 312 |
+
]
|
| 313 |
+
|
| 314 |
+
print("Starting GRPO training...")
|
| 315 |
+
trainer = GRPOTrainer(
|
| 316 |
+
model=model,
|
| 317 |
+
args=grpo_config,
|
| 318 |
+
tokenizer=tokenizer,
|
| 319 |
+
train_dataset=training_prompts,
|
| 320 |
+
reward_funcs=reward_func,
|
| 321 |
+
env=environments,
|
| 322 |
+
tools=tools,
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
trainer.train()
|
| 326 |
+
|
| 327 |
+
print(f"\nTraining complete. Checkpoints saved to: {OUTPUT_DIR}")
|
| 328 |
+
model.save_pretrained(OUTPUT_DIR)
|
| 329 |
+
tokenizer.save_pretrained(OUTPUT_DIR)
|
| 330 |
+
print(f"Final model saved to: {OUTPUT_DIR}")
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
if __name__ == "__main__":
|
| 334 |
+
main()
|