Uddiii commited on
Commit
2df5c63
Β·
1 Parent(s): 71a0a91

chore(kaggle): rebuild notebook v3 + clean dev-scratch files

Browse files

* New `kaggle/build_notebook.py` - single source of truth that regenerates
the Kaggle notebook + KAGGLE_QUICKSTART.md from scratch (run once whenever
the layout drifts).
* New `kaggle/KAGGLE_QUICKSTART.md` - concise step-by-step run order with
troubleshooting table for every dependency-hell symptom we hit.
* `kaggle/train_ermap_grpo_kaggle.ipynb` rebuilt as a clean 20-cell layout:
- new constant per-phase reward thresholds (P1=+1.2 / P2=+1.1 / P3=+1.0)
- idempotent REPAIR cell (pins torch 2.10 cu128 + bnb + unsloth/zoo + trl,
verifies in a subprocess so the kernel never gets poisoned mid-install)
- pre-flight Groq routing + PING smoke test using router._models / _clients
- explicit dry-run + HF-push hook + per-phase dashboards + final push.
* `train_grpo.py`: add optional `phase_episode_budgets` (fixed-budget
curriculum mode) alongside the existing reward-threshold early-stop.
Fully backward compatible (default behaviour unchanged); CLI flags
--phase{1,2,3}-budget.
* Remove dev-scratch files: `_smoke_dead_keys.py`, `ER_MAP/_verify.py`,
`ER_MAP/_replot.py`, `kaggle/KAGGLE.md` (replaced by KAGGLE_QUICKSTART).

Made-with: Cursor

ER_MAP/_replot.py DELETED
@@ -1,11 +0,0 @@
1
- """Quick script to regenerate reward curve from saved eval_results.json"""
2
- import json
3
- import sys
4
- sys.path.insert(0, ".")
5
- from ER_MAP.evaluate import plot_reward_curve
6
-
7
- with open("d:/Meta_Finals/ER_MAP/eval_results.json") as f:
8
- results = json.load(f)
9
-
10
- plot_reward_curve(results, "d:/Meta_Finals/ER_MAP/reward_curve.png")
11
- print("Done!")
 
 
 
 
 
 
 
 
 
 
 
 
ER_MAP/_verify.py DELETED
@@ -1,15 +0,0 @@
1
- from ER_MAP.envs.randomizer import DISEASE_POOL, DIFFICULTY_TIERS, generate_ground_truth
2
-
3
- print(f"=== {len(DISEASE_POOL)} DISEASES ===")
4
- for d in DISEASE_POOL:
5
- print(f" {d['true_disease']}")
6
-
7
- print()
8
- print("=== DIFFICULTY TIERS ===")
9
- for tier in ["easy", "medium", "hard"]:
10
- gt = generate_ground_truth(difficulty=tier)
11
- p = gt["patient"]
12
- print(f" {tier.upper():8s} | compliance: {p['compliance']:20s} | comm: {p['communication']:20s} | {gt['disease']['true_disease']}")
13
-
14
- combos = 3 * 4 * 4 * 4 * 4 * 3 * 3 * 3 * 3 * 15
15
- print(f"\nTotal unique scenario combinations: {combos:,}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ER_MAP/training/train_grpo.py CHANGED
@@ -686,9 +686,47 @@ def train(
686
  phase_min_win_rate: float = 0.20,
687
  convergence_window: int = 3,
688
  early_stop: bool = True,
 
 
 
 
 
 
 
 
 
 
689
  ):
690
  if phase_reward_targets is None:
691
  phase_reward_targets = {1: 1.5, 2: 1.2, 3: 1.0}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
692
  """
693
  Main GRPO training loop with curriculum scheduling.
694
 
@@ -758,7 +796,17 @@ def train(
758
 
759
  logger.info(f"\nStarting GRPO training for up to {num_episodes} episodes "
760
  f"(={num_episodes // group_size} GRPO updates)")
761
- if early_stop:
 
 
 
 
 
 
 
 
 
 
762
  logger.info(
763
  f" Early-stop ON: per-phase reward thresholds (sustained for "
764
  f"{convergence_window} groups, win-rate >= {phase_min_win_rate:.0%}):"
@@ -955,6 +1003,32 @@ def train(
955
  f"Phase Episodes: {s['phase_episodes']}"
956
  )
957
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
958
  # --- Per-phase early-stop / promotion check ------------------------
959
  # Maintain a buffer of the last `convergence_window` GRPO groups
960
  # with their (phase, rolling_avg, rolling_win). When ALL N entries
@@ -1099,8 +1173,26 @@ if __name__ == "__main__":
1099
  parser.add_argument("--no-early-stop", action="store_true",
1100
  help="Disable early-stop (always run all configured episodes)")
1101
 
 
 
 
 
 
 
 
 
1102
  args = parser.parse_args()
1103
 
 
 
 
 
 
 
 
 
 
 
1104
  train(
1105
  num_episodes=args.episodes,
1106
  group_size=args.group_size,
@@ -1119,4 +1211,5 @@ if __name__ == "__main__":
1119
  phase_min_win_rate=args.phase_min_win_rate,
1120
  convergence_window=args.convergence_window,
1121
  early_stop=not args.no_early_stop,
 
1122
  )
 
686
  phase_min_win_rate: float = 0.20,
687
  convergence_window: int = 3,
688
  early_stop: bool = True,
689
+ # ---------------- Fixed-budget curriculum (alternative mode) -----------
690
+ # When set, training advances phases at FIXED episode counts instead of
691
+ # via the reward-target early-stop. Useful when you want a clean
692
+ # reward-growth curve over a known wall-clock budget. Example:
693
+ # phase_episode_budgets = {1: 20, 2: 30, 3: 50} # 100 episodes total
694
+ # When this is provided, `early_stop` is forced to False (the reward
695
+ # thresholds become observational, logged for plots only) and
696
+ # `num_episodes` is auto-set to sum(phase_episode_budgets.values()) if
697
+ # the caller passed a smaller / inconsistent value.
698
+ phase_episode_budgets: Optional[Dict[int, int]] = None,
699
  ):
700
  if phase_reward_targets is None:
701
  phase_reward_targets = {1: 1.5, 2: 1.2, 3: 1.0}
702
+
703
+ # Fixed-budget mode overrides early-stop and aligns num_episodes.
704
+ fixed_budget_mode = phase_episode_budgets is not None and len(phase_episode_budgets) > 0
705
+ if fixed_budget_mode:
706
+ # Sanity: must have all 3 phases keyed, all positive ints
707
+ missing = [p for p in (1, 2, 3) if p not in phase_episode_budgets]
708
+ if missing:
709
+ raise ValueError(
710
+ f"phase_episode_budgets must include all phases (1,2,3); missing: {missing}"
711
+ )
712
+ for _p, _n in phase_episode_budgets.items():
713
+ if not isinstance(_n, int) or _n <= 0:
714
+ raise ValueError(
715
+ f"phase_episode_budgets[{_p}] must be a positive int, got {_n!r}"
716
+ )
717
+ budget_sum = sum(phase_episode_budgets.values())
718
+ if num_episodes != budget_sum:
719
+ logger.info(
720
+ f"[Fixed-budget] num_episodes ({num_episodes}) overridden to "
721
+ f"sum(phase_episode_budgets) = {budget_sum}"
722
+ )
723
+ num_episodes = budget_sum
724
+ if early_stop:
725
+ logger.info(
726
+ "[Fixed-budget] early_stop=True is incompatible with fixed budgets; "
727
+ "disabling early_stop. Reward targets will still be logged for plots."
728
+ )
729
+ early_stop = False
730
  """
731
  Main GRPO training loop with curriculum scheduling.
732
 
 
796
 
797
  logger.info(f"\nStarting GRPO training for up to {num_episodes} episodes "
798
  f"(={num_episodes // group_size} GRPO updates)")
799
+ if fixed_budget_mode:
800
+ logger.info(
801
+ " Fixed-budget curriculum: phases advance at fixed episode counts."
802
+ )
803
+ for _pid in sorted(phase_episode_budgets.keys()):
804
+ logger.info(
805
+ f" Phase {_pid}: {phase_episode_budgets[_pid]} episodes "
806
+ f"(target avg-reward {phase_reward_targets.get(_pid, float('nan')):+.2f}, "
807
+ f"observational only)"
808
+ )
809
+ elif early_stop:
810
  logger.info(
811
  f" Early-stop ON: per-phase reward thresholds (sustained for "
812
  f"{convergence_window} groups, win-rate >= {phase_min_win_rate:.0%}):"
 
1003
  f"Phase Episodes: {s['phase_episodes']}"
1004
  )
1005
 
1006
+ # --- Fixed-budget phase transition ---------------------------------
1007
+ # When the operator pre-allocates per-phase episode budgets (e.g.
1008
+ # P1=20, P2=30, P3=50), advance at the boundaries regardless of
1009
+ # reward. Phase 3 budget exhaustion lets the outer-loop
1010
+ # `num_episodes` cap end training naturally.
1011
+ if fixed_budget_mode:
1012
+ current_phase = s["phase"]
1013
+ budget = phase_episode_budgets.get(current_phase, 0)
1014
+ if (
1015
+ current_phase < 3
1016
+ and s["phase_episodes"] >= budget
1017
+ ):
1018
+ promoted = scheduler.force_promote(
1019
+ reason=(
1020
+ f"fixed-budget: completed {s['phase_episodes']} episodes "
1021
+ f"in Phase {current_phase} (budget={budget})"
1022
+ )
1023
+ )
1024
+ if promoted:
1025
+ new_phase = scheduler.phase_id
1026
+ logger.info(
1027
+ f" [Fixed-budget] Phase {current_phase} budget exhausted "
1028
+ f"-> Phase {new_phase}: {scheduler.current_phase.name} "
1029
+ f"({phase_episode_budgets.get(new_phase, '?')} episodes allocated)"
1030
+ )
1031
+
1032
  # --- Per-phase early-stop / promotion check ------------------------
1033
  # Maintain a buffer of the last `convergence_window` GRPO groups
1034
  # with their (phase, rolling_avg, rolling_win). When ALL N entries
 
1173
  parser.add_argument("--no-early-stop", action="store_true",
1174
  help="Disable early-stop (always run all configured episodes)")
1175
 
1176
+ # Fixed-budget curriculum (mutually exclusive with early-stop)
1177
+ parser.add_argument("--phase1-budget", type=int, default=None,
1178
+ help="Fixed episode budget for Phase 1 (Tool Mastery)")
1179
+ parser.add_argument("--phase2-budget", type=int, default=None,
1180
+ help="Fixed episode budget for Phase 2 (Clinical Reasoning)")
1181
+ parser.add_argument("--phase3-budget", type=int, default=None,
1182
+ help="Fixed episode budget for Phase 3 (Empathetic Negotiation)")
1183
+
1184
  args = parser.parse_args()
1185
 
1186
+ _budgets = None
1187
+ if any(b is not None for b in (args.phase1_budget, args.phase2_budget, args.phase3_budget)):
1188
+ if not all(b is not None for b in (args.phase1_budget, args.phase2_budget, args.phase3_budget)):
1189
+ parser.error("--phase{1,2,3}-budget must all be set together")
1190
+ _budgets = {
1191
+ 1: args.phase1_budget,
1192
+ 2: args.phase2_budget,
1193
+ 3: args.phase3_budget,
1194
+ }
1195
+
1196
  train(
1197
  num_episodes=args.episodes,
1198
  group_size=args.group_size,
 
1211
  phase_min_win_rate=args.phase_min_win_rate,
1212
  convergence_window=args.convergence_window,
1213
  early_stop=not args.no_early_stop,
1214
+ phase_episode_budgets=_budgets,
1215
  )
_smoke_dead_keys.py DELETED
@@ -1,333 +0,0 @@
1
- """
2
- Smoke test: simulate the EXACT failure mode from the user's last log
3
- (Patient + Nurse keys revoked, Doctor + Judge keys alive) and verify
4
- that:
5
-
6
- 1. AgentRouter.query falls back to a live judge client
7
- 2. DoctorBrain's chain advances past the dead Doctor key (we'll also
8
- simulate Doctor revocation)
9
- 3. TTS emotion adapter gets disabled after first 401 (no spam)
10
-
11
- Runs from the repo root: ``python _smoke_dead_keys.py``
12
- """
13
- from __future__ import annotations
14
-
15
- import os
16
- import sys
17
- import importlib
18
-
19
- REPO_ROOT = os.path.abspath(os.path.dirname(__file__))
20
- if REPO_ROOT not in sys.path:
21
- sys.path.insert(0, REPO_ROOT)
22
-
23
- CHECKS_PASSED = 0
24
- CHECKS_FAILED = 0
25
-
26
-
27
- def check(label, ok, detail=""):
28
- global CHECKS_PASSED, CHECKS_FAILED
29
- tag = "PASS" if ok else "FAIL"
30
- if ok:
31
- CHECKS_PASSED += 1
32
- else:
33
- CHECKS_FAILED += 1
34
- line = f" [{tag}] {label}"
35
- if detail:
36
- line += f" -- {detail}"
37
- print(line, flush=True)
38
-
39
-
40
- # ---------------------------------------------------------------------------
41
- # 1. AgentRouter fallback chain (Patient + Nurse dead, judges alive)
42
- # ---------------------------------------------------------------------------
43
- print("\n--- Test 1: AgentRouter.query with patient+nurse dead ---", flush=True)
44
-
45
- # Inject env vars BEFORE importing dashboard so demo defaults still get set.
46
- os.environ.setdefault("GROQ_DOCTOR_API_KEY", "gsk_dummy_doctor")
47
- os.environ.setdefault("GROQ_NURSE_API_KEY", "gsk_dummy_nurse")
48
- os.environ.setdefault("GROQ_PATIENT_API_KEY", "gsk_dummy_patient")
49
- os.environ.setdefault("GROQ_EMPATHY_JUDGE_API_KEY", "gsk_dummy_judge")
50
- os.environ.setdefault("GROQ_MEDICAL_JUDGE_API_KEY", "gsk_dummy_judge")
51
-
52
- from ER_MAP.envs import api_router as _api_router_mod # noqa: E402
53
-
54
- class _MockResp:
55
- def __init__(self, content):
56
- self.choices = [type("C", (), {"message": type("M", (), {"content": content})()})()]
57
-
58
-
59
- class _MockClient:
60
- """Mock Groq client that either succeeds or raises a 401."""
61
-
62
- def __init__(self, name, dead=False, payload='{"action":"speak","content":"OK"}'):
63
- self.name = name
64
- self.dead = dead
65
- self.payload = payload
66
- self.calls = 0
67
- self.chat = type("Chat", (), {"completions": self})()
68
-
69
- def create(self, **kw):
70
- self.calls += 1
71
- if self.dead:
72
- raise Exception(
73
- f"Error code: 401 - {{'error': {{'message': 'Invalid API Key', "
74
- f"'type': 'invalid_request_error', 'code': 'invalid_api_key'}}}}"
75
- )
76
- return _MockResp(self.payload)
77
-
78
-
79
- router = _api_router_mod.AgentRouter(
80
- api_key="x",
81
- nurse_api_key="x",
82
- patient_api_key="x",
83
- empathy_judge_api_key="x",
84
- medical_judge_api_key="x",
85
- )
86
-
87
- # Patch all 4 clients with mocks: patient + nurse are dead, judges are alive.
88
- mock_clients = {
89
- "nurse": _MockClient("nurse", dead=True),
90
- "patient": _MockClient("patient", dead=True),
91
- "empathy_judge": _MockClient("empathy_judge", dead=False),
92
- "medical_judge": _MockClient("medical_judge", dead=False),
93
- }
94
- router._clients = mock_clients
95
- router._dead_clients = set() # let runtime detect deadness through the cascade
96
-
97
- # Query as nurse β€” should walk nurse -> patient -> medical_judge and succeed.
98
- result = router.query("nurse", "system", [{"role": "user", "content": "hi"}])
99
-
100
- check(
101
- "router.query('nurse') falls through to a live judge",
102
- result.get("action") == "speak",
103
- f"got {result}",
104
- )
105
- check(
106
- "nurse client was attempted",
107
- mock_clients["nurse"].calls == 1,
108
- f"calls={mock_clients['nurse'].calls}",
109
- )
110
- check(
111
- "patient client was attempted (next in chain)",
112
- mock_clients["patient"].calls == 1,
113
- f"calls={mock_clients['patient'].calls}",
114
- )
115
- check(
116
- "medical_judge client served the request",
117
- mock_clients["medical_judge"].calls == 1,
118
- f"calls={mock_clients['medical_judge'].calls}",
119
- )
120
- check(
121
- "empathy_judge NOT called once medical_judge succeeded",
122
- mock_clients["empathy_judge"].calls == 0,
123
- f"calls={mock_clients['empathy_judge'].calls}",
124
- )
125
- check(
126
- "nurse marked dead in router state",
127
- "nurse" in router._dead_clients,
128
- )
129
- check(
130
- "patient marked dead in router state",
131
- "patient" in router._dead_clients,
132
- )
133
-
134
- # Subsequent queries should skip dead clients entirely.
135
- mock_clients["medical_judge"].calls = 0
136
- mock_clients["empathy_judge"].calls = 0
137
- mock_clients["nurse"].calls = 0
138
- mock_clients["patient"].calls = 0
139
- result2 = router.query("patient", "system", [{"role": "user", "content": "hi"}])
140
- check(
141
- "second query (patient) skips dead clients",
142
- mock_clients["nurse"].calls == 0 and mock_clients["patient"].calls == 0,
143
- )
144
- check(
145
- "second query reaches a live judge",
146
- result2.get("action") == "speak",
147
- f"got {result2}",
148
- )
149
-
150
- # ---------------------------------------------------------------------------
151
- # 2. DoctorBrain key chain
152
- # ---------------------------------------------------------------------------
153
- print("\n--- Test 2: DoctorBrain with primary key dead, fallback alive ---", flush=True)
154
-
155
- from ER_MAP import dashboard as _dash # noqa: E402
156
-
157
- class _DoctorMockChat:
158
- def __init__(self, owner):
159
- self.owner = owner
160
- self.completions = self
161
-
162
- def create(self, **kw):
163
- self.owner.calls += 1
164
- if self.owner.dead:
165
- raise Exception("Error code: 401 - {'error': {'code': 'invalid_api_key'}}")
166
- return _MockResp('{"action":"read_soap","content":"check the chart first"}')
167
-
168
-
169
- class _DoctorMockGroq:
170
- def __init__(self, dead=False):
171
- self.dead = dead
172
- self.calls = 0
173
- self.chat = _DoctorMockChat(self)
174
-
175
-
176
- # Build a brain with 3 keys: primary dead, second dead, third alive.
177
- brain = _dash.DoctorBrain(
178
- api_key="key1",
179
- fallback_api_keys=["key2", "key3"],
180
- model="llama-3.1-8b-instant",
181
- )
182
-
183
- # Replace each entry's client with our mock.
184
- brain._chain[0]["client"] = _DoctorMockGroq(dead=True) # key1 dead
185
- brain._chain[1]["client"] = _DoctorMockGroq(dead=True) # key2 dead
186
- brain._chain[2]["client"] = _DoctorMockGroq(dead=False) # key3 alive
187
-
188
- reply = brain.decide("Patient is here. Vitals pending.")
189
- check(
190
- "DoctorBrain walks past 2 dead keys and uses the 3rd",
191
- '"action":"read_soap"' in reply or "'action': 'read_soap'" in reply,
192
- f"reply={reply[:120]}",
193
- )
194
- check(
195
- "key1 marked dead",
196
- brain._chain[0]["dead"] is True,
197
- )
198
- check(
199
- "key2 marked dead",
200
- brain._chain[1]["dead"] is True,
201
- )
202
- check(
203
- "key3 still alive",
204
- brain._chain[2]["dead"] is False,
205
- )
206
- check(
207
- "key3 actually answered (call count)",
208
- brain._chain[2]["client"].calls == 1,
209
- f"calls={brain._chain[2]['client'].calls}",
210
- )
211
-
212
- # Second decide() should jump straight to key3 β€” no retries on the dead ones.
213
- brain._chain[0]["client"].calls = 0
214
- brain._chain[1]["client"].calls = 0
215
- brain._chain[2]["client"].calls = 0
216
- brain.decide("Now consider next step.")
217
- check(
218
- "second decide() skips dead keys (no extra calls on key1/key2)",
219
- brain._chain[0]["client"].calls == 0 and brain._chain[1]["client"].calls == 0,
220
- )
221
- check(
222
- "second decide() served by key3 again",
223
- brain._chain[2]["client"].calls == 1,
224
- )
225
-
226
- # All 3 dead β†’ falls back to _smart_fallback_action (no crash).
227
- brain2 = _dash.DoctorBrain(
228
- api_key="k1",
229
- fallback_api_keys=["k2"],
230
- model="llama-3.1-8b-instant",
231
- )
232
- brain2._chain[0]["client"] = _DoctorMockGroq(dead=True)
233
- brain2._chain[1]["client"] = _DoctorMockGroq(dead=True)
234
- reply3 = brain2.decide("Patient is here.")
235
- check(
236
- "all keys dead -> _smart_fallback_action returns valid JSON",
237
- reply3.startswith("{") and ('"tool"' in reply3 or '"action"' in reply3),
238
- f"reply={reply3[:120]}",
239
- )
240
-
241
- # ---------------------------------------------------------------------------
242
- # 3. TTS emotion adapter auto-disable on 401
243
- # ---------------------------------------------------------------------------
244
- print("\n--- Test 3: TTS emotion adapter shuts down after first 401 ---", flush=True)
245
-
246
- from ER_MAP import tts_engine as _tts # noqa: E402
247
-
248
- # Make sure ElevenLabs is forced off so we don't hit network.
249
- os.environ["ERMAP_DISABLE_ELEVENLABS"] = "1"
250
-
251
- eng = _tts.TTSEngine(elevenlabs_api_key="", groq_api_key="dummy")
252
-
253
- # Replace its Groq client with a mock that always raises 401.
254
- class _AlwaysAuthFail:
255
- def __init__(self):
256
- self.calls = 0
257
- self.chat = self
258
- self.completions = self
259
-
260
- def create(self, **kw):
261
- self.calls += 1
262
- raise Exception("Error code: 401 - {'error': {'code': 'invalid_api_key'}}")
263
-
264
- mock_groq = _AlwaysAuthFail()
265
- eng._groq_client = mock_groq
266
-
267
- # Trigger the adapter: status helper should report auth=True the first time.
268
- text1, auth1 = _tts._emotionalize_with_status(
269
- "Patient please describe your symptoms in detail.",
270
- "patient_anxious_panicked",
271
- eng._groq_client,
272
- eng._groq_model,
273
- )
274
- check(
275
- "first call hits Groq and observes 401",
276
- auth1 is True and mock_groq.calls == 1,
277
- f"auth={auth1} calls={mock_groq.calls}",
278
- )
279
- check(
280
- "first call still returns usable text via regex fallback",
281
- isinstance(text1, str) and len(text1) > 5,
282
- f"text={text1[:80]}",
283
- )
284
-
285
- # Simulate the engine setting its dead flag and verify subsequent passes
286
- # never hit Groq again.
287
- eng._emotion_adapter_dead = True
288
- mock_groq.calls = 0
289
-
290
- # Run the same code path the engine uses internally:
291
- if eng._emotion_adapter_dead:
292
- # Engine bypasses the LLM call entirely β†’ no Groq invocation.
293
- fallback_only = _tts._fallback_emotion_transform(
294
- "Patient please describe your symptoms in detail.",
295
- "patient_anxious_panicked",
296
- )
297
- fallback_calls = mock_groq.calls
298
- else:
299
- fallback_calls = -1
300
- check(
301
- "engine bypasses Groq once emotion adapter marked dead",
302
- fallback_calls == 0,
303
- f"calls after mark-dead={fallback_calls}",
304
- )
305
- check(
306
- "regex fallback still produces speech",
307
- isinstance(fallback_only, str) and len(fallback_only) > 5,
308
- f"text={fallback_only[:80]}",
309
- )
310
-
311
- # ---------------------------------------------------------------------------
312
- # 4. Health probe smoke (returns DEAD_AUTH on a junk key without crashing)
313
- # ---------------------------------------------------------------------------
314
- print("\n--- Test 4: _probe_groq_key handles an invalid key gracefully ---", flush=True)
315
-
316
- status, detail = _dash._probe_groq_key("gsk_definitely_invalid_key", "llama-3.1-8b-instant", timeout_s=4.0)
317
- check(
318
- "probe returns DEAD_AUTH for invalid key",
319
- status == "DEAD_AUTH",
320
- f"status={status} detail={detail}",
321
- )
322
-
323
- status_missing, _ = _dash._probe_groq_key("", "llama-3.1-8b-instant")
324
- check(
325
- "probe returns MISSING for empty key (no network call)",
326
- status_missing == "MISSING",
327
- )
328
-
329
- # ---------------------------------------------------------------------------
330
- print("\n" + "=" * 60, flush=True)
331
- print(f" RESULT: {CHECKS_PASSED} passed, {CHECKS_FAILED} failed", flush=True)
332
- print("=" * 60, flush=True)
333
- sys.exit(0 if CHECKS_FAILED == 0 else 1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
kaggle/KAGGLE.md DELETED
@@ -1,265 +0,0 @@
1
- # Training ER-MAP on Kaggle Free Tier
2
-
3
- This guide walks you through training the ER-MAP **Doctor agent** with GRPO + 3-phase curriculum learning on Kaggle's free GPU tier β€” **zero dollars**, **30 GPU-hours/week**, **single Tesla T4 16 GB**.
4
-
5
- ## TL;DR β€” fastest path to a converged Doctor
6
-
7
- 1. **Fork** this repo on GitHub (it must be reachable from inside the Kaggle kernel).
8
- 2. Get **5 Groq API keys** from https://console.groq.com/keys (one per role gives you 5x the daily quota; you can also use one key for everything if you don't mind sharing the rate-limit budget).
9
- 3. Get one **HF write token** from https://huggingface.co/settings/tokens (fine-grained, scope: `write` to your own repos) β€” needed so checkpoints survive the 12-hour Kaggle session limit.
10
- 4. **New Notebook on Kaggle** β†’ Settings β†’ **Accelerator: GPU T4 x2** β†’ **Internet: On**.
11
- 5. Add the secrets in the right sidebar (Add-ons β†’ Secrets):
12
- - `GROQ_NURSE_API_KEY`, `GROQ_PATIENT_API_KEY`, `GROQ_EMPATHY_JUDGE_API_KEY`, `GROQ_MEDICAL_JUDGE_API_KEY`
13
- - `HF_TOKEN`
14
- - *(optional)* `WANDB_API_KEY`
15
- 6. Open `kaggle/train_ermap_grpo_kaggle.ipynb` from your fork inside Kaggle (File β†’ Import Notebook β†’ URL).
16
- 7. Edit the two URLs in cell 2 (`GIT_URL`) and cell 5 (`HF_PUSH_REPO`) to your fork / username.
17
- 8. **Run All**.
18
-
19
- Training **stops automatically** the instant the Doctor sustains a phase-specific reward bar for **3 consecutive GRPO groups** β€” `+1.5` in Phase 1 (force-promote), `+1.2` in Phase 2 (force-promote), `+1.0` in Phase 3 (END). This is the "train until optimal rewards are constantly received" guarantee β€” see the *Train-until-optimal* section below. `NUM_EPISODES=120` is just a hard cap; healthy runs converge between episodes 70-130 (~6-11 h on T4 Γ—2).
20
-
21
- You'll see one full 6-panel dashboard PNG **per curriculum phase** land in `/kaggle/working/er_map_grpo_checkpoints/plots/` after training finishes (`phase1_dashboard.png`, `phase2_dashboard.png`, `phase3_dashboard.png`, plus `all_phases_overview.png` and `all_phases_comparison.png`), and your final LoRA adapter will be sitting on Hugging Face Hub at `<you>/ermap-doctor-lora`.
22
-
23
- **What each per-phase dashboard shows:**
24
-
25
- | Panel | What it tells you |
26
- | --- | --- |
27
- | Reward growth | raw episode reward + rolling mean + verified rolling mean |
28
- | Rolling win rate (w=20) | did the policy actually get better in this phase? |
29
- | Outcome distribution over time | stacked WIN/PARTIAL/INCORRECT/AMA_LOSS/FATAL_LOSS bars per ~5-episode bin |
30
- | Reward components | mean of every reward component (process / treatment / empathy / labs / etc.) |
31
- | GRPO update stats | per-group loss + KL β€” should *not* explode |
32
- | Episode length | histogram of step counts β€” should rise from Phase 1 to Phase 3 |
33
-
34
- ---
35
-
36
- ## Hardware feasibility
37
-
38
- | Resource | Kaggle Free Tier | What we use | Headroom |
39
- |---|---|---|---|
40
- | GPU | Tesla T4 16 GB | Llama-3.1-8B-4bit + LoRA(r=16) β‰ˆ 7-9 GB | ~50% free |
41
- | RAM | 13 GB system | base model + tokenizer + buffers β‰ˆ 5 GB | OK |
42
- | Disk | 73 GB | repo + checkpoints + cache β‰ˆ 10 GB | OK |
43
- | Session | 12 h max | typical full Phase-1+early-Phase-2 = 6-8 h | OK |
44
- | Weekly | 30 GPU-h | one full curriculum run + a re-run = ~15-20 h | OK |
45
- | Internet | allowed | Groq calls per env step | OK |
46
-
47
- **Why Llama-3.1-8B over Qwen-3-4B (the other train_grpo.py default)?**
48
- - 8B reasons noticeably better on multi-turn clinical dialogue
49
- - 4-bit quant brings it to 5 GB β€” still fits on T4 with LoRA
50
- - Groq hosts the same 8B (8B-instant) so the deployed inference path matches the training distribution exactly
51
-
52
- If you ever need to fall back to a smaller model (e.g. for a P100 session), edit `MODEL_NAME` in cell 5 of the notebook to `unsloth/Qwen2.5-3B-Instruct-bnb-4bit`. Everything else stays the same.
53
-
54
- ---
55
-
56
- ## Two ways to get the source onto Kaggle
57
-
58
- ### Option A β€” public GitHub fork (recommended)
59
-
60
- In cell 2 of the notebook:
61
- ```python
62
- GIT_URL = "https://github.com/YOUR_USERNAME/Meta_Finals.git"
63
- BRANCH = "main"
64
- ```
65
- The cell does a shallow clone into `/kaggle/working/Meta_Finals` and you're done.
66
-
67
- ### Option B β€” upload as a Kaggle Dataset (no GitHub needed)
68
-
69
- 1. Locally:
70
- ```bash
71
- cd D:/Meta_Finals
72
- # Exclude heavy/regenerable folders before zipping.
73
- tar --exclude='.git' --exclude='__pycache__' --exclude='*.ipynb_checkpoints' \
74
- --exclude='er_map_grpo_checkpoints' \
75
- -czf ermap-source.tar.gz .
76
- ```
77
- 2. Kaggle β†’ **Datasets** β†’ **New Dataset** β†’ upload `ermap-source.tar.gz` β†’ name it **`ermap-source`** β†’ save.
78
- 3. In your training notebook β†’ right sidebar β†’ **+ Add Data** β†’ search for **`ermap-source`** β†’ Add.
79
- 4. Cell 2 of the notebook detects `/kaggle/input/ermap-source/` and copies it into `/kaggle/working/Meta_Finals` automatically.
80
-
81
- Use Option B when:
82
- - Your fork is private and you don't want to expose the repo
83
- - You have local edits not yet pushed
84
- - Bandwidth from Kaggle to GitHub is flaky
85
-
86
- ---
87
-
88
- ## What happens when the 12-hour session ends mid-training
89
-
90
- Without intervention you'd lose everything. The notebook prevents this:
91
-
92
- 1. **Periodic HF Hub push.** Cell 7 monkey-patches `save_lora_adapters()` so every checkpoint saved by the GRPO loop also pushes to your `HF_PUSH_REPO`. The training loop checkpoints every `group_size Γ— 5` episodes (so every 10 episodes when `GROUP_SIZE=2`).
93
- 2. **Resume on the next session.** Set `HF_RESUME_REPO` in cell 4 of the notebook on the *new* Kaggle session. The latest LoRA adapter is downloaded to `/kaggle/working/checkpoints/resume/` before training starts β€” but **the current `train_grpo.py` doesn't auto-load this folder yet**; for now use it as a manual recovery (load the adapter and continue training in code). A future PR will wire the auto-resume into `load_model_and_tokenizer`.
94
-
95
- In practice: a single 12-hour session is usually enough to clear Phase 1 and produce publishable per-phase dashboards, so resume is the safety net rather than the main path.
96
-
97
- > **Re-render plots from any saved metrics file** (locally or in another Kaggle session):
98
- > ```bash
99
- > python -m ER_MAP.plotting \
100
- > --metrics er_map_grpo_checkpoints/training_metrics.json \
101
- > --out er_map_grpo_checkpoints/plots
102
- > ```
103
- > This is the same call the notebook makes β€” handy if you want to regenerate the charts after training, or restyle them without re-running training.
104
-
105
- ---
106
-
107
- ## Per-role Groq keys vs. one shared key
108
-
109
- The dashboard ships with 4 distinct Groq clients (Nurse, Patient, Empathy Judge, Medical Judge) and a fallback chain that walks across all four if any fails auth. Per-key budgets are *shared* on Groq's free tier (limits are per-account, not per-key) β€” but the model split below buys you real headroom because **each model has its own daily pool**.
110
-
111
- ### Default model assignment (traffic-shaping)
112
-
113
- | Role | Model | Free-tier pool | Why |
114
- |---|---|---|---|
115
- | Nurse | `llama-3.1-8b-instant` | 14 400 RPD / 500K TPD | high-volume (every env step) |
116
- | Patient | `llama-3.1-8b-instant` | shared 8B pool | high-volume (every env step) |
117
- | Empathy Judge | `llama-3.3-70b-versatile` | 1 000 RPD / 100K TPD | grading quality directly shapes reward |
118
- | Medical Judge | `llama-3.3-70b-versatile` | shared 70B pool | grading quality directly shapes reward |
119
-
120
- Quick budget check for **one full 120-episode training run**:
121
-
122
- | Pool | Estimated calls/run | Daily ceiling | Headroom |
123
- |---|---|---|---|
124
- | 8B-instant (Nurse + Patient) | ~2 880 | 14 400 RPD | ~5x |
125
- | 70B-versatile (judges) | ~720 | 1 000 RPD | ~1.4x |
126
-
127
- You can do **one training run per day per account** comfortably. If you need to retry inside the same day, drop one of the two judges to 8B-instant temporarily β€” the reward signal degrades a little, but training keeps moving.
128
-
129
- If you only have **one** Groq key total, set just `GROQ_API_KEY` as a Kaggle Secret. Everything still works β€” the AgentRouter falls back to the same client for all roles, and the per-model budgets still split traffic across pools.
130
-
131
- ---
132
-
133
- ## What the reward-growth curve should look like
134
-
135
- If training is healthy, after ~80 episodes you should see:
136
-
137
- - **Rolling avg reward** climbs from β‰ˆ -0.4 (random baseline) toward +1.5+ (the early-stop target)
138
- - **Rolling win rate** climbs from ~10% to 40%+
139
- - A **vertical red dashed line** marks the Phase 1 β†’ Phase 2 promotion (typically episode 30-60), and a second one marks Phase 2 β†’ Phase 3 (typically episode 60-90)
140
- - KL divergence stays in `[0.005, 0.05]` β€” if it spikes above 0.5 the model is drifting (lower `LEARNING_RATE` and re-run)
141
-
142
- If the curve is flat or trending down:
143
- - Check that Groq is actually responding (look for `Groq API error` lines in the log)
144
- - Check that `rewards.std()` is non-zero across the group (cell logs print `adv_std=`; if it's < 1e-6 GRPO skips the update)
145
- - Drop `GROUP_SIZE` from 2 β†’ 1? **Don't** β€” group size 1 = no advantage signal = no GRPO update. Keep G β‰₯ 2.
146
-
147
- ---
148
-
149
- ## Train-until-optimal β€” per-phase reward thresholds
150
-
151
- > *"I want training until certain optimal rewards are constantly received."*
152
-
153
- After every GRPO update the loop maintains a **rolling buffer of the last `CONVERGENCE_WINDOW=3` groups**. When all 3 entries are in the *same* current phase AND each has `rolling_avg_reward >= PHASE_REWARD_TARGETS[current_phase]`, the loop reacts:
154
-
155
- | Current phase | When buffer qualifies | Effect |
156
- |---|---|---|
157
- | Phase 1 (Tool Mastery) | sustained `+1.5` for 3 groups | force-promote to Phase 2, clear buffer |
158
- | Phase 2 (Clinical Reasoning) | sustained `+1.2` for 3 groups | force-promote to Phase 3, clear buffer |
159
- | Phase 3 (Empathetic Negotiation) | sustained `+1.0` for 3 groups | **END TRAINING** |
160
-
161
- The buffer is cleared after each promotion so stale entries cannot pre-satisfy the next phase's bar. A soft `PHASE_MIN_WIN_RATE=0.20` floor prevents stopping on partial-credit-only runs.
162
-
163
- If even one group in the window slips below the bar, the counter resets β€” guaranteeing the policy is *constantly* hitting the target, not transiently. `NUM_EPISODES` becomes a hard safety cap, not a fixed budget.
164
-
165
- ### Why the targets descend from `1.5` β†’ `1.2` β†’ `1.0`
166
-
167
- The phases are not equally easy to score on. Looking at the reward function:
168
-
169
- | Phase | Best-case reward (clean win) | Realistic clean-policy mean |
170
- |---|---|---|
171
- | 1 β€” easy patient, clean SOAP | `+2.0` (full terminal_win) | `+1.6 .. +1.8` |
172
- | 2 β€” mixed compliance + noisy SOAP | `+1.7` (terminal_win - some lab noise) | `+1.2 .. +1.4` |
173
- | 3 β€” full persona randomization + consent costs | `+1.4` (terminal_win - empathy/AMA penalties) | `+1.0 .. +1.2` |
174
-
175
- So requiring `+1.5` in Phase 1 demonstrates real tool mastery (not just floor-grazing), while requiring `+1.0` in Phase 3 is genuinely hard β€” no random policy ever sustains it.
176
-
177
- ### Defaults (in the notebook, section 5)
178
-
179
- ```python
180
- EARLY_STOP_ENABLED = True
181
- PHASE_REWARD_TARGETS = {1: 1.5, 2: 1.2, 3: 1.0}
182
- PHASE_MIN_WIN_RATE = 0.20 # soft floor
183
- CONVERGENCE_WINDOW = 3
184
- ```
185
-
186
- ### Reading the per-phase telemetry
187
-
188
- After every GRPO group the log prints:
189
-
190
- ```
191
- [Scheduler] Phase 2 (Clinical Reasoning) | Win Rate: 42.0% | Avg Reward: +1.18 | Phase Episodes: 14
192
- [EarlyStop] Phase 2 target avg-reward >= +1.20: qualified 2/3 recent groups (need all 3 -> promote)
193
- ```
194
-
195
- When the Phase-2 buffer fills with 3 qualifying groups:
196
-
197
- ```
198
- [Scheduler] force_promote() called: sustained rolling-avg-reward +1.20 for 3 consecutive groups in Phase 2
199
- ************************************************************
200
- CURRICULUM PROMOTION: Clinical Reasoning -> Empathetic Negotiation
201
- ************************************************************
202
- ```
203
-
204
- When Phase 3 finally converges:
205
-
206
- ```
207
- ************************************************************
208
- EARLY STOP: Phase 3 convergence reached after 92 episodes
209
- Last 3 groups all sustained:
210
- rolling_avg_reward >= +1.00
211
- rolling_win_rate >= 20%
212
- in Phase 3 (Empathetic Negotiation)
213
- ************************************************************
214
- ```
215
-
216
- …and the loop exits cleanly into the final-save / final-push / plotting cells.
217
-
218
- ### Per-phase wall-clock estimates on Kaggle T4 Γ—2
219
-
220
- | Phase | Typical episodes to hit target | Wall-clock | Why |
221
- |---|---|---|---|
222
- | 1 | 16 – 30 episodes (8 – 15 groups) | **~1.5 – 2.5 h** | Easy patients + clean SOAP; tool format is the only real lift. |
223
- | 2 | 24 – 40 episodes (12 – 20 groups) | **~2.0 – 3.5 h** | Most policy improvement happens here. |
224
- | 3 | 30 – 60 episodes (15 – 30 groups) | **~2.5 – 5.0 h** | Empathy + consent costs make `+1.0` genuinely hard. |
225
- | **Total** | 70 – 130 episodes | **~6 – 11 h** | Fits the 12 h GPU session with ~1 h margin. |
226
-
227
- Per-group wall-clock β‰ˆ 8 – 12 min on T4 (depending on episode length); per-episode β‰ˆ 3 – 5 min for env rollout + β‰ˆ 1 – 2 min amortized for the GRPO update.
228
-
229
- ### Tuning suggestions
230
-
231
- | Goal | What to change |
232
- |---|---|
233
- | Smoke run (converge fast on a weak policy) | `PHASE_REWARD_TARGETS={1: 0.5, 2: 0.4, 3: 0.3}`, `CONVERGENCE_WINDOW=2` |
234
- | Hackathon-grade Doctor | keep defaults |
235
- | Aim for SOTA on this benchmark | `PHASE_REWARD_TARGETS={1: 1.7, 2: 1.5, 3: 1.3}`, `CONVERGENCE_WINDOW=5` |
236
- | Disable entirely (run full 120 episodes regardless) | `EARLY_STOP_ENABLED=False` |
237
- | Resuming from a partial run | targets are unchanged β€” the buffer is rebuilt from the new session's groups, so nothing weird happens |
238
-
239
- ### What NOT to do
240
-
241
- - Don't set `CONVERGENCE_WINDOW=1`. A single lucky group can pass any bar; you'll promote out of Phase 1 the instant the first easy patient is correctly discharged.
242
- - Don't lower the Phase-1 target below `+0.8`. The built-in scheduler already promotes Phase 1 β†’ Phase 2 at `win_rate >= 40% AND avg_reward >= +0.3`, so a Phase-1 reward bar below that is dead code.
243
- - Don't raise the Phase-3 target above `+1.5`. The reward ceiling on a Phase-3 episode (after empathy/consent costs) is around `+1.4 .. +1.6`; sustaining `+1.5+` for 3 consecutive groups is essentially unachievable in 12 h.
244
-
245
- ---
246
-
247
- ## Common Kaggle gotchas
248
-
249
- | Symptom | Fix |
250
- |---|---|
251
- | `Groq API error: 401 invalid_api_key` | Regenerate the key (Groq auto-revokes keys posted publicly). Update the Kaggle Secret. |
252
- | `OutOfMemoryError` on T4 | Drop `MAX_SEQ_LENGTH` from 2048 to 1536 inside `load_model_and_tokenizer`, or switch to `unsloth/Qwen2.5-3B-Instruct-bnb-4bit`. |
253
- | `unsloth import` failed | Restart kernel after `pip install` β€” Unsloth pins `xformers` versions and the running kernel keeps the old import cached. |
254
- | Checkpoints not appearing on HF Hub | Verify `HF_PUSH_REPO` doesn't still contain the `<your-username>` placeholder, and that `HF_TOKEN` has `write` scope. |
255
- | "Internet off" warning | Right sidebar β†’ Settings β†’ toggle Internet to **On**. (Default is off for new accounts.) |
256
-
257
- ---
258
-
259
- ## Cost summary
260
-
261
- - **Kaggle**: free
262
- - **Groq API (training)**: free (within free-tier daily quotas, ~2 000 calls per full run)
263
- - **Hugging Face Hub**: free for the LoRA adapter (~50 MB) + free for the merged fp16 (~16 GB on a public repo, free up to 1 TB total)
264
- - **Wandb**: free for personal projects
265
- - **Total**: $0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
kaggle/KAGGLE_QUICKSTART.md ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Kaggle Quickstart β€” ER-MAP GRPO Training (v3 stable)
2
+
3
+ The Kaggle notebook is in `kaggle/train_ermap_grpo_kaggle.ipynb`. This file
4
+ is the cheat sheet for running it end-to-end without the dependency hell
5
+ that bit us in earlier attempts.
6
+
7
+ ## 0. Prerequisites (one-time)
8
+
9
+ 1. **GitHub fork** of this repo. The notebook clones from a public fork at
10
+ cell 6 β€” edit `GIT_URL`. Alternatively, upload the repo as a Kaggle
11
+ Dataset named `ermap-source` (Add Data β†’ Upload).
12
+ 2. **Hugging Face write token** (`HF_TOKEN`) for pushing the trained
13
+ adapter. Create at https://huggingface.co/settings/tokens (fine-grained,
14
+ write access on a single model repo is enough).
15
+ 3. **Five Groq keys** (one each for Nurse / Patient / Empathy Judge /
16
+ Medical Judge / shared fallback). Free-tier accounts are fine; the
17
+ per-account limits multiply across keys.
18
+
19
+ ## 1. Create the Kaggle notebook
20
+
21
+ 1. Sign in to https://www.kaggle.com/code β†’ **New Notebook**.
22
+ 2. Right sidebar:
23
+ - Accelerator: **GPU T4 Γ—2** (or P100)
24
+ - Internet: **On**
25
+ - Persistence: Files only
26
+ 3. **File β†’ Upload Notebook** β†’ choose `kaggle/train_ermap_grpo_kaggle.ipynb`
27
+ from this repo.
28
+
29
+ ## 2. Add Kaggle Secrets
30
+
31
+ Add-ons β†’ Secrets β†’ Add a new secret. Required labels (exactly):
32
+
33
+ | Label | Value |
34
+ |---|---|
35
+ | `GROQ_NURSE_API_KEY` | your nurse Groq key |
36
+ | `GROQ_PATIENT_API_KEY` | your patient Groq key |
37
+ | `GROQ_EMPATHY_JUDGE_API_KEY` | your empathy-judge Groq key |
38
+ | `GROQ_MEDICAL_JUDGE_API_KEY` | your medical-judge Groq key |
39
+ | `HF_TOKEN` | your HF write token |
40
+ | `WANDB_API_KEY` *(optional)* | your W&B key (skip β€” disabled by default) |
41
+
42
+ The notebook reads them via `kaggle_helpers.load_kaggle_secrets()` and
43
+ exports them as env vars.
44
+
45
+ ## 3. Edit two placeholders in the notebook
46
+
47
+ - **Cell 6:** `GIT_URL = "https://github.com/<your-fork>/Meta_Finals.git"`
48
+ - **Cell 8:** `HF_PUSH_REPO = "<your-username>/ermap-doctor-lora"`
49
+
50
+ If you uploaded the repo as a Kaggle Dataset instead, leave `GIT_URL` as the
51
+ placeholder β€” cell 6 will detect `/kaggle/input/ermap-source` and copy from
52
+ there.
53
+
54
+ ## 4. Run order (the only sequence that works)
55
+
56
+ | Cell | What it does | Expected output |
57
+ |---|---|---|
58
+ | 2 | GPU + disk + python + internet sanity check | GPU listed, disk free > 8 GB |
59
+ | 3 | **REPAIR** β€” pin torch 2.10 cu128, reinstall bitsandbytes, upgrade unsloth | `REPAIR OK` (or `RESTART REQUIRED`) |
60
+ | **(restart)** | If cell 3 said RESTART REQUIRED β†’ Run β†’ Restart kernel | β€” |
61
+ | 5 | Post-restart import verify | All `OK`, GPUs listed |
62
+ | 6 | Clone / mount the repo | `OK. Repo at /kaggle/working/Meta_Finals` |
63
+ | 7 | Wire Kaggle Secrets β†’ env vars | `OK β€” at least one Groq key is wired` |
64
+ | 8 | HF Hub config | `Starting fresh β€” no resume.` |
65
+ | 9 | Hyperparameters (P1=+1.2, P2=+1.1, P3=+1.0) | thresholds printed |
66
+ | 10 | **Pre-flight** β€” Groq routing + 4Γ— PING | 4Γ— `[PASS]`, then `OK` |
67
+ | 11 | Dry-run smoke test (no GPU) | `Dry-run OK` |
68
+ | 12 | Wire HF push hook | `Hub-push hook installed.` |
69
+ | 13 | **REAL TRAINING** (4–6 h) | per-group rolling stats, eventual `EARLY STOP` |
70
+ | 14 | Final push to HF | `Final checkpoints pushed: https://huggingface.co/...` |
71
+ | 15 | Per-phase plots | 5 PNGs displayed inline |
72
+ | 16 | Push plots to HF | `Plots pushed: ...` |
73
+ | 17 | Inference smoke-test (optional) | 3 sample Doctor actions printed |
74
+
75
+ ## 5. Common failures & fixes
76
+
77
+ | Symptom | Root cause | Fix |
78
+ |---|---|---|
79
+ | `numpy was upgraded mid-session` | numpy import poisoned by a previous cell | Restart kernel, re-run from cell 3 |
80
+ | `Pillow incompatible with torchvision` | Pillow ABI mismatch | Restart kernel, re-run from cell 3 |
81
+ | `PyTorch and torchvision compiled with different CUDA major` | torch upgraded to cu13 by a transient resolve | Re-run cell 3 (it pins cu128) and restart |
82
+ | `cannot import name 'create_gradient_checkpointing_buffer'` | unsloth ↔ unsloth_zoo version drift | Re-run cell 3 (upgrades both in lockstep) |
83
+ | `libnvJitLink.so.13 missing` | bitsandbytes built against different CUDA | Re-run cell 3 (force-reinstalls bitsandbytes after torch pin) |
84
+ | Disk usage > quota | Kaggle's 20 GB working partition fills up | First line of cell 3 cleans `/tmp` and pip cache |
85
+ | Pre-flight `[FAIL]` for a role | Groq key dead / quota exceeded | Generate a new key in console.groq.com β†’ update Kaggle Secret β†’ re-run cell 7+10 |
86
+ | `[FAIL]` says `routing=WRONG` | env var not set when `AgentRouter()` was constructed | Re-run cell 9 BEFORE cell 10 |
87
+ | Training freezes at episode 1 for >10 min | Doctor.generate hung; Unsloth import broke silently | Check cell 5 output for `unsloth` line; restart kernel and re-run cell 3 if missing |
88
+
89
+ ## 6. What the trained model gives you
90
+
91
+ After cell 13 finishes (or hits the 12 h Kaggle session cap), you have:
92
+
93
+ - `OUTPUT_DIR/final_lora/` β€” LoRA adapter weights (~50 MB), pushed to
94
+ `HF_PUSH_REPO`
95
+ - `OUTPUT_DIR/final_merged_fp16/` β€” full Llama-3.1-8B fp16 merge with the
96
+ adapter applied (~16 GB), pushed to `HF_PUSH_REPO-merged`
97
+ - `OUTPUT_DIR/training_metrics.json` β€” per-episode rewards, outcomes,
98
+ rolling stats β€” input for the per-phase plots
99
+ - `OUTPUT_DIR/plots/*.png` β€” 5 dashboards (one per phase + cross-phase
100
+ overview + comparison bar)
101
+
102
+ Use the LoRA adapter for the demo (quick to load, runs on a 4050 6 GB at
103
+ ~30 tok/s); use the merged fp16 if you need to host on a Vercel/HF Space
104
+ without `peft`.
kaggle/build_notebook.py ADDED
@@ -0,0 +1,880 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ kaggle/build_notebook.py
3
+ ========================
4
+ Programmatically (re)builds `train_ermap_grpo_kaggle.ipynb` from scratch.
5
+
6
+ Why a builder script?
7
+ --------------------
8
+ The hand-edited notebook drifted into a fragile state across many sessions:
9
+ mixed early-stop / fixed-budget params, stale install snippets, dead pre-flight
10
+ checks, etc. This script is the single source of truth β€” run it once and the
11
+ notebook is regenerated as a clean, deterministic v3 layout.
12
+
13
+ Run:
14
+ python kaggle/build_notebook.py
15
+
16
+ Output:
17
+ kaggle/train_ermap_grpo_kaggle.ipynb (overwritten)
18
+ kaggle/KAGGLE_QUICKSTART.md (overwritten)
19
+ """
20
+ from __future__ import annotations
21
+
22
+ import json
23
+ import textwrap
24
+ from pathlib import Path
25
+
26
+
27
+ # ---------------------------------------------------------------------------
28
+ # Cell helpers
29
+ # ---------------------------------------------------------------------------
30
+
31
+ def md_cell(text: str) -> dict:
32
+ return {
33
+ "cell_type": "markdown",
34
+ "metadata": {},
35
+ "source": _split_keep_newlines(text),
36
+ }
37
+
38
+
39
+ def code_cell(text: str) -> dict:
40
+ return {
41
+ "cell_type": "code",
42
+ "execution_count": None,
43
+ "metadata": {},
44
+ "outputs": [],
45
+ "source": _split_keep_newlines(text),
46
+ }
47
+
48
+
49
+ def _split_keep_newlines(text: str) -> list[str]:
50
+ """Notebook 'source' fields expect each line to terminate with '\n'
51
+ except the last one. Splitting like this keeps `git diff` clean when
52
+ the notebook is regenerated."""
53
+ text = textwrap.dedent(text).lstrip("\n")
54
+ if not text.endswith("\n"):
55
+ text = text + "\n"
56
+ lines = text.splitlines(keepends=True)
57
+ if lines:
58
+ # The last line should NOT have a trailing newline (Jupyter convention).
59
+ if lines[-1].endswith("\n"):
60
+ lines[-1] = lines[-1].rstrip("\n")
61
+ return lines
62
+
63
+
64
+ # ---------------------------------------------------------------------------
65
+ # Cell sources
66
+ # ---------------------------------------------------------------------------
67
+
68
+ CELL_01_TITLE = """\
69
+ # ER-MAP β€” Doctor Agent GRPO Training (Kaggle Free-Tier Β· v3 stable)
70
+
71
+ Trains the **Doctor LLM** (Llama-3.1-8B-Instruct, 4-bit + LoRA r=16) via GRPO
72
+ with a 3-phase curriculum on Kaggle's free GPU. Designed to survive Kaggle's
73
+ pre-baked image quirks (numpy / Pillow ABI mismatches, torch + torchvision
74
+ CUDA-major mismatches, transient `unsloth_zoo` upgrades).
75
+
76
+ ## TL;DR β€” How to run this notebook
77
+
78
+ 1. **Notebook settings (right sidebar):**
79
+ - Accelerator: **GPU T4 Γ—2** (or P100)
80
+ - Internet: **On**
81
+ - Persistence: Files only
82
+ 2. **Kaggle Secrets** (Add-ons β†’ Secrets):
83
+ - **Required:** `GROQ_NURSE_API_KEY`, `GROQ_PATIENT_API_KEY`,
84
+ `GROQ_EMPATHY_JUDGE_API_KEY`, `GROQ_MEDICAL_JUDGE_API_KEY`, `HF_TOKEN`
85
+ - **Optional:** `WANDB_API_KEY`
86
+ 3. **Run cells 2 β†’ 3 (sanity + REPAIR).** When cell 3 prints
87
+ `RESTART REQUIRED`, click **Run β†’ Restart kernel**, then resume from cell 5.
88
+ 4. **Run cells 5 β†’ 11 (verify + configure + dry-run + pre-flight).** Each cell
89
+ should print an `OK` line before moving on.
90
+ 5. **Run cell 13 (the long training cell, 4–6 hours).**
91
+ 6. **Run cells 14 β†’ 17 (final push + plots + inference smoke-test).**
92
+
93
+ ## Curriculum + reward thresholds (this run)
94
+
95
+ Constant per-phase rolling-avg-reward bars; sustained for **3 consecutive
96
+ GRPO groups** triggers either a phase promotion or end-of-training.
97
+
98
+ | Phase | Reward target (sustained Γ—3 groups) | Action when met |
99
+ |---|---|---|
100
+ | 1 β€” Tool Mastery | `+1.2` | force-promote to Phase 2 |
101
+ | 2 β€” Clinical Reasoning | `+1.1` | force-promote to Phase 3 |
102
+ | 3 β€” Empathetic Negotiation | `+1.0` | END TRAINING |
103
+
104
+ Why these numbers? The un-trained 8B Doctor's baseline on the same env is
105
+ `P1=+0.76, P2=+0.59, P3=+0.39`. Targets of `+1.2 / +1.1 / +1.0` correspond
106
+ to roughly `1.6Γ— / 1.9Γ— / 2.6Γ—` improvement over baseline β€” a meaningful
107
+ signal but reachable inside Kaggle's 12 h session limit.
108
+ """
109
+
110
+ CELL_02_SANITY = """\
111
+ # === CELL 2 β€” Sanity check (GPU + disk + python + internet) ===
112
+ # Run this FIRST. If any check fails, fix it before running the REPAIR cell.
113
+
114
+ import os, shutil, subprocess, sys, socket
115
+
116
+ print("--- GPU ---")
117
+ try:
118
+ print(subprocess.check_output(
119
+ ["nvidia-smi", "--query-gpu=name,memory.total,memory.free", "--format=csv"],
120
+ timeout=10,
121
+ ).decode())
122
+ except Exception as e:
123
+ print(f"nvidia-smi failed: {e}")
124
+ print("-> Set Accelerator to 'GPU T4 x2' in the right sidebar.")
125
+
126
+ print("--- Disk (/kaggle/working) ---")
127
+ total, used, free = shutil.disk_usage("/kaggle/working")
128
+ print(f" total={total/1e9:5.1f} GB | used={used/1e9:5.1f} GB | free={free/1e9:5.1f} GB")
129
+ if free < 8 * 1e9:
130
+ print(" WARNING: free disk < 8 GB β€” repair cell may fail. "
131
+ "Consider 'Run > Restart and clear cell outputs' to reset /tmp.")
132
+
133
+ print("--- Python ---")
134
+ print(f" python={sys.version.split()[0]} | exe={sys.executable}")
135
+
136
+ print("--- Internet (api.groq.com:443) ---")
137
+ try:
138
+ socket.create_connection(("api.groq.com", 443), timeout=5).close()
139
+ print(" reachable")
140
+ except Exception as e:
141
+ print(f" UNREACHABLE: {e}")
142
+ print(" -> Settings (right sidebar) -> Internet -> ON")
143
+ """
144
+
145
+ CELL_03_REPAIR = """\
146
+ # === CELL 3 β€” REPAIR CELL (idempotent full environment rebuild) ===
147
+ # Single source of truth for ER-MAP's GPU stack. Safe to re-run. After it
148
+ # finishes you'll see one of two final lines:
149
+ #
150
+ # RESTART REQUIRED -> Run -> Restart kernel, then resume from cell 5
151
+ # REPAIR OK -> proceed directly to cell 5
152
+ #
153
+ # Note: this cell only runs shell commands and one isolated subprocess.
154
+ # It deliberately does NOT `import torch / numpy / Pillow / unsloth` in the
155
+ # kernel, so re-running it after a botched install does not poison further
156
+ # attempts.
157
+
158
+ print("=" * 72); print(" CELL 3 β€” REPAIR"); print("=" * 72)
159
+
160
+ # 1. Clean caches (Kaggle's /kaggle/working is only 20 GB β€” installs
161
+ # routinely fill it after a few re-runs).
162
+ print("[1/6] Cleaning pip + tmp + HF dataset caches...")
163
+ get_ipython().system('pip cache purge -q || true')
164
+ get_ipython().system('rm -rf /tmp/* /root/.cache/pip /root/.cache/huggingface/datasets 2>/dev/null || true')
165
+
166
+ # 2. Pin torch + torchvision to the cu128 wheel (matches Kaggle's CUDA 12.8
167
+ # base image). DON'T let pip pull a generic CUDA-13 build β€” that breaks
168
+ # bitsandbytes (libnvJitLink.so.13 missing) and torchvision (CUDA-major
169
+ # mismatch RuntimeError at import time).
170
+ print("[2/6] Installing torch==2.10.0 + torchvision==0.25.0 (cu128)...")
171
+ get_ipython().system('pip install -q --no-cache-dir --force-reinstall '
172
+ 'torch==2.10.0 torchvision==0.25.0 '
173
+ '--index-url https://download.pytorch.org/whl/cu128')
174
+
175
+ # 3. Reinstall bitsandbytes against the now-pinned torch.
176
+ print("[3/6] Reinstalling bitsandbytes...")
177
+ get_ipython().system('pip install -q --no-cache-dir --force-reinstall bitsandbytes')
178
+
179
+ # 4. Upgrade unsloth + unsloth_zoo + trl in lockstep. unsloth and
180
+ # unsloth_zoo are released as a matched pair; if pip pulls a fresh
181
+ # unsloth_zoo against an old unsloth you get
182
+ # ImportError: cannot import name 'create_gradient_checkpointing_buffer'
183
+ print("[4/6] Upgrading unsloth + unsloth_zoo + trl...")
184
+ get_ipython().system('pip install -q --upgrade --no-cache-dir '
185
+ 'unsloth unsloth_zoo "trl>=0.18.2"')
186
+
187
+ # 5. ER-MAP runtime deps that aren't pre-installed on Kaggle.
188
+ print("[5/6] Installing ER-MAP runtime deps...")
189
+ get_ipython().system('pip install -q --no-cache-dir '
190
+ '"groq>=0.18.0" "huggingface_hub>=0.25.0" '
191
+ '"gymnasium>=0.29.0" "openenv-core>=0.1.0"')
192
+
193
+ # 6. Verify in a SUBPROCESS (so the parent kernel never imports any of these
194
+ # while pip is mid-flight, which is what causes the
195
+ # 'numpy was upgraded mid-session (loaded: X, installed: Y)' RuntimeError
196
+ # we kept hitting before).
197
+ print("[6/6] Verifying via subprocess...")
198
+ import subprocess, sys, json
199
+
200
+ verify_script = r'''
201
+ import json, sys
202
+ out = {"ok": True, "details": {}, "errors": []}
203
+ try:
204
+ import importlib.metadata as md
205
+ for pkg in ("torch", "torchvision", "bitsandbytes", "unsloth", "unsloth_zoo",
206
+ "trl", "transformers", "peft", "accelerate", "groq",
207
+ "huggingface_hub", "gymnasium", "numpy", "Pillow"):
208
+ try:
209
+ out["details"][pkg + "_installed"] = md.version(pkg)
210
+ except md.PackageNotFoundError:
211
+ out["details"][pkg + "_installed"] = None
212
+
213
+ import torch, torchvision, numpy as np, PIL, unsloth, unsloth_zoo, bitsandbytes, trl
214
+ out["details"]["torch_loaded"] = torch.__version__
215
+ out["details"]["torch_cuda"] = torch.version.cuda
216
+ out["details"]["cuda_available"] = bool(torch.cuda.is_available())
217
+ out["details"]["gpu_count"] = int(torch.cuda.device_count())
218
+ out["details"]["torchvision_loaded"] = torchvision.__version__
219
+ out["details"]["numpy_loaded"] = np.__version__
220
+ out["details"]["pillow_loaded"] = PIL.__version__
221
+ out["details"]["unsloth_loaded"] = unsloth.__version__
222
+ out["details"]["unsloth_zoo_loaded"] = unsloth_zoo.__version__
223
+ out["details"]["bitsandbytes_loaded"] = bitsandbytes.__version__
224
+ out["details"]["trl_loaded"] = trl.__version__
225
+
226
+ # Cross-check loaded-vs-installed for the C-extension libs that bit us
227
+ # on every previous run.
228
+ for pkg, loaded_key, installed_key in [
229
+ ("numpy", "numpy_loaded", "numpy_installed"),
230
+ ("Pillow", "pillow_loaded", "Pillow_installed"),
231
+ ("torch", "torch_loaded", "torch_installed"),
232
+ ]:
233
+ loaded = out["details"].get(loaded_key)
234
+ installed = out["details"].get(installed_key)
235
+ if loaded and installed and loaded != installed:
236
+ # Strip any local-version suffix (e.g. '+cu128') before compare.
237
+ if loaded.split("+")[0] != installed.split("+")[0]:
238
+ out["errors"].append(
239
+ f"{pkg} mismatch: loaded={loaded} installed={installed}"
240
+ )
241
+ except Exception as e:
242
+ out["ok"] = False
243
+ out["errors"].append(f"{type(e).__name__}: {e}")
244
+ print(json.dumps(out, default=str))
245
+ '''.lstrip()
246
+
247
+ res = subprocess.run([sys.executable, "-c", verify_script],
248
+ capture_output=True, text=True, timeout=180)
249
+ print(res.stdout if res.stdout else "<no stdout>")
250
+ if res.stderr:
251
+ print("---- subprocess stderr ----"); print(res.stderr)
252
+
253
+ # Parse the LAST line of stdout (others are prints from package init).
254
+ try:
255
+ last = res.stdout.strip().splitlines()[-1]
256
+ parsed = json.loads(last)
257
+ except Exception:
258
+ parsed = {"ok": False, "errors": ["could not parse verification output"]}
259
+
260
+ ok = parsed.get("ok") and not parsed.get("errors")
261
+ d = parsed.get("details", {})
262
+
263
+ print("\n" + "=" * 72)
264
+ if ok:
265
+ print(" REPAIR OK")
266
+ print(f" torch : {d.get('torch_loaded')} (CUDA {d.get('torch_cuda')})")
267
+ print(f" torchvision : {d.get('torchvision_loaded')}")
268
+ print(f" bitsandbytes: {d.get('bitsandbytes_loaded')}")
269
+ print(f" unsloth : {d.get('unsloth_loaded')} | unsloth_zoo: {d.get('unsloth_zoo_loaded')}")
270
+ print(f" trl : {d.get('trl_loaded')}")
271
+ print(f" numpy : {d.get('numpy_loaded')} | Pillow: {d.get('pillow_loaded')}")
272
+ print(f" GPUs : {d.get('gpu_count')} (cuda_available={d.get('cuda_available')})")
273
+ print()
274
+ print(" -> If this kernel previously imported torch/numpy/Pillow/unsloth,")
275
+ print(" RESTART NOW (Run -> Restart kernel) before continuing to cell 5.")
276
+ print(" If this is a fresh kernel, you can proceed directly.")
277
+ else:
278
+ print(" RESTART REQUIRED β€” issues detected:")
279
+ for e in parsed.get("errors", []):
280
+ print(f" - {e}")
281
+ print()
282
+ print(" Action: Run -> Restart kernel, then re-run from cell 2.")
283
+ print("=" * 72)
284
+ """
285
+
286
+ CELL_04_RESTART = """\
287
+ ## ⚠ Restart kernel here if cell 3 said `RESTART REQUIRED`
288
+
289
+ Click **Run β†’ Restart kernel** (or **Run β†’ Restart & clear cell outputs**),
290
+ then resume from **cell 5**. Skipping the restart will produce ABI mismatch
291
+ errors at the first GPU op.
292
+
293
+ If cell 3 said `REPAIR OK` AND this is a fresh kernel that hasn't imported
294
+ torch/numpy/Pillow/unsloth yet, you can proceed to cell 5 directly.
295
+ """
296
+
297
+ CELL_05_VERIFY = """\
298
+ # === CELL 5 β€” Post-restart verify (this kernel can import everything) ===
299
+ import importlib.metadata as md
300
+
301
+ print("--- Loaded versions in this kernel ---")
302
+ import torch, numpy, PIL, torchvision, unsloth, unsloth_zoo, bitsandbytes, trl, transformers, peft
303
+
304
+ versions = {
305
+ "torch": torch.__version__,
306
+ "torchvision": torchvision.__version__,
307
+ "numpy": numpy.__version__,
308
+ "Pillow": PIL.__version__,
309
+ "unsloth": unsloth.__version__,
310
+ "unsloth_zoo": unsloth_zoo.__version__,
311
+ "bitsandbytes": bitsandbytes.__version__,
312
+ "trl": trl.__version__,
313
+ "transformers": transformers.__version__,
314
+ "peft": peft.__version__,
315
+ }
316
+ all_ok = True
317
+ for k, v in versions.items():
318
+ try:
319
+ inst = md.version(k)
320
+ except md.PackageNotFoundError:
321
+ inst = "(not installed)"
322
+ # Tolerate local version suffixes like '+cu128'
323
+ flag = "OK" if inst.split("+")[0] == v.split("+")[0] else f"MISMATCH (installed={inst})"
324
+ if "MISMATCH" in flag:
325
+ all_ok = False
326
+ print(f" {k:14s}: loaded={v:20s} [{flag}]")
327
+
328
+ print()
329
+ print(f" CUDA available : {torch.cuda.is_available()}")
330
+ print(f" GPU count : {torch.cuda.device_count()}")
331
+ if torch.cuda.is_available():
332
+ for i in range(torch.cuda.device_count()):
333
+ p = torch.cuda.get_device_properties(i)
334
+ print(f" GPU {i} : {p.name} ({p.total_memory/1e9:.1f} GB)")
335
+
336
+ print()
337
+ print("OK" if all_ok else "NOT OK β€” re-run cell 3 and restart kernel.")
338
+ """
339
+
340
+ CELL_06_REPO = """\
341
+ # === CELL 6 β€” Mount the ER-MAP repo into /kaggle/working ===
342
+ import os, subprocess, sys
343
+
344
+ # OPTION A: clone a public GitHub fork (preferred). Edit GIT_URL.
345
+ GIT_URL = "https://github.com/<your-fork>/Meta_Finals.git"
346
+ BRANCH = "main"
347
+ REPO_ROOT = "/kaggle/working/Meta_Finals"
348
+
349
+ # OPTION B: Kaggle Dataset upload β€” set this if you uploaded the repo
350
+ # as a Kaggle Dataset named "ermap-source" (Add Data -> Upload).
351
+ DATASET_DIR = "/kaggle/input/ermap-source"
352
+
353
+ if not os.path.isdir(f"{REPO_ROOT}/ER_MAP"):
354
+ if "<your-fork>" not in GIT_URL:
355
+ print(f"Cloning {GIT_URL}@{BRANCH} -> {REPO_ROOT}...")
356
+ out = subprocess.run(
357
+ ["git", "clone", "--depth", "1", "-b", BRANCH, GIT_URL, REPO_ROOT],
358
+ capture_output=True, text=True,
359
+ )
360
+ print(out.stdout); print(out.stderr)
361
+ elif os.path.isdir(DATASET_DIR):
362
+ print(f"Copying {DATASET_DIR} -> {REPO_ROOT}...")
363
+ import shutil
364
+ shutil.copytree(DATASET_DIR, REPO_ROOT, dirs_exist_ok=True)
365
+
366
+ assert os.path.isdir(f"{REPO_ROOT}/ER_MAP"), (
367
+ "Repo not found.\\n"
368
+ " - Edit GIT_URL above to your GitHub fork, OR\\n"
369
+ " - Upload the repo as a Kaggle Dataset named 'ermap-source' (Add Data -> Upload)."
370
+ )
371
+
372
+ sys.path.insert(0, REPO_ROOT)
373
+ sys.path.insert(0, f"{REPO_ROOT}/kaggle")
374
+ print(f"OK. Repo at {REPO_ROOT}")
375
+ """
376
+
377
+ CELL_07_SECRETS = """\
378
+ # === CELL 7 β€” Wire Kaggle Secrets into env vars ===
379
+ import os
380
+ from kaggle_helpers import load_kaggle_secrets, kaggle_env_summary
381
+
382
+ load_kaggle_secrets()
383
+ kaggle_env_summary()
384
+
385
+ # Hard fail if no Groq key β€” training would silently use mock LLMs.
386
+ assert any(os.environ.get(k) for k in (
387
+ "GROQ_NURSE_API_KEY", "GROQ_PATIENT_API_KEY",
388
+ "GROQ_EMPATHY_JUDGE_API_KEY", "GROQ_MEDICAL_JUDGE_API_KEY",
389
+ "GROQ_API_KEY",
390
+ )), ("No Groq key found in Kaggle Secrets. "
391
+ "Add at least GROQ_NURSE_API_KEY in Add-ons -> Secrets.")
392
+ print("OK β€” at least one Groq key is wired.")
393
+ """
394
+
395
+ CELL_08_HF = """\
396
+ # === CELL 8 β€” Hugging Face Hub config (for checkpoint backup) ===
397
+ import os
398
+ from kaggle_helpers import push_checkpoint_to_hub, download_checkpoint_from_hub
399
+
400
+ # EDIT the line below to your HF model id (e.g. "udayd/ermap-doctor-lora").
401
+ HF_PUSH_REPO = "<your-username>/ermap-doctor-lora"
402
+ # To resume from a previous run, paste the same repo id here. Empty = fresh.
403
+ HF_RESUME_REPO = ""
404
+
405
+ RESUME_DIR = "/kaggle/working/checkpoints/resume"
406
+ if HF_RESUME_REPO:
407
+ download_checkpoint_from_hub(HF_RESUME_REPO, RESUME_DIR)
408
+ contents = os.listdir(RESUME_DIR) if os.path.isdir(RESUME_DIR) else []
409
+ print(f"Resume dir: {contents or '(empty)'}")
410
+ else:
411
+ print("Starting fresh β€” no resume.")
412
+
413
+ if "<your-username>" in HF_PUSH_REPO:
414
+ print("\\nWARNING: HF_PUSH_REPO still has <your-username> placeholder.")
415
+ print(" Checkpoints will NOT be pushed to HF Hub.")
416
+ print(" Edit the cell above and re-run before training if you want backups.")
417
+ """
418
+
419
+ CELL_09_HYPERPARAMS = """\
420
+ # === CELL 9 β€” GRPO hyperparameters ===
421
+ import os
422
+
423
+ MODEL_NAME = "unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit"
424
+ GROUP_SIZE = 2
425
+ LEARNING_RATE = 5e-6
426
+ KL_BETA = 0.04
427
+ OUTPUT_DIR = "/kaggle/working/er_map_grpo_checkpoints"
428
+ PUSH_EVERY_EPS = 20
429
+ USE_WANDB = False # WANDB conflicts with protobuf 7 on Kaggle base image
430
+ NUM_EPISODES = 200 # hard cap; early-stop usually finishes first
431
+
432
+ # --- Per-phase reward thresholds (constant for this run) -------------------
433
+ # After every GRPO update we look at the last CONVERGENCE_WINDOW groups; if
434
+ # ALL of them belong to the same current phase AND each has
435
+ # rolling_avg_reward >= PHASE_REWARD_TARGETS[current_phase] AND
436
+ # rolling_win_rate >= PHASE_MIN_WIN_RATE, we either:
437
+ # - force-promote to the next phase (Phase 1 / Phase 2), OR
438
+ # - terminate training (Phase 3).
439
+ EARLY_STOP_ENABLED = True
440
+ PHASE_REWARD_TARGETS = {1: 1.2, 2: 1.1, 3: 1.0}
441
+ PHASE_MIN_WIN_RATE = 0.20
442
+ CONVERGENCE_WINDOW = 3
443
+
444
+ # --- Per-episode budget controls (read by triage_env) ----------------------
445
+ os.environ["ERMAP_MAX_EPISODE_STEPS"] = "20"
446
+ os.environ["ERMAP_MAX_INTERNAL_EXCHANGES"] = "5"
447
+
448
+ # --- Groq traffic-shaping (8B for actors, 70B for judges) ------------------
449
+ # High-volume conversational roles (Nurse + Patient) on the 8B-instant pool
450
+ # (500K TPD, 14,400 RPD); the two judges stay on 70B-versatile because their
451
+ # grading quality directly shapes the reward signal.
452
+ os.environ["ERMAP_NURSE_MODEL"] = "llama-3.1-8b-instant"
453
+ os.environ["ERMAP_PATIENT_MODEL"] = "llama-3.1-8b-instant"
454
+ os.environ["ERMAP_EMPATHY_JUDGE_MODEL"] = "llama-3.3-70b-versatile"
455
+ os.environ["ERMAP_MEDICAL_JUDGE_MODEL"] = "llama-3.3-70b-versatile"
456
+
457
+ print("Hyperparameters set:")
458
+ print(f" NUM_EPISODES = {NUM_EPISODES}")
459
+ print(f" GROUP_SIZE = {GROUP_SIZE}")
460
+ print(f" PHASE_REWARD_TARGETS = {PHASE_REWARD_TARGETS}")
461
+ print(f" PHASE_MIN_WIN_RATE = {PHASE_MIN_WIN_RATE}")
462
+ print(f" CONVERGENCE_WINDOW = {CONVERGENCE_WINDOW}")
463
+ print(f" Nurse / Patient = llama-3.1-8b-instant (actors, high-volume)")
464
+ print(f" Empathy / Med Judge = llama-3.3-70b-versatile (graders, quality)")
465
+ """
466
+
467
+ CELL_10_PREFLIGHT = """\
468
+ # === CELL 10 β€” Pre-flight: Groq routing + key liveness ===
469
+ # Verifies that:
470
+ # - each role is routed to the model you set in cell 9, and
471
+ # - each role's Groq key actually answers a 1-token "PING" prompt.
472
+
473
+ import os
474
+ from ER_MAP.envs.api_router import AgentRouter
475
+
476
+ router = AgentRouter()
477
+ expected = {
478
+ "nurse": "llama-3.1-8b-instant",
479
+ "patient": "llama-3.1-8b-instant",
480
+ "empathy_judge": "llama-3.3-70b-versatile",
481
+ "medical_judge": "llama-3.3-70b-versatile",
482
+ }
483
+
484
+ print("=" * 60); print(" PRE-FLIGHT β€” Groq routing + smoke test"); print("=" * 60)
485
+ all_pass = True
486
+ for role, exp in expected.items():
487
+ actual = router._models.get(role, "?")
488
+ routing_ok = (actual == exp)
489
+ client = router._clients.get(role)
490
+
491
+ if client is None:
492
+ print(f" [SKIP] {role:14s} -> no Groq client (key missing)")
493
+ all_pass = False
494
+ continue
495
+
496
+ try:
497
+ resp = client.chat.completions.create(
498
+ model=exp,
499
+ messages=[{"role": "user", "content": "Reply with exactly: PING"}],
500
+ max_tokens=4, temperature=0,
501
+ )
502
+ api_ok = "PING" in (resp.choices[0].message.content or "").upper()
503
+ err = ""
504
+ except Exception as e:
505
+ api_ok = False
506
+ err = f" ({type(e).__name__}: {str(e)[:80]})"
507
+
508
+ flag = "PASS" if (routing_ok and api_ok) else "FAIL"
509
+ if flag == "FAIL":
510
+ all_pass = False
511
+ print(f" [{flag}] {role:14s} -> {actual:30s} "
512
+ f"routing={'ok' if routing_ok else 'WRONG'}, "
513
+ f"api={'ok' if api_ok else 'fail'}{err}")
514
+
515
+ print("=" * 60)
516
+ print("OK" if all_pass else "NOT OK β€” fix routing/keys before training.")
517
+ print("=" * 60)
518
+ assert all_pass, "Pre-flight failed; do not proceed to training."
519
+ """
520
+
521
+ CELL_11_DRYRUN = """\
522
+ # === CELL 11 β€” Dry-run smoke test (no GPU, no model load) ===
523
+ # Verifies the curriculum scheduler + reward verifier + per-phase early-stop
524
+ # wiring before we burn GPU minutes on the real run.
525
+
526
+ from ER_MAP.training.train_grpo import train
527
+
528
+ _ = train(
529
+ num_episodes=8,
530
+ group_size=2,
531
+ model_name=MODEL_NAME,
532
+ learning_rate=LEARNING_RATE,
533
+ kl_beta=KL_BETA,
534
+ output_dir="/kaggle/working/_dryrun",
535
+ dry_run=True,
536
+ phase_reward_targets=PHASE_REWARD_TARGETS,
537
+ phase_min_win_rate=PHASE_MIN_WIN_RATE,
538
+ convergence_window=CONVERGENCE_WINDOW,
539
+ early_stop=EARLY_STOP_ENABLED,
540
+ )
541
+ print("\\nDry-run OK β€” scheduler + verifier + per-phase early-stop wiring is healthy.")
542
+ """
543
+
544
+ CELL_12_HOOK = """\
545
+ # === CELL 12 β€” Wire periodic HF Hub push into training ===
546
+ # We monkey-patch save_lora_adapters so every checkpoint dump also pushes
547
+ # the LoRA adapter to HF Hub. Failures are non-fatal β€” training keeps
548
+ # running even if a push fails (e.g. transient HF 502).
549
+
550
+ from ER_MAP.training import train_grpo as _tg
551
+ _original_save = _tg.save_lora_adapters
552
+
553
+ def save_lora_adapters_with_push(model, tokenizer, output_dir):
554
+ _original_save(model, tokenizer, output_dir)
555
+ if HF_PUSH_REPO and "<your-username>" not in HF_PUSH_REPO:
556
+ try:
557
+ push_checkpoint_to_hub(
558
+ output_dir, HF_PUSH_REPO,
559
+ commit_message=f"checkpoint @ {os.path.basename(output_dir)}",
560
+ )
561
+ except Exception as e:
562
+ print(f" [hub-push] non-fatal failure: {e}")
563
+
564
+ _tg.save_lora_adapters = save_lora_adapters_with_push
565
+ print("Hub-push hook installed.")
566
+ """
567
+
568
+ CELL_13_TRAIN_MD = """\
569
+ ## 13 Β· Run real training (the 4–6 hour cell)
570
+
571
+ **Estimated wall-clock on Kaggle T4 Γ—2:**
572
+
573
+ - ~3–5 min per episode (6–14 env steps Γ— Doctor.generate + 4–8 Γ— Groq calls)
574
+ - ~1–2 min amortized per GRPO update (G=2 trajectories Γ— response-token log-probs)
575
+ - **Per-group β‰ˆ 8–12 min** (2 episodes + 1 update)
576
+
577
+ | Phase | Typical episodes to reach target | Wall-clock |
578
+ |---|---|---|
579
+ | 1 (target `+1.2` Γ— 3) | 12 – 24 episodes (6 – 12 groups) | ~1.0 – 2.0 h |
580
+ | 2 (target `+1.1` Γ— 3) | 16 – 32 episodes (8 – 16 groups) | ~1.5 – 2.5 h |
581
+ | 3 (target `+1.0` Γ— 3) | 20 – 50 episodes (10 – 25 groups) | ~2.0 – 4.0 h |
582
+ | **Total** | 50 – 100 episodes | **~4.5 – 8.5 h** |
583
+
584
+ If `NUM_EPISODES=200` is exhausted before Phase 3 converges, training
585
+ stops at the cap and the latest LoRA checkpoint is on HF Hub already
586
+ (we push every 20 episodes), so resume in a fresh session via
587
+ `HF_RESUME_REPO` in cell 8.
588
+ """
589
+
590
+ CELL_13_TRAIN = """\
591
+ # === CELL 13 β€” REAL TRAINING (4-6 h cell) ===
592
+ metrics = train(
593
+ num_episodes=NUM_EPISODES,
594
+ group_size=GROUP_SIZE,
595
+ model_name=MODEL_NAME,
596
+ groq_api_key=os.environ.get("GROQ_NURSE_API_KEY", "")
597
+ or os.environ.get("GROQ_API_KEY", ""),
598
+ learning_rate=LEARNING_RATE,
599
+ kl_beta=KL_BETA,
600
+ use_wandb=USE_WANDB,
601
+ output_dir=OUTPUT_DIR,
602
+ dry_run=False,
603
+ phase_reward_targets=PHASE_REWARD_TARGETS,
604
+ phase_min_win_rate=PHASE_MIN_WIN_RATE,
605
+ convergence_window=CONVERGENCE_WINDOW,
606
+ early_stop=EARLY_STOP_ENABLED,
607
+ )
608
+ print(f"\\nTraining returned {len(metrics)} metric records.")
609
+ """
610
+
611
+ CELL_14_FINAL_PUSH = """\
612
+ # === CELL 14 β€” Final push: adapters + merged fp16 ===
613
+ FINAL_LORA_DIR = f"{OUTPUT_DIR}/final_lora"
614
+ FINAL_MERGED_DIR = f"{OUTPUT_DIR}/final_merged_fp16"
615
+
616
+ if HF_PUSH_REPO and "<your-username>" not in HF_PUSH_REPO:
617
+ push_checkpoint_to_hub(FINAL_LORA_DIR, HF_PUSH_REPO,
618
+ commit_message="final LoRA adapter")
619
+ if os.path.isdir(FINAL_MERGED_DIR):
620
+ push_checkpoint_to_hub(FINAL_MERGED_DIR, f"{HF_PUSH_REPO}-merged",
621
+ commit_message="final merged fp16")
622
+ print(f"Final checkpoints pushed: https://huggingface.co/{HF_PUSH_REPO}")
623
+ else:
624
+ print("HF_PUSH_REPO not configured β€” skipping final push.")
625
+ """
626
+
627
+ CELL_15_PLOTS_MD = """\
628
+ ## 15 Β· Per-phase training graphs (one dashboard per curriculum phase)
629
+
630
+ We render a 6-panel dashboard for **every phase that contains episodes**,
631
+ plus a cross-phase overview and a phase-comparison bar chart. All PNGs are
632
+ written to `er_map_grpo_checkpoints/plots/` and uploaded to HF Hub in the
633
+ next cell so they survive Kaggle session expiry.
634
+
635
+ Each per-phase dashboard contains:
636
+
637
+ 1. **Reward growth** β€” raw scatter + rolling mean (w=10) + verified rolling mean
638
+ 2. **Rolling win rate** β€” w=20 win-rate evolution within the phase
639
+ 3. **Outcome distribution over time** β€” stacked bars (WIN/PARTIAL/INCORRECT/AMA_LOSS/FATAL_LOSS)
640
+ 4. **Reward components** β€” mean of each component (process / treatment / empathy / labs / etc.)
641
+ 5. **GRPO update stats** β€” loss + KL divergence per group update
642
+ 6. **Episode length distribution** β€” histogram of step counts
643
+ """
644
+
645
+ CELL_15_PLOTS = """\
646
+ # === CELL 15 β€” Per-phase training dashboards ===
647
+ from ER_MAP.plotting import plot_per_phase_dashboards
648
+ from IPython.display import Image, display, Markdown
649
+
650
+ PLOTS_DIR = f"{OUTPUT_DIR}/plots"
651
+ written = plot_per_phase_dashboards(
652
+ metrics_path=f"{OUTPUT_DIR}/training_metrics.json",
653
+ output_dir=PLOTS_DIR,
654
+ )
655
+
656
+ print(f"Saved {len(written)} chart(s) to {PLOTS_DIR}:")
657
+ for name, path in written.items():
658
+ size_kb = os.path.getsize(path) / 1024
659
+ print(f" {name:<28s} -> {path} ({size_kb:.0f} KB)")
660
+
661
+ # Display each chart inline so the operator sees them without leaving Kaggle.
662
+ ordered = (sorted(k for k in written if k.startswith("phase"))
663
+ + ["all_phases_overview", "all_phases_comparison"])
664
+ for key in ordered:
665
+ if key not in written:
666
+ continue
667
+ display(Markdown(f"### {key.replace('_', ' ').title()}"))
668
+ display(Image(filename=written[key]))
669
+ """
670
+
671
+ CELL_16_PUSH_PLOTS = """\
672
+ # === CELL 16 β€” Push plots to HF Hub ===
673
+ if HF_PUSH_REPO and "<your-username>" not in HF_PUSH_REPO:
674
+ push_checkpoint_to_hub(PLOTS_DIR, HF_PUSH_REPO,
675
+ commit_message="per-phase training plots")
676
+ print(f"Plots pushed: https://huggingface.co/{HF_PUSH_REPO}/tree/main")
677
+ else:
678
+ print("HF_PUSH_REPO not configured β€” plots stay only in /kaggle/working/.")
679
+ """
680
+
681
+ CELL_17_INFER_MD = """\
682
+ ## 17 Β· (Optional) Inference smoke-test on the trained model
683
+
684
+ Catches the classic 'merge path looked OK but the saved model emits garbage'
685
+ failure mode before the demo.
686
+ """
687
+
688
+ CELL_17_INFER = """\
689
+ # === CELL 17 β€” Inference smoke-test on the trained model ===
690
+ from ER_MAP.training.train_grpo import generate_doctor_action, load_model_and_tokenizer
691
+ from peft import PeftModel
692
+
693
+ base_model, tok = load_model_and_tokenizer(model_name=MODEL_NAME)
694
+ trained = PeftModel.from_pretrained(base_model, FINAL_LORA_DIR)
695
+
696
+ test_obs = (
697
+ '{"event":"episode_start","nurse_experience":"veteran",'
698
+ '"message":"Patient with chest pain, HR 120, BP 90/60, vague history.",'
699
+ '"soap_summary":{}}'
700
+ )
701
+ for i in range(3):
702
+ print(f"\\n--- Sample {i+1} ---")
703
+ print(generate_doctor_action(trained, tok, test_obs, max_new_tokens=160))
704
+ """
705
+
706
+
707
+ # ---------------------------------------------------------------------------
708
+ # Quickstart markdown (sibling file)
709
+ # ---------------------------------------------------------------------------
710
+
711
+ QUICKSTART_MD = """\
712
+ # Kaggle Quickstart β€” ER-MAP GRPO Training (v3 stable)
713
+
714
+ The Kaggle notebook is in `kaggle/train_ermap_grpo_kaggle.ipynb`. This file
715
+ is the cheat sheet for running it end-to-end without the dependency hell
716
+ that bit us in earlier attempts.
717
+
718
+ ## 0. Prerequisites (one-time)
719
+
720
+ 1. **GitHub fork** of this repo. The notebook clones from a public fork at
721
+ cell 6 β€” edit `GIT_URL`. Alternatively, upload the repo as a Kaggle
722
+ Dataset named `ermap-source` (Add Data β†’ Upload).
723
+ 2. **Hugging Face write token** (`HF_TOKEN`) for pushing the trained
724
+ adapter. Create at https://huggingface.co/settings/tokens (fine-grained,
725
+ write access on a single model repo is enough).
726
+ 3. **Five Groq keys** (one each for Nurse / Patient / Empathy Judge /
727
+ Medical Judge / shared fallback). Free-tier accounts are fine; the
728
+ per-account limits multiply across keys.
729
+
730
+ ## 1. Create the Kaggle notebook
731
+
732
+ 1. Sign in to https://www.kaggle.com/code β†’ **New Notebook**.
733
+ 2. Right sidebar:
734
+ - Accelerator: **GPU T4 Γ—2** (or P100)
735
+ - Internet: **On**
736
+ - Persistence: Files only
737
+ 3. **File β†’ Upload Notebook** β†’ choose `kaggle/train_ermap_grpo_kaggle.ipynb`
738
+ from this repo.
739
+
740
+ ## 2. Add Kaggle Secrets
741
+
742
+ Add-ons β†’ Secrets β†’ Add a new secret. Required labels (exactly):
743
+
744
+ | Label | Value |
745
+ |---|---|
746
+ | `GROQ_NURSE_API_KEY` | your nurse Groq key |
747
+ | `GROQ_PATIENT_API_KEY` | your patient Groq key |
748
+ | `GROQ_EMPATHY_JUDGE_API_KEY` | your empathy-judge Groq key |
749
+ | `GROQ_MEDICAL_JUDGE_API_KEY` | your medical-judge Groq key |
750
+ | `HF_TOKEN` | your HF write token |
751
+ | `WANDB_API_KEY` *(optional)* | your W&B key (skip β€” disabled by default) |
752
+
753
+ The notebook reads them via `kaggle_helpers.load_kaggle_secrets()` and
754
+ exports them as env vars.
755
+
756
+ ## 3. Edit two placeholders in the notebook
757
+
758
+ - **Cell 6:** `GIT_URL = "https://github.com/<your-fork>/Meta_Finals.git"`
759
+ - **Cell 8:** `HF_PUSH_REPO = "<your-username>/ermap-doctor-lora"`
760
+
761
+ If you uploaded the repo as a Kaggle Dataset instead, leave `GIT_URL` as the
762
+ placeholder β€” cell 6 will detect `/kaggle/input/ermap-source` and copy from
763
+ there.
764
+
765
+ ## 4. Run order (the only sequence that works)
766
+
767
+ | Cell | What it does | Expected output |
768
+ |---|---|---|
769
+ | 2 | GPU + disk + python + internet sanity check | GPU listed, disk free > 8 GB |
770
+ | 3 | **REPAIR** β€” pin torch 2.10 cu128, reinstall bitsandbytes, upgrade unsloth | `REPAIR OK` (or `RESTART REQUIRED`) |
771
+ | **(restart)** | If cell 3 said RESTART REQUIRED β†’ Run β†’ Restart kernel | β€” |
772
+ | 5 | Post-restart import verify | All `OK`, GPUs listed |
773
+ | 6 | Clone / mount the repo | `OK. Repo at /kaggle/working/Meta_Finals` |
774
+ | 7 | Wire Kaggle Secrets β†’ env vars | `OK β€” at least one Groq key is wired` |
775
+ | 8 | HF Hub config | `Starting fresh β€” no resume.` |
776
+ | 9 | Hyperparameters (P1=+1.2, P2=+1.1, P3=+1.0) | thresholds printed |
777
+ | 10 | **Pre-flight** β€” Groq routing + 4Γ— PING | 4Γ— `[PASS]`, then `OK` |
778
+ | 11 | Dry-run smoke test (no GPU) | `Dry-run OK` |
779
+ | 12 | Wire HF push hook | `Hub-push hook installed.` |
780
+ | 13 | **REAL TRAINING** (4–6 h) | per-group rolling stats, eventual `EARLY STOP` |
781
+ | 14 | Final push to HF | `Final checkpoints pushed: https://huggingface.co/...` |
782
+ | 15 | Per-phase plots | 5 PNGs displayed inline |
783
+ | 16 | Push plots to HF | `Plots pushed: ...` |
784
+ | 17 | Inference smoke-test (optional) | 3 sample Doctor actions printed |
785
+
786
+ ## 5. Common failures & fixes
787
+
788
+ | Symptom | Root cause | Fix |
789
+ |---|---|---|
790
+ | `numpy was upgraded mid-session` | numpy import poisoned by a previous cell | Restart kernel, re-run from cell 3 |
791
+ | `Pillow incompatible with torchvision` | Pillow ABI mismatch | Restart kernel, re-run from cell 3 |
792
+ | `PyTorch and torchvision compiled with different CUDA major` | torch upgraded to cu13 by a transient resolve | Re-run cell 3 (it pins cu128) and restart |
793
+ | `cannot import name 'create_gradient_checkpointing_buffer'` | unsloth ↔ unsloth_zoo version drift | Re-run cell 3 (upgrades both in lockstep) |
794
+ | `libnvJitLink.so.13 missing` | bitsandbytes built against different CUDA | Re-run cell 3 (force-reinstalls bitsandbytes after torch pin) |
795
+ | Disk usage > quota | Kaggle's 20 GB working partition fills up | First line of cell 3 cleans `/tmp` and pip cache |
796
+ | Pre-flight `[FAIL]` for a role | Groq key dead / quota exceeded | Generate a new key in console.groq.com β†’ update Kaggle Secret β†’ re-run cell 7+10 |
797
+ | `[FAIL]` says `routing=WRONG` | env var not set when `AgentRouter()` was constructed | Re-run cell 9 BEFORE cell 10 |
798
+ | Training freezes at episode 1 for >10 min | Doctor.generate hung; Unsloth import broke silently | Check cell 5 output for `unsloth` line; restart kernel and re-run cell 3 if missing |
799
+
800
+ ## 6. What the trained model gives you
801
+
802
+ After cell 13 finishes (or hits the 12 h Kaggle session cap), you have:
803
+
804
+ - `OUTPUT_DIR/final_lora/` β€” LoRA adapter weights (~50 MB), pushed to
805
+ `HF_PUSH_REPO`
806
+ - `OUTPUT_DIR/final_merged_fp16/` β€” full Llama-3.1-8B fp16 merge with the
807
+ adapter applied (~16 GB), pushed to `HF_PUSH_REPO-merged`
808
+ - `OUTPUT_DIR/training_metrics.json` β€” per-episode rewards, outcomes,
809
+ rolling stats β€” input for the per-phase plots
810
+ - `OUTPUT_DIR/plots/*.png` β€” 5 dashboards (one per phase + cross-phase
811
+ overview + comparison bar)
812
+
813
+ Use the LoRA adapter for the demo (quick to load, runs on a 4050 6 GB at
814
+ ~30 tok/s); use the merged fp16 if you need to host on a Vercel/HF Space
815
+ without `peft`.
816
+ """
817
+
818
+
819
+ # ---------------------------------------------------------------------------
820
+ # Build the notebook
821
+ # ---------------------------------------------------------------------------
822
+
823
+ def build_notebook() -> dict:
824
+ cells = [
825
+ md_cell(CELL_01_TITLE), # 0
826
+ code_cell(CELL_02_SANITY), # 1
827
+ code_cell(CELL_03_REPAIR), # 2
828
+ md_cell(CELL_04_RESTART), # 3
829
+ code_cell(CELL_05_VERIFY), # 4
830
+ code_cell(CELL_06_REPO), # 5
831
+ code_cell(CELL_07_SECRETS), # 6
832
+ code_cell(CELL_08_HF), # 7
833
+ code_cell(CELL_09_HYPERPARAMS), # 8
834
+ code_cell(CELL_10_PREFLIGHT), # 9
835
+ code_cell(CELL_11_DRYRUN), # 10
836
+ code_cell(CELL_12_HOOK), # 11
837
+ md_cell(CELL_13_TRAIN_MD), # 12
838
+ code_cell(CELL_13_TRAIN), # 13
839
+ code_cell(CELL_14_FINAL_PUSH), # 14
840
+ md_cell(CELL_15_PLOTS_MD), # 15
841
+ code_cell(CELL_15_PLOTS), # 16
842
+ code_cell(CELL_16_PUSH_PLOTS), # 17
843
+ md_cell(CELL_17_INFER_MD), # 18
844
+ code_cell(CELL_17_INFER), # 19
845
+ ]
846
+ return {
847
+ "cells": cells,
848
+ "metadata": {
849
+ "kernelspec": {
850
+ "display_name": "Python 3",
851
+ "language": "python",
852
+ "name": "python3",
853
+ },
854
+ "language_info": {
855
+ "name": "python",
856
+ "version": "3.10",
857
+ },
858
+ },
859
+ "nbformat": 4,
860
+ "nbformat_minor": 5,
861
+ }
862
+
863
+
864
+ def main() -> None:
865
+ here = Path(__file__).parent
866
+ nb_path = here / "train_ermap_grpo_kaggle.ipynb"
867
+ qs_path = here / "KAGGLE_QUICKSTART.md"
868
+
869
+ nb = build_notebook()
870
+ nb_path.write_text(json.dumps(nb, indent=1, ensure_ascii=False), encoding="utf-8")
871
+ qs_path.write_text(QUICKSTART_MD, encoding="utf-8")
872
+
873
+ n_md = sum(1 for c in nb["cells"] if c["cell_type"] == "markdown")
874
+ n_code = sum(1 for c in nb["cells"] if c["cell_type"] == "code")
875
+ print(f"Wrote {nb_path} ({len(nb['cells'])} cells: {n_md} md / {n_code} code)")
876
+ print(f"Wrote {qs_path} ({len(QUICKSTART_MD.splitlines())} lines)")
877
+
878
+
879
+ if __name__ == "__main__":
880
+ main()
kaggle/train_ermap_grpo_kaggle.ipynb CHANGED
@@ -4,33 +4,45 @@
4
  "cell_type": "markdown",
5
  "metadata": {},
6
  "source": [
7
- "# ER-MAP β€” Doctor Agent GRPO Training (Kaggle Free-Tier)\n",
8
- "\n",
9
- "**Target hardware:** Tesla T4 16 GB (or P100 16 GB) β€” Kaggle's free GPU.\n",
10
- "\n",
11
- "**What this notebook does:**\n",
12
- "1. Clones / mounts the ER-MAP repo\n",
13
- "2. Installs the missing pieces (Unsloth, TRL, Groq, HF Hub) on top of Kaggle's pre-baked PyTorch image\n",
14
- "3. Loads Llama-3.1-8B in 4-bit + LoRA(r=16) via Unsloth (~7 GB VRAM)\n",
15
- "4. Runs the manual GRPO loop from `ER_MAP/training/train_grpo.py` with 3-phase curriculum learning\n",
16
- "5. Pushes LoRA adapter checkpoints to a Hugging Face Hub repo every 20 episodes so the 12-hour Kaggle session limit doesn't lose progress\n",
17
- "\n",
18
- "**Required Kaggle Secrets** (Add-ons β†’ Secrets):\n",
19
- "- `GROQ_NURSE_API_KEY`, `GROQ_PATIENT_API_KEY`, `GROQ_EMPATHY_JUDGE_API_KEY`, `GROQ_MEDICAL_JUDGE_API_KEY` β€” for the multi-agent env actors and judges\n",
20
- "- `HF_TOKEN` β€” to push checkpoints (use a fine-grained write token)\n",
21
- "- `WANDB_API_KEY` β€” *optional*, for the reward-growth chart\n",
22
- "\n",
23
- "**Notebook settings (right sidebar):**\n",
24
- "- Accelerator: **GPU T4 x2** (or P100)\n",
25
- "- Internet: **On** *(Groq calls require this)*\n",
26
- "- Persistence: Files only"
27
- ]
28
- },
29
- {
30
- "cell_type": "markdown",
31
- "metadata": {},
32
- "source": [
33
- "## 1 Β· Sanity check the GPU + clone the repo"
 
 
 
 
 
 
 
 
 
 
 
 
34
  ]
35
  },
36
  {
@@ -39,8 +51,38 @@
39
  "metadata": {},
40
  "outputs": [],
41
  "source": [
42
- "!nvidia-smi --query-gpu=name,memory.total,memory.free --format=csv\n",
43
- "!python -c \"import torch; print('torch', torch.__version__, 'cuda', torch.cuda.is_available())\""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  ]
45
  },
46
  {
@@ -49,47 +91,159 @@
49
  "metadata": {},
50
  "outputs": [],
51
  "source": [
52
- "# --- OPTION A: clone the public repo (preferred) ----------------------\n",
53
- "# Replace <your-github-fork> with your actual fork URL. Public fork\n",
54
- "# works without any token. For a private repo, set HF_TOKEN OR pass\n",
55
- "# a GH PAT via Kaggle Secrets.\n",
56
- "GIT_URL = \"https://github.com/<your-github-fork>/Meta_Finals.git\"\n",
57
- "BRANCH = \"main\"\n",
58
- "REPO_ROOT = \"/kaggle/working/Meta_Finals\"\n",
59
- "\n",
60
- "import os, subprocess\n",
61
- "if not os.path.isdir(f\"{REPO_ROOT}/ER_MAP\") and \"<your-github-fork>\" not in GIT_URL:\n",
62
- " print(subprocess.check_output(\n",
63
- " [\"git\", \"clone\", \"--depth\", \"1\", \"-b\", BRANCH, GIT_URL, REPO_ROOT],\n",
64
- " stderr=subprocess.STDOUT,\n",
65
- " ).decode())\n",
66
- "\n",
67
- "# --- OPTION B: dataset upload (if you don't want to push to GitHub) ---\n",
68
- "# 1. Locally: zip the repo (excluding .git, checkpoints, __pycache__).\n",
69
- "# 2. Kaggle: New Dataset -> upload the zip -> name it `ermap-source`.\n",
70
- "# 3. This notebook: Add Data -> ermap-source.\n",
71
- "# 4. Run the next cell to copy /kaggle/input/ermap-source/ into\n",
72
- "# /kaggle/working/Meta_Finals (writeable).\n",
73
- "DATASET_DIR = \"/kaggle/input/ermap-source\"\n",
74
- "if not os.path.isdir(f\"{REPO_ROOT}/ER_MAP\") and os.path.isdir(DATASET_DIR):\n",
75
- " import shutil\n",
76
- " shutil.copytree(DATASET_DIR, REPO_ROOT, dirs_exist_ok=True)\n",
77
- " print(f\"Copied {DATASET_DIR} -> {REPO_ROOT}\")\n",
78
- "\n",
79
- "assert os.path.isdir(f\"{REPO_ROOT}/ER_MAP\"), (\n",
80
- " \"Repo not found. Either set GIT_URL above (Option A) or upload the \"\n",
81
- " \"repo as a Kaggle Dataset named 'ermap-source' (Option B).\"\n",
82
- ")\n",
83
- "print(\"Repo ready at\", REPO_ROOT)"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  ]
85
  },
86
  {
87
  "cell_type": "markdown",
88
  "metadata": {},
89
  "source": [
90
- "## 2 Β· Install the missing dependencies\n",
 
 
 
 
91
  "\n",
92
- "Kaggle's GPU image already ships with PyTorch 2.x + CUDA 12 + transformers + accelerate + peft + bitsandbytes. We only add Unsloth (which pins matching xformers/triton), TRL, Gymnasium, Groq SDK, and the HF Hub client."
 
93
  ]
94
  },
95
  {
@@ -98,26 +252,46 @@
98
  "metadata": {},
99
  "outputs": [],
100
  "source": [
101
- "# --upgrade is critical: Kaggle's pre-baked layer often ships an\n",
102
- "# OLD `unsloth` paired with whatever fresh `unsloth_zoo` pip pulled\n",
103
- "# this morning, and the import then fails with:\n",
104
- "# ImportError: cannot import name 'create_gradient_checkpointing_buffer'\n",
105
- "# Forcing both packages to upgrade in one resolve pass keeps them in lockstep.\n",
106
- "!pip install -q --upgrade -r {REPO_ROOT}/kaggle/requirements_kaggle.txt\n",
107
- "# Sanity check the unsloth import β€” it's the most fragile dep on Kaggle.\n",
108
- "# If you see the gradient_checkpointing ImportError below, run:\n",
109
- "# !pip install --upgrade --force-reinstall --no-cache-dir unsloth unsloth_zoo\n",
110
- "# in a NEW cell, then RESTART the kernel and re-run from cell 2.\n",
111
- "!python -c \"import unsloth, unsloth_zoo; print('unsloth', unsloth.__version__, '| unsloth_zoo', unsloth_zoo.__version__)\""
112
- ]
113
- },
114
- {
115
- "cell_type": "markdown",
116
- "metadata": {},
117
- "source": [
118
- "## 3 Β· Wire Kaggle Secrets into env vars\n",
119
- "\n",
120
- "ER-MAP reads `GROQ_NURSE_API_KEY` / `GROQ_PATIENT_API_KEY` / etc. directly from `os.environ`. The helper below copies your Kaggle Secrets into those env vars in one shot."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  ]
122
  },
123
  {
@@ -126,28 +300,63 @@
126
  "metadata": {},
127
  "outputs": [],
128
  "source": [
129
- "import sys, os\n",
130
- "sys.path.insert(0, REPO_ROOT) # so we can import ER_MAP\n",
131
- "sys.path.insert(0, f\"{REPO_ROOT}/kaggle\") # so we can import kaggle_helpers\n",
132
- "\n",
133
- "from kaggle_helpers import (\n",
134
- " load_kaggle_secrets,\n",
135
- " kaggle_env_summary,\n",
136
- " push_checkpoint_to_hub,\n",
137
- " download_checkpoint_from_hub,\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  ")\n",
139
  "\n",
140
- "load_kaggle_secrets()\n",
141
- "kaggle_env_summary()"
 
142
  ]
143
  },
144
  {
145
- "cell_type": "markdown",
 
146
  "metadata": {},
 
147
  "source": [
148
- "## 4 Β· (Optional) Resume from a previous Kaggle session\n",
 
 
149
  "\n",
150
- "If you've trained before and pushed an adapter to HF Hub, set `HF_RESUME_REPO` and run the cell to pull the latest LoRA adapter into `/kaggle/working/checkpoints/resume/`. The training cell will pick it up automatically."
 
 
 
 
 
 
 
 
 
 
151
  ]
152
  },
153
  {
@@ -156,68 +365,27 @@
156
  "metadata": {},
157
  "outputs": [],
158
  "source": [
159
- "HF_PUSH_REPO = \"<your-username>/ermap-doctor-lora\" # where checkpoints will be pushed\n",
160
- "HF_RESUME_REPO = \"\" # e.g. \"<your-username>/ermap-doctor-lora\"; leave empty to start fresh\n",
 
 
 
 
 
 
161
  "\n",
162
  "RESUME_DIR = \"/kaggle/working/checkpoints/resume\"\n",
163
  "if HF_RESUME_REPO:\n",
164
  " download_checkpoint_from_hub(HF_RESUME_REPO, RESUME_DIR)\n",
165
- " print(\"Resume dir contents:\", os.listdir(RESUME_DIR) if os.path.isdir(RESUME_DIR) else \"(empty)\")"
166
- ]
167
- },
168
- {
169
- "cell_type": "markdown",
170
- "metadata": {},
171
- "source": [
172
- "## 5 Β· Configure the GRPO run\n",
173
- "\n",
174
- "The defaults below are tuned for **one 12-hour Kaggle session** on a single T4. They produce a clean upward reward-growth curve through Phase 1 + early Phase 2; if you have a second session, lower `--episodes` is fine because LoRA adapters resume cleanly via Step 4 above.\n",
175
- "\n",
176
- "| Parameter | Value | Reason |\n",
177
- "|---|---|---|\n",
178
- "| Model | `unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit` | Llama-3-family small-tier, 4-bit ~5 GB |\n",
179
- "| LoRA rank | 16 | balances expressivity vs speed on T4 |\n",
180
- "| Group size G | 2 | Kaggle's T4 fits G=2 comfortably; G=4 needs 30+ min/group |\n",
181
- "| Episodes (cap) | 120 | hard cap; early-stop usually finishes first |\n",
182
- "| LR | 5e-6 | conservative, prevents catastrophic forgetting on small group |\n",
183
- "| KL beta | 0.04 | matches the paper's recipe; restrains drift from base policy |\n",
184
- "| Max episode steps | 20 | matches `triage_env.py` default |\n",
185
- "| Internal exchanges | 5 | shorter than default (8) to fit within 12 h budget |\n",
186
- "\n",
187
- "### Train-until-optimal (per-phase reward thresholds)\n",
188
- "\n",
189
- "Training **never** runs for a fixed episode budget. After every GRPO update we look at the last `CONVERGENCE_WINDOW=3` groups; if **all three** belong to the same current phase AND each has `rolling_avg_reward >= PHASE_REWARD_TARGETS[current_phase]`, we either:\n",
190
- "\n",
191
- "- **Phase 1 / Phase 2** β†’ force-promote to the next curriculum phase (the buffer is then cleared so stale entries don't satisfy the next phase's check).\n",
192
- "- **Phase 3** β†’ terminate training (this is the 'optimal rewards constantly received' criterion).\n",
193
- "\n",
194
- "Why per-phase, not a single global bar? The phases are not equally difficult β€” Phase 1 wins are worth ~`+2.0` on the reward scale (full terminal_win on a clean SOAP) while Phase 3 wins routinely cost `~0.5` in consent / empathy friction even when the diagnosis is correct. A single global `+1.5` would either gate Phase 3 too aggressively or pass Phase 1 with garbage Phase-2 behavior.\n",
195
- "\n",
196
- "| Phase | Default target | Why this number | Action when met |\n",
197
- "|---|---|---|---|\n",
198
- "| 1 β€” Tool Mastery | `+1.5` | A Phase-1 episode that uses tools cleanly + discharges with the correct treatment lands at `+1.6 .. +2.0`. Sustaining `+1.5` means the model has tool-format down. | force-promote to Phase 2 |\n",
199
- "| 2 β€” Clinical Reasoning | `+1.2` | Phase 2 adds noisy SOAP and mixed compliance. A solid clinician policy lands at `+1.2 .. +1.5`. | force-promote to Phase 3 |\n",
200
- "| 3 β€” Empathetic Negotiation | `+1.0` | Phase 3 imposes empathy + consent costs (`-0.3..-0.6` per episode even on wins). Sustained `+1.0` here is genuinely hard and is the hackathon success criterion. | END TRAINING |\n",
201
- "\n",
202
- "| Knob | Default | Meaning |\n",
203
- "|---|---|---|\n",
204
- "| `PHASE_REWARD_TARGETS` | `{1: 1.5, 2: 1.2, 3: 1.0}` | per-phase sustained rolling-avg-reward bar |\n",
205
- "| `PHASE_MIN_WIN_RATE` | `0.20` | soft floor on rolling win rate (sanity check) |\n",
206
- "| `CONVERGENCE_WINDOW` | `3` | how many consecutive groups must hit the bar |\n",
207
- "| `EARLY_STOP_ENABLED` | `True` | set `False` to always burn the full `NUM_EPISODES` budget |\n",
208
- "\n",
209
- "### Estimated wall-clock per phase on Kaggle T4 Γ—2\n",
210
- "\n",
211
- "Each episode = 6–14 env steps Γ— (Doctor.generate β‰ˆ 2–3 s) + 4–8 Groq calls (β‰ˆ 0.4–1.0 s each). One GRPO update over `G=2` trajectories = 1 forward + 1 backward over response tokens β‰ˆ 60–120 s on T4. Net per-group wall-clock β‰ˆ **8–12 minutes**.\n",
212
- "\n",
213
- "| Phase | Typical episodes to reach target | Wall-clock | Notes |\n",
214
- "|---|---|---|---|\n",
215
- "| 1 (target `+1.5` Γ— 3) | 16 – 30 episodes (8 – 15 groups) | **~1.5 – 2.5 h** | Easy patients + clean SOAP β€” tool-format is the only thing the model has to learn. |\n",
216
- "| 2 (target `+1.2` Γ— 3) | 24 – 40 episodes (12 – 20 groups) | **~2.0 – 3.5 h** | Mixed-compliance patients + noisy SOAP. Bulk of the policy improvement happens here. |\n",
217
- "| 3 (target `+1.0` Γ— 3) | 30 – 60 episodes (15 – 30 groups) | **~2.5 – 5.0 h** | Hard patients + empathy/consent costs. May not converge in 12 h on a fresh base; that's why `NUM_EPISODES=120` is the hard cap. |\n",
218
- "| **Total** | 70 – 130 episodes | **~6 – 11 h** | Fits inside Kaggle's 12 h GPU session with ~1 h margin. |\n",
219
  "\n",
220
- "If Phase 3 doesn't converge before the 12 h limit, your latest LoRA checkpoint is already on HF Hub (we push every 20 episodes), so just resume in a fresh session via `HF_RESUME_REPO`."
 
 
 
221
  ]
222
  },
223
  {
@@ -226,65 +394,110 @@
226
  "metadata": {},
227
  "outputs": [],
228
  "source": [
229
- "# --- Training hyperparameters ---\n",
 
 
230
  "MODEL_NAME = \"unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit\"\n",
231
- "NUM_EPISODES = 120 # HARD CAP; early-stop usually finishes first\n",
232
  "GROUP_SIZE = 2\n",
233
  "LEARNING_RATE = 5e-6\n",
234
  "KL_BETA = 0.04\n",
235
  "OUTPUT_DIR = \"/kaggle/working/er_map_grpo_checkpoints\"\n",
236
  "PUSH_EVERY_EPS = 20\n",
237
- "USE_WANDB = bool(os.environ.get(\"WANDB_API_KEY\"))\n",
238
- "\n",
239
- "# --- Early-stopping (per-phase reward thresholds) ---\n",
240
- "# After every GRPO update, we check the last CONVERGENCE_WINDOW groups.\n",
241
- "# If ALL of them are in the SAME current phase AND each has\n",
242
- "# rolling_avg_reward >= PHASE_REWARD_TARGETS[current_phase], we either:\n",
243
- "# - force-promote to the next phase (Phase 1, Phase 2), OR\n",
 
 
244
  "# - terminate training (Phase 3).\n",
245
- "# Baseline (un-trained Groq Doctor) avg reward by phase:\n",
246
- "# P1=+0.76, P2=+0.59, P3=+0.39\n",
247
- "# So the +1.5/+1.2/+1.0 bar = 2.0x / 2.0x / 2.6x improvement.\n",
248
- "EARLY_STOP_ENABLED = True\n",
249
- "PHASE_REWARD_TARGETS = {1: 1.5, 2: 1.2, 3: 1.0}\n",
250
- "PHASE_MIN_WIN_RATE = 0.20 # soft floor; +1.0 reward implies >=20% wins\n",
251
- "CONVERGENCE_WINDOW = 3 # 3 consecutive groups must qualify\n",
252
- "\n",
253
- "# --- Per-episode budget controls (passed via env vars) ---\n",
254
  "os.environ[\"ERMAP_MAX_EPISODE_STEPS\"] = \"20\"\n",
255
  "os.environ[\"ERMAP_MAX_INTERNAL_EXCHANGES\"] = \"5\"\n",
256
- "# Doctor-on-Kaggle is the LOCAL trained model, NOT a Groq call. The\n",
257
- "# Doctor's Groq key is therefore unused here, but Nurse / Patient /\n",
258
- "# Empathy Judge / Medical Judge all hit Groq once per env step.\n",
259
- "# Traffic-shaping: high-volume roleplay agents (Nurse + Patient) on the\n",
260
- "# 8B-instant pool (500K TPD, 14,400 RPD); the two judges stay on 70B-\n",
261
- "# versatile because their grading quality directly shapes the reward.\n",
262
  "os.environ[\"ERMAP_NURSE_MODEL\"] = \"llama-3.1-8b-instant\"\n",
263
  "os.environ[\"ERMAP_PATIENT_MODEL\"] = \"llama-3.1-8b-instant\"\n",
264
  "os.environ[\"ERMAP_EMPATHY_JUDGE_MODEL\"] = \"llama-3.3-70b-versatile\"\n",
265
  "os.environ[\"ERMAP_MEDICAL_JUDGE_MODEL\"] = \"llama-3.3-70b-versatile\"\n",
266
  "\n",
267
- "# Sanity: at least one Groq key must be present, otherwise the env\n",
268
- "# falls back to mock responses and the trained model won't see\n",
269
- "# realistic dialogue.\n",
270
- "assert any(\n",
271
- " os.environ.get(k) for k in [\n",
272
- " \"GROQ_NURSE_API_KEY\", \"GROQ_PATIENT_API_KEY\",\n",
273
- " \"GROQ_EMPATHY_JUDGE_API_KEY\", \"GROQ_MEDICAL_JUDGE_API_KEY\",\n",
274
- " \"GROQ_API_KEY\",\n",
275
- " ]\n",
276
- "), (\"No Groq key found in Kaggle Secrets β€” add at least \"\n",
277
- " \"GROQ_NURSE_API_KEY before running training.\")\n",
278
- "print(\"Hyperparameters and env vars set.\")"
279
  ]
280
  },
281
  {
282
- "cell_type": "markdown",
 
283
  "metadata": {},
 
284
  "source": [
285
- "## 6 Β· Dry-run smoke test (no GPU, no model load)\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
286
  "\n",
287
- "Verifies the curriculum scheduler + reward verifier + metrics logger are wired correctly **before** burning GPU minutes on a real run."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
288
  ]
289
  },
290
  {
@@ -293,6 +506,10 @@
293
  "metadata": {},
294
  "outputs": [],
295
  "source": [
 
 
 
 
296
  "from ER_MAP.training.train_grpo import train\n",
297
  "\n",
298
  "_ = train(\n",
@@ -303,17 +520,12 @@
303
  " kl_beta=KL_BETA,\n",
304
  " output_dir=\"/kaggle/working/_dryrun\",\n",
305
  " dry_run=True,\n",
 
 
 
 
306
  ")\n",
307
- "print(\"Dry-run OK β€” scheduler + verifier wiring is healthy.\")"
308
- ]
309
- },
310
- {
311
- "cell_type": "markdown",
312
- "metadata": {},
313
- "source": [
314
- "## 7 Β· Wire periodic HF-Hub push into the training loop\n",
315
- "\n",
316
- "We monkey-patch `save_lora_adapters` so every checkpoint dump also pushes to HF Hub. Failures are non-fatal β€” training keeps running even if the push fails."
317
  ]
318
  },
319
  {
@@ -322,18 +534,20 @@
322
  "metadata": {},
323
  "outputs": [],
324
  "source": [
 
 
 
 
 
325
  "from ER_MAP.training import train_grpo as _tg\n",
326
  "_original_save = _tg.save_lora_adapters\n",
327
- "_episode_marker = {\"n\": 0}\n",
328
  "\n",
329
  "def save_lora_adapters_with_push(model, tokenizer, output_dir):\n",
330
  " _original_save(model, tokenizer, output_dir)\n",
331
- " _episode_marker[\"n\"] += 1\n",
332
  " if HF_PUSH_REPO and \"<your-username>\" not in HF_PUSH_REPO:\n",
333
  " try:\n",
334
  " push_checkpoint_to_hub(\n",
335
- " output_dir,\n",
336
- " HF_PUSH_REPO,\n",
337
  " commit_message=f\"checkpoint @ {os.path.basename(output_dir)}\",\n",
338
  " )\n",
339
  " except Exception as e:\n",
@@ -347,18 +561,25 @@
347
  "cell_type": "markdown",
348
  "metadata": {},
349
  "source": [
350
- "## 8 Β· Run real training (the 6-11 hour cell)\n",
351
- "\n",
352
- "With the per-phase early-stop targets `{1: +1.5, 2: +1.2, 3: +1.0}` set above, expect:\n",
353
- "\n",
354
- "- ~3-5 minutes per episode (6-14 env steps Γ— Doctor.generate + 4-8 Γ— Groq calls)\n",
355
- "- ~1-2 minutes amortized per GRPO update (G=2 trajectories Γ— response-token log-probs)\n",
356
- "- **Per-group wall-clock β‰ˆ 8-12 min** (2 episodes + 1 update)\n",
357
- "- **Phase 1 β†’ Phase 2 force-promote** typically lands at **episode 16-30** (sustained `+1.5` Γ— 3 groups)\n",
358
- "- **Phase 2 β†’ Phase 3 force-promote** typically lands at **episode 40-70**\n",
359
- "- **Phase 3 EARLY STOP** typically lands at **episode 70-130** (sustained `+1.0` Γ— 3 groups)\n",
360
- "- Reward-growth signal (rolling avg) becomes visible after ~episode 20\n",
361
- "- If `NUM_EPISODES=120` is exhausted before Phase 3 converges, training stops at the cap and the latest checkpoint is on HF Hub β€” resume in a fresh session via `HF_RESUME_REPO`."
 
 
 
 
 
 
 
362
  ]
363
  },
364
  {
@@ -367,31 +588,24 @@
367
  "metadata": {},
368
  "outputs": [],
369
  "source": [
 
370
  "metrics = train(\n",
371
  " num_episodes=NUM_EPISODES,\n",
372
  " group_size=GROUP_SIZE,\n",
373
  " model_name=MODEL_NAME,\n",
374
- " groq_api_key=os.environ.get(\"GROQ_NURSE_API_KEY\", \"\") or os.environ.get(\"GROQ_API_KEY\", \"\"),\n",
 
375
  " learning_rate=LEARNING_RATE,\n",
376
  " kl_beta=KL_BETA,\n",
377
  " use_wandb=USE_WANDB,\n",
378
  " output_dir=OUTPUT_DIR,\n",
379
  " dry_run=False,\n",
380
- " # ----- Per-phase early-stop ('train until optimal rewards are constantly received') -----\n",
381
  " phase_reward_targets=PHASE_REWARD_TARGETS,\n",
382
  " phase_min_win_rate=PHASE_MIN_WIN_RATE,\n",
383
  " convergence_window=CONVERGENCE_WINDOW,\n",
384
  " early_stop=EARLY_STOP_ENABLED,\n",
385
- ")"
386
- ]
387
- },
388
- {
389
- "cell_type": "markdown",
390
- "metadata": {},
391
- "source": [
392
- "## 9 Β· Final push: adapters + merged fp16 weights\n",
393
- "\n",
394
- "The training loop already wrote `final_lora/` and `final_merged_fp16/` to `OUTPUT_DIR`. We push both to HF Hub so you can serve them from Vercel / a HF Space without re-running training."
395
  ]
396
  },
397
  {
@@ -400,20 +614,17 @@
400
  "metadata": {},
401
  "outputs": [],
402
  "source": [
 
403
  "FINAL_LORA_DIR = f\"{OUTPUT_DIR}/final_lora\"\n",
404
  "FINAL_MERGED_DIR = f\"{OUTPUT_DIR}/final_merged_fp16\"\n",
405
  "\n",
406
  "if HF_PUSH_REPO and \"<your-username>\" not in HF_PUSH_REPO:\n",
407
- " push_checkpoint_to_hub(\n",
408
- " FINAL_LORA_DIR, HF_PUSH_REPO,\n",
409
- " commit_message=\"final LoRA adapter\",\n",
410
- " )\n",
411
  " if os.path.isdir(FINAL_MERGED_DIR):\n",
412
- " push_checkpoint_to_hub(\n",
413
- " FINAL_MERGED_DIR, f\"{HF_PUSH_REPO}-merged\",\n",
414
- " commit_message=\"final merged fp16\",\n",
415
- " )\n",
416
- " print(\"Final checkpoints pushed.\")\n",
417
  "else:\n",
418
  " print(\"HF_PUSH_REPO not configured β€” skipping final push.\")"
419
  ]
@@ -422,16 +633,20 @@
422
  "cell_type": "markdown",
423
  "metadata": {},
424
  "source": [
425
- "## 10 Β· Per-phase training graphs (one dashboard per curriculum phase)\n",
 
 
 
 
 
426
  "\n",
427
- "We render a complete 6-panel dashboard for every phase that contains episodes, plus a cross-phase overview and a phase-comparison bar chart. All PNGs are written to `er_map_grpo_checkpoints/plots/` and uploaded to HF Hub at the end of the notebook so they survive session expiry.\n",
428
  "\n",
429
- "**Each per-phase dashboard contains:**\n",
430
  "1. **Reward growth** β€” raw scatter + rolling mean (w=10) + verified rolling mean\n",
431
- "2. **Rolling win rate** β€” w=20 win rate evolution within the phase\n",
432
- "3. **Outcome distribution over time** β€” stacked bars (WIN/PARTIAL/INCORRECT/AMA_LOSS/FATAL_LOSS) per episode bin\n",
433
- "4. **Reward components** β€” mean of each component (process / treatment / empathy / labs / etc.) within the phase\n",
434
- "5. **GRPO update statistics** β€” loss + KL divergence per group update\n",
435
  "6. **Episode length distribution** β€” histogram of step counts"
436
  ]
437
  },
@@ -441,6 +656,7 @@
441
  "metadata": {},
442
  "outputs": [],
443
  "source": [
 
444
  "from ER_MAP.plotting import plot_per_phase_dashboards\n",
445
  "from IPython.display import Image, display, Markdown\n",
446
  "\n",
@@ -450,54 +666,44 @@
450
  " output_dir=PLOTS_DIR,\n",
451
  ")\n",
452
  "\n",
453
- "print(f\"Saved {len(written)} chart(s):\")\n",
454
  "for name, path in written.items():\n",
455
  " size_kb = os.path.getsize(path) / 1024\n",
456
  " print(f\" {name:<28s} -> {path} ({size_kb:.0f} KB)\")\n",
457
  "\n",
458
- "# Display each chart inline in the notebook so the operator sees them\n",
459
- "# without leaving Kaggle. Order: per-phase dashboards first (1, 2, 3),\n",
460
- "# then the cross-phase overview, then the bar comparison.\n",
461
- "ordered_keys = (\n",
462
- " sorted(k for k in written if k.startswith(\"phase\")) +\n",
463
- " [\"all_phases_overview\", \"all_phases_comparison\"]\n",
464
- ")\n",
465
- "for key in ordered_keys:\n",
466
  " if key not in written:\n",
467
  " continue\n",
468
  " display(Markdown(f\"### {key.replace('_', ' ').title()}\"))\n",
469
  " display(Image(filename=written[key]))"
470
  ]
471
  },
472
- {
473
- "cell_type": "markdown",
474
- "metadata": {},
475
- "source": [
476
- "## 10b Β· Push the plots to HF Hub (so they survive session expiry)"
477
- ]
478
- },
479
  {
480
  "cell_type": "code",
481
  "execution_count": null,
482
  "metadata": {},
483
  "outputs": [],
484
  "source": [
 
485
  "if HF_PUSH_REPO and \"<your-username>\" not in HF_PUSH_REPO:\n",
486
- " push_checkpoint_to_hub(\n",
487
- " PLOTS_DIR, HF_PUSH_REPO,\n",
488
- " commit_message=\"per-phase training plots\",\n",
489
- " )\n",
490
  "else:\n",
491
- " print(\"HF_PUSH_REPO not configured \u2014 plots stay only in /kaggle/working/.\")"
492
  ]
493
  },
494
  {
495
  "cell_type": "markdown",
496
  "metadata": {},
497
  "source": [
498
- "## 11 Β· (Optional) Inference smoke-test on the trained model\n",
499
  "\n",
500
- "Catches the classic \"merge path looked OK but the saved model emits garbage\" failure mode before the demo."
 
501
  ]
502
  },
503
  {
@@ -506,6 +712,7 @@
506
  "metadata": {},
507
  "outputs": [],
508
  "source": [
 
509
  "from ER_MAP.training.train_grpo import generate_doctor_action, load_model_and_tokenizer\n",
510
  "from peft import PeftModel\n",
511
  "\n",
@@ -536,4 +743,4 @@
536
  },
537
  "nbformat": 4,
538
  "nbformat_minor": 5
539
- }
 
4
  "cell_type": "markdown",
5
  "metadata": {},
6
  "source": [
7
+ "# ER-MAP β€” Doctor Agent GRPO Training (Kaggle Free-Tier Β· v3 stable)\n",
8
+ "\n",
9
+ "Trains the **Doctor LLM** (Llama-3.1-8B-Instruct, 4-bit + LoRA r=16) via GRPO\n",
10
+ "with a 3-phase curriculum on Kaggle's free GPU. Designed to survive Kaggle's\n",
11
+ "pre-baked image quirks (numpy / Pillow ABI mismatches, torch + torchvision\n",
12
+ "CUDA-major mismatches, transient `unsloth_zoo` upgrades).\n",
13
+ "\n",
14
+ "## TL;DR β€” How to run this notebook\n",
15
+ "\n",
16
+ "1. **Notebook settings (right sidebar):**\n",
17
+ " - Accelerator: **GPU T4 Γ—2** (or P100)\n",
18
+ " - Internet: **On**\n",
19
+ " - Persistence: Files only\n",
20
+ "2. **Kaggle Secrets** (Add-ons β†’ Secrets):\n",
21
+ " - **Required:** `GROQ_NURSE_API_KEY`, `GROQ_PATIENT_API_KEY`,\n",
22
+ " `GROQ_EMPATHY_JUDGE_API_KEY`, `GROQ_MEDICAL_JUDGE_API_KEY`, `HF_TOKEN`\n",
23
+ " - **Optional:** `WANDB_API_KEY`\n",
24
+ "3. **Run cells 2 β†’ 3 (sanity + REPAIR).** When cell 3 prints\n",
25
+ " `RESTART REQUIRED`, click **Run β†’ Restart kernel**, then resume from cell 5.\n",
26
+ "4. **Run cells 5 β†’ 11 (verify + configure + dry-run + pre-flight).** Each cell\n",
27
+ " should print an `OK` line before moving on.\n",
28
+ "5. **Run cell 13 (the long training cell, 4–6 hours).**\n",
29
+ "6. **Run cells 14 β†’ 17 (final push + plots + inference smoke-test).**\n",
30
+ "\n",
31
+ "## Curriculum + reward thresholds (this run)\n",
32
+ "\n",
33
+ "Constant per-phase rolling-avg-reward bars; sustained for **3 consecutive\n",
34
+ "GRPO groups** triggers either a phase promotion or end-of-training.\n",
35
+ "\n",
36
+ "| Phase | Reward target (sustained Γ—3 groups) | Action when met |\n",
37
+ "|---|---|---|\n",
38
+ "| 1 β€” Tool Mastery | `+1.2` | force-promote to Phase 2 |\n",
39
+ "| 2 β€” Clinical Reasoning | `+1.1` | force-promote to Phase 3 |\n",
40
+ "| 3 β€” Empathetic Negotiation | `+1.0` | END TRAINING |\n",
41
+ "\n",
42
+ "Why these numbers? The un-trained 8B Doctor's baseline on the same env is\n",
43
+ "`P1=+0.76, P2=+0.59, P3=+0.39`. Targets of `+1.2 / +1.1 / +1.0` correspond\n",
44
+ "to roughly `1.6Γ— / 1.9Γ— / 2.6Γ—` improvement over baseline β€” a meaningful\n",
45
+ "signal but reachable inside Kaggle's 12 h session limit."
46
  ]
47
  },
48
  {
 
51
  "metadata": {},
52
  "outputs": [],
53
  "source": [
54
+ "# === CELL 2 β€” Sanity check (GPU + disk + python + internet) ===\n",
55
+ "# Run this FIRST. If any check fails, fix it before running the REPAIR cell.\n",
56
+ "\n",
57
+ "import os, shutil, subprocess, sys, socket\n",
58
+ "\n",
59
+ "print(\"--- GPU ---\")\n",
60
+ "try:\n",
61
+ " print(subprocess.check_output(\n",
62
+ " [\"nvidia-smi\", \"--query-gpu=name,memory.total,memory.free\", \"--format=csv\"],\n",
63
+ " timeout=10,\n",
64
+ " ).decode())\n",
65
+ "except Exception as e:\n",
66
+ " print(f\"nvidia-smi failed: {e}\")\n",
67
+ " print(\"-> Set Accelerator to 'GPU T4 x2' in the right sidebar.\")\n",
68
+ "\n",
69
+ "print(\"--- Disk (/kaggle/working) ---\")\n",
70
+ "total, used, free = shutil.disk_usage(\"/kaggle/working\")\n",
71
+ "print(f\" total={total/1e9:5.1f} GB | used={used/1e9:5.1f} GB | free={free/1e9:5.1f} GB\")\n",
72
+ "if free < 8 * 1e9:\n",
73
+ " print(\" WARNING: free disk < 8 GB β€” repair cell may fail. \"\n",
74
+ " \"Consider 'Run > Restart and clear cell outputs' to reset /tmp.\")\n",
75
+ "\n",
76
+ "print(\"--- Python ---\")\n",
77
+ "print(f\" python={sys.version.split()[0]} | exe={sys.executable}\")\n",
78
+ "\n",
79
+ "print(\"--- Internet (api.groq.com:443) ---\")\n",
80
+ "try:\n",
81
+ " socket.create_connection((\"api.groq.com\", 443), timeout=5).close()\n",
82
+ " print(\" reachable\")\n",
83
+ "except Exception as e:\n",
84
+ " print(f\" UNREACHABLE: {e}\")\n",
85
+ " print(\" -> Settings (right sidebar) -> Internet -> ON\")"
86
  ]
87
  },
88
  {
 
91
  "metadata": {},
92
  "outputs": [],
93
  "source": [
94
+ "# === CELL 3 β€” REPAIR CELL (idempotent full environment rebuild) ===\n",
95
+ "# Single source of truth for ER-MAP's GPU stack. Safe to re-run. After it\n",
96
+ "# finishes you'll see one of two final lines:\n",
97
+ "#\n",
98
+ "# RESTART REQUIRED -> Run -> Restart kernel, then resume from cell 5\n",
99
+ "# REPAIR OK -> proceed directly to cell 5\n",
100
+ "#\n",
101
+ "# Note: this cell only runs shell commands and one isolated subprocess.\n",
102
+ "# It deliberately does NOT `import torch / numpy / Pillow / unsloth` in the\n",
103
+ "# kernel, so re-running it after a botched install does not poison further\n",
104
+ "# attempts.\n",
105
+ "\n",
106
+ "print(\"=\" * 72); print(\" CELL 3 β€” REPAIR\"); print(\"=\" * 72)\n",
107
+ "\n",
108
+ "# 1. Clean caches (Kaggle's /kaggle/working is only 20 GB β€” installs\n",
109
+ "# routinely fill it after a few re-runs).\n",
110
+ "print(\"[1/6] Cleaning pip + tmp + HF dataset caches...\")\n",
111
+ "get_ipython().system('pip cache purge -q || true')\n",
112
+ "get_ipython().system('rm -rf /tmp/* /root/.cache/pip /root/.cache/huggingface/datasets 2>/dev/null || true')\n",
113
+ "\n",
114
+ "# 2. Pin torch + torchvision to the cu128 wheel (matches Kaggle's CUDA 12.8\n",
115
+ "# base image). DON'T let pip pull a generic CUDA-13 build β€” that breaks\n",
116
+ "# bitsandbytes (libnvJitLink.so.13 missing) and torchvision (CUDA-major\n",
117
+ "# mismatch RuntimeError at import time).\n",
118
+ "print(\"[2/6] Installing torch==2.10.0 + torchvision==0.25.0 (cu128)...\")\n",
119
+ "get_ipython().system('pip install -q --no-cache-dir --force-reinstall '\n",
120
+ " 'torch==2.10.0 torchvision==0.25.0 '\n",
121
+ " '--index-url https://download.pytorch.org/whl/cu128')\n",
122
+ "\n",
123
+ "# 3. Reinstall bitsandbytes against the now-pinned torch.\n",
124
+ "print(\"[3/6] Reinstalling bitsandbytes...\")\n",
125
+ "get_ipython().system('pip install -q --no-cache-dir --force-reinstall bitsandbytes')\n",
126
+ "\n",
127
+ "# 4. Upgrade unsloth + unsloth_zoo + trl in lockstep. unsloth and\n",
128
+ "# unsloth_zoo are released as a matched pair; if pip pulls a fresh\n",
129
+ "# unsloth_zoo against an old unsloth you get\n",
130
+ "# ImportError: cannot import name 'create_gradient_checkpointing_buffer'\n",
131
+ "print(\"[4/6] Upgrading unsloth + unsloth_zoo + trl...\")\n",
132
+ "get_ipython().system('pip install -q --upgrade --no-cache-dir '\n",
133
+ " 'unsloth unsloth_zoo \"trl>=0.18.2\"')\n",
134
+ "\n",
135
+ "# 5. ER-MAP runtime deps that aren't pre-installed on Kaggle.\n",
136
+ "print(\"[5/6] Installing ER-MAP runtime deps...\")\n",
137
+ "get_ipython().system('pip install -q --no-cache-dir '\n",
138
+ " '\"groq>=0.18.0\" \"huggingface_hub>=0.25.0\" '\n",
139
+ " '\"gymnasium>=0.29.0\" \"openenv-core>=0.1.0\"')\n",
140
+ "\n",
141
+ "# 6. Verify in a SUBPROCESS (so the parent kernel never imports any of these\n",
142
+ "# while pip is mid-flight, which is what causes the\n",
143
+ "# 'numpy was upgraded mid-session (loaded: X, installed: Y)' RuntimeError\n",
144
+ "# we kept hitting before).\n",
145
+ "print(\"[6/6] Verifying via subprocess...\")\n",
146
+ "import subprocess, sys, json\n",
147
+ "\n",
148
+ "verify_script = r'''\n",
149
+ "import json, sys\n",
150
+ "out = {\"ok\": True, \"details\": {}, \"errors\": []}\n",
151
+ "try:\n",
152
+ " import importlib.metadata as md\n",
153
+ " for pkg in (\"torch\", \"torchvision\", \"bitsandbytes\", \"unsloth\", \"unsloth_zoo\",\n",
154
+ " \"trl\", \"transformers\", \"peft\", \"accelerate\", \"groq\",\n",
155
+ " \"huggingface_hub\", \"gymnasium\", \"numpy\", \"Pillow\"):\n",
156
+ " try:\n",
157
+ " out[\"details\"][pkg + \"_installed\"] = md.version(pkg)\n",
158
+ " except md.PackageNotFoundError:\n",
159
+ " out[\"details\"][pkg + \"_installed\"] = None\n",
160
+ "\n",
161
+ " import torch, torchvision, numpy as np, PIL, unsloth, unsloth_zoo, bitsandbytes, trl\n",
162
+ " out[\"details\"][\"torch_loaded\"] = torch.__version__\n",
163
+ " out[\"details\"][\"torch_cuda\"] = torch.version.cuda\n",
164
+ " out[\"details\"][\"cuda_available\"] = bool(torch.cuda.is_available())\n",
165
+ " out[\"details\"][\"gpu_count\"] = int(torch.cuda.device_count())\n",
166
+ " out[\"details\"][\"torchvision_loaded\"] = torchvision.__version__\n",
167
+ " out[\"details\"][\"numpy_loaded\"] = np.__version__\n",
168
+ " out[\"details\"][\"pillow_loaded\"] = PIL.__version__\n",
169
+ " out[\"details\"][\"unsloth_loaded\"] = unsloth.__version__\n",
170
+ " out[\"details\"][\"unsloth_zoo_loaded\"] = unsloth_zoo.__version__\n",
171
+ " out[\"details\"][\"bitsandbytes_loaded\"] = bitsandbytes.__version__\n",
172
+ " out[\"details\"][\"trl_loaded\"] = trl.__version__\n",
173
+ "\n",
174
+ " # Cross-check loaded-vs-installed for the C-extension libs that bit us\n",
175
+ " # on every previous run.\n",
176
+ " for pkg, loaded_key, installed_key in [\n",
177
+ " (\"numpy\", \"numpy_loaded\", \"numpy_installed\"),\n",
178
+ " (\"Pillow\", \"pillow_loaded\", \"Pillow_installed\"),\n",
179
+ " (\"torch\", \"torch_loaded\", \"torch_installed\"),\n",
180
+ " ]:\n",
181
+ " loaded = out[\"details\"].get(loaded_key)\n",
182
+ " installed = out[\"details\"].get(installed_key)\n",
183
+ " if loaded and installed and loaded != installed:\n",
184
+ " # Strip any local-version suffix (e.g. '+cu128') before compare.\n",
185
+ " if loaded.split(\"+\")[0] != installed.split(\"+\")[0]:\n",
186
+ " out[\"errors\"].append(\n",
187
+ " f\"{pkg} mismatch: loaded={loaded} installed={installed}\"\n",
188
+ " )\n",
189
+ "except Exception as e:\n",
190
+ " out[\"ok\"] = False\n",
191
+ " out[\"errors\"].append(f\"{type(e).__name__}: {e}\")\n",
192
+ "print(json.dumps(out, default=str))\n",
193
+ "'''.lstrip()\n",
194
+ "\n",
195
+ "res = subprocess.run([sys.executable, \"-c\", verify_script],\n",
196
+ " capture_output=True, text=True, timeout=180)\n",
197
+ "print(res.stdout if res.stdout else \"<no stdout>\")\n",
198
+ "if res.stderr:\n",
199
+ " print(\"---- subprocess stderr ----\"); print(res.stderr)\n",
200
+ "\n",
201
+ "# Parse the LAST line of stdout (others are prints from package init).\n",
202
+ "try:\n",
203
+ " last = res.stdout.strip().splitlines()[-1]\n",
204
+ " parsed = json.loads(last)\n",
205
+ "except Exception:\n",
206
+ " parsed = {\"ok\": False, \"errors\": [\"could not parse verification output\"]}\n",
207
+ "\n",
208
+ "ok = parsed.get(\"ok\") and not parsed.get(\"errors\")\n",
209
+ "d = parsed.get(\"details\", {})\n",
210
+ "\n",
211
+ "print(\"\n",
212
+ "\" + \"=\" * 72)\n",
213
+ "if ok:\n",
214
+ " print(\" REPAIR OK\")\n",
215
+ " print(f\" torch : {d.get('torch_loaded')} (CUDA {d.get('torch_cuda')})\")\n",
216
+ " print(f\" torchvision : {d.get('torchvision_loaded')}\")\n",
217
+ " print(f\" bitsandbytes: {d.get('bitsandbytes_loaded')}\")\n",
218
+ " print(f\" unsloth : {d.get('unsloth_loaded')} | unsloth_zoo: {d.get('unsloth_zoo_loaded')}\")\n",
219
+ " print(f\" trl : {d.get('trl_loaded')}\")\n",
220
+ " print(f\" numpy : {d.get('numpy_loaded')} | Pillow: {d.get('pillow_loaded')}\")\n",
221
+ " print(f\" GPUs : {d.get('gpu_count')} (cuda_available={d.get('cuda_available')})\")\n",
222
+ " print()\n",
223
+ " print(\" -> If this kernel previously imported torch/numpy/Pillow/unsloth,\")\n",
224
+ " print(\" RESTART NOW (Run -> Restart kernel) before continuing to cell 5.\")\n",
225
+ " print(\" If this is a fresh kernel, you can proceed directly.\")\n",
226
+ "else:\n",
227
+ " print(\" RESTART REQUIRED β€” issues detected:\")\n",
228
+ " for e in parsed.get(\"errors\", []):\n",
229
+ " print(f\" - {e}\")\n",
230
+ " print()\n",
231
+ " print(\" Action: Run -> Restart kernel, then re-run from cell 2.\")\n",
232
+ "print(\"=\" * 72)"
233
  ]
234
  },
235
  {
236
  "cell_type": "markdown",
237
  "metadata": {},
238
  "source": [
239
+ "## ⚠ Restart kernel here if cell 3 said `RESTART REQUIRED`\n",
240
+ "\n",
241
+ "Click **Run β†’ Restart kernel** (or **Run β†’ Restart & clear cell outputs**),\n",
242
+ "then resume from **cell 5**. Skipping the restart will produce ABI mismatch\n",
243
+ "errors at the first GPU op.\n",
244
  "\n",
245
+ "If cell 3 said `REPAIR OK` AND this is a fresh kernel that hasn't imported\n",
246
+ "torch/numpy/Pillow/unsloth yet, you can proceed to cell 5 directly."
247
  ]
248
  },
249
  {
 
252
  "metadata": {},
253
  "outputs": [],
254
  "source": [
255
+ "# === CELL 5 β€” Post-restart verify (this kernel can import everything) ===\n",
256
+ "import importlib.metadata as md\n",
257
+ "\n",
258
+ "print(\"--- Loaded versions in this kernel ---\")\n",
259
+ "import torch, numpy, PIL, torchvision, unsloth, unsloth_zoo, bitsandbytes, trl, transformers, peft\n",
260
+ "\n",
261
+ "versions = {\n",
262
+ " \"torch\": torch.__version__,\n",
263
+ " \"torchvision\": torchvision.__version__,\n",
264
+ " \"numpy\": numpy.__version__,\n",
265
+ " \"Pillow\": PIL.__version__,\n",
266
+ " \"unsloth\": unsloth.__version__,\n",
267
+ " \"unsloth_zoo\": unsloth_zoo.__version__,\n",
268
+ " \"bitsandbytes\": bitsandbytes.__version__,\n",
269
+ " \"trl\": trl.__version__,\n",
270
+ " \"transformers\": transformers.__version__,\n",
271
+ " \"peft\": peft.__version__,\n",
272
+ "}\n",
273
+ "all_ok = True\n",
274
+ "for k, v in versions.items():\n",
275
+ " try:\n",
276
+ " inst = md.version(k)\n",
277
+ " except md.PackageNotFoundError:\n",
278
+ " inst = \"(not installed)\"\n",
279
+ " # Tolerate local version suffixes like '+cu128'\n",
280
+ " flag = \"OK\" if inst.split(\"+\")[0] == v.split(\"+\")[0] else f\"MISMATCH (installed={inst})\"\n",
281
+ " if \"MISMATCH\" in flag:\n",
282
+ " all_ok = False\n",
283
+ " print(f\" {k:14s}: loaded={v:20s} [{flag}]\")\n",
284
+ "\n",
285
+ "print()\n",
286
+ "print(f\" CUDA available : {torch.cuda.is_available()}\")\n",
287
+ "print(f\" GPU count : {torch.cuda.device_count()}\")\n",
288
+ "if torch.cuda.is_available():\n",
289
+ " for i in range(torch.cuda.device_count()):\n",
290
+ " p = torch.cuda.get_device_properties(i)\n",
291
+ " print(f\" GPU {i} : {p.name} ({p.total_memory/1e9:.1f} GB)\")\n",
292
+ "\n",
293
+ "print()\n",
294
+ "print(\"OK\" if all_ok else \"NOT OK β€” re-run cell 3 and restart kernel.\")"
295
  ]
296
  },
297
  {
 
300
  "metadata": {},
301
  "outputs": [],
302
  "source": [
303
+ "# === CELL 6 β€” Mount the ER-MAP repo into /kaggle/working ===\n",
304
+ "import os, subprocess, sys\n",
305
+ "\n",
306
+ "# OPTION A: clone a public GitHub fork (preferred). Edit GIT_URL.\n",
307
+ "GIT_URL = \"https://github.com/<your-fork>/Meta_Finals.git\"\n",
308
+ "BRANCH = \"main\"\n",
309
+ "REPO_ROOT = \"/kaggle/working/Meta_Finals\"\n",
310
+ "\n",
311
+ "# OPTION B: Kaggle Dataset upload β€” set this if you uploaded the repo\n",
312
+ "# as a Kaggle Dataset named \"ermap-source\" (Add Data -> Upload).\n",
313
+ "DATASET_DIR = \"/kaggle/input/ermap-source\"\n",
314
+ "\n",
315
+ "if not os.path.isdir(f\"{REPO_ROOT}/ER_MAP\"):\n",
316
+ " if \"<your-fork>\" not in GIT_URL:\n",
317
+ " print(f\"Cloning {GIT_URL}@{BRANCH} -> {REPO_ROOT}...\")\n",
318
+ " out = subprocess.run(\n",
319
+ " [\"git\", \"clone\", \"--depth\", \"1\", \"-b\", BRANCH, GIT_URL, REPO_ROOT],\n",
320
+ " capture_output=True, text=True,\n",
321
+ " )\n",
322
+ " print(out.stdout); print(out.stderr)\n",
323
+ " elif os.path.isdir(DATASET_DIR):\n",
324
+ " print(f\"Copying {DATASET_DIR} -> {REPO_ROOT}...\")\n",
325
+ " import shutil\n",
326
+ " shutil.copytree(DATASET_DIR, REPO_ROOT, dirs_exist_ok=True)\n",
327
+ "\n",
328
+ "assert os.path.isdir(f\"{REPO_ROOT}/ER_MAP\"), (\n",
329
+ " \"Repo not found.\\n\"\n",
330
+ " \" - Edit GIT_URL above to your GitHub fork, OR\\n\"\n",
331
+ " \" - Upload the repo as a Kaggle Dataset named 'ermap-source' (Add Data -> Upload).\"\n",
332
  ")\n",
333
  "\n",
334
+ "sys.path.insert(0, REPO_ROOT)\n",
335
+ "sys.path.insert(0, f\"{REPO_ROOT}/kaggle\")\n",
336
+ "print(f\"OK. Repo at {REPO_ROOT}\")"
337
  ]
338
  },
339
  {
340
+ "cell_type": "code",
341
+ "execution_count": null,
342
  "metadata": {},
343
+ "outputs": [],
344
  "source": [
345
+ "# === CELL 7 β€” Wire Kaggle Secrets into env vars ===\n",
346
+ "import os\n",
347
+ "from kaggle_helpers import load_kaggle_secrets, kaggle_env_summary\n",
348
  "\n",
349
+ "load_kaggle_secrets()\n",
350
+ "kaggle_env_summary()\n",
351
+ "\n",
352
+ "# Hard fail if no Groq key β€” training would silently use mock LLMs.\n",
353
+ "assert any(os.environ.get(k) for k in (\n",
354
+ " \"GROQ_NURSE_API_KEY\", \"GROQ_PATIENT_API_KEY\",\n",
355
+ " \"GROQ_EMPATHY_JUDGE_API_KEY\", \"GROQ_MEDICAL_JUDGE_API_KEY\",\n",
356
+ " \"GROQ_API_KEY\",\n",
357
+ ")), (\"No Groq key found in Kaggle Secrets. \"\n",
358
+ " \"Add at least GROQ_NURSE_API_KEY in Add-ons -> Secrets.\")\n",
359
+ "print(\"OK β€” at least one Groq key is wired.\")"
360
  ]
361
  },
362
  {
 
365
  "metadata": {},
366
  "outputs": [],
367
  "source": [
368
+ "# === CELL 8 β€” Hugging Face Hub config (for checkpoint backup) ===\n",
369
+ "import os\n",
370
+ "from kaggle_helpers import push_checkpoint_to_hub, download_checkpoint_from_hub\n",
371
+ "\n",
372
+ "# EDIT the line below to your HF model id (e.g. \"udayd/ermap-doctor-lora\").\n",
373
+ "HF_PUSH_REPO = \"<your-username>/ermap-doctor-lora\"\n",
374
+ "# To resume from a previous run, paste the same repo id here. Empty = fresh.\n",
375
+ "HF_RESUME_REPO = \"\"\n",
376
  "\n",
377
  "RESUME_DIR = \"/kaggle/working/checkpoints/resume\"\n",
378
  "if HF_RESUME_REPO:\n",
379
  " download_checkpoint_from_hub(HF_RESUME_REPO, RESUME_DIR)\n",
380
+ " contents = os.listdir(RESUME_DIR) if os.path.isdir(RESUME_DIR) else []\n",
381
+ " print(f\"Resume dir: {contents or '(empty)'}\")\n",
382
+ "else:\n",
383
+ " print(\"Starting fresh β€” no resume.\")\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
384
  "\n",
385
+ "if \"<your-username>\" in HF_PUSH_REPO:\n",
386
+ " print(\"\\nWARNING: HF_PUSH_REPO still has <your-username> placeholder.\")\n",
387
+ " print(\" Checkpoints will NOT be pushed to HF Hub.\")\n",
388
+ " print(\" Edit the cell above and re-run before training if you want backups.\")"
389
  ]
390
  },
391
  {
 
394
  "metadata": {},
395
  "outputs": [],
396
  "source": [
397
+ "# === CELL 9 β€” GRPO hyperparameters ===\n",
398
+ "import os\n",
399
+ "\n",
400
  "MODEL_NAME = \"unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit\"\n",
 
401
  "GROUP_SIZE = 2\n",
402
  "LEARNING_RATE = 5e-6\n",
403
  "KL_BETA = 0.04\n",
404
  "OUTPUT_DIR = \"/kaggle/working/er_map_grpo_checkpoints\"\n",
405
  "PUSH_EVERY_EPS = 20\n",
406
+ "USE_WANDB = False # WANDB conflicts with protobuf 7 on Kaggle base image\n",
407
+ "NUM_EPISODES = 200 # hard cap; early-stop usually finishes first\n",
408
+ "\n",
409
+ "# --- Per-phase reward thresholds (constant for this run) -------------------\n",
410
+ "# After every GRPO update we look at the last CONVERGENCE_WINDOW groups; if\n",
411
+ "# ALL of them belong to the same current phase AND each has\n",
412
+ "# rolling_avg_reward >= PHASE_REWARD_TARGETS[current_phase] AND\n",
413
+ "# rolling_win_rate >= PHASE_MIN_WIN_RATE, we either:\n",
414
+ "# - force-promote to the next phase (Phase 1 / Phase 2), OR\n",
415
  "# - terminate training (Phase 3).\n",
416
+ "EARLY_STOP_ENABLED = True\n",
417
+ "PHASE_REWARD_TARGETS = {1: 1.2, 2: 1.1, 3: 1.0}\n",
418
+ "PHASE_MIN_WIN_RATE = 0.20\n",
419
+ "CONVERGENCE_WINDOW = 3\n",
420
+ "\n",
421
+ "# --- Per-episode budget controls (read by triage_env) ----------------------\n",
 
 
 
422
  "os.environ[\"ERMAP_MAX_EPISODE_STEPS\"] = \"20\"\n",
423
  "os.environ[\"ERMAP_MAX_INTERNAL_EXCHANGES\"] = \"5\"\n",
424
+ "\n",
425
+ "# --- Groq traffic-shaping (8B for actors, 70B for judges) ------------------\n",
426
+ "# High-volume conversational roles (Nurse + Patient) on the 8B-instant pool\n",
427
+ "# (500K TPD, 14,400 RPD); the two judges stay on 70B-versatile because their\n",
428
+ "# grading quality directly shapes the reward signal.\n",
 
429
  "os.environ[\"ERMAP_NURSE_MODEL\"] = \"llama-3.1-8b-instant\"\n",
430
  "os.environ[\"ERMAP_PATIENT_MODEL\"] = \"llama-3.1-8b-instant\"\n",
431
  "os.environ[\"ERMAP_EMPATHY_JUDGE_MODEL\"] = \"llama-3.3-70b-versatile\"\n",
432
  "os.environ[\"ERMAP_MEDICAL_JUDGE_MODEL\"] = \"llama-3.3-70b-versatile\"\n",
433
  "\n",
434
+ "print(\"Hyperparameters set:\")\n",
435
+ "print(f\" NUM_EPISODES = {NUM_EPISODES}\")\n",
436
+ "print(f\" GROUP_SIZE = {GROUP_SIZE}\")\n",
437
+ "print(f\" PHASE_REWARD_TARGETS = {PHASE_REWARD_TARGETS}\")\n",
438
+ "print(f\" PHASE_MIN_WIN_RATE = {PHASE_MIN_WIN_RATE}\")\n",
439
+ "print(f\" CONVERGENCE_WINDOW = {CONVERGENCE_WINDOW}\")\n",
440
+ "print(f\" Nurse / Patient = llama-3.1-8b-instant (actors, high-volume)\")\n",
441
+ "print(f\" Empathy / Med Judge = llama-3.3-70b-versatile (graders, quality)\")"
 
 
 
 
442
  ]
443
  },
444
  {
445
+ "cell_type": "code",
446
+ "execution_count": null,
447
  "metadata": {},
448
+ "outputs": [],
449
  "source": [
450
+ "# === CELL 10 β€” Pre-flight: Groq routing + key liveness ===\n",
451
+ "# Verifies that:\n",
452
+ "# - each role is routed to the model you set in cell 9, and\n",
453
+ "# - each role's Groq key actually answers a 1-token \"PING\" prompt.\n",
454
+ "\n",
455
+ "import os\n",
456
+ "from ER_MAP.envs.api_router import AgentRouter\n",
457
+ "\n",
458
+ "router = AgentRouter()\n",
459
+ "expected = {\n",
460
+ " \"nurse\": \"llama-3.1-8b-instant\",\n",
461
+ " \"patient\": \"llama-3.1-8b-instant\",\n",
462
+ " \"empathy_judge\": \"llama-3.3-70b-versatile\",\n",
463
+ " \"medical_judge\": \"llama-3.3-70b-versatile\",\n",
464
+ "}\n",
465
+ "\n",
466
+ "print(\"=\" * 60); print(\" PRE-FLIGHT β€” Groq routing + smoke test\"); print(\"=\" * 60)\n",
467
+ "all_pass = True\n",
468
+ "for role, exp in expected.items():\n",
469
+ " actual = router._models.get(role, \"?\")\n",
470
+ " routing_ok = (actual == exp)\n",
471
+ " client = router._clients.get(role)\n",
472
+ "\n",
473
+ " if client is None:\n",
474
+ " print(f\" [SKIP] {role:14s} -> no Groq client (key missing)\")\n",
475
+ " all_pass = False\n",
476
+ " continue\n",
477
  "\n",
478
+ " try:\n",
479
+ " resp = client.chat.completions.create(\n",
480
+ " model=exp,\n",
481
+ " messages=[{\"role\": \"user\", \"content\": \"Reply with exactly: PING\"}],\n",
482
+ " max_tokens=4, temperature=0,\n",
483
+ " )\n",
484
+ " api_ok = \"PING\" in (resp.choices[0].message.content or \"\").upper()\n",
485
+ " err = \"\"\n",
486
+ " except Exception as e:\n",
487
+ " api_ok = False\n",
488
+ " err = f\" ({type(e).__name__}: {str(e)[:80]})\"\n",
489
+ "\n",
490
+ " flag = \"PASS\" if (routing_ok and api_ok) else \"FAIL\"\n",
491
+ " if flag == \"FAIL\":\n",
492
+ " all_pass = False\n",
493
+ " print(f\" [{flag}] {role:14s} -> {actual:30s} \"\n",
494
+ " f\"routing={'ok' if routing_ok else 'WRONG'}, \"\n",
495
+ " f\"api={'ok' if api_ok else 'fail'}{err}\")\n",
496
+ "\n",
497
+ "print(\"=\" * 60)\n",
498
+ "print(\"OK\" if all_pass else \"NOT OK β€” fix routing/keys before training.\")\n",
499
+ "print(\"=\" * 60)\n",
500
+ "assert all_pass, \"Pre-flight failed; do not proceed to training.\""
501
  ]
502
  },
503
  {
 
506
  "metadata": {},
507
  "outputs": [],
508
  "source": [
509
+ "# === CELL 11 β€” Dry-run smoke test (no GPU, no model load) ===\n",
510
+ "# Verifies the curriculum scheduler + reward verifier + per-phase early-stop\n",
511
+ "# wiring before we burn GPU minutes on the real run.\n",
512
+ "\n",
513
  "from ER_MAP.training.train_grpo import train\n",
514
  "\n",
515
  "_ = train(\n",
 
520
  " kl_beta=KL_BETA,\n",
521
  " output_dir=\"/kaggle/working/_dryrun\",\n",
522
  " dry_run=True,\n",
523
+ " phase_reward_targets=PHASE_REWARD_TARGETS,\n",
524
+ " phase_min_win_rate=PHASE_MIN_WIN_RATE,\n",
525
+ " convergence_window=CONVERGENCE_WINDOW,\n",
526
+ " early_stop=EARLY_STOP_ENABLED,\n",
527
  ")\n",
528
+ "print(\"\\nDry-run OK β€” scheduler + verifier + per-phase early-stop wiring is healthy.\")"
 
 
 
 
 
 
 
 
 
529
  ]
530
  },
531
  {
 
534
  "metadata": {},
535
  "outputs": [],
536
  "source": [
537
+ "# === CELL 12 β€” Wire periodic HF Hub push into training ===\n",
538
+ "# We monkey-patch save_lora_adapters so every checkpoint dump also pushes\n",
539
+ "# the LoRA adapter to HF Hub. Failures are non-fatal β€” training keeps\n",
540
+ "# running even if a push fails (e.g. transient HF 502).\n",
541
+ "\n",
542
  "from ER_MAP.training import train_grpo as _tg\n",
543
  "_original_save = _tg.save_lora_adapters\n",
 
544
  "\n",
545
  "def save_lora_adapters_with_push(model, tokenizer, output_dir):\n",
546
  " _original_save(model, tokenizer, output_dir)\n",
 
547
  " if HF_PUSH_REPO and \"<your-username>\" not in HF_PUSH_REPO:\n",
548
  " try:\n",
549
  " push_checkpoint_to_hub(\n",
550
+ " output_dir, HF_PUSH_REPO,\n",
 
551
  " commit_message=f\"checkpoint @ {os.path.basename(output_dir)}\",\n",
552
  " )\n",
553
  " except Exception as e:\n",
 
561
  "cell_type": "markdown",
562
  "metadata": {},
563
  "source": [
564
+ "## 13 Β· Run real training (the 4–6 hour cell)\n",
565
+ "\n",
566
+ "**Estimated wall-clock on Kaggle T4 Γ—2:**\n",
567
+ "\n",
568
+ "- ~3–5 min per episode (6–14 env steps Γ— Doctor.generate + 4–8 Γ— Groq calls)\n",
569
+ "- ~1–2 min amortized per GRPO update (G=2 trajectories Γ— response-token log-probs)\n",
570
+ "- **Per-group β‰ˆ 8–12 min** (2 episodes + 1 update)\n",
571
+ "\n",
572
+ "| Phase | Typical episodes to reach target | Wall-clock |\n",
573
+ "|---|---|---|\n",
574
+ "| 1 (target `+1.2` Γ— 3) | 12 – 24 episodes (6 – 12 groups) | ~1.0 – 2.0 h |\n",
575
+ "| 2 (target `+1.1` Γ— 3) | 16 – 32 episodes (8 – 16 groups) | ~1.5 – 2.5 h |\n",
576
+ "| 3 (target `+1.0` Γ— 3) | 20 – 50 episodes (10 – 25 groups) | ~2.0 – 4.0 h |\n",
577
+ "| **Total** | 50 – 100 episodes | **~4.5 – 8.5 h** |\n",
578
+ "\n",
579
+ "If `NUM_EPISODES=200` is exhausted before Phase 3 converges, training\n",
580
+ "stops at the cap and the latest LoRA checkpoint is on HF Hub already\n",
581
+ "(we push every 20 episodes), so resume in a fresh session via\n",
582
+ "`HF_RESUME_REPO` in cell 8."
583
  ]
584
  },
585
  {
 
588
  "metadata": {},
589
  "outputs": [],
590
  "source": [
591
+ "# === CELL 13 β€” REAL TRAINING (4-6 h cell) ===\n",
592
  "metrics = train(\n",
593
  " num_episodes=NUM_EPISODES,\n",
594
  " group_size=GROUP_SIZE,\n",
595
  " model_name=MODEL_NAME,\n",
596
+ " groq_api_key=os.environ.get(\"GROQ_NURSE_API_KEY\", \"\")\n",
597
+ " or os.environ.get(\"GROQ_API_KEY\", \"\"),\n",
598
  " learning_rate=LEARNING_RATE,\n",
599
  " kl_beta=KL_BETA,\n",
600
  " use_wandb=USE_WANDB,\n",
601
  " output_dir=OUTPUT_DIR,\n",
602
  " dry_run=False,\n",
 
603
  " phase_reward_targets=PHASE_REWARD_TARGETS,\n",
604
  " phase_min_win_rate=PHASE_MIN_WIN_RATE,\n",
605
  " convergence_window=CONVERGENCE_WINDOW,\n",
606
  " early_stop=EARLY_STOP_ENABLED,\n",
607
+ ")\n",
608
+ "print(f\"\\nTraining returned {len(metrics)} metric records.\")"
 
 
 
 
 
 
 
 
609
  ]
610
  },
611
  {
 
614
  "metadata": {},
615
  "outputs": [],
616
  "source": [
617
+ "# === CELL 14 β€” Final push: adapters + merged fp16 ===\n",
618
  "FINAL_LORA_DIR = f\"{OUTPUT_DIR}/final_lora\"\n",
619
  "FINAL_MERGED_DIR = f\"{OUTPUT_DIR}/final_merged_fp16\"\n",
620
  "\n",
621
  "if HF_PUSH_REPO and \"<your-username>\" not in HF_PUSH_REPO:\n",
622
+ " push_checkpoint_to_hub(FINAL_LORA_DIR, HF_PUSH_REPO,\n",
623
+ " commit_message=\"final LoRA adapter\")\n",
 
 
624
  " if os.path.isdir(FINAL_MERGED_DIR):\n",
625
+ " push_checkpoint_to_hub(FINAL_MERGED_DIR, f\"{HF_PUSH_REPO}-merged\",\n",
626
+ " commit_message=\"final merged fp16\")\n",
627
+ " print(f\"Final checkpoints pushed: https://huggingface.co/{HF_PUSH_REPO}\")\n",
 
 
628
  "else:\n",
629
  " print(\"HF_PUSH_REPO not configured β€” skipping final push.\")"
630
  ]
 
633
  "cell_type": "markdown",
634
  "metadata": {},
635
  "source": [
636
+ "## 15 Β· Per-phase training graphs (one dashboard per curriculum phase)\n",
637
+ "\n",
638
+ "We render a 6-panel dashboard for **every phase that contains episodes**,\n",
639
+ "plus a cross-phase overview and a phase-comparison bar chart. All PNGs are\n",
640
+ "written to `er_map_grpo_checkpoints/plots/` and uploaded to HF Hub in the\n",
641
+ "next cell so they survive Kaggle session expiry.\n",
642
  "\n",
643
+ "Each per-phase dashboard contains:\n",
644
  "\n",
 
645
  "1. **Reward growth** β€” raw scatter + rolling mean (w=10) + verified rolling mean\n",
646
+ "2. **Rolling win rate** β€” w=20 win-rate evolution within the phase\n",
647
+ "3. **Outcome distribution over time** β€” stacked bars (WIN/PARTIAL/INCORRECT/AMA_LOSS/FATAL_LOSS)\n",
648
+ "4. **Reward components** β€” mean of each component (process / treatment / empathy / labs / etc.)\n",
649
+ "5. **GRPO update stats** β€” loss + KL divergence per group update\n",
650
  "6. **Episode length distribution** β€” histogram of step counts"
651
  ]
652
  },
 
656
  "metadata": {},
657
  "outputs": [],
658
  "source": [
659
+ "# === CELL 15 β€” Per-phase training dashboards ===\n",
660
  "from ER_MAP.plotting import plot_per_phase_dashboards\n",
661
  "from IPython.display import Image, display, Markdown\n",
662
  "\n",
 
666
  " output_dir=PLOTS_DIR,\n",
667
  ")\n",
668
  "\n",
669
+ "print(f\"Saved {len(written)} chart(s) to {PLOTS_DIR}:\")\n",
670
  "for name, path in written.items():\n",
671
  " size_kb = os.path.getsize(path) / 1024\n",
672
  " print(f\" {name:<28s} -> {path} ({size_kb:.0f} KB)\")\n",
673
  "\n",
674
+ "# Display each chart inline so the operator sees them without leaving Kaggle.\n",
675
+ "ordered = (sorted(k for k in written if k.startswith(\"phase\"))\n",
676
+ " + [\"all_phases_overview\", \"all_phases_comparison\"])\n",
677
+ "for key in ordered:\n",
 
 
 
 
678
  " if key not in written:\n",
679
  " continue\n",
680
  " display(Markdown(f\"### {key.replace('_', ' ').title()}\"))\n",
681
  " display(Image(filename=written[key]))"
682
  ]
683
  },
 
 
 
 
 
 
 
684
  {
685
  "cell_type": "code",
686
  "execution_count": null,
687
  "metadata": {},
688
  "outputs": [],
689
  "source": [
690
+ "# === CELL 16 β€” Push plots to HF Hub ===\n",
691
  "if HF_PUSH_REPO and \"<your-username>\" not in HF_PUSH_REPO:\n",
692
+ " push_checkpoint_to_hub(PLOTS_DIR, HF_PUSH_REPO,\n",
693
+ " commit_message=\"per-phase training plots\")\n",
694
+ " print(f\"Plots pushed: https://huggingface.co/{HF_PUSH_REPO}/tree/main\")\n",
 
695
  "else:\n",
696
+ " print(\"HF_PUSH_REPO not configured β€” plots stay only in /kaggle/working/.\")"
697
  ]
698
  },
699
  {
700
  "cell_type": "markdown",
701
  "metadata": {},
702
  "source": [
703
+ "## 17 Β· (Optional) Inference smoke-test on the trained model\n",
704
  "\n",
705
+ "Catches the classic 'merge path looked OK but the saved model emits garbage'\n",
706
+ "failure mode before the demo."
707
  ]
708
  },
709
  {
 
712
  "metadata": {},
713
  "outputs": [],
714
  "source": [
715
+ "# === CELL 17 β€” Inference smoke-test on the trained model ===\n",
716
  "from ER_MAP.training.train_grpo import generate_doctor_action, load_model_and_tokenizer\n",
717
  "from peft import PeftModel\n",
718
  "\n",
 
743
  },
744
  "nbformat": 4,
745
  "nbformat_minor": 5
746
+ }