Multi-Agentic / kaggle /clean_launch.py
Uddiii's picture
feat: add support for lowercase Hugging Face Space secrets
63726b6
"""
ER-MAP CLEAN TRAINING LAUNCH (Kaggle, T4-safe, 75-episode fixed budget)
Self-contained, idempotent, foolproof. Replaces the old Cell 9 / Cell 11 /
Cell 13 sequence with ONE cell that:
1. Force-pulls the repo to origin/main (picks up any new fix commits)
2. Drops the cached ER_MAP module so the next import picks up the fresh disk
3. Asserts the train_grpo patches are live in the running module (kl-gate
+ use_kl loss branch + phase_episode_budgets parameter)
4. Sets all hyperparameters EXPLICITLY — does not depend on any earlier
cell's globals being correct
5. Frees VRAM aggressively and asserts >= 6 GB free before launch
6. Runs the Groq pre-flight (routing + 4-key liveness) and asserts all PASS
7. Calls train() with phase_episode_budgets={1: 20, 2: 25, 3: 30}
Usage from a Kaggle notebook cell:
exec(open("/kaggle/working/Meta_Finals/kaggle/clean_launch.py").read())
That one line is all you paste. Press play. Walk away for ~4 hours.
"""
import os, sys, gc, subprocess, importlib # noqa: E401
# Set the CUDA allocator config FIRST, before anything imports torch.
# This must precede the very first `import torch` in the kernel — otherwise
# the allocator is already initialized and the env var has no effect.
os.environ.setdefault("PYTORCH_ALLOC_CONF", "expandable_segments:True")
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
# =============================================================================
# 1. Force repo to latest commit on origin/main
# =============================================================================
REPO_ROOT = "/kaggle/working/Meta_Finals"
print("[1/7] Updating repo to origin/main...")
subprocess.run(["git", "-C", REPO_ROOT, "fetch", "origin"], check=True)
subprocess.run(["git", "-C", REPO_ROOT, "reset", "--hard", "origin/main"], check=True)
subprocess.run(["git", "-C", REPO_ROOT, "log", "-1", "--oneline"])
# =============================================================================
# 2. Drop cached ER_MAP modules so import picks up the latest disk version
# =============================================================================
print("\n[2/7] Dropping cached modules...")
for _m in list(sys.modules):
if _m.startswith("ER_MAP"):
del sys.modules[_m]
if REPO_ROOT not in sys.path:
sys.path.insert(0, REPO_ROOT)
# =============================================================================
# 3. Verify all required patches are live in the running module
# =============================================================================
print("\n[3/7] Verifying patches...")
import inspect
import ER_MAP.training.train_grpo as tg
_train_src = inspect.getsource(tg.train)
assert "if kl_beta > 0.0:" in _train_src, (
"FAIL: train() missing kl_beta gate. Pull the latest commit on origin/main."
)
assert "phase_episode_budgets" in _train_src, (
"FAIL: train() missing phase_episode_budgets support."
)
assert "use_kl" in tg.manual_grpo_step.__code__.co_varnames, (
"FAIL: manual_grpo_step missing 'use_kl' branch."
)
# Per-step backward (T4 OOM fix): the loss must be scaled and backward()-ed
# inside the trajectory loop, NOT accumulated into one big graph.
_step_src = inspect.getsource(tg.manual_grpo_step)
assert "scaled.backward()" in _step_src and "n_steps_total" in _step_src, (
"FAIL: manual_grpo_step still accumulates loss across all steps "
"(graph stays alive in VRAM -> OOM during update). "
"Pull the latest commit on origin/main."
)
# Inference/Training mode swap (the silent ~7 GB VRAM leak fix).
assert "for_inference" in _train_src and "_to_training" in _train_src, (
"FAIL: train() missing for_inference/for_training mode swap. "
"Pull the latest commit on origin/main."
)
# LoRA dropout=0 + attention-only modules (Unsloth fast path).
_loader_src = inspect.getsource(tg.load_model_and_tokenizer)
assert "lora_dropout=0" in _loader_src, (
"FAIL: lora_dropout still > 0 (disables Unsloth fast LoRA kernels)."
)
assert '"q_proj", "k_proj", "v_proj", "o_proj"' in _loader_src, (
"FAIL: LoRA targets still include MLP modules (too many trainable params for T4)."
)
print(" OK — kl_beta gate live")
print(" OK — phase_episode_budgets supported")
print(" OK — use_kl branch in loss function")
print(" OK — per-step backward (memory-bounded GRPO update)")
print(" OK — for_inference/for_training mode swap (no checkpointing leak)")
print(" OK — lora_dropout=0, attention-only LoRA (Unsloth fast path)")
# =============================================================================
# 4. EXPLICIT hyperparameters — does not rely on any previous cell's globals
# =============================================================================
print("\n[4/7] Setting hyperparameters (explicit, no Cell 9 dependency)...")
MODEL_NAME = "unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit"
GROUP_SIZE = 2
LEARNING_RATE = 5e-6
KL_BETA = 0.0 # T4-safe: skip reference model load (saves ~5 GB VRAM)
PHASE_EPISODE_BUDGETS = {1: 20, 2: 25, 3: 30} # 75 episodes total
NUM_EPISODES = sum(PHASE_EPISODE_BUDGETS.values())
PHASE_REWARD_TARGETS = {1: 1.2, 2: 1.1, 3: 1.0} # observational only
PHASE_MIN_WIN_RATE = 0.20
CONVERGENCE_WINDOW = 3
EARLY_STOP_ENABLED = False # forced off by train() under fixed-budget anyway
OUTPUT_DIR = "/kaggle/working/er_map_grpo_checkpoints"
# Anti-fragmentation for the GRPO backward pass on T4 (re-asserted; the real
# set must happen at the top of this script, before the first torch import).
os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
# Groq traffic shaping — 8B for actors, 70B for judges
os.environ["ERMAP_NURSE_MODEL"] = "llama-3.1-8b-instant"
os.environ["ERMAP_PATIENT_MODEL"] = "llama-3.1-8b-instant"
os.environ["ERMAP_EMPATHY_JUDGE_MODEL"] = "llama-3.3-70b-versatile"
os.environ["ERMAP_MEDICAL_JUDGE_MODEL"] = "llama-3.3-70b-versatile"
# Episode budget controls (read by triage_env)
os.environ["ERMAP_MAX_EPISODE_STEPS"] = "20"
os.environ["ERMAP_MAX_INTERNAL_EXCHANGES"] = "5"
# Doctor generation length — 128 keeps KV cache + activations small enough
# for the GRPO update to fit alongside the 8B model on a 15.6 GB T4.
os.environ["ERMAP_DOCTOR_MAX_NEW_TOKENS"] = "128"
print(f" NUM_EPISODES = {NUM_EPISODES}")
print(f" PHASE_EPISODE_BUDGETS = {PHASE_EPISODE_BUDGETS}")
print(f" GROUP_SIZE = {GROUP_SIZE}")
print(f" KL_BETA = {KL_BETA} (skip ref model)")
print(f" PHASE_REWARD_TARGETS = {PHASE_REWARD_TARGETS} (observational)")
# =============================================================================
# 5. Free VRAM and assert headroom for the model load
# =============================================================================
print("\n[5/7] Freeing VRAM...")
import torch # noqa: E402
for _name in ("model", "tokenizer", "ref_model", "optimizer"):
if _name in globals():
try:
del globals()[_name]
except KeyError:
pass
gc.collect()
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
_free, _total = torch.cuda.mem_get_info(0)
print(f" VRAM free: {_free/1e9:.2f} / {_total/1e9:.2f} GB")
assert _free / 1e9 >= 6.0, (
f"FAIL: only {_free/1e9:.2f} GB free; need >= 6 GB. "
"Do Run -> Restart kernel, then re-run Cell 6 (mount), Cell 7 (secrets), "
"and this cell. The kernel has unrecoverable VRAM fragmentation."
)
# =============================================================================
# 6. Groq pre-flight (routing + 4-key liveness)
# =============================================================================
print("\n[6/7] Pre-flight: Groq routing + key liveness...")
from ER_MAP.envs.api_router import AgentRouter # noqa: E402
_router = AgentRouter()
_expected = {
"nurse": "llama-3.1-8b-instant",
"patient": "llama-3.1-8b-instant",
"empathy_judge": "llama-3.3-70b-versatile",
"medical_judge": "llama-3.3-70b-versatile",
}
_all_pass = True
for _role, _exp in _expected.items():
_actual = _router._models.get(_role, "?")
_client = _router._clients.get(_role)
if _client is None:
print(f" [SKIP] {_role:14s} -> no Groq client (key missing)")
_all_pass = False
continue
try:
_resp = _client.chat.completions.create(
model=_exp,
messages=[{"role": "user", "content": "Reply with exactly: PING"}],
max_tokens=4, temperature=0,
)
_api_ok = "PING" in (_resp.choices[0].message.content or "").upper()
_err = ""
except Exception as _e:
_api_ok = False
_err = f" ({type(_e).__name__}: {str(_e)[:80]})"
_flag = "PASS" if (_actual == _exp and _api_ok) else "FAIL"
print(f" [{_flag}] {_role:14s} | model={_actual:25s} | api_ok={_api_ok}{_err}")
if _flag == "FAIL":
_all_pass = False
assert _all_pass, "Pre-flight FAILED. Re-run Cell 7 (secrets) and Cell 6 (repo)."
# =============================================================================
# 7. LAUNCH — fixed-budget GRPO training
# =============================================================================
print("\n[7/7] Launching GRPO training (75 episodes, fixed budget)...")
print("=" * 72)
print(" Phase 1 (Tool Mastery) : 20 episodes")
print(" Phase 2 (Clinical Reasoning) : 25 episodes")
print(" Phase 3 (Empathetic Negotiation) : 30 episodes")
print(" Total : 75 episodes (~3-5 hours on T4)")
print(" HF Hub backup : every 20 episodes")
print("=" * 72)
metrics = tg.train(
num_episodes=NUM_EPISODES,
group_size=GROUP_SIZE,
model_name=MODEL_NAME,
groq_api_key=((os.environ.get("GROQ_NURSE_API_KEY") or os.environ.get("nurse")) or os.environ.get("nurse", ""))
or ((os.environ.get("GROQ_API_KEY") or os.environ.get("groq")) or os.environ.get("groq", "")),
learning_rate=LEARNING_RATE,
kl_beta=KL_BETA,
use_wandb=False,
output_dir=OUTPUT_DIR,
dry_run=False,
phase_reward_targets=PHASE_REWARD_TARGETS,
phase_min_win_rate=PHASE_MIN_WIN_RATE,
convergence_window=CONVERGENCE_WINDOW,
early_stop=EARLY_STOP_ENABLED,
phase_episode_budgets=PHASE_EPISODE_BUDGETS,
)
print("=" * 72)
print(f"\nTRAINING COMPLETE — {len(metrics)} metric records collected.")
print(f"Final LoRA adapter: {OUTPUT_DIR}/final_lora")
print(f"Plots will be rendered by Cell 15 (run it next).")