Spaces:
Sleeping
Sleeping
Commit Β·
4551d9d
1
Parent(s): 3018ee0
minor fixes
Browse files- README.md +1 -1
- inference.py +4 -4
- openenv.yaml +3 -3
- sqlsherlock_env/server/environment.py +6 -17
- tests/test_environment.py +13 -10
README.md
CHANGED
|
@@ -2,7 +2,7 @@
|
|
| 2 |
title: SQLSherlock Env
|
| 3 |
emoji: π
|
| 4 |
colorFrom: indigo
|
| 5 |
-
colorTo:
|
| 6 |
sdk: docker
|
| 7 |
app_port: 7860
|
| 8 |
tags:
|
|
|
|
| 2 |
title: SQLSherlock Env
|
| 3 |
emoji: π
|
| 4 |
colorFrom: indigo
|
| 5 |
+
colorTo: blue
|
| 6 |
sdk: docker
|
| 7 |
app_port: 7860
|
| 8 |
tags:
|
inference.py
CHANGED
|
@@ -45,11 +45,11 @@ SPACE_URL = os.getenv("SPACE_URL", "http://localhost:7860")
|
|
| 45 |
# Optional β if you use from_docker_image():
|
| 46 |
LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")
|
| 47 |
|
| 48 |
-
#
|
| 49 |
STEP_BUDGETS: dict[str, int] = {
|
| 50 |
-
"task1_null_and_types":
|
| 51 |
-
"task2_constraints_and_fk":
|
| 52 |
-
"task3_full_audit_with_trap":
|
| 53 |
}
|
| 54 |
|
| 55 |
TASKS = [
|
|
|
|
| 45 |
# Optional β if you use from_docker_image():
|
| 46 |
LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")
|
| 47 |
|
| 48 |
+
# Full environment max_steps β agent gets maximum room to clean
|
| 49 |
STEP_BUDGETS: dict[str, int] = {
|
| 50 |
+
"task1_null_and_types": 30, # env max_steps = 30
|
| 51 |
+
"task2_constraints_and_fk": 40, # env max_steps = 40
|
| 52 |
+
"task3_full_audit_with_trap": 50, # env max_steps = 50
|
| 53 |
}
|
| 54 |
|
| 55 |
TASKS = [
|
openenv.yaml
CHANGED
|
@@ -24,7 +24,7 @@ tasks:
|
|
| 24 |
- id: task1_null_and_types
|
| 25 |
name: "Null and type error repair"
|
| 26 |
difficulty: easy
|
| 27 |
-
max_steps:
|
| 28 |
description: >
|
| 29 |
Find and fix null values and type errors in the primary table.
|
| 30 |
Profile columns, identify anomalies, fix with reasoning,
|
|
@@ -33,7 +33,7 @@ tasks:
|
|
| 33 |
- id: task2_constraints_and_fk
|
| 34 |
name: "Constraint and FK integrity"
|
| 35 |
difficulty: medium
|
| 36 |
-
max_steps:
|
| 37 |
description: >
|
| 38 |
Everything in Task 1 plus constraint violations
|
| 39 |
(negative values in must-be-positive columns) and FK
|
|
@@ -42,7 +42,7 @@ tasks:
|
|
| 42 |
- id: task3_full_audit_with_trap
|
| 43 |
name: "Full statistical audit with trap"
|
| 44 |
difficulty: hard
|
| 45 |
-
max_steps:
|
| 46 |
description: >
|
| 47 |
Full audit including statistical outliers. TRAP WARNING:
|
| 48 |
one numeric value looks suspicious but is legitimate.
|
|
|
|
| 24 |
- id: task1_null_and_types
|
| 25 |
name: "Null and type error repair"
|
| 26 |
difficulty: easy
|
| 27 |
+
max_steps: 30
|
| 28 |
description: >
|
| 29 |
Find and fix null values and type errors in the primary table.
|
| 30 |
Profile columns, identify anomalies, fix with reasoning,
|
|
|
|
| 33 |
- id: task2_constraints_and_fk
|
| 34 |
name: "Constraint and FK integrity"
|
| 35 |
difficulty: medium
|
| 36 |
+
max_steps: 40
|
| 37 |
description: >
|
| 38 |
Everything in Task 1 plus constraint violations
|
| 39 |
(negative values in must-be-positive columns) and FK
|
|
|
|
| 42 |
- id: task3_full_audit_with_trap
|
| 43 |
name: "Full statistical audit with trap"
|
| 44 |
difficulty: hard
|
| 45 |
+
max_steps: 50
|
| 46 |
description: >
|
| 47 |
Full audit including statistical outliers. TRAP WARNING:
|
| 48 |
one numeric value looks suspicious but is legitimate.
|
sqlsherlock_env/server/environment.py
CHANGED
|
@@ -33,7 +33,7 @@ TASKS: list[dict] = [
|
|
| 33 |
"id": "task1_null_and_types",
|
| 34 |
"name": "Null and type error repair",
|
| 35 |
"difficulty": "easy",
|
| 36 |
-
"max_steps":
|
| 37 |
"description": (
|
| 38 |
"Find and fix null values and type errors in the primary table. "
|
| 39 |
"Profile columns, identify anomalies, fix with reasoning, "
|
|
@@ -44,7 +44,7 @@ TASKS: list[dict] = [
|
|
| 44 |
"id": "task2_constraints_and_fk",
|
| 45 |
"name": "Constraint and FK integrity",
|
| 46 |
"difficulty": "medium",
|
| 47 |
-
"max_steps":
|
| 48 |
"description": (
|
| 49 |
"Everything in Task 1 plus constraint violations "
|
| 50 |
"(negative values in must-be-positive columns) and FK "
|
|
@@ -55,7 +55,7 @@ TASKS: list[dict] = [
|
|
| 55 |
"id": "task3_full_audit_with_trap",
|
| 56 |
"name": "Full statistical audit with trap",
|
| 57 |
"difficulty": "hard",
|
| 58 |
-
"max_steps":
|
| 59 |
"description": (
|
| 60 |
"Full audit including statistical outliers. TRAP WARNING: "
|
| 61 |
"one numeric value looks suspicious but is legitimate. "
|
|
@@ -100,21 +100,10 @@ class SQLSherlockEnvironment(Environment):
|
|
| 100 |
Raises:
|
| 101 |
ValueError: If dataset or task_id is missing/invalid.
|
| 102 |
"""
|
| 103 |
-
dataset
|
| 104 |
-
task_id
|
| 105 |
-
seed
|
| 106 |
max_rows = int(kwargs.get("max_rows", 500))
|
| 107 |
-
|
| 108 |
-
if not dataset or not dataset.strip():
|
| 109 |
-
raise ValueError(
|
| 110 |
-
"reset() requires 'dataset' keyword argument. "
|
| 111 |
-
"Provide a file path, HuggingFace dataset name, or raw CSV text."
|
| 112 |
-
)
|
| 113 |
-
if not task_id or not task_id.strip():
|
| 114 |
-
raise ValueError(
|
| 115 |
-
"reset() requires 'task_id' keyword argument. "
|
| 116 |
-
f"Valid tasks: {sorted(_TASK_MAP.keys())}"
|
| 117 |
-
)
|
| 118 |
if task_id not in _TASK_MAP:
|
| 119 |
raise ValueError(
|
| 120 |
f"Unknown task_id '{task_id}'. "
|
|
|
|
| 33 |
"id": "task1_null_and_types",
|
| 34 |
"name": "Null and type error repair",
|
| 35 |
"difficulty": "easy",
|
| 36 |
+
"max_steps": 30,
|
| 37 |
"description": (
|
| 38 |
"Find and fix null values and type errors in the primary table. "
|
| 39 |
"Profile columns, identify anomalies, fix with reasoning, "
|
|
|
|
| 44 |
"id": "task2_constraints_and_fk",
|
| 45 |
"name": "Constraint and FK integrity",
|
| 46 |
"difficulty": "medium",
|
| 47 |
+
"max_steps": 40,
|
| 48 |
"description": (
|
| 49 |
"Everything in Task 1 plus constraint violations "
|
| 50 |
"(negative values in must-be-positive columns) and FK "
|
|
|
|
| 55 |
"id": "task3_full_audit_with_trap",
|
| 56 |
"name": "Full statistical audit with trap",
|
| 57 |
"difficulty": "hard",
|
| 58 |
+
"max_steps": 50,
|
| 59 |
"description": (
|
| 60 |
"Full audit including statistical outliers. TRAP WARNING: "
|
| 61 |
"one numeric value looks suspicious but is legitimate. "
|
|
|
|
| 100 |
Raises:
|
| 101 |
ValueError: If dataset or task_id is missing/invalid.
|
| 102 |
"""
|
| 103 |
+
dataset = kwargs.get("dataset", "") or "phihung/titanic"
|
| 104 |
+
task_id = kwargs.get("task_id", "") or "task1_null_and_types"
|
| 105 |
+
seed = int(kwargs.get("seed", 42))
|
| 106 |
max_rows = int(kwargs.get("max_rows", 500))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
if task_id not in _TASK_MAP:
|
| 108 |
raise ValueError(
|
| 109 |
f"Unknown task_id '{task_id}'. "
|
tests/test_environment.py
CHANGED
|
@@ -73,9 +73,9 @@ class TestTasksCatalogue:
|
|
| 73 |
|
| 74 |
def test_max_steps_values(self):
|
| 75 |
step_map = {t["id"]: t["max_steps"] for t in TASKS}
|
| 76 |
-
assert step_map["task1_null_and_types"] ==
|
| 77 |
-
assert step_map["task2_constraints_and_fk"] ==
|
| 78 |
-
assert step_map["task3_full_audit_with_trap"] ==
|
| 79 |
|
| 80 |
|
| 81 |
# ---------------------------------------------------------------------------
|
|
@@ -99,13 +99,16 @@ class TestReset:
|
|
| 99 |
obs = env.reset(dataset=RAW_CSV_TEXT, task_id="task1_null_and_types")
|
| 100 |
assert obs.step == 0
|
| 101 |
|
| 102 |
-
def
|
| 103 |
-
|
| 104 |
-
|
|
|
|
|
|
|
| 105 |
|
| 106 |
-
def
|
| 107 |
-
|
| 108 |
-
|
|
|
|
| 109 |
|
| 110 |
def test_reset_invalid_task_raises(self, env):
|
| 111 |
with pytest.raises(ValueError, match="Unknown task_id"):
|
|
@@ -391,7 +394,7 @@ class TestMaxSteps:
|
|
| 391 |
env.reset(dataset=RAW_CSV_TEXT, task_id="task1_null_and_types")
|
| 392 |
table = list(env._db.table_names())[0]
|
| 393 |
done = False
|
| 394 |
-
for _ in range(
|
| 395 |
_, _, done, _ = _step(env,
|
| 396 |
SQLSherlockAction(action_type="inspect", table=table)
|
| 397 |
)
|
|
|
|
| 73 |
|
| 74 |
def test_max_steps_values(self):
|
| 75 |
step_map = {t["id"]: t["max_steps"] for t in TASKS}
|
| 76 |
+
assert step_map["task1_null_and_types"] == 30
|
| 77 |
+
assert step_map["task2_constraints_and_fk"] == 40
|
| 78 |
+
assert step_map["task3_full_audit_with_trap"] == 50
|
| 79 |
|
| 80 |
|
| 81 |
# ---------------------------------------------------------------------------
|
|
|
|
| 99 |
obs = env.reset(dataset=RAW_CSV_TEXT, task_id="task1_null_and_types")
|
| 100 |
assert obs.step == 0
|
| 101 |
|
| 102 |
+
def test_reset_no_dataset_uses_default(self, env):
|
| 103 |
+
"""Empty dataset defaults to phihung/titanic."""
|
| 104 |
+
obs = env.reset(dataset="", task_id="task1_null_and_types")
|
| 105 |
+
assert isinstance(obs, SQLSherlockObservation)
|
| 106 |
+
assert len(obs.tables_summary) > 0
|
| 107 |
|
| 108 |
+
def test_reset_no_task_uses_default(self, env):
|
| 109 |
+
"""Empty task_id defaults to task1_null_and_types."""
|
| 110 |
+
obs = env.reset(dataset=RAW_CSV_TEXT, task_id="")
|
| 111 |
+
assert isinstance(obs, SQLSherlockObservation)
|
| 112 |
|
| 113 |
def test_reset_invalid_task_raises(self, env):
|
| 114 |
with pytest.raises(ValueError, match="Unknown task_id"):
|
|
|
|
| 394 |
env.reset(dataset=RAW_CSV_TEXT, task_id="task1_null_and_types")
|
| 395 |
table = list(env._db.table_names())[0]
|
| 396 |
done = False
|
| 397 |
+
for _ in range(35): # more than max_steps=30
|
| 398 |
_, _, done, _ = _step(env,
|
| 399 |
SQLSherlockAction(action_type="inspect", table=table)
|
| 400 |
)
|