Hemanth Kunta commited on
Commit
cb70147
·
1 Parent(s): 94595e2

final fix for phase 2 stdout score range

Browse files
Files changed (1) hide show
  1. inference.py +42 -18
inference.py CHANGED
@@ -29,6 +29,7 @@ MAX_TOKENS = 1000
29
  MAX_AUDIT_STEPS = 9
30
  FIX_STEPS = 3
31
  WALL_LIMIT = 15 * 60
 
32
 
33
  SYSTEM_PROMPT = """You are a SQL Data Auditor.
34
 
@@ -110,6 +111,24 @@ def emit_block(kind: str, **fields) -> None:
110
  print(" ".join(parts), flush=True)
111
 
112
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  def parse_action(text: str) -> dict:
114
  raw = (text or "").strip()
115
  raw = raw.replace("```json", "").replace("```", "").strip()
@@ -476,24 +495,26 @@ def run_task_hybrid(task_id: int, global_start: float) -> float:
476
  print(f"Tables: {list(obs['tables'].keys())} | Credits: {obs['query_credits_remaining']}")
477
 
478
  if time.time() - global_start > WALL_LIMIT - 60:
479
- return 0.0
 
 
480
 
481
  evidence, base_report = build_probe_report(task_id)
482
  final_report = llm_refine_report(task_id, obs, evidence, base_report)
483
  final_report = normalize_report(final_report)
484
 
485
  out = submit(final_report)
486
- score = float(out.get("reward", {}).get("value", 0.0))
487
  emit_block("STEP", task=task_id, step=1, reward=score, action="submit_report")
488
 
489
  # Optional harmless fix step for bonus phase behavior parity.
490
  try:
491
  fix = call_env("step", {"action": {"action_type": "fix_sql", "sql": "UPDATE orders SET order_total = order_total WHERE 1=0"}})
492
- score = float(fix.get("reward", {}).get("value", score))
493
  emit_block("STEP", task=task_id, step=2, reward=score, action="fix_sql")
494
  except Exception:
495
  pass
496
- print(f" Episode done. Final score: {score:.3f}")
497
  emit_block("END", task=task_id, score=score, steps=2)
498
  return score
499
 
@@ -598,17 +619,17 @@ def run_task_heuristic(task_id: int) -> float:
598
  }
599
 
600
  out = submit(report)
601
- score = float(out.get("reward", {}).get("value", 0.0))
602
- print(f" audit score: {score:.3f}")
603
  emit_block("STEP", task=task_id, step=1, reward=score, action="submit_report")
604
  # One no-op fix to demonstrate fix phase behavior.
605
  try:
606
  fix = call_env("step", {"action": {"action_type": "fix_sql", "sql": "UPDATE orders SET order_total = order_total WHERE 1=0"}})
607
- score = float(fix.get("reward", {}).get("value", score))
608
  emit_block("STEP", task=task_id, step=2, reward=score, action="fix_sql")
609
  except Exception:
610
  pass
611
- print(f" final score: {score:.3f}")
612
  emit_block("END", task=task_id, score=score, steps=2)
613
  return score
614
 
@@ -623,7 +644,7 @@ def run_task(task_id: int, global_start: float) -> float:
623
  print(f"Tables: {list(obs['tables'].keys())} | Credits: {obs['query_credits_remaining']}")
624
 
625
  history = []
626
- final_score = 0.0
627
  total_steps = MAX_AUDIT_STEPS + FIX_STEPS
628
 
629
  for step in range(1, total_steps + 1):
@@ -682,11 +703,11 @@ Return next action JSON only."""
682
  reward = step_result.get("reward", {})
683
 
684
  history.append({"step": step, "action": action.get("action_type", "unknown")})
685
- final_score = float(reward.get("value", final_score))
686
  emit_block("STEP", task=task_id, step=step, reward=final_score, action=action.get("action_type", "unknown"))
687
 
688
  if reward.get("done"):
689
- print(f" Episode done. Final score: {final_score:.3f}")
690
  emit_block("END", task=task_id, score=final_score, steps=step)
691
  return final_score
692
 
@@ -704,7 +725,7 @@ Return next action JSON only."""
704
  }
705
  try:
706
  result = call_env("step", {"action": empty_report})
707
- final_score = float(result.get("reward", {}).get("value", final_score))
708
  except Exception:
709
  pass
710
  emit_block("END", task=task_id, score=final_score, steps=total_steps)
@@ -741,20 +762,23 @@ def main():
741
 
742
  for task_id in [1, 2, 3, 4]:
743
  if time.time() - global_start > WALL_LIMIT - 120:
744
- scores[f"task_{task_id}"] = 0.0
 
 
 
745
  continue
746
  if use_heuristic:
747
- scores[f"task_{task_id}"] = run_task_heuristic(task_id)
748
  else:
749
- scores[f"task_{task_id}"] = run_task_hybrid(task_id, global_start)
750
 
751
  print("\n" + "=" * 60)
752
  print("BASELINE RESULTS (seed=42)")
753
  print("=" * 60)
754
  for k, v in scores.items():
755
- print(f" {k}: {v:.3f}")
756
- mean = sum(scores.values()) / max(len(scores), 1)
757
- print(f" mean: {mean:.3f}")
758
  print(f" total wall time: {(time.time() - global_start) / 60:.1f} min")
759
  if not use_heuristic and all(v <= 0.0 for v in scores.values()):
760
  print("WARNING: LLM mode ran but all scores are 0.0. Check model connectivity and prompt behavior.")
 
29
  MAX_AUDIT_STEPS = 9
30
  FIX_STEPS = 3
31
  WALL_LIMIT = 15 * 60
32
+ SCORE_EPS = 1e-6
33
 
34
  SYSTEM_PROMPT = """You are a SQL Data Auditor.
35
 
 
111
  print(" ".join(parts), flush=True)
112
 
113
 
114
+ def strict_score(value: float | int | str | None, default: float = SCORE_EPS) -> float:
115
+ """Clamp score into strict open interval (0,1) for validator compatibility."""
116
+ try:
117
+ v = float(value)
118
+ except Exception:
119
+ v = float(default)
120
+ if not (v > 0.0):
121
+ return SCORE_EPS
122
+ if not (v < 1.0):
123
+ return 1.0 - SCORE_EPS
124
+ return v
125
+
126
+
127
+ def score_text(value: float | int | str | None, default: float = SCORE_EPS) -> str:
128
+ """Stable string formatting for printed score lines without rounding to 1.000."""
129
+ return f"{strict_score(value, default=default):.6f}"
130
+
131
+
132
  def parse_action(text: str) -> dict:
133
  raw = (text or "").strip()
134
  raw = raw.replace("```json", "").replace("```", "").strip()
 
495
  print(f"Tables: {list(obs['tables'].keys())} | Credits: {obs['query_credits_remaining']}")
496
 
497
  if time.time() - global_start > WALL_LIMIT - 60:
498
+ score = strict_score(0.0)
499
+ emit_block("END", task=task_id, score=score, steps=0)
500
+ return score
501
 
502
  evidence, base_report = build_probe_report(task_id)
503
  final_report = llm_refine_report(task_id, obs, evidence, base_report)
504
  final_report = normalize_report(final_report)
505
 
506
  out = submit(final_report)
507
+ score = strict_score(out.get("reward", {}).get("value", 0.0))
508
  emit_block("STEP", task=task_id, step=1, reward=score, action="submit_report")
509
 
510
  # Optional harmless fix step for bonus phase behavior parity.
511
  try:
512
  fix = call_env("step", {"action": {"action_type": "fix_sql", "sql": "UPDATE orders SET order_total = order_total WHERE 1=0"}})
513
+ score = strict_score(fix.get("reward", {}).get("value", score), default=score)
514
  emit_block("STEP", task=task_id, step=2, reward=score, action="fix_sql")
515
  except Exception:
516
  pass
517
+ print(f" Episode done. Final score: {score_text(score, default=score)}")
518
  emit_block("END", task=task_id, score=score, steps=2)
519
  return score
520
 
 
619
  }
620
 
621
  out = submit(report)
622
+ score = strict_score(out.get("reward", {}).get("value", 0.0))
623
+ print(f" audit score: {score_text(score, default=score)}")
624
  emit_block("STEP", task=task_id, step=1, reward=score, action="submit_report")
625
  # One no-op fix to demonstrate fix phase behavior.
626
  try:
627
  fix = call_env("step", {"action": {"action_type": "fix_sql", "sql": "UPDATE orders SET order_total = order_total WHERE 1=0"}})
628
+ score = strict_score(fix.get("reward", {}).get("value", score), default=score)
629
  emit_block("STEP", task=task_id, step=2, reward=score, action="fix_sql")
630
  except Exception:
631
  pass
632
+ print(f" final score: {score_text(score, default=score)}")
633
  emit_block("END", task=task_id, score=score, steps=2)
634
  return score
635
 
 
644
  print(f"Tables: {list(obs['tables'].keys())} | Credits: {obs['query_credits_remaining']}")
645
 
646
  history = []
647
+ final_score = strict_score(0.0)
648
  total_steps = MAX_AUDIT_STEPS + FIX_STEPS
649
 
650
  for step in range(1, total_steps + 1):
 
703
  reward = step_result.get("reward", {})
704
 
705
  history.append({"step": step, "action": action.get("action_type", "unknown")})
706
+ final_score = strict_score(reward.get("value", final_score), default=final_score)
707
  emit_block("STEP", task=task_id, step=step, reward=final_score, action=action.get("action_type", "unknown"))
708
 
709
  if reward.get("done"):
710
+ print(f" Episode done. Final score: {score_text(final_score, default=final_score)}")
711
  emit_block("END", task=task_id, score=final_score, steps=step)
712
  return final_score
713
 
 
725
  }
726
  try:
727
  result = call_env("step", {"action": empty_report})
728
+ final_score = strict_score(result.get("reward", {}).get("value", final_score), default=final_score)
729
  except Exception:
730
  pass
731
  emit_block("END", task=task_id, score=final_score, steps=total_steps)
 
762
 
763
  for task_id in [1, 2, 3, 4]:
764
  if time.time() - global_start > WALL_LIMIT - 120:
765
+ score = strict_score(0.0)
766
+ emit_block("START", task=task_id, mode="skipped", seed=SEED)
767
+ emit_block("END", task=task_id, score=score, steps=0)
768
+ scores[f"task_{task_id}"] = score
769
  continue
770
  if use_heuristic:
771
+ scores[f"task_{task_id}"] = strict_score(run_task_heuristic(task_id))
772
  else:
773
+ scores[f"task_{task_id}"] = strict_score(run_task_hybrid(task_id, global_start))
774
 
775
  print("\n" + "=" * 60)
776
  print("BASELINE RESULTS (seed=42)")
777
  print("=" * 60)
778
  for k, v in scores.items():
779
+ print(f" {k}: {score_text(v, default=v)}")
780
+ mean = strict_score(sum(scores.values()) / max(len(scores), 1))
781
+ print(f" mean: {score_text(mean, default=mean)}")
782
  print(f" total wall time: {(time.time() - global_start) / 60:.1f} min")
783
  if not use_heuristic and all(v <= 0.0 for v in scores.values()):
784
  print("WARNING: LLM mode ran but all scores are 0.0. Check model connectivity and prompt behavior.")