Spaces:
Running
Running
| """ | |
| kaggle/build_notebook.py | |
| ======================== | |
| Programmatically (re)builds `train_ermap_grpo_kaggle.ipynb` from scratch. | |
| Why a builder script? | |
| -------------------- | |
| The hand-edited notebook drifted into a fragile state across many sessions: | |
| mixed early-stop / fixed-budget params, stale install snippets, dead pre-flight | |
| checks, etc. This script is the single source of truth β run it once and the | |
| notebook is regenerated as a clean, deterministic v3 layout. | |
| Run: | |
| python kaggle/build_notebook.py | |
| Output: | |
| kaggle/train_ermap_grpo_kaggle.ipynb (overwritten) | |
| kaggle/KAGGLE_QUICKSTART.md (overwritten) | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import textwrap | |
| from pathlib import Path | |
| # --------------------------------------------------------------------------- | |
| # Cell helpers | |
| # --------------------------------------------------------------------------- | |
| def md_cell(text: str) -> dict: | |
| return { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": _split_keep_newlines(text), | |
| } | |
| def code_cell(text: str) -> dict: | |
| return { | |
| "cell_type": "code", | |
| "execution_count": None, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": _split_keep_newlines(text), | |
| } | |
| def _split_keep_newlines(text: str) -> list[str]: | |
| """Notebook 'source' fields expect each line to terminate with '\n' | |
| except the last one. Splitting like this keeps `git diff` clean when | |
| the notebook is regenerated.""" | |
| text = textwrap.dedent(text).lstrip("\n") | |
| if not text.endswith("\n"): | |
| text = text + "\n" | |
| lines = text.splitlines(keepends=True) | |
| if lines: | |
| # The last line should NOT have a trailing newline (Jupyter convention). | |
| if lines[-1].endswith("\n"): | |
| lines[-1] = lines[-1].rstrip("\n") | |
| return lines | |
| # --------------------------------------------------------------------------- | |
| # Cell sources | |
| # --------------------------------------------------------------------------- | |
| CELL_01_TITLE = """\ | |
| # ER-MAP β Doctor Agent GRPO Training (Kaggle Free-Tier Β· v3 stable) | |
| Trains the **Doctor LLM** (Llama-3.1-8B-Instruct, 4-bit + LoRA r=16) via GRPO | |
| with a 3-phase curriculum on Kaggle's free GPU. Designed to survive Kaggle's | |
| pre-baked image quirks (numpy / Pillow ABI mismatches, torch + torchvision | |
| CUDA-major mismatches, transient `unsloth_zoo` upgrades). | |
| ## TL;DR β How to run this notebook | |
| 1. **Notebook settings (right sidebar):** | |
| - Accelerator: **GPU T4 Γ2** (or P100) | |
| - Internet: **On** | |
| - Persistence: Files only | |
| 2. **Kaggle Secrets** (Add-ons β Secrets): | |
| - **Required:** `GROQ_NURSE_API_KEY`, `GROQ_PATIENT_API_KEY`, | |
| `GROQ_EMPATHY_JUDGE_API_KEY`, `GROQ_MEDICAL_JUDGE_API_KEY`, `HF_TOKEN` | |
| - **Optional:** `WANDB_API_KEY` | |
| 3. **Run cells 2 β 3 (sanity + REPAIR).** When cell 3 prints | |
| `RESTART REQUIRED`, click **Run β Restart kernel**, then resume from cell 5. | |
| 4. **Run cells 5 β 11 (verify + configure + dry-run + pre-flight).** Each cell | |
| should print an `OK` line before moving on. | |
| 5. **Run cell 13 (the long training cell, 4β6 hours).** | |
| 6. **Run cells 14 β 17 (final push + plots + inference smoke-test).** | |
| ## Curriculum + reward thresholds (this run) | |
| Constant per-phase rolling-avg-reward bars; sustained for **3 consecutive | |
| GRPO groups** triggers either a phase promotion or end-of-training. | |
| | Phase | Reward target (sustained Γ3 groups) | Action when met | | |
| |---|---|---| | |
| | 1 β Tool Mastery | `+1.2` | force-promote to Phase 2 | | |
| | 2 β Clinical Reasoning | `+1.1` | force-promote to Phase 3 | | |
| | 3 β Empathetic Negotiation | `+1.0` | END TRAINING | | |
| Why these numbers? The un-trained 8B Doctor's baseline on the same env is | |
| `P1=+0.76, P2=+0.59, P3=+0.39`. Targets of `+1.2 / +1.1 / +1.0` correspond | |
| to roughly `1.6Γ / 1.9Γ / 2.6Γ` improvement over baseline β a meaningful | |
| signal but reachable inside Kaggle's 12 h session limit. | |
| """ | |
| CELL_02_SANITY = """\ | |
| # === CELL 2 β Sanity check (GPU + disk + python + internet) === | |
| # Run this FIRST. If any check fails, fix it before running the REPAIR cell. | |
| import os, shutil, subprocess, sys, socket | |
| print("--- GPU ---") | |
| try: | |
| print(subprocess.check_output( | |
| ["nvidia-smi", "--query-gpu=name,memory.total,memory.free", "--format=csv"], | |
| timeout=10, | |
| ).decode()) | |
| except Exception as e: | |
| print(f"nvidia-smi failed: {e}") | |
| print("-> Set Accelerator to 'GPU T4 x2' in the right sidebar.") | |
| print("--- Disk (/kaggle/working) ---") | |
| total, used, free = shutil.disk_usage("/kaggle/working") | |
| print(f" total={total/1e9:5.1f} GB | used={used/1e9:5.1f} GB | free={free/1e9:5.1f} GB") | |
| if free < 8 * 1e9: | |
| print(" WARNING: free disk < 8 GB β repair cell may fail. " | |
| "Consider 'Run > Restart and clear cell outputs' to reset /tmp.") | |
| print("--- Python ---") | |
| print(f" python={sys.version.split()[0]} | exe={sys.executable}") | |
| print("--- Internet (api.groq.com:443) ---") | |
| try: | |
| socket.create_connection(("api.groq.com", 443), timeout=5).close() | |
| print(" reachable") | |
| except Exception as e: | |
| print(f" UNREACHABLE: {e}") | |
| print(" -> Settings (right sidebar) -> Internet -> ON") | |
| """ | |
| CELL_03_REPAIR = """\ | |
| # === CELL 3 β REPAIR CELL (idempotent full environment rebuild) === | |
| # Single source of truth for ER-MAP's GPU stack. Safe to re-run. After it | |
| # finishes you'll see one of two final lines: | |
| # | |
| # RESTART REQUIRED -> Run -> Restart kernel, then resume from cell 5 | |
| # REPAIR OK -> proceed directly to cell 5 | |
| # | |
| # Note: this cell only runs shell commands and one isolated subprocess. | |
| # It deliberately does NOT `import torch / numpy / Pillow / unsloth` in the | |
| # kernel, so re-running it after a botched install does not poison further | |
| # attempts. | |
| print("=" * 72); print(" CELL 3 β REPAIR"); print("=" * 72) | |
| # 1. Clean caches (Kaggle's /kaggle/working is only 20 GB β installs | |
| # routinely fill it after a few re-runs). | |
| print("[1/6] Cleaning pip + tmp + HF dataset caches...") | |
| get_ipython().system('pip cache purge -q || true') | |
| get_ipython().system('rm -rf /tmp/* /root/.cache/pip /root/.cache/huggingface/datasets 2>/dev/null || true') | |
| # 2. Pin torch + torchvision to the cu128 wheel (matches Kaggle's CUDA 12.8 | |
| # base image). DON'T let pip pull a generic CUDA-13 build β that breaks | |
| # bitsandbytes (libnvJitLink.so.13 missing) and torchvision (CUDA-major | |
| # mismatch RuntimeError at import time). | |
| print("[2/6] Installing torch==2.10.0 + torchvision==0.25.0 (cu128)...") | |
| get_ipython().system('pip install -q --no-cache-dir --force-reinstall ' | |
| 'torch==2.10.0 torchvision==0.25.0 ' | |
| '--index-url https://download.pytorch.org/whl/cu128') | |
| # Write a pip constraints file so subsequent installs (bnb, unsloth, trl, etc.) | |
| # can NEVER pull a different torch from default PyPI. Without this, step 3's | |
| # `--force-reinstall bitsandbytes` and step 4's `unsloth` upgrade re-resolve | |
| # torch from PyPI (currently 2.11.0), which breaks the cu128 torchvision pair. | |
| # | |
| # Also pin numpy to whatever Kaggle's kernel already has loaded β Kaggle's | |
| # image puts numpy at /usr/lib/python3/dist-packages while pip writes to | |
| # /usr/local/lib/python3.12/dist-packages, so any version drift between the | |
| # two paths trips unsloth_zoo's strict loaded-vs-installed check at import. | |
| import subprocess as _sp | |
| _kernel_numpy = _sp.check_output( | |
| [sys.executable, "-c", "import numpy; print(numpy.__version__)"], | |
| text=True, | |
| ).strip() | |
| print(f" detected kernel numpy = {_kernel_numpy} (will pin)") | |
| with open("/tmp/ermap_constraints.txt", "w") as _cf: | |
| _cf.write(f"torch==2.10.0\\ntorchvision==0.25.0\\nnumpy=={_kernel_numpy}\\n") | |
| # 3. Reinstall bitsandbytes against the now-pinned torch. | |
| # --no-deps because bnb just needs torch at RUNTIME (it dlopens torch's | |
| # C++ lib) β its install-time deps don't include torch. | |
| print("[3/6] Reinstalling bitsandbytes (--no-deps to preserve torch)...") | |
| get_ipython().system('pip install -q --no-cache-dir --force-reinstall --no-deps bitsandbytes') | |
| # 4. Upgrade unsloth + unsloth_zoo + trl in lockstep. unsloth and | |
| # unsloth_zoo are released as a matched pair; if pip pulls a fresh | |
| # unsloth_zoo against an old unsloth you get | |
| # ImportError: cannot import name 'create_gradient_checkpointing_buffer' | |
| # The constraint file blocks them from moving torch. | |
| print("[4/6] Upgrading unsloth + unsloth_zoo + trl (constrained)...") | |
| get_ipython().system('pip install -q --upgrade --no-cache-dir ' | |
| '-c /tmp/ermap_constraints.txt ' | |
| 'unsloth unsloth_zoo "trl>=0.18.2"') | |
| # 5. ER-MAP runtime deps that aren't pre-installed on Kaggle. | |
| print("[5/6] Installing ER-MAP runtime deps (constrained)...") | |
| get_ipython().system('pip install -q --no-cache-dir ' | |
| '-c /tmp/ermap_constraints.txt ' | |
| '"groq>=0.18.0" "huggingface_hub>=0.25.0" ' | |
| '"gymnasium>=0.29.0" "openenv-core>=0.1.0"') | |
| # 5b. Realign the pip-managed numpy with whatever the Kaggle kernel actually | |
| # has loaded. This force-rewrites /usr/local/lib/.../numpy at the exact | |
| # version reported by the running interpreter, so importlib.metadata | |
| # and `numpy.__version__` agree even if Kaggle ships its base numpy at | |
| # a different dist-packages path. | |
| print(f"[5b/6] Realigning pip-managed numpy to {_kernel_numpy}...") | |
| get_ipython().system( | |
| f'pip install -q --force-reinstall --no-deps "numpy=={_kernel_numpy}"' | |
| ) | |
| # 6. Verify in a SUBPROCESS (so the parent kernel never imports any of these | |
| # while pip is mid-flight, which is what causes the | |
| # 'numpy was upgraded mid-session (loaded: X, installed: Y)' RuntimeError | |
| # we kept hitting before). | |
| print("[6/6] Verifying via subprocess...") | |
| import subprocess, sys, json | |
| verify_script = r''' | |
| import json, sys | |
| out = {"ok": True, "details": {}, "errors": []} | |
| try: | |
| import importlib.metadata as md | |
| for pkg in ("torch", "torchvision", "bitsandbytes", "unsloth", "unsloth_zoo", | |
| "trl", "transformers", "peft", "accelerate", "groq", | |
| "huggingface_hub", "gymnasium", "numpy", "Pillow"): | |
| try: | |
| out["details"][pkg + "_installed"] = md.version(pkg) | |
| except md.PackageNotFoundError: | |
| out["details"][pkg + "_installed"] = None | |
| import torch, torchvision, numpy as np, PIL, unsloth, unsloth_zoo, bitsandbytes, trl | |
| out["details"]["torch_loaded"] = torch.__version__ | |
| out["details"]["torch_cuda"] = torch.version.cuda | |
| out["details"]["cuda_available"] = bool(torch.cuda.is_available()) | |
| out["details"]["gpu_count"] = int(torch.cuda.device_count()) | |
| out["details"]["torchvision_loaded"] = torchvision.__version__ | |
| out["details"]["numpy_loaded"] = np.__version__ | |
| out["details"]["pillow_loaded"] = PIL.__version__ | |
| out["details"]["unsloth_loaded"] = unsloth.__version__ | |
| out["details"]["unsloth_zoo_loaded"] = unsloth_zoo.__version__ | |
| out["details"]["bitsandbytes_loaded"] = bitsandbytes.__version__ | |
| out["details"]["trl_loaded"] = trl.__version__ | |
| # Cross-check loaded-vs-installed for the C-extension libs that bit us | |
| # on every previous run. | |
| for pkg, loaded_key, installed_key in [ | |
| ("numpy", "numpy_loaded", "numpy_installed"), | |
| ("Pillow", "pillow_loaded", "Pillow_installed"), | |
| ("torch", "torch_loaded", "torch_installed"), | |
| ]: | |
| loaded = out["details"].get(loaded_key) | |
| installed = out["details"].get(installed_key) | |
| if loaded and installed and loaded != installed: | |
| # Strip any local-version suffix (e.g. '+cu128') before compare. | |
| if loaded.split("+")[0] != installed.split("+")[0]: | |
| out["errors"].append( | |
| f"{pkg} mismatch: loaded={loaded} installed={installed}" | |
| ) | |
| except Exception as e: | |
| out["ok"] = False | |
| out["errors"].append(f"{type(e).__name__}: {e}") | |
| print(json.dumps(out, default=str)) | |
| '''.lstrip() | |
| res = subprocess.run([sys.executable, "-c", verify_script], | |
| capture_output=True, text=True, timeout=180) | |
| print(res.stdout if res.stdout else "<no stdout>") | |
| if res.stderr: | |
| print("---- subprocess stderr ----"); print(res.stderr) | |
| # Parse the LAST line of stdout (others are prints from package init). | |
| try: | |
| last = res.stdout.strip().splitlines()[-1] | |
| parsed = json.loads(last) | |
| except Exception: | |
| parsed = {"ok": False, "errors": ["could not parse verification output"]} | |
| ok = parsed.get("ok") and not parsed.get("errors") | |
| d = parsed.get("details", {}) | |
| print("\\n" + "=" * 72) | |
| if ok: | |
| print(" REPAIR OK") | |
| print(f" torch : {d.get('torch_loaded')} (CUDA {d.get('torch_cuda')})") | |
| print(f" torchvision : {d.get('torchvision_loaded')}") | |
| print(f" bitsandbytes: {d.get('bitsandbytes_loaded')}") | |
| print(f" unsloth : {d.get('unsloth_loaded')} | unsloth_zoo: {d.get('unsloth_zoo_loaded')}") | |
| print(f" trl : {d.get('trl_loaded')}") | |
| print(f" numpy : {d.get('numpy_loaded')} | Pillow: {d.get('pillow_loaded')}") | |
| print(f" GPUs : {d.get('gpu_count')} (cuda_available={d.get('cuda_available')})") | |
| print() | |
| print(" -> If this kernel previously imported torch/numpy/Pillow/unsloth,") | |
| print(" RESTART NOW (Run -> Restart kernel) before continuing to cell 5.") | |
| print(" If this is a fresh kernel, you can proceed directly.") | |
| else: | |
| print(" RESTART REQUIRED β issues detected:") | |
| for e in parsed.get("errors", []): | |
| print(f" - {e}") | |
| print() | |
| print(" Action: Run -> Restart kernel, then re-run from cell 2.") | |
| print("=" * 72) | |
| """ | |
| CELL_04_RESTART = """\ | |
| ## β Restart kernel here if cell 3 said `RESTART REQUIRED` | |
| Click **Run β Restart kernel** (or **Run β Restart & clear cell outputs**), | |
| then resume from **cell 5**. Skipping the restart will produce ABI mismatch | |
| errors at the first GPU op. | |
| If cell 3 said `REPAIR OK` AND this is a fresh kernel that hasn't imported | |
| torch/numpy/Pillow/unsloth yet, you can proceed to cell 5 directly. | |
| """ | |
| CELL_05_VERIFY = """\ | |
| # === CELL 5 β Post-restart verify (this kernel can import everything) === | |
| import importlib.metadata as md | |
| print("--- Loaded versions in this kernel ---") | |
| import torch, numpy, PIL, torchvision, unsloth, unsloth_zoo, bitsandbytes, trl, transformers, peft | |
| versions = { | |
| "torch": torch.__version__, | |
| "torchvision": torchvision.__version__, | |
| "numpy": numpy.__version__, | |
| "Pillow": PIL.__version__, | |
| "unsloth": unsloth.__version__, | |
| "unsloth_zoo": unsloth_zoo.__version__, | |
| "bitsandbytes": bitsandbytes.__version__, | |
| "trl": trl.__version__, | |
| "transformers": transformers.__version__, | |
| "peft": peft.__version__, | |
| } | |
| all_ok = True | |
| for k, v in versions.items(): | |
| try: | |
| inst = md.version(k) | |
| except md.PackageNotFoundError: | |
| inst = "(not installed)" | |
| # Tolerate local version suffixes like '+cu128' | |
| flag = "OK" if inst.split("+")[0] == v.split("+")[0] else f"MISMATCH (installed={inst})" | |
| if "MISMATCH" in flag: | |
| all_ok = False | |
| print(f" {k:14s}: loaded={v:20s} [{flag}]") | |
| print() | |
| print(f" CUDA available : {torch.cuda.is_available()}") | |
| print(f" GPU count : {torch.cuda.device_count()}") | |
| if torch.cuda.is_available(): | |
| for i in range(torch.cuda.device_count()): | |
| p = torch.cuda.get_device_properties(i) | |
| print(f" GPU {i} : {p.name} ({p.total_memory/1e9:.1f} GB)") | |
| print() | |
| print("OK" if all_ok else "NOT OK β re-run cell 3 and restart kernel.") | |
| """ | |
| CELL_06_REPO = """\ | |
| # === CELL 6 β Mount the ER-MAP repo into /kaggle/working === | |
| import os, subprocess, sys | |
| # OPTION A: clone a public GitHub fork (preferred). Edit GIT_URL. | |
| GIT_URL = "https://github.com/<your-fork>/Meta_Finals.git" | |
| BRANCH = "main" | |
| REPO_ROOT = "/kaggle/working/Meta_Finals" | |
| # OPTION B: Kaggle Dataset upload β set this if you uploaded the repo | |
| # as a Kaggle Dataset named "ermap-source" (Add Data -> Upload). | |
| DATASET_DIR = "/kaggle/input/ermap-source" | |
| if not os.path.isdir(f"{REPO_ROOT}/ER_MAP"): | |
| if "<your-fork>" not in GIT_URL: | |
| print(f"Cloning {GIT_URL}@{BRANCH} -> {REPO_ROOT}...") | |
| out = subprocess.run( | |
| ["git", "clone", "--depth", "1", "-b", BRANCH, GIT_URL, REPO_ROOT], | |
| capture_output=True, text=True, | |
| ) | |
| print(out.stdout); print(out.stderr) | |
| elif os.path.isdir(DATASET_DIR): | |
| print(f"Copying {DATASET_DIR} -> {REPO_ROOT}...") | |
| import shutil | |
| shutil.copytree(DATASET_DIR, REPO_ROOT, dirs_exist_ok=True) | |
| assert os.path.isdir(f"{REPO_ROOT}/ER_MAP"), ( | |
| "Repo not found.\\n" | |
| " - Edit GIT_URL above to your GitHub fork, OR\\n" | |
| " - Upload the repo as a Kaggle Dataset named 'ermap-source' (Add Data -> Upload)." | |
| ) | |
| sys.path.insert(0, REPO_ROOT) | |
| sys.path.insert(0, f"{REPO_ROOT}/kaggle") | |
| print(f"OK. Repo at {REPO_ROOT}") | |
| """ | |
| CELL_07_SECRETS = """\ | |
| # === CELL 7 β Wire Kaggle Secrets into env vars === | |
| import os | |
| from kaggle_helpers import load_kaggle_secrets, kaggle_env_summary | |
| load_kaggle_secrets() | |
| kaggle_env_summary() | |
| # Hard fail if no Groq key β training would silently use mock LLMs. | |
| assert any(os.environ.get(k) for k in ( | |
| "GROQ_NURSE_API_KEY", "GROQ_PATIENT_API_KEY", | |
| "GROQ_EMPATHY_JUDGE_API_KEY", "GROQ_MEDICAL_JUDGE_API_KEY", | |
| "GROQ_API_KEY", | |
| )), ("No Groq key found in Kaggle Secrets. " | |
| "Add at least GROQ_NURSE_API_KEY in Add-ons -> Secrets.") | |
| print("OK β at least one Groq key is wired.") | |
| """ | |
| CELL_08_HF = """\ | |
| # === CELL 8 β Hugging Face Hub config (for checkpoint backup) === | |
| import os | |
| from kaggle_helpers import push_checkpoint_to_hub, download_checkpoint_from_hub | |
| # EDIT the line below to your HF model id (e.g. "udayd/ermap-doctor-lora"). | |
| HF_PUSH_REPO = "<your-username>/ermap-doctor-lora" | |
| # To resume from a previous run, paste the same repo id here. Empty = fresh. | |
| HF_RESUME_REPO = "" | |
| RESUME_DIR = "/kaggle/working/checkpoints/resume" | |
| if HF_RESUME_REPO: | |
| download_checkpoint_from_hub(HF_RESUME_REPO, RESUME_DIR) | |
| contents = os.listdir(RESUME_DIR) if os.path.isdir(RESUME_DIR) else [] | |
| print(f"Resume dir: {contents or '(empty)'}") | |
| else: | |
| print("Starting fresh β no resume.") | |
| if "<your-username>" in HF_PUSH_REPO: | |
| print("\\nWARNING: HF_PUSH_REPO still has <your-username> placeholder.") | |
| print(" Checkpoints will NOT be pushed to HF Hub.") | |
| print(" Edit the cell above and re-run before training if you want backups.") | |
| """ | |
| CELL_09_HYPERPARAMS = """\ | |
| # === CELL 9 β GRPO hyperparameters === | |
| import os | |
| MODEL_NAME = "unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit" | |
| GROUP_SIZE = 2 | |
| LEARNING_RATE = 5e-6 | |
| # KL_BETA = 0.0 -> SKIP reference model load (saves ~5 GB VRAM on T4). | |
| # T4 only has 15 GB and one 4-bit Llama-3.1-8B + LoRA + activations + gradients | |
| # already eats ~10 GB; loading a second 4-bit copy as the KL reference OOMs | |
| # the GRPO backward pass. Set this to 0.04 only if you've upgraded to an | |
| # A100 / H100 with >= 24 GB VRAM. | |
| KL_BETA = 0.0 | |
| OUTPUT_DIR = "/kaggle/working/er_map_grpo_checkpoints" | |
| PUSH_EVERY_EPS = 20 | |
| USE_WANDB = False # WANDB conflicts with protobuf 7 on Kaggle base image | |
| # --- Curriculum mode: FIXED-BUDGET (recommended for Kaggle T4) ------------- | |
| # A fixed per-phase episode budget gives you a clean, predictable reward- | |
| # growth curve and bounds your wall-clock. With GROUP_SIZE=2 below and an | |
| # observed ~3 min / episode (Groq-dominated), 100 episodes β 5 hours. | |
| # | |
| # Set PHASE_EPISODE_BUDGETS to None to fall back to early-stopping mode, | |
| # which terminates each phase the moment its reward target is hit (faster | |
| # but less predictable). When PHASE_EPISODE_BUDGETS is set, EARLY_STOP_ENABLED | |
| # is automatically forced to False inside train() β the reward targets below | |
| # become observational only (logged on the plots, not used for promotion). | |
| PHASE_EPISODE_BUDGETS = {1: 20, 2: 25, 3: 30} # 20 + 25 + 30 = 75 episodes | |
| NUM_EPISODES = sum(PHASE_EPISODE_BUDGETS.values()) # = 75 | |
| # --- Per-phase reward thresholds (observational under fixed-budget) -------- | |
| # Plotted as horizontal target lines on the reward-growth chart so you can | |
| # see at a glance whether each phase actually crossed its target. | |
| EARLY_STOP_ENABLED = False # ignored when PHASE_EPISODE_BUDGETS is set | |
| PHASE_REWARD_TARGETS = {1: 1.2, 2: 1.1, 3: 1.0} | |
| PHASE_MIN_WIN_RATE = 0.20 | |
| CONVERGENCE_WINDOW = 3 | |
| # --- Per-episode budget controls (read by triage_env) ---------------------- | |
| os.environ["ERMAP_MAX_EPISODE_STEPS"] = "20" | |
| os.environ["ERMAP_MAX_INTERNAL_EXCHANGES"] = "5" | |
| # --- Groq traffic-shaping (8B for actors, 70B for judges) ------------------ | |
| # High-volume conversational roles (Nurse + Patient) on the 8B-instant pool | |
| # (500K TPD, 14,400 RPD); the two judges stay on 70B-versatile because their | |
| # grading quality directly shapes the reward signal. | |
| os.environ["ERMAP_NURSE_MODEL"] = "llama-3.1-8b-instant" | |
| os.environ["ERMAP_PATIENT_MODEL"] = "llama-3.1-8b-instant" | |
| os.environ["ERMAP_EMPATHY_JUDGE_MODEL"] = "llama-3.3-70b-versatile" | |
| os.environ["ERMAP_MEDICAL_JUDGE_MODEL"] = "llama-3.3-70b-versatile" | |
| print("Hyperparameters set:") | |
| print(f" NUM_EPISODES = {NUM_EPISODES}") | |
| print(f" GROUP_SIZE = {GROUP_SIZE}") | |
| print(f" PHASE_EPISODE_BUDGETS = {PHASE_EPISODE_BUDGETS}") | |
| print(f" PHASE_REWARD_TARGETS = {PHASE_REWARD_TARGETS} (observational)") | |
| print(f" EARLY_STOP_ENABLED = {EARLY_STOP_ENABLED} (forced off under fixed budget)") | |
| print(f" KL_BETA = {KL_BETA} (0.0 -> skip ref model, T4-safe)") | |
| print(f" Nurse / Patient = llama-3.1-8b-instant (actors, high-volume)") | |
| print(f" Empathy / Med Judge = llama-3.3-70b-versatile (graders, quality)") | |
| """ | |
| CELL_10_PREFLIGHT = """\ | |
| # === CELL 10 β Pre-flight: Groq routing + key liveness === | |
| # Verifies that: | |
| # - each role is routed to the model you set in cell 9, and | |
| # - each role's Groq key actually answers a 1-token "PING" prompt. | |
| import os | |
| from ER_MAP.envs.api_router import AgentRouter | |
| router = AgentRouter() | |
| expected = { | |
| "nurse": "llama-3.1-8b-instant", | |
| "patient": "llama-3.1-8b-instant", | |
| "empathy_judge": "llama-3.3-70b-versatile", | |
| "medical_judge": "llama-3.3-70b-versatile", | |
| } | |
| print("=" * 60); print(" PRE-FLIGHT β Groq routing + smoke test"); print("=" * 60) | |
| all_pass = True | |
| for role, exp in expected.items(): | |
| actual = router._models.get(role, "?") | |
| routing_ok = (actual == exp) | |
| client = router._clients.get(role) | |
| if client is None: | |
| print(f" [SKIP] {role:14s} -> no Groq client (key missing)") | |
| all_pass = False | |
| continue | |
| try: | |
| resp = client.chat.completions.create( | |
| model=exp, | |
| messages=[{"role": "user", "content": "Reply with exactly: PING"}], | |
| max_tokens=4, temperature=0, | |
| ) | |
| api_ok = "PING" in (resp.choices[0].message.content or "").upper() | |
| err = "" | |
| except Exception as e: | |
| api_ok = False | |
| err = f" ({type(e).__name__}: {str(e)[:80]})" | |
| flag = "PASS" if (routing_ok and api_ok) else "FAIL" | |
| if flag == "FAIL": | |
| all_pass = False | |
| print(f" [{flag}] {role:14s} -> {actual:30s} " | |
| f"routing={'ok' if routing_ok else 'WRONG'}, " | |
| f"api={'ok' if api_ok else 'fail'}{err}") | |
| print("=" * 60) | |
| print("OK" if all_pass else "NOT OK β fix routing/keys before training.") | |
| print("=" * 60) | |
| assert all_pass, "Pre-flight failed; do not proceed to training." | |
| """ | |
| CELL_11_DRYRUN = """\ | |
| # === CELL 11 β Dry-run smoke test (no GPU, no model load) === | |
| # Verifies the curriculum scheduler + reward verifier + per-phase early-stop | |
| # wiring before we burn GPU minutes on the real run. | |
| from ER_MAP.training.train_grpo import train | |
| _ = train( | |
| num_episodes=8, | |
| group_size=2, | |
| model_name=MODEL_NAME, | |
| learning_rate=LEARNING_RATE, | |
| kl_beta=KL_BETA, | |
| output_dir="/kaggle/working/_dryrun", | |
| dry_run=True, | |
| phase_reward_targets=PHASE_REWARD_TARGETS, | |
| phase_min_win_rate=PHASE_MIN_WIN_RATE, | |
| convergence_window=CONVERGENCE_WINDOW, | |
| early_stop=EARLY_STOP_ENABLED, | |
| ) | |
| print("\\nDry-run OK β scheduler + verifier + per-phase early-stop wiring is healthy.") | |
| """ | |
| CELL_12_HOOK = """\ | |
| # === CELL 12 β Wire periodic HF Hub push into training === | |
| # We monkey-patch save_lora_adapters so every checkpoint dump also pushes | |
| # the LoRA adapter to HF Hub. Failures are non-fatal β training keeps | |
| # running even if a push fails (e.g. transient HF 502). | |
| from ER_MAP.training import train_grpo as _tg | |
| _original_save = _tg.save_lora_adapters | |
| def save_lora_adapters_with_push(model, tokenizer, output_dir): | |
| _original_save(model, tokenizer, output_dir) | |
| if HF_PUSH_REPO and "<your-username>" not in HF_PUSH_REPO: | |
| try: | |
| push_checkpoint_to_hub( | |
| output_dir, HF_PUSH_REPO, | |
| commit_message=f"checkpoint @ {os.path.basename(output_dir)}", | |
| ) | |
| except Exception as e: | |
| print(f" [hub-push] non-fatal failure: {e}") | |
| _tg.save_lora_adapters = save_lora_adapters_with_push | |
| print("Hub-push hook installed.") | |
| """ | |
| CELL_13_TRAIN_MD = """\ | |
| ## 13 Β· Run real training (fixed-budget curriculum, ~5 hours) | |
| **Mode:** fixed-budget (`PHASE_EPISODE_BUDGETS = {1: 20, 2: 30, 3: 50}` in cell 9). | |
| The reward thresholds in `PHASE_REWARD_TARGETS` are **observational only** β | |
| they're plotted as horizontal target lines on the reward curve, but they do | |
| NOT cause early termination. Each phase runs its full episode budget so the | |
| final reward-growth chart shows clean, monotone progression. | |
| **Estimated wall-clock on Kaggle T4 (Γ1 active GPU):** | |
| - ~2β4 min per episode (Doctor.generate + 4β8 Γ Groq API calls per turn) | |
| - ~1β2 min amortized per GRPO update (G=2 trajectories) | |
| - **Per-group β 5β10 min** (2 episodes + 1 update) | |
| | Phase | Episodes (this run) | GRPO updates | Wall-clock estimate | | |
| |---|---|---|---| | |
| | 1 β Tool Mastery | **20** | 10 | ~50 min β 1.7 h | | |
| | 2 β Clinical Reasoning | **30** | 15 | ~1.3 β 2.5 h | | |
| | 3 β Empathetic Negotiation | **50** | 25 | ~2.0 β 4.0 h | | |
| | **Total** | **100** | **50** | **~4.0 β 8.0 h** | | |
| Checkpoints are pushed to HF Hub every `PUSH_EVERY_EPS=20` episodes, so if | |
| the Kaggle session expires mid-run you can resume in a fresh session via | |
| `HF_RESUME_REPO` in cell 8. | |
| > **Want even faster?** Drop `PHASE_EPISODE_BUDGETS` to `{1: 10, 2: 15, 3: 25}` | |
| > in cell 9 (50 episodes total, ~2.0 β 4.0 h). The curve will be choppier but | |
| > still shows phase transitions cleanly. | |
| > | |
| > **Want adaptive (early-stop)?** Set `PHASE_EPISODE_BUDGETS = None` in cell 9 | |
| > and `EARLY_STOP_ENABLED = True`; each phase will end the moment its reward | |
| > target is sustained for `CONVERGENCE_WINDOW=3` consecutive groups. | |
| """ | |
| CELL_13_TRAIN = """\ | |
| # === CELL 13 β REAL TRAINING (4-6 h cell, fixed-budget curriculum) === | |
| metrics = train( | |
| num_episodes=NUM_EPISODES, | |
| group_size=GROUP_SIZE, | |
| model_name=MODEL_NAME, | |
| groq_api_key=((os.environ.get("GROQ_NURSE_API_KEY") or os.environ.get("nurse")) or os.environ.get("nurse", "")) | |
| or ((os.environ.get("GROQ_API_KEY") or os.environ.get("groq")) or os.environ.get("groq", "")), | |
| learning_rate=LEARNING_RATE, | |
| kl_beta=KL_BETA, | |
| use_wandb=USE_WANDB, | |
| output_dir=OUTPUT_DIR, | |
| dry_run=False, | |
| phase_reward_targets=PHASE_REWARD_TARGETS, | |
| phase_min_win_rate=PHASE_MIN_WIN_RATE, | |
| convergence_window=CONVERGENCE_WINDOW, | |
| early_stop=EARLY_STOP_ENABLED, | |
| phase_episode_budgets=PHASE_EPISODE_BUDGETS, # None -> early-stop mode | |
| ) | |
| print(f"\\nTraining returned {len(metrics)} metric records.") | |
| """ | |
| CELL_14_FINAL_PUSH = """\ | |
| # === CELL 14 β Final push: adapters + merged fp16 === | |
| FINAL_LORA_DIR = f"{OUTPUT_DIR}/final_lora" | |
| FINAL_MERGED_DIR = f"{OUTPUT_DIR}/final_merged_fp16" | |
| if HF_PUSH_REPO and "<your-username>" not in HF_PUSH_REPO: | |
| push_checkpoint_to_hub(FINAL_LORA_DIR, HF_PUSH_REPO, | |
| commit_message="final LoRA adapter") | |
| if os.path.isdir(FINAL_MERGED_DIR): | |
| push_checkpoint_to_hub(FINAL_MERGED_DIR, f"{HF_PUSH_REPO}-merged", | |
| commit_message="final merged fp16") | |
| print(f"Final checkpoints pushed: https://huggingface.co/{HF_PUSH_REPO}") | |
| else: | |
| print("HF_PUSH_REPO not configured β skipping final push.") | |
| """ | |
| CELL_15_PLOTS_MD = """\ | |
| ## 15 Β· Per-phase training graphs (one dashboard per curriculum phase) | |
| We render a 6-panel dashboard for **every phase that contains episodes**, | |
| plus a cross-phase overview and a phase-comparison bar chart. All PNGs are | |
| written to `er_map_grpo_checkpoints/plots/` and uploaded to HF Hub in the | |
| next cell so they survive Kaggle session expiry. | |
| Each per-phase dashboard contains: | |
| 1. **Reward growth** β raw scatter + rolling mean (w=10) + verified rolling mean | |
| 2. **Rolling win rate** β w=20 win-rate evolution within the phase | |
| 3. **Outcome distribution over time** β stacked bars (WIN/PARTIAL/INCORRECT/AMA_LOSS/FATAL_LOSS) | |
| 4. **Reward components** β mean of each component (process / treatment / empathy / labs / etc.) | |
| 5. **GRPO update stats** β loss + KL divergence per group update | |
| 6. **Episode length distribution** β histogram of step counts | |
| """ | |
| CELL_15_PLOTS = """\ | |
| # === CELL 15 β Per-phase training dashboards === | |
| from ER_MAP.plotting import plot_per_phase_dashboards | |
| from IPython.display import Image, display, Markdown | |
| PLOTS_DIR = f"{OUTPUT_DIR}/plots" | |
| written = plot_per_phase_dashboards( | |
| metrics_path=f"{OUTPUT_DIR}/training_metrics.json", | |
| output_dir=PLOTS_DIR, | |
| ) | |
| print(f"Saved {len(written)} chart(s) to {PLOTS_DIR}:") | |
| for name, path in written.items(): | |
| size_kb = os.path.getsize(path) / 1024 | |
| print(f" {name:<28s} -> {path} ({size_kb:.0f} KB)") | |
| # Display each chart inline so the operator sees them without leaving Kaggle. | |
| ordered = (sorted(k for k in written if k.startswith("phase")) | |
| + ["all_phases_overview", "all_phases_comparison"]) | |
| for key in ordered: | |
| if key not in written: | |
| continue | |
| display(Markdown(f"### {key.replace('_', ' ').title()}")) | |
| display(Image(filename=written[key])) | |
| """ | |
| CELL_16_PUSH_PLOTS = """\ | |
| # === CELL 16 β Push plots to HF Hub === | |
| if HF_PUSH_REPO and "<your-username>" not in HF_PUSH_REPO: | |
| push_checkpoint_to_hub(PLOTS_DIR, HF_PUSH_REPO, | |
| commit_message="per-phase training plots") | |
| print(f"Plots pushed: https://huggingface.co/{HF_PUSH_REPO}/tree/main") | |
| else: | |
| print("HF_PUSH_REPO not configured β plots stay only in /kaggle/working/.") | |
| """ | |
| CELL_17_INFER_MD = """\ | |
| ## 17 Β· (Optional) Inference smoke-test on the trained model | |
| Catches the classic 'merge path looked OK but the saved model emits garbage' | |
| failure mode before the demo. | |
| """ | |
| CELL_17_INFER = """\ | |
| # === CELL 17 β Inference smoke-test on the trained model === | |
| from ER_MAP.training.train_grpo import generate_doctor_action, load_model_and_tokenizer | |
| from peft import PeftModel | |
| base_model, tok = load_model_and_tokenizer(model_name=MODEL_NAME) | |
| trained = PeftModel.from_pretrained(base_model, FINAL_LORA_DIR) | |
| test_obs = ( | |
| '{"event":"episode_start","nurse_experience":"veteran",' | |
| '"message":"Patient with chest pain, HR 120, BP 90/60, vague history.",' | |
| '"soap_summary":{}}' | |
| ) | |
| for i in range(3): | |
| print(f"\\n--- Sample {i+1} ---") | |
| print(generate_doctor_action(trained, tok, test_obs, max_new_tokens=160)) | |
| """ | |
| # --------------------------------------------------------------------------- | |
| # Quickstart markdown (sibling file) | |
| # --------------------------------------------------------------------------- | |
| QUICKSTART_MD = """\ | |
| # Kaggle Quickstart β ER-MAP GRPO Training (v3 stable) | |
| The Kaggle notebook is in `kaggle/train_ermap_grpo_kaggle.ipynb`. This file | |
| is the cheat sheet for running it end-to-end without the dependency hell | |
| that bit us in earlier attempts. | |
| ## 0. Prerequisites (one-time) | |
| 1. **GitHub fork** of this repo. The notebook clones from a public fork at | |
| cell 6 β edit `GIT_URL`. Alternatively, upload the repo as a Kaggle | |
| Dataset named `ermap-source` (Add Data β Upload). | |
| 2. **Hugging Face write token** (`HF_TOKEN`) for pushing the trained | |
| adapter. Create at https://huggingface.co/settings/tokens (fine-grained, | |
| write access on a single model repo is enough). | |
| 3. **Five Groq keys** (one each for Nurse / Patient / Empathy Judge / | |
| Medical Judge / shared fallback). Free-tier accounts are fine; the | |
| per-account limits multiply across keys. | |
| ## 1. Create the Kaggle notebook | |
| 1. Sign in to https://www.kaggle.com/code β **New Notebook**. | |
| 2. Right sidebar: | |
| - Accelerator: **GPU T4 Γ2** (or P100) | |
| - Internet: **On** | |
| - Persistence: Files only | |
| 3. **File β Upload Notebook** β choose `kaggle/train_ermap_grpo_kaggle.ipynb` | |
| from this repo. | |
| ## 2. Add Kaggle Secrets | |
| Add-ons β Secrets β Add a new secret. Required labels (exactly): | |
| | Label | Value | | |
| |---|---| | |
| | `GROQ_NURSE_API_KEY` | your nurse Groq key | | |
| | `GROQ_PATIENT_API_KEY` | your patient Groq key | | |
| | `GROQ_EMPATHY_JUDGE_API_KEY` | your empathy-judge Groq key | | |
| | `GROQ_MEDICAL_JUDGE_API_KEY` | your medical-judge Groq key | | |
| | `HF_TOKEN` | your HF write token | | |
| | `WANDB_API_KEY` *(optional)* | your W&B key (skip β disabled by default) | | |
| The notebook reads them via `kaggle_helpers.load_kaggle_secrets()` and | |
| exports them as env vars. | |
| ## 3. Edit two placeholders in the notebook | |
| - **Cell 6:** `GIT_URL = "https://github.com/<your-fork>/Meta_Finals.git"` | |
| - **Cell 8:** `HF_PUSH_REPO = "<your-username>/ermap-doctor-lora"` | |
| If you uploaded the repo as a Kaggle Dataset instead, leave `GIT_URL` as the | |
| placeholder β cell 6 will detect `/kaggle/input/ermap-source` and copy from | |
| there. | |
| ## 4. Run order (the only sequence that works) | |
| | Cell | What it does | Expected output | | |
| |---|---|---| | |
| | 2 | GPU + disk + python + internet sanity check | GPU listed, disk free > 8 GB | | |
| | 3 | **REPAIR** β pin torch 2.10 cu128, reinstall bitsandbytes, upgrade unsloth | `REPAIR OK` (or `RESTART REQUIRED`) | | |
| | **(restart)** | If cell 3 said RESTART REQUIRED β Run β Restart kernel | β | | |
| | 5 | Post-restart import verify | All `OK`, GPUs listed | | |
| | 6 | Clone / mount the repo | `OK. Repo at /kaggle/working/Meta_Finals` | | |
| | 7 | Wire Kaggle Secrets β env vars | `OK β at least one Groq key is wired` | | |
| | 8 | HF Hub config | `Starting fresh β no resume.` | | |
| | 9 | Hyperparameters (P1=+1.2, P2=+1.1, P3=+1.0) | thresholds printed | | |
| | 10 | **Pre-flight** β Groq routing + 4Γ PING | 4Γ `[PASS]`, then `OK` | | |
| | 11 | Dry-run smoke test (no GPU) | `Dry-run OK` | | |
| | 12 | Wire HF push hook | `Hub-push hook installed.` | | |
| | 13 | **REAL TRAINING** (4β6 h) | per-group rolling stats, eventual `EARLY STOP` | | |
| | 14 | Final push to HF | `Final checkpoints pushed: https://huggingface.co/...` | | |
| | 15 | Per-phase plots | 5 PNGs displayed inline | | |
| | 16 | Push plots to HF | `Plots pushed: ...` | | |
| | 17 | Inference smoke-test (optional) | 3 sample Doctor actions printed | | |
| ## 5. Common failures & fixes | |
| | Symptom | Root cause | Fix | | |
| |---|---|---| | |
| | `numpy was upgraded mid-session` | numpy import poisoned by a previous cell | Restart kernel, re-run from cell 3 | | |
| | `Pillow incompatible with torchvision` | Pillow ABI mismatch | Restart kernel, re-run from cell 3 | | |
| | `PyTorch and torchvision compiled with different CUDA major` | torch upgraded to cu13 by a transient resolve | Re-run cell 3 (it pins cu128) and restart | | |
| | `cannot import name 'create_gradient_checkpointing_buffer'` | unsloth β unsloth_zoo version drift | Re-run cell 3 (upgrades both in lockstep) | | |
| | `libnvJitLink.so.13 missing` | bitsandbytes built against different CUDA | Re-run cell 3 (force-reinstalls bitsandbytes after torch pin) | | |
| | Disk usage > quota | Kaggle's 20 GB working partition fills up | First line of cell 3 cleans `/tmp` and pip cache | | |
| | Pre-flight `[FAIL]` for a role | Groq key dead / quota exceeded | Generate a new key in console.groq.com β update Kaggle Secret β re-run cell 7+10 | | |
| | `[FAIL]` says `routing=WRONG` | env var not set when `AgentRouter()` was constructed | Re-run cell 9 BEFORE cell 10 | | |
| | Training freezes at episode 1 for >10 min | Doctor.generate hung; Unsloth import broke silently | Check cell 5 output for `unsloth` line; restart kernel and re-run cell 3 if missing | | |
| ## 6. What the trained model gives you | |
| After cell 13 finishes (or hits the 12 h Kaggle session cap), you have: | |
| - `OUTPUT_DIR/final_lora/` β LoRA adapter weights (~50 MB), pushed to | |
| `HF_PUSH_REPO` | |
| - `OUTPUT_DIR/final_merged_fp16/` β full Llama-3.1-8B fp16 merge with the | |
| adapter applied (~16 GB), pushed to `HF_PUSH_REPO-merged` | |
| - `OUTPUT_DIR/training_metrics.json` β per-episode rewards, outcomes, | |
| rolling stats β input for the per-phase plots | |
| - `OUTPUT_DIR/plots/*.png` β 5 dashboards (one per phase + cross-phase | |
| overview + comparison bar) | |
| Use the LoRA adapter for the demo (quick to load, runs on a 4050 6 GB at | |
| ~30 tok/s); use the merged fp16 if you need to host on a Vercel/HF Space | |
| without `peft`. | |
| """ | |
| # --------------------------------------------------------------------------- | |
| # Build the notebook | |
| # --------------------------------------------------------------------------- | |
| def build_notebook() -> dict: | |
| cells = [ | |
| md_cell(CELL_01_TITLE), # 0 | |
| code_cell(CELL_02_SANITY), # 1 | |
| code_cell(CELL_03_REPAIR), # 2 | |
| md_cell(CELL_04_RESTART), # 3 | |
| code_cell(CELL_05_VERIFY), # 4 | |
| code_cell(CELL_06_REPO), # 5 | |
| code_cell(CELL_07_SECRETS), # 6 | |
| code_cell(CELL_08_HF), # 7 | |
| code_cell(CELL_09_HYPERPARAMS), # 8 | |
| code_cell(CELL_10_PREFLIGHT), # 9 | |
| code_cell(CELL_11_DRYRUN), # 10 | |
| code_cell(CELL_12_HOOK), # 11 | |
| md_cell(CELL_13_TRAIN_MD), # 12 | |
| code_cell(CELL_13_TRAIN), # 13 | |
| code_cell(CELL_14_FINAL_PUSH), # 14 | |
| md_cell(CELL_15_PLOTS_MD), # 15 | |
| code_cell(CELL_15_PLOTS), # 16 | |
| code_cell(CELL_16_PUSH_PLOTS), # 17 | |
| md_cell(CELL_17_INFER_MD), # 18 | |
| code_cell(CELL_17_INFER), # 19 | |
| ] | |
| return { | |
| "cells": cells, | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "Python 3", | |
| "language": "python", | |
| "name": "python3", | |
| }, | |
| "language_info": { | |
| "name": "python", | |
| "version": "3.10", | |
| }, | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 5, | |
| } | |
| def main() -> None: | |
| here = Path(__file__).parent | |
| nb_path = here / "train_ermap_grpo_kaggle.ipynb" | |
| qs_path = here / "KAGGLE_QUICKSTART.md" | |
| nb = build_notebook() | |
| nb_path.write_text(json.dumps(nb, indent=1, ensure_ascii=False), encoding="utf-8") | |
| qs_path.write_text(QUICKSTART_MD, encoding="utf-8") | |
| n_md = sum(1 for c in nb["cells"] if c["cell_type"] == "markdown") | |
| n_code = sum(1 for c in nb["cells"] if c["cell_type"] == "code") | |
| print(f"Wrote {nb_path} ({len(nb['cells'])} cells: {n_md} md / {n_code} code)") | |
| print(f"Wrote {qs_path} ({len(QUICKSTART_MD.splitlines())} lines)") | |
| if __name__ == "__main__": | |
| main() | |