ashishbaberwal commited on
Commit
e225fd7
·
1 Parent(s): 375024b

New Final

Browse files
Files changed (1) hide show
  1. inference.py +37 -6
inference.py CHANGED
@@ -34,6 +34,24 @@ FALLBACK_ACTION = json.dumps({
34
  "final_decision": "changes_requested"
35
  })
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
  def log_start(task: str, env_name: str, model: str, max_steps: int, seed) -> None:
39
  payload = {
@@ -533,16 +551,18 @@ def run_episode(env, agent, task_id: str, max_steps: int, seed=None) -> Dict[str
533
  print(f"Episode error: {error_message}", flush=True)
534
 
535
  try:
536
- final_score = env.get_task_score()
537
  except Exception:
538
- final_score = 0.0
 
 
539
 
540
  try:
541
  diagnostics = env.summary()
542
  except Exception:
543
  diagnostics = {}
544
 
545
- success = final_score >= 0.7 and error_message is None
546
  log_end(success=success, steps=step, score=final_score, rewards=rewards)
547
 
548
  return {
@@ -575,7 +595,13 @@ def run_batch(env, agent, task_ids: List[str], max_steps: int, output: str):
575
  all_results.append(result)
576
  except Exception as e:
577
  print(f"Error on task {task_id}: {e}", flush=True)
578
- all_results.append({"task_id": task_id, "task_score": 0.0, "total_reward": 0.0, "f1": 0.0, "error": str(e)})
 
 
 
 
 
 
579
 
580
  valid = [r for r in all_results if "error" not in r or r.get("error") is None]
581
  avg_score = sum(r["task_score"] for r in valid) / max(1, len(valid))
@@ -618,7 +644,7 @@ def main() -> int:
618
  return 1
619
 
620
  parser = argparse.ArgumentParser(description="Run code review agent")
621
- parser.add_argument("--task-id", type=str, default="bug_detection_easy_1")
622
  parser.add_argument("--max-steps", type=int, default=50)
623
  parser.add_argument("--output", type=str, default="baseline_results.json")
624
  parser.add_argument("--batch", action="store_true")
@@ -646,6 +672,11 @@ def main() -> int:
646
  task_ids = [t["task_id"] for t in TaskDefinitions.get_all_tasks()]
647
  run_batch(env, agent, task_ids, args.max_steps, args.output)
648
  else:
 
 
 
 
 
649
  result = run_episode(env, agent, args.task_id, args.max_steps, seed=args.seed)
650
 
651
  print("\n" + "=" * 60, flush=True)
@@ -668,7 +699,7 @@ def main() -> int:
668
  fallback = {
669
  "task_id": getattr(args, "task_id", "unknown"),
670
  "total_reward": 0.0,
671
- "task_score": 0.0,
672
  "steps": 0,
673
  "error": str(e),
674
  "model": MODEL_NAME,
 
34
  "final_decision": "changes_requested"
35
  })
36
 
37
+ DEFAULT_CORE_TASKS = [
38
+ "bug_detection_easy_1",
39
+ "memory_leak_medium_1",
40
+ "security_hard_1",
41
+ ]
42
+
43
+
44
+ def _open_interval_score(value: float, epsilon: float = 1e-4) -> float:
45
+ try:
46
+ numeric = float(value)
47
+ except Exception:
48
+ numeric = 0.0
49
+ if numeric <= 0.0:
50
+ return epsilon
51
+ if numeric >= 1.0:
52
+ return 1.0 - epsilon
53
+ return numeric
54
+
55
 
56
  def log_start(task: str, env_name: str, model: str, max_steps: int, seed) -> None:
57
  payload = {
 
551
  print(f"Episode error: {error_message}", flush=True)
552
 
553
  try:
554
+ final_score_raw = env.get_task_score()
555
  except Exception:
556
+ final_score_raw = 0.0
557
+
558
+ final_score = _open_interval_score(final_score_raw)
559
 
560
  try:
561
  diagnostics = env.summary()
562
  except Exception:
563
  diagnostics = {}
564
 
565
+ success = final_score_raw >= 0.7 and error_message is None
566
  log_end(success=success, steps=step, score=final_score, rewards=rewards)
567
 
568
  return {
 
595
  all_results.append(result)
596
  except Exception as e:
597
  print(f"Error on task {task_id}: {e}", flush=True)
598
+ all_results.append({
599
+ "task_id": task_id,
600
+ "task_score": round(_open_interval_score(0.0), 4),
601
+ "total_reward": 0.0,
602
+ "f1": 0.0,
603
+ "error": str(e),
604
+ })
605
 
606
  valid = [r for r in all_results if "error" not in r or r.get("error") is None]
607
  avg_score = sum(r["task_score"] for r in valid) / max(1, len(valid))
 
644
  return 1
645
 
646
  parser = argparse.ArgumentParser(description="Run code review agent")
647
+ parser.add_argument("--task-id", type=str, default="__AUTO__")
648
  parser.add_argument("--max-steps", type=int, default=50)
649
  parser.add_argument("--output", type=str, default="baseline_results.json")
650
  parser.add_argument("--batch", action="store_true")
 
672
  task_ids = [t["task_id"] for t in TaskDefinitions.get_all_tasks()]
673
  run_batch(env, agent, task_ids, args.max_steps, args.output)
674
  else:
675
+ if args.task_id == "__AUTO__":
676
+ print("No --task-id provided; running core 3-task baseline.", flush=True)
677
+ run_batch(env, agent, DEFAULT_CORE_TASKS, args.max_steps, args.output)
678
+ return 0
679
+
680
  result = run_episode(env, agent, args.task_id, args.max_steps, seed=args.seed)
681
 
682
  print("\n" + "=" * 60, flush=True)
 
699
  fallback = {
700
  "task_id": getattr(args, "task_id", "unknown"),
701
  "total_reward": 0.0,
702
+ "task_score": round(_open_interval_score(0.0), 4),
703
  "steps": 0,
704
  "error": str(e),
705
  "model": MODEL_NAME,