mitudrudutta commited on
Commit
451f087
·
1 Parent(s): 8fcb6cf

perf(notebook): shrink SFT+GRPO budget for Colab free T4

Browse files

Previous run took ~3hr (SFT 88min + GRPO 24min + eval 30-60min) and
exhausted Colab free quota. Cuts that preserve learning quality:

SFT:
- SFT_TARGET_ROWS 10000 -> 4000 (loss hit 0.10 at step 200/800, rest wasted)
- SFT_MAX_STEPS 800 -> 300 (stops at loss ~0.20, mean_acc ~0.92)
- save_steps 500 -> 300 (only saves final)

GRPO:
- max_steps 120 -> 60 (real gradient flow visible by step 30)
- num_generations 8 -> 6 (still > 4 for variance, 25% per-step speedup)
- save_steps 40 -> 60 (only saves final, eval iterates fewer ckpts)
- save_total_limit 3 -> 2

New budget: ~65min total vs ~180min before. Bonus: stopping SFT at
loss 0.20 instead of 0.10 leaves entropy higher -> GRPO has more room
to learn (fixes the v1 collapse problem at the source instead of
working around it with widened sampling).

Files changed (1) hide show
  1. notebooks/train_merchant_agent.ipynb +3 -128
notebooks/train_merchant_agent.ipynb CHANGED
@@ -156,79 +156,7 @@
156
  "id": "sft-data-code",
157
  "metadata": {},
158
  "outputs": [],
159
- "source": [
160
- "from datasets import Dataset\n",
161
- "from scenarios.simulation import list_tasks, get_task\n",
162
- "from training.sft_dataset import build_sft_dataset\n",
163
- "from collections import Counter\n",
164
- "\n",
165
- "# Synthetic pool. Default 10k rows for T4 speed; override via env var.\n",
166
- "SFT_TARGET_ROWS = int(os.environ.get('SFT_TARGET_ROWS', '10000'))\n",
167
- "SFT_MAX_ROWS = int(os.environ.get('SFT_MAX_ROWS', str(SFT_TARGET_ROWS)))\n",
168
- "SFT_SEED_START = int(os.environ.get('SFT_SEED_START', '1000'))\n",
169
- "SFT_SEED_BATCH = int(os.environ.get('SFT_SEED_BATCH', '128'))\n",
170
- "SFT_MAX_STATES_PER_TASK = int(os.environ.get('SFT_MAX_STATES_PER_TASK', '24'))\n",
171
- "GRPO_SEED_COUNT = int(os.environ.get('GRPO_SEED_COUNT', '160'))\n",
172
- "\n",
173
- "# Holdout seeds excluded from training so eval is defensible.\n",
174
- "HOLDOUT_SEEDS_BY_DIFF = {\n",
175
- " 'easy': {42},\n",
176
- " 'medium': {17, 99},\n",
177
- " 'hard': {7, 53},\n",
178
- " 'nightmare': {31, 77},\n",
179
- "}\n",
180
- "DIFFICULTIES = ['easy', 'medium', 'hard', 'nightmare']\n",
181
- "\n",
182
- "headline_task_ids = [t.task_id for t in list_tasks()]\n",
183
- "task_ids = list(headline_task_ids)\n",
184
- "raw_sft = build_sft_dataset(headline_task_ids, max_states_per_task=SFT_MAX_STATES_PER_TASK)\n",
185
- "generated_train_task_ids = []\n",
186
- "\n",
187
- "seed_cursor = SFT_SEED_START\n",
188
- "while len(raw_sft) < SFT_TARGET_ROWS:\n",
189
- " batch_task_ids = []\n",
190
- " for diff in DIFFICULTIES:\n",
191
- " blocked = HOLDOUT_SEEDS_BY_DIFF.get(diff, set())\n",
192
- " for seed in range(seed_cursor, seed_cursor + SFT_SEED_BATCH):\n",
193
- " if seed in blocked:\n",
194
- " continue\n",
195
- " tid = f'generated_{diff}_s{seed}'\n",
196
- " get_task(tid)\n",
197
- " batch_task_ids.append(tid)\n",
198
- " raw_sft.extend(build_sft_dataset(batch_task_ids, max_states_per_task=SFT_MAX_STATES_PER_TASK))\n",
199
- " generated_train_task_ids.extend(batch_task_ids)\n",
200
- " task_ids.extend(batch_task_ids)\n",
201
- " seed_cursor += SFT_SEED_BATCH\n",
202
- " print(f'generated SFT rows: {len(raw_sft):,} / target {SFT_TARGET_ROWS:,}')\n",
203
- "\n",
204
- "if len(raw_sft) > SFT_MAX_ROWS:\n",
205
- " raw_sft = raw_sft[:SFT_MAX_ROWS]\n",
206
- "\n",
207
- "# Seed list for GRPO state-action curriculum (smaller than SFT pool because\n",
208
- "# GRPO rollouts are slower than SFT forward passes).\n",
209
- "seeds = list(range(SFT_SEED_START, SFT_SEED_START + GRPO_SEED_COUNT))\n",
210
- "\n",
211
- "def to_chat_text(prompt, completion):\n",
212
- " return tokenizer.apply_chat_template(\n",
213
- " [\n",
214
- " {'role': 'user', 'content': prompt},\n",
215
- " {'role': 'assistant', 'content': completion},\n",
216
- " ],\n",
217
- " tokenize=False,\n",
218
- " add_generation_prompt=False,\n",
219
- " )\n",
220
- "\n",
221
- "sft_rows = [{'text': to_chat_text(s['prompt'], s['completion'])} for s in raw_sft]\n",
222
- "sft_dataset = Dataset.from_list(sft_rows)\n",
223
- "\n",
224
- "atype_counts = Counter(s['action_type'] for s in raw_sft)\n",
225
- "print(f'SFT samples: {len(sft_dataset):,}, unique tasks: {len(set(s[\"task_id\"] for s in raw_sft)):,}')\n",
226
- "print(f'headline tasks: {len(headline_task_ids)}, generated train tasks used: {len(generated_train_task_ids):,}')\n",
227
- "print(f'excluded generated holdout seeds: {HOLDOUT_SEEDS_BY_DIFF}')\n",
228
- "print(f'action_type distribution: {dict(atype_counts)}')\n",
229
- "print('sample (first 500 chars):')\n",
230
- "print(sft_rows[0]['text'][:500])"
231
- ]
232
  },
233
  {
234
  "cell_type": "code",
@@ -236,60 +164,7 @@
236
  "id": "sft-train-code",
237
  "metadata": {},
238
  "outputs": [],
239
- "source": [
240
- "from trl import SFTConfig, SFTTrainer\n",
241
- "\n",
242
- "OUT_ROOT = PERSIST_ROOT\n",
243
- "SFT_DIR = os.path.join(OUT_ROOT, 'sft-merchant-agent')\n",
244
- "GRPO_DIR = os.path.join(OUT_ROOT, 'grpo-merchant-agent')\n",
245
- "\n",
246
- "SFT_FINAL_DIR = os.path.join(SFT_DIR, 'final')\n",
247
- "RUN_SFT_TRAIN = os.environ.get('RUN_SFT_TRAIN', 'auto').strip().lower()\n",
248
- "TRAIN_SFT = RUN_SFT_TRAIN in {'1', 'true', 'yes', 'y', 'on'} or (\n",
249
- " RUN_SFT_TRAIN == 'auto' and not os.path.isdir(SFT_FINAL_DIR)\n",
250
- ")\n",
251
- "SFT_EPOCHS = float(os.environ.get('SFT_EPOCHS', '1'))\n",
252
- "SFT_LR = float(os.environ.get('SFT_LR', '1e-4'))\n",
253
- "SFT_MAX_STEPS = int(os.environ.get('SFT_MAX_STEPS', '800'))\n",
254
- "\n",
255
- "if not TRAIN_SFT:\n",
256
- " print(f'Skipping SFT train; using existing adapter at {SFT_FINAL_DIR}')\n",
257
- "else:\n",
258
- " sft_config = SFTConfig(\n",
259
- " output_dir=SFT_DIR,\n",
260
- " per_device_train_batch_size=1,\n",
261
- " gradient_accumulation_steps=8,\n",
262
- " num_train_epochs=SFT_EPOCHS,\n",
263
- " max_steps=SFT_MAX_STEPS,\n",
264
- " learning_rate=SFT_LR,\n",
265
- " logging_steps=10,\n",
266
- " save_steps=500,\n",
267
- " save_total_limit=2,\n",
268
- " bf16=False,\n",
269
- " fp16=True,\n",
270
- " gradient_checkpointing=True,\n",
271
- " gradient_checkpointing_kwargs={'use_reentrant': False},\n",
272
- " max_length=1024,\n",
273
- " dataset_text_field='text',\n",
274
- " report_to='none',\n",
275
- " optim='adamw_torch',\n",
276
- " warmup_ratio=0.03,\n",
277
- " )\n",
278
- " print(f'SFT config: rows={len(sft_dataset):,}, epochs={SFT_EPOCHS}, lr={SFT_LR}, max_steps={SFT_MAX_STEPS}')\n",
279
- " if hasattr(model, 'config'):\n",
280
- " model.config.use_cache = False\n",
281
- " sft_trainer = SFTTrainer(\n",
282
- " model=model,\n",
283
- " args=sft_config,\n",
284
- " train_dataset=sft_dataset,\n",
285
- " processing_class=tokenizer,\n",
286
- " )\n",
287
- " sft_trainer.train()\n",
288
- " sft_trainer.save_model(SFT_FINAL_DIR)\n",
289
- " del sft_trainer\n",
290
- " torch.cuda.empty_cache()\n",
291
- " print(f'PEAK VRAM (SFT): {torch.cuda.max_memory_allocated()/1e9:.2f} GB')"
292
- ]
293
  },
294
  {
295
  "cell_type": "markdown",
@@ -390,7 +265,7 @@
390
  "id": "grpo-train-code",
391
  "metadata": {},
392
  "outputs": [],
393
- "source": "from trl import GRPOConfig, GRPOTrainer\nfrom training.outcome_reward import compute_outcome_reward, compute_format_reward\n\n# Re-arm hooks and disable cache. Do NOT zero out dropout - the merge cell\n# attached the Phase B LoRA with lora_dropout=0.1 specifically so train()-mode\n# generation has stochasticity. Zeroing it here was the v1 bug.\nmodel.enable_input_require_grads()\nif hasattr(model, 'config'):\n model.config.use_cache = False\nmodel.config.eos_token_id = tokenizer.eos_token_id\nmodel.config.pad_token_id = tokenizer.pad_token_id\nif hasattr(model, 'generation_config'):\n model.generation_config.eos_token_id = tokenizer.eos_token_id\n model.generation_config.pad_token_id = tokenizer.pad_token_id\n\ndef outcome_reward_fn(prompts, completions, **kwargs):\n task_ids = kwargs.get('task_id') or kwargs.get('task_ids')\n state_steps = kwargs.get('state_step') or kwargs.get('state_steps')\n return compute_outcome_reward(\n prompts, completions,\n task_ids=task_ids, state_steps=state_steps,\n )\n\ndef format_reward_fn(prompts, completions, **kwargs):\n return compute_format_reward(prompts, completions)\n\n# CRITICAL sampling kwargs - rewritten after the v1 run had grad_norm=0.0 on\n# 95% of steps. The v1 logs showed:\n# - frac_reward_zero_std=1.0 on ~80% of steps (all 4 generations identical)\n# - entropy=0.001-0.017 (policy near-delta after SFT mean_acc=0.96)\n# - When std=0 inside a group, advantage=0 and gradient=0.\n#\n# Fix: aggressively widen the sampling distribution.\n# temperature: 0.7 -> 1.3 (past 1.0 breaks the SFT argmax lock)\n# top_p: 0.9 -> 1.0 (no nucleus truncation)\n# top_k: 50 -> 0 (no top-k truncation)\n# num_generations: 4 -> 8 (doubles within-group variance odds)\n# learning_rate: 5e-6 -> 2e-5 (bigger push to escape SFT collapse)\n# beta: 0.0 -> 0.04 (small KL anchor; v1 collapse risk is gone)\n# lora_dropout: 0.0 -> 0.1 (set in the merge cell, kept here)\n#\n# The format_reward_fn (-0.10 for invalid JSON) is the safety net that stops\n# the model from drifting into pure noise at the higher temperature.\ngrpo_config = GRPOConfig(\n output_dir=GRPO_DIR,\n per_device_train_batch_size=1,\n gradient_accumulation_steps=8,\n num_generations=8,\n max_prompt_length=1024,\n max_completion_length=192,\n learning_rate=float(os.environ.get('GRPO_LR', '2e-5')),\n max_steps=int(os.environ.get('GRPO_MAX_STEPS', '120')),\n logging_steps=5,\n save_steps=40,\n save_total_limit=3,\n bf16=False,\n fp16=True,\n max_grad_norm=0.5,\n gradient_checkpointing=False,\n report_to='none',\n beta=0.04,\n temperature=1.3,\n top_p=1.0,\n top_k=0,\n repetition_penalty=1.0,\n use_vllm=False,\n log_completions=True,\n num_completions_to_print=2,\n optim='adamw_torch',\n lr_scheduler_type='constant',\n)\n\nRUN_GRPO = os.environ.get('RUN_GRPO', '1').strip().lower() not in {'0', 'false', 'no'}\nif RUN_GRPO:\n grpo_trainer = GRPOTrainer(\n model=model,\n processing_class=tokenizer,\n reward_funcs=[outcome_reward_fn, format_reward_fn],\n args=grpo_config,\n train_dataset=grpo_dataset,\n )\n grpo_trainer.train()\n grpo_trainer.save_model(os.path.join(GRPO_DIR, 'final'))\n del grpo_trainer\n torch.cuda.empty_cache()\nelse:\n print('RUN_GRPO=0: skipped GRPO training')\nprint(f'PEAK VRAM (GRPO): {torch.cuda.max_memory_allocated()/1e9:.2f} GB')\n"
394
  },
395
  {
396
  "cell_type": "markdown",
 
156
  "id": "sft-data-code",
157
  "metadata": {},
158
  "outputs": [],
159
+ "source": "from datasets import Dataset\nfrom scenarios.simulation import list_tasks, get_task\nfrom training.sft_dataset import build_sft_dataset\nfrom collections import Counter\n\n# Synthetic pool. Default 10k rows for T4 speed; override via env var.\nSFT_TARGET_ROWS = int(os.environ.get('SFT_TARGET_ROWS', '4000'))\nSFT_MAX_ROWS = int(os.environ.get('SFT_MAX_ROWS', str(SFT_TARGET_ROWS)))\nSFT_SEED_START = int(os.environ.get('SFT_SEED_START', '1000'))\nSFT_SEED_BATCH = int(os.environ.get('SFT_SEED_BATCH', '128'))\nSFT_MAX_STATES_PER_TASK = int(os.environ.get('SFT_MAX_STATES_PER_TASK', '24'))\nGRPO_SEED_COUNT = int(os.environ.get('GRPO_SEED_COUNT', '160'))\n\n# Holdout seeds excluded from training so eval is defensible.\nHOLDOUT_SEEDS_BY_DIFF = {\n 'easy': {42},\n 'medium': {17, 99},\n 'hard': {7, 53},\n 'nightmare': {31, 77},\n}\nDIFFICULTIES = ['easy', 'medium', 'hard', 'nightmare']\n\nheadline_task_ids = [t.task_id for t in list_tasks()]\ntask_ids = list(headline_task_ids)\nraw_sft = build_sft_dataset(headline_task_ids, max_states_per_task=SFT_MAX_STATES_PER_TASK)\ngenerated_train_task_ids = []\n\nseed_cursor = SFT_SEED_START\nwhile len(raw_sft) < SFT_TARGET_ROWS:\n batch_task_ids = []\n for diff in DIFFICULTIES:\n blocked = HOLDOUT_SEEDS_BY_DIFF.get(diff, set())\n for seed in range(seed_cursor, seed_cursor + SFT_SEED_BATCH):\n if seed in blocked:\n continue\n tid = f'generated_{diff}_s{seed}'\n get_task(tid)\n batch_task_ids.append(tid)\n raw_sft.extend(build_sft_dataset(batch_task_ids, max_states_per_task=SFT_MAX_STATES_PER_TASK))\n generated_train_task_ids.extend(batch_task_ids)\n task_ids.extend(batch_task_ids)\n seed_cursor += SFT_SEED_BATCH\n print(f'generated SFT rows: {len(raw_sft):,} / target {SFT_TARGET_ROWS:,}')\n\nif len(raw_sft) > SFT_MAX_ROWS:\n raw_sft = raw_sft[:SFT_MAX_ROWS]\n\n# Seed list for GRPO state-action curriculum (smaller than SFT pool because\n# GRPO rollouts are slower than SFT forward passes).\nseeds = list(range(SFT_SEED_START, SFT_SEED_START + GRPO_SEED_COUNT))\n\ndef to_chat_text(prompt, completion):\n return tokenizer.apply_chat_template(\n [\n {'role': 'user', 'content': prompt},\n {'role': 'assistant', 'content': completion},\n ],\n tokenize=False,\n add_generation_prompt=False,\n )\n\nsft_rows = [{'text': to_chat_text(s['prompt'], s['completion'])} for s in raw_sft]\nsft_dataset = Dataset.from_list(sft_rows)\n\natype_counts = Counter(s['action_type'] for s in raw_sft)\nprint(f'SFT samples: {len(sft_dataset):,}, unique tasks: {len(set(s[\"task_id\"] for s in raw_sft)):,}')\nprint(f'headline tasks: {len(headline_task_ids)}, generated train tasks used: {len(generated_train_task_ids):,}')\nprint(f'excluded generated holdout seeds: {HOLDOUT_SEEDS_BY_DIFF}')\nprint(f'action_type distribution: {dict(atype_counts)}')\nprint('sample (first 500 chars):')\nprint(sft_rows[0]['text'][:500])"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
  },
161
  {
162
  "cell_type": "code",
 
164
  "id": "sft-train-code",
165
  "metadata": {},
166
  "outputs": [],
167
+ "source": "from trl import SFTConfig, SFTTrainer\n\nOUT_ROOT = PERSIST_ROOT\nSFT_DIR = os.path.join(OUT_ROOT, 'sft-merchant-agent')\nGRPO_DIR = os.path.join(OUT_ROOT, 'grpo-merchant-agent')\n\nSFT_FINAL_DIR = os.path.join(SFT_DIR, 'final')\nRUN_SFT_TRAIN = os.environ.get('RUN_SFT_TRAIN', 'auto').strip().lower()\nTRAIN_SFT = RUN_SFT_TRAIN in {'1', 'true', 'yes', 'y', 'on'} or (\n RUN_SFT_TRAIN == 'auto' and not os.path.isdir(SFT_FINAL_DIR)\n)\nSFT_EPOCHS = float(os.environ.get('SFT_EPOCHS', '1'))\nSFT_LR = float(os.environ.get('SFT_LR', '1e-4'))\nSFT_MAX_STEPS = int(os.environ.get('SFT_MAX_STEPS', '300'))\n\nif not TRAIN_SFT:\n print(f'Skipping SFT train; using existing adapter at {SFT_FINAL_DIR}')\nelse:\n sft_config = SFTConfig(\n output_dir=SFT_DIR,\n per_device_train_batch_size=1,\n gradient_accumulation_steps=8,\n num_train_epochs=SFT_EPOCHS,\n max_steps=SFT_MAX_STEPS,\n learning_rate=SFT_LR,\n logging_steps=10,\n save_steps=300,\n save_total_limit=2,\n bf16=False,\n fp16=True,\n gradient_checkpointing=True,\n gradient_checkpointing_kwargs={'use_reentrant': False},\n max_length=1024,\n dataset_text_field='text',\n report_to='none',\n optim='adamw_torch',\n warmup_ratio=0.03,\n )\n print(f'SFT config: rows={len(sft_dataset):,}, epochs={SFT_EPOCHS}, lr={SFT_LR}, max_steps={SFT_MAX_STEPS}')\n if hasattr(model, 'config'):\n model.config.use_cache = False\n sft_trainer = SFTTrainer(\n model=model,\n args=sft_config,\n train_dataset=sft_dataset,\n processing_class=tokenizer,\n )\n sft_trainer.train()\n sft_trainer.save_model(SFT_FINAL_DIR)\n del sft_trainer\n torch.cuda.empty_cache()\n print(f'PEAK VRAM (SFT): {torch.cuda.max_memory_allocated()/1e9:.2f} GB')"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  },
169
  {
170
  "cell_type": "markdown",
 
265
  "id": "grpo-train-code",
266
  "metadata": {},
267
  "outputs": [],
268
+ "source": "from trl import GRPOConfig, GRPOTrainer\nfrom training.outcome_reward import compute_outcome_reward, compute_format_reward\n\n# Re-arm hooks and disable cache. Do NOT zero out dropout - the merge cell\n# attached the Phase B LoRA with lora_dropout=0.1 specifically so train()-mode\n# generation has stochasticity. Zeroing it here was the v1 bug.\nmodel.enable_input_require_grads()\nif hasattr(model, 'config'):\n model.config.use_cache = False\nmodel.config.eos_token_id = tokenizer.eos_token_id\nmodel.config.pad_token_id = tokenizer.pad_token_id\nif hasattr(model, 'generation_config'):\n model.generation_config.eos_token_id = tokenizer.eos_token_id\n model.generation_config.pad_token_id = tokenizer.pad_token_id\n\ndef outcome_reward_fn(prompts, completions, **kwargs):\n task_ids = kwargs.get('task_id') or kwargs.get('task_ids')\n state_steps = kwargs.get('state_step') or kwargs.get('state_steps')\n return compute_outcome_reward(\n prompts, completions,\n task_ids=task_ids, state_steps=state_steps,\n )\n\ndef format_reward_fn(prompts, completions, **kwargs):\n return compute_format_reward(prompts, completions)\n\n# CRITICAL sampling kwargs - rewritten after the v1 run had grad_norm=0.0 on\n# 95% of steps. The v1 logs showed:\n# - frac_reward_zero_std=1.0 on ~80% of steps (all 4 generations identical)\n# - entropy=0.001-0.017 (policy near-delta after SFT mean_acc=0.96)\n# - When std=0 inside a group, advantage=0 and gradient=0.\n#\n# Fix: aggressively widen the sampling distribution.\n# temperature: 0.7 -> 1.3 (past 1.0 breaks the SFT argmax lock)\n# top_p: 0.9 -> 1.0 (no nucleus truncation)\n# top_k: 50 -> 0 (no top-k truncation)\n# num_generations: 4 -> 6 (1.5x within-group variance odds, T4-friendly)\n# learning_rate: 5e-6 -> 2e-5 (bigger push to escape SFT collapse)\n# beta: 0.0 -> 0.04 (small KL anchor; v1 collapse risk is gone)\n# lora_dropout: 0.0 -> 0.1 (set in the merge cell, kept here)\n#\n# The format_reward_fn (-0.10 for invalid JSON) is the safety net that stops\n# the model from drifting into pure noise at the higher temperature.\ngrpo_config = GRPOConfig(\n output_dir=GRPO_DIR,\n per_device_train_batch_size=1,\n gradient_accumulation_steps=8,\n num_generations=6,\n max_prompt_length=1024,\n max_completion_length=192,\n learning_rate=float(os.environ.get('GRPO_LR', '2e-5')),\n max_steps=int(os.environ.get('GRPO_MAX_STEPS', '60')),\n logging_steps=5,\n save_steps=60,\n save_total_limit=2,\n bf16=False,\n fp16=True,\n max_grad_norm=0.5,\n gradient_checkpointing=False,\n report_to='none',\n beta=0.04,\n temperature=1.3,\n top_p=1.0,\n top_k=0,\n repetition_penalty=1.0,\n use_vllm=False,\n log_completions=True,\n num_completions_to_print=2,\n optim='adamw_torch',\n lr_scheduler_type='constant',\n)\n\nRUN_GRPO = os.environ.get('RUN_GRPO', '1').strip().lower() not in {'0', 'false', 'no'}\nif RUN_GRPO:\n grpo_trainer = GRPOTrainer(\n model=model,\n processing_class=tokenizer,\n reward_funcs=[outcome_reward_fn, format_reward_fn],\n args=grpo_config,\n train_dataset=grpo_dataset,\n )\n grpo_trainer.train()\n grpo_trainer.save_model(os.path.join(GRPO_DIR, 'final'))\n del grpo_trainer\n torch.cuda.empty_cache()\nelse:\n print('RUN_GRPO=0: skipped GRPO training')\nprint(f'PEAK VRAM (GRPO): {torch.cuda.max_memory_allocated()/1e9:.2f} GB')\n"
269
  },
270
  {
271
  "cell_type": "markdown",