Uddiii commited on
Commit
cd923aa
Β·
1 Parent(s): 69f89ec

feat(kaggle): add clean_launch.py + shrink budget to 20/25/30 = 75 eps

Browse files
kaggle/build_notebook.py CHANGED
@@ -476,8 +476,8 @@ USE_WANDB = False # WANDB conflicts with protobuf 7 on Kaggle base image
476
  # but less predictable). When PHASE_EPISODE_BUDGETS is set, EARLY_STOP_ENABLED
477
  # is automatically forced to False inside train() β€” the reward targets below
478
  # become observational only (logged on the plots, not used for promotion).
479
- PHASE_EPISODE_BUDGETS = {1: 20, 2: 30, 3: 50} # 20 + 30 + 50 = 100 episodes
480
- NUM_EPISODES = sum(PHASE_EPISODE_BUDGETS.values()) # = 100
481
 
482
  # --- Per-phase reward thresholds (observational under fixed-budget) --------
483
  # Plotted as horizontal target lines on the reward-growth chart so you can
 
476
  # but less predictable). When PHASE_EPISODE_BUDGETS is set, EARLY_STOP_ENABLED
477
  # is automatically forced to False inside train() β€” the reward targets below
478
  # become observational only (logged on the plots, not used for promotion).
479
+ PHASE_EPISODE_BUDGETS = {1: 20, 2: 25, 3: 30} # 20 + 25 + 30 = 75 episodes
480
+ NUM_EPISODES = sum(PHASE_EPISODE_BUDGETS.values()) # = 75
481
 
482
  # --- Per-phase reward thresholds (observational under fixed-budget) --------
483
  # Plotted as horizontal target lines on the reward-growth chart so you can
kaggle/clean_launch.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ER-MAP CLEAN TRAINING LAUNCH (Kaggle, T4-safe, 75-episode fixed budget)
3
+
4
+ Self-contained, idempotent, foolproof. Replaces the old Cell 9 / Cell 11 /
5
+ Cell 13 sequence with ONE cell that:
6
+
7
+ 1. Force-pulls the repo to origin/main (picks up any new fix commits)
8
+ 2. Drops the cached ER_MAP module so the next import picks up the fresh disk
9
+ 3. Asserts the train_grpo patches are live in the running module (kl-gate
10
+ + use_kl loss branch + phase_episode_budgets parameter)
11
+ 4. Sets all hyperparameters EXPLICITLY β€” does not depend on any earlier
12
+ cell's globals being correct
13
+ 5. Frees VRAM aggressively and asserts >= 6 GB free before launch
14
+ 6. Runs the Groq pre-flight (routing + 4-key liveness) and asserts all PASS
15
+ 7. Calls train() with phase_episode_budgets={1: 20, 2: 25, 3: 30}
16
+
17
+ Usage from a Kaggle notebook cell:
18
+
19
+ exec(open("/kaggle/working/Meta_Finals/kaggle/clean_launch.py").read())
20
+
21
+ That one line is all you paste. Press play. Walk away for ~4 hours.
22
+ """
23
+
24
+ import os, sys, gc, subprocess, importlib # noqa: E401
25
+
26
+ # =============================================================================
27
+ # 1. Force repo to latest commit on origin/main
28
+ # =============================================================================
29
+ REPO_ROOT = "/kaggle/working/Meta_Finals"
30
+ print("[1/7] Updating repo to origin/main...")
31
+ subprocess.run(["git", "-C", REPO_ROOT, "fetch", "origin"], check=True)
32
+ subprocess.run(["git", "-C", REPO_ROOT, "reset", "--hard", "origin/main"], check=True)
33
+ subprocess.run(["git", "-C", REPO_ROOT, "log", "-1", "--oneline"])
34
+
35
+ # =============================================================================
36
+ # 2. Drop cached ER_MAP modules so import picks up the latest disk version
37
+ # =============================================================================
38
+ print("\n[2/7] Dropping cached modules...")
39
+ for _m in list(sys.modules):
40
+ if _m.startswith("ER_MAP"):
41
+ del sys.modules[_m]
42
+ if REPO_ROOT not in sys.path:
43
+ sys.path.insert(0, REPO_ROOT)
44
+
45
+ # =============================================================================
46
+ # 3. Verify all required patches are live in the running module
47
+ # =============================================================================
48
+ print("\n[3/7] Verifying patches...")
49
+ import inspect
50
+ import ER_MAP.training.train_grpo as tg
51
+
52
+ _train_src = inspect.getsource(tg.train)
53
+ assert "if kl_beta > 0.0:" in _train_src, (
54
+ "FAIL: train() missing kl_beta gate. Pull the latest commit on origin/main."
55
+ )
56
+ assert "phase_episode_budgets" in _train_src, (
57
+ "FAIL: train() missing phase_episode_budgets support."
58
+ )
59
+ assert "use_kl" in tg.manual_grpo_step.__code__.co_varnames, (
60
+ "FAIL: manual_grpo_step missing 'use_kl' branch."
61
+ )
62
+ print(" OK β€” kl_beta gate live")
63
+ print(" OK β€” phase_episode_budgets supported")
64
+ print(" OK β€” use_kl branch in loss function")
65
+
66
+ # =============================================================================
67
+ # 4. EXPLICIT hyperparameters β€” does not rely on any previous cell's globals
68
+ # =============================================================================
69
+ print("\n[4/7] Setting hyperparameters (explicit, no Cell 9 dependency)...")
70
+
71
+ MODEL_NAME = "unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit"
72
+ GROUP_SIZE = 2
73
+ LEARNING_RATE = 5e-6
74
+ KL_BETA = 0.0 # T4-safe: skip reference model load (saves ~5 GB VRAM)
75
+ PHASE_EPISODE_BUDGETS = {1: 20, 2: 25, 3: 30} # 75 episodes total
76
+ NUM_EPISODES = sum(PHASE_EPISODE_BUDGETS.values())
77
+ PHASE_REWARD_TARGETS = {1: 1.2, 2: 1.1, 3: 1.0} # observational only
78
+ PHASE_MIN_WIN_RATE = 0.20
79
+ CONVERGENCE_WINDOW = 3
80
+ EARLY_STOP_ENABLED = False # forced off by train() under fixed-budget anyway
81
+ OUTPUT_DIR = "/kaggle/working/er_map_grpo_checkpoints"
82
+
83
+ # Anti-fragmentation for the GRPO backward pass on T4
84
+ os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
85
+
86
+ # Groq traffic shaping β€” 8B for actors, 70B for judges
87
+ os.environ["ERMAP_NURSE_MODEL"] = "llama-3.1-8b-instant"
88
+ os.environ["ERMAP_PATIENT_MODEL"] = "llama-3.1-8b-instant"
89
+ os.environ["ERMAP_EMPATHY_JUDGE_MODEL"] = "llama-3.3-70b-versatile"
90
+ os.environ["ERMAP_MEDICAL_JUDGE_MODEL"] = "llama-3.3-70b-versatile"
91
+
92
+ # Episode budget controls (read by triage_env)
93
+ os.environ["ERMAP_MAX_EPISODE_STEPS"] = "20"
94
+ os.environ["ERMAP_MAX_INTERNAL_EXCHANGES"] = "5"
95
+
96
+ print(f" NUM_EPISODES = {NUM_EPISODES}")
97
+ print(f" PHASE_EPISODE_BUDGETS = {PHASE_EPISODE_BUDGETS}")
98
+ print(f" GROUP_SIZE = {GROUP_SIZE}")
99
+ print(f" KL_BETA = {KL_BETA} (skip ref model)")
100
+ print(f" PHASE_REWARD_TARGETS = {PHASE_REWARD_TARGETS} (observational)")
101
+
102
+ # =============================================================================
103
+ # 5. Free VRAM and assert headroom for the model load
104
+ # =============================================================================
105
+ print("\n[5/7] Freeing VRAM...")
106
+ import torch # noqa: E402
107
+
108
+ for _name in ("model", "tokenizer", "ref_model", "optimizer"):
109
+ if _name in globals():
110
+ try:
111
+ del globals()[_name]
112
+ except KeyError:
113
+ pass
114
+
115
+ gc.collect()
116
+ torch.cuda.empty_cache()
117
+ torch.cuda.ipc_collect()
118
+
119
+ _free, _total = torch.cuda.mem_get_info(0)
120
+ print(f" VRAM free: {_free/1e9:.2f} / {_total/1e9:.2f} GB")
121
+ assert _free / 1e9 >= 6.0, (
122
+ f"FAIL: only {_free/1e9:.2f} GB free; need >= 6 GB. "
123
+ "Do Run -> Restart kernel, then re-run Cell 6 (mount), Cell 7 (secrets), "
124
+ "and this cell. The kernel has unrecoverable VRAM fragmentation."
125
+ )
126
+
127
+ # =============================================================================
128
+ # 6. Groq pre-flight (routing + 4-key liveness)
129
+ # =============================================================================
130
+ print("\n[6/7] Pre-flight: Groq routing + key liveness...")
131
+ from ER_MAP.envs.api_router import AgentRouter # noqa: E402
132
+
133
+ _router = AgentRouter()
134
+ _expected = {
135
+ "nurse": "llama-3.1-8b-instant",
136
+ "patient": "llama-3.1-8b-instant",
137
+ "empathy_judge": "llama-3.3-70b-versatile",
138
+ "medical_judge": "llama-3.3-70b-versatile",
139
+ }
140
+ _all_pass = True
141
+ for _role, _exp in _expected.items():
142
+ _actual = _router._models.get(_role, "?")
143
+ _client = _router._clients.get(_role)
144
+ if _client is None:
145
+ print(f" [SKIP] {_role:14s} -> no Groq client (key missing)")
146
+ _all_pass = False
147
+ continue
148
+ try:
149
+ _resp = _client.chat.completions.create(
150
+ model=_exp,
151
+ messages=[{"role": "user", "content": "Reply with exactly: PING"}],
152
+ max_tokens=4, temperature=0,
153
+ )
154
+ _api_ok = "PING" in (_resp.choices[0].message.content or "").upper()
155
+ _err = ""
156
+ except Exception as _e:
157
+ _api_ok = False
158
+ _err = f" ({type(_e).__name__}: {str(_e)[:80]})"
159
+ _flag = "PASS" if (_actual == _exp and _api_ok) else "FAIL"
160
+ print(f" [{_flag}] {_role:14s} | model={_actual:25s} | api_ok={_api_ok}{_err}")
161
+ if _flag == "FAIL":
162
+ _all_pass = False
163
+ assert _all_pass, "Pre-flight FAILED. Re-run Cell 7 (secrets) and Cell 6 (repo)."
164
+
165
+ # =============================================================================
166
+ # 7. LAUNCH β€” fixed-budget GRPO training
167
+ # =============================================================================
168
+ print("\n[7/7] Launching GRPO training (75 episodes, fixed budget)...")
169
+ print("=" * 72)
170
+ print(" Phase 1 (Tool Mastery) : 20 episodes")
171
+ print(" Phase 2 (Clinical Reasoning) : 25 episodes")
172
+ print(" Phase 3 (Empathetic Negotiation) : 30 episodes")
173
+ print(" Total : 75 episodes (~3-5 hours on T4)")
174
+ print(" HF Hub backup : every 20 episodes")
175
+ print("=" * 72)
176
+
177
+ metrics = tg.train(
178
+ num_episodes=NUM_EPISODES,
179
+ group_size=GROUP_SIZE,
180
+ model_name=MODEL_NAME,
181
+ groq_api_key=os.environ.get("GROQ_NURSE_API_KEY", "")
182
+ or os.environ.get("GROQ_API_KEY", ""),
183
+ learning_rate=LEARNING_RATE,
184
+ kl_beta=KL_BETA,
185
+ use_wandb=False,
186
+ output_dir=OUTPUT_DIR,
187
+ dry_run=False,
188
+ phase_reward_targets=PHASE_REWARD_TARGETS,
189
+ phase_min_win_rate=PHASE_MIN_WIN_RATE,
190
+ convergence_window=CONVERGENCE_WINDOW,
191
+ early_stop=EARLY_STOP_ENABLED,
192
+ phase_episode_budgets=PHASE_EPISODE_BUDGETS,
193
+ )
194
+ print("=" * 72)
195
+ print(f"\nTRAINING COMPLETE β€” {len(metrics)} metric records collected.")
196
+ print(f"Final LoRA adapter: {OUTPUT_DIR}/final_lora")
197
+ print(f"Plots will be rendered by Cell 15 (run it next).")
kaggle/train_ermap_grpo_kaggle.ipynb CHANGED
@@ -452,8 +452,8 @@
452
  "# but less predictable). When PHASE_EPISODE_BUDGETS is set, EARLY_STOP_ENABLED\n",
453
  "# is automatically forced to False inside train() β€” the reward targets below\n",
454
  "# become observational only (logged on the plots, not used for promotion).\n",
455
- "PHASE_EPISODE_BUDGETS = {1: 20, 2: 30, 3: 50} # 20 + 30 + 50 = 100 episodes\n",
456
- "NUM_EPISODES = sum(PHASE_EPISODE_BUDGETS.values()) # = 100\n",
457
  "\n",
458
  "# --- Per-phase reward thresholds (observational under fixed-budget) --------\n",
459
  "# Plotted as horizontal target lines on the reward-growth chart so you can\n",
 
452
  "# but less predictable). When PHASE_EPISODE_BUDGETS is set, EARLY_STOP_ENABLED\n",
453
  "# is automatically forced to False inside train() β€” the reward targets below\n",
454
  "# become observational only (logged on the plots, not used for promotion).\n",
455
+ "PHASE_EPISODE_BUDGETS = {1: 20, 2: 25, 3: 30} # 20 + 25 + 30 = 75 episodes\n",
456
+ "NUM_EPISODES = sum(PHASE_EPISODE_BUDGETS.values()) # = 75\n",
457
  "\n",
458
  "# --- Per-phase reward thresholds (observational under fixed-budget) --------\n",
459
  "# Plotted as horizontal target lines on the reward-growth chart so you can\n",