Spaces:
Sleeping
Sleeping
| """ | |
| ER-MAP CLEAN TRAINING LAUNCH (Kaggle, T4-safe, 75-episode fixed budget) | |
| Self-contained, idempotent, foolproof. Replaces the old Cell 9 / Cell 11 / | |
| Cell 13 sequence with ONE cell that: | |
| 1. Force-pulls the repo to origin/main (picks up any new fix commits) | |
| 2. Drops the cached ER_MAP module so the next import picks up the fresh disk | |
| 3. Asserts the train_grpo patches are live in the running module (kl-gate | |
| + use_kl loss branch + phase_episode_budgets parameter) | |
| 4. Sets all hyperparameters EXPLICITLY — does not depend on any earlier | |
| cell's globals being correct | |
| 5. Frees VRAM aggressively and asserts >= 6 GB free before launch | |
| 6. Runs the Groq pre-flight (routing + 4-key liveness) and asserts all PASS | |
| 7. Calls train() with phase_episode_budgets={1: 20, 2: 25, 3: 30} | |
| Usage from a Kaggle notebook cell: | |
| exec(open("/kaggle/working/Meta_Finals/kaggle/clean_launch.py").read()) | |
| That one line is all you paste. Press play. Walk away for ~4 hours. | |
| """ | |
| import os, sys, gc, subprocess, importlib # noqa: E401 | |
| # Set the CUDA allocator config FIRST, before anything imports torch. | |
| # This must precede the very first `import torch` in the kernel — otherwise | |
| # the allocator is already initialized and the env var has no effect. | |
| os.environ.setdefault("PYTORCH_ALLOC_CONF", "expandable_segments:True") | |
| os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") | |
| # ============================================================================= | |
| # 1. Force repo to latest commit on origin/main | |
| # ============================================================================= | |
| REPO_ROOT = "/kaggle/working/Meta_Finals" | |
| print("[1/7] Updating repo to origin/main...") | |
| subprocess.run(["git", "-C", REPO_ROOT, "fetch", "origin"], check=True) | |
| subprocess.run(["git", "-C", REPO_ROOT, "reset", "--hard", "origin/main"], check=True) | |
| subprocess.run(["git", "-C", REPO_ROOT, "log", "-1", "--oneline"]) | |
| # ============================================================================= | |
| # 2. Drop cached ER_MAP modules so import picks up the latest disk version | |
| # ============================================================================= | |
| print("\n[2/7] Dropping cached modules...") | |
| for _m in list(sys.modules): | |
| if _m.startswith("ER_MAP"): | |
| del sys.modules[_m] | |
| if REPO_ROOT not in sys.path: | |
| sys.path.insert(0, REPO_ROOT) | |
| # ============================================================================= | |
| # 3. Verify all required patches are live in the running module | |
| # ============================================================================= | |
| print("\n[3/7] Verifying patches...") | |
| import inspect | |
| import ER_MAP.training.train_grpo as tg | |
| _train_src = inspect.getsource(tg.train) | |
| assert "if kl_beta > 0.0:" in _train_src, ( | |
| "FAIL: train() missing kl_beta gate. Pull the latest commit on origin/main." | |
| ) | |
| assert "phase_episode_budgets" in _train_src, ( | |
| "FAIL: train() missing phase_episode_budgets support." | |
| ) | |
| assert "use_kl" in tg.manual_grpo_step.__code__.co_varnames, ( | |
| "FAIL: manual_grpo_step missing 'use_kl' branch." | |
| ) | |
| # Per-step backward (T4 OOM fix): the loss must be scaled and backward()-ed | |
| # inside the trajectory loop, NOT accumulated into one big graph. | |
| _step_src = inspect.getsource(tg.manual_grpo_step) | |
| assert "scaled.backward()" in _step_src and "n_steps_total" in _step_src, ( | |
| "FAIL: manual_grpo_step still accumulates loss across all steps " | |
| "(graph stays alive in VRAM -> OOM during update). " | |
| "Pull the latest commit on origin/main." | |
| ) | |
| # Inference/Training mode swap (the silent ~7 GB VRAM leak fix). | |
| assert "for_inference" in _train_src and "_to_training" in _train_src, ( | |
| "FAIL: train() missing for_inference/for_training mode swap. " | |
| "Pull the latest commit on origin/main." | |
| ) | |
| # LoRA dropout=0 + attention-only modules (Unsloth fast path). | |
| _loader_src = inspect.getsource(tg.load_model_and_tokenizer) | |
| assert "lora_dropout=0" in _loader_src, ( | |
| "FAIL: lora_dropout still > 0 (disables Unsloth fast LoRA kernels)." | |
| ) | |
| assert '"q_proj", "k_proj", "v_proj", "o_proj"' in _loader_src, ( | |
| "FAIL: LoRA targets still include MLP modules (too many trainable params for T4)." | |
| ) | |
| print(" OK — kl_beta gate live") | |
| print(" OK — phase_episode_budgets supported") | |
| print(" OK — use_kl branch in loss function") | |
| print(" OK — per-step backward (memory-bounded GRPO update)") | |
| print(" OK — for_inference/for_training mode swap (no checkpointing leak)") | |
| print(" OK — lora_dropout=0, attention-only LoRA (Unsloth fast path)") | |
| # ============================================================================= | |
| # 4. EXPLICIT hyperparameters — does not rely on any previous cell's globals | |
| # ============================================================================= | |
| print("\n[4/7] Setting hyperparameters (explicit, no Cell 9 dependency)...") | |
| MODEL_NAME = "unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit" | |
| GROUP_SIZE = 2 | |
| LEARNING_RATE = 5e-6 | |
| KL_BETA = 0.0 # T4-safe: skip reference model load (saves ~5 GB VRAM) | |
| PHASE_EPISODE_BUDGETS = {1: 20, 2: 25, 3: 30} # 75 episodes total | |
| NUM_EPISODES = sum(PHASE_EPISODE_BUDGETS.values()) | |
| PHASE_REWARD_TARGETS = {1: 1.2, 2: 1.1, 3: 1.0} # observational only | |
| PHASE_MIN_WIN_RATE = 0.20 | |
| CONVERGENCE_WINDOW = 3 | |
| EARLY_STOP_ENABLED = False # forced off by train() under fixed-budget anyway | |
| OUTPUT_DIR = "/kaggle/working/er_map_grpo_checkpoints" | |
| # Anti-fragmentation for the GRPO backward pass on T4 (re-asserted; the real | |
| # set must happen at the top of this script, before the first torch import). | |
| os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True" | |
| os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" | |
| # Groq traffic shaping — 8B for actors, 70B for judges | |
| os.environ["ERMAP_NURSE_MODEL"] = "llama-3.1-8b-instant" | |
| os.environ["ERMAP_PATIENT_MODEL"] = "llama-3.1-8b-instant" | |
| os.environ["ERMAP_EMPATHY_JUDGE_MODEL"] = "llama-3.3-70b-versatile" | |
| os.environ["ERMAP_MEDICAL_JUDGE_MODEL"] = "llama-3.3-70b-versatile" | |
| # Episode budget controls (read by triage_env) | |
| os.environ["ERMAP_MAX_EPISODE_STEPS"] = "20" | |
| os.environ["ERMAP_MAX_INTERNAL_EXCHANGES"] = "5" | |
| # Doctor generation length — 128 keeps KV cache + activations small enough | |
| # for the GRPO update to fit alongside the 8B model on a 15.6 GB T4. | |
| os.environ["ERMAP_DOCTOR_MAX_NEW_TOKENS"] = "128" | |
| print(f" NUM_EPISODES = {NUM_EPISODES}") | |
| print(f" PHASE_EPISODE_BUDGETS = {PHASE_EPISODE_BUDGETS}") | |
| print(f" GROUP_SIZE = {GROUP_SIZE}") | |
| print(f" KL_BETA = {KL_BETA} (skip ref model)") | |
| print(f" PHASE_REWARD_TARGETS = {PHASE_REWARD_TARGETS} (observational)") | |
| # ============================================================================= | |
| # 5. Free VRAM and assert headroom for the model load | |
| # ============================================================================= | |
| print("\n[5/7] Freeing VRAM...") | |
| import torch # noqa: E402 | |
| for _name in ("model", "tokenizer", "ref_model", "optimizer"): | |
| if _name in globals(): | |
| try: | |
| del globals()[_name] | |
| except KeyError: | |
| pass | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| torch.cuda.ipc_collect() | |
| _free, _total = torch.cuda.mem_get_info(0) | |
| print(f" VRAM free: {_free/1e9:.2f} / {_total/1e9:.2f} GB") | |
| assert _free / 1e9 >= 6.0, ( | |
| f"FAIL: only {_free/1e9:.2f} GB free; need >= 6 GB. " | |
| "Do Run -> Restart kernel, then re-run Cell 6 (mount), Cell 7 (secrets), " | |
| "and this cell. The kernel has unrecoverable VRAM fragmentation." | |
| ) | |
| # ============================================================================= | |
| # 6. Groq pre-flight (routing + 4-key liveness) | |
| # ============================================================================= | |
| print("\n[6/7] Pre-flight: Groq routing + key liveness...") | |
| from ER_MAP.envs.api_router import AgentRouter # noqa: E402 | |
| _router = AgentRouter() | |
| _expected = { | |
| "nurse": "llama-3.1-8b-instant", | |
| "patient": "llama-3.1-8b-instant", | |
| "empathy_judge": "llama-3.3-70b-versatile", | |
| "medical_judge": "llama-3.3-70b-versatile", | |
| } | |
| _all_pass = True | |
| for _role, _exp in _expected.items(): | |
| _actual = _router._models.get(_role, "?") | |
| _client = _router._clients.get(_role) | |
| if _client is None: | |
| print(f" [SKIP] {_role:14s} -> no Groq client (key missing)") | |
| _all_pass = False | |
| continue | |
| try: | |
| _resp = _client.chat.completions.create( | |
| model=_exp, | |
| messages=[{"role": "user", "content": "Reply with exactly: PING"}], | |
| max_tokens=4, temperature=0, | |
| ) | |
| _api_ok = "PING" in (_resp.choices[0].message.content or "").upper() | |
| _err = "" | |
| except Exception as _e: | |
| _api_ok = False | |
| _err = f" ({type(_e).__name__}: {str(_e)[:80]})" | |
| _flag = "PASS" if (_actual == _exp and _api_ok) else "FAIL" | |
| print(f" [{_flag}] {_role:14s} | model={_actual:25s} | api_ok={_api_ok}{_err}") | |
| if _flag == "FAIL": | |
| _all_pass = False | |
| assert _all_pass, "Pre-flight FAILED. Re-run Cell 7 (secrets) and Cell 6 (repo)." | |
| # ============================================================================= | |
| # 7. LAUNCH — fixed-budget GRPO training | |
| # ============================================================================= | |
| print("\n[7/7] Launching GRPO training (75 episodes, fixed budget)...") | |
| print("=" * 72) | |
| print(" Phase 1 (Tool Mastery) : 20 episodes") | |
| print(" Phase 2 (Clinical Reasoning) : 25 episodes") | |
| print(" Phase 3 (Empathetic Negotiation) : 30 episodes") | |
| print(" Total : 75 episodes (~3-5 hours on T4)") | |
| print(" HF Hub backup : every 20 episodes") | |
| print("=" * 72) | |
| metrics = tg.train( | |
| num_episodes=NUM_EPISODES, | |
| group_size=GROUP_SIZE, | |
| model_name=MODEL_NAME, | |
| groq_api_key=((os.environ.get("GROQ_NURSE_API_KEY") or os.environ.get("nurse")) or os.environ.get("nurse", "")) | |
| or ((os.environ.get("GROQ_API_KEY") or os.environ.get("groq")) or os.environ.get("groq", "")), | |
| learning_rate=LEARNING_RATE, | |
| kl_beta=KL_BETA, | |
| use_wandb=False, | |
| output_dir=OUTPUT_DIR, | |
| dry_run=False, | |
| phase_reward_targets=PHASE_REWARD_TARGETS, | |
| phase_min_win_rate=PHASE_MIN_WIN_RATE, | |
| convergence_window=CONVERGENCE_WINDOW, | |
| early_stop=EARLY_STOP_ENABLED, | |
| phase_episode_budgets=PHASE_EPISODE_BUDGETS, | |
| ) | |
| print("=" * 72) | |
| print(f"\nTRAINING COMPLETE — {len(metrics)} metric records collected.") | |
| print(f"Final LoRA adapter: {OUTPUT_DIR}/final_lora") | |
| print(f"Plots will be rendered by Cell 15 (run it next).") | |