Spaces:
Sleeping
Sleeping
Hemanth Kunta commited on
Commit ·
cb70147
1
Parent(s): 94595e2
final fix for phase 2 stdout score range
Browse files- inference.py +42 -18
inference.py
CHANGED
|
@@ -29,6 +29,7 @@ MAX_TOKENS = 1000
|
|
| 29 |
MAX_AUDIT_STEPS = 9
|
| 30 |
FIX_STEPS = 3
|
| 31 |
WALL_LIMIT = 15 * 60
|
|
|
|
| 32 |
|
| 33 |
SYSTEM_PROMPT = """You are a SQL Data Auditor.
|
| 34 |
|
|
@@ -110,6 +111,24 @@ def emit_block(kind: str, **fields) -> None:
|
|
| 110 |
print(" ".join(parts), flush=True)
|
| 111 |
|
| 112 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
def parse_action(text: str) -> dict:
|
| 114 |
raw = (text or "").strip()
|
| 115 |
raw = raw.replace("```json", "").replace("```", "").strip()
|
|
@@ -476,24 +495,26 @@ def run_task_hybrid(task_id: int, global_start: float) -> float:
|
|
| 476 |
print(f"Tables: {list(obs['tables'].keys())} | Credits: {obs['query_credits_remaining']}")
|
| 477 |
|
| 478 |
if time.time() - global_start > WALL_LIMIT - 60:
|
| 479 |
-
|
|
|
|
|
|
|
| 480 |
|
| 481 |
evidence, base_report = build_probe_report(task_id)
|
| 482 |
final_report = llm_refine_report(task_id, obs, evidence, base_report)
|
| 483 |
final_report = normalize_report(final_report)
|
| 484 |
|
| 485 |
out = submit(final_report)
|
| 486 |
-
score =
|
| 487 |
emit_block("STEP", task=task_id, step=1, reward=score, action="submit_report")
|
| 488 |
|
| 489 |
# Optional harmless fix step for bonus phase behavior parity.
|
| 490 |
try:
|
| 491 |
fix = call_env("step", {"action": {"action_type": "fix_sql", "sql": "UPDATE orders SET order_total = order_total WHERE 1=0"}})
|
| 492 |
-
score =
|
| 493 |
emit_block("STEP", task=task_id, step=2, reward=score, action="fix_sql")
|
| 494 |
except Exception:
|
| 495 |
pass
|
| 496 |
-
print(f" Episode done. Final score: {score
|
| 497 |
emit_block("END", task=task_id, score=score, steps=2)
|
| 498 |
return score
|
| 499 |
|
|
@@ -598,17 +619,17 @@ def run_task_heuristic(task_id: int) -> float:
|
|
| 598 |
}
|
| 599 |
|
| 600 |
out = submit(report)
|
| 601 |
-
score =
|
| 602 |
-
print(f" audit score: {score
|
| 603 |
emit_block("STEP", task=task_id, step=1, reward=score, action="submit_report")
|
| 604 |
# One no-op fix to demonstrate fix phase behavior.
|
| 605 |
try:
|
| 606 |
fix = call_env("step", {"action": {"action_type": "fix_sql", "sql": "UPDATE orders SET order_total = order_total WHERE 1=0"}})
|
| 607 |
-
score =
|
| 608 |
emit_block("STEP", task=task_id, step=2, reward=score, action="fix_sql")
|
| 609 |
except Exception:
|
| 610 |
pass
|
| 611 |
-
print(f" final score: {score
|
| 612 |
emit_block("END", task=task_id, score=score, steps=2)
|
| 613 |
return score
|
| 614 |
|
|
@@ -623,7 +644,7 @@ def run_task(task_id: int, global_start: float) -> float:
|
|
| 623 |
print(f"Tables: {list(obs['tables'].keys())} | Credits: {obs['query_credits_remaining']}")
|
| 624 |
|
| 625 |
history = []
|
| 626 |
-
final_score = 0.0
|
| 627 |
total_steps = MAX_AUDIT_STEPS + FIX_STEPS
|
| 628 |
|
| 629 |
for step in range(1, total_steps + 1):
|
|
@@ -682,11 +703,11 @@ Return next action JSON only."""
|
|
| 682 |
reward = step_result.get("reward", {})
|
| 683 |
|
| 684 |
history.append({"step": step, "action": action.get("action_type", "unknown")})
|
| 685 |
-
final_score =
|
| 686 |
emit_block("STEP", task=task_id, step=step, reward=final_score, action=action.get("action_type", "unknown"))
|
| 687 |
|
| 688 |
if reward.get("done"):
|
| 689 |
-
print(f" Episode done. Final score: {final_score
|
| 690 |
emit_block("END", task=task_id, score=final_score, steps=step)
|
| 691 |
return final_score
|
| 692 |
|
|
@@ -704,7 +725,7 @@ Return next action JSON only."""
|
|
| 704 |
}
|
| 705 |
try:
|
| 706 |
result = call_env("step", {"action": empty_report})
|
| 707 |
-
final_score =
|
| 708 |
except Exception:
|
| 709 |
pass
|
| 710 |
emit_block("END", task=task_id, score=final_score, steps=total_steps)
|
|
@@ -741,20 +762,23 @@ def main():
|
|
| 741 |
|
| 742 |
for task_id in [1, 2, 3, 4]:
|
| 743 |
if time.time() - global_start > WALL_LIMIT - 120:
|
| 744 |
-
|
|
|
|
|
|
|
|
|
|
| 745 |
continue
|
| 746 |
if use_heuristic:
|
| 747 |
-
scores[f"task_{task_id}"] = run_task_heuristic(task_id)
|
| 748 |
else:
|
| 749 |
-
scores[f"task_{task_id}"] = run_task_hybrid(task_id, global_start)
|
| 750 |
|
| 751 |
print("\n" + "=" * 60)
|
| 752 |
print("BASELINE RESULTS (seed=42)")
|
| 753 |
print("=" * 60)
|
| 754 |
for k, v in scores.items():
|
| 755 |
-
print(f" {k}: {v
|
| 756 |
-
mean = sum(scores.values()) / max(len(scores), 1)
|
| 757 |
-
print(f" mean: {mean
|
| 758 |
print(f" total wall time: {(time.time() - global_start) / 60:.1f} min")
|
| 759 |
if not use_heuristic and all(v <= 0.0 for v in scores.values()):
|
| 760 |
print("WARNING: LLM mode ran but all scores are 0.0. Check model connectivity and prompt behavior.")
|
|
|
|
| 29 |
MAX_AUDIT_STEPS = 9
|
| 30 |
FIX_STEPS = 3
|
| 31 |
WALL_LIMIT = 15 * 60
|
| 32 |
+
SCORE_EPS = 1e-6
|
| 33 |
|
| 34 |
SYSTEM_PROMPT = """You are a SQL Data Auditor.
|
| 35 |
|
|
|
|
| 111 |
print(" ".join(parts), flush=True)
|
| 112 |
|
| 113 |
|
| 114 |
+
def strict_score(value: float | int | str | None, default: float = SCORE_EPS) -> float:
|
| 115 |
+
"""Clamp score into strict open interval (0,1) for validator compatibility."""
|
| 116 |
+
try:
|
| 117 |
+
v = float(value)
|
| 118 |
+
except Exception:
|
| 119 |
+
v = float(default)
|
| 120 |
+
if not (v > 0.0):
|
| 121 |
+
return SCORE_EPS
|
| 122 |
+
if not (v < 1.0):
|
| 123 |
+
return 1.0 - SCORE_EPS
|
| 124 |
+
return v
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def score_text(value: float | int | str | None, default: float = SCORE_EPS) -> str:
|
| 128 |
+
"""Stable string formatting for printed score lines without rounding to 1.000."""
|
| 129 |
+
return f"{strict_score(value, default=default):.6f}"
|
| 130 |
+
|
| 131 |
+
|
| 132 |
def parse_action(text: str) -> dict:
|
| 133 |
raw = (text or "").strip()
|
| 134 |
raw = raw.replace("```json", "").replace("```", "").strip()
|
|
|
|
| 495 |
print(f"Tables: {list(obs['tables'].keys())} | Credits: {obs['query_credits_remaining']}")
|
| 496 |
|
| 497 |
if time.time() - global_start > WALL_LIMIT - 60:
|
| 498 |
+
score = strict_score(0.0)
|
| 499 |
+
emit_block("END", task=task_id, score=score, steps=0)
|
| 500 |
+
return score
|
| 501 |
|
| 502 |
evidence, base_report = build_probe_report(task_id)
|
| 503 |
final_report = llm_refine_report(task_id, obs, evidence, base_report)
|
| 504 |
final_report = normalize_report(final_report)
|
| 505 |
|
| 506 |
out = submit(final_report)
|
| 507 |
+
score = strict_score(out.get("reward", {}).get("value", 0.0))
|
| 508 |
emit_block("STEP", task=task_id, step=1, reward=score, action="submit_report")
|
| 509 |
|
| 510 |
# Optional harmless fix step for bonus phase behavior parity.
|
| 511 |
try:
|
| 512 |
fix = call_env("step", {"action": {"action_type": "fix_sql", "sql": "UPDATE orders SET order_total = order_total WHERE 1=0"}})
|
| 513 |
+
score = strict_score(fix.get("reward", {}).get("value", score), default=score)
|
| 514 |
emit_block("STEP", task=task_id, step=2, reward=score, action="fix_sql")
|
| 515 |
except Exception:
|
| 516 |
pass
|
| 517 |
+
print(f" Episode done. Final score: {score_text(score, default=score)}")
|
| 518 |
emit_block("END", task=task_id, score=score, steps=2)
|
| 519 |
return score
|
| 520 |
|
|
|
|
| 619 |
}
|
| 620 |
|
| 621 |
out = submit(report)
|
| 622 |
+
score = strict_score(out.get("reward", {}).get("value", 0.0))
|
| 623 |
+
print(f" audit score: {score_text(score, default=score)}")
|
| 624 |
emit_block("STEP", task=task_id, step=1, reward=score, action="submit_report")
|
| 625 |
# One no-op fix to demonstrate fix phase behavior.
|
| 626 |
try:
|
| 627 |
fix = call_env("step", {"action": {"action_type": "fix_sql", "sql": "UPDATE orders SET order_total = order_total WHERE 1=0"}})
|
| 628 |
+
score = strict_score(fix.get("reward", {}).get("value", score), default=score)
|
| 629 |
emit_block("STEP", task=task_id, step=2, reward=score, action="fix_sql")
|
| 630 |
except Exception:
|
| 631 |
pass
|
| 632 |
+
print(f" final score: {score_text(score, default=score)}")
|
| 633 |
emit_block("END", task=task_id, score=score, steps=2)
|
| 634 |
return score
|
| 635 |
|
|
|
|
| 644 |
print(f"Tables: {list(obs['tables'].keys())} | Credits: {obs['query_credits_remaining']}")
|
| 645 |
|
| 646 |
history = []
|
| 647 |
+
final_score = strict_score(0.0)
|
| 648 |
total_steps = MAX_AUDIT_STEPS + FIX_STEPS
|
| 649 |
|
| 650 |
for step in range(1, total_steps + 1):
|
|
|
|
| 703 |
reward = step_result.get("reward", {})
|
| 704 |
|
| 705 |
history.append({"step": step, "action": action.get("action_type", "unknown")})
|
| 706 |
+
final_score = strict_score(reward.get("value", final_score), default=final_score)
|
| 707 |
emit_block("STEP", task=task_id, step=step, reward=final_score, action=action.get("action_type", "unknown"))
|
| 708 |
|
| 709 |
if reward.get("done"):
|
| 710 |
+
print(f" Episode done. Final score: {score_text(final_score, default=final_score)}")
|
| 711 |
emit_block("END", task=task_id, score=final_score, steps=step)
|
| 712 |
return final_score
|
| 713 |
|
|
|
|
| 725 |
}
|
| 726 |
try:
|
| 727 |
result = call_env("step", {"action": empty_report})
|
| 728 |
+
final_score = strict_score(result.get("reward", {}).get("value", final_score), default=final_score)
|
| 729 |
except Exception:
|
| 730 |
pass
|
| 731 |
emit_block("END", task=task_id, score=final_score, steps=total_steps)
|
|
|
|
| 762 |
|
| 763 |
for task_id in [1, 2, 3, 4]:
|
| 764 |
if time.time() - global_start > WALL_LIMIT - 120:
|
| 765 |
+
score = strict_score(0.0)
|
| 766 |
+
emit_block("START", task=task_id, mode="skipped", seed=SEED)
|
| 767 |
+
emit_block("END", task=task_id, score=score, steps=0)
|
| 768 |
+
scores[f"task_{task_id}"] = score
|
| 769 |
continue
|
| 770 |
if use_heuristic:
|
| 771 |
+
scores[f"task_{task_id}"] = strict_score(run_task_heuristic(task_id))
|
| 772 |
else:
|
| 773 |
+
scores[f"task_{task_id}"] = strict_score(run_task_hybrid(task_id, global_start))
|
| 774 |
|
| 775 |
print("\n" + "=" * 60)
|
| 776 |
print("BASELINE RESULTS (seed=42)")
|
| 777 |
print("=" * 60)
|
| 778 |
for k, v in scores.items():
|
| 779 |
+
print(f" {k}: {score_text(v, default=v)}")
|
| 780 |
+
mean = strict_score(sum(scores.values()) / max(len(scores), 1))
|
| 781 |
+
print(f" mean: {score_text(mean, default=mean)}")
|
| 782 |
print(f" total wall time: {(time.time() - global_start) / 60:.1f} min")
|
| 783 |
if not use_heuristic and all(v <= 0.0 for v in scores.values()):
|
| 784 |
print("WARNING: LLM mode ran but all scores are 0.0. Check model connectivity and prompt behavior.")
|