Spaces:
Sleeping
Sleeping
Hemanth Kunta commited on
Commit ·
aa25459
1
Parent(s): ae0d0fa
Fix invalid rewards and Space query guards
Browse files- env/app.py +1 -1
- openenv.yaml +2 -2
- space_app.py +39 -0
env/app.py
CHANGED
|
@@ -101,7 +101,7 @@ def step(payload: dict):
|
|
| 101 |
result = engine.execute(action.sql)
|
| 102 |
if isinstance(result, str) and result.startswith("ERROR"):
|
| 103 |
obs = _make_observation(task, state, engine, table_names, None, result, None)
|
| 104 |
-
reward = Reward(value=
|
| 105 |
else:
|
| 106 |
state.query_credits -= 1
|
| 107 |
obs = _make_observation(task, state, engine, table_names, result if isinstance(result, list) else None, None, None)
|
|
|
|
| 101 |
result = engine.execute(action.sql)
|
| 102 |
if isinstance(result, str) and result.startswith("ERROR"):
|
| 103 |
obs = _make_observation(task, state, engine, table_names, None, result, None)
|
| 104 |
+
reward = Reward(value=0.0, breakdown=_zero_breakdown(), done=False, info={"error": result})
|
| 105 |
else:
|
| 106 |
state.query_credits -= 1
|
| 107 |
obs = _make_observation(task, state, engine, table_names, result if isinstance(result, list) else None, None, None)
|
openenv.yaml
CHANGED
|
@@ -70,7 +70,7 @@ observation_space:
|
|
| 70 |
last_action_error: "string | null"
|
| 71 |
last_fix_score: "float | null"
|
| 72 |
|
| 73 |
-
reward_range: [
|
| 74 |
|
| 75 |
reward_design:
|
| 76 |
audit_score: "0.0–1.0, Brier-adjusted per finding confidence"
|
|
@@ -78,7 +78,7 @@ reward_design:
|
|
| 78 |
valid_query_finds_issue: "+0.1 for valid SQL that surfaces NULLs, duplicates, or other clear audit evidence"
|
| 79 |
budget_bonus: "up to +0.10 for early report submission"
|
| 80 |
fix_bonus: "up to +0.25 for correct fix_sql repairs"
|
| 81 |
-
|
| 82 |
|
| 83 |
api:
|
| 84 |
reset: "POST /reset {task_id: int, seed: int}"
|
|
|
|
| 70 |
last_action_error: "string | null"
|
| 71 |
last_fix_score: "float | null"
|
| 72 |
|
| 73 |
+
reward_range: [0.0, 1.25]
|
| 74 |
|
| 75 |
reward_design:
|
| 76 |
audit_score: "0.0–1.0, Brier-adjusted per finding confidence"
|
|
|
|
| 78 |
valid_query_finds_issue: "+0.1 for valid SQL that surfaces NULLs, duplicates, or other clear audit evidence"
|
| 79 |
budget_bonus: "up to +0.10 for early report submission"
|
| 80 |
fix_bonus: "up to +0.25 for correct fix_sql repairs"
|
| 81 |
+
invalid_sql_penalty: 0.0
|
| 82 |
|
| 83 |
api:
|
| 84 |
reset: "POST /reset {task_id: int, seed: int}"
|
space_app.py
CHANGED
|
@@ -72,6 +72,36 @@ def heuristic_queries(task_id: int) -> list[str]:
|
|
| 72 |
]
|
| 73 |
|
| 74 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
def normalize_command(text: str) -> str:
|
| 76 |
return (text or "").strip()
|
| 77 |
|
|
@@ -143,6 +173,11 @@ def run_query(sql_text: str, current_obs: dict[str, Any] | None, chat: list[dict
|
|
| 143 |
chat = chat + [{"role": "assistant", "content": "Send a SQL query first."}]
|
| 144 |
return chat, format_observation(current_obs), session_status(current_obs), format_reward({}), current_obs
|
| 145 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
out = SESSION.step({"action": {"action_type": "query", "sql": sql}})
|
| 147 |
obs = out.get("observation")
|
| 148 |
reward = out.get("reward")
|
|
@@ -180,6 +215,10 @@ def auto_audit(current_obs: dict[str, Any] | None, chat: list[dict[str, str]]):
|
|
| 180 |
obs = current_obs
|
| 181 |
reward = None
|
| 182 |
for sql in queries:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 183 |
out = SESSION.step({"action": {"action_type": "query", "sql": sql}})
|
| 184 |
obs = out.get("observation")
|
| 185 |
reward = out.get("reward")
|
|
|
|
| 72 |
]
|
| 73 |
|
| 74 |
|
| 75 |
+
def current_tables(obs: dict[str, Any] | None) -> set[str]:
|
| 76 |
+
tables = (obs or {}).get("tables") or {}
|
| 77 |
+
return {str(name).lower() for name in tables.keys()}
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def referenced_tables(sql_text: str) -> set[str]:
|
| 81 |
+
sql = normalize_command(sql_text)
|
| 82 |
+
matches = re.finditer(r"\b(?:from|join)\s+([a-zA-Z_][\w\.]*)", sql, flags=re.IGNORECASE)
|
| 83 |
+
refs: set[str] = set()
|
| 84 |
+
for match in matches:
|
| 85 |
+
identifier = match.group(1).split(".")[-1].lower()
|
| 86 |
+
if identifier:
|
| 87 |
+
refs.add(identifier)
|
| 88 |
+
return refs
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def validate_query_tables(sql_text: str, obs: dict[str, Any] | None) -> str | None:
|
| 92 |
+
allowed = current_tables(obs)
|
| 93 |
+
if not allowed:
|
| 94 |
+
return None
|
| 95 |
+
refs = referenced_tables(sql_text)
|
| 96 |
+
if not refs:
|
| 97 |
+
return None
|
| 98 |
+
unknown = sorted(refs - allowed)
|
| 99 |
+
if unknown:
|
| 100 |
+
available = ", ".join(sorted(allowed))
|
| 101 |
+
return f"This task only exposes: {available}. Please query one of those tables instead of: {', '.join(unknown)}."
|
| 102 |
+
return None
|
| 103 |
+
|
| 104 |
+
|
| 105 |
def normalize_command(text: str) -> str:
|
| 106 |
return (text or "").strip()
|
| 107 |
|
|
|
|
| 173 |
chat = chat + [{"role": "assistant", "content": "Send a SQL query first."}]
|
| 174 |
return chat, format_observation(current_obs), session_status(current_obs), format_reward({}), current_obs
|
| 175 |
|
| 176 |
+
table_error = validate_query_tables(sql, current_obs)
|
| 177 |
+
if table_error:
|
| 178 |
+
chat = chat + [{"role": "assistant", "content": table_error}]
|
| 179 |
+
return chat, format_observation(current_obs), session_status(current_obs), format_reward({"value": 0.0, "done": False}), current_obs
|
| 180 |
+
|
| 181 |
out = SESSION.step({"action": {"action_type": "query", "sql": sql}})
|
| 182 |
obs = out.get("observation")
|
| 183 |
reward = out.get("reward")
|
|
|
|
| 215 |
obs = current_obs
|
| 216 |
reward = None
|
| 217 |
for sql in queries:
|
| 218 |
+
table_error = validate_query_tables(sql, obs)
|
| 219 |
+
if table_error:
|
| 220 |
+
running_chat.append({"role": "assistant", "content": table_error})
|
| 221 |
+
continue
|
| 222 |
out = SESSION.step({"action": {"action_type": "query", "sql": sql}})
|
| 223 |
obs = out.get("observation")
|
| 224 |
reward = out.get("reward")
|