Swethaditya commited on
Commit
4551d9d
Β·
1 Parent(s): 3018ee0

minor fixes

Browse files
README.md CHANGED
@@ -2,7 +2,7 @@
2
  title: SQLSherlock Env
3
  emoji: πŸ”
4
  colorFrom: indigo
5
- colorTo: cyan
6
  sdk: docker
7
  app_port: 7860
8
  tags:
 
2
  title: SQLSherlock Env
3
  emoji: πŸ”
4
  colorFrom: indigo
5
+ colorTo: blue
6
  sdk: docker
7
  app_port: 7860
8
  tags:
inference.py CHANGED
@@ -45,11 +45,11 @@ SPACE_URL = os.getenv("SPACE_URL", "http://localhost:7860")
45
  # Optional β€” if you use from_docker_image():
46
  LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")
47
 
48
- # Use full environment max_steps β€” no artificial restriction
49
  STEP_BUDGETS: dict[str, int] = {
50
- "task1_null_and_types": 20, # env max_steps = 20
51
- "task2_constraints_and_fk": 25, # env max_steps = 25
52
- "task3_full_audit_with_trap": 30, # env max_steps = 30
53
  }
54
 
55
  TASKS = [
 
45
  # Optional β€” if you use from_docker_image():
46
  LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")
47
 
48
+ # Full environment max_steps β€” agent gets maximum room to clean
49
  STEP_BUDGETS: dict[str, int] = {
50
+ "task1_null_and_types": 30, # env max_steps = 30
51
+ "task2_constraints_and_fk": 40, # env max_steps = 40
52
+ "task3_full_audit_with_trap": 50, # env max_steps = 50
53
  }
54
 
55
  TASKS = [
openenv.yaml CHANGED
@@ -24,7 +24,7 @@ tasks:
24
  - id: task1_null_and_types
25
  name: "Null and type error repair"
26
  difficulty: easy
27
- max_steps: 20
28
  description: >
29
  Find and fix null values and type errors in the primary table.
30
  Profile columns, identify anomalies, fix with reasoning,
@@ -33,7 +33,7 @@ tasks:
33
  - id: task2_constraints_and_fk
34
  name: "Constraint and FK integrity"
35
  difficulty: medium
36
- max_steps: 25
37
  description: >
38
  Everything in Task 1 plus constraint violations
39
  (negative values in must-be-positive columns) and FK
@@ -42,7 +42,7 @@ tasks:
42
  - id: task3_full_audit_with_trap
43
  name: "Full statistical audit with trap"
44
  difficulty: hard
45
- max_steps: 30
46
  description: >
47
  Full audit including statistical outliers. TRAP WARNING:
48
  one numeric value looks suspicious but is legitimate.
 
24
  - id: task1_null_and_types
25
  name: "Null and type error repair"
26
  difficulty: easy
27
+ max_steps: 30
28
  description: >
29
  Find and fix null values and type errors in the primary table.
30
  Profile columns, identify anomalies, fix with reasoning,
 
33
  - id: task2_constraints_and_fk
34
  name: "Constraint and FK integrity"
35
  difficulty: medium
36
+ max_steps: 40
37
  description: >
38
  Everything in Task 1 plus constraint violations
39
  (negative values in must-be-positive columns) and FK
 
42
  - id: task3_full_audit_with_trap
43
  name: "Full statistical audit with trap"
44
  difficulty: hard
45
+ max_steps: 50
46
  description: >
47
  Full audit including statistical outliers. TRAP WARNING:
48
  one numeric value looks suspicious but is legitimate.
sqlsherlock_env/server/environment.py CHANGED
@@ -33,7 +33,7 @@ TASKS: list[dict] = [
33
  "id": "task1_null_and_types",
34
  "name": "Null and type error repair",
35
  "difficulty": "easy",
36
- "max_steps": 20,
37
  "description": (
38
  "Find and fix null values and type errors in the primary table. "
39
  "Profile columns, identify anomalies, fix with reasoning, "
@@ -44,7 +44,7 @@ TASKS: list[dict] = [
44
  "id": "task2_constraints_and_fk",
45
  "name": "Constraint and FK integrity",
46
  "difficulty": "medium",
47
- "max_steps": 25,
48
  "description": (
49
  "Everything in Task 1 plus constraint violations "
50
  "(negative values in must-be-positive columns) and FK "
@@ -55,7 +55,7 @@ TASKS: list[dict] = [
55
  "id": "task3_full_audit_with_trap",
56
  "name": "Full statistical audit with trap",
57
  "difficulty": "hard",
58
- "max_steps": 30,
59
  "description": (
60
  "Full audit including statistical outliers. TRAP WARNING: "
61
  "one numeric value looks suspicious but is legitimate. "
@@ -100,21 +100,10 @@ class SQLSherlockEnvironment(Environment):
100
  Raises:
101
  ValueError: If dataset or task_id is missing/invalid.
102
  """
103
- dataset = kwargs.get("dataset", "")
104
- task_id = kwargs.get("task_id", "")
105
- seed = int(kwargs.get("seed", 42))
106
  max_rows = int(kwargs.get("max_rows", 500))
107
-
108
- if not dataset or not dataset.strip():
109
- raise ValueError(
110
- "reset() requires 'dataset' keyword argument. "
111
- "Provide a file path, HuggingFace dataset name, or raw CSV text."
112
- )
113
- if not task_id or not task_id.strip():
114
- raise ValueError(
115
- "reset() requires 'task_id' keyword argument. "
116
- f"Valid tasks: {sorted(_TASK_MAP.keys())}"
117
- )
118
  if task_id not in _TASK_MAP:
119
  raise ValueError(
120
  f"Unknown task_id '{task_id}'. "
 
33
  "id": "task1_null_and_types",
34
  "name": "Null and type error repair",
35
  "difficulty": "easy",
36
+ "max_steps": 30,
37
  "description": (
38
  "Find and fix null values and type errors in the primary table. "
39
  "Profile columns, identify anomalies, fix with reasoning, "
 
44
  "id": "task2_constraints_and_fk",
45
  "name": "Constraint and FK integrity",
46
  "difficulty": "medium",
47
+ "max_steps": 40,
48
  "description": (
49
  "Everything in Task 1 plus constraint violations "
50
  "(negative values in must-be-positive columns) and FK "
 
55
  "id": "task3_full_audit_with_trap",
56
  "name": "Full statistical audit with trap",
57
  "difficulty": "hard",
58
+ "max_steps": 50,
59
  "description": (
60
  "Full audit including statistical outliers. TRAP WARNING: "
61
  "one numeric value looks suspicious but is legitimate. "
 
100
  Raises:
101
  ValueError: If dataset or task_id is missing/invalid.
102
  """
103
+ dataset = kwargs.get("dataset", "") or "phihung/titanic"
104
+ task_id = kwargs.get("task_id", "") or "task1_null_and_types"
105
+ seed = int(kwargs.get("seed", 42))
106
  max_rows = int(kwargs.get("max_rows", 500))
 
 
 
 
 
 
 
 
 
 
 
107
  if task_id not in _TASK_MAP:
108
  raise ValueError(
109
  f"Unknown task_id '{task_id}'. "
tests/test_environment.py CHANGED
@@ -73,9 +73,9 @@ class TestTasksCatalogue:
73
 
74
  def test_max_steps_values(self):
75
  step_map = {t["id"]: t["max_steps"] for t in TASKS}
76
- assert step_map["task1_null_and_types"] == 20
77
- assert step_map["task2_constraints_and_fk"] == 25
78
- assert step_map["task3_full_audit_with_trap"] == 30
79
 
80
 
81
  # ---------------------------------------------------------------------------
@@ -99,13 +99,16 @@ class TestReset:
99
  obs = env.reset(dataset=RAW_CSV_TEXT, task_id="task1_null_and_types")
100
  assert obs.step == 0
101
 
102
- def test_reset_no_dataset_raises(self, env):
103
- with pytest.raises(ValueError, match="dataset"):
104
- env.reset(dataset="", task_id="task1_null_and_types")
 
 
105
 
106
- def test_reset_no_task_raises(self, env):
107
- with pytest.raises(ValueError, match="task_id"):
108
- env.reset(dataset=RAW_CSV_TEXT, task_id="")
 
109
 
110
  def test_reset_invalid_task_raises(self, env):
111
  with pytest.raises(ValueError, match="Unknown task_id"):
@@ -391,7 +394,7 @@ class TestMaxSteps:
391
  env.reset(dataset=RAW_CSV_TEXT, task_id="task1_null_and_types")
392
  table = list(env._db.table_names())[0]
393
  done = False
394
- for _ in range(25): # more than max_steps=20
395
  _, _, done, _ = _step(env,
396
  SQLSherlockAction(action_type="inspect", table=table)
397
  )
 
73
 
74
  def test_max_steps_values(self):
75
  step_map = {t["id"]: t["max_steps"] for t in TASKS}
76
+ assert step_map["task1_null_and_types"] == 30
77
+ assert step_map["task2_constraints_and_fk"] == 40
78
+ assert step_map["task3_full_audit_with_trap"] == 50
79
 
80
 
81
  # ---------------------------------------------------------------------------
 
99
  obs = env.reset(dataset=RAW_CSV_TEXT, task_id="task1_null_and_types")
100
  assert obs.step == 0
101
 
102
+ def test_reset_no_dataset_uses_default(self, env):
103
+ """Empty dataset defaults to phihung/titanic."""
104
+ obs = env.reset(dataset="", task_id="task1_null_and_types")
105
+ assert isinstance(obs, SQLSherlockObservation)
106
+ assert len(obs.tables_summary) > 0
107
 
108
+ def test_reset_no_task_uses_default(self, env):
109
+ """Empty task_id defaults to task1_null_and_types."""
110
+ obs = env.reset(dataset=RAW_CSV_TEXT, task_id="")
111
+ assert isinstance(obs, SQLSherlockObservation)
112
 
113
  def test_reset_invalid_task_raises(self, env):
114
  with pytest.raises(ValueError, match="Unknown task_id"):
 
394
  env.reset(dataset=RAW_CSV_TEXT, task_id="task1_null_and_types")
395
  table = list(env._db.table_names())[0]
396
  done = False
397
+ for _ in range(35): # more than max_steps=30
398
  _, _, done, _ = _step(env,
399
  SQLSherlockAction(action_type="inspect", table=table)
400
  )