samrat-rm commited on
Commit
ae1e803
Β·
1 Parent(s): 1288c52

feat: implementing judge LLM which contributes to 15% of scoring

Browse files
Files changed (2) hide show
  1. inference.py +28 -5
  2. llm_judge.py +94 -0
inference.py CHANGED
@@ -30,6 +30,7 @@ load_dotenv()
30
  from openai import OpenAI
31
 
32
  from client import WhyDidItFailEnv
 
33
  from models import WhyDidItFailAction
34
  from server.scenarios import SCENARIOS
35
 
@@ -112,8 +113,8 @@ def _get_action(client: OpenAI, step: int, obs_summary: str, history: List[str])
112
  except Exception as exc:
113
  print(f" [DEBUG] parse error: {exc}", flush=True)
114
  if step <= 2:
115
- return WhyDidItFailAction(action_type="inspect_logs", diagnosis=None, suggested_fix=None)
116
- return WhyDidItFailAction(action_type="submit_diagnosis", diagnosis="unknown", suggested_fix=None)
117
 
118
  # ── episode runner ────────────────────────────────────────────────────────────
119
 
@@ -133,7 +134,15 @@ async def run_episode(env: WhyDidItFailEnv, client: OpenAI, scenario_key: str) -
133
  obs = result.observation
134
  reward = result.reward or 0.0
135
  done = result.done
136
- act_str = action.model_dump_json(exclude_none=True)
 
 
 
 
 
 
 
 
137
 
138
  rewards.append(reward)
139
  history.append(f"Step {step}: {act_str} β†’ reward={reward:.2f} | {obs.feedback}")
@@ -142,8 +151,22 @@ async def run_episode(env: WhyDidItFailEnv, client: OpenAI, scenario_key: str) -
142
  if done:
143
  break
144
 
145
- # Final score = reward on submit_diagnosis (last reward)
146
- score = rewards[-1] if rewards else 0.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
  success = score >= SUCCESS_THRESHOLD
148
  return {"scenario_key": scenario_key, "score": score, "steps": len(rewards), "success": success}
149
 
 
30
  from openai import OpenAI
31
 
32
  from client import WhyDidItFailEnv
33
+ from llm_judge import judge as llm_judge
34
  from models import WhyDidItFailAction
35
  from server.scenarios import SCENARIOS
36
 
 
113
  except Exception as exc:
114
  print(f" [DEBUG] parse error: {exc}", flush=True)
115
  if step <= 2:
116
+ return WhyDidItFailAction(action_type="inspect_logs", diagnosis=None, suggested_fix=None,reasoning=None)
117
+ return WhyDidItFailAction(action_type="submit_diagnosis", diagnosis="unknown", suggested_fix=None,reasoning=None)
118
 
119
  # ── episode runner ────────────────────────────────────────────────────────────
120
 
 
134
  obs = result.observation
135
  reward = result.reward or 0.0
136
  done = result.done
137
+ act_str = action.model_dump_json(exclude_none=True, exclude_defaults=True)
138
+
139
+ if action.action_type in ("inspect_logs", "inspect_config", "inspect_gradients"):
140
+ source = action.action_type.replace("inspect_", "")
141
+ if source not in inspection_order:
142
+ inspection_order.append(source)
143
+
144
+ if action.action_type == "submit_diagnosis":
145
+ submit_action = action # judge runs after loop β€” WebSocket is closed by then
146
 
147
  rewards.append(reward)
148
  history.append(f"Step {step}: {act_str} β†’ reward={reward:.2f} | {obs.feedback}")
 
151
  if done:
152
  break
153
 
154
+ # WebSocket is closed β€” safe to call the judge now
155
+ keyword_score = rewards[-1] if rewards else 0.0
156
+ judge_score = 0.0
157
+ if submit_action is not None:
158
+ judge_score = llm_judge(
159
+ client=client,
160
+ model=MODEL_NAME,
161
+ diagnosis=submit_action.diagnosis or "",
162
+ reasoning=submit_action.reasoning,
163
+ suggested_fix=submit_action.suggested_fix,
164
+ scenario=SCENARIOS[scenario_key],
165
+ inspection_order=inspection_order,
166
+ )
167
+ score = round(0.85 * keyword_score + 0.15 * judge_score, 4)
168
+ print(f" [JUDGE] scenario={scenario_key} keyword={keyword_score:.3f} reasoning={judge_score:.3f} total={score:.3f}", flush=True)
169
+
170
  success = score >= SUCCESS_THRESHOLD
171
  return {"scenario_key": scenario_key, "score": score, "steps": len(rewards), "success": success}
172
 
llm_judge.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LLM Judge β€” reasoning quality scorer for WhyDidItFail.
3
+
4
+ Called from inference.py after submit_diagnosis.
5
+ Uses the same OpenAI-compatible client and model as the agent.
6
+ Returns a normalized score in [0.0, 1.0] representing reasoning quality.
7
+ Returns 0.0 silently if reasoning is absent or the call fails.
8
+
9
+ Scoring criteria (0–5 each, total 0–15 β†’ normalized to 0.0–1.0):
10
+ evidence_grounding β€” does the reasoning cite specific observed values?
11
+ causal_chain β€” does it connect evidence to the failure mode logically?
12
+ fix_rationale β€” is the fix justified by the evidence?
13
+
14
+ Final score in inference.py:
15
+ total = 0.85 * keyword_score + 0.15 * judge_score β†’ always in [0.0, 1.0]
16
+ """
17
+
18
+ import json
19
+
20
+ from openai import OpenAI
21
+
22
+
23
+ def _build_prompt(
24
+ diagnosis: str,
25
+ suggested_fix: str | None,
26
+ reasoning: str,
27
+ scenario: dict,
28
+ inspection_order: list[str],
29
+ ) -> str:
30
+ seen: dict = {}
31
+ if "logs" in inspection_order:
32
+ seen["training_logs"] = scenario.get("logs", [])
33
+ if "config" in inspection_order:
34
+ seen["config"] = scenario.get("config", {})
35
+ if "gradients" in inspection_order:
36
+ seen["gradient_norms"] = scenario.get("gradient_norms", None)
37
+
38
+ return f"""You are evaluating the reasoning of an ML debugging agent.
39
+
40
+ Agent submission:
41
+ Diagnosis: {diagnosis}
42
+ Suggested fix: {suggested_fix or "none provided"}
43
+ Reasoning: {reasoning}
44
+
45
+ Data the agent had access to:
46
+ {json.dumps(seen, indent=2)}
47
+
48
+ Score the reasoning (integers only):
49
+ evidence_grounding (0-5): Does the reasoning cite specific values from the data above?
50
+ causal_chain (0-5): Does it logically connect that evidence to the diagnosed failure mode?
51
+ fix_rationale (0-5): Is the fix directly justified by the evidence and diagnosis?
52
+
53
+ Respond with JSON only, no explanation:
54
+ {{"evidence_grounding": <int>, "causal_chain": <int>, "fix_rationale": <int>}}"""
55
+
56
+
57
+ def judge(
58
+ client: OpenAI,
59
+ model: str,
60
+ diagnosis: str,
61
+ reasoning: str | None,
62
+ suggested_fix: str | None,
63
+ scenario: dict,
64
+ inspection_order: list[str],
65
+ ) -> float:
66
+ """Score reasoning quality. Returns 0.0–1.0. Returns 0.0 if reasoning absent or call fails."""
67
+ if not reasoning or not reasoning.strip():
68
+ return 0.0
69
+
70
+ try:
71
+ completion = client.chat.completions.create(
72
+ model=model,
73
+ messages=[
74
+ {"role": "user", "content": _build_prompt(
75
+ diagnosis, suggested_fix, reasoning, scenario, inspection_order
76
+ )},
77
+ ],
78
+ temperature=0.0,
79
+ max_tokens=64,
80
+ )
81
+ text = (completion.choices[0].message.content or "").strip()
82
+ data = json.loads(text)
83
+
84
+ raw = (
85
+ data.get("evidence_grounding", 0)
86
+ + data.get("causal_chain", 0)
87
+ + data.get("fix_rationale", 0)
88
+ )
89
+ # normalize: raw 0–15 β†’ 0.0–1.0
90
+ return round(max(0, min(15, raw)) / 15, 4)
91
+
92
+ except Exception as exc:
93
+ print(f" [JUDGE] failed: {exc}", flush=True)
94
+ return 0.0