Hemanth Kunta commited on
Commit
94595e2
·
1 Parent(s): 84b0d00

final local fixes

Browse files
__pycache__/space_app.cpython-311.pyc CHANGED
Binary files a/__pycache__/space_app.cpython-311.pyc and b/__pycache__/space_app.cpython-311.pyc differ
 
env/__pycache__/app.cpython-311.pyc CHANGED
Binary files a/env/__pycache__/app.cpython-311.pyc and b/env/__pycache__/app.cpython-311.pyc differ
 
env/app.py CHANGED
@@ -104,7 +104,8 @@ def step(payload: dict):
104
  else:
105
  state.query_credits -= 1
106
  obs = _make_observation(task, state, engine, table_names, result if isinstance(result, list) else None, None, None)
107
- reward = Reward(value=0.0, breakdown=_zero_breakdown(), done=False, info={})
 
108
  return _step_response(obs, reward)
109
 
110
  if action.action_type == "submit_report":
@@ -213,3 +214,27 @@ def _zero_breakdown(destructive: float = 0.0) -> RewardBreakdown:
213
  fix_verification_bonus=destructive,
214
  total=destructive,
215
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  else:
105
  state.query_credits -= 1
106
  obs = _make_observation(task, state, engine, table_names, result if isinstance(result, list) else None, None, None)
107
+ query_reward = _query_reward(action.sql, result if isinstance(result, list) else None)
108
+ reward = Reward(value=query_reward, breakdown=_zero_breakdown(), done=False, info={})
109
  return _step_response(obs, reward)
110
 
111
  if action.action_type == "submit_report":
 
214
  fix_verification_bonus=destructive,
215
  total=destructive,
216
  )
217
+
218
+
219
+ def _query_reward(sql: str, result: list[dict[str, Any]] | None) -> float:
220
+ """Provide small positive rewards for valid exploration and stronger credit for finding obvious issues."""
221
+ rows = result or []
222
+ if not rows:
223
+ return 0.01
224
+
225
+ sql_lower = (sql or "").lower()
226
+ hot_query = any(keyword in sql_lower for keyword in ("null", "dup", "duplicate", "group by", "having", "count(", "is null"))
227
+
228
+ def row_has_signal(row: dict[str, Any]) -> bool:
229
+ for value in row.values():
230
+ if value is None:
231
+ return True
232
+ if isinstance(value, (int, float)) and value > 0:
233
+ return True
234
+ if isinstance(value, str) and value.strip().lower() in {"null", "n/a", "na", "unknown", "none", "-", ""}:
235
+ return True
236
+ return False
237
+
238
+ if hot_query and any(row_has_signal(row) for row in rows):
239
+ return 0.1
240
+ return 0.01
inference.py CHANGED
@@ -21,6 +21,7 @@ MODEL_NAME = os.getenv("MODEL_NAME", "meta-llama/Llama-3.2-3B-Instruct")
21
 
22
  client: OpenAI | None = None
23
  FORCE_HEURISTIC = os.environ.get("FORCE_HEURISTIC", "0") == "1"
 
24
 
25
  SEED = int(os.environ.get("SEED", "42"))
26
  TEMPERATURE = 0.1
@@ -29,7 +30,16 @@ MAX_AUDIT_STEPS = 9
29
  FIX_STEPS = 3
30
  WALL_LIMIT = 15 * 60
31
 
32
- SYSTEM_PROMPT = """You are a data quality auditor AI agent. You investigate dirty SQL datasets.
 
 
 
 
 
 
 
 
 
33
 
34
  AVAILABLE ACTIONS (respond with JSON only, no extra text):
35
 
@@ -112,7 +122,24 @@ def parse_action(text: str) -> dict:
112
  return json.loads(m.group())
113
  except Exception:
114
  pass
115
- return {"action_type": "query", "sql": "SELECT 1 AS fallback"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
 
117
 
118
  def normalize_report(report: dict | None) -> dict:
@@ -204,14 +231,14 @@ def coerce_action(raw: str, task_id: int, step: int, total_steps: int) -> dict:
204
  # Close episode safely near step limit.
205
  if step >= total_steps - 1:
206
  return fallback_submit_action(task_id)
207
- return {"action_type": "query", "sql": "SELECT 1 AS fallback"}
208
 
209
  if at == "query":
210
  sql = str(parsed.get("sql", "")).strip()
211
  if not sql:
212
  if step >= total_steps - 1:
213
  return fallback_submit_action(task_id)
214
- return {"action_type": "query", "sql": "SELECT 1 AS fallback"}
215
  if step >= total_steps - 1:
216
  return fallback_submit_action(task_id)
217
  return {"action_type": "query", "sql": sql}
@@ -222,7 +249,7 @@ def coerce_action(raw: str, task_id: int, step: int, total_steps: int) -> dict:
222
  # fix_sql is allowed only in fix phase after submit; avoid using it in audit loop.
223
  if step >= total_steps - 1:
224
  return fallback_submit_action(task_id)
225
- return {"action_type": "query", "sql": "SELECT 1 AS fallback"}
226
 
227
 
228
  def llm_ready() -> tuple[bool, str]:
@@ -273,19 +300,28 @@ def _extract_json_object(text: str) -> dict | None:
273
  def llm_refine_report(task_id: int, obs: dict, evidence: dict, base_report: dict) -> dict:
274
  if client is None:
275
  return base_report
 
276
  prompt = {
277
  "task_id": task_id,
278
  "task_description": obs.get("task_description", ""),
279
  "tables": obs.get("tables", {}),
 
280
  "evidence": evidence,
281
  "base_report": base_report,
282
- "instruction": "Return ONLY a valid JSON object for report with same schema fields. Keep numeric values grounded in evidence.",
283
  }
284
  try:
285
  c = client.chat.completions.create(
286
  model=MODEL_NAME,
287
  messages=[
288
- {"role": "system", "content": "You are a strict JSON report formatter for data quality audits."},
 
 
 
 
 
 
 
289
  {"role": "user", "content": json.dumps(prompt)},
290
  ],
291
  temperature=0.0,
 
21
 
22
  client: OpenAI | None = None
23
  FORCE_HEURISTIC = os.environ.get("FORCE_HEURISTIC", "0") == "1"
24
+ FALLBACK_SQL = "SELECT 1 AS fallback"
25
 
26
  SEED = int(os.environ.get("SEED", "42"))
27
  TEMPERATURE = 0.1
 
30
  FIX_STEPS = 3
31
  WALL_LIMIT = 15 * 60
32
 
33
+ SYSTEM_PROMPT = """You are a SQL Data Auditor.
34
+
35
+ CRITICAL RULES:
36
+ - Only reason about and reference tables listed in the current observation.
37
+ - Current available tables will be provided in the user message; never query or invent tables outside that list.
38
+ - Never invent table names.
39
+ - When producing JSON, return valid JSON only.
40
+ - When producing SQL, return a single raw SELECT statement only.
41
+
42
+ You investigate dirty SQL datasets.
43
 
44
  AVAILABLE ACTIONS (respond with JSON only, no extra text):
45
 
 
122
  return json.loads(m.group())
123
  except Exception:
124
  pass
125
+ return {"action_type": "query", "sql": FALLBACK_SQL}
126
+
127
+
128
+ def parse_model_action(response_text: str) -> str:
129
+ """Extract a raw SQL query from a model response, tolerating markdown and accidental JSON."""
130
+ clean_text = re.sub(r"```sql|```", "", (response_text or "")).strip()
131
+
132
+ if clean_text.startswith("{"):
133
+ try:
134
+ data = json.loads(clean_text)
135
+ return str(data.get("query") or data.get("sql") or FALLBACK_SQL)
136
+ except Exception:
137
+ pass
138
+
139
+ if clean_text.upper().startswith("SELECT"):
140
+ return clean_text
141
+
142
+ return FALLBACK_SQL
143
 
144
 
145
  def normalize_report(report: dict | None) -> dict:
 
231
  # Close episode safely near step limit.
232
  if step >= total_steps - 1:
233
  return fallback_submit_action(task_id)
234
+ return {"action_type": "query", "sql": parse_model_action(raw)}
235
 
236
  if at == "query":
237
  sql = str(parsed.get("sql", "")).strip()
238
  if not sql:
239
  if step >= total_steps - 1:
240
  return fallback_submit_action(task_id)
241
+ return {"action_type": "query", "sql": parse_model_action(raw)}
242
  if step >= total_steps - 1:
243
  return fallback_submit_action(task_id)
244
  return {"action_type": "query", "sql": sql}
 
249
  # fix_sql is allowed only in fix phase after submit; avoid using it in audit loop.
250
  if step >= total_steps - 1:
251
  return fallback_submit_action(task_id)
252
+ return {"action_type": "query", "sql": parse_model_action(raw)}
253
 
254
 
255
  def llm_ready() -> tuple[bool, str]:
 
300
  def llm_refine_report(task_id: int, obs: dict, evidence: dict, base_report: dict) -> dict:
301
  if client is None:
302
  return base_report
303
+ table_names = ", ".join(sorted((obs.get("tables", {}) or {}).keys())) or "<none>"
304
  prompt = {
305
  "task_id": task_id,
306
  "task_description": obs.get("task_description", ""),
307
  "tables": obs.get("tables", {}),
308
+ "current_available_tables": list((obs.get("tables", {}) or {}).keys()),
309
  "evidence": evidence,
310
  "base_report": base_report,
311
+ "instruction": "Return ONLY a valid JSON object for report with same schema fields. Keep numeric values grounded in evidence and use only the listed tables.",
312
  }
313
  try:
314
  c = client.chat.completions.create(
315
  model=MODEL_NAME,
316
  messages=[
317
+ {
318
+ "role": "system",
319
+ "content": (
320
+ "You are a strict JSON report formatter for data quality audits. "
321
+ f"Only use the current observation's tables: {table_names}. "
322
+ "Do not invent tables. Do not change numeric evidence except to preserve it faithfully."
323
+ ),
324
+ },
325
  {"role": "user", "content": json.dumps(prompt)},
326
  ],
327
  temperature=0.0,
openenv.yaml CHANGED
@@ -74,6 +74,8 @@ reward_range: [-0.1, 1.25]
74
 
75
  reward_design:
76
  audit_score: "0.0–1.0, Brier-adjusted per finding confidence"
 
 
77
  budget_bonus: "up to +0.10 for early report submission"
78
  fix_bonus: "up to +0.25 for correct fix_sql repairs"
79
  destructive_sql_penalty: -0.1
 
74
 
75
  reward_design:
76
  audit_score: "0.0–1.0, Brier-adjusted per finding confidence"
77
+ valid_query_no_signal: "+0.01 for syntactically valid exploratory SQL that returns no obvious issue signal"
78
+ valid_query_finds_issue: "+0.1 for valid SQL that surfaces NULLs, duplicates, or other clear audit evidence"
79
  budget_bonus: "up to +0.10 for early report submission"
80
  fix_bonus: "up to +0.25 for correct fix_sql repairs"
81
  destructive_sql_penalty: -0.1
outputs/agent_memory.json CHANGED
The diff for this file is too large to render. See raw diff