Multi-Agentic / kaggle /build_notebook.py
Uddiii's picture
feat: add support for lowercase Hugging Face Space secrets
63726b6
"""
kaggle/build_notebook.py
========================
Programmatically (re)builds `train_ermap_grpo_kaggle.ipynb` from scratch.
Why a builder script?
--------------------
The hand-edited notebook drifted into a fragile state across many sessions:
mixed early-stop / fixed-budget params, stale install snippets, dead pre-flight
checks, etc. This script is the single source of truth β€” run it once and the
notebook is regenerated as a clean, deterministic v3 layout.
Run:
python kaggle/build_notebook.py
Output:
kaggle/train_ermap_grpo_kaggle.ipynb (overwritten)
kaggle/KAGGLE_QUICKSTART.md (overwritten)
"""
from __future__ import annotations
import json
import textwrap
from pathlib import Path
# ---------------------------------------------------------------------------
# Cell helpers
# ---------------------------------------------------------------------------
def md_cell(text: str) -> dict:
return {
"cell_type": "markdown",
"metadata": {},
"source": _split_keep_newlines(text),
}
def code_cell(text: str) -> dict:
return {
"cell_type": "code",
"execution_count": None,
"metadata": {},
"outputs": [],
"source": _split_keep_newlines(text),
}
def _split_keep_newlines(text: str) -> list[str]:
"""Notebook 'source' fields expect each line to terminate with '\n'
except the last one. Splitting like this keeps `git diff` clean when
the notebook is regenerated."""
text = textwrap.dedent(text).lstrip("\n")
if not text.endswith("\n"):
text = text + "\n"
lines = text.splitlines(keepends=True)
if lines:
# The last line should NOT have a trailing newline (Jupyter convention).
if lines[-1].endswith("\n"):
lines[-1] = lines[-1].rstrip("\n")
return lines
# ---------------------------------------------------------------------------
# Cell sources
# ---------------------------------------------------------------------------
CELL_01_TITLE = """\
# ER-MAP β€” Doctor Agent GRPO Training (Kaggle Free-Tier Β· v3 stable)
Trains the **Doctor LLM** (Llama-3.1-8B-Instruct, 4-bit + LoRA r=16) via GRPO
with a 3-phase curriculum on Kaggle's free GPU. Designed to survive Kaggle's
pre-baked image quirks (numpy / Pillow ABI mismatches, torch + torchvision
CUDA-major mismatches, transient `unsloth_zoo` upgrades).
## TL;DR β€” How to run this notebook
1. **Notebook settings (right sidebar):**
- Accelerator: **GPU T4 Γ—2** (or P100)
- Internet: **On**
- Persistence: Files only
2. **Kaggle Secrets** (Add-ons β†’ Secrets):
- **Required:** `GROQ_NURSE_API_KEY`, `GROQ_PATIENT_API_KEY`,
`GROQ_EMPATHY_JUDGE_API_KEY`, `GROQ_MEDICAL_JUDGE_API_KEY`, `HF_TOKEN`
- **Optional:** `WANDB_API_KEY`
3. **Run cells 2 β†’ 3 (sanity + REPAIR).** When cell 3 prints
`RESTART REQUIRED`, click **Run β†’ Restart kernel**, then resume from cell 5.
4. **Run cells 5 β†’ 11 (verify + configure + dry-run + pre-flight).** Each cell
should print an `OK` line before moving on.
5. **Run cell 13 (the long training cell, 4–6 hours).**
6. **Run cells 14 β†’ 17 (final push + plots + inference smoke-test).**
## Curriculum + reward thresholds (this run)
Constant per-phase rolling-avg-reward bars; sustained for **3 consecutive
GRPO groups** triggers either a phase promotion or end-of-training.
| Phase | Reward target (sustained Γ—3 groups) | Action when met |
|---|---|---|
| 1 β€” Tool Mastery | `+1.2` | force-promote to Phase 2 |
| 2 β€” Clinical Reasoning | `+1.1` | force-promote to Phase 3 |
| 3 β€” Empathetic Negotiation | `+1.0` | END TRAINING |
Why these numbers? The un-trained 8B Doctor's baseline on the same env is
`P1=+0.76, P2=+0.59, P3=+0.39`. Targets of `+1.2 / +1.1 / +1.0` correspond
to roughly `1.6Γ— / 1.9Γ— / 2.6Γ—` improvement over baseline β€” a meaningful
signal but reachable inside Kaggle's 12 h session limit.
"""
CELL_02_SANITY = """\
# === CELL 2 β€” Sanity check (GPU + disk + python + internet) ===
# Run this FIRST. If any check fails, fix it before running the REPAIR cell.
import os, shutil, subprocess, sys, socket
print("--- GPU ---")
try:
print(subprocess.check_output(
["nvidia-smi", "--query-gpu=name,memory.total,memory.free", "--format=csv"],
timeout=10,
).decode())
except Exception as e:
print(f"nvidia-smi failed: {e}")
print("-> Set Accelerator to 'GPU T4 x2' in the right sidebar.")
print("--- Disk (/kaggle/working) ---")
total, used, free = shutil.disk_usage("/kaggle/working")
print(f" total={total/1e9:5.1f} GB | used={used/1e9:5.1f} GB | free={free/1e9:5.1f} GB")
if free < 8 * 1e9:
print(" WARNING: free disk < 8 GB β€” repair cell may fail. "
"Consider 'Run > Restart and clear cell outputs' to reset /tmp.")
print("--- Python ---")
print(f" python={sys.version.split()[0]} | exe={sys.executable}")
print("--- Internet (api.groq.com:443) ---")
try:
socket.create_connection(("api.groq.com", 443), timeout=5).close()
print(" reachable")
except Exception as e:
print(f" UNREACHABLE: {e}")
print(" -> Settings (right sidebar) -> Internet -> ON")
"""
CELL_03_REPAIR = """\
# === CELL 3 β€” REPAIR CELL (idempotent full environment rebuild) ===
# Single source of truth for ER-MAP's GPU stack. Safe to re-run. After it
# finishes you'll see one of two final lines:
#
# RESTART REQUIRED -> Run -> Restart kernel, then resume from cell 5
# REPAIR OK -> proceed directly to cell 5
#
# Note: this cell only runs shell commands and one isolated subprocess.
# It deliberately does NOT `import torch / numpy / Pillow / unsloth` in the
# kernel, so re-running it after a botched install does not poison further
# attempts.
print("=" * 72); print(" CELL 3 β€” REPAIR"); print("=" * 72)
# 1. Clean caches (Kaggle's /kaggle/working is only 20 GB β€” installs
# routinely fill it after a few re-runs).
print("[1/6] Cleaning pip + tmp + HF dataset caches...")
get_ipython().system('pip cache purge -q || true')
get_ipython().system('rm -rf /tmp/* /root/.cache/pip /root/.cache/huggingface/datasets 2>/dev/null || true')
# 2. Pin torch + torchvision to the cu128 wheel (matches Kaggle's CUDA 12.8
# base image). DON'T let pip pull a generic CUDA-13 build β€” that breaks
# bitsandbytes (libnvJitLink.so.13 missing) and torchvision (CUDA-major
# mismatch RuntimeError at import time).
print("[2/6] Installing torch==2.10.0 + torchvision==0.25.0 (cu128)...")
get_ipython().system('pip install -q --no-cache-dir --force-reinstall '
'torch==2.10.0 torchvision==0.25.0 '
'--index-url https://download.pytorch.org/whl/cu128')
# Write a pip constraints file so subsequent installs (bnb, unsloth, trl, etc.)
# can NEVER pull a different torch from default PyPI. Without this, step 3's
# `--force-reinstall bitsandbytes` and step 4's `unsloth` upgrade re-resolve
# torch from PyPI (currently 2.11.0), which breaks the cu128 torchvision pair.
#
# Also pin numpy to whatever Kaggle's kernel already has loaded β€” Kaggle's
# image puts numpy at /usr/lib/python3/dist-packages while pip writes to
# /usr/local/lib/python3.12/dist-packages, so any version drift between the
# two paths trips unsloth_zoo's strict loaded-vs-installed check at import.
import subprocess as _sp
_kernel_numpy = _sp.check_output(
[sys.executable, "-c", "import numpy; print(numpy.__version__)"],
text=True,
).strip()
print(f" detected kernel numpy = {_kernel_numpy} (will pin)")
with open("/tmp/ermap_constraints.txt", "w") as _cf:
_cf.write(f"torch==2.10.0\\ntorchvision==0.25.0\\nnumpy=={_kernel_numpy}\\n")
# 3. Reinstall bitsandbytes against the now-pinned torch.
# --no-deps because bnb just needs torch at RUNTIME (it dlopens torch's
# C++ lib) β€” its install-time deps don't include torch.
print("[3/6] Reinstalling bitsandbytes (--no-deps to preserve torch)...")
get_ipython().system('pip install -q --no-cache-dir --force-reinstall --no-deps bitsandbytes')
# 4. Upgrade unsloth + unsloth_zoo + trl in lockstep. unsloth and
# unsloth_zoo are released as a matched pair; if pip pulls a fresh
# unsloth_zoo against an old unsloth you get
# ImportError: cannot import name 'create_gradient_checkpointing_buffer'
# The constraint file blocks them from moving torch.
print("[4/6] Upgrading unsloth + unsloth_zoo + trl (constrained)...")
get_ipython().system('pip install -q --upgrade --no-cache-dir '
'-c /tmp/ermap_constraints.txt '
'unsloth unsloth_zoo "trl>=0.18.2"')
# 5. ER-MAP runtime deps that aren't pre-installed on Kaggle.
print("[5/6] Installing ER-MAP runtime deps (constrained)...")
get_ipython().system('pip install -q --no-cache-dir '
'-c /tmp/ermap_constraints.txt '
'"groq>=0.18.0" "huggingface_hub>=0.25.0" '
'"gymnasium>=0.29.0" "openenv-core>=0.1.0"')
# 5b. Realign the pip-managed numpy with whatever the Kaggle kernel actually
# has loaded. This force-rewrites /usr/local/lib/.../numpy at the exact
# version reported by the running interpreter, so importlib.metadata
# and `numpy.__version__` agree even if Kaggle ships its base numpy at
# a different dist-packages path.
print(f"[5b/6] Realigning pip-managed numpy to {_kernel_numpy}...")
get_ipython().system(
f'pip install -q --force-reinstall --no-deps "numpy=={_kernel_numpy}"'
)
# 6. Verify in a SUBPROCESS (so the parent kernel never imports any of these
# while pip is mid-flight, which is what causes the
# 'numpy was upgraded mid-session (loaded: X, installed: Y)' RuntimeError
# we kept hitting before).
print("[6/6] Verifying via subprocess...")
import subprocess, sys, json
verify_script = r'''
import json, sys
out = {"ok": True, "details": {}, "errors": []}
try:
import importlib.metadata as md
for pkg in ("torch", "torchvision", "bitsandbytes", "unsloth", "unsloth_zoo",
"trl", "transformers", "peft", "accelerate", "groq",
"huggingface_hub", "gymnasium", "numpy", "Pillow"):
try:
out["details"][pkg + "_installed"] = md.version(pkg)
except md.PackageNotFoundError:
out["details"][pkg + "_installed"] = None
import torch, torchvision, numpy as np, PIL, unsloth, unsloth_zoo, bitsandbytes, trl
out["details"]["torch_loaded"] = torch.__version__
out["details"]["torch_cuda"] = torch.version.cuda
out["details"]["cuda_available"] = bool(torch.cuda.is_available())
out["details"]["gpu_count"] = int(torch.cuda.device_count())
out["details"]["torchvision_loaded"] = torchvision.__version__
out["details"]["numpy_loaded"] = np.__version__
out["details"]["pillow_loaded"] = PIL.__version__
out["details"]["unsloth_loaded"] = unsloth.__version__
out["details"]["unsloth_zoo_loaded"] = unsloth_zoo.__version__
out["details"]["bitsandbytes_loaded"] = bitsandbytes.__version__
out["details"]["trl_loaded"] = trl.__version__
# Cross-check loaded-vs-installed for the C-extension libs that bit us
# on every previous run.
for pkg, loaded_key, installed_key in [
("numpy", "numpy_loaded", "numpy_installed"),
("Pillow", "pillow_loaded", "Pillow_installed"),
("torch", "torch_loaded", "torch_installed"),
]:
loaded = out["details"].get(loaded_key)
installed = out["details"].get(installed_key)
if loaded and installed and loaded != installed:
# Strip any local-version suffix (e.g. '+cu128') before compare.
if loaded.split("+")[0] != installed.split("+")[0]:
out["errors"].append(
f"{pkg} mismatch: loaded={loaded} installed={installed}"
)
except Exception as e:
out["ok"] = False
out["errors"].append(f"{type(e).__name__}: {e}")
print(json.dumps(out, default=str))
'''.lstrip()
res = subprocess.run([sys.executable, "-c", verify_script],
capture_output=True, text=True, timeout=180)
print(res.stdout if res.stdout else "<no stdout>")
if res.stderr:
print("---- subprocess stderr ----"); print(res.stderr)
# Parse the LAST line of stdout (others are prints from package init).
try:
last = res.stdout.strip().splitlines()[-1]
parsed = json.loads(last)
except Exception:
parsed = {"ok": False, "errors": ["could not parse verification output"]}
ok = parsed.get("ok") and not parsed.get("errors")
d = parsed.get("details", {})
print("\\n" + "=" * 72)
if ok:
print(" REPAIR OK")
print(f" torch : {d.get('torch_loaded')} (CUDA {d.get('torch_cuda')})")
print(f" torchvision : {d.get('torchvision_loaded')}")
print(f" bitsandbytes: {d.get('bitsandbytes_loaded')}")
print(f" unsloth : {d.get('unsloth_loaded')} | unsloth_zoo: {d.get('unsloth_zoo_loaded')}")
print(f" trl : {d.get('trl_loaded')}")
print(f" numpy : {d.get('numpy_loaded')} | Pillow: {d.get('pillow_loaded')}")
print(f" GPUs : {d.get('gpu_count')} (cuda_available={d.get('cuda_available')})")
print()
print(" -> If this kernel previously imported torch/numpy/Pillow/unsloth,")
print(" RESTART NOW (Run -> Restart kernel) before continuing to cell 5.")
print(" If this is a fresh kernel, you can proceed directly.")
else:
print(" RESTART REQUIRED β€” issues detected:")
for e in parsed.get("errors", []):
print(f" - {e}")
print()
print(" Action: Run -> Restart kernel, then re-run from cell 2.")
print("=" * 72)
"""
CELL_04_RESTART = """\
## ⚠ Restart kernel here if cell 3 said `RESTART REQUIRED`
Click **Run β†’ Restart kernel** (or **Run β†’ Restart & clear cell outputs**),
then resume from **cell 5**. Skipping the restart will produce ABI mismatch
errors at the first GPU op.
If cell 3 said `REPAIR OK` AND this is a fresh kernel that hasn't imported
torch/numpy/Pillow/unsloth yet, you can proceed to cell 5 directly.
"""
CELL_05_VERIFY = """\
# === CELL 5 β€” Post-restart verify (this kernel can import everything) ===
import importlib.metadata as md
print("--- Loaded versions in this kernel ---")
import torch, numpy, PIL, torchvision, unsloth, unsloth_zoo, bitsandbytes, trl, transformers, peft
versions = {
"torch": torch.__version__,
"torchvision": torchvision.__version__,
"numpy": numpy.__version__,
"Pillow": PIL.__version__,
"unsloth": unsloth.__version__,
"unsloth_zoo": unsloth_zoo.__version__,
"bitsandbytes": bitsandbytes.__version__,
"trl": trl.__version__,
"transformers": transformers.__version__,
"peft": peft.__version__,
}
all_ok = True
for k, v in versions.items():
try:
inst = md.version(k)
except md.PackageNotFoundError:
inst = "(not installed)"
# Tolerate local version suffixes like '+cu128'
flag = "OK" if inst.split("+")[0] == v.split("+")[0] else f"MISMATCH (installed={inst})"
if "MISMATCH" in flag:
all_ok = False
print(f" {k:14s}: loaded={v:20s} [{flag}]")
print()
print(f" CUDA available : {torch.cuda.is_available()}")
print(f" GPU count : {torch.cuda.device_count()}")
if torch.cuda.is_available():
for i in range(torch.cuda.device_count()):
p = torch.cuda.get_device_properties(i)
print(f" GPU {i} : {p.name} ({p.total_memory/1e9:.1f} GB)")
print()
print("OK" if all_ok else "NOT OK β€” re-run cell 3 and restart kernel.")
"""
CELL_06_REPO = """\
# === CELL 6 β€” Mount the ER-MAP repo into /kaggle/working ===
import os, subprocess, sys
# OPTION A: clone a public GitHub fork (preferred). Edit GIT_URL.
GIT_URL = "https://github.com/<your-fork>/Meta_Finals.git"
BRANCH = "main"
REPO_ROOT = "/kaggle/working/Meta_Finals"
# OPTION B: Kaggle Dataset upload β€” set this if you uploaded the repo
# as a Kaggle Dataset named "ermap-source" (Add Data -> Upload).
DATASET_DIR = "/kaggle/input/ermap-source"
if not os.path.isdir(f"{REPO_ROOT}/ER_MAP"):
if "<your-fork>" not in GIT_URL:
print(f"Cloning {GIT_URL}@{BRANCH} -> {REPO_ROOT}...")
out = subprocess.run(
["git", "clone", "--depth", "1", "-b", BRANCH, GIT_URL, REPO_ROOT],
capture_output=True, text=True,
)
print(out.stdout); print(out.stderr)
elif os.path.isdir(DATASET_DIR):
print(f"Copying {DATASET_DIR} -> {REPO_ROOT}...")
import shutil
shutil.copytree(DATASET_DIR, REPO_ROOT, dirs_exist_ok=True)
assert os.path.isdir(f"{REPO_ROOT}/ER_MAP"), (
"Repo not found.\\n"
" - Edit GIT_URL above to your GitHub fork, OR\\n"
" - Upload the repo as a Kaggle Dataset named 'ermap-source' (Add Data -> Upload)."
)
sys.path.insert(0, REPO_ROOT)
sys.path.insert(0, f"{REPO_ROOT}/kaggle")
print(f"OK. Repo at {REPO_ROOT}")
"""
CELL_07_SECRETS = """\
# === CELL 7 β€” Wire Kaggle Secrets into env vars ===
import os
from kaggle_helpers import load_kaggle_secrets, kaggle_env_summary
load_kaggle_secrets()
kaggle_env_summary()
# Hard fail if no Groq key β€” training would silently use mock LLMs.
assert any(os.environ.get(k) for k in (
"GROQ_NURSE_API_KEY", "GROQ_PATIENT_API_KEY",
"GROQ_EMPATHY_JUDGE_API_KEY", "GROQ_MEDICAL_JUDGE_API_KEY",
"GROQ_API_KEY",
)), ("No Groq key found in Kaggle Secrets. "
"Add at least GROQ_NURSE_API_KEY in Add-ons -> Secrets.")
print("OK β€” at least one Groq key is wired.")
"""
CELL_08_HF = """\
# === CELL 8 β€” Hugging Face Hub config (for checkpoint backup) ===
import os
from kaggle_helpers import push_checkpoint_to_hub, download_checkpoint_from_hub
# EDIT the line below to your HF model id (e.g. "udayd/ermap-doctor-lora").
HF_PUSH_REPO = "<your-username>/ermap-doctor-lora"
# To resume from a previous run, paste the same repo id here. Empty = fresh.
HF_RESUME_REPO = ""
RESUME_DIR = "/kaggle/working/checkpoints/resume"
if HF_RESUME_REPO:
download_checkpoint_from_hub(HF_RESUME_REPO, RESUME_DIR)
contents = os.listdir(RESUME_DIR) if os.path.isdir(RESUME_DIR) else []
print(f"Resume dir: {contents or '(empty)'}")
else:
print("Starting fresh β€” no resume.")
if "<your-username>" in HF_PUSH_REPO:
print("\\nWARNING: HF_PUSH_REPO still has <your-username> placeholder.")
print(" Checkpoints will NOT be pushed to HF Hub.")
print(" Edit the cell above and re-run before training if you want backups.")
"""
CELL_09_HYPERPARAMS = """\
# === CELL 9 β€” GRPO hyperparameters ===
import os
MODEL_NAME = "unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit"
GROUP_SIZE = 2
LEARNING_RATE = 5e-6
# KL_BETA = 0.0 -> SKIP reference model load (saves ~5 GB VRAM on T4).
# T4 only has 15 GB and one 4-bit Llama-3.1-8B + LoRA + activations + gradients
# already eats ~10 GB; loading a second 4-bit copy as the KL reference OOMs
# the GRPO backward pass. Set this to 0.04 only if you've upgraded to an
# A100 / H100 with >= 24 GB VRAM.
KL_BETA = 0.0
OUTPUT_DIR = "/kaggle/working/er_map_grpo_checkpoints"
PUSH_EVERY_EPS = 20
USE_WANDB = False # WANDB conflicts with protobuf 7 on Kaggle base image
# --- Curriculum mode: FIXED-BUDGET (recommended for Kaggle T4) -------------
# A fixed per-phase episode budget gives you a clean, predictable reward-
# growth curve and bounds your wall-clock. With GROUP_SIZE=2 below and an
# observed ~3 min / episode (Groq-dominated), 100 episodes β‰ˆ 5 hours.
#
# Set PHASE_EPISODE_BUDGETS to None to fall back to early-stopping mode,
# which terminates each phase the moment its reward target is hit (faster
# but less predictable). When PHASE_EPISODE_BUDGETS is set, EARLY_STOP_ENABLED
# is automatically forced to False inside train() β€” the reward targets below
# become observational only (logged on the plots, not used for promotion).
PHASE_EPISODE_BUDGETS = {1: 20, 2: 25, 3: 30} # 20 + 25 + 30 = 75 episodes
NUM_EPISODES = sum(PHASE_EPISODE_BUDGETS.values()) # = 75
# --- Per-phase reward thresholds (observational under fixed-budget) --------
# Plotted as horizontal target lines on the reward-growth chart so you can
# see at a glance whether each phase actually crossed its target.
EARLY_STOP_ENABLED = False # ignored when PHASE_EPISODE_BUDGETS is set
PHASE_REWARD_TARGETS = {1: 1.2, 2: 1.1, 3: 1.0}
PHASE_MIN_WIN_RATE = 0.20
CONVERGENCE_WINDOW = 3
# --- Per-episode budget controls (read by triage_env) ----------------------
os.environ["ERMAP_MAX_EPISODE_STEPS"] = "20"
os.environ["ERMAP_MAX_INTERNAL_EXCHANGES"] = "5"
# --- Groq traffic-shaping (8B for actors, 70B for judges) ------------------
# High-volume conversational roles (Nurse + Patient) on the 8B-instant pool
# (500K TPD, 14,400 RPD); the two judges stay on 70B-versatile because their
# grading quality directly shapes the reward signal.
os.environ["ERMAP_NURSE_MODEL"] = "llama-3.1-8b-instant"
os.environ["ERMAP_PATIENT_MODEL"] = "llama-3.1-8b-instant"
os.environ["ERMAP_EMPATHY_JUDGE_MODEL"] = "llama-3.3-70b-versatile"
os.environ["ERMAP_MEDICAL_JUDGE_MODEL"] = "llama-3.3-70b-versatile"
print("Hyperparameters set:")
print(f" NUM_EPISODES = {NUM_EPISODES}")
print(f" GROUP_SIZE = {GROUP_SIZE}")
print(f" PHASE_EPISODE_BUDGETS = {PHASE_EPISODE_BUDGETS}")
print(f" PHASE_REWARD_TARGETS = {PHASE_REWARD_TARGETS} (observational)")
print(f" EARLY_STOP_ENABLED = {EARLY_STOP_ENABLED} (forced off under fixed budget)")
print(f" KL_BETA = {KL_BETA} (0.0 -> skip ref model, T4-safe)")
print(f" Nurse / Patient = llama-3.1-8b-instant (actors, high-volume)")
print(f" Empathy / Med Judge = llama-3.3-70b-versatile (graders, quality)")
"""
CELL_10_PREFLIGHT = """\
# === CELL 10 β€” Pre-flight: Groq routing + key liveness ===
# Verifies that:
# - each role is routed to the model you set in cell 9, and
# - each role's Groq key actually answers a 1-token "PING" prompt.
import os
from ER_MAP.envs.api_router import AgentRouter
router = AgentRouter()
expected = {
"nurse": "llama-3.1-8b-instant",
"patient": "llama-3.1-8b-instant",
"empathy_judge": "llama-3.3-70b-versatile",
"medical_judge": "llama-3.3-70b-versatile",
}
print("=" * 60); print(" PRE-FLIGHT β€” Groq routing + smoke test"); print("=" * 60)
all_pass = True
for role, exp in expected.items():
actual = router._models.get(role, "?")
routing_ok = (actual == exp)
client = router._clients.get(role)
if client is None:
print(f" [SKIP] {role:14s} -> no Groq client (key missing)")
all_pass = False
continue
try:
resp = client.chat.completions.create(
model=exp,
messages=[{"role": "user", "content": "Reply with exactly: PING"}],
max_tokens=4, temperature=0,
)
api_ok = "PING" in (resp.choices[0].message.content or "").upper()
err = ""
except Exception as e:
api_ok = False
err = f" ({type(e).__name__}: {str(e)[:80]})"
flag = "PASS" if (routing_ok and api_ok) else "FAIL"
if flag == "FAIL":
all_pass = False
print(f" [{flag}] {role:14s} -> {actual:30s} "
f"routing={'ok' if routing_ok else 'WRONG'}, "
f"api={'ok' if api_ok else 'fail'}{err}")
print("=" * 60)
print("OK" if all_pass else "NOT OK β€” fix routing/keys before training.")
print("=" * 60)
assert all_pass, "Pre-flight failed; do not proceed to training."
"""
CELL_11_DRYRUN = """\
# === CELL 11 β€” Dry-run smoke test (no GPU, no model load) ===
# Verifies the curriculum scheduler + reward verifier + per-phase early-stop
# wiring before we burn GPU minutes on the real run.
from ER_MAP.training.train_grpo import train
_ = train(
num_episodes=8,
group_size=2,
model_name=MODEL_NAME,
learning_rate=LEARNING_RATE,
kl_beta=KL_BETA,
output_dir="/kaggle/working/_dryrun",
dry_run=True,
phase_reward_targets=PHASE_REWARD_TARGETS,
phase_min_win_rate=PHASE_MIN_WIN_RATE,
convergence_window=CONVERGENCE_WINDOW,
early_stop=EARLY_STOP_ENABLED,
)
print("\\nDry-run OK β€” scheduler + verifier + per-phase early-stop wiring is healthy.")
"""
CELL_12_HOOK = """\
# === CELL 12 β€” Wire periodic HF Hub push into training ===
# We monkey-patch save_lora_adapters so every checkpoint dump also pushes
# the LoRA adapter to HF Hub. Failures are non-fatal β€” training keeps
# running even if a push fails (e.g. transient HF 502).
from ER_MAP.training import train_grpo as _tg
_original_save = _tg.save_lora_adapters
def save_lora_adapters_with_push(model, tokenizer, output_dir):
_original_save(model, tokenizer, output_dir)
if HF_PUSH_REPO and "<your-username>" not in HF_PUSH_REPO:
try:
push_checkpoint_to_hub(
output_dir, HF_PUSH_REPO,
commit_message=f"checkpoint @ {os.path.basename(output_dir)}",
)
except Exception as e:
print(f" [hub-push] non-fatal failure: {e}")
_tg.save_lora_adapters = save_lora_adapters_with_push
print("Hub-push hook installed.")
"""
CELL_13_TRAIN_MD = """\
## 13 Β· Run real training (fixed-budget curriculum, ~5 hours)
**Mode:** fixed-budget (`PHASE_EPISODE_BUDGETS = {1: 20, 2: 30, 3: 50}` in cell 9).
The reward thresholds in `PHASE_REWARD_TARGETS` are **observational only** β€”
they're plotted as horizontal target lines on the reward curve, but they do
NOT cause early termination. Each phase runs its full episode budget so the
final reward-growth chart shows clean, monotone progression.
**Estimated wall-clock on Kaggle T4 (Γ—1 active GPU):**
- ~2–4 min per episode (Doctor.generate + 4–8 Γ— Groq API calls per turn)
- ~1–2 min amortized per GRPO update (G=2 trajectories)
- **Per-group β‰ˆ 5–10 min** (2 episodes + 1 update)
| Phase | Episodes (this run) | GRPO updates | Wall-clock estimate |
|---|---|---|---|
| 1 β€” Tool Mastery | **20** | 10 | ~50 min – 1.7 h |
| 2 β€” Clinical Reasoning | **30** | 15 | ~1.3 – 2.5 h |
| 3 β€” Empathetic Negotiation | **50** | 25 | ~2.0 – 4.0 h |
| **Total** | **100** | **50** | **~4.0 – 8.0 h** |
Checkpoints are pushed to HF Hub every `PUSH_EVERY_EPS=20` episodes, so if
the Kaggle session expires mid-run you can resume in a fresh session via
`HF_RESUME_REPO` in cell 8.
> **Want even faster?** Drop `PHASE_EPISODE_BUDGETS` to `{1: 10, 2: 15, 3: 25}`
> in cell 9 (50 episodes total, ~2.0 – 4.0 h). The curve will be choppier but
> still shows phase transitions cleanly.
>
> **Want adaptive (early-stop)?** Set `PHASE_EPISODE_BUDGETS = None` in cell 9
> and `EARLY_STOP_ENABLED = True`; each phase will end the moment its reward
> target is sustained for `CONVERGENCE_WINDOW=3` consecutive groups.
"""
CELL_13_TRAIN = """\
# === CELL 13 β€” REAL TRAINING (4-6 h cell, fixed-budget curriculum) ===
metrics = train(
num_episodes=NUM_EPISODES,
group_size=GROUP_SIZE,
model_name=MODEL_NAME,
groq_api_key=((os.environ.get("GROQ_NURSE_API_KEY") or os.environ.get("nurse")) or os.environ.get("nurse", ""))
or ((os.environ.get("GROQ_API_KEY") or os.environ.get("groq")) or os.environ.get("groq", "")),
learning_rate=LEARNING_RATE,
kl_beta=KL_BETA,
use_wandb=USE_WANDB,
output_dir=OUTPUT_DIR,
dry_run=False,
phase_reward_targets=PHASE_REWARD_TARGETS,
phase_min_win_rate=PHASE_MIN_WIN_RATE,
convergence_window=CONVERGENCE_WINDOW,
early_stop=EARLY_STOP_ENABLED,
phase_episode_budgets=PHASE_EPISODE_BUDGETS, # None -> early-stop mode
)
print(f"\\nTraining returned {len(metrics)} metric records.")
"""
CELL_14_FINAL_PUSH = """\
# === CELL 14 β€” Final push: adapters + merged fp16 ===
FINAL_LORA_DIR = f"{OUTPUT_DIR}/final_lora"
FINAL_MERGED_DIR = f"{OUTPUT_DIR}/final_merged_fp16"
if HF_PUSH_REPO and "<your-username>" not in HF_PUSH_REPO:
push_checkpoint_to_hub(FINAL_LORA_DIR, HF_PUSH_REPO,
commit_message="final LoRA adapter")
if os.path.isdir(FINAL_MERGED_DIR):
push_checkpoint_to_hub(FINAL_MERGED_DIR, f"{HF_PUSH_REPO}-merged",
commit_message="final merged fp16")
print(f"Final checkpoints pushed: https://huggingface.co/{HF_PUSH_REPO}")
else:
print("HF_PUSH_REPO not configured β€” skipping final push.")
"""
CELL_15_PLOTS_MD = """\
## 15 Β· Per-phase training graphs (one dashboard per curriculum phase)
We render a 6-panel dashboard for **every phase that contains episodes**,
plus a cross-phase overview and a phase-comparison bar chart. All PNGs are
written to `er_map_grpo_checkpoints/plots/` and uploaded to HF Hub in the
next cell so they survive Kaggle session expiry.
Each per-phase dashboard contains:
1. **Reward growth** β€” raw scatter + rolling mean (w=10) + verified rolling mean
2. **Rolling win rate** β€” w=20 win-rate evolution within the phase
3. **Outcome distribution over time** β€” stacked bars (WIN/PARTIAL/INCORRECT/AMA_LOSS/FATAL_LOSS)
4. **Reward components** β€” mean of each component (process / treatment / empathy / labs / etc.)
5. **GRPO update stats** β€” loss + KL divergence per group update
6. **Episode length distribution** β€” histogram of step counts
"""
CELL_15_PLOTS = """\
# === CELL 15 β€” Per-phase training dashboards ===
from ER_MAP.plotting import plot_per_phase_dashboards
from IPython.display import Image, display, Markdown
PLOTS_DIR = f"{OUTPUT_DIR}/plots"
written = plot_per_phase_dashboards(
metrics_path=f"{OUTPUT_DIR}/training_metrics.json",
output_dir=PLOTS_DIR,
)
print(f"Saved {len(written)} chart(s) to {PLOTS_DIR}:")
for name, path in written.items():
size_kb = os.path.getsize(path) / 1024
print(f" {name:<28s} -> {path} ({size_kb:.0f} KB)")
# Display each chart inline so the operator sees them without leaving Kaggle.
ordered = (sorted(k for k in written if k.startswith("phase"))
+ ["all_phases_overview", "all_phases_comparison"])
for key in ordered:
if key not in written:
continue
display(Markdown(f"### {key.replace('_', ' ').title()}"))
display(Image(filename=written[key]))
"""
CELL_16_PUSH_PLOTS = """\
# === CELL 16 β€” Push plots to HF Hub ===
if HF_PUSH_REPO and "<your-username>" not in HF_PUSH_REPO:
push_checkpoint_to_hub(PLOTS_DIR, HF_PUSH_REPO,
commit_message="per-phase training plots")
print(f"Plots pushed: https://huggingface.co/{HF_PUSH_REPO}/tree/main")
else:
print("HF_PUSH_REPO not configured β€” plots stay only in /kaggle/working/.")
"""
CELL_17_INFER_MD = """\
## 17 Β· (Optional) Inference smoke-test on the trained model
Catches the classic 'merge path looked OK but the saved model emits garbage'
failure mode before the demo.
"""
CELL_17_INFER = """\
# === CELL 17 β€” Inference smoke-test on the trained model ===
from ER_MAP.training.train_grpo import generate_doctor_action, load_model_and_tokenizer
from peft import PeftModel
base_model, tok = load_model_and_tokenizer(model_name=MODEL_NAME)
trained = PeftModel.from_pretrained(base_model, FINAL_LORA_DIR)
test_obs = (
'{"event":"episode_start","nurse_experience":"veteran",'
'"message":"Patient with chest pain, HR 120, BP 90/60, vague history.",'
'"soap_summary":{}}'
)
for i in range(3):
print(f"\\n--- Sample {i+1} ---")
print(generate_doctor_action(trained, tok, test_obs, max_new_tokens=160))
"""
# ---------------------------------------------------------------------------
# Quickstart markdown (sibling file)
# ---------------------------------------------------------------------------
QUICKSTART_MD = """\
# Kaggle Quickstart β€” ER-MAP GRPO Training (v3 stable)
The Kaggle notebook is in `kaggle/train_ermap_grpo_kaggle.ipynb`. This file
is the cheat sheet for running it end-to-end without the dependency hell
that bit us in earlier attempts.
## 0. Prerequisites (one-time)
1. **GitHub fork** of this repo. The notebook clones from a public fork at
cell 6 β€” edit `GIT_URL`. Alternatively, upload the repo as a Kaggle
Dataset named `ermap-source` (Add Data β†’ Upload).
2. **Hugging Face write token** (`HF_TOKEN`) for pushing the trained
adapter. Create at https://huggingface.co/settings/tokens (fine-grained,
write access on a single model repo is enough).
3. **Five Groq keys** (one each for Nurse / Patient / Empathy Judge /
Medical Judge / shared fallback). Free-tier accounts are fine; the
per-account limits multiply across keys.
## 1. Create the Kaggle notebook
1. Sign in to https://www.kaggle.com/code β†’ **New Notebook**.
2. Right sidebar:
- Accelerator: **GPU T4 Γ—2** (or P100)
- Internet: **On**
- Persistence: Files only
3. **File β†’ Upload Notebook** β†’ choose `kaggle/train_ermap_grpo_kaggle.ipynb`
from this repo.
## 2. Add Kaggle Secrets
Add-ons β†’ Secrets β†’ Add a new secret. Required labels (exactly):
| Label | Value |
|---|---|
| `GROQ_NURSE_API_KEY` | your nurse Groq key |
| `GROQ_PATIENT_API_KEY` | your patient Groq key |
| `GROQ_EMPATHY_JUDGE_API_KEY` | your empathy-judge Groq key |
| `GROQ_MEDICAL_JUDGE_API_KEY` | your medical-judge Groq key |
| `HF_TOKEN` | your HF write token |
| `WANDB_API_KEY` *(optional)* | your W&B key (skip β€” disabled by default) |
The notebook reads them via `kaggle_helpers.load_kaggle_secrets()` and
exports them as env vars.
## 3. Edit two placeholders in the notebook
- **Cell 6:** `GIT_URL = "https://github.com/<your-fork>/Meta_Finals.git"`
- **Cell 8:** `HF_PUSH_REPO = "<your-username>/ermap-doctor-lora"`
If you uploaded the repo as a Kaggle Dataset instead, leave `GIT_URL` as the
placeholder β€” cell 6 will detect `/kaggle/input/ermap-source` and copy from
there.
## 4. Run order (the only sequence that works)
| Cell | What it does | Expected output |
|---|---|---|
| 2 | GPU + disk + python + internet sanity check | GPU listed, disk free > 8 GB |
| 3 | **REPAIR** β€” pin torch 2.10 cu128, reinstall bitsandbytes, upgrade unsloth | `REPAIR OK` (or `RESTART REQUIRED`) |
| **(restart)** | If cell 3 said RESTART REQUIRED β†’ Run β†’ Restart kernel | β€” |
| 5 | Post-restart import verify | All `OK`, GPUs listed |
| 6 | Clone / mount the repo | `OK. Repo at /kaggle/working/Meta_Finals` |
| 7 | Wire Kaggle Secrets β†’ env vars | `OK β€” at least one Groq key is wired` |
| 8 | HF Hub config | `Starting fresh β€” no resume.` |
| 9 | Hyperparameters (P1=+1.2, P2=+1.1, P3=+1.0) | thresholds printed |
| 10 | **Pre-flight** β€” Groq routing + 4Γ— PING | 4Γ— `[PASS]`, then `OK` |
| 11 | Dry-run smoke test (no GPU) | `Dry-run OK` |
| 12 | Wire HF push hook | `Hub-push hook installed.` |
| 13 | **REAL TRAINING** (4–6 h) | per-group rolling stats, eventual `EARLY STOP` |
| 14 | Final push to HF | `Final checkpoints pushed: https://huggingface.co/...` |
| 15 | Per-phase plots | 5 PNGs displayed inline |
| 16 | Push plots to HF | `Plots pushed: ...` |
| 17 | Inference smoke-test (optional) | 3 sample Doctor actions printed |
## 5. Common failures & fixes
| Symptom | Root cause | Fix |
|---|---|---|
| `numpy was upgraded mid-session` | numpy import poisoned by a previous cell | Restart kernel, re-run from cell 3 |
| `Pillow incompatible with torchvision` | Pillow ABI mismatch | Restart kernel, re-run from cell 3 |
| `PyTorch and torchvision compiled with different CUDA major` | torch upgraded to cu13 by a transient resolve | Re-run cell 3 (it pins cu128) and restart |
| `cannot import name 'create_gradient_checkpointing_buffer'` | unsloth ↔ unsloth_zoo version drift | Re-run cell 3 (upgrades both in lockstep) |
| `libnvJitLink.so.13 missing` | bitsandbytes built against different CUDA | Re-run cell 3 (force-reinstalls bitsandbytes after torch pin) |
| Disk usage > quota | Kaggle's 20 GB working partition fills up | First line of cell 3 cleans `/tmp` and pip cache |
| Pre-flight `[FAIL]` for a role | Groq key dead / quota exceeded | Generate a new key in console.groq.com β†’ update Kaggle Secret β†’ re-run cell 7+10 |
| `[FAIL]` says `routing=WRONG` | env var not set when `AgentRouter()` was constructed | Re-run cell 9 BEFORE cell 10 |
| Training freezes at episode 1 for >10 min | Doctor.generate hung; Unsloth import broke silently | Check cell 5 output for `unsloth` line; restart kernel and re-run cell 3 if missing |
## 6. What the trained model gives you
After cell 13 finishes (or hits the 12 h Kaggle session cap), you have:
- `OUTPUT_DIR/final_lora/` β€” LoRA adapter weights (~50 MB), pushed to
`HF_PUSH_REPO`
- `OUTPUT_DIR/final_merged_fp16/` β€” full Llama-3.1-8B fp16 merge with the
adapter applied (~16 GB), pushed to `HF_PUSH_REPO-merged`
- `OUTPUT_DIR/training_metrics.json` β€” per-episode rewards, outcomes,
rolling stats β€” input for the per-phase plots
- `OUTPUT_DIR/plots/*.png` β€” 5 dashboards (one per phase + cross-phase
overview + comparison bar)
Use the LoRA adapter for the demo (quick to load, runs on a 4050 6 GB at
~30 tok/s); use the merged fp16 if you need to host on a Vercel/HF Space
without `peft`.
"""
# ---------------------------------------------------------------------------
# Build the notebook
# ---------------------------------------------------------------------------
def build_notebook() -> dict:
cells = [
md_cell(CELL_01_TITLE), # 0
code_cell(CELL_02_SANITY), # 1
code_cell(CELL_03_REPAIR), # 2
md_cell(CELL_04_RESTART), # 3
code_cell(CELL_05_VERIFY), # 4
code_cell(CELL_06_REPO), # 5
code_cell(CELL_07_SECRETS), # 6
code_cell(CELL_08_HF), # 7
code_cell(CELL_09_HYPERPARAMS), # 8
code_cell(CELL_10_PREFLIGHT), # 9
code_cell(CELL_11_DRYRUN), # 10
code_cell(CELL_12_HOOK), # 11
md_cell(CELL_13_TRAIN_MD), # 12
code_cell(CELL_13_TRAIN), # 13
code_cell(CELL_14_FINAL_PUSH), # 14
md_cell(CELL_15_PLOTS_MD), # 15
code_cell(CELL_15_PLOTS), # 16
code_cell(CELL_16_PUSH_PLOTS), # 17
md_cell(CELL_17_INFER_MD), # 18
code_cell(CELL_17_INFER), # 19
]
return {
"cells": cells,
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3",
},
"language_info": {
"name": "python",
"version": "3.10",
},
},
"nbformat": 4,
"nbformat_minor": 5,
}
def main() -> None:
here = Path(__file__).parent
nb_path = here / "train_ermap_grpo_kaggle.ipynb"
qs_path = here / "KAGGLE_QUICKSTART.md"
nb = build_notebook()
nb_path.write_text(json.dumps(nb, indent=1, ensure_ascii=False), encoding="utf-8")
qs_path.write_text(QUICKSTART_MD, encoding="utf-8")
n_md = sum(1 for c in nb["cells"] if c["cell_type"] == "markdown")
n_code = sum(1 for c in nb["cells"] if c["cell_type"] == "code")
print(f"Wrote {nb_path} ({len(nb['cells'])} cells: {n_md} md / {n_code} code)")
print(f"Wrote {qs_path} ({len(QUICKSTART_MD.splitlines())} lines)")
if __name__ == "__main__":
main()