Spaces:
Sleeping
Sleeping
feat(kaggle): add clean_launch.py + shrink budget to 20/25/30 = 75 eps
Browse files- kaggle/build_notebook.py +2 -2
- kaggle/clean_launch.py +197 -0
- kaggle/train_ermap_grpo_kaggle.ipynb +2 -2
kaggle/build_notebook.py
CHANGED
|
@@ -476,8 +476,8 @@ USE_WANDB = False # WANDB conflicts with protobuf 7 on Kaggle base image
|
|
| 476 |
# but less predictable). When PHASE_EPISODE_BUDGETS is set, EARLY_STOP_ENABLED
|
| 477 |
# is automatically forced to False inside train() β the reward targets below
|
| 478 |
# become observational only (logged on the plots, not used for promotion).
|
| 479 |
-
PHASE_EPISODE_BUDGETS = {1: 20, 2:
|
| 480 |
-
NUM_EPISODES = sum(PHASE_EPISODE_BUDGETS.values()) # =
|
| 481 |
|
| 482 |
# --- Per-phase reward thresholds (observational under fixed-budget) --------
|
| 483 |
# Plotted as horizontal target lines on the reward-growth chart so you can
|
|
|
|
| 476 |
# but less predictable). When PHASE_EPISODE_BUDGETS is set, EARLY_STOP_ENABLED
|
| 477 |
# is automatically forced to False inside train() β the reward targets below
|
| 478 |
# become observational only (logged on the plots, not used for promotion).
|
| 479 |
+
PHASE_EPISODE_BUDGETS = {1: 20, 2: 25, 3: 30} # 20 + 25 + 30 = 75 episodes
|
| 480 |
+
NUM_EPISODES = sum(PHASE_EPISODE_BUDGETS.values()) # = 75
|
| 481 |
|
| 482 |
# --- Per-phase reward thresholds (observational under fixed-budget) --------
|
| 483 |
# Plotted as horizontal target lines on the reward-growth chart so you can
|
kaggle/clean_launch.py
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ER-MAP CLEAN TRAINING LAUNCH (Kaggle, T4-safe, 75-episode fixed budget)
|
| 3 |
+
|
| 4 |
+
Self-contained, idempotent, foolproof. Replaces the old Cell 9 / Cell 11 /
|
| 5 |
+
Cell 13 sequence with ONE cell that:
|
| 6 |
+
|
| 7 |
+
1. Force-pulls the repo to origin/main (picks up any new fix commits)
|
| 8 |
+
2. Drops the cached ER_MAP module so the next import picks up the fresh disk
|
| 9 |
+
3. Asserts the train_grpo patches are live in the running module (kl-gate
|
| 10 |
+
+ use_kl loss branch + phase_episode_budgets parameter)
|
| 11 |
+
4. Sets all hyperparameters EXPLICITLY β does not depend on any earlier
|
| 12 |
+
cell's globals being correct
|
| 13 |
+
5. Frees VRAM aggressively and asserts >= 6 GB free before launch
|
| 14 |
+
6. Runs the Groq pre-flight (routing + 4-key liveness) and asserts all PASS
|
| 15 |
+
7. Calls train() with phase_episode_budgets={1: 20, 2: 25, 3: 30}
|
| 16 |
+
|
| 17 |
+
Usage from a Kaggle notebook cell:
|
| 18 |
+
|
| 19 |
+
exec(open("/kaggle/working/Meta_Finals/kaggle/clean_launch.py").read())
|
| 20 |
+
|
| 21 |
+
That one line is all you paste. Press play. Walk away for ~4 hours.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
import os, sys, gc, subprocess, importlib # noqa: E401
|
| 25 |
+
|
| 26 |
+
# =============================================================================
|
| 27 |
+
# 1. Force repo to latest commit on origin/main
|
| 28 |
+
# =============================================================================
|
| 29 |
+
REPO_ROOT = "/kaggle/working/Meta_Finals"
|
| 30 |
+
print("[1/7] Updating repo to origin/main...")
|
| 31 |
+
subprocess.run(["git", "-C", REPO_ROOT, "fetch", "origin"], check=True)
|
| 32 |
+
subprocess.run(["git", "-C", REPO_ROOT, "reset", "--hard", "origin/main"], check=True)
|
| 33 |
+
subprocess.run(["git", "-C", REPO_ROOT, "log", "-1", "--oneline"])
|
| 34 |
+
|
| 35 |
+
# =============================================================================
|
| 36 |
+
# 2. Drop cached ER_MAP modules so import picks up the latest disk version
|
| 37 |
+
# =============================================================================
|
| 38 |
+
print("\n[2/7] Dropping cached modules...")
|
| 39 |
+
for _m in list(sys.modules):
|
| 40 |
+
if _m.startswith("ER_MAP"):
|
| 41 |
+
del sys.modules[_m]
|
| 42 |
+
if REPO_ROOT not in sys.path:
|
| 43 |
+
sys.path.insert(0, REPO_ROOT)
|
| 44 |
+
|
| 45 |
+
# =============================================================================
|
| 46 |
+
# 3. Verify all required patches are live in the running module
|
| 47 |
+
# =============================================================================
|
| 48 |
+
print("\n[3/7] Verifying patches...")
|
| 49 |
+
import inspect
|
| 50 |
+
import ER_MAP.training.train_grpo as tg
|
| 51 |
+
|
| 52 |
+
_train_src = inspect.getsource(tg.train)
|
| 53 |
+
assert "if kl_beta > 0.0:" in _train_src, (
|
| 54 |
+
"FAIL: train() missing kl_beta gate. Pull the latest commit on origin/main."
|
| 55 |
+
)
|
| 56 |
+
assert "phase_episode_budgets" in _train_src, (
|
| 57 |
+
"FAIL: train() missing phase_episode_budgets support."
|
| 58 |
+
)
|
| 59 |
+
assert "use_kl" in tg.manual_grpo_step.__code__.co_varnames, (
|
| 60 |
+
"FAIL: manual_grpo_step missing 'use_kl' branch."
|
| 61 |
+
)
|
| 62 |
+
print(" OK β kl_beta gate live")
|
| 63 |
+
print(" OK β phase_episode_budgets supported")
|
| 64 |
+
print(" OK β use_kl branch in loss function")
|
| 65 |
+
|
| 66 |
+
# =============================================================================
|
| 67 |
+
# 4. EXPLICIT hyperparameters β does not rely on any previous cell's globals
|
| 68 |
+
# =============================================================================
|
| 69 |
+
print("\n[4/7] Setting hyperparameters (explicit, no Cell 9 dependency)...")
|
| 70 |
+
|
| 71 |
+
MODEL_NAME = "unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit"
|
| 72 |
+
GROUP_SIZE = 2
|
| 73 |
+
LEARNING_RATE = 5e-6
|
| 74 |
+
KL_BETA = 0.0 # T4-safe: skip reference model load (saves ~5 GB VRAM)
|
| 75 |
+
PHASE_EPISODE_BUDGETS = {1: 20, 2: 25, 3: 30} # 75 episodes total
|
| 76 |
+
NUM_EPISODES = sum(PHASE_EPISODE_BUDGETS.values())
|
| 77 |
+
PHASE_REWARD_TARGETS = {1: 1.2, 2: 1.1, 3: 1.0} # observational only
|
| 78 |
+
PHASE_MIN_WIN_RATE = 0.20
|
| 79 |
+
CONVERGENCE_WINDOW = 3
|
| 80 |
+
EARLY_STOP_ENABLED = False # forced off by train() under fixed-budget anyway
|
| 81 |
+
OUTPUT_DIR = "/kaggle/working/er_map_grpo_checkpoints"
|
| 82 |
+
|
| 83 |
+
# Anti-fragmentation for the GRPO backward pass on T4
|
| 84 |
+
os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
|
| 85 |
+
|
| 86 |
+
# Groq traffic shaping β 8B for actors, 70B for judges
|
| 87 |
+
os.environ["ERMAP_NURSE_MODEL"] = "llama-3.1-8b-instant"
|
| 88 |
+
os.environ["ERMAP_PATIENT_MODEL"] = "llama-3.1-8b-instant"
|
| 89 |
+
os.environ["ERMAP_EMPATHY_JUDGE_MODEL"] = "llama-3.3-70b-versatile"
|
| 90 |
+
os.environ["ERMAP_MEDICAL_JUDGE_MODEL"] = "llama-3.3-70b-versatile"
|
| 91 |
+
|
| 92 |
+
# Episode budget controls (read by triage_env)
|
| 93 |
+
os.environ["ERMAP_MAX_EPISODE_STEPS"] = "20"
|
| 94 |
+
os.environ["ERMAP_MAX_INTERNAL_EXCHANGES"] = "5"
|
| 95 |
+
|
| 96 |
+
print(f" NUM_EPISODES = {NUM_EPISODES}")
|
| 97 |
+
print(f" PHASE_EPISODE_BUDGETS = {PHASE_EPISODE_BUDGETS}")
|
| 98 |
+
print(f" GROUP_SIZE = {GROUP_SIZE}")
|
| 99 |
+
print(f" KL_BETA = {KL_BETA} (skip ref model)")
|
| 100 |
+
print(f" PHASE_REWARD_TARGETS = {PHASE_REWARD_TARGETS} (observational)")
|
| 101 |
+
|
| 102 |
+
# =============================================================================
|
| 103 |
+
# 5. Free VRAM and assert headroom for the model load
|
| 104 |
+
# =============================================================================
|
| 105 |
+
print("\n[5/7] Freeing VRAM...")
|
| 106 |
+
import torch # noqa: E402
|
| 107 |
+
|
| 108 |
+
for _name in ("model", "tokenizer", "ref_model", "optimizer"):
|
| 109 |
+
if _name in globals():
|
| 110 |
+
try:
|
| 111 |
+
del globals()[_name]
|
| 112 |
+
except KeyError:
|
| 113 |
+
pass
|
| 114 |
+
|
| 115 |
+
gc.collect()
|
| 116 |
+
torch.cuda.empty_cache()
|
| 117 |
+
torch.cuda.ipc_collect()
|
| 118 |
+
|
| 119 |
+
_free, _total = torch.cuda.mem_get_info(0)
|
| 120 |
+
print(f" VRAM free: {_free/1e9:.2f} / {_total/1e9:.2f} GB")
|
| 121 |
+
assert _free / 1e9 >= 6.0, (
|
| 122 |
+
f"FAIL: only {_free/1e9:.2f} GB free; need >= 6 GB. "
|
| 123 |
+
"Do Run -> Restart kernel, then re-run Cell 6 (mount), Cell 7 (secrets), "
|
| 124 |
+
"and this cell. The kernel has unrecoverable VRAM fragmentation."
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
# =============================================================================
|
| 128 |
+
# 6. Groq pre-flight (routing + 4-key liveness)
|
| 129 |
+
# =============================================================================
|
| 130 |
+
print("\n[6/7] Pre-flight: Groq routing + key liveness...")
|
| 131 |
+
from ER_MAP.envs.api_router import AgentRouter # noqa: E402
|
| 132 |
+
|
| 133 |
+
_router = AgentRouter()
|
| 134 |
+
_expected = {
|
| 135 |
+
"nurse": "llama-3.1-8b-instant",
|
| 136 |
+
"patient": "llama-3.1-8b-instant",
|
| 137 |
+
"empathy_judge": "llama-3.3-70b-versatile",
|
| 138 |
+
"medical_judge": "llama-3.3-70b-versatile",
|
| 139 |
+
}
|
| 140 |
+
_all_pass = True
|
| 141 |
+
for _role, _exp in _expected.items():
|
| 142 |
+
_actual = _router._models.get(_role, "?")
|
| 143 |
+
_client = _router._clients.get(_role)
|
| 144 |
+
if _client is None:
|
| 145 |
+
print(f" [SKIP] {_role:14s} -> no Groq client (key missing)")
|
| 146 |
+
_all_pass = False
|
| 147 |
+
continue
|
| 148 |
+
try:
|
| 149 |
+
_resp = _client.chat.completions.create(
|
| 150 |
+
model=_exp,
|
| 151 |
+
messages=[{"role": "user", "content": "Reply with exactly: PING"}],
|
| 152 |
+
max_tokens=4, temperature=0,
|
| 153 |
+
)
|
| 154 |
+
_api_ok = "PING" in (_resp.choices[0].message.content or "").upper()
|
| 155 |
+
_err = ""
|
| 156 |
+
except Exception as _e:
|
| 157 |
+
_api_ok = False
|
| 158 |
+
_err = f" ({type(_e).__name__}: {str(_e)[:80]})"
|
| 159 |
+
_flag = "PASS" if (_actual == _exp and _api_ok) else "FAIL"
|
| 160 |
+
print(f" [{_flag}] {_role:14s} | model={_actual:25s} | api_ok={_api_ok}{_err}")
|
| 161 |
+
if _flag == "FAIL":
|
| 162 |
+
_all_pass = False
|
| 163 |
+
assert _all_pass, "Pre-flight FAILED. Re-run Cell 7 (secrets) and Cell 6 (repo)."
|
| 164 |
+
|
| 165 |
+
# =============================================================================
|
| 166 |
+
# 7. LAUNCH β fixed-budget GRPO training
|
| 167 |
+
# =============================================================================
|
| 168 |
+
print("\n[7/7] Launching GRPO training (75 episodes, fixed budget)...")
|
| 169 |
+
print("=" * 72)
|
| 170 |
+
print(" Phase 1 (Tool Mastery) : 20 episodes")
|
| 171 |
+
print(" Phase 2 (Clinical Reasoning) : 25 episodes")
|
| 172 |
+
print(" Phase 3 (Empathetic Negotiation) : 30 episodes")
|
| 173 |
+
print(" Total : 75 episodes (~3-5 hours on T4)")
|
| 174 |
+
print(" HF Hub backup : every 20 episodes")
|
| 175 |
+
print("=" * 72)
|
| 176 |
+
|
| 177 |
+
metrics = tg.train(
|
| 178 |
+
num_episodes=NUM_EPISODES,
|
| 179 |
+
group_size=GROUP_SIZE,
|
| 180 |
+
model_name=MODEL_NAME,
|
| 181 |
+
groq_api_key=os.environ.get("GROQ_NURSE_API_KEY", "")
|
| 182 |
+
or os.environ.get("GROQ_API_KEY", ""),
|
| 183 |
+
learning_rate=LEARNING_RATE,
|
| 184 |
+
kl_beta=KL_BETA,
|
| 185 |
+
use_wandb=False,
|
| 186 |
+
output_dir=OUTPUT_DIR,
|
| 187 |
+
dry_run=False,
|
| 188 |
+
phase_reward_targets=PHASE_REWARD_TARGETS,
|
| 189 |
+
phase_min_win_rate=PHASE_MIN_WIN_RATE,
|
| 190 |
+
convergence_window=CONVERGENCE_WINDOW,
|
| 191 |
+
early_stop=EARLY_STOP_ENABLED,
|
| 192 |
+
phase_episode_budgets=PHASE_EPISODE_BUDGETS,
|
| 193 |
+
)
|
| 194 |
+
print("=" * 72)
|
| 195 |
+
print(f"\nTRAINING COMPLETE β {len(metrics)} metric records collected.")
|
| 196 |
+
print(f"Final LoRA adapter: {OUTPUT_DIR}/final_lora")
|
| 197 |
+
print(f"Plots will be rendered by Cell 15 (run it next).")
|
kaggle/train_ermap_grpo_kaggle.ipynb
CHANGED
|
@@ -452,8 +452,8 @@
|
|
| 452 |
"# but less predictable). When PHASE_EPISODE_BUDGETS is set, EARLY_STOP_ENABLED\n",
|
| 453 |
"# is automatically forced to False inside train() β the reward targets below\n",
|
| 454 |
"# become observational only (logged on the plots, not used for promotion).\n",
|
| 455 |
-
"PHASE_EPISODE_BUDGETS = {1: 20, 2:
|
| 456 |
-
"NUM_EPISODES = sum(PHASE_EPISODE_BUDGETS.values()) # =
|
| 457 |
"\n",
|
| 458 |
"# --- Per-phase reward thresholds (observational under fixed-budget) --------\n",
|
| 459 |
"# Plotted as horizontal target lines on the reward-growth chart so you can\n",
|
|
|
|
| 452 |
"# but less predictable). When PHASE_EPISODE_BUDGETS is set, EARLY_STOP_ENABLED\n",
|
| 453 |
"# is automatically forced to False inside train() β the reward targets below\n",
|
| 454 |
"# become observational only (logged on the plots, not used for promotion).\n",
|
| 455 |
+
"PHASE_EPISODE_BUDGETS = {1: 20, 2: 25, 3: 30} # 20 + 25 + 30 = 75 episodes\n",
|
| 456 |
+
"NUM_EPISODES = sum(PHASE_EPISODE_BUDGETS.values()) # = 75\n",
|
| 457 |
"\n",
|
| 458 |
"# --- Per-phase reward thresholds (observational under fixed-budget) --------\n",
|
| 459 |
"# Plotted as horizontal target lines on the reward-growth chart so you can\n",
|