Spaces:
Sleeping
chore(kaggle): rebuild notebook v3 + clean dev-scratch files
Browse files* New `kaggle/build_notebook.py` - single source of truth that regenerates
the Kaggle notebook + KAGGLE_QUICKSTART.md from scratch (run once whenever
the layout drifts).
* New `kaggle/KAGGLE_QUICKSTART.md` - concise step-by-step run order with
troubleshooting table for every dependency-hell symptom we hit.
* `kaggle/train_ermap_grpo_kaggle.ipynb` rebuilt as a clean 20-cell layout:
- new constant per-phase reward thresholds (P1=+1.2 / P2=+1.1 / P3=+1.0)
- idempotent REPAIR cell (pins torch 2.10 cu128 + bnb + unsloth/zoo + trl,
verifies in a subprocess so the kernel never gets poisoned mid-install)
- pre-flight Groq routing + PING smoke test using router._models / _clients
- explicit dry-run + HF-push hook + per-phase dashboards + final push.
* `train_grpo.py`: add optional `phase_episode_budgets` (fixed-budget
curriculum mode) alongside the existing reward-threshold early-stop.
Fully backward compatible (default behaviour unchanged); CLI flags
--phase{1,2,3}-budget.
* Remove dev-scratch files: `_smoke_dead_keys.py`, `ER_MAP/_verify.py`,
`ER_MAP/_replot.py`, `kaggle/KAGGLE.md` (replaced by KAGGLE_QUICKSTART).
Made-with: Cursor
- ER_MAP/_replot.py +0 -11
- ER_MAP/_verify.py +0 -15
- ER_MAP/training/train_grpo.py +94 -1
- _smoke_dead_keys.py +0 -333
- kaggle/KAGGLE.md +0 -265
- kaggle/KAGGLE_QUICKSTART.md +104 -0
- kaggle/build_notebook.py +880 -0
- kaggle/train_ermap_grpo_kaggle.ipynb +478 -271
|
@@ -1,11 +0,0 @@
|
|
| 1 |
-
"""Quick script to regenerate reward curve from saved eval_results.json"""
|
| 2 |
-
import json
|
| 3 |
-
import sys
|
| 4 |
-
sys.path.insert(0, ".")
|
| 5 |
-
from ER_MAP.evaluate import plot_reward_curve
|
| 6 |
-
|
| 7 |
-
with open("d:/Meta_Finals/ER_MAP/eval_results.json") as f:
|
| 8 |
-
results = json.load(f)
|
| 9 |
-
|
| 10 |
-
plot_reward_curve(results, "d:/Meta_Finals/ER_MAP/reward_curve.png")
|
| 11 |
-
print("Done!")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -1,15 +0,0 @@
|
|
| 1 |
-
from ER_MAP.envs.randomizer import DISEASE_POOL, DIFFICULTY_TIERS, generate_ground_truth
|
| 2 |
-
|
| 3 |
-
print(f"=== {len(DISEASE_POOL)} DISEASES ===")
|
| 4 |
-
for d in DISEASE_POOL:
|
| 5 |
-
print(f" {d['true_disease']}")
|
| 6 |
-
|
| 7 |
-
print()
|
| 8 |
-
print("=== DIFFICULTY TIERS ===")
|
| 9 |
-
for tier in ["easy", "medium", "hard"]:
|
| 10 |
-
gt = generate_ground_truth(difficulty=tier)
|
| 11 |
-
p = gt["patient"]
|
| 12 |
-
print(f" {tier.upper():8s} | compliance: {p['compliance']:20s} | comm: {p['communication']:20s} | {gt['disease']['true_disease']}")
|
| 13 |
-
|
| 14 |
-
combos = 3 * 4 * 4 * 4 * 4 * 3 * 3 * 3 * 3 * 15
|
| 15 |
-
print(f"\nTotal unique scenario combinations: {combos:,}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -686,9 +686,47 @@ def train(
|
|
| 686 |
phase_min_win_rate: float = 0.20,
|
| 687 |
convergence_window: int = 3,
|
| 688 |
early_stop: bool = True,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 689 |
):
|
| 690 |
if phase_reward_targets is None:
|
| 691 |
phase_reward_targets = {1: 1.5, 2: 1.2, 3: 1.0}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 692 |
"""
|
| 693 |
Main GRPO training loop with curriculum scheduling.
|
| 694 |
|
|
@@ -758,7 +796,17 @@ def train(
|
|
| 758 |
|
| 759 |
logger.info(f"\nStarting GRPO training for up to {num_episodes} episodes "
|
| 760 |
f"(={num_episodes // group_size} GRPO updates)")
|
| 761 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 762 |
logger.info(
|
| 763 |
f" Early-stop ON: per-phase reward thresholds (sustained for "
|
| 764 |
f"{convergence_window} groups, win-rate >= {phase_min_win_rate:.0%}):"
|
|
@@ -955,6 +1003,32 @@ def train(
|
|
| 955 |
f"Phase Episodes: {s['phase_episodes']}"
|
| 956 |
)
|
| 957 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 958 |
# --- Per-phase early-stop / promotion check ------------------------
|
| 959 |
# Maintain a buffer of the last `convergence_window` GRPO groups
|
| 960 |
# with their (phase, rolling_avg, rolling_win). When ALL N entries
|
|
@@ -1099,8 +1173,26 @@ if __name__ == "__main__":
|
|
| 1099 |
parser.add_argument("--no-early-stop", action="store_true",
|
| 1100 |
help="Disable early-stop (always run all configured episodes)")
|
| 1101 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1102 |
args = parser.parse_args()
|
| 1103 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1104 |
train(
|
| 1105 |
num_episodes=args.episodes,
|
| 1106 |
group_size=args.group_size,
|
|
@@ -1119,4 +1211,5 @@ if __name__ == "__main__":
|
|
| 1119 |
phase_min_win_rate=args.phase_min_win_rate,
|
| 1120 |
convergence_window=args.convergence_window,
|
| 1121 |
early_stop=not args.no_early_stop,
|
|
|
|
| 1122 |
)
|
|
|
|
| 686 |
phase_min_win_rate: float = 0.20,
|
| 687 |
convergence_window: int = 3,
|
| 688 |
early_stop: bool = True,
|
| 689 |
+
# ---------------- Fixed-budget curriculum (alternative mode) -----------
|
| 690 |
+
# When set, training advances phases at FIXED episode counts instead of
|
| 691 |
+
# via the reward-target early-stop. Useful when you want a clean
|
| 692 |
+
# reward-growth curve over a known wall-clock budget. Example:
|
| 693 |
+
# phase_episode_budgets = {1: 20, 2: 30, 3: 50} # 100 episodes total
|
| 694 |
+
# When this is provided, `early_stop` is forced to False (the reward
|
| 695 |
+
# thresholds become observational, logged for plots only) and
|
| 696 |
+
# `num_episodes` is auto-set to sum(phase_episode_budgets.values()) if
|
| 697 |
+
# the caller passed a smaller / inconsistent value.
|
| 698 |
+
phase_episode_budgets: Optional[Dict[int, int]] = None,
|
| 699 |
):
|
| 700 |
if phase_reward_targets is None:
|
| 701 |
phase_reward_targets = {1: 1.5, 2: 1.2, 3: 1.0}
|
| 702 |
+
|
| 703 |
+
# Fixed-budget mode overrides early-stop and aligns num_episodes.
|
| 704 |
+
fixed_budget_mode = phase_episode_budgets is not None and len(phase_episode_budgets) > 0
|
| 705 |
+
if fixed_budget_mode:
|
| 706 |
+
# Sanity: must have all 3 phases keyed, all positive ints
|
| 707 |
+
missing = [p for p in (1, 2, 3) if p not in phase_episode_budgets]
|
| 708 |
+
if missing:
|
| 709 |
+
raise ValueError(
|
| 710 |
+
f"phase_episode_budgets must include all phases (1,2,3); missing: {missing}"
|
| 711 |
+
)
|
| 712 |
+
for _p, _n in phase_episode_budgets.items():
|
| 713 |
+
if not isinstance(_n, int) or _n <= 0:
|
| 714 |
+
raise ValueError(
|
| 715 |
+
f"phase_episode_budgets[{_p}] must be a positive int, got {_n!r}"
|
| 716 |
+
)
|
| 717 |
+
budget_sum = sum(phase_episode_budgets.values())
|
| 718 |
+
if num_episodes != budget_sum:
|
| 719 |
+
logger.info(
|
| 720 |
+
f"[Fixed-budget] num_episodes ({num_episodes}) overridden to "
|
| 721 |
+
f"sum(phase_episode_budgets) = {budget_sum}"
|
| 722 |
+
)
|
| 723 |
+
num_episodes = budget_sum
|
| 724 |
+
if early_stop:
|
| 725 |
+
logger.info(
|
| 726 |
+
"[Fixed-budget] early_stop=True is incompatible with fixed budgets; "
|
| 727 |
+
"disabling early_stop. Reward targets will still be logged for plots."
|
| 728 |
+
)
|
| 729 |
+
early_stop = False
|
| 730 |
"""
|
| 731 |
Main GRPO training loop with curriculum scheduling.
|
| 732 |
|
|
|
|
| 796 |
|
| 797 |
logger.info(f"\nStarting GRPO training for up to {num_episodes} episodes "
|
| 798 |
f"(={num_episodes // group_size} GRPO updates)")
|
| 799 |
+
if fixed_budget_mode:
|
| 800 |
+
logger.info(
|
| 801 |
+
" Fixed-budget curriculum: phases advance at fixed episode counts."
|
| 802 |
+
)
|
| 803 |
+
for _pid in sorted(phase_episode_budgets.keys()):
|
| 804 |
+
logger.info(
|
| 805 |
+
f" Phase {_pid}: {phase_episode_budgets[_pid]} episodes "
|
| 806 |
+
f"(target avg-reward {phase_reward_targets.get(_pid, float('nan')):+.2f}, "
|
| 807 |
+
f"observational only)"
|
| 808 |
+
)
|
| 809 |
+
elif early_stop:
|
| 810 |
logger.info(
|
| 811 |
f" Early-stop ON: per-phase reward thresholds (sustained for "
|
| 812 |
f"{convergence_window} groups, win-rate >= {phase_min_win_rate:.0%}):"
|
|
|
|
| 1003 |
f"Phase Episodes: {s['phase_episodes']}"
|
| 1004 |
)
|
| 1005 |
|
| 1006 |
+
# --- Fixed-budget phase transition ---------------------------------
|
| 1007 |
+
# When the operator pre-allocates per-phase episode budgets (e.g.
|
| 1008 |
+
# P1=20, P2=30, P3=50), advance at the boundaries regardless of
|
| 1009 |
+
# reward. Phase 3 budget exhaustion lets the outer-loop
|
| 1010 |
+
# `num_episodes` cap end training naturally.
|
| 1011 |
+
if fixed_budget_mode:
|
| 1012 |
+
current_phase = s["phase"]
|
| 1013 |
+
budget = phase_episode_budgets.get(current_phase, 0)
|
| 1014 |
+
if (
|
| 1015 |
+
current_phase < 3
|
| 1016 |
+
and s["phase_episodes"] >= budget
|
| 1017 |
+
):
|
| 1018 |
+
promoted = scheduler.force_promote(
|
| 1019 |
+
reason=(
|
| 1020 |
+
f"fixed-budget: completed {s['phase_episodes']} episodes "
|
| 1021 |
+
f"in Phase {current_phase} (budget={budget})"
|
| 1022 |
+
)
|
| 1023 |
+
)
|
| 1024 |
+
if promoted:
|
| 1025 |
+
new_phase = scheduler.phase_id
|
| 1026 |
+
logger.info(
|
| 1027 |
+
f" [Fixed-budget] Phase {current_phase} budget exhausted "
|
| 1028 |
+
f"-> Phase {new_phase}: {scheduler.current_phase.name} "
|
| 1029 |
+
f"({phase_episode_budgets.get(new_phase, '?')} episodes allocated)"
|
| 1030 |
+
)
|
| 1031 |
+
|
| 1032 |
# --- Per-phase early-stop / promotion check ------------------------
|
| 1033 |
# Maintain a buffer of the last `convergence_window` GRPO groups
|
| 1034 |
# with their (phase, rolling_avg, rolling_win). When ALL N entries
|
|
|
|
| 1173 |
parser.add_argument("--no-early-stop", action="store_true",
|
| 1174 |
help="Disable early-stop (always run all configured episodes)")
|
| 1175 |
|
| 1176 |
+
# Fixed-budget curriculum (mutually exclusive with early-stop)
|
| 1177 |
+
parser.add_argument("--phase1-budget", type=int, default=None,
|
| 1178 |
+
help="Fixed episode budget for Phase 1 (Tool Mastery)")
|
| 1179 |
+
parser.add_argument("--phase2-budget", type=int, default=None,
|
| 1180 |
+
help="Fixed episode budget for Phase 2 (Clinical Reasoning)")
|
| 1181 |
+
parser.add_argument("--phase3-budget", type=int, default=None,
|
| 1182 |
+
help="Fixed episode budget for Phase 3 (Empathetic Negotiation)")
|
| 1183 |
+
|
| 1184 |
args = parser.parse_args()
|
| 1185 |
|
| 1186 |
+
_budgets = None
|
| 1187 |
+
if any(b is not None for b in (args.phase1_budget, args.phase2_budget, args.phase3_budget)):
|
| 1188 |
+
if not all(b is not None for b in (args.phase1_budget, args.phase2_budget, args.phase3_budget)):
|
| 1189 |
+
parser.error("--phase{1,2,3}-budget must all be set together")
|
| 1190 |
+
_budgets = {
|
| 1191 |
+
1: args.phase1_budget,
|
| 1192 |
+
2: args.phase2_budget,
|
| 1193 |
+
3: args.phase3_budget,
|
| 1194 |
+
}
|
| 1195 |
+
|
| 1196 |
train(
|
| 1197 |
num_episodes=args.episodes,
|
| 1198 |
group_size=args.group_size,
|
|
|
|
| 1211 |
phase_min_win_rate=args.phase_min_win_rate,
|
| 1212 |
convergence_window=args.convergence_window,
|
| 1213 |
early_stop=not args.no_early_stop,
|
| 1214 |
+
phase_episode_budgets=_budgets,
|
| 1215 |
)
|
|
@@ -1,333 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Smoke test: simulate the EXACT failure mode from the user's last log
|
| 3 |
-
(Patient + Nurse keys revoked, Doctor + Judge keys alive) and verify
|
| 4 |
-
that:
|
| 5 |
-
|
| 6 |
-
1. AgentRouter.query falls back to a live judge client
|
| 7 |
-
2. DoctorBrain's chain advances past the dead Doctor key (we'll also
|
| 8 |
-
simulate Doctor revocation)
|
| 9 |
-
3. TTS emotion adapter gets disabled after first 401 (no spam)
|
| 10 |
-
|
| 11 |
-
Runs from the repo root: ``python _smoke_dead_keys.py``
|
| 12 |
-
"""
|
| 13 |
-
from __future__ import annotations
|
| 14 |
-
|
| 15 |
-
import os
|
| 16 |
-
import sys
|
| 17 |
-
import importlib
|
| 18 |
-
|
| 19 |
-
REPO_ROOT = os.path.abspath(os.path.dirname(__file__))
|
| 20 |
-
if REPO_ROOT not in sys.path:
|
| 21 |
-
sys.path.insert(0, REPO_ROOT)
|
| 22 |
-
|
| 23 |
-
CHECKS_PASSED = 0
|
| 24 |
-
CHECKS_FAILED = 0
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
def check(label, ok, detail=""):
|
| 28 |
-
global CHECKS_PASSED, CHECKS_FAILED
|
| 29 |
-
tag = "PASS" if ok else "FAIL"
|
| 30 |
-
if ok:
|
| 31 |
-
CHECKS_PASSED += 1
|
| 32 |
-
else:
|
| 33 |
-
CHECKS_FAILED += 1
|
| 34 |
-
line = f" [{tag}] {label}"
|
| 35 |
-
if detail:
|
| 36 |
-
line += f" -- {detail}"
|
| 37 |
-
print(line, flush=True)
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
# ---------------------------------------------------------------------------
|
| 41 |
-
# 1. AgentRouter fallback chain (Patient + Nurse dead, judges alive)
|
| 42 |
-
# ---------------------------------------------------------------------------
|
| 43 |
-
print("\n--- Test 1: AgentRouter.query with patient+nurse dead ---", flush=True)
|
| 44 |
-
|
| 45 |
-
# Inject env vars BEFORE importing dashboard so demo defaults still get set.
|
| 46 |
-
os.environ.setdefault("GROQ_DOCTOR_API_KEY", "gsk_dummy_doctor")
|
| 47 |
-
os.environ.setdefault("GROQ_NURSE_API_KEY", "gsk_dummy_nurse")
|
| 48 |
-
os.environ.setdefault("GROQ_PATIENT_API_KEY", "gsk_dummy_patient")
|
| 49 |
-
os.environ.setdefault("GROQ_EMPATHY_JUDGE_API_KEY", "gsk_dummy_judge")
|
| 50 |
-
os.environ.setdefault("GROQ_MEDICAL_JUDGE_API_KEY", "gsk_dummy_judge")
|
| 51 |
-
|
| 52 |
-
from ER_MAP.envs import api_router as _api_router_mod # noqa: E402
|
| 53 |
-
|
| 54 |
-
class _MockResp:
|
| 55 |
-
def __init__(self, content):
|
| 56 |
-
self.choices = [type("C", (), {"message": type("M", (), {"content": content})()})()]
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
class _MockClient:
|
| 60 |
-
"""Mock Groq client that either succeeds or raises a 401."""
|
| 61 |
-
|
| 62 |
-
def __init__(self, name, dead=False, payload='{"action":"speak","content":"OK"}'):
|
| 63 |
-
self.name = name
|
| 64 |
-
self.dead = dead
|
| 65 |
-
self.payload = payload
|
| 66 |
-
self.calls = 0
|
| 67 |
-
self.chat = type("Chat", (), {"completions": self})()
|
| 68 |
-
|
| 69 |
-
def create(self, **kw):
|
| 70 |
-
self.calls += 1
|
| 71 |
-
if self.dead:
|
| 72 |
-
raise Exception(
|
| 73 |
-
f"Error code: 401 - {{'error': {{'message': 'Invalid API Key', "
|
| 74 |
-
f"'type': 'invalid_request_error', 'code': 'invalid_api_key'}}}}"
|
| 75 |
-
)
|
| 76 |
-
return _MockResp(self.payload)
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
router = _api_router_mod.AgentRouter(
|
| 80 |
-
api_key="x",
|
| 81 |
-
nurse_api_key="x",
|
| 82 |
-
patient_api_key="x",
|
| 83 |
-
empathy_judge_api_key="x",
|
| 84 |
-
medical_judge_api_key="x",
|
| 85 |
-
)
|
| 86 |
-
|
| 87 |
-
# Patch all 4 clients with mocks: patient + nurse are dead, judges are alive.
|
| 88 |
-
mock_clients = {
|
| 89 |
-
"nurse": _MockClient("nurse", dead=True),
|
| 90 |
-
"patient": _MockClient("patient", dead=True),
|
| 91 |
-
"empathy_judge": _MockClient("empathy_judge", dead=False),
|
| 92 |
-
"medical_judge": _MockClient("medical_judge", dead=False),
|
| 93 |
-
}
|
| 94 |
-
router._clients = mock_clients
|
| 95 |
-
router._dead_clients = set() # let runtime detect deadness through the cascade
|
| 96 |
-
|
| 97 |
-
# Query as nurse β should walk nurse -> patient -> medical_judge and succeed.
|
| 98 |
-
result = router.query("nurse", "system", [{"role": "user", "content": "hi"}])
|
| 99 |
-
|
| 100 |
-
check(
|
| 101 |
-
"router.query('nurse') falls through to a live judge",
|
| 102 |
-
result.get("action") == "speak",
|
| 103 |
-
f"got {result}",
|
| 104 |
-
)
|
| 105 |
-
check(
|
| 106 |
-
"nurse client was attempted",
|
| 107 |
-
mock_clients["nurse"].calls == 1,
|
| 108 |
-
f"calls={mock_clients['nurse'].calls}",
|
| 109 |
-
)
|
| 110 |
-
check(
|
| 111 |
-
"patient client was attempted (next in chain)",
|
| 112 |
-
mock_clients["patient"].calls == 1,
|
| 113 |
-
f"calls={mock_clients['patient'].calls}",
|
| 114 |
-
)
|
| 115 |
-
check(
|
| 116 |
-
"medical_judge client served the request",
|
| 117 |
-
mock_clients["medical_judge"].calls == 1,
|
| 118 |
-
f"calls={mock_clients['medical_judge'].calls}",
|
| 119 |
-
)
|
| 120 |
-
check(
|
| 121 |
-
"empathy_judge NOT called once medical_judge succeeded",
|
| 122 |
-
mock_clients["empathy_judge"].calls == 0,
|
| 123 |
-
f"calls={mock_clients['empathy_judge'].calls}",
|
| 124 |
-
)
|
| 125 |
-
check(
|
| 126 |
-
"nurse marked dead in router state",
|
| 127 |
-
"nurse" in router._dead_clients,
|
| 128 |
-
)
|
| 129 |
-
check(
|
| 130 |
-
"patient marked dead in router state",
|
| 131 |
-
"patient" in router._dead_clients,
|
| 132 |
-
)
|
| 133 |
-
|
| 134 |
-
# Subsequent queries should skip dead clients entirely.
|
| 135 |
-
mock_clients["medical_judge"].calls = 0
|
| 136 |
-
mock_clients["empathy_judge"].calls = 0
|
| 137 |
-
mock_clients["nurse"].calls = 0
|
| 138 |
-
mock_clients["patient"].calls = 0
|
| 139 |
-
result2 = router.query("patient", "system", [{"role": "user", "content": "hi"}])
|
| 140 |
-
check(
|
| 141 |
-
"second query (patient) skips dead clients",
|
| 142 |
-
mock_clients["nurse"].calls == 0 and mock_clients["patient"].calls == 0,
|
| 143 |
-
)
|
| 144 |
-
check(
|
| 145 |
-
"second query reaches a live judge",
|
| 146 |
-
result2.get("action") == "speak",
|
| 147 |
-
f"got {result2}",
|
| 148 |
-
)
|
| 149 |
-
|
| 150 |
-
# ---------------------------------------------------------------------------
|
| 151 |
-
# 2. DoctorBrain key chain
|
| 152 |
-
# ---------------------------------------------------------------------------
|
| 153 |
-
print("\n--- Test 2: DoctorBrain with primary key dead, fallback alive ---", flush=True)
|
| 154 |
-
|
| 155 |
-
from ER_MAP import dashboard as _dash # noqa: E402
|
| 156 |
-
|
| 157 |
-
class _DoctorMockChat:
|
| 158 |
-
def __init__(self, owner):
|
| 159 |
-
self.owner = owner
|
| 160 |
-
self.completions = self
|
| 161 |
-
|
| 162 |
-
def create(self, **kw):
|
| 163 |
-
self.owner.calls += 1
|
| 164 |
-
if self.owner.dead:
|
| 165 |
-
raise Exception("Error code: 401 - {'error': {'code': 'invalid_api_key'}}")
|
| 166 |
-
return _MockResp('{"action":"read_soap","content":"check the chart first"}')
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
class _DoctorMockGroq:
|
| 170 |
-
def __init__(self, dead=False):
|
| 171 |
-
self.dead = dead
|
| 172 |
-
self.calls = 0
|
| 173 |
-
self.chat = _DoctorMockChat(self)
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
# Build a brain with 3 keys: primary dead, second dead, third alive.
|
| 177 |
-
brain = _dash.DoctorBrain(
|
| 178 |
-
api_key="key1",
|
| 179 |
-
fallback_api_keys=["key2", "key3"],
|
| 180 |
-
model="llama-3.1-8b-instant",
|
| 181 |
-
)
|
| 182 |
-
|
| 183 |
-
# Replace each entry's client with our mock.
|
| 184 |
-
brain._chain[0]["client"] = _DoctorMockGroq(dead=True) # key1 dead
|
| 185 |
-
brain._chain[1]["client"] = _DoctorMockGroq(dead=True) # key2 dead
|
| 186 |
-
brain._chain[2]["client"] = _DoctorMockGroq(dead=False) # key3 alive
|
| 187 |
-
|
| 188 |
-
reply = brain.decide("Patient is here. Vitals pending.")
|
| 189 |
-
check(
|
| 190 |
-
"DoctorBrain walks past 2 dead keys and uses the 3rd",
|
| 191 |
-
'"action":"read_soap"' in reply or "'action': 'read_soap'" in reply,
|
| 192 |
-
f"reply={reply[:120]}",
|
| 193 |
-
)
|
| 194 |
-
check(
|
| 195 |
-
"key1 marked dead",
|
| 196 |
-
brain._chain[0]["dead"] is True,
|
| 197 |
-
)
|
| 198 |
-
check(
|
| 199 |
-
"key2 marked dead",
|
| 200 |
-
brain._chain[1]["dead"] is True,
|
| 201 |
-
)
|
| 202 |
-
check(
|
| 203 |
-
"key3 still alive",
|
| 204 |
-
brain._chain[2]["dead"] is False,
|
| 205 |
-
)
|
| 206 |
-
check(
|
| 207 |
-
"key3 actually answered (call count)",
|
| 208 |
-
brain._chain[2]["client"].calls == 1,
|
| 209 |
-
f"calls={brain._chain[2]['client'].calls}",
|
| 210 |
-
)
|
| 211 |
-
|
| 212 |
-
# Second decide() should jump straight to key3 β no retries on the dead ones.
|
| 213 |
-
brain._chain[0]["client"].calls = 0
|
| 214 |
-
brain._chain[1]["client"].calls = 0
|
| 215 |
-
brain._chain[2]["client"].calls = 0
|
| 216 |
-
brain.decide("Now consider next step.")
|
| 217 |
-
check(
|
| 218 |
-
"second decide() skips dead keys (no extra calls on key1/key2)",
|
| 219 |
-
brain._chain[0]["client"].calls == 0 and brain._chain[1]["client"].calls == 0,
|
| 220 |
-
)
|
| 221 |
-
check(
|
| 222 |
-
"second decide() served by key3 again",
|
| 223 |
-
brain._chain[2]["client"].calls == 1,
|
| 224 |
-
)
|
| 225 |
-
|
| 226 |
-
# All 3 dead β falls back to _smart_fallback_action (no crash).
|
| 227 |
-
brain2 = _dash.DoctorBrain(
|
| 228 |
-
api_key="k1",
|
| 229 |
-
fallback_api_keys=["k2"],
|
| 230 |
-
model="llama-3.1-8b-instant",
|
| 231 |
-
)
|
| 232 |
-
brain2._chain[0]["client"] = _DoctorMockGroq(dead=True)
|
| 233 |
-
brain2._chain[1]["client"] = _DoctorMockGroq(dead=True)
|
| 234 |
-
reply3 = brain2.decide("Patient is here.")
|
| 235 |
-
check(
|
| 236 |
-
"all keys dead -> _smart_fallback_action returns valid JSON",
|
| 237 |
-
reply3.startswith("{") and ('"tool"' in reply3 or '"action"' in reply3),
|
| 238 |
-
f"reply={reply3[:120]}",
|
| 239 |
-
)
|
| 240 |
-
|
| 241 |
-
# ---------------------------------------------------------------------------
|
| 242 |
-
# 3. TTS emotion adapter auto-disable on 401
|
| 243 |
-
# ---------------------------------------------------------------------------
|
| 244 |
-
print("\n--- Test 3: TTS emotion adapter shuts down after first 401 ---", flush=True)
|
| 245 |
-
|
| 246 |
-
from ER_MAP import tts_engine as _tts # noqa: E402
|
| 247 |
-
|
| 248 |
-
# Make sure ElevenLabs is forced off so we don't hit network.
|
| 249 |
-
os.environ["ERMAP_DISABLE_ELEVENLABS"] = "1"
|
| 250 |
-
|
| 251 |
-
eng = _tts.TTSEngine(elevenlabs_api_key="", groq_api_key="dummy")
|
| 252 |
-
|
| 253 |
-
# Replace its Groq client with a mock that always raises 401.
|
| 254 |
-
class _AlwaysAuthFail:
|
| 255 |
-
def __init__(self):
|
| 256 |
-
self.calls = 0
|
| 257 |
-
self.chat = self
|
| 258 |
-
self.completions = self
|
| 259 |
-
|
| 260 |
-
def create(self, **kw):
|
| 261 |
-
self.calls += 1
|
| 262 |
-
raise Exception("Error code: 401 - {'error': {'code': 'invalid_api_key'}}")
|
| 263 |
-
|
| 264 |
-
mock_groq = _AlwaysAuthFail()
|
| 265 |
-
eng._groq_client = mock_groq
|
| 266 |
-
|
| 267 |
-
# Trigger the adapter: status helper should report auth=True the first time.
|
| 268 |
-
text1, auth1 = _tts._emotionalize_with_status(
|
| 269 |
-
"Patient please describe your symptoms in detail.",
|
| 270 |
-
"patient_anxious_panicked",
|
| 271 |
-
eng._groq_client,
|
| 272 |
-
eng._groq_model,
|
| 273 |
-
)
|
| 274 |
-
check(
|
| 275 |
-
"first call hits Groq and observes 401",
|
| 276 |
-
auth1 is True and mock_groq.calls == 1,
|
| 277 |
-
f"auth={auth1} calls={mock_groq.calls}",
|
| 278 |
-
)
|
| 279 |
-
check(
|
| 280 |
-
"first call still returns usable text via regex fallback",
|
| 281 |
-
isinstance(text1, str) and len(text1) > 5,
|
| 282 |
-
f"text={text1[:80]}",
|
| 283 |
-
)
|
| 284 |
-
|
| 285 |
-
# Simulate the engine setting its dead flag and verify subsequent passes
|
| 286 |
-
# never hit Groq again.
|
| 287 |
-
eng._emotion_adapter_dead = True
|
| 288 |
-
mock_groq.calls = 0
|
| 289 |
-
|
| 290 |
-
# Run the same code path the engine uses internally:
|
| 291 |
-
if eng._emotion_adapter_dead:
|
| 292 |
-
# Engine bypasses the LLM call entirely β no Groq invocation.
|
| 293 |
-
fallback_only = _tts._fallback_emotion_transform(
|
| 294 |
-
"Patient please describe your symptoms in detail.",
|
| 295 |
-
"patient_anxious_panicked",
|
| 296 |
-
)
|
| 297 |
-
fallback_calls = mock_groq.calls
|
| 298 |
-
else:
|
| 299 |
-
fallback_calls = -1
|
| 300 |
-
check(
|
| 301 |
-
"engine bypasses Groq once emotion adapter marked dead",
|
| 302 |
-
fallback_calls == 0,
|
| 303 |
-
f"calls after mark-dead={fallback_calls}",
|
| 304 |
-
)
|
| 305 |
-
check(
|
| 306 |
-
"regex fallback still produces speech",
|
| 307 |
-
isinstance(fallback_only, str) and len(fallback_only) > 5,
|
| 308 |
-
f"text={fallback_only[:80]}",
|
| 309 |
-
)
|
| 310 |
-
|
| 311 |
-
# ---------------------------------------------------------------------------
|
| 312 |
-
# 4. Health probe smoke (returns DEAD_AUTH on a junk key without crashing)
|
| 313 |
-
# ---------------------------------------------------------------------------
|
| 314 |
-
print("\n--- Test 4: _probe_groq_key handles an invalid key gracefully ---", flush=True)
|
| 315 |
-
|
| 316 |
-
status, detail = _dash._probe_groq_key("gsk_definitely_invalid_key", "llama-3.1-8b-instant", timeout_s=4.0)
|
| 317 |
-
check(
|
| 318 |
-
"probe returns DEAD_AUTH for invalid key",
|
| 319 |
-
status == "DEAD_AUTH",
|
| 320 |
-
f"status={status} detail={detail}",
|
| 321 |
-
)
|
| 322 |
-
|
| 323 |
-
status_missing, _ = _dash._probe_groq_key("", "llama-3.1-8b-instant")
|
| 324 |
-
check(
|
| 325 |
-
"probe returns MISSING for empty key (no network call)",
|
| 326 |
-
status_missing == "MISSING",
|
| 327 |
-
)
|
| 328 |
-
|
| 329 |
-
# ---------------------------------------------------------------------------
|
| 330 |
-
print("\n" + "=" * 60, flush=True)
|
| 331 |
-
print(f" RESULT: {CHECKS_PASSED} passed, {CHECKS_FAILED} failed", flush=True)
|
| 332 |
-
print("=" * 60, flush=True)
|
| 333 |
-
sys.exit(0 if CHECKS_FAILED == 0 else 1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -1,265 +0,0 @@
|
|
| 1 |
-
# Training ER-MAP on Kaggle Free Tier
|
| 2 |
-
|
| 3 |
-
This guide walks you through training the ER-MAP **Doctor agent** with GRPO + 3-phase curriculum learning on Kaggle's free GPU tier β **zero dollars**, **30 GPU-hours/week**, **single Tesla T4 16 GB**.
|
| 4 |
-
|
| 5 |
-
## TL;DR β fastest path to a converged Doctor
|
| 6 |
-
|
| 7 |
-
1. **Fork** this repo on GitHub (it must be reachable from inside the Kaggle kernel).
|
| 8 |
-
2. Get **5 Groq API keys** from https://console.groq.com/keys (one per role gives you 5x the daily quota; you can also use one key for everything if you don't mind sharing the rate-limit budget).
|
| 9 |
-
3. Get one **HF write token** from https://huggingface.co/settings/tokens (fine-grained, scope: `write` to your own repos) β needed so checkpoints survive the 12-hour Kaggle session limit.
|
| 10 |
-
4. **New Notebook on Kaggle** β Settings β **Accelerator: GPU T4 x2** β **Internet: On**.
|
| 11 |
-
5. Add the secrets in the right sidebar (Add-ons β Secrets):
|
| 12 |
-
- `GROQ_NURSE_API_KEY`, `GROQ_PATIENT_API_KEY`, `GROQ_EMPATHY_JUDGE_API_KEY`, `GROQ_MEDICAL_JUDGE_API_KEY`
|
| 13 |
-
- `HF_TOKEN`
|
| 14 |
-
- *(optional)* `WANDB_API_KEY`
|
| 15 |
-
6. Open `kaggle/train_ermap_grpo_kaggle.ipynb` from your fork inside Kaggle (File β Import Notebook β URL).
|
| 16 |
-
7. Edit the two URLs in cell 2 (`GIT_URL`) and cell 5 (`HF_PUSH_REPO`) to your fork / username.
|
| 17 |
-
8. **Run All**.
|
| 18 |
-
|
| 19 |
-
Training **stops automatically** the instant the Doctor sustains a phase-specific reward bar for **3 consecutive GRPO groups** β `+1.5` in Phase 1 (force-promote), `+1.2` in Phase 2 (force-promote), `+1.0` in Phase 3 (END). This is the "train until optimal rewards are constantly received" guarantee β see the *Train-until-optimal* section below. `NUM_EPISODES=120` is just a hard cap; healthy runs converge between episodes 70-130 (~6-11 h on T4 Γ2).
|
| 20 |
-
|
| 21 |
-
You'll see one full 6-panel dashboard PNG **per curriculum phase** land in `/kaggle/working/er_map_grpo_checkpoints/plots/` after training finishes (`phase1_dashboard.png`, `phase2_dashboard.png`, `phase3_dashboard.png`, plus `all_phases_overview.png` and `all_phases_comparison.png`), and your final LoRA adapter will be sitting on Hugging Face Hub at `<you>/ermap-doctor-lora`.
|
| 22 |
-
|
| 23 |
-
**What each per-phase dashboard shows:**
|
| 24 |
-
|
| 25 |
-
| Panel | What it tells you |
|
| 26 |
-
| --- | --- |
|
| 27 |
-
| Reward growth | raw episode reward + rolling mean + verified rolling mean |
|
| 28 |
-
| Rolling win rate (w=20) | did the policy actually get better in this phase? |
|
| 29 |
-
| Outcome distribution over time | stacked WIN/PARTIAL/INCORRECT/AMA_LOSS/FATAL_LOSS bars per ~5-episode bin |
|
| 30 |
-
| Reward components | mean of every reward component (process / treatment / empathy / labs / etc.) |
|
| 31 |
-
| GRPO update stats | per-group loss + KL β should *not* explode |
|
| 32 |
-
| Episode length | histogram of step counts β should rise from Phase 1 to Phase 3 |
|
| 33 |
-
|
| 34 |
-
---
|
| 35 |
-
|
| 36 |
-
## Hardware feasibility
|
| 37 |
-
|
| 38 |
-
| Resource | Kaggle Free Tier | What we use | Headroom |
|
| 39 |
-
|---|---|---|---|
|
| 40 |
-
| GPU | Tesla T4 16 GB | Llama-3.1-8B-4bit + LoRA(r=16) β 7-9 GB | ~50% free |
|
| 41 |
-
| RAM | 13 GB system | base model + tokenizer + buffers β 5 GB | OK |
|
| 42 |
-
| Disk | 73 GB | repo + checkpoints + cache β 10 GB | OK |
|
| 43 |
-
| Session | 12 h max | typical full Phase-1+early-Phase-2 = 6-8 h | OK |
|
| 44 |
-
| Weekly | 30 GPU-h | one full curriculum run + a re-run = ~15-20 h | OK |
|
| 45 |
-
| Internet | allowed | Groq calls per env step | OK |
|
| 46 |
-
|
| 47 |
-
**Why Llama-3.1-8B over Qwen-3-4B (the other train_grpo.py default)?**
|
| 48 |
-
- 8B reasons noticeably better on multi-turn clinical dialogue
|
| 49 |
-
- 4-bit quant brings it to 5 GB β still fits on T4 with LoRA
|
| 50 |
-
- Groq hosts the same 8B (8B-instant) so the deployed inference path matches the training distribution exactly
|
| 51 |
-
|
| 52 |
-
If you ever need to fall back to a smaller model (e.g. for a P100 session), edit `MODEL_NAME` in cell 5 of the notebook to `unsloth/Qwen2.5-3B-Instruct-bnb-4bit`. Everything else stays the same.
|
| 53 |
-
|
| 54 |
-
---
|
| 55 |
-
|
| 56 |
-
## Two ways to get the source onto Kaggle
|
| 57 |
-
|
| 58 |
-
### Option A β public GitHub fork (recommended)
|
| 59 |
-
|
| 60 |
-
In cell 2 of the notebook:
|
| 61 |
-
```python
|
| 62 |
-
GIT_URL = "https://github.com/YOUR_USERNAME/Meta_Finals.git"
|
| 63 |
-
BRANCH = "main"
|
| 64 |
-
```
|
| 65 |
-
The cell does a shallow clone into `/kaggle/working/Meta_Finals` and you're done.
|
| 66 |
-
|
| 67 |
-
### Option B β upload as a Kaggle Dataset (no GitHub needed)
|
| 68 |
-
|
| 69 |
-
1. Locally:
|
| 70 |
-
```bash
|
| 71 |
-
cd D:/Meta_Finals
|
| 72 |
-
# Exclude heavy/regenerable folders before zipping.
|
| 73 |
-
tar --exclude='.git' --exclude='__pycache__' --exclude='*.ipynb_checkpoints' \
|
| 74 |
-
--exclude='er_map_grpo_checkpoints' \
|
| 75 |
-
-czf ermap-source.tar.gz .
|
| 76 |
-
```
|
| 77 |
-
2. Kaggle β **Datasets** β **New Dataset** β upload `ermap-source.tar.gz` β name it **`ermap-source`** β save.
|
| 78 |
-
3. In your training notebook β right sidebar β **+ Add Data** β search for **`ermap-source`** β Add.
|
| 79 |
-
4. Cell 2 of the notebook detects `/kaggle/input/ermap-source/` and copies it into `/kaggle/working/Meta_Finals` automatically.
|
| 80 |
-
|
| 81 |
-
Use Option B when:
|
| 82 |
-
- Your fork is private and you don't want to expose the repo
|
| 83 |
-
- You have local edits not yet pushed
|
| 84 |
-
- Bandwidth from Kaggle to GitHub is flaky
|
| 85 |
-
|
| 86 |
-
---
|
| 87 |
-
|
| 88 |
-
## What happens when the 12-hour session ends mid-training
|
| 89 |
-
|
| 90 |
-
Without intervention you'd lose everything. The notebook prevents this:
|
| 91 |
-
|
| 92 |
-
1. **Periodic HF Hub push.** Cell 7 monkey-patches `save_lora_adapters()` so every checkpoint saved by the GRPO loop also pushes to your `HF_PUSH_REPO`. The training loop checkpoints every `group_size Γ 5` episodes (so every 10 episodes when `GROUP_SIZE=2`).
|
| 93 |
-
2. **Resume on the next session.** Set `HF_RESUME_REPO` in cell 4 of the notebook on the *new* Kaggle session. The latest LoRA adapter is downloaded to `/kaggle/working/checkpoints/resume/` before training starts β but **the current `train_grpo.py` doesn't auto-load this folder yet**; for now use it as a manual recovery (load the adapter and continue training in code). A future PR will wire the auto-resume into `load_model_and_tokenizer`.
|
| 94 |
-
|
| 95 |
-
In practice: a single 12-hour session is usually enough to clear Phase 1 and produce publishable per-phase dashboards, so resume is the safety net rather than the main path.
|
| 96 |
-
|
| 97 |
-
> **Re-render plots from any saved metrics file** (locally or in another Kaggle session):
|
| 98 |
-
> ```bash
|
| 99 |
-
> python -m ER_MAP.plotting \
|
| 100 |
-
> --metrics er_map_grpo_checkpoints/training_metrics.json \
|
| 101 |
-
> --out er_map_grpo_checkpoints/plots
|
| 102 |
-
> ```
|
| 103 |
-
> This is the same call the notebook makes β handy if you want to regenerate the charts after training, or restyle them without re-running training.
|
| 104 |
-
|
| 105 |
-
---
|
| 106 |
-
|
| 107 |
-
## Per-role Groq keys vs. one shared key
|
| 108 |
-
|
| 109 |
-
The dashboard ships with 4 distinct Groq clients (Nurse, Patient, Empathy Judge, Medical Judge) and a fallback chain that walks across all four if any fails auth. Per-key budgets are *shared* on Groq's free tier (limits are per-account, not per-key) β but the model split below buys you real headroom because **each model has its own daily pool**.
|
| 110 |
-
|
| 111 |
-
### Default model assignment (traffic-shaping)
|
| 112 |
-
|
| 113 |
-
| Role | Model | Free-tier pool | Why |
|
| 114 |
-
|---|---|---|---|
|
| 115 |
-
| Nurse | `llama-3.1-8b-instant` | 14 400 RPD / 500K TPD | high-volume (every env step) |
|
| 116 |
-
| Patient | `llama-3.1-8b-instant` | shared 8B pool | high-volume (every env step) |
|
| 117 |
-
| Empathy Judge | `llama-3.3-70b-versatile` | 1 000 RPD / 100K TPD | grading quality directly shapes reward |
|
| 118 |
-
| Medical Judge | `llama-3.3-70b-versatile` | shared 70B pool | grading quality directly shapes reward |
|
| 119 |
-
|
| 120 |
-
Quick budget check for **one full 120-episode training run**:
|
| 121 |
-
|
| 122 |
-
| Pool | Estimated calls/run | Daily ceiling | Headroom |
|
| 123 |
-
|---|---|---|---|
|
| 124 |
-
| 8B-instant (Nurse + Patient) | ~2 880 | 14 400 RPD | ~5x |
|
| 125 |
-
| 70B-versatile (judges) | ~720 | 1 000 RPD | ~1.4x |
|
| 126 |
-
|
| 127 |
-
You can do **one training run per day per account** comfortably. If you need to retry inside the same day, drop one of the two judges to 8B-instant temporarily β the reward signal degrades a little, but training keeps moving.
|
| 128 |
-
|
| 129 |
-
If you only have **one** Groq key total, set just `GROQ_API_KEY` as a Kaggle Secret. Everything still works β the AgentRouter falls back to the same client for all roles, and the per-model budgets still split traffic across pools.
|
| 130 |
-
|
| 131 |
-
---
|
| 132 |
-
|
| 133 |
-
## What the reward-growth curve should look like
|
| 134 |
-
|
| 135 |
-
If training is healthy, after ~80 episodes you should see:
|
| 136 |
-
|
| 137 |
-
- **Rolling avg reward** climbs from β -0.4 (random baseline) toward +1.5+ (the early-stop target)
|
| 138 |
-
- **Rolling win rate** climbs from ~10% to 40%+
|
| 139 |
-
- A **vertical red dashed line** marks the Phase 1 β Phase 2 promotion (typically episode 30-60), and a second one marks Phase 2 β Phase 3 (typically episode 60-90)
|
| 140 |
-
- KL divergence stays in `[0.005, 0.05]` β if it spikes above 0.5 the model is drifting (lower `LEARNING_RATE` and re-run)
|
| 141 |
-
|
| 142 |
-
If the curve is flat or trending down:
|
| 143 |
-
- Check that Groq is actually responding (look for `Groq API error` lines in the log)
|
| 144 |
-
- Check that `rewards.std()` is non-zero across the group (cell logs print `adv_std=`; if it's < 1e-6 GRPO skips the update)
|
| 145 |
-
- Drop `GROUP_SIZE` from 2 β 1? **Don't** β group size 1 = no advantage signal = no GRPO update. Keep G β₯ 2.
|
| 146 |
-
|
| 147 |
-
---
|
| 148 |
-
|
| 149 |
-
## Train-until-optimal β per-phase reward thresholds
|
| 150 |
-
|
| 151 |
-
> *"I want training until certain optimal rewards are constantly received."*
|
| 152 |
-
|
| 153 |
-
After every GRPO update the loop maintains a **rolling buffer of the last `CONVERGENCE_WINDOW=3` groups**. When all 3 entries are in the *same* current phase AND each has `rolling_avg_reward >= PHASE_REWARD_TARGETS[current_phase]`, the loop reacts:
|
| 154 |
-
|
| 155 |
-
| Current phase | When buffer qualifies | Effect |
|
| 156 |
-
|---|---|---|
|
| 157 |
-
| Phase 1 (Tool Mastery) | sustained `+1.5` for 3 groups | force-promote to Phase 2, clear buffer |
|
| 158 |
-
| Phase 2 (Clinical Reasoning) | sustained `+1.2` for 3 groups | force-promote to Phase 3, clear buffer |
|
| 159 |
-
| Phase 3 (Empathetic Negotiation) | sustained `+1.0` for 3 groups | **END TRAINING** |
|
| 160 |
-
|
| 161 |
-
The buffer is cleared after each promotion so stale entries cannot pre-satisfy the next phase's bar. A soft `PHASE_MIN_WIN_RATE=0.20` floor prevents stopping on partial-credit-only runs.
|
| 162 |
-
|
| 163 |
-
If even one group in the window slips below the bar, the counter resets β guaranteeing the policy is *constantly* hitting the target, not transiently. `NUM_EPISODES` becomes a hard safety cap, not a fixed budget.
|
| 164 |
-
|
| 165 |
-
### Why the targets descend from `1.5` β `1.2` β `1.0`
|
| 166 |
-
|
| 167 |
-
The phases are not equally easy to score on. Looking at the reward function:
|
| 168 |
-
|
| 169 |
-
| Phase | Best-case reward (clean win) | Realistic clean-policy mean |
|
| 170 |
-
|---|---|---|
|
| 171 |
-
| 1 β easy patient, clean SOAP | `+2.0` (full terminal_win) | `+1.6 .. +1.8` |
|
| 172 |
-
| 2 β mixed compliance + noisy SOAP | `+1.7` (terminal_win - some lab noise) | `+1.2 .. +1.4` |
|
| 173 |
-
| 3 β full persona randomization + consent costs | `+1.4` (terminal_win - empathy/AMA penalties) | `+1.0 .. +1.2` |
|
| 174 |
-
|
| 175 |
-
So requiring `+1.5` in Phase 1 demonstrates real tool mastery (not just floor-grazing), while requiring `+1.0` in Phase 3 is genuinely hard β no random policy ever sustains it.
|
| 176 |
-
|
| 177 |
-
### Defaults (in the notebook, section 5)
|
| 178 |
-
|
| 179 |
-
```python
|
| 180 |
-
EARLY_STOP_ENABLED = True
|
| 181 |
-
PHASE_REWARD_TARGETS = {1: 1.5, 2: 1.2, 3: 1.0}
|
| 182 |
-
PHASE_MIN_WIN_RATE = 0.20 # soft floor
|
| 183 |
-
CONVERGENCE_WINDOW = 3
|
| 184 |
-
```
|
| 185 |
-
|
| 186 |
-
### Reading the per-phase telemetry
|
| 187 |
-
|
| 188 |
-
After every GRPO group the log prints:
|
| 189 |
-
|
| 190 |
-
```
|
| 191 |
-
[Scheduler] Phase 2 (Clinical Reasoning) | Win Rate: 42.0% | Avg Reward: +1.18 | Phase Episodes: 14
|
| 192 |
-
[EarlyStop] Phase 2 target avg-reward >= +1.20: qualified 2/3 recent groups (need all 3 -> promote)
|
| 193 |
-
```
|
| 194 |
-
|
| 195 |
-
When the Phase-2 buffer fills with 3 qualifying groups:
|
| 196 |
-
|
| 197 |
-
```
|
| 198 |
-
[Scheduler] force_promote() called: sustained rolling-avg-reward +1.20 for 3 consecutive groups in Phase 2
|
| 199 |
-
************************************************************
|
| 200 |
-
CURRICULUM PROMOTION: Clinical Reasoning -> Empathetic Negotiation
|
| 201 |
-
************************************************************
|
| 202 |
-
```
|
| 203 |
-
|
| 204 |
-
When Phase 3 finally converges:
|
| 205 |
-
|
| 206 |
-
```
|
| 207 |
-
************************************************************
|
| 208 |
-
EARLY STOP: Phase 3 convergence reached after 92 episodes
|
| 209 |
-
Last 3 groups all sustained:
|
| 210 |
-
rolling_avg_reward >= +1.00
|
| 211 |
-
rolling_win_rate >= 20%
|
| 212 |
-
in Phase 3 (Empathetic Negotiation)
|
| 213 |
-
************************************************************
|
| 214 |
-
```
|
| 215 |
-
|
| 216 |
-
β¦and the loop exits cleanly into the final-save / final-push / plotting cells.
|
| 217 |
-
|
| 218 |
-
### Per-phase wall-clock estimates on Kaggle T4 Γ2
|
| 219 |
-
|
| 220 |
-
| Phase | Typical episodes to hit target | Wall-clock | Why |
|
| 221 |
-
|---|---|---|---|
|
| 222 |
-
| 1 | 16 β 30 episodes (8 β 15 groups) | **~1.5 β 2.5 h** | Easy patients + clean SOAP; tool format is the only real lift. |
|
| 223 |
-
| 2 | 24 β 40 episodes (12 β 20 groups) | **~2.0 β 3.5 h** | Most policy improvement happens here. |
|
| 224 |
-
| 3 | 30 β 60 episodes (15 β 30 groups) | **~2.5 β 5.0 h** | Empathy + consent costs make `+1.0` genuinely hard. |
|
| 225 |
-
| **Total** | 70 β 130 episodes | **~6 β 11 h** | Fits the 12 h GPU session with ~1 h margin. |
|
| 226 |
-
|
| 227 |
-
Per-group wall-clock β 8 β 12 min on T4 (depending on episode length); per-episode β 3 β 5 min for env rollout + β 1 β 2 min amortized for the GRPO update.
|
| 228 |
-
|
| 229 |
-
### Tuning suggestions
|
| 230 |
-
|
| 231 |
-
| Goal | What to change |
|
| 232 |
-
|---|---|
|
| 233 |
-
| Smoke run (converge fast on a weak policy) | `PHASE_REWARD_TARGETS={1: 0.5, 2: 0.4, 3: 0.3}`, `CONVERGENCE_WINDOW=2` |
|
| 234 |
-
| Hackathon-grade Doctor | keep defaults |
|
| 235 |
-
| Aim for SOTA on this benchmark | `PHASE_REWARD_TARGETS={1: 1.7, 2: 1.5, 3: 1.3}`, `CONVERGENCE_WINDOW=5` |
|
| 236 |
-
| Disable entirely (run full 120 episodes regardless) | `EARLY_STOP_ENABLED=False` |
|
| 237 |
-
| Resuming from a partial run | targets are unchanged β the buffer is rebuilt from the new session's groups, so nothing weird happens |
|
| 238 |
-
|
| 239 |
-
### What NOT to do
|
| 240 |
-
|
| 241 |
-
- Don't set `CONVERGENCE_WINDOW=1`. A single lucky group can pass any bar; you'll promote out of Phase 1 the instant the first easy patient is correctly discharged.
|
| 242 |
-
- Don't lower the Phase-1 target below `+0.8`. The built-in scheduler already promotes Phase 1 β Phase 2 at `win_rate >= 40% AND avg_reward >= +0.3`, so a Phase-1 reward bar below that is dead code.
|
| 243 |
-
- Don't raise the Phase-3 target above `+1.5`. The reward ceiling on a Phase-3 episode (after empathy/consent costs) is around `+1.4 .. +1.6`; sustaining `+1.5+` for 3 consecutive groups is essentially unachievable in 12 h.
|
| 244 |
-
|
| 245 |
-
---
|
| 246 |
-
|
| 247 |
-
## Common Kaggle gotchas
|
| 248 |
-
|
| 249 |
-
| Symptom | Fix |
|
| 250 |
-
|---|---|
|
| 251 |
-
| `Groq API error: 401 invalid_api_key` | Regenerate the key (Groq auto-revokes keys posted publicly). Update the Kaggle Secret. |
|
| 252 |
-
| `OutOfMemoryError` on T4 | Drop `MAX_SEQ_LENGTH` from 2048 to 1536 inside `load_model_and_tokenizer`, or switch to `unsloth/Qwen2.5-3B-Instruct-bnb-4bit`. |
|
| 253 |
-
| `unsloth import` failed | Restart kernel after `pip install` β Unsloth pins `xformers` versions and the running kernel keeps the old import cached. |
|
| 254 |
-
| Checkpoints not appearing on HF Hub | Verify `HF_PUSH_REPO` doesn't still contain the `<your-username>` placeholder, and that `HF_TOKEN` has `write` scope. |
|
| 255 |
-
| "Internet off" warning | Right sidebar β Settings β toggle Internet to **On**. (Default is off for new accounts.) |
|
| 256 |
-
|
| 257 |
-
---
|
| 258 |
-
|
| 259 |
-
## Cost summary
|
| 260 |
-
|
| 261 |
-
- **Kaggle**: free
|
| 262 |
-
- **Groq API (training)**: free (within free-tier daily quotas, ~2 000 calls per full run)
|
| 263 |
-
- **Hugging Face Hub**: free for the LoRA adapter (~50 MB) + free for the merged fp16 (~16 GB on a public repo, free up to 1 TB total)
|
| 264 |
-
- **Wandb**: free for personal projects
|
| 265 |
-
- **Total**: $0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Kaggle Quickstart β ER-MAP GRPO Training (v3 stable)
|
| 2 |
+
|
| 3 |
+
The Kaggle notebook is in `kaggle/train_ermap_grpo_kaggle.ipynb`. This file
|
| 4 |
+
is the cheat sheet for running it end-to-end without the dependency hell
|
| 5 |
+
that bit us in earlier attempts.
|
| 6 |
+
|
| 7 |
+
## 0. Prerequisites (one-time)
|
| 8 |
+
|
| 9 |
+
1. **GitHub fork** of this repo. The notebook clones from a public fork at
|
| 10 |
+
cell 6 β edit `GIT_URL`. Alternatively, upload the repo as a Kaggle
|
| 11 |
+
Dataset named `ermap-source` (Add Data β Upload).
|
| 12 |
+
2. **Hugging Face write token** (`HF_TOKEN`) for pushing the trained
|
| 13 |
+
adapter. Create at https://huggingface.co/settings/tokens (fine-grained,
|
| 14 |
+
write access on a single model repo is enough).
|
| 15 |
+
3. **Five Groq keys** (one each for Nurse / Patient / Empathy Judge /
|
| 16 |
+
Medical Judge / shared fallback). Free-tier accounts are fine; the
|
| 17 |
+
per-account limits multiply across keys.
|
| 18 |
+
|
| 19 |
+
## 1. Create the Kaggle notebook
|
| 20 |
+
|
| 21 |
+
1. Sign in to https://www.kaggle.com/code β **New Notebook**.
|
| 22 |
+
2. Right sidebar:
|
| 23 |
+
- Accelerator: **GPU T4 Γ2** (or P100)
|
| 24 |
+
- Internet: **On**
|
| 25 |
+
- Persistence: Files only
|
| 26 |
+
3. **File β Upload Notebook** β choose `kaggle/train_ermap_grpo_kaggle.ipynb`
|
| 27 |
+
from this repo.
|
| 28 |
+
|
| 29 |
+
## 2. Add Kaggle Secrets
|
| 30 |
+
|
| 31 |
+
Add-ons β Secrets β Add a new secret. Required labels (exactly):
|
| 32 |
+
|
| 33 |
+
| Label | Value |
|
| 34 |
+
|---|---|
|
| 35 |
+
| `GROQ_NURSE_API_KEY` | your nurse Groq key |
|
| 36 |
+
| `GROQ_PATIENT_API_KEY` | your patient Groq key |
|
| 37 |
+
| `GROQ_EMPATHY_JUDGE_API_KEY` | your empathy-judge Groq key |
|
| 38 |
+
| `GROQ_MEDICAL_JUDGE_API_KEY` | your medical-judge Groq key |
|
| 39 |
+
| `HF_TOKEN` | your HF write token |
|
| 40 |
+
| `WANDB_API_KEY` *(optional)* | your W&B key (skip β disabled by default) |
|
| 41 |
+
|
| 42 |
+
The notebook reads them via `kaggle_helpers.load_kaggle_secrets()` and
|
| 43 |
+
exports them as env vars.
|
| 44 |
+
|
| 45 |
+
## 3. Edit two placeholders in the notebook
|
| 46 |
+
|
| 47 |
+
- **Cell 6:** `GIT_URL = "https://github.com/<your-fork>/Meta_Finals.git"`
|
| 48 |
+
- **Cell 8:** `HF_PUSH_REPO = "<your-username>/ermap-doctor-lora"`
|
| 49 |
+
|
| 50 |
+
If you uploaded the repo as a Kaggle Dataset instead, leave `GIT_URL` as the
|
| 51 |
+
placeholder β cell 6 will detect `/kaggle/input/ermap-source` and copy from
|
| 52 |
+
there.
|
| 53 |
+
|
| 54 |
+
## 4. Run order (the only sequence that works)
|
| 55 |
+
|
| 56 |
+
| Cell | What it does | Expected output |
|
| 57 |
+
|---|---|---|
|
| 58 |
+
| 2 | GPU + disk + python + internet sanity check | GPU listed, disk free > 8 GB |
|
| 59 |
+
| 3 | **REPAIR** β pin torch 2.10 cu128, reinstall bitsandbytes, upgrade unsloth | `REPAIR OK` (or `RESTART REQUIRED`) |
|
| 60 |
+
| **(restart)** | If cell 3 said RESTART REQUIRED β Run β Restart kernel | β |
|
| 61 |
+
| 5 | Post-restart import verify | All `OK`, GPUs listed |
|
| 62 |
+
| 6 | Clone / mount the repo | `OK. Repo at /kaggle/working/Meta_Finals` |
|
| 63 |
+
| 7 | Wire Kaggle Secrets β env vars | `OK β at least one Groq key is wired` |
|
| 64 |
+
| 8 | HF Hub config | `Starting fresh β no resume.` |
|
| 65 |
+
| 9 | Hyperparameters (P1=+1.2, P2=+1.1, P3=+1.0) | thresholds printed |
|
| 66 |
+
| 10 | **Pre-flight** β Groq routing + 4Γ PING | 4Γ `[PASS]`, then `OK` |
|
| 67 |
+
| 11 | Dry-run smoke test (no GPU) | `Dry-run OK` |
|
| 68 |
+
| 12 | Wire HF push hook | `Hub-push hook installed.` |
|
| 69 |
+
| 13 | **REAL TRAINING** (4β6 h) | per-group rolling stats, eventual `EARLY STOP` |
|
| 70 |
+
| 14 | Final push to HF | `Final checkpoints pushed: https://huggingface.co/...` |
|
| 71 |
+
| 15 | Per-phase plots | 5 PNGs displayed inline |
|
| 72 |
+
| 16 | Push plots to HF | `Plots pushed: ...` |
|
| 73 |
+
| 17 | Inference smoke-test (optional) | 3 sample Doctor actions printed |
|
| 74 |
+
|
| 75 |
+
## 5. Common failures & fixes
|
| 76 |
+
|
| 77 |
+
| Symptom | Root cause | Fix |
|
| 78 |
+
|---|---|---|
|
| 79 |
+
| `numpy was upgraded mid-session` | numpy import poisoned by a previous cell | Restart kernel, re-run from cell 3 |
|
| 80 |
+
| `Pillow incompatible with torchvision` | Pillow ABI mismatch | Restart kernel, re-run from cell 3 |
|
| 81 |
+
| `PyTorch and torchvision compiled with different CUDA major` | torch upgraded to cu13 by a transient resolve | Re-run cell 3 (it pins cu128) and restart |
|
| 82 |
+
| `cannot import name 'create_gradient_checkpointing_buffer'` | unsloth β unsloth_zoo version drift | Re-run cell 3 (upgrades both in lockstep) |
|
| 83 |
+
| `libnvJitLink.so.13 missing` | bitsandbytes built against different CUDA | Re-run cell 3 (force-reinstalls bitsandbytes after torch pin) |
|
| 84 |
+
| Disk usage > quota | Kaggle's 20 GB working partition fills up | First line of cell 3 cleans `/tmp` and pip cache |
|
| 85 |
+
| Pre-flight `[FAIL]` for a role | Groq key dead / quota exceeded | Generate a new key in console.groq.com β update Kaggle Secret β re-run cell 7+10 |
|
| 86 |
+
| `[FAIL]` says `routing=WRONG` | env var not set when `AgentRouter()` was constructed | Re-run cell 9 BEFORE cell 10 |
|
| 87 |
+
| Training freezes at episode 1 for >10 min | Doctor.generate hung; Unsloth import broke silently | Check cell 5 output for `unsloth` line; restart kernel and re-run cell 3 if missing |
|
| 88 |
+
|
| 89 |
+
## 6. What the trained model gives you
|
| 90 |
+
|
| 91 |
+
After cell 13 finishes (or hits the 12 h Kaggle session cap), you have:
|
| 92 |
+
|
| 93 |
+
- `OUTPUT_DIR/final_lora/` β LoRA adapter weights (~50 MB), pushed to
|
| 94 |
+
`HF_PUSH_REPO`
|
| 95 |
+
- `OUTPUT_DIR/final_merged_fp16/` β full Llama-3.1-8B fp16 merge with the
|
| 96 |
+
adapter applied (~16 GB), pushed to `HF_PUSH_REPO-merged`
|
| 97 |
+
- `OUTPUT_DIR/training_metrics.json` β per-episode rewards, outcomes,
|
| 98 |
+
rolling stats β input for the per-phase plots
|
| 99 |
+
- `OUTPUT_DIR/plots/*.png` β 5 dashboards (one per phase + cross-phase
|
| 100 |
+
overview + comparison bar)
|
| 101 |
+
|
| 102 |
+
Use the LoRA adapter for the demo (quick to load, runs on a 4050 6 GB at
|
| 103 |
+
~30 tok/s); use the merged fp16 if you need to host on a Vercel/HF Space
|
| 104 |
+
without `peft`.
|
|
@@ -0,0 +1,880 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
kaggle/build_notebook.py
|
| 3 |
+
========================
|
| 4 |
+
Programmatically (re)builds `train_ermap_grpo_kaggle.ipynb` from scratch.
|
| 5 |
+
|
| 6 |
+
Why a builder script?
|
| 7 |
+
--------------------
|
| 8 |
+
The hand-edited notebook drifted into a fragile state across many sessions:
|
| 9 |
+
mixed early-stop / fixed-budget params, stale install snippets, dead pre-flight
|
| 10 |
+
checks, etc. This script is the single source of truth β run it once and the
|
| 11 |
+
notebook is regenerated as a clean, deterministic v3 layout.
|
| 12 |
+
|
| 13 |
+
Run:
|
| 14 |
+
python kaggle/build_notebook.py
|
| 15 |
+
|
| 16 |
+
Output:
|
| 17 |
+
kaggle/train_ermap_grpo_kaggle.ipynb (overwritten)
|
| 18 |
+
kaggle/KAGGLE_QUICKSTART.md (overwritten)
|
| 19 |
+
"""
|
| 20 |
+
from __future__ import annotations
|
| 21 |
+
|
| 22 |
+
import json
|
| 23 |
+
import textwrap
|
| 24 |
+
from pathlib import Path
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
# ---------------------------------------------------------------------------
|
| 28 |
+
# Cell helpers
|
| 29 |
+
# ---------------------------------------------------------------------------
|
| 30 |
+
|
| 31 |
+
def md_cell(text: str) -> dict:
|
| 32 |
+
return {
|
| 33 |
+
"cell_type": "markdown",
|
| 34 |
+
"metadata": {},
|
| 35 |
+
"source": _split_keep_newlines(text),
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def code_cell(text: str) -> dict:
|
| 40 |
+
return {
|
| 41 |
+
"cell_type": "code",
|
| 42 |
+
"execution_count": None,
|
| 43 |
+
"metadata": {},
|
| 44 |
+
"outputs": [],
|
| 45 |
+
"source": _split_keep_newlines(text),
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def _split_keep_newlines(text: str) -> list[str]:
|
| 50 |
+
"""Notebook 'source' fields expect each line to terminate with '\n'
|
| 51 |
+
except the last one. Splitting like this keeps `git diff` clean when
|
| 52 |
+
the notebook is regenerated."""
|
| 53 |
+
text = textwrap.dedent(text).lstrip("\n")
|
| 54 |
+
if not text.endswith("\n"):
|
| 55 |
+
text = text + "\n"
|
| 56 |
+
lines = text.splitlines(keepends=True)
|
| 57 |
+
if lines:
|
| 58 |
+
# The last line should NOT have a trailing newline (Jupyter convention).
|
| 59 |
+
if lines[-1].endswith("\n"):
|
| 60 |
+
lines[-1] = lines[-1].rstrip("\n")
|
| 61 |
+
return lines
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
# ---------------------------------------------------------------------------
|
| 65 |
+
# Cell sources
|
| 66 |
+
# ---------------------------------------------------------------------------
|
| 67 |
+
|
| 68 |
+
CELL_01_TITLE = """\
|
| 69 |
+
# ER-MAP β Doctor Agent GRPO Training (Kaggle Free-Tier Β· v3 stable)
|
| 70 |
+
|
| 71 |
+
Trains the **Doctor LLM** (Llama-3.1-8B-Instruct, 4-bit + LoRA r=16) via GRPO
|
| 72 |
+
with a 3-phase curriculum on Kaggle's free GPU. Designed to survive Kaggle's
|
| 73 |
+
pre-baked image quirks (numpy / Pillow ABI mismatches, torch + torchvision
|
| 74 |
+
CUDA-major mismatches, transient `unsloth_zoo` upgrades).
|
| 75 |
+
|
| 76 |
+
## TL;DR β How to run this notebook
|
| 77 |
+
|
| 78 |
+
1. **Notebook settings (right sidebar):**
|
| 79 |
+
- Accelerator: **GPU T4 Γ2** (or P100)
|
| 80 |
+
- Internet: **On**
|
| 81 |
+
- Persistence: Files only
|
| 82 |
+
2. **Kaggle Secrets** (Add-ons β Secrets):
|
| 83 |
+
- **Required:** `GROQ_NURSE_API_KEY`, `GROQ_PATIENT_API_KEY`,
|
| 84 |
+
`GROQ_EMPATHY_JUDGE_API_KEY`, `GROQ_MEDICAL_JUDGE_API_KEY`, `HF_TOKEN`
|
| 85 |
+
- **Optional:** `WANDB_API_KEY`
|
| 86 |
+
3. **Run cells 2 β 3 (sanity + REPAIR).** When cell 3 prints
|
| 87 |
+
`RESTART REQUIRED`, click **Run β Restart kernel**, then resume from cell 5.
|
| 88 |
+
4. **Run cells 5 β 11 (verify + configure + dry-run + pre-flight).** Each cell
|
| 89 |
+
should print an `OK` line before moving on.
|
| 90 |
+
5. **Run cell 13 (the long training cell, 4β6 hours).**
|
| 91 |
+
6. **Run cells 14 β 17 (final push + plots + inference smoke-test).**
|
| 92 |
+
|
| 93 |
+
## Curriculum + reward thresholds (this run)
|
| 94 |
+
|
| 95 |
+
Constant per-phase rolling-avg-reward bars; sustained for **3 consecutive
|
| 96 |
+
GRPO groups** triggers either a phase promotion or end-of-training.
|
| 97 |
+
|
| 98 |
+
| Phase | Reward target (sustained Γ3 groups) | Action when met |
|
| 99 |
+
|---|---|---|
|
| 100 |
+
| 1 β Tool Mastery | `+1.2` | force-promote to Phase 2 |
|
| 101 |
+
| 2 β Clinical Reasoning | `+1.1` | force-promote to Phase 3 |
|
| 102 |
+
| 3 β Empathetic Negotiation | `+1.0` | END TRAINING |
|
| 103 |
+
|
| 104 |
+
Why these numbers? The un-trained 8B Doctor's baseline on the same env is
|
| 105 |
+
`P1=+0.76, P2=+0.59, P3=+0.39`. Targets of `+1.2 / +1.1 / +1.0` correspond
|
| 106 |
+
to roughly `1.6Γ / 1.9Γ / 2.6Γ` improvement over baseline β a meaningful
|
| 107 |
+
signal but reachable inside Kaggle's 12 h session limit.
|
| 108 |
+
"""
|
| 109 |
+
|
| 110 |
+
CELL_02_SANITY = """\
|
| 111 |
+
# === CELL 2 β Sanity check (GPU + disk + python + internet) ===
|
| 112 |
+
# Run this FIRST. If any check fails, fix it before running the REPAIR cell.
|
| 113 |
+
|
| 114 |
+
import os, shutil, subprocess, sys, socket
|
| 115 |
+
|
| 116 |
+
print("--- GPU ---")
|
| 117 |
+
try:
|
| 118 |
+
print(subprocess.check_output(
|
| 119 |
+
["nvidia-smi", "--query-gpu=name,memory.total,memory.free", "--format=csv"],
|
| 120 |
+
timeout=10,
|
| 121 |
+
).decode())
|
| 122 |
+
except Exception as e:
|
| 123 |
+
print(f"nvidia-smi failed: {e}")
|
| 124 |
+
print("-> Set Accelerator to 'GPU T4 x2' in the right sidebar.")
|
| 125 |
+
|
| 126 |
+
print("--- Disk (/kaggle/working) ---")
|
| 127 |
+
total, used, free = shutil.disk_usage("/kaggle/working")
|
| 128 |
+
print(f" total={total/1e9:5.1f} GB | used={used/1e9:5.1f} GB | free={free/1e9:5.1f} GB")
|
| 129 |
+
if free < 8 * 1e9:
|
| 130 |
+
print(" WARNING: free disk < 8 GB β repair cell may fail. "
|
| 131 |
+
"Consider 'Run > Restart and clear cell outputs' to reset /tmp.")
|
| 132 |
+
|
| 133 |
+
print("--- Python ---")
|
| 134 |
+
print(f" python={sys.version.split()[0]} | exe={sys.executable}")
|
| 135 |
+
|
| 136 |
+
print("--- Internet (api.groq.com:443) ---")
|
| 137 |
+
try:
|
| 138 |
+
socket.create_connection(("api.groq.com", 443), timeout=5).close()
|
| 139 |
+
print(" reachable")
|
| 140 |
+
except Exception as e:
|
| 141 |
+
print(f" UNREACHABLE: {e}")
|
| 142 |
+
print(" -> Settings (right sidebar) -> Internet -> ON")
|
| 143 |
+
"""
|
| 144 |
+
|
| 145 |
+
CELL_03_REPAIR = """\
|
| 146 |
+
# === CELL 3 β REPAIR CELL (idempotent full environment rebuild) ===
|
| 147 |
+
# Single source of truth for ER-MAP's GPU stack. Safe to re-run. After it
|
| 148 |
+
# finishes you'll see one of two final lines:
|
| 149 |
+
#
|
| 150 |
+
# RESTART REQUIRED -> Run -> Restart kernel, then resume from cell 5
|
| 151 |
+
# REPAIR OK -> proceed directly to cell 5
|
| 152 |
+
#
|
| 153 |
+
# Note: this cell only runs shell commands and one isolated subprocess.
|
| 154 |
+
# It deliberately does NOT `import torch / numpy / Pillow / unsloth` in the
|
| 155 |
+
# kernel, so re-running it after a botched install does not poison further
|
| 156 |
+
# attempts.
|
| 157 |
+
|
| 158 |
+
print("=" * 72); print(" CELL 3 β REPAIR"); print("=" * 72)
|
| 159 |
+
|
| 160 |
+
# 1. Clean caches (Kaggle's /kaggle/working is only 20 GB β installs
|
| 161 |
+
# routinely fill it after a few re-runs).
|
| 162 |
+
print("[1/6] Cleaning pip + tmp + HF dataset caches...")
|
| 163 |
+
get_ipython().system('pip cache purge -q || true')
|
| 164 |
+
get_ipython().system('rm -rf /tmp/* /root/.cache/pip /root/.cache/huggingface/datasets 2>/dev/null || true')
|
| 165 |
+
|
| 166 |
+
# 2. Pin torch + torchvision to the cu128 wheel (matches Kaggle's CUDA 12.8
|
| 167 |
+
# base image). DON'T let pip pull a generic CUDA-13 build β that breaks
|
| 168 |
+
# bitsandbytes (libnvJitLink.so.13 missing) and torchvision (CUDA-major
|
| 169 |
+
# mismatch RuntimeError at import time).
|
| 170 |
+
print("[2/6] Installing torch==2.10.0 + torchvision==0.25.0 (cu128)...")
|
| 171 |
+
get_ipython().system('pip install -q --no-cache-dir --force-reinstall '
|
| 172 |
+
'torch==2.10.0 torchvision==0.25.0 '
|
| 173 |
+
'--index-url https://download.pytorch.org/whl/cu128')
|
| 174 |
+
|
| 175 |
+
# 3. Reinstall bitsandbytes against the now-pinned torch.
|
| 176 |
+
print("[3/6] Reinstalling bitsandbytes...")
|
| 177 |
+
get_ipython().system('pip install -q --no-cache-dir --force-reinstall bitsandbytes')
|
| 178 |
+
|
| 179 |
+
# 4. Upgrade unsloth + unsloth_zoo + trl in lockstep. unsloth and
|
| 180 |
+
# unsloth_zoo are released as a matched pair; if pip pulls a fresh
|
| 181 |
+
# unsloth_zoo against an old unsloth you get
|
| 182 |
+
# ImportError: cannot import name 'create_gradient_checkpointing_buffer'
|
| 183 |
+
print("[4/6] Upgrading unsloth + unsloth_zoo + trl...")
|
| 184 |
+
get_ipython().system('pip install -q --upgrade --no-cache-dir '
|
| 185 |
+
'unsloth unsloth_zoo "trl>=0.18.2"')
|
| 186 |
+
|
| 187 |
+
# 5. ER-MAP runtime deps that aren't pre-installed on Kaggle.
|
| 188 |
+
print("[5/6] Installing ER-MAP runtime deps...")
|
| 189 |
+
get_ipython().system('pip install -q --no-cache-dir '
|
| 190 |
+
'"groq>=0.18.0" "huggingface_hub>=0.25.0" '
|
| 191 |
+
'"gymnasium>=0.29.0" "openenv-core>=0.1.0"')
|
| 192 |
+
|
| 193 |
+
# 6. Verify in a SUBPROCESS (so the parent kernel never imports any of these
|
| 194 |
+
# while pip is mid-flight, which is what causes the
|
| 195 |
+
# 'numpy was upgraded mid-session (loaded: X, installed: Y)' RuntimeError
|
| 196 |
+
# we kept hitting before).
|
| 197 |
+
print("[6/6] Verifying via subprocess...")
|
| 198 |
+
import subprocess, sys, json
|
| 199 |
+
|
| 200 |
+
verify_script = r'''
|
| 201 |
+
import json, sys
|
| 202 |
+
out = {"ok": True, "details": {}, "errors": []}
|
| 203 |
+
try:
|
| 204 |
+
import importlib.metadata as md
|
| 205 |
+
for pkg in ("torch", "torchvision", "bitsandbytes", "unsloth", "unsloth_zoo",
|
| 206 |
+
"trl", "transformers", "peft", "accelerate", "groq",
|
| 207 |
+
"huggingface_hub", "gymnasium", "numpy", "Pillow"):
|
| 208 |
+
try:
|
| 209 |
+
out["details"][pkg + "_installed"] = md.version(pkg)
|
| 210 |
+
except md.PackageNotFoundError:
|
| 211 |
+
out["details"][pkg + "_installed"] = None
|
| 212 |
+
|
| 213 |
+
import torch, torchvision, numpy as np, PIL, unsloth, unsloth_zoo, bitsandbytes, trl
|
| 214 |
+
out["details"]["torch_loaded"] = torch.__version__
|
| 215 |
+
out["details"]["torch_cuda"] = torch.version.cuda
|
| 216 |
+
out["details"]["cuda_available"] = bool(torch.cuda.is_available())
|
| 217 |
+
out["details"]["gpu_count"] = int(torch.cuda.device_count())
|
| 218 |
+
out["details"]["torchvision_loaded"] = torchvision.__version__
|
| 219 |
+
out["details"]["numpy_loaded"] = np.__version__
|
| 220 |
+
out["details"]["pillow_loaded"] = PIL.__version__
|
| 221 |
+
out["details"]["unsloth_loaded"] = unsloth.__version__
|
| 222 |
+
out["details"]["unsloth_zoo_loaded"] = unsloth_zoo.__version__
|
| 223 |
+
out["details"]["bitsandbytes_loaded"] = bitsandbytes.__version__
|
| 224 |
+
out["details"]["trl_loaded"] = trl.__version__
|
| 225 |
+
|
| 226 |
+
# Cross-check loaded-vs-installed for the C-extension libs that bit us
|
| 227 |
+
# on every previous run.
|
| 228 |
+
for pkg, loaded_key, installed_key in [
|
| 229 |
+
("numpy", "numpy_loaded", "numpy_installed"),
|
| 230 |
+
("Pillow", "pillow_loaded", "Pillow_installed"),
|
| 231 |
+
("torch", "torch_loaded", "torch_installed"),
|
| 232 |
+
]:
|
| 233 |
+
loaded = out["details"].get(loaded_key)
|
| 234 |
+
installed = out["details"].get(installed_key)
|
| 235 |
+
if loaded and installed and loaded != installed:
|
| 236 |
+
# Strip any local-version suffix (e.g. '+cu128') before compare.
|
| 237 |
+
if loaded.split("+")[0] != installed.split("+")[0]:
|
| 238 |
+
out["errors"].append(
|
| 239 |
+
f"{pkg} mismatch: loaded={loaded} installed={installed}"
|
| 240 |
+
)
|
| 241 |
+
except Exception as e:
|
| 242 |
+
out["ok"] = False
|
| 243 |
+
out["errors"].append(f"{type(e).__name__}: {e}")
|
| 244 |
+
print(json.dumps(out, default=str))
|
| 245 |
+
'''.lstrip()
|
| 246 |
+
|
| 247 |
+
res = subprocess.run([sys.executable, "-c", verify_script],
|
| 248 |
+
capture_output=True, text=True, timeout=180)
|
| 249 |
+
print(res.stdout if res.stdout else "<no stdout>")
|
| 250 |
+
if res.stderr:
|
| 251 |
+
print("---- subprocess stderr ----"); print(res.stderr)
|
| 252 |
+
|
| 253 |
+
# Parse the LAST line of stdout (others are prints from package init).
|
| 254 |
+
try:
|
| 255 |
+
last = res.stdout.strip().splitlines()[-1]
|
| 256 |
+
parsed = json.loads(last)
|
| 257 |
+
except Exception:
|
| 258 |
+
parsed = {"ok": False, "errors": ["could not parse verification output"]}
|
| 259 |
+
|
| 260 |
+
ok = parsed.get("ok") and not parsed.get("errors")
|
| 261 |
+
d = parsed.get("details", {})
|
| 262 |
+
|
| 263 |
+
print("\n" + "=" * 72)
|
| 264 |
+
if ok:
|
| 265 |
+
print(" REPAIR OK")
|
| 266 |
+
print(f" torch : {d.get('torch_loaded')} (CUDA {d.get('torch_cuda')})")
|
| 267 |
+
print(f" torchvision : {d.get('torchvision_loaded')}")
|
| 268 |
+
print(f" bitsandbytes: {d.get('bitsandbytes_loaded')}")
|
| 269 |
+
print(f" unsloth : {d.get('unsloth_loaded')} | unsloth_zoo: {d.get('unsloth_zoo_loaded')}")
|
| 270 |
+
print(f" trl : {d.get('trl_loaded')}")
|
| 271 |
+
print(f" numpy : {d.get('numpy_loaded')} | Pillow: {d.get('pillow_loaded')}")
|
| 272 |
+
print(f" GPUs : {d.get('gpu_count')} (cuda_available={d.get('cuda_available')})")
|
| 273 |
+
print()
|
| 274 |
+
print(" -> If this kernel previously imported torch/numpy/Pillow/unsloth,")
|
| 275 |
+
print(" RESTART NOW (Run -> Restart kernel) before continuing to cell 5.")
|
| 276 |
+
print(" If this is a fresh kernel, you can proceed directly.")
|
| 277 |
+
else:
|
| 278 |
+
print(" RESTART REQUIRED β issues detected:")
|
| 279 |
+
for e in parsed.get("errors", []):
|
| 280 |
+
print(f" - {e}")
|
| 281 |
+
print()
|
| 282 |
+
print(" Action: Run -> Restart kernel, then re-run from cell 2.")
|
| 283 |
+
print("=" * 72)
|
| 284 |
+
"""
|
| 285 |
+
|
| 286 |
+
CELL_04_RESTART = """\
|
| 287 |
+
## β Restart kernel here if cell 3 said `RESTART REQUIRED`
|
| 288 |
+
|
| 289 |
+
Click **Run β Restart kernel** (or **Run β Restart & clear cell outputs**),
|
| 290 |
+
then resume from **cell 5**. Skipping the restart will produce ABI mismatch
|
| 291 |
+
errors at the first GPU op.
|
| 292 |
+
|
| 293 |
+
If cell 3 said `REPAIR OK` AND this is a fresh kernel that hasn't imported
|
| 294 |
+
torch/numpy/Pillow/unsloth yet, you can proceed to cell 5 directly.
|
| 295 |
+
"""
|
| 296 |
+
|
| 297 |
+
CELL_05_VERIFY = """\
|
| 298 |
+
# === CELL 5 β Post-restart verify (this kernel can import everything) ===
|
| 299 |
+
import importlib.metadata as md
|
| 300 |
+
|
| 301 |
+
print("--- Loaded versions in this kernel ---")
|
| 302 |
+
import torch, numpy, PIL, torchvision, unsloth, unsloth_zoo, bitsandbytes, trl, transformers, peft
|
| 303 |
+
|
| 304 |
+
versions = {
|
| 305 |
+
"torch": torch.__version__,
|
| 306 |
+
"torchvision": torchvision.__version__,
|
| 307 |
+
"numpy": numpy.__version__,
|
| 308 |
+
"Pillow": PIL.__version__,
|
| 309 |
+
"unsloth": unsloth.__version__,
|
| 310 |
+
"unsloth_zoo": unsloth_zoo.__version__,
|
| 311 |
+
"bitsandbytes": bitsandbytes.__version__,
|
| 312 |
+
"trl": trl.__version__,
|
| 313 |
+
"transformers": transformers.__version__,
|
| 314 |
+
"peft": peft.__version__,
|
| 315 |
+
}
|
| 316 |
+
all_ok = True
|
| 317 |
+
for k, v in versions.items():
|
| 318 |
+
try:
|
| 319 |
+
inst = md.version(k)
|
| 320 |
+
except md.PackageNotFoundError:
|
| 321 |
+
inst = "(not installed)"
|
| 322 |
+
# Tolerate local version suffixes like '+cu128'
|
| 323 |
+
flag = "OK" if inst.split("+")[0] == v.split("+")[0] else f"MISMATCH (installed={inst})"
|
| 324 |
+
if "MISMATCH" in flag:
|
| 325 |
+
all_ok = False
|
| 326 |
+
print(f" {k:14s}: loaded={v:20s} [{flag}]")
|
| 327 |
+
|
| 328 |
+
print()
|
| 329 |
+
print(f" CUDA available : {torch.cuda.is_available()}")
|
| 330 |
+
print(f" GPU count : {torch.cuda.device_count()}")
|
| 331 |
+
if torch.cuda.is_available():
|
| 332 |
+
for i in range(torch.cuda.device_count()):
|
| 333 |
+
p = torch.cuda.get_device_properties(i)
|
| 334 |
+
print(f" GPU {i} : {p.name} ({p.total_memory/1e9:.1f} GB)")
|
| 335 |
+
|
| 336 |
+
print()
|
| 337 |
+
print("OK" if all_ok else "NOT OK β re-run cell 3 and restart kernel.")
|
| 338 |
+
"""
|
| 339 |
+
|
| 340 |
+
CELL_06_REPO = """\
|
| 341 |
+
# === CELL 6 β Mount the ER-MAP repo into /kaggle/working ===
|
| 342 |
+
import os, subprocess, sys
|
| 343 |
+
|
| 344 |
+
# OPTION A: clone a public GitHub fork (preferred). Edit GIT_URL.
|
| 345 |
+
GIT_URL = "https://github.com/<your-fork>/Meta_Finals.git"
|
| 346 |
+
BRANCH = "main"
|
| 347 |
+
REPO_ROOT = "/kaggle/working/Meta_Finals"
|
| 348 |
+
|
| 349 |
+
# OPTION B: Kaggle Dataset upload β set this if you uploaded the repo
|
| 350 |
+
# as a Kaggle Dataset named "ermap-source" (Add Data -> Upload).
|
| 351 |
+
DATASET_DIR = "/kaggle/input/ermap-source"
|
| 352 |
+
|
| 353 |
+
if not os.path.isdir(f"{REPO_ROOT}/ER_MAP"):
|
| 354 |
+
if "<your-fork>" not in GIT_URL:
|
| 355 |
+
print(f"Cloning {GIT_URL}@{BRANCH} -> {REPO_ROOT}...")
|
| 356 |
+
out = subprocess.run(
|
| 357 |
+
["git", "clone", "--depth", "1", "-b", BRANCH, GIT_URL, REPO_ROOT],
|
| 358 |
+
capture_output=True, text=True,
|
| 359 |
+
)
|
| 360 |
+
print(out.stdout); print(out.stderr)
|
| 361 |
+
elif os.path.isdir(DATASET_DIR):
|
| 362 |
+
print(f"Copying {DATASET_DIR} -> {REPO_ROOT}...")
|
| 363 |
+
import shutil
|
| 364 |
+
shutil.copytree(DATASET_DIR, REPO_ROOT, dirs_exist_ok=True)
|
| 365 |
+
|
| 366 |
+
assert os.path.isdir(f"{REPO_ROOT}/ER_MAP"), (
|
| 367 |
+
"Repo not found.\\n"
|
| 368 |
+
" - Edit GIT_URL above to your GitHub fork, OR\\n"
|
| 369 |
+
" - Upload the repo as a Kaggle Dataset named 'ermap-source' (Add Data -> Upload)."
|
| 370 |
+
)
|
| 371 |
+
|
| 372 |
+
sys.path.insert(0, REPO_ROOT)
|
| 373 |
+
sys.path.insert(0, f"{REPO_ROOT}/kaggle")
|
| 374 |
+
print(f"OK. Repo at {REPO_ROOT}")
|
| 375 |
+
"""
|
| 376 |
+
|
| 377 |
+
CELL_07_SECRETS = """\
|
| 378 |
+
# === CELL 7 β Wire Kaggle Secrets into env vars ===
|
| 379 |
+
import os
|
| 380 |
+
from kaggle_helpers import load_kaggle_secrets, kaggle_env_summary
|
| 381 |
+
|
| 382 |
+
load_kaggle_secrets()
|
| 383 |
+
kaggle_env_summary()
|
| 384 |
+
|
| 385 |
+
# Hard fail if no Groq key β training would silently use mock LLMs.
|
| 386 |
+
assert any(os.environ.get(k) for k in (
|
| 387 |
+
"GROQ_NURSE_API_KEY", "GROQ_PATIENT_API_KEY",
|
| 388 |
+
"GROQ_EMPATHY_JUDGE_API_KEY", "GROQ_MEDICAL_JUDGE_API_KEY",
|
| 389 |
+
"GROQ_API_KEY",
|
| 390 |
+
)), ("No Groq key found in Kaggle Secrets. "
|
| 391 |
+
"Add at least GROQ_NURSE_API_KEY in Add-ons -> Secrets.")
|
| 392 |
+
print("OK β at least one Groq key is wired.")
|
| 393 |
+
"""
|
| 394 |
+
|
| 395 |
+
CELL_08_HF = """\
|
| 396 |
+
# === CELL 8 β Hugging Face Hub config (for checkpoint backup) ===
|
| 397 |
+
import os
|
| 398 |
+
from kaggle_helpers import push_checkpoint_to_hub, download_checkpoint_from_hub
|
| 399 |
+
|
| 400 |
+
# EDIT the line below to your HF model id (e.g. "udayd/ermap-doctor-lora").
|
| 401 |
+
HF_PUSH_REPO = "<your-username>/ermap-doctor-lora"
|
| 402 |
+
# To resume from a previous run, paste the same repo id here. Empty = fresh.
|
| 403 |
+
HF_RESUME_REPO = ""
|
| 404 |
+
|
| 405 |
+
RESUME_DIR = "/kaggle/working/checkpoints/resume"
|
| 406 |
+
if HF_RESUME_REPO:
|
| 407 |
+
download_checkpoint_from_hub(HF_RESUME_REPO, RESUME_DIR)
|
| 408 |
+
contents = os.listdir(RESUME_DIR) if os.path.isdir(RESUME_DIR) else []
|
| 409 |
+
print(f"Resume dir: {contents or '(empty)'}")
|
| 410 |
+
else:
|
| 411 |
+
print("Starting fresh β no resume.")
|
| 412 |
+
|
| 413 |
+
if "<your-username>" in HF_PUSH_REPO:
|
| 414 |
+
print("\\nWARNING: HF_PUSH_REPO still has <your-username> placeholder.")
|
| 415 |
+
print(" Checkpoints will NOT be pushed to HF Hub.")
|
| 416 |
+
print(" Edit the cell above and re-run before training if you want backups.")
|
| 417 |
+
"""
|
| 418 |
+
|
| 419 |
+
CELL_09_HYPERPARAMS = """\
|
| 420 |
+
# === CELL 9 β GRPO hyperparameters ===
|
| 421 |
+
import os
|
| 422 |
+
|
| 423 |
+
MODEL_NAME = "unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit"
|
| 424 |
+
GROUP_SIZE = 2
|
| 425 |
+
LEARNING_RATE = 5e-6
|
| 426 |
+
KL_BETA = 0.04
|
| 427 |
+
OUTPUT_DIR = "/kaggle/working/er_map_grpo_checkpoints"
|
| 428 |
+
PUSH_EVERY_EPS = 20
|
| 429 |
+
USE_WANDB = False # WANDB conflicts with protobuf 7 on Kaggle base image
|
| 430 |
+
NUM_EPISODES = 200 # hard cap; early-stop usually finishes first
|
| 431 |
+
|
| 432 |
+
# --- Per-phase reward thresholds (constant for this run) -------------------
|
| 433 |
+
# After every GRPO update we look at the last CONVERGENCE_WINDOW groups; if
|
| 434 |
+
# ALL of them belong to the same current phase AND each has
|
| 435 |
+
# rolling_avg_reward >= PHASE_REWARD_TARGETS[current_phase] AND
|
| 436 |
+
# rolling_win_rate >= PHASE_MIN_WIN_RATE, we either:
|
| 437 |
+
# - force-promote to the next phase (Phase 1 / Phase 2), OR
|
| 438 |
+
# - terminate training (Phase 3).
|
| 439 |
+
EARLY_STOP_ENABLED = True
|
| 440 |
+
PHASE_REWARD_TARGETS = {1: 1.2, 2: 1.1, 3: 1.0}
|
| 441 |
+
PHASE_MIN_WIN_RATE = 0.20
|
| 442 |
+
CONVERGENCE_WINDOW = 3
|
| 443 |
+
|
| 444 |
+
# --- Per-episode budget controls (read by triage_env) ----------------------
|
| 445 |
+
os.environ["ERMAP_MAX_EPISODE_STEPS"] = "20"
|
| 446 |
+
os.environ["ERMAP_MAX_INTERNAL_EXCHANGES"] = "5"
|
| 447 |
+
|
| 448 |
+
# --- Groq traffic-shaping (8B for actors, 70B for judges) ------------------
|
| 449 |
+
# High-volume conversational roles (Nurse + Patient) on the 8B-instant pool
|
| 450 |
+
# (500K TPD, 14,400 RPD); the two judges stay on 70B-versatile because their
|
| 451 |
+
# grading quality directly shapes the reward signal.
|
| 452 |
+
os.environ["ERMAP_NURSE_MODEL"] = "llama-3.1-8b-instant"
|
| 453 |
+
os.environ["ERMAP_PATIENT_MODEL"] = "llama-3.1-8b-instant"
|
| 454 |
+
os.environ["ERMAP_EMPATHY_JUDGE_MODEL"] = "llama-3.3-70b-versatile"
|
| 455 |
+
os.environ["ERMAP_MEDICAL_JUDGE_MODEL"] = "llama-3.3-70b-versatile"
|
| 456 |
+
|
| 457 |
+
print("Hyperparameters set:")
|
| 458 |
+
print(f" NUM_EPISODES = {NUM_EPISODES}")
|
| 459 |
+
print(f" GROUP_SIZE = {GROUP_SIZE}")
|
| 460 |
+
print(f" PHASE_REWARD_TARGETS = {PHASE_REWARD_TARGETS}")
|
| 461 |
+
print(f" PHASE_MIN_WIN_RATE = {PHASE_MIN_WIN_RATE}")
|
| 462 |
+
print(f" CONVERGENCE_WINDOW = {CONVERGENCE_WINDOW}")
|
| 463 |
+
print(f" Nurse / Patient = llama-3.1-8b-instant (actors, high-volume)")
|
| 464 |
+
print(f" Empathy / Med Judge = llama-3.3-70b-versatile (graders, quality)")
|
| 465 |
+
"""
|
| 466 |
+
|
| 467 |
+
CELL_10_PREFLIGHT = """\
|
| 468 |
+
# === CELL 10 β Pre-flight: Groq routing + key liveness ===
|
| 469 |
+
# Verifies that:
|
| 470 |
+
# - each role is routed to the model you set in cell 9, and
|
| 471 |
+
# - each role's Groq key actually answers a 1-token "PING" prompt.
|
| 472 |
+
|
| 473 |
+
import os
|
| 474 |
+
from ER_MAP.envs.api_router import AgentRouter
|
| 475 |
+
|
| 476 |
+
router = AgentRouter()
|
| 477 |
+
expected = {
|
| 478 |
+
"nurse": "llama-3.1-8b-instant",
|
| 479 |
+
"patient": "llama-3.1-8b-instant",
|
| 480 |
+
"empathy_judge": "llama-3.3-70b-versatile",
|
| 481 |
+
"medical_judge": "llama-3.3-70b-versatile",
|
| 482 |
+
}
|
| 483 |
+
|
| 484 |
+
print("=" * 60); print(" PRE-FLIGHT β Groq routing + smoke test"); print("=" * 60)
|
| 485 |
+
all_pass = True
|
| 486 |
+
for role, exp in expected.items():
|
| 487 |
+
actual = router._models.get(role, "?")
|
| 488 |
+
routing_ok = (actual == exp)
|
| 489 |
+
client = router._clients.get(role)
|
| 490 |
+
|
| 491 |
+
if client is None:
|
| 492 |
+
print(f" [SKIP] {role:14s} -> no Groq client (key missing)")
|
| 493 |
+
all_pass = False
|
| 494 |
+
continue
|
| 495 |
+
|
| 496 |
+
try:
|
| 497 |
+
resp = client.chat.completions.create(
|
| 498 |
+
model=exp,
|
| 499 |
+
messages=[{"role": "user", "content": "Reply with exactly: PING"}],
|
| 500 |
+
max_tokens=4, temperature=0,
|
| 501 |
+
)
|
| 502 |
+
api_ok = "PING" in (resp.choices[0].message.content or "").upper()
|
| 503 |
+
err = ""
|
| 504 |
+
except Exception as e:
|
| 505 |
+
api_ok = False
|
| 506 |
+
err = f" ({type(e).__name__}: {str(e)[:80]})"
|
| 507 |
+
|
| 508 |
+
flag = "PASS" if (routing_ok and api_ok) else "FAIL"
|
| 509 |
+
if flag == "FAIL":
|
| 510 |
+
all_pass = False
|
| 511 |
+
print(f" [{flag}] {role:14s} -> {actual:30s} "
|
| 512 |
+
f"routing={'ok' if routing_ok else 'WRONG'}, "
|
| 513 |
+
f"api={'ok' if api_ok else 'fail'}{err}")
|
| 514 |
+
|
| 515 |
+
print("=" * 60)
|
| 516 |
+
print("OK" if all_pass else "NOT OK β fix routing/keys before training.")
|
| 517 |
+
print("=" * 60)
|
| 518 |
+
assert all_pass, "Pre-flight failed; do not proceed to training."
|
| 519 |
+
"""
|
| 520 |
+
|
| 521 |
+
CELL_11_DRYRUN = """\
|
| 522 |
+
# === CELL 11 β Dry-run smoke test (no GPU, no model load) ===
|
| 523 |
+
# Verifies the curriculum scheduler + reward verifier + per-phase early-stop
|
| 524 |
+
# wiring before we burn GPU minutes on the real run.
|
| 525 |
+
|
| 526 |
+
from ER_MAP.training.train_grpo import train
|
| 527 |
+
|
| 528 |
+
_ = train(
|
| 529 |
+
num_episodes=8,
|
| 530 |
+
group_size=2,
|
| 531 |
+
model_name=MODEL_NAME,
|
| 532 |
+
learning_rate=LEARNING_RATE,
|
| 533 |
+
kl_beta=KL_BETA,
|
| 534 |
+
output_dir="/kaggle/working/_dryrun",
|
| 535 |
+
dry_run=True,
|
| 536 |
+
phase_reward_targets=PHASE_REWARD_TARGETS,
|
| 537 |
+
phase_min_win_rate=PHASE_MIN_WIN_RATE,
|
| 538 |
+
convergence_window=CONVERGENCE_WINDOW,
|
| 539 |
+
early_stop=EARLY_STOP_ENABLED,
|
| 540 |
+
)
|
| 541 |
+
print("\\nDry-run OK β scheduler + verifier + per-phase early-stop wiring is healthy.")
|
| 542 |
+
"""
|
| 543 |
+
|
| 544 |
+
CELL_12_HOOK = """\
|
| 545 |
+
# === CELL 12 β Wire periodic HF Hub push into training ===
|
| 546 |
+
# We monkey-patch save_lora_adapters so every checkpoint dump also pushes
|
| 547 |
+
# the LoRA adapter to HF Hub. Failures are non-fatal β training keeps
|
| 548 |
+
# running even if a push fails (e.g. transient HF 502).
|
| 549 |
+
|
| 550 |
+
from ER_MAP.training import train_grpo as _tg
|
| 551 |
+
_original_save = _tg.save_lora_adapters
|
| 552 |
+
|
| 553 |
+
def save_lora_adapters_with_push(model, tokenizer, output_dir):
|
| 554 |
+
_original_save(model, tokenizer, output_dir)
|
| 555 |
+
if HF_PUSH_REPO and "<your-username>" not in HF_PUSH_REPO:
|
| 556 |
+
try:
|
| 557 |
+
push_checkpoint_to_hub(
|
| 558 |
+
output_dir, HF_PUSH_REPO,
|
| 559 |
+
commit_message=f"checkpoint @ {os.path.basename(output_dir)}",
|
| 560 |
+
)
|
| 561 |
+
except Exception as e:
|
| 562 |
+
print(f" [hub-push] non-fatal failure: {e}")
|
| 563 |
+
|
| 564 |
+
_tg.save_lora_adapters = save_lora_adapters_with_push
|
| 565 |
+
print("Hub-push hook installed.")
|
| 566 |
+
"""
|
| 567 |
+
|
| 568 |
+
CELL_13_TRAIN_MD = """\
|
| 569 |
+
## 13 Β· Run real training (the 4β6 hour cell)
|
| 570 |
+
|
| 571 |
+
**Estimated wall-clock on Kaggle T4 Γ2:**
|
| 572 |
+
|
| 573 |
+
- ~3β5 min per episode (6β14 env steps Γ Doctor.generate + 4β8 Γ Groq calls)
|
| 574 |
+
- ~1β2 min amortized per GRPO update (G=2 trajectories Γ response-token log-probs)
|
| 575 |
+
- **Per-group β 8β12 min** (2 episodes + 1 update)
|
| 576 |
+
|
| 577 |
+
| Phase | Typical episodes to reach target | Wall-clock |
|
| 578 |
+
|---|---|---|
|
| 579 |
+
| 1 (target `+1.2` Γ 3) | 12 β 24 episodes (6 β 12 groups) | ~1.0 β 2.0 h |
|
| 580 |
+
| 2 (target `+1.1` Γ 3) | 16 β 32 episodes (8 β 16 groups) | ~1.5 β 2.5 h |
|
| 581 |
+
| 3 (target `+1.0` Γ 3) | 20 β 50 episodes (10 β 25 groups) | ~2.0 β 4.0 h |
|
| 582 |
+
| **Total** | 50 β 100 episodes | **~4.5 β 8.5 h** |
|
| 583 |
+
|
| 584 |
+
If `NUM_EPISODES=200` is exhausted before Phase 3 converges, training
|
| 585 |
+
stops at the cap and the latest LoRA checkpoint is on HF Hub already
|
| 586 |
+
(we push every 20 episodes), so resume in a fresh session via
|
| 587 |
+
`HF_RESUME_REPO` in cell 8.
|
| 588 |
+
"""
|
| 589 |
+
|
| 590 |
+
CELL_13_TRAIN = """\
|
| 591 |
+
# === CELL 13 β REAL TRAINING (4-6 h cell) ===
|
| 592 |
+
metrics = train(
|
| 593 |
+
num_episodes=NUM_EPISODES,
|
| 594 |
+
group_size=GROUP_SIZE,
|
| 595 |
+
model_name=MODEL_NAME,
|
| 596 |
+
groq_api_key=os.environ.get("GROQ_NURSE_API_KEY", "")
|
| 597 |
+
or os.environ.get("GROQ_API_KEY", ""),
|
| 598 |
+
learning_rate=LEARNING_RATE,
|
| 599 |
+
kl_beta=KL_BETA,
|
| 600 |
+
use_wandb=USE_WANDB,
|
| 601 |
+
output_dir=OUTPUT_DIR,
|
| 602 |
+
dry_run=False,
|
| 603 |
+
phase_reward_targets=PHASE_REWARD_TARGETS,
|
| 604 |
+
phase_min_win_rate=PHASE_MIN_WIN_RATE,
|
| 605 |
+
convergence_window=CONVERGENCE_WINDOW,
|
| 606 |
+
early_stop=EARLY_STOP_ENABLED,
|
| 607 |
+
)
|
| 608 |
+
print(f"\\nTraining returned {len(metrics)} metric records.")
|
| 609 |
+
"""
|
| 610 |
+
|
| 611 |
+
CELL_14_FINAL_PUSH = """\
|
| 612 |
+
# === CELL 14 β Final push: adapters + merged fp16 ===
|
| 613 |
+
FINAL_LORA_DIR = f"{OUTPUT_DIR}/final_lora"
|
| 614 |
+
FINAL_MERGED_DIR = f"{OUTPUT_DIR}/final_merged_fp16"
|
| 615 |
+
|
| 616 |
+
if HF_PUSH_REPO and "<your-username>" not in HF_PUSH_REPO:
|
| 617 |
+
push_checkpoint_to_hub(FINAL_LORA_DIR, HF_PUSH_REPO,
|
| 618 |
+
commit_message="final LoRA adapter")
|
| 619 |
+
if os.path.isdir(FINAL_MERGED_DIR):
|
| 620 |
+
push_checkpoint_to_hub(FINAL_MERGED_DIR, f"{HF_PUSH_REPO}-merged",
|
| 621 |
+
commit_message="final merged fp16")
|
| 622 |
+
print(f"Final checkpoints pushed: https://huggingface.co/{HF_PUSH_REPO}")
|
| 623 |
+
else:
|
| 624 |
+
print("HF_PUSH_REPO not configured β skipping final push.")
|
| 625 |
+
"""
|
| 626 |
+
|
| 627 |
+
CELL_15_PLOTS_MD = """\
|
| 628 |
+
## 15 Β· Per-phase training graphs (one dashboard per curriculum phase)
|
| 629 |
+
|
| 630 |
+
We render a 6-panel dashboard for **every phase that contains episodes**,
|
| 631 |
+
plus a cross-phase overview and a phase-comparison bar chart. All PNGs are
|
| 632 |
+
written to `er_map_grpo_checkpoints/plots/` and uploaded to HF Hub in the
|
| 633 |
+
next cell so they survive Kaggle session expiry.
|
| 634 |
+
|
| 635 |
+
Each per-phase dashboard contains:
|
| 636 |
+
|
| 637 |
+
1. **Reward growth** β raw scatter + rolling mean (w=10) + verified rolling mean
|
| 638 |
+
2. **Rolling win rate** β w=20 win-rate evolution within the phase
|
| 639 |
+
3. **Outcome distribution over time** β stacked bars (WIN/PARTIAL/INCORRECT/AMA_LOSS/FATAL_LOSS)
|
| 640 |
+
4. **Reward components** β mean of each component (process / treatment / empathy / labs / etc.)
|
| 641 |
+
5. **GRPO update stats** β loss + KL divergence per group update
|
| 642 |
+
6. **Episode length distribution** β histogram of step counts
|
| 643 |
+
"""
|
| 644 |
+
|
| 645 |
+
CELL_15_PLOTS = """\
|
| 646 |
+
# === CELL 15 β Per-phase training dashboards ===
|
| 647 |
+
from ER_MAP.plotting import plot_per_phase_dashboards
|
| 648 |
+
from IPython.display import Image, display, Markdown
|
| 649 |
+
|
| 650 |
+
PLOTS_DIR = f"{OUTPUT_DIR}/plots"
|
| 651 |
+
written = plot_per_phase_dashboards(
|
| 652 |
+
metrics_path=f"{OUTPUT_DIR}/training_metrics.json",
|
| 653 |
+
output_dir=PLOTS_DIR,
|
| 654 |
+
)
|
| 655 |
+
|
| 656 |
+
print(f"Saved {len(written)} chart(s) to {PLOTS_DIR}:")
|
| 657 |
+
for name, path in written.items():
|
| 658 |
+
size_kb = os.path.getsize(path) / 1024
|
| 659 |
+
print(f" {name:<28s} -> {path} ({size_kb:.0f} KB)")
|
| 660 |
+
|
| 661 |
+
# Display each chart inline so the operator sees them without leaving Kaggle.
|
| 662 |
+
ordered = (sorted(k for k in written if k.startswith("phase"))
|
| 663 |
+
+ ["all_phases_overview", "all_phases_comparison"])
|
| 664 |
+
for key in ordered:
|
| 665 |
+
if key not in written:
|
| 666 |
+
continue
|
| 667 |
+
display(Markdown(f"### {key.replace('_', ' ').title()}"))
|
| 668 |
+
display(Image(filename=written[key]))
|
| 669 |
+
"""
|
| 670 |
+
|
| 671 |
+
CELL_16_PUSH_PLOTS = """\
|
| 672 |
+
# === CELL 16 β Push plots to HF Hub ===
|
| 673 |
+
if HF_PUSH_REPO and "<your-username>" not in HF_PUSH_REPO:
|
| 674 |
+
push_checkpoint_to_hub(PLOTS_DIR, HF_PUSH_REPO,
|
| 675 |
+
commit_message="per-phase training plots")
|
| 676 |
+
print(f"Plots pushed: https://huggingface.co/{HF_PUSH_REPO}/tree/main")
|
| 677 |
+
else:
|
| 678 |
+
print("HF_PUSH_REPO not configured β plots stay only in /kaggle/working/.")
|
| 679 |
+
"""
|
| 680 |
+
|
| 681 |
+
CELL_17_INFER_MD = """\
|
| 682 |
+
## 17 Β· (Optional) Inference smoke-test on the trained model
|
| 683 |
+
|
| 684 |
+
Catches the classic 'merge path looked OK but the saved model emits garbage'
|
| 685 |
+
failure mode before the demo.
|
| 686 |
+
"""
|
| 687 |
+
|
| 688 |
+
CELL_17_INFER = """\
|
| 689 |
+
# === CELL 17 β Inference smoke-test on the trained model ===
|
| 690 |
+
from ER_MAP.training.train_grpo import generate_doctor_action, load_model_and_tokenizer
|
| 691 |
+
from peft import PeftModel
|
| 692 |
+
|
| 693 |
+
base_model, tok = load_model_and_tokenizer(model_name=MODEL_NAME)
|
| 694 |
+
trained = PeftModel.from_pretrained(base_model, FINAL_LORA_DIR)
|
| 695 |
+
|
| 696 |
+
test_obs = (
|
| 697 |
+
'{"event":"episode_start","nurse_experience":"veteran",'
|
| 698 |
+
'"message":"Patient with chest pain, HR 120, BP 90/60, vague history.",'
|
| 699 |
+
'"soap_summary":{}}'
|
| 700 |
+
)
|
| 701 |
+
for i in range(3):
|
| 702 |
+
print(f"\\n--- Sample {i+1} ---")
|
| 703 |
+
print(generate_doctor_action(trained, tok, test_obs, max_new_tokens=160))
|
| 704 |
+
"""
|
| 705 |
+
|
| 706 |
+
|
| 707 |
+
# ---------------------------------------------------------------------------
|
| 708 |
+
# Quickstart markdown (sibling file)
|
| 709 |
+
# ---------------------------------------------------------------------------
|
| 710 |
+
|
| 711 |
+
QUICKSTART_MD = """\
|
| 712 |
+
# Kaggle Quickstart β ER-MAP GRPO Training (v3 stable)
|
| 713 |
+
|
| 714 |
+
The Kaggle notebook is in `kaggle/train_ermap_grpo_kaggle.ipynb`. This file
|
| 715 |
+
is the cheat sheet for running it end-to-end without the dependency hell
|
| 716 |
+
that bit us in earlier attempts.
|
| 717 |
+
|
| 718 |
+
## 0. Prerequisites (one-time)
|
| 719 |
+
|
| 720 |
+
1. **GitHub fork** of this repo. The notebook clones from a public fork at
|
| 721 |
+
cell 6 β edit `GIT_URL`. Alternatively, upload the repo as a Kaggle
|
| 722 |
+
Dataset named `ermap-source` (Add Data β Upload).
|
| 723 |
+
2. **Hugging Face write token** (`HF_TOKEN`) for pushing the trained
|
| 724 |
+
adapter. Create at https://huggingface.co/settings/tokens (fine-grained,
|
| 725 |
+
write access on a single model repo is enough).
|
| 726 |
+
3. **Five Groq keys** (one each for Nurse / Patient / Empathy Judge /
|
| 727 |
+
Medical Judge / shared fallback). Free-tier accounts are fine; the
|
| 728 |
+
per-account limits multiply across keys.
|
| 729 |
+
|
| 730 |
+
## 1. Create the Kaggle notebook
|
| 731 |
+
|
| 732 |
+
1. Sign in to https://www.kaggle.com/code β **New Notebook**.
|
| 733 |
+
2. Right sidebar:
|
| 734 |
+
- Accelerator: **GPU T4 Γ2** (or P100)
|
| 735 |
+
- Internet: **On**
|
| 736 |
+
- Persistence: Files only
|
| 737 |
+
3. **File β Upload Notebook** β choose `kaggle/train_ermap_grpo_kaggle.ipynb`
|
| 738 |
+
from this repo.
|
| 739 |
+
|
| 740 |
+
## 2. Add Kaggle Secrets
|
| 741 |
+
|
| 742 |
+
Add-ons β Secrets β Add a new secret. Required labels (exactly):
|
| 743 |
+
|
| 744 |
+
| Label | Value |
|
| 745 |
+
|---|---|
|
| 746 |
+
| `GROQ_NURSE_API_KEY` | your nurse Groq key |
|
| 747 |
+
| `GROQ_PATIENT_API_KEY` | your patient Groq key |
|
| 748 |
+
| `GROQ_EMPATHY_JUDGE_API_KEY` | your empathy-judge Groq key |
|
| 749 |
+
| `GROQ_MEDICAL_JUDGE_API_KEY` | your medical-judge Groq key |
|
| 750 |
+
| `HF_TOKEN` | your HF write token |
|
| 751 |
+
| `WANDB_API_KEY` *(optional)* | your W&B key (skip β disabled by default) |
|
| 752 |
+
|
| 753 |
+
The notebook reads them via `kaggle_helpers.load_kaggle_secrets()` and
|
| 754 |
+
exports them as env vars.
|
| 755 |
+
|
| 756 |
+
## 3. Edit two placeholders in the notebook
|
| 757 |
+
|
| 758 |
+
- **Cell 6:** `GIT_URL = "https://github.com/<your-fork>/Meta_Finals.git"`
|
| 759 |
+
- **Cell 8:** `HF_PUSH_REPO = "<your-username>/ermap-doctor-lora"`
|
| 760 |
+
|
| 761 |
+
If you uploaded the repo as a Kaggle Dataset instead, leave `GIT_URL` as the
|
| 762 |
+
placeholder β cell 6 will detect `/kaggle/input/ermap-source` and copy from
|
| 763 |
+
there.
|
| 764 |
+
|
| 765 |
+
## 4. Run order (the only sequence that works)
|
| 766 |
+
|
| 767 |
+
| Cell | What it does | Expected output |
|
| 768 |
+
|---|---|---|
|
| 769 |
+
| 2 | GPU + disk + python + internet sanity check | GPU listed, disk free > 8 GB |
|
| 770 |
+
| 3 | **REPAIR** β pin torch 2.10 cu128, reinstall bitsandbytes, upgrade unsloth | `REPAIR OK` (or `RESTART REQUIRED`) |
|
| 771 |
+
| **(restart)** | If cell 3 said RESTART REQUIRED β Run β Restart kernel | β |
|
| 772 |
+
| 5 | Post-restart import verify | All `OK`, GPUs listed |
|
| 773 |
+
| 6 | Clone / mount the repo | `OK. Repo at /kaggle/working/Meta_Finals` |
|
| 774 |
+
| 7 | Wire Kaggle Secrets β env vars | `OK β at least one Groq key is wired` |
|
| 775 |
+
| 8 | HF Hub config | `Starting fresh β no resume.` |
|
| 776 |
+
| 9 | Hyperparameters (P1=+1.2, P2=+1.1, P3=+1.0) | thresholds printed |
|
| 777 |
+
| 10 | **Pre-flight** β Groq routing + 4Γ PING | 4Γ `[PASS]`, then `OK` |
|
| 778 |
+
| 11 | Dry-run smoke test (no GPU) | `Dry-run OK` |
|
| 779 |
+
| 12 | Wire HF push hook | `Hub-push hook installed.` |
|
| 780 |
+
| 13 | **REAL TRAINING** (4β6 h) | per-group rolling stats, eventual `EARLY STOP` |
|
| 781 |
+
| 14 | Final push to HF | `Final checkpoints pushed: https://huggingface.co/...` |
|
| 782 |
+
| 15 | Per-phase plots | 5 PNGs displayed inline |
|
| 783 |
+
| 16 | Push plots to HF | `Plots pushed: ...` |
|
| 784 |
+
| 17 | Inference smoke-test (optional) | 3 sample Doctor actions printed |
|
| 785 |
+
|
| 786 |
+
## 5. Common failures & fixes
|
| 787 |
+
|
| 788 |
+
| Symptom | Root cause | Fix |
|
| 789 |
+
|---|---|---|
|
| 790 |
+
| `numpy was upgraded mid-session` | numpy import poisoned by a previous cell | Restart kernel, re-run from cell 3 |
|
| 791 |
+
| `Pillow incompatible with torchvision` | Pillow ABI mismatch | Restart kernel, re-run from cell 3 |
|
| 792 |
+
| `PyTorch and torchvision compiled with different CUDA major` | torch upgraded to cu13 by a transient resolve | Re-run cell 3 (it pins cu128) and restart |
|
| 793 |
+
| `cannot import name 'create_gradient_checkpointing_buffer'` | unsloth β unsloth_zoo version drift | Re-run cell 3 (upgrades both in lockstep) |
|
| 794 |
+
| `libnvJitLink.so.13 missing` | bitsandbytes built against different CUDA | Re-run cell 3 (force-reinstalls bitsandbytes after torch pin) |
|
| 795 |
+
| Disk usage > quota | Kaggle's 20 GB working partition fills up | First line of cell 3 cleans `/tmp` and pip cache |
|
| 796 |
+
| Pre-flight `[FAIL]` for a role | Groq key dead / quota exceeded | Generate a new key in console.groq.com β update Kaggle Secret β re-run cell 7+10 |
|
| 797 |
+
| `[FAIL]` says `routing=WRONG` | env var not set when `AgentRouter()` was constructed | Re-run cell 9 BEFORE cell 10 |
|
| 798 |
+
| Training freezes at episode 1 for >10 min | Doctor.generate hung; Unsloth import broke silently | Check cell 5 output for `unsloth` line; restart kernel and re-run cell 3 if missing |
|
| 799 |
+
|
| 800 |
+
## 6. What the trained model gives you
|
| 801 |
+
|
| 802 |
+
After cell 13 finishes (or hits the 12 h Kaggle session cap), you have:
|
| 803 |
+
|
| 804 |
+
- `OUTPUT_DIR/final_lora/` β LoRA adapter weights (~50 MB), pushed to
|
| 805 |
+
`HF_PUSH_REPO`
|
| 806 |
+
- `OUTPUT_DIR/final_merged_fp16/` β full Llama-3.1-8B fp16 merge with the
|
| 807 |
+
adapter applied (~16 GB), pushed to `HF_PUSH_REPO-merged`
|
| 808 |
+
- `OUTPUT_DIR/training_metrics.json` β per-episode rewards, outcomes,
|
| 809 |
+
rolling stats β input for the per-phase plots
|
| 810 |
+
- `OUTPUT_DIR/plots/*.png` β 5 dashboards (one per phase + cross-phase
|
| 811 |
+
overview + comparison bar)
|
| 812 |
+
|
| 813 |
+
Use the LoRA adapter for the demo (quick to load, runs on a 4050 6 GB at
|
| 814 |
+
~30 tok/s); use the merged fp16 if you need to host on a Vercel/HF Space
|
| 815 |
+
without `peft`.
|
| 816 |
+
"""
|
| 817 |
+
|
| 818 |
+
|
| 819 |
+
# ---------------------------------------------------------------------------
|
| 820 |
+
# Build the notebook
|
| 821 |
+
# ---------------------------------------------------------------------------
|
| 822 |
+
|
| 823 |
+
def build_notebook() -> dict:
|
| 824 |
+
cells = [
|
| 825 |
+
md_cell(CELL_01_TITLE), # 0
|
| 826 |
+
code_cell(CELL_02_SANITY), # 1
|
| 827 |
+
code_cell(CELL_03_REPAIR), # 2
|
| 828 |
+
md_cell(CELL_04_RESTART), # 3
|
| 829 |
+
code_cell(CELL_05_VERIFY), # 4
|
| 830 |
+
code_cell(CELL_06_REPO), # 5
|
| 831 |
+
code_cell(CELL_07_SECRETS), # 6
|
| 832 |
+
code_cell(CELL_08_HF), # 7
|
| 833 |
+
code_cell(CELL_09_HYPERPARAMS), # 8
|
| 834 |
+
code_cell(CELL_10_PREFLIGHT), # 9
|
| 835 |
+
code_cell(CELL_11_DRYRUN), # 10
|
| 836 |
+
code_cell(CELL_12_HOOK), # 11
|
| 837 |
+
md_cell(CELL_13_TRAIN_MD), # 12
|
| 838 |
+
code_cell(CELL_13_TRAIN), # 13
|
| 839 |
+
code_cell(CELL_14_FINAL_PUSH), # 14
|
| 840 |
+
md_cell(CELL_15_PLOTS_MD), # 15
|
| 841 |
+
code_cell(CELL_15_PLOTS), # 16
|
| 842 |
+
code_cell(CELL_16_PUSH_PLOTS), # 17
|
| 843 |
+
md_cell(CELL_17_INFER_MD), # 18
|
| 844 |
+
code_cell(CELL_17_INFER), # 19
|
| 845 |
+
]
|
| 846 |
+
return {
|
| 847 |
+
"cells": cells,
|
| 848 |
+
"metadata": {
|
| 849 |
+
"kernelspec": {
|
| 850 |
+
"display_name": "Python 3",
|
| 851 |
+
"language": "python",
|
| 852 |
+
"name": "python3",
|
| 853 |
+
},
|
| 854 |
+
"language_info": {
|
| 855 |
+
"name": "python",
|
| 856 |
+
"version": "3.10",
|
| 857 |
+
},
|
| 858 |
+
},
|
| 859 |
+
"nbformat": 4,
|
| 860 |
+
"nbformat_minor": 5,
|
| 861 |
+
}
|
| 862 |
+
|
| 863 |
+
|
| 864 |
+
def main() -> None:
|
| 865 |
+
here = Path(__file__).parent
|
| 866 |
+
nb_path = here / "train_ermap_grpo_kaggle.ipynb"
|
| 867 |
+
qs_path = here / "KAGGLE_QUICKSTART.md"
|
| 868 |
+
|
| 869 |
+
nb = build_notebook()
|
| 870 |
+
nb_path.write_text(json.dumps(nb, indent=1, ensure_ascii=False), encoding="utf-8")
|
| 871 |
+
qs_path.write_text(QUICKSTART_MD, encoding="utf-8")
|
| 872 |
+
|
| 873 |
+
n_md = sum(1 for c in nb["cells"] if c["cell_type"] == "markdown")
|
| 874 |
+
n_code = sum(1 for c in nb["cells"] if c["cell_type"] == "code")
|
| 875 |
+
print(f"Wrote {nb_path} ({len(nb['cells'])} cells: {n_md} md / {n_code} code)")
|
| 876 |
+
print(f"Wrote {qs_path} ({len(QUICKSTART_MD.splitlines())} lines)")
|
| 877 |
+
|
| 878 |
+
|
| 879 |
+
if __name__ == "__main__":
|
| 880 |
+
main()
|
|
@@ -4,33 +4,45 @@
|
|
| 4 |
"cell_type": "markdown",
|
| 5 |
"metadata": {},
|
| 6 |
"source": [
|
| 7 |
-
"# ER-MAP β Doctor Agent GRPO Training (Kaggle Free-Tier)\n",
|
| 8 |
-
"\n",
|
| 9 |
-
"**
|
| 10 |
-
"\n",
|
| 11 |
-
"
|
| 12 |
-
"
|
| 13 |
-
"
|
| 14 |
-
"
|
| 15 |
-
"
|
| 16 |
-
"
|
| 17 |
-
"\n",
|
| 18 |
-
"
|
| 19 |
-
"-
|
| 20 |
-
"
|
| 21 |
-
"-
|
| 22 |
-
"\n",
|
| 23 |
-
"**
|
| 24 |
-
"
|
| 25 |
-
"
|
| 26 |
-
"-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
]
|
| 35 |
},
|
| 36 |
{
|
|
@@ -39,8 +51,38 @@
|
|
| 39 |
"metadata": {},
|
| 40 |
"outputs": [],
|
| 41 |
"source": [
|
| 42 |
-
"
|
| 43 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
]
|
| 45 |
},
|
| 46 |
{
|
|
@@ -49,47 +91,159 @@
|
|
| 49 |
"metadata": {},
|
| 50 |
"outputs": [],
|
| 51 |
"source": [
|
| 52 |
-
"#
|
| 53 |
-
"#
|
| 54 |
-
"#
|
| 55 |
-
"#
|
| 56 |
-
"
|
| 57 |
-
"
|
| 58 |
-
"
|
| 59 |
-
"\n",
|
| 60 |
-
"import
|
| 61 |
-
"
|
| 62 |
-
"
|
| 63 |
-
"
|
| 64 |
-
"
|
| 65 |
-
"
|
| 66 |
-
"\n",
|
| 67 |
-
"#
|
| 68 |
-
"
|
| 69 |
-
"
|
| 70 |
-
"
|
| 71 |
-
"
|
| 72 |
-
"#
|
| 73 |
-
"
|
| 74 |
-
"
|
| 75 |
-
" import
|
| 76 |
-
"
|
| 77 |
-
"
|
| 78 |
-
"\n",
|
| 79 |
-
"
|
| 80 |
-
"
|
| 81 |
-
"
|
| 82 |
-
")\n",
|
| 83 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
]
|
| 85 |
},
|
| 86 |
{
|
| 87 |
"cell_type": "markdown",
|
| 88 |
"metadata": {},
|
| 89 |
"source": [
|
| 90 |
-
"##
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
"\n",
|
| 92 |
-
"
|
|
|
|
| 93 |
]
|
| 94 |
},
|
| 95 |
{
|
|
@@ -98,26 +252,46 @@
|
|
| 98 |
"metadata": {},
|
| 99 |
"outputs": [],
|
| 100 |
"source": [
|
| 101 |
-
"#
|
| 102 |
-
"
|
| 103 |
-
"
|
| 104 |
-
"
|
| 105 |
-
"
|
| 106 |
-
"
|
| 107 |
-
"
|
| 108 |
-
"
|
| 109 |
-
"
|
| 110 |
-
"
|
| 111 |
-
"
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
"
|
| 119 |
-
"\n",
|
| 120 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
]
|
| 122 |
},
|
| 123 |
{
|
|
@@ -126,28 +300,63 @@
|
|
| 126 |
"metadata": {},
|
| 127 |
"outputs": [],
|
| 128 |
"source": [
|
| 129 |
-
"
|
| 130 |
-
"
|
| 131 |
-
"
|
| 132 |
-
"\n",
|
| 133 |
-
"
|
| 134 |
-
"
|
| 135 |
-
"
|
| 136 |
-
"
|
| 137 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
")\n",
|
| 139 |
"\n",
|
| 140 |
-
"
|
| 141 |
-
"
|
|
|
|
| 142 |
]
|
| 143 |
},
|
| 144 |
{
|
| 145 |
-
"cell_type": "
|
|
|
|
| 146 |
"metadata": {},
|
|
|
|
| 147 |
"source": [
|
| 148 |
-
"#
|
|
|
|
|
|
|
| 149 |
"\n",
|
| 150 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
]
|
| 152 |
},
|
| 153 |
{
|
|
@@ -156,68 +365,27 @@
|
|
| 156 |
"metadata": {},
|
| 157 |
"outputs": [],
|
| 158 |
"source": [
|
| 159 |
-
"
|
| 160 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 161 |
"\n",
|
| 162 |
"RESUME_DIR = \"/kaggle/working/checkpoints/resume\"\n",
|
| 163 |
"if HF_RESUME_REPO:\n",
|
| 164 |
" download_checkpoint_from_hub(HF_RESUME_REPO, RESUME_DIR)\n",
|
| 165 |
-
"
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
"cell_type": "markdown",
|
| 170 |
-
"metadata": {},
|
| 171 |
-
"source": [
|
| 172 |
-
"## 5 Β· Configure the GRPO run\n",
|
| 173 |
-
"\n",
|
| 174 |
-
"The defaults below are tuned for **one 12-hour Kaggle session** on a single T4. They produce a clean upward reward-growth curve through Phase 1 + early Phase 2; if you have a second session, lower `--episodes` is fine because LoRA adapters resume cleanly via Step 4 above.\n",
|
| 175 |
-
"\n",
|
| 176 |
-
"| Parameter | Value | Reason |\n",
|
| 177 |
-
"|---|---|---|\n",
|
| 178 |
-
"| Model | `unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit` | Llama-3-family small-tier, 4-bit ~5 GB |\n",
|
| 179 |
-
"| LoRA rank | 16 | balances expressivity vs speed on T4 |\n",
|
| 180 |
-
"| Group size G | 2 | Kaggle's T4 fits G=2 comfortably; G=4 needs 30+ min/group |\n",
|
| 181 |
-
"| Episodes (cap) | 120 | hard cap; early-stop usually finishes first |\n",
|
| 182 |
-
"| LR | 5e-6 | conservative, prevents catastrophic forgetting on small group |\n",
|
| 183 |
-
"| KL beta | 0.04 | matches the paper's recipe; restrains drift from base policy |\n",
|
| 184 |
-
"| Max episode steps | 20 | matches `triage_env.py` default |\n",
|
| 185 |
-
"| Internal exchanges | 5 | shorter than default (8) to fit within 12 h budget |\n",
|
| 186 |
-
"\n",
|
| 187 |
-
"### Train-until-optimal (per-phase reward thresholds)\n",
|
| 188 |
-
"\n",
|
| 189 |
-
"Training **never** runs for a fixed episode budget. After every GRPO update we look at the last `CONVERGENCE_WINDOW=3` groups; if **all three** belong to the same current phase AND each has `rolling_avg_reward >= PHASE_REWARD_TARGETS[current_phase]`, we either:\n",
|
| 190 |
-
"\n",
|
| 191 |
-
"- **Phase 1 / Phase 2** β force-promote to the next curriculum phase (the buffer is then cleared so stale entries don't satisfy the next phase's check).\n",
|
| 192 |
-
"- **Phase 3** β terminate training (this is the 'optimal rewards constantly received' criterion).\n",
|
| 193 |
-
"\n",
|
| 194 |
-
"Why per-phase, not a single global bar? The phases are not equally difficult β Phase 1 wins are worth ~`+2.0` on the reward scale (full terminal_win on a clean SOAP) while Phase 3 wins routinely cost `~0.5` in consent / empathy friction even when the diagnosis is correct. A single global `+1.5` would either gate Phase 3 too aggressively or pass Phase 1 with garbage Phase-2 behavior.\n",
|
| 195 |
-
"\n",
|
| 196 |
-
"| Phase | Default target | Why this number | Action when met |\n",
|
| 197 |
-
"|---|---|---|---|\n",
|
| 198 |
-
"| 1 β Tool Mastery | `+1.5` | A Phase-1 episode that uses tools cleanly + discharges with the correct treatment lands at `+1.6 .. +2.0`. Sustaining `+1.5` means the model has tool-format down. | force-promote to Phase 2 |\n",
|
| 199 |
-
"| 2 β Clinical Reasoning | `+1.2` | Phase 2 adds noisy SOAP and mixed compliance. A solid clinician policy lands at `+1.2 .. +1.5`. | force-promote to Phase 3 |\n",
|
| 200 |
-
"| 3 β Empathetic Negotiation | `+1.0` | Phase 3 imposes empathy + consent costs (`-0.3..-0.6` per episode even on wins). Sustained `+1.0` here is genuinely hard and is the hackathon success criterion. | END TRAINING |\n",
|
| 201 |
-
"\n",
|
| 202 |
-
"| Knob | Default | Meaning |\n",
|
| 203 |
-
"|---|---|---|\n",
|
| 204 |
-
"| `PHASE_REWARD_TARGETS` | `{1: 1.5, 2: 1.2, 3: 1.0}` | per-phase sustained rolling-avg-reward bar |\n",
|
| 205 |
-
"| `PHASE_MIN_WIN_RATE` | `0.20` | soft floor on rolling win rate (sanity check) |\n",
|
| 206 |
-
"| `CONVERGENCE_WINDOW` | `3` | how many consecutive groups must hit the bar |\n",
|
| 207 |
-
"| `EARLY_STOP_ENABLED` | `True` | set `False` to always burn the full `NUM_EPISODES` budget |\n",
|
| 208 |
-
"\n",
|
| 209 |
-
"### Estimated wall-clock per phase on Kaggle T4 Γ2\n",
|
| 210 |
-
"\n",
|
| 211 |
-
"Each episode = 6β14 env steps Γ (Doctor.generate β 2β3 s) + 4β8 Groq calls (β 0.4β1.0 s each). One GRPO update over `G=2` trajectories = 1 forward + 1 backward over response tokens β 60β120 s on T4. Net per-group wall-clock β **8β12 minutes**.\n",
|
| 212 |
-
"\n",
|
| 213 |
-
"| Phase | Typical episodes to reach target | Wall-clock | Notes |\n",
|
| 214 |
-
"|---|---|---|---|\n",
|
| 215 |
-
"| 1 (target `+1.5` Γ 3) | 16 β 30 episodes (8 β 15 groups) | **~1.5 β 2.5 h** | Easy patients + clean SOAP β tool-format is the only thing the model has to learn. |\n",
|
| 216 |
-
"| 2 (target `+1.2` Γ 3) | 24 β 40 episodes (12 β 20 groups) | **~2.0 β 3.5 h** | Mixed-compliance patients + noisy SOAP. Bulk of the policy improvement happens here. |\n",
|
| 217 |
-
"| 3 (target `+1.0` Γ 3) | 30 β 60 episodes (15 β 30 groups) | **~2.5 β 5.0 h** | Hard patients + empathy/consent costs. May not converge in 12 h on a fresh base; that's why `NUM_EPISODES=120` is the hard cap. |\n",
|
| 218 |
-
"| **Total** | 70 β 130 episodes | **~6 β 11 h** | Fits inside Kaggle's 12 h GPU session with ~1 h margin. |\n",
|
| 219 |
"\n",
|
| 220 |
-
"
|
|
|
|
|
|
|
|
|
|
| 221 |
]
|
| 222 |
},
|
| 223 |
{
|
|
@@ -226,65 +394,110 @@
|
|
| 226 |
"metadata": {},
|
| 227 |
"outputs": [],
|
| 228 |
"source": [
|
| 229 |
-
"#
|
|
|
|
|
|
|
| 230 |
"MODEL_NAME = \"unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit\"\n",
|
| 231 |
-
"NUM_EPISODES = 120 # HARD CAP; early-stop usually finishes first\n",
|
| 232 |
"GROUP_SIZE = 2\n",
|
| 233 |
"LEARNING_RATE = 5e-6\n",
|
| 234 |
"KL_BETA = 0.04\n",
|
| 235 |
"OUTPUT_DIR = \"/kaggle/working/er_map_grpo_checkpoints\"\n",
|
| 236 |
"PUSH_EVERY_EPS = 20\n",
|
| 237 |
-
"USE_WANDB =
|
| 238 |
-
"\n",
|
| 239 |
-
"
|
| 240 |
-
"#
|
| 241 |
-
"#
|
| 242 |
-
"#
|
| 243 |
-
"#
|
|
|
|
|
|
|
| 244 |
"# - terminate training (Phase 3).\n",
|
| 245 |
-
"
|
| 246 |
-
"
|
| 247 |
-
"
|
| 248 |
-
"
|
| 249 |
-
"
|
| 250 |
-
"
|
| 251 |
-
"CONVERGENCE_WINDOW = 3 # 3 consecutive groups must qualify\n",
|
| 252 |
-
"\n",
|
| 253 |
-
"# --- Per-episode budget controls (passed via env vars) ---\n",
|
| 254 |
"os.environ[\"ERMAP_MAX_EPISODE_STEPS\"] = \"20\"\n",
|
| 255 |
"os.environ[\"ERMAP_MAX_INTERNAL_EXCHANGES\"] = \"5\"\n",
|
| 256 |
-
"
|
| 257 |
-
"#
|
| 258 |
-
"#
|
| 259 |
-
"#
|
| 260 |
-
"#
|
| 261 |
-
"# versatile because their grading quality directly shapes the reward.\n",
|
| 262 |
"os.environ[\"ERMAP_NURSE_MODEL\"] = \"llama-3.1-8b-instant\"\n",
|
| 263 |
"os.environ[\"ERMAP_PATIENT_MODEL\"] = \"llama-3.1-8b-instant\"\n",
|
| 264 |
"os.environ[\"ERMAP_EMPATHY_JUDGE_MODEL\"] = \"llama-3.3-70b-versatile\"\n",
|
| 265 |
"os.environ[\"ERMAP_MEDICAL_JUDGE_MODEL\"] = \"llama-3.3-70b-versatile\"\n",
|
| 266 |
"\n",
|
| 267 |
-
"
|
| 268 |
-
"
|
| 269 |
-
"
|
| 270 |
-
"
|
| 271 |
-
"
|
| 272 |
-
"
|
| 273 |
-
"
|
| 274 |
-
"
|
| 275 |
-
" ]\n",
|
| 276 |
-
"), (\"No Groq key found in Kaggle Secrets β add at least \"\n",
|
| 277 |
-
" \"GROQ_NURSE_API_KEY before running training.\")\n",
|
| 278 |
-
"print(\"Hyperparameters and env vars set.\")"
|
| 279 |
]
|
| 280 |
},
|
| 281 |
{
|
| 282 |
-
"cell_type": "
|
|
|
|
| 283 |
"metadata": {},
|
|
|
|
| 284 |
"source": [
|
| 285 |
-
"#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 286 |
"\n",
|
| 287 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 288 |
]
|
| 289 |
},
|
| 290 |
{
|
|
@@ -293,6 +506,10 @@
|
|
| 293 |
"metadata": {},
|
| 294 |
"outputs": [],
|
| 295 |
"source": [
|
|
|
|
|
|
|
|
|
|
|
|
|
| 296 |
"from ER_MAP.training.train_grpo import train\n",
|
| 297 |
"\n",
|
| 298 |
"_ = train(\n",
|
|
@@ -303,17 +520,12 @@
|
|
| 303 |
" kl_beta=KL_BETA,\n",
|
| 304 |
" output_dir=\"/kaggle/working/_dryrun\",\n",
|
| 305 |
" dry_run=True,\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 306 |
")\n",
|
| 307 |
-
"print(\"
|
| 308 |
-
]
|
| 309 |
-
},
|
| 310 |
-
{
|
| 311 |
-
"cell_type": "markdown",
|
| 312 |
-
"metadata": {},
|
| 313 |
-
"source": [
|
| 314 |
-
"## 7 Β· Wire periodic HF-Hub push into the training loop\n",
|
| 315 |
-
"\n",
|
| 316 |
-
"We monkey-patch `save_lora_adapters` so every checkpoint dump also pushes to HF Hub. Failures are non-fatal β training keeps running even if the push fails."
|
| 317 |
]
|
| 318 |
},
|
| 319 |
{
|
|
@@ -322,18 +534,20 @@
|
|
| 322 |
"metadata": {},
|
| 323 |
"outputs": [],
|
| 324 |
"source": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 325 |
"from ER_MAP.training import train_grpo as _tg\n",
|
| 326 |
"_original_save = _tg.save_lora_adapters\n",
|
| 327 |
-
"_episode_marker = {\"n\": 0}\n",
|
| 328 |
"\n",
|
| 329 |
"def save_lora_adapters_with_push(model, tokenizer, output_dir):\n",
|
| 330 |
" _original_save(model, tokenizer, output_dir)\n",
|
| 331 |
-
" _episode_marker[\"n\"] += 1\n",
|
| 332 |
" if HF_PUSH_REPO and \"<your-username>\" not in HF_PUSH_REPO:\n",
|
| 333 |
" try:\n",
|
| 334 |
" push_checkpoint_to_hub(\n",
|
| 335 |
-
" output_dir,\n",
|
| 336 |
-
" HF_PUSH_REPO,\n",
|
| 337 |
" commit_message=f\"checkpoint @ {os.path.basename(output_dir)}\",\n",
|
| 338 |
" )\n",
|
| 339 |
" except Exception as e:\n",
|
|
@@ -347,18 +561,25 @@
|
|
| 347 |
"cell_type": "markdown",
|
| 348 |
"metadata": {},
|
| 349 |
"source": [
|
| 350 |
-
"##
|
| 351 |
-
"\n",
|
| 352 |
-
"
|
| 353 |
-
"\n",
|
| 354 |
-
"- ~3
|
| 355 |
-
"- ~1
|
| 356 |
-
"- **Per-group
|
| 357 |
-
"
|
| 358 |
-
"
|
| 359 |
-
"-
|
| 360 |
-
"
|
| 361 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 362 |
]
|
| 363 |
},
|
| 364 |
{
|
|
@@ -367,31 +588,24 @@
|
|
| 367 |
"metadata": {},
|
| 368 |
"outputs": [],
|
| 369 |
"source": [
|
|
|
|
| 370 |
"metrics = train(\n",
|
| 371 |
" num_episodes=NUM_EPISODES,\n",
|
| 372 |
" group_size=GROUP_SIZE,\n",
|
| 373 |
" model_name=MODEL_NAME,\n",
|
| 374 |
-
" groq_api_key=os.environ.get(\"GROQ_NURSE_API_KEY\", \"\")
|
|
|
|
| 375 |
" learning_rate=LEARNING_RATE,\n",
|
| 376 |
" kl_beta=KL_BETA,\n",
|
| 377 |
" use_wandb=USE_WANDB,\n",
|
| 378 |
" output_dir=OUTPUT_DIR,\n",
|
| 379 |
" dry_run=False,\n",
|
| 380 |
-
" # ----- Per-phase early-stop ('train until optimal rewards are constantly received') -----\n",
|
| 381 |
" phase_reward_targets=PHASE_REWARD_TARGETS,\n",
|
| 382 |
" phase_min_win_rate=PHASE_MIN_WIN_RATE,\n",
|
| 383 |
" convergence_window=CONVERGENCE_WINDOW,\n",
|
| 384 |
" early_stop=EARLY_STOP_ENABLED,\n",
|
| 385 |
-
")"
|
| 386 |
-
|
| 387 |
-
},
|
| 388 |
-
{
|
| 389 |
-
"cell_type": "markdown",
|
| 390 |
-
"metadata": {},
|
| 391 |
-
"source": [
|
| 392 |
-
"## 9 Β· Final push: adapters + merged fp16 weights\n",
|
| 393 |
-
"\n",
|
| 394 |
-
"The training loop already wrote `final_lora/` and `final_merged_fp16/` to `OUTPUT_DIR`. We push both to HF Hub so you can serve them from Vercel / a HF Space without re-running training."
|
| 395 |
]
|
| 396 |
},
|
| 397 |
{
|
|
@@ -400,20 +614,17 @@
|
|
| 400 |
"metadata": {},
|
| 401 |
"outputs": [],
|
| 402 |
"source": [
|
|
|
|
| 403 |
"FINAL_LORA_DIR = f\"{OUTPUT_DIR}/final_lora\"\n",
|
| 404 |
"FINAL_MERGED_DIR = f\"{OUTPUT_DIR}/final_merged_fp16\"\n",
|
| 405 |
"\n",
|
| 406 |
"if HF_PUSH_REPO and \"<your-username>\" not in HF_PUSH_REPO:\n",
|
| 407 |
-
" push_checkpoint_to_hub(\n",
|
| 408 |
-
"
|
| 409 |
-
" commit_message=\"final LoRA adapter\",\n",
|
| 410 |
-
" )\n",
|
| 411 |
" if os.path.isdir(FINAL_MERGED_DIR):\n",
|
| 412 |
-
" push_checkpoint_to_hub(\n",
|
| 413 |
-
"
|
| 414 |
-
"
|
| 415 |
-
" )\n",
|
| 416 |
-
" print(\"Final checkpoints pushed.\")\n",
|
| 417 |
"else:\n",
|
| 418 |
" print(\"HF_PUSH_REPO not configured β skipping final push.\")"
|
| 419 |
]
|
|
@@ -422,16 +633,20 @@
|
|
| 422 |
"cell_type": "markdown",
|
| 423 |
"metadata": {},
|
| 424 |
"source": [
|
| 425 |
-
"##
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 426 |
"\n",
|
| 427 |
-
"
|
| 428 |
"\n",
|
| 429 |
-
"**Each per-phase dashboard contains:**\n",
|
| 430 |
"1. **Reward growth** β raw scatter + rolling mean (w=10) + verified rolling mean\n",
|
| 431 |
-
"2. **Rolling win rate** β w=20 win
|
| 432 |
-
"3. **Outcome distribution over time** β stacked bars (WIN/PARTIAL/INCORRECT/AMA_LOSS/FATAL_LOSS)
|
| 433 |
-
"4. **Reward components** β mean of each component (process / treatment / empathy / labs / etc.)
|
| 434 |
-
"5. **GRPO update
|
| 435 |
"6. **Episode length distribution** β histogram of step counts"
|
| 436 |
]
|
| 437 |
},
|
|
@@ -441,6 +656,7 @@
|
|
| 441 |
"metadata": {},
|
| 442 |
"outputs": [],
|
| 443 |
"source": [
|
|
|
|
| 444 |
"from ER_MAP.plotting import plot_per_phase_dashboards\n",
|
| 445 |
"from IPython.display import Image, display, Markdown\n",
|
| 446 |
"\n",
|
|
@@ -450,54 +666,44 @@
|
|
| 450 |
" output_dir=PLOTS_DIR,\n",
|
| 451 |
")\n",
|
| 452 |
"\n",
|
| 453 |
-
"print(f\"Saved {len(written)} chart(s):\")\n",
|
| 454 |
"for name, path in written.items():\n",
|
| 455 |
" size_kb = os.path.getsize(path) / 1024\n",
|
| 456 |
" print(f\" {name:<28s} -> {path} ({size_kb:.0f} KB)\")\n",
|
| 457 |
"\n",
|
| 458 |
-
"# Display each chart inline
|
| 459 |
-
"
|
| 460 |
-
"
|
| 461 |
-
"
|
| 462 |
-
" sorted(k for k in written if k.startswith(\"phase\")) +\n",
|
| 463 |
-
" [\"all_phases_overview\", \"all_phases_comparison\"]\n",
|
| 464 |
-
")\n",
|
| 465 |
-
"for key in ordered_keys:\n",
|
| 466 |
" if key not in written:\n",
|
| 467 |
" continue\n",
|
| 468 |
" display(Markdown(f\"### {key.replace('_', ' ').title()}\"))\n",
|
| 469 |
" display(Image(filename=written[key]))"
|
| 470 |
]
|
| 471 |
},
|
| 472 |
-
{
|
| 473 |
-
"cell_type": "markdown",
|
| 474 |
-
"metadata": {},
|
| 475 |
-
"source": [
|
| 476 |
-
"## 10b Β· Push the plots to HF Hub (so they survive session expiry)"
|
| 477 |
-
]
|
| 478 |
-
},
|
| 479 |
{
|
| 480 |
"cell_type": "code",
|
| 481 |
"execution_count": null,
|
| 482 |
"metadata": {},
|
| 483 |
"outputs": [],
|
| 484 |
"source": [
|
|
|
|
| 485 |
"if HF_PUSH_REPO and \"<your-username>\" not in HF_PUSH_REPO:\n",
|
| 486 |
-
" push_checkpoint_to_hub(\n",
|
| 487 |
-
"
|
| 488 |
-
"
|
| 489 |
-
" )\n",
|
| 490 |
"else:\n",
|
| 491 |
-
" print(\"HF_PUSH_REPO not configured
|
| 492 |
]
|
| 493 |
},
|
| 494 |
{
|
| 495 |
"cell_type": "markdown",
|
| 496 |
"metadata": {},
|
| 497 |
"source": [
|
| 498 |
-
"##
|
| 499 |
"\n",
|
| 500 |
-
"Catches the classic
|
|
|
|
| 501 |
]
|
| 502 |
},
|
| 503 |
{
|
|
@@ -506,6 +712,7 @@
|
|
| 506 |
"metadata": {},
|
| 507 |
"outputs": [],
|
| 508 |
"source": [
|
|
|
|
| 509 |
"from ER_MAP.training.train_grpo import generate_doctor_action, load_model_and_tokenizer\n",
|
| 510 |
"from peft import PeftModel\n",
|
| 511 |
"\n",
|
|
@@ -536,4 +743,4 @@
|
|
| 536 |
},
|
| 537 |
"nbformat": 4,
|
| 538 |
"nbformat_minor": 5
|
| 539 |
-
}
|
|
|
|
| 4 |
"cell_type": "markdown",
|
| 5 |
"metadata": {},
|
| 6 |
"source": [
|
| 7 |
+
"# ER-MAP β Doctor Agent GRPO Training (Kaggle Free-Tier Β· v3 stable)\n",
|
| 8 |
+
"\n",
|
| 9 |
+
"Trains the **Doctor LLM** (Llama-3.1-8B-Instruct, 4-bit + LoRA r=16) via GRPO\n",
|
| 10 |
+
"with a 3-phase curriculum on Kaggle's free GPU. Designed to survive Kaggle's\n",
|
| 11 |
+
"pre-baked image quirks (numpy / Pillow ABI mismatches, torch + torchvision\n",
|
| 12 |
+
"CUDA-major mismatches, transient `unsloth_zoo` upgrades).\n",
|
| 13 |
+
"\n",
|
| 14 |
+
"## TL;DR β How to run this notebook\n",
|
| 15 |
+
"\n",
|
| 16 |
+
"1. **Notebook settings (right sidebar):**\n",
|
| 17 |
+
" - Accelerator: **GPU T4 Γ2** (or P100)\n",
|
| 18 |
+
" - Internet: **On**\n",
|
| 19 |
+
" - Persistence: Files only\n",
|
| 20 |
+
"2. **Kaggle Secrets** (Add-ons β Secrets):\n",
|
| 21 |
+
" - **Required:** `GROQ_NURSE_API_KEY`, `GROQ_PATIENT_API_KEY`,\n",
|
| 22 |
+
" `GROQ_EMPATHY_JUDGE_API_KEY`, `GROQ_MEDICAL_JUDGE_API_KEY`, `HF_TOKEN`\n",
|
| 23 |
+
" - **Optional:** `WANDB_API_KEY`\n",
|
| 24 |
+
"3. **Run cells 2 β 3 (sanity + REPAIR).** When cell 3 prints\n",
|
| 25 |
+
" `RESTART REQUIRED`, click **Run β Restart kernel**, then resume from cell 5.\n",
|
| 26 |
+
"4. **Run cells 5 β 11 (verify + configure + dry-run + pre-flight).** Each cell\n",
|
| 27 |
+
" should print an `OK` line before moving on.\n",
|
| 28 |
+
"5. **Run cell 13 (the long training cell, 4β6 hours).**\n",
|
| 29 |
+
"6. **Run cells 14 β 17 (final push + plots + inference smoke-test).**\n",
|
| 30 |
+
"\n",
|
| 31 |
+
"## Curriculum + reward thresholds (this run)\n",
|
| 32 |
+
"\n",
|
| 33 |
+
"Constant per-phase rolling-avg-reward bars; sustained for **3 consecutive\n",
|
| 34 |
+
"GRPO groups** triggers either a phase promotion or end-of-training.\n",
|
| 35 |
+
"\n",
|
| 36 |
+
"| Phase | Reward target (sustained Γ3 groups) | Action when met |\n",
|
| 37 |
+
"|---|---|---|\n",
|
| 38 |
+
"| 1 β Tool Mastery | `+1.2` | force-promote to Phase 2 |\n",
|
| 39 |
+
"| 2 β Clinical Reasoning | `+1.1` | force-promote to Phase 3 |\n",
|
| 40 |
+
"| 3 β Empathetic Negotiation | `+1.0` | END TRAINING |\n",
|
| 41 |
+
"\n",
|
| 42 |
+
"Why these numbers? The un-trained 8B Doctor's baseline on the same env is\n",
|
| 43 |
+
"`P1=+0.76, P2=+0.59, P3=+0.39`. Targets of `+1.2 / +1.1 / +1.0` correspond\n",
|
| 44 |
+
"to roughly `1.6Γ / 1.9Γ / 2.6Γ` improvement over baseline β a meaningful\n",
|
| 45 |
+
"signal but reachable inside Kaggle's 12 h session limit."
|
| 46 |
]
|
| 47 |
},
|
| 48 |
{
|
|
|
|
| 51 |
"metadata": {},
|
| 52 |
"outputs": [],
|
| 53 |
"source": [
|
| 54 |
+
"# === CELL 2 β Sanity check (GPU + disk + python + internet) ===\n",
|
| 55 |
+
"# Run this FIRST. If any check fails, fix it before running the REPAIR cell.\n",
|
| 56 |
+
"\n",
|
| 57 |
+
"import os, shutil, subprocess, sys, socket\n",
|
| 58 |
+
"\n",
|
| 59 |
+
"print(\"--- GPU ---\")\n",
|
| 60 |
+
"try:\n",
|
| 61 |
+
" print(subprocess.check_output(\n",
|
| 62 |
+
" [\"nvidia-smi\", \"--query-gpu=name,memory.total,memory.free\", \"--format=csv\"],\n",
|
| 63 |
+
" timeout=10,\n",
|
| 64 |
+
" ).decode())\n",
|
| 65 |
+
"except Exception as e:\n",
|
| 66 |
+
" print(f\"nvidia-smi failed: {e}\")\n",
|
| 67 |
+
" print(\"-> Set Accelerator to 'GPU T4 x2' in the right sidebar.\")\n",
|
| 68 |
+
"\n",
|
| 69 |
+
"print(\"--- Disk (/kaggle/working) ---\")\n",
|
| 70 |
+
"total, used, free = shutil.disk_usage(\"/kaggle/working\")\n",
|
| 71 |
+
"print(f\" total={total/1e9:5.1f} GB | used={used/1e9:5.1f} GB | free={free/1e9:5.1f} GB\")\n",
|
| 72 |
+
"if free < 8 * 1e9:\n",
|
| 73 |
+
" print(\" WARNING: free disk < 8 GB β repair cell may fail. \"\n",
|
| 74 |
+
" \"Consider 'Run > Restart and clear cell outputs' to reset /tmp.\")\n",
|
| 75 |
+
"\n",
|
| 76 |
+
"print(\"--- Python ---\")\n",
|
| 77 |
+
"print(f\" python={sys.version.split()[0]} | exe={sys.executable}\")\n",
|
| 78 |
+
"\n",
|
| 79 |
+
"print(\"--- Internet (api.groq.com:443) ---\")\n",
|
| 80 |
+
"try:\n",
|
| 81 |
+
" socket.create_connection((\"api.groq.com\", 443), timeout=5).close()\n",
|
| 82 |
+
" print(\" reachable\")\n",
|
| 83 |
+
"except Exception as e:\n",
|
| 84 |
+
" print(f\" UNREACHABLE: {e}\")\n",
|
| 85 |
+
" print(\" -> Settings (right sidebar) -> Internet -> ON\")"
|
| 86 |
]
|
| 87 |
},
|
| 88 |
{
|
|
|
|
| 91 |
"metadata": {},
|
| 92 |
"outputs": [],
|
| 93 |
"source": [
|
| 94 |
+
"# === CELL 3 β REPAIR CELL (idempotent full environment rebuild) ===\n",
|
| 95 |
+
"# Single source of truth for ER-MAP's GPU stack. Safe to re-run. After it\n",
|
| 96 |
+
"# finishes you'll see one of two final lines:\n",
|
| 97 |
+
"#\n",
|
| 98 |
+
"# RESTART REQUIRED -> Run -> Restart kernel, then resume from cell 5\n",
|
| 99 |
+
"# REPAIR OK -> proceed directly to cell 5\n",
|
| 100 |
+
"#\n",
|
| 101 |
+
"# Note: this cell only runs shell commands and one isolated subprocess.\n",
|
| 102 |
+
"# It deliberately does NOT `import torch / numpy / Pillow / unsloth` in the\n",
|
| 103 |
+
"# kernel, so re-running it after a botched install does not poison further\n",
|
| 104 |
+
"# attempts.\n",
|
| 105 |
+
"\n",
|
| 106 |
+
"print(\"=\" * 72); print(\" CELL 3 β REPAIR\"); print(\"=\" * 72)\n",
|
| 107 |
+
"\n",
|
| 108 |
+
"# 1. Clean caches (Kaggle's /kaggle/working is only 20 GB β installs\n",
|
| 109 |
+
"# routinely fill it after a few re-runs).\n",
|
| 110 |
+
"print(\"[1/6] Cleaning pip + tmp + HF dataset caches...\")\n",
|
| 111 |
+
"get_ipython().system('pip cache purge -q || true')\n",
|
| 112 |
+
"get_ipython().system('rm -rf /tmp/* /root/.cache/pip /root/.cache/huggingface/datasets 2>/dev/null || true')\n",
|
| 113 |
+
"\n",
|
| 114 |
+
"# 2. Pin torch + torchvision to the cu128 wheel (matches Kaggle's CUDA 12.8\n",
|
| 115 |
+
"# base image). DON'T let pip pull a generic CUDA-13 build β that breaks\n",
|
| 116 |
+
"# bitsandbytes (libnvJitLink.so.13 missing) and torchvision (CUDA-major\n",
|
| 117 |
+
"# mismatch RuntimeError at import time).\n",
|
| 118 |
+
"print(\"[2/6] Installing torch==2.10.0 + torchvision==0.25.0 (cu128)...\")\n",
|
| 119 |
+
"get_ipython().system('pip install -q --no-cache-dir --force-reinstall '\n",
|
| 120 |
+
" 'torch==2.10.0 torchvision==0.25.0 '\n",
|
| 121 |
+
" '--index-url https://download.pytorch.org/whl/cu128')\n",
|
| 122 |
+
"\n",
|
| 123 |
+
"# 3. Reinstall bitsandbytes against the now-pinned torch.\n",
|
| 124 |
+
"print(\"[3/6] Reinstalling bitsandbytes...\")\n",
|
| 125 |
+
"get_ipython().system('pip install -q --no-cache-dir --force-reinstall bitsandbytes')\n",
|
| 126 |
+
"\n",
|
| 127 |
+
"# 4. Upgrade unsloth + unsloth_zoo + trl in lockstep. unsloth and\n",
|
| 128 |
+
"# unsloth_zoo are released as a matched pair; if pip pulls a fresh\n",
|
| 129 |
+
"# unsloth_zoo against an old unsloth you get\n",
|
| 130 |
+
"# ImportError: cannot import name 'create_gradient_checkpointing_buffer'\n",
|
| 131 |
+
"print(\"[4/6] Upgrading unsloth + unsloth_zoo + trl...\")\n",
|
| 132 |
+
"get_ipython().system('pip install -q --upgrade --no-cache-dir '\n",
|
| 133 |
+
" 'unsloth unsloth_zoo \"trl>=0.18.2\"')\n",
|
| 134 |
+
"\n",
|
| 135 |
+
"# 5. ER-MAP runtime deps that aren't pre-installed on Kaggle.\n",
|
| 136 |
+
"print(\"[5/6] Installing ER-MAP runtime deps...\")\n",
|
| 137 |
+
"get_ipython().system('pip install -q --no-cache-dir '\n",
|
| 138 |
+
" '\"groq>=0.18.0\" \"huggingface_hub>=0.25.0\" '\n",
|
| 139 |
+
" '\"gymnasium>=0.29.0\" \"openenv-core>=0.1.0\"')\n",
|
| 140 |
+
"\n",
|
| 141 |
+
"# 6. Verify in a SUBPROCESS (so the parent kernel never imports any of these\n",
|
| 142 |
+
"# while pip is mid-flight, which is what causes the\n",
|
| 143 |
+
"# 'numpy was upgraded mid-session (loaded: X, installed: Y)' RuntimeError\n",
|
| 144 |
+
"# we kept hitting before).\n",
|
| 145 |
+
"print(\"[6/6] Verifying via subprocess...\")\n",
|
| 146 |
+
"import subprocess, sys, json\n",
|
| 147 |
+
"\n",
|
| 148 |
+
"verify_script = r'''\n",
|
| 149 |
+
"import json, sys\n",
|
| 150 |
+
"out = {\"ok\": True, \"details\": {}, \"errors\": []}\n",
|
| 151 |
+
"try:\n",
|
| 152 |
+
" import importlib.metadata as md\n",
|
| 153 |
+
" for pkg in (\"torch\", \"torchvision\", \"bitsandbytes\", \"unsloth\", \"unsloth_zoo\",\n",
|
| 154 |
+
" \"trl\", \"transformers\", \"peft\", \"accelerate\", \"groq\",\n",
|
| 155 |
+
" \"huggingface_hub\", \"gymnasium\", \"numpy\", \"Pillow\"):\n",
|
| 156 |
+
" try:\n",
|
| 157 |
+
" out[\"details\"][pkg + \"_installed\"] = md.version(pkg)\n",
|
| 158 |
+
" except md.PackageNotFoundError:\n",
|
| 159 |
+
" out[\"details\"][pkg + \"_installed\"] = None\n",
|
| 160 |
+
"\n",
|
| 161 |
+
" import torch, torchvision, numpy as np, PIL, unsloth, unsloth_zoo, bitsandbytes, trl\n",
|
| 162 |
+
" out[\"details\"][\"torch_loaded\"] = torch.__version__\n",
|
| 163 |
+
" out[\"details\"][\"torch_cuda\"] = torch.version.cuda\n",
|
| 164 |
+
" out[\"details\"][\"cuda_available\"] = bool(torch.cuda.is_available())\n",
|
| 165 |
+
" out[\"details\"][\"gpu_count\"] = int(torch.cuda.device_count())\n",
|
| 166 |
+
" out[\"details\"][\"torchvision_loaded\"] = torchvision.__version__\n",
|
| 167 |
+
" out[\"details\"][\"numpy_loaded\"] = np.__version__\n",
|
| 168 |
+
" out[\"details\"][\"pillow_loaded\"] = PIL.__version__\n",
|
| 169 |
+
" out[\"details\"][\"unsloth_loaded\"] = unsloth.__version__\n",
|
| 170 |
+
" out[\"details\"][\"unsloth_zoo_loaded\"] = unsloth_zoo.__version__\n",
|
| 171 |
+
" out[\"details\"][\"bitsandbytes_loaded\"] = bitsandbytes.__version__\n",
|
| 172 |
+
" out[\"details\"][\"trl_loaded\"] = trl.__version__\n",
|
| 173 |
+
"\n",
|
| 174 |
+
" # Cross-check loaded-vs-installed for the C-extension libs that bit us\n",
|
| 175 |
+
" # on every previous run.\n",
|
| 176 |
+
" for pkg, loaded_key, installed_key in [\n",
|
| 177 |
+
" (\"numpy\", \"numpy_loaded\", \"numpy_installed\"),\n",
|
| 178 |
+
" (\"Pillow\", \"pillow_loaded\", \"Pillow_installed\"),\n",
|
| 179 |
+
" (\"torch\", \"torch_loaded\", \"torch_installed\"),\n",
|
| 180 |
+
" ]:\n",
|
| 181 |
+
" loaded = out[\"details\"].get(loaded_key)\n",
|
| 182 |
+
" installed = out[\"details\"].get(installed_key)\n",
|
| 183 |
+
" if loaded and installed and loaded != installed:\n",
|
| 184 |
+
" # Strip any local-version suffix (e.g. '+cu128') before compare.\n",
|
| 185 |
+
" if loaded.split(\"+\")[0] != installed.split(\"+\")[0]:\n",
|
| 186 |
+
" out[\"errors\"].append(\n",
|
| 187 |
+
" f\"{pkg} mismatch: loaded={loaded} installed={installed}\"\n",
|
| 188 |
+
" )\n",
|
| 189 |
+
"except Exception as e:\n",
|
| 190 |
+
" out[\"ok\"] = False\n",
|
| 191 |
+
" out[\"errors\"].append(f\"{type(e).__name__}: {e}\")\n",
|
| 192 |
+
"print(json.dumps(out, default=str))\n",
|
| 193 |
+
"'''.lstrip()\n",
|
| 194 |
+
"\n",
|
| 195 |
+
"res = subprocess.run([sys.executable, \"-c\", verify_script],\n",
|
| 196 |
+
" capture_output=True, text=True, timeout=180)\n",
|
| 197 |
+
"print(res.stdout if res.stdout else \"<no stdout>\")\n",
|
| 198 |
+
"if res.stderr:\n",
|
| 199 |
+
" print(\"---- subprocess stderr ----\"); print(res.stderr)\n",
|
| 200 |
+
"\n",
|
| 201 |
+
"# Parse the LAST line of stdout (others are prints from package init).\n",
|
| 202 |
+
"try:\n",
|
| 203 |
+
" last = res.stdout.strip().splitlines()[-1]\n",
|
| 204 |
+
" parsed = json.loads(last)\n",
|
| 205 |
+
"except Exception:\n",
|
| 206 |
+
" parsed = {\"ok\": False, \"errors\": [\"could not parse verification output\"]}\n",
|
| 207 |
+
"\n",
|
| 208 |
+
"ok = parsed.get(\"ok\") and not parsed.get(\"errors\")\n",
|
| 209 |
+
"d = parsed.get(\"details\", {})\n",
|
| 210 |
+
"\n",
|
| 211 |
+
"print(\"\n",
|
| 212 |
+
"\" + \"=\" * 72)\n",
|
| 213 |
+
"if ok:\n",
|
| 214 |
+
" print(\" REPAIR OK\")\n",
|
| 215 |
+
" print(f\" torch : {d.get('torch_loaded')} (CUDA {d.get('torch_cuda')})\")\n",
|
| 216 |
+
" print(f\" torchvision : {d.get('torchvision_loaded')}\")\n",
|
| 217 |
+
" print(f\" bitsandbytes: {d.get('bitsandbytes_loaded')}\")\n",
|
| 218 |
+
" print(f\" unsloth : {d.get('unsloth_loaded')} | unsloth_zoo: {d.get('unsloth_zoo_loaded')}\")\n",
|
| 219 |
+
" print(f\" trl : {d.get('trl_loaded')}\")\n",
|
| 220 |
+
" print(f\" numpy : {d.get('numpy_loaded')} | Pillow: {d.get('pillow_loaded')}\")\n",
|
| 221 |
+
" print(f\" GPUs : {d.get('gpu_count')} (cuda_available={d.get('cuda_available')})\")\n",
|
| 222 |
+
" print()\n",
|
| 223 |
+
" print(\" -> If this kernel previously imported torch/numpy/Pillow/unsloth,\")\n",
|
| 224 |
+
" print(\" RESTART NOW (Run -> Restart kernel) before continuing to cell 5.\")\n",
|
| 225 |
+
" print(\" If this is a fresh kernel, you can proceed directly.\")\n",
|
| 226 |
+
"else:\n",
|
| 227 |
+
" print(\" RESTART REQUIRED β issues detected:\")\n",
|
| 228 |
+
" for e in parsed.get(\"errors\", []):\n",
|
| 229 |
+
" print(f\" - {e}\")\n",
|
| 230 |
+
" print()\n",
|
| 231 |
+
" print(\" Action: Run -> Restart kernel, then re-run from cell 2.\")\n",
|
| 232 |
+
"print(\"=\" * 72)"
|
| 233 |
]
|
| 234 |
},
|
| 235 |
{
|
| 236 |
"cell_type": "markdown",
|
| 237 |
"metadata": {},
|
| 238 |
"source": [
|
| 239 |
+
"## β Restart kernel here if cell 3 said `RESTART REQUIRED`\n",
|
| 240 |
+
"\n",
|
| 241 |
+
"Click **Run β Restart kernel** (or **Run β Restart & clear cell outputs**),\n",
|
| 242 |
+
"then resume from **cell 5**. Skipping the restart will produce ABI mismatch\n",
|
| 243 |
+
"errors at the first GPU op.\n",
|
| 244 |
"\n",
|
| 245 |
+
"If cell 3 said `REPAIR OK` AND this is a fresh kernel that hasn't imported\n",
|
| 246 |
+
"torch/numpy/Pillow/unsloth yet, you can proceed to cell 5 directly."
|
| 247 |
]
|
| 248 |
},
|
| 249 |
{
|
|
|
|
| 252 |
"metadata": {},
|
| 253 |
"outputs": [],
|
| 254 |
"source": [
|
| 255 |
+
"# === CELL 5 β Post-restart verify (this kernel can import everything) ===\n",
|
| 256 |
+
"import importlib.metadata as md\n",
|
| 257 |
+
"\n",
|
| 258 |
+
"print(\"--- Loaded versions in this kernel ---\")\n",
|
| 259 |
+
"import torch, numpy, PIL, torchvision, unsloth, unsloth_zoo, bitsandbytes, trl, transformers, peft\n",
|
| 260 |
+
"\n",
|
| 261 |
+
"versions = {\n",
|
| 262 |
+
" \"torch\": torch.__version__,\n",
|
| 263 |
+
" \"torchvision\": torchvision.__version__,\n",
|
| 264 |
+
" \"numpy\": numpy.__version__,\n",
|
| 265 |
+
" \"Pillow\": PIL.__version__,\n",
|
| 266 |
+
" \"unsloth\": unsloth.__version__,\n",
|
| 267 |
+
" \"unsloth_zoo\": unsloth_zoo.__version__,\n",
|
| 268 |
+
" \"bitsandbytes\": bitsandbytes.__version__,\n",
|
| 269 |
+
" \"trl\": trl.__version__,\n",
|
| 270 |
+
" \"transformers\": transformers.__version__,\n",
|
| 271 |
+
" \"peft\": peft.__version__,\n",
|
| 272 |
+
"}\n",
|
| 273 |
+
"all_ok = True\n",
|
| 274 |
+
"for k, v in versions.items():\n",
|
| 275 |
+
" try:\n",
|
| 276 |
+
" inst = md.version(k)\n",
|
| 277 |
+
" except md.PackageNotFoundError:\n",
|
| 278 |
+
" inst = \"(not installed)\"\n",
|
| 279 |
+
" # Tolerate local version suffixes like '+cu128'\n",
|
| 280 |
+
" flag = \"OK\" if inst.split(\"+\")[0] == v.split(\"+\")[0] else f\"MISMATCH (installed={inst})\"\n",
|
| 281 |
+
" if \"MISMATCH\" in flag:\n",
|
| 282 |
+
" all_ok = False\n",
|
| 283 |
+
" print(f\" {k:14s}: loaded={v:20s} [{flag}]\")\n",
|
| 284 |
+
"\n",
|
| 285 |
+
"print()\n",
|
| 286 |
+
"print(f\" CUDA available : {torch.cuda.is_available()}\")\n",
|
| 287 |
+
"print(f\" GPU count : {torch.cuda.device_count()}\")\n",
|
| 288 |
+
"if torch.cuda.is_available():\n",
|
| 289 |
+
" for i in range(torch.cuda.device_count()):\n",
|
| 290 |
+
" p = torch.cuda.get_device_properties(i)\n",
|
| 291 |
+
" print(f\" GPU {i} : {p.name} ({p.total_memory/1e9:.1f} GB)\")\n",
|
| 292 |
+
"\n",
|
| 293 |
+
"print()\n",
|
| 294 |
+
"print(\"OK\" if all_ok else \"NOT OK β re-run cell 3 and restart kernel.\")"
|
| 295 |
]
|
| 296 |
},
|
| 297 |
{
|
|
|
|
| 300 |
"metadata": {},
|
| 301 |
"outputs": [],
|
| 302 |
"source": [
|
| 303 |
+
"# === CELL 6 β Mount the ER-MAP repo into /kaggle/working ===\n",
|
| 304 |
+
"import os, subprocess, sys\n",
|
| 305 |
+
"\n",
|
| 306 |
+
"# OPTION A: clone a public GitHub fork (preferred). Edit GIT_URL.\n",
|
| 307 |
+
"GIT_URL = \"https://github.com/<your-fork>/Meta_Finals.git\"\n",
|
| 308 |
+
"BRANCH = \"main\"\n",
|
| 309 |
+
"REPO_ROOT = \"/kaggle/working/Meta_Finals\"\n",
|
| 310 |
+
"\n",
|
| 311 |
+
"# OPTION B: Kaggle Dataset upload β set this if you uploaded the repo\n",
|
| 312 |
+
"# as a Kaggle Dataset named \"ermap-source\" (Add Data -> Upload).\n",
|
| 313 |
+
"DATASET_DIR = \"/kaggle/input/ermap-source\"\n",
|
| 314 |
+
"\n",
|
| 315 |
+
"if not os.path.isdir(f\"{REPO_ROOT}/ER_MAP\"):\n",
|
| 316 |
+
" if \"<your-fork>\" not in GIT_URL:\n",
|
| 317 |
+
" print(f\"Cloning {GIT_URL}@{BRANCH} -> {REPO_ROOT}...\")\n",
|
| 318 |
+
" out = subprocess.run(\n",
|
| 319 |
+
" [\"git\", \"clone\", \"--depth\", \"1\", \"-b\", BRANCH, GIT_URL, REPO_ROOT],\n",
|
| 320 |
+
" capture_output=True, text=True,\n",
|
| 321 |
+
" )\n",
|
| 322 |
+
" print(out.stdout); print(out.stderr)\n",
|
| 323 |
+
" elif os.path.isdir(DATASET_DIR):\n",
|
| 324 |
+
" print(f\"Copying {DATASET_DIR} -> {REPO_ROOT}...\")\n",
|
| 325 |
+
" import shutil\n",
|
| 326 |
+
" shutil.copytree(DATASET_DIR, REPO_ROOT, dirs_exist_ok=True)\n",
|
| 327 |
+
"\n",
|
| 328 |
+
"assert os.path.isdir(f\"{REPO_ROOT}/ER_MAP\"), (\n",
|
| 329 |
+
" \"Repo not found.\\n\"\n",
|
| 330 |
+
" \" - Edit GIT_URL above to your GitHub fork, OR\\n\"\n",
|
| 331 |
+
" \" - Upload the repo as a Kaggle Dataset named 'ermap-source' (Add Data -> Upload).\"\n",
|
| 332 |
")\n",
|
| 333 |
"\n",
|
| 334 |
+
"sys.path.insert(0, REPO_ROOT)\n",
|
| 335 |
+
"sys.path.insert(0, f\"{REPO_ROOT}/kaggle\")\n",
|
| 336 |
+
"print(f\"OK. Repo at {REPO_ROOT}\")"
|
| 337 |
]
|
| 338 |
},
|
| 339 |
{
|
| 340 |
+
"cell_type": "code",
|
| 341 |
+
"execution_count": null,
|
| 342 |
"metadata": {},
|
| 343 |
+
"outputs": [],
|
| 344 |
"source": [
|
| 345 |
+
"# === CELL 7 β Wire Kaggle Secrets into env vars ===\n",
|
| 346 |
+
"import os\n",
|
| 347 |
+
"from kaggle_helpers import load_kaggle_secrets, kaggle_env_summary\n",
|
| 348 |
"\n",
|
| 349 |
+
"load_kaggle_secrets()\n",
|
| 350 |
+
"kaggle_env_summary()\n",
|
| 351 |
+
"\n",
|
| 352 |
+
"# Hard fail if no Groq key β training would silently use mock LLMs.\n",
|
| 353 |
+
"assert any(os.environ.get(k) for k in (\n",
|
| 354 |
+
" \"GROQ_NURSE_API_KEY\", \"GROQ_PATIENT_API_KEY\",\n",
|
| 355 |
+
" \"GROQ_EMPATHY_JUDGE_API_KEY\", \"GROQ_MEDICAL_JUDGE_API_KEY\",\n",
|
| 356 |
+
" \"GROQ_API_KEY\",\n",
|
| 357 |
+
")), (\"No Groq key found in Kaggle Secrets. \"\n",
|
| 358 |
+
" \"Add at least GROQ_NURSE_API_KEY in Add-ons -> Secrets.\")\n",
|
| 359 |
+
"print(\"OK β at least one Groq key is wired.\")"
|
| 360 |
]
|
| 361 |
},
|
| 362 |
{
|
|
|
|
| 365 |
"metadata": {},
|
| 366 |
"outputs": [],
|
| 367 |
"source": [
|
| 368 |
+
"# === CELL 8 β Hugging Face Hub config (for checkpoint backup) ===\n",
|
| 369 |
+
"import os\n",
|
| 370 |
+
"from kaggle_helpers import push_checkpoint_to_hub, download_checkpoint_from_hub\n",
|
| 371 |
+
"\n",
|
| 372 |
+
"# EDIT the line below to your HF model id (e.g. \"udayd/ermap-doctor-lora\").\n",
|
| 373 |
+
"HF_PUSH_REPO = \"<your-username>/ermap-doctor-lora\"\n",
|
| 374 |
+
"# To resume from a previous run, paste the same repo id here. Empty = fresh.\n",
|
| 375 |
+
"HF_RESUME_REPO = \"\"\n",
|
| 376 |
"\n",
|
| 377 |
"RESUME_DIR = \"/kaggle/working/checkpoints/resume\"\n",
|
| 378 |
"if HF_RESUME_REPO:\n",
|
| 379 |
" download_checkpoint_from_hub(HF_RESUME_REPO, RESUME_DIR)\n",
|
| 380 |
+
" contents = os.listdir(RESUME_DIR) if os.path.isdir(RESUME_DIR) else []\n",
|
| 381 |
+
" print(f\"Resume dir: {contents or '(empty)'}\")\n",
|
| 382 |
+
"else:\n",
|
| 383 |
+
" print(\"Starting fresh β no resume.\")\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 384 |
"\n",
|
| 385 |
+
"if \"<your-username>\" in HF_PUSH_REPO:\n",
|
| 386 |
+
" print(\"\\nWARNING: HF_PUSH_REPO still has <your-username> placeholder.\")\n",
|
| 387 |
+
" print(\" Checkpoints will NOT be pushed to HF Hub.\")\n",
|
| 388 |
+
" print(\" Edit the cell above and re-run before training if you want backups.\")"
|
| 389 |
]
|
| 390 |
},
|
| 391 |
{
|
|
|
|
| 394 |
"metadata": {},
|
| 395 |
"outputs": [],
|
| 396 |
"source": [
|
| 397 |
+
"# === CELL 9 β GRPO hyperparameters ===\n",
|
| 398 |
+
"import os\n",
|
| 399 |
+
"\n",
|
| 400 |
"MODEL_NAME = \"unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit\"\n",
|
|
|
|
| 401 |
"GROUP_SIZE = 2\n",
|
| 402 |
"LEARNING_RATE = 5e-6\n",
|
| 403 |
"KL_BETA = 0.04\n",
|
| 404 |
"OUTPUT_DIR = \"/kaggle/working/er_map_grpo_checkpoints\"\n",
|
| 405 |
"PUSH_EVERY_EPS = 20\n",
|
| 406 |
+
"USE_WANDB = False # WANDB conflicts with protobuf 7 on Kaggle base image\n",
|
| 407 |
+
"NUM_EPISODES = 200 # hard cap; early-stop usually finishes first\n",
|
| 408 |
+
"\n",
|
| 409 |
+
"# --- Per-phase reward thresholds (constant for this run) -------------------\n",
|
| 410 |
+
"# After every GRPO update we look at the last CONVERGENCE_WINDOW groups; if\n",
|
| 411 |
+
"# ALL of them belong to the same current phase AND each has\n",
|
| 412 |
+
"# rolling_avg_reward >= PHASE_REWARD_TARGETS[current_phase] AND\n",
|
| 413 |
+
"# rolling_win_rate >= PHASE_MIN_WIN_RATE, we either:\n",
|
| 414 |
+
"# - force-promote to the next phase (Phase 1 / Phase 2), OR\n",
|
| 415 |
"# - terminate training (Phase 3).\n",
|
| 416 |
+
"EARLY_STOP_ENABLED = True\n",
|
| 417 |
+
"PHASE_REWARD_TARGETS = {1: 1.2, 2: 1.1, 3: 1.0}\n",
|
| 418 |
+
"PHASE_MIN_WIN_RATE = 0.20\n",
|
| 419 |
+
"CONVERGENCE_WINDOW = 3\n",
|
| 420 |
+
"\n",
|
| 421 |
+
"# --- Per-episode budget controls (read by triage_env) ----------------------\n",
|
|
|
|
|
|
|
|
|
|
| 422 |
"os.environ[\"ERMAP_MAX_EPISODE_STEPS\"] = \"20\"\n",
|
| 423 |
"os.environ[\"ERMAP_MAX_INTERNAL_EXCHANGES\"] = \"5\"\n",
|
| 424 |
+
"\n",
|
| 425 |
+
"# --- Groq traffic-shaping (8B for actors, 70B for judges) ------------------\n",
|
| 426 |
+
"# High-volume conversational roles (Nurse + Patient) on the 8B-instant pool\n",
|
| 427 |
+
"# (500K TPD, 14,400 RPD); the two judges stay on 70B-versatile because their\n",
|
| 428 |
+
"# grading quality directly shapes the reward signal.\n",
|
|
|
|
| 429 |
"os.environ[\"ERMAP_NURSE_MODEL\"] = \"llama-3.1-8b-instant\"\n",
|
| 430 |
"os.environ[\"ERMAP_PATIENT_MODEL\"] = \"llama-3.1-8b-instant\"\n",
|
| 431 |
"os.environ[\"ERMAP_EMPATHY_JUDGE_MODEL\"] = \"llama-3.3-70b-versatile\"\n",
|
| 432 |
"os.environ[\"ERMAP_MEDICAL_JUDGE_MODEL\"] = \"llama-3.3-70b-versatile\"\n",
|
| 433 |
"\n",
|
| 434 |
+
"print(\"Hyperparameters set:\")\n",
|
| 435 |
+
"print(f\" NUM_EPISODES = {NUM_EPISODES}\")\n",
|
| 436 |
+
"print(f\" GROUP_SIZE = {GROUP_SIZE}\")\n",
|
| 437 |
+
"print(f\" PHASE_REWARD_TARGETS = {PHASE_REWARD_TARGETS}\")\n",
|
| 438 |
+
"print(f\" PHASE_MIN_WIN_RATE = {PHASE_MIN_WIN_RATE}\")\n",
|
| 439 |
+
"print(f\" CONVERGENCE_WINDOW = {CONVERGENCE_WINDOW}\")\n",
|
| 440 |
+
"print(f\" Nurse / Patient = llama-3.1-8b-instant (actors, high-volume)\")\n",
|
| 441 |
+
"print(f\" Empathy / Med Judge = llama-3.3-70b-versatile (graders, quality)\")"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 442 |
]
|
| 443 |
},
|
| 444 |
{
|
| 445 |
+
"cell_type": "code",
|
| 446 |
+
"execution_count": null,
|
| 447 |
"metadata": {},
|
| 448 |
+
"outputs": [],
|
| 449 |
"source": [
|
| 450 |
+
"# === CELL 10 β Pre-flight: Groq routing + key liveness ===\n",
|
| 451 |
+
"# Verifies that:\n",
|
| 452 |
+
"# - each role is routed to the model you set in cell 9, and\n",
|
| 453 |
+
"# - each role's Groq key actually answers a 1-token \"PING\" prompt.\n",
|
| 454 |
+
"\n",
|
| 455 |
+
"import os\n",
|
| 456 |
+
"from ER_MAP.envs.api_router import AgentRouter\n",
|
| 457 |
+
"\n",
|
| 458 |
+
"router = AgentRouter()\n",
|
| 459 |
+
"expected = {\n",
|
| 460 |
+
" \"nurse\": \"llama-3.1-8b-instant\",\n",
|
| 461 |
+
" \"patient\": \"llama-3.1-8b-instant\",\n",
|
| 462 |
+
" \"empathy_judge\": \"llama-3.3-70b-versatile\",\n",
|
| 463 |
+
" \"medical_judge\": \"llama-3.3-70b-versatile\",\n",
|
| 464 |
+
"}\n",
|
| 465 |
+
"\n",
|
| 466 |
+
"print(\"=\" * 60); print(\" PRE-FLIGHT β Groq routing + smoke test\"); print(\"=\" * 60)\n",
|
| 467 |
+
"all_pass = True\n",
|
| 468 |
+
"for role, exp in expected.items():\n",
|
| 469 |
+
" actual = router._models.get(role, \"?\")\n",
|
| 470 |
+
" routing_ok = (actual == exp)\n",
|
| 471 |
+
" client = router._clients.get(role)\n",
|
| 472 |
+
"\n",
|
| 473 |
+
" if client is None:\n",
|
| 474 |
+
" print(f\" [SKIP] {role:14s} -> no Groq client (key missing)\")\n",
|
| 475 |
+
" all_pass = False\n",
|
| 476 |
+
" continue\n",
|
| 477 |
"\n",
|
| 478 |
+
" try:\n",
|
| 479 |
+
" resp = client.chat.completions.create(\n",
|
| 480 |
+
" model=exp,\n",
|
| 481 |
+
" messages=[{\"role\": \"user\", \"content\": \"Reply with exactly: PING\"}],\n",
|
| 482 |
+
" max_tokens=4, temperature=0,\n",
|
| 483 |
+
" )\n",
|
| 484 |
+
" api_ok = \"PING\" in (resp.choices[0].message.content or \"\").upper()\n",
|
| 485 |
+
" err = \"\"\n",
|
| 486 |
+
" except Exception as e:\n",
|
| 487 |
+
" api_ok = False\n",
|
| 488 |
+
" err = f\" ({type(e).__name__}: {str(e)[:80]})\"\n",
|
| 489 |
+
"\n",
|
| 490 |
+
" flag = \"PASS\" if (routing_ok and api_ok) else \"FAIL\"\n",
|
| 491 |
+
" if flag == \"FAIL\":\n",
|
| 492 |
+
" all_pass = False\n",
|
| 493 |
+
" print(f\" [{flag}] {role:14s} -> {actual:30s} \"\n",
|
| 494 |
+
" f\"routing={'ok' if routing_ok else 'WRONG'}, \"\n",
|
| 495 |
+
" f\"api={'ok' if api_ok else 'fail'}{err}\")\n",
|
| 496 |
+
"\n",
|
| 497 |
+
"print(\"=\" * 60)\n",
|
| 498 |
+
"print(\"OK\" if all_pass else \"NOT OK β fix routing/keys before training.\")\n",
|
| 499 |
+
"print(\"=\" * 60)\n",
|
| 500 |
+
"assert all_pass, \"Pre-flight failed; do not proceed to training.\""
|
| 501 |
]
|
| 502 |
},
|
| 503 |
{
|
|
|
|
| 506 |
"metadata": {},
|
| 507 |
"outputs": [],
|
| 508 |
"source": [
|
| 509 |
+
"# === CELL 11 β Dry-run smoke test (no GPU, no model load) ===\n",
|
| 510 |
+
"# Verifies the curriculum scheduler + reward verifier + per-phase early-stop\n",
|
| 511 |
+
"# wiring before we burn GPU minutes on the real run.\n",
|
| 512 |
+
"\n",
|
| 513 |
"from ER_MAP.training.train_grpo import train\n",
|
| 514 |
"\n",
|
| 515 |
"_ = train(\n",
|
|
|
|
| 520 |
" kl_beta=KL_BETA,\n",
|
| 521 |
" output_dir=\"/kaggle/working/_dryrun\",\n",
|
| 522 |
" dry_run=True,\n",
|
| 523 |
+
" phase_reward_targets=PHASE_REWARD_TARGETS,\n",
|
| 524 |
+
" phase_min_win_rate=PHASE_MIN_WIN_RATE,\n",
|
| 525 |
+
" convergence_window=CONVERGENCE_WINDOW,\n",
|
| 526 |
+
" early_stop=EARLY_STOP_ENABLED,\n",
|
| 527 |
")\n",
|
| 528 |
+
"print(\"\\nDry-run OK β scheduler + verifier + per-phase early-stop wiring is healthy.\")"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 529 |
]
|
| 530 |
},
|
| 531 |
{
|
|
|
|
| 534 |
"metadata": {},
|
| 535 |
"outputs": [],
|
| 536 |
"source": [
|
| 537 |
+
"# === CELL 12 β Wire periodic HF Hub push into training ===\n",
|
| 538 |
+
"# We monkey-patch save_lora_adapters so every checkpoint dump also pushes\n",
|
| 539 |
+
"# the LoRA adapter to HF Hub. Failures are non-fatal β training keeps\n",
|
| 540 |
+
"# running even if a push fails (e.g. transient HF 502).\n",
|
| 541 |
+
"\n",
|
| 542 |
"from ER_MAP.training import train_grpo as _tg\n",
|
| 543 |
"_original_save = _tg.save_lora_adapters\n",
|
|
|
|
| 544 |
"\n",
|
| 545 |
"def save_lora_adapters_with_push(model, tokenizer, output_dir):\n",
|
| 546 |
" _original_save(model, tokenizer, output_dir)\n",
|
|
|
|
| 547 |
" if HF_PUSH_REPO and \"<your-username>\" not in HF_PUSH_REPO:\n",
|
| 548 |
" try:\n",
|
| 549 |
" push_checkpoint_to_hub(\n",
|
| 550 |
+
" output_dir, HF_PUSH_REPO,\n",
|
|
|
|
| 551 |
" commit_message=f\"checkpoint @ {os.path.basename(output_dir)}\",\n",
|
| 552 |
" )\n",
|
| 553 |
" except Exception as e:\n",
|
|
|
|
| 561 |
"cell_type": "markdown",
|
| 562 |
"metadata": {},
|
| 563 |
"source": [
|
| 564 |
+
"## 13 Β· Run real training (the 4β6 hour cell)\n",
|
| 565 |
+
"\n",
|
| 566 |
+
"**Estimated wall-clock on Kaggle T4 Γ2:**\n",
|
| 567 |
+
"\n",
|
| 568 |
+
"- ~3β5 min per episode (6β14 env steps Γ Doctor.generate + 4β8 Γ Groq calls)\n",
|
| 569 |
+
"- ~1β2 min amortized per GRPO update (G=2 trajectories Γ response-token log-probs)\n",
|
| 570 |
+
"- **Per-group β 8β12 min** (2 episodes + 1 update)\n",
|
| 571 |
+
"\n",
|
| 572 |
+
"| Phase | Typical episodes to reach target | Wall-clock |\n",
|
| 573 |
+
"|---|---|---|\n",
|
| 574 |
+
"| 1 (target `+1.2` Γ 3) | 12 β 24 episodes (6 β 12 groups) | ~1.0 β 2.0 h |\n",
|
| 575 |
+
"| 2 (target `+1.1` Γ 3) | 16 β 32 episodes (8 β 16 groups) | ~1.5 β 2.5 h |\n",
|
| 576 |
+
"| 3 (target `+1.0` Γ 3) | 20 β 50 episodes (10 β 25 groups) | ~2.0 β 4.0 h |\n",
|
| 577 |
+
"| **Total** | 50 β 100 episodes | **~4.5 β 8.5 h** |\n",
|
| 578 |
+
"\n",
|
| 579 |
+
"If `NUM_EPISODES=200` is exhausted before Phase 3 converges, training\n",
|
| 580 |
+
"stops at the cap and the latest LoRA checkpoint is on HF Hub already\n",
|
| 581 |
+
"(we push every 20 episodes), so resume in a fresh session via\n",
|
| 582 |
+
"`HF_RESUME_REPO` in cell 8."
|
| 583 |
]
|
| 584 |
},
|
| 585 |
{
|
|
|
|
| 588 |
"metadata": {},
|
| 589 |
"outputs": [],
|
| 590 |
"source": [
|
| 591 |
+
"# === CELL 13 β REAL TRAINING (4-6 h cell) ===\n",
|
| 592 |
"metrics = train(\n",
|
| 593 |
" num_episodes=NUM_EPISODES,\n",
|
| 594 |
" group_size=GROUP_SIZE,\n",
|
| 595 |
" model_name=MODEL_NAME,\n",
|
| 596 |
+
" groq_api_key=os.environ.get(\"GROQ_NURSE_API_KEY\", \"\")\n",
|
| 597 |
+
" or os.environ.get(\"GROQ_API_KEY\", \"\"),\n",
|
| 598 |
" learning_rate=LEARNING_RATE,\n",
|
| 599 |
" kl_beta=KL_BETA,\n",
|
| 600 |
" use_wandb=USE_WANDB,\n",
|
| 601 |
" output_dir=OUTPUT_DIR,\n",
|
| 602 |
" dry_run=False,\n",
|
|
|
|
| 603 |
" phase_reward_targets=PHASE_REWARD_TARGETS,\n",
|
| 604 |
" phase_min_win_rate=PHASE_MIN_WIN_RATE,\n",
|
| 605 |
" convergence_window=CONVERGENCE_WINDOW,\n",
|
| 606 |
" early_stop=EARLY_STOP_ENABLED,\n",
|
| 607 |
+
")\n",
|
| 608 |
+
"print(f\"\\nTraining returned {len(metrics)} metric records.\")"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 609 |
]
|
| 610 |
},
|
| 611 |
{
|
|
|
|
| 614 |
"metadata": {},
|
| 615 |
"outputs": [],
|
| 616 |
"source": [
|
| 617 |
+
"# === CELL 14 β Final push: adapters + merged fp16 ===\n",
|
| 618 |
"FINAL_LORA_DIR = f\"{OUTPUT_DIR}/final_lora\"\n",
|
| 619 |
"FINAL_MERGED_DIR = f\"{OUTPUT_DIR}/final_merged_fp16\"\n",
|
| 620 |
"\n",
|
| 621 |
"if HF_PUSH_REPO and \"<your-username>\" not in HF_PUSH_REPO:\n",
|
| 622 |
+
" push_checkpoint_to_hub(FINAL_LORA_DIR, HF_PUSH_REPO,\n",
|
| 623 |
+
" commit_message=\"final LoRA adapter\")\n",
|
|
|
|
|
|
|
| 624 |
" if os.path.isdir(FINAL_MERGED_DIR):\n",
|
| 625 |
+
" push_checkpoint_to_hub(FINAL_MERGED_DIR, f\"{HF_PUSH_REPO}-merged\",\n",
|
| 626 |
+
" commit_message=\"final merged fp16\")\n",
|
| 627 |
+
" print(f\"Final checkpoints pushed: https://huggingface.co/{HF_PUSH_REPO}\")\n",
|
|
|
|
|
|
|
| 628 |
"else:\n",
|
| 629 |
" print(\"HF_PUSH_REPO not configured β skipping final push.\")"
|
| 630 |
]
|
|
|
|
| 633 |
"cell_type": "markdown",
|
| 634 |
"metadata": {},
|
| 635 |
"source": [
|
| 636 |
+
"## 15 Β· Per-phase training graphs (one dashboard per curriculum phase)\n",
|
| 637 |
+
"\n",
|
| 638 |
+
"We render a 6-panel dashboard for **every phase that contains episodes**,\n",
|
| 639 |
+
"plus a cross-phase overview and a phase-comparison bar chart. All PNGs are\n",
|
| 640 |
+
"written to `er_map_grpo_checkpoints/plots/` and uploaded to HF Hub in the\n",
|
| 641 |
+
"next cell so they survive Kaggle session expiry.\n",
|
| 642 |
"\n",
|
| 643 |
+
"Each per-phase dashboard contains:\n",
|
| 644 |
"\n",
|
|
|
|
| 645 |
"1. **Reward growth** β raw scatter + rolling mean (w=10) + verified rolling mean\n",
|
| 646 |
+
"2. **Rolling win rate** β w=20 win-rate evolution within the phase\n",
|
| 647 |
+
"3. **Outcome distribution over time** β stacked bars (WIN/PARTIAL/INCORRECT/AMA_LOSS/FATAL_LOSS)\n",
|
| 648 |
+
"4. **Reward components** β mean of each component (process / treatment / empathy / labs / etc.)\n",
|
| 649 |
+
"5. **GRPO update stats** β loss + KL divergence per group update\n",
|
| 650 |
"6. **Episode length distribution** β histogram of step counts"
|
| 651 |
]
|
| 652 |
},
|
|
|
|
| 656 |
"metadata": {},
|
| 657 |
"outputs": [],
|
| 658 |
"source": [
|
| 659 |
+
"# === CELL 15 β Per-phase training dashboards ===\n",
|
| 660 |
"from ER_MAP.plotting import plot_per_phase_dashboards\n",
|
| 661 |
"from IPython.display import Image, display, Markdown\n",
|
| 662 |
"\n",
|
|
|
|
| 666 |
" output_dir=PLOTS_DIR,\n",
|
| 667 |
")\n",
|
| 668 |
"\n",
|
| 669 |
+
"print(f\"Saved {len(written)} chart(s) to {PLOTS_DIR}:\")\n",
|
| 670 |
"for name, path in written.items():\n",
|
| 671 |
" size_kb = os.path.getsize(path) / 1024\n",
|
| 672 |
" print(f\" {name:<28s} -> {path} ({size_kb:.0f} KB)\")\n",
|
| 673 |
"\n",
|
| 674 |
+
"# Display each chart inline so the operator sees them without leaving Kaggle.\n",
|
| 675 |
+
"ordered = (sorted(k for k in written if k.startswith(\"phase\"))\n",
|
| 676 |
+
" + [\"all_phases_overview\", \"all_phases_comparison\"])\n",
|
| 677 |
+
"for key in ordered:\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 678 |
" if key not in written:\n",
|
| 679 |
" continue\n",
|
| 680 |
" display(Markdown(f\"### {key.replace('_', ' ').title()}\"))\n",
|
| 681 |
" display(Image(filename=written[key]))"
|
| 682 |
]
|
| 683 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 684 |
{
|
| 685 |
"cell_type": "code",
|
| 686 |
"execution_count": null,
|
| 687 |
"metadata": {},
|
| 688 |
"outputs": [],
|
| 689 |
"source": [
|
| 690 |
+
"# === CELL 16 β Push plots to HF Hub ===\n",
|
| 691 |
"if HF_PUSH_REPO and \"<your-username>\" not in HF_PUSH_REPO:\n",
|
| 692 |
+
" push_checkpoint_to_hub(PLOTS_DIR, HF_PUSH_REPO,\n",
|
| 693 |
+
" commit_message=\"per-phase training plots\")\n",
|
| 694 |
+
" print(f\"Plots pushed: https://huggingface.co/{HF_PUSH_REPO}/tree/main\")\n",
|
|
|
|
| 695 |
"else:\n",
|
| 696 |
+
" print(\"HF_PUSH_REPO not configured β plots stay only in /kaggle/working/.\")"
|
| 697 |
]
|
| 698 |
},
|
| 699 |
{
|
| 700 |
"cell_type": "markdown",
|
| 701 |
"metadata": {},
|
| 702 |
"source": [
|
| 703 |
+
"## 17 Β· (Optional) Inference smoke-test on the trained model\n",
|
| 704 |
"\n",
|
| 705 |
+
"Catches the classic 'merge path looked OK but the saved model emits garbage'\n",
|
| 706 |
+
"failure mode before the demo."
|
| 707 |
]
|
| 708 |
},
|
| 709 |
{
|
|
|
|
| 712 |
"metadata": {},
|
| 713 |
"outputs": [],
|
| 714 |
"source": [
|
| 715 |
+
"# === CELL 17 β Inference smoke-test on the trained model ===\n",
|
| 716 |
"from ER_MAP.training.train_grpo import generate_doctor_action, load_model_and_tokenizer\n",
|
| 717 |
"from peft import PeftModel\n",
|
| 718 |
"\n",
|
|
|
|
| 743 |
},
|
| 744 |
"nbformat": 4,
|
| 745 |
"nbformat_minor": 5
|
| 746 |
+
}
|