Hemanth Kunta commited on
Commit
aa25459
·
1 Parent(s): ae0d0fa

Fix invalid rewards and Space query guards

Browse files
Files changed (3) hide show
  1. env/app.py +1 -1
  2. openenv.yaml +2 -2
  3. space_app.py +39 -0
env/app.py CHANGED
@@ -101,7 +101,7 @@ def step(payload: dict):
101
  result = engine.execute(action.sql)
102
  if isinstance(result, str) and result.startswith("ERROR"):
103
  obs = _make_observation(task, state, engine, table_names, None, result, None)
104
- reward = Reward(value=-0.1, breakdown=_zero_breakdown(destructive=-0.1), done=False, info={"error": result})
105
  else:
106
  state.query_credits -= 1
107
  obs = _make_observation(task, state, engine, table_names, result if isinstance(result, list) else None, None, None)
 
101
  result = engine.execute(action.sql)
102
  if isinstance(result, str) and result.startswith("ERROR"):
103
  obs = _make_observation(task, state, engine, table_names, None, result, None)
104
+ reward = Reward(value=0.0, breakdown=_zero_breakdown(), done=False, info={"error": result})
105
  else:
106
  state.query_credits -= 1
107
  obs = _make_observation(task, state, engine, table_names, result if isinstance(result, list) else None, None, None)
openenv.yaml CHANGED
@@ -70,7 +70,7 @@ observation_space:
70
  last_action_error: "string | null"
71
  last_fix_score: "float | null"
72
 
73
- reward_range: [-0.1, 1.25]
74
 
75
  reward_design:
76
  audit_score: "0.0–1.0, Brier-adjusted per finding confidence"
@@ -78,7 +78,7 @@ reward_design:
78
  valid_query_finds_issue: "+0.1 for valid SQL that surfaces NULLs, duplicates, or other clear audit evidence"
79
  budget_bonus: "up to +0.10 for early report submission"
80
  fix_bonus: "up to +0.25 for correct fix_sql repairs"
81
- destructive_sql_penalty: -0.1
82
 
83
  api:
84
  reset: "POST /reset {task_id: int, seed: int}"
 
70
  last_action_error: "string | null"
71
  last_fix_score: "float | null"
72
 
73
+ reward_range: [0.0, 1.25]
74
 
75
  reward_design:
76
  audit_score: "0.0–1.0, Brier-adjusted per finding confidence"
 
78
  valid_query_finds_issue: "+0.1 for valid SQL that surfaces NULLs, duplicates, or other clear audit evidence"
79
  budget_bonus: "up to +0.10 for early report submission"
80
  fix_bonus: "up to +0.25 for correct fix_sql repairs"
81
+ invalid_sql_penalty: 0.0
82
 
83
  api:
84
  reset: "POST /reset {task_id: int, seed: int}"
space_app.py CHANGED
@@ -72,6 +72,36 @@ def heuristic_queries(task_id: int) -> list[str]:
72
  ]
73
 
74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  def normalize_command(text: str) -> str:
76
  return (text or "").strip()
77
 
@@ -143,6 +173,11 @@ def run_query(sql_text: str, current_obs: dict[str, Any] | None, chat: list[dict
143
  chat = chat + [{"role": "assistant", "content": "Send a SQL query first."}]
144
  return chat, format_observation(current_obs), session_status(current_obs), format_reward({}), current_obs
145
 
 
 
 
 
 
146
  out = SESSION.step({"action": {"action_type": "query", "sql": sql}})
147
  obs = out.get("observation")
148
  reward = out.get("reward")
@@ -180,6 +215,10 @@ def auto_audit(current_obs: dict[str, Any] | None, chat: list[dict[str, str]]):
180
  obs = current_obs
181
  reward = None
182
  for sql in queries:
 
 
 
 
183
  out = SESSION.step({"action": {"action_type": "query", "sql": sql}})
184
  obs = out.get("observation")
185
  reward = out.get("reward")
 
72
  ]
73
 
74
 
75
+ def current_tables(obs: dict[str, Any] | None) -> set[str]:
76
+ tables = (obs or {}).get("tables") or {}
77
+ return {str(name).lower() for name in tables.keys()}
78
+
79
+
80
+ def referenced_tables(sql_text: str) -> set[str]:
81
+ sql = normalize_command(sql_text)
82
+ matches = re.finditer(r"\b(?:from|join)\s+([a-zA-Z_][\w\.]*)", sql, flags=re.IGNORECASE)
83
+ refs: set[str] = set()
84
+ for match in matches:
85
+ identifier = match.group(1).split(".")[-1].lower()
86
+ if identifier:
87
+ refs.add(identifier)
88
+ return refs
89
+
90
+
91
+ def validate_query_tables(sql_text: str, obs: dict[str, Any] | None) -> str | None:
92
+ allowed = current_tables(obs)
93
+ if not allowed:
94
+ return None
95
+ refs = referenced_tables(sql_text)
96
+ if not refs:
97
+ return None
98
+ unknown = sorted(refs - allowed)
99
+ if unknown:
100
+ available = ", ".join(sorted(allowed))
101
+ return f"This task only exposes: {available}. Please query one of those tables instead of: {', '.join(unknown)}."
102
+ return None
103
+
104
+
105
  def normalize_command(text: str) -> str:
106
  return (text or "").strip()
107
 
 
173
  chat = chat + [{"role": "assistant", "content": "Send a SQL query first."}]
174
  return chat, format_observation(current_obs), session_status(current_obs), format_reward({}), current_obs
175
 
176
+ table_error = validate_query_tables(sql, current_obs)
177
+ if table_error:
178
+ chat = chat + [{"role": "assistant", "content": table_error}]
179
+ return chat, format_observation(current_obs), session_status(current_obs), format_reward({"value": 0.0, "done": False}), current_obs
180
+
181
  out = SESSION.step({"action": {"action_type": "query", "sql": sql}})
182
  obs = out.get("observation")
183
  reward = out.get("reward")
 
215
  obs = current_obs
216
  reward = None
217
  for sql in queries:
218
+ table_error = validate_query_tables(sql, obs)
219
+ if table_error:
220
+ running_chat.append({"role": "assistant", "content": table_error})
221
+ continue
222
  out = SESSION.step({"action": {"action_type": "query", "sql": sql}})
223
  obs = out.get("observation")
224
  reward = out.get("reward")