Spaces:
Sleeping
perf(notebook): shrink SFT+GRPO budget for Colab free T4
Browse filesPrevious run took ~3hr (SFT 88min + GRPO 24min + eval 30-60min) and
exhausted Colab free quota. Cuts that preserve learning quality:
SFT:
- SFT_TARGET_ROWS 10000 -> 4000 (loss hit 0.10 at step 200/800, rest wasted)
- SFT_MAX_STEPS 800 -> 300 (stops at loss ~0.20, mean_acc ~0.92)
- save_steps 500 -> 300 (only saves final)
GRPO:
- max_steps 120 -> 60 (real gradient flow visible by step 30)
- num_generations 8 -> 6 (still > 4 for variance, 25% per-step speedup)
- save_steps 40 -> 60 (only saves final, eval iterates fewer ckpts)
- save_total_limit 3 -> 2
New budget: ~65min total vs ~180min before. Bonus: stopping SFT at
loss 0.20 instead of 0.10 leaves entropy higher -> GRPO has more room
to learn (fixes the v1 collapse problem at the source instead of
working around it with widened sampling).
|
@@ -156,79 +156,7 @@
|
|
| 156 |
"id": "sft-data-code",
|
| 157 |
"metadata": {},
|
| 158 |
"outputs": [],
|
| 159 |
-
"source": [
|
| 160 |
-
"from datasets import Dataset\n",
|
| 161 |
-
"from scenarios.simulation import list_tasks, get_task\n",
|
| 162 |
-
"from training.sft_dataset import build_sft_dataset\n",
|
| 163 |
-
"from collections import Counter\n",
|
| 164 |
-
"\n",
|
| 165 |
-
"# Synthetic pool. Default 10k rows for T4 speed; override via env var.\n",
|
| 166 |
-
"SFT_TARGET_ROWS = int(os.environ.get('SFT_TARGET_ROWS', '10000'))\n",
|
| 167 |
-
"SFT_MAX_ROWS = int(os.environ.get('SFT_MAX_ROWS', str(SFT_TARGET_ROWS)))\n",
|
| 168 |
-
"SFT_SEED_START = int(os.environ.get('SFT_SEED_START', '1000'))\n",
|
| 169 |
-
"SFT_SEED_BATCH = int(os.environ.get('SFT_SEED_BATCH', '128'))\n",
|
| 170 |
-
"SFT_MAX_STATES_PER_TASK = int(os.environ.get('SFT_MAX_STATES_PER_TASK', '24'))\n",
|
| 171 |
-
"GRPO_SEED_COUNT = int(os.environ.get('GRPO_SEED_COUNT', '160'))\n",
|
| 172 |
-
"\n",
|
| 173 |
-
"# Holdout seeds excluded from training so eval is defensible.\n",
|
| 174 |
-
"HOLDOUT_SEEDS_BY_DIFF = {\n",
|
| 175 |
-
" 'easy': {42},\n",
|
| 176 |
-
" 'medium': {17, 99},\n",
|
| 177 |
-
" 'hard': {7, 53},\n",
|
| 178 |
-
" 'nightmare': {31, 77},\n",
|
| 179 |
-
"}\n",
|
| 180 |
-
"DIFFICULTIES = ['easy', 'medium', 'hard', 'nightmare']\n",
|
| 181 |
-
"\n",
|
| 182 |
-
"headline_task_ids = [t.task_id for t in list_tasks()]\n",
|
| 183 |
-
"task_ids = list(headline_task_ids)\n",
|
| 184 |
-
"raw_sft = build_sft_dataset(headline_task_ids, max_states_per_task=SFT_MAX_STATES_PER_TASK)\n",
|
| 185 |
-
"generated_train_task_ids = []\n",
|
| 186 |
-
"\n",
|
| 187 |
-
"seed_cursor = SFT_SEED_START\n",
|
| 188 |
-
"while len(raw_sft) < SFT_TARGET_ROWS:\n",
|
| 189 |
-
" batch_task_ids = []\n",
|
| 190 |
-
" for diff in DIFFICULTIES:\n",
|
| 191 |
-
" blocked = HOLDOUT_SEEDS_BY_DIFF.get(diff, set())\n",
|
| 192 |
-
" for seed in range(seed_cursor, seed_cursor + SFT_SEED_BATCH):\n",
|
| 193 |
-
" if seed in blocked:\n",
|
| 194 |
-
" continue\n",
|
| 195 |
-
" tid = f'generated_{diff}_s{seed}'\n",
|
| 196 |
-
" get_task(tid)\n",
|
| 197 |
-
" batch_task_ids.append(tid)\n",
|
| 198 |
-
" raw_sft.extend(build_sft_dataset(batch_task_ids, max_states_per_task=SFT_MAX_STATES_PER_TASK))\n",
|
| 199 |
-
" generated_train_task_ids.extend(batch_task_ids)\n",
|
| 200 |
-
" task_ids.extend(batch_task_ids)\n",
|
| 201 |
-
" seed_cursor += SFT_SEED_BATCH\n",
|
| 202 |
-
" print(f'generated SFT rows: {len(raw_sft):,} / target {SFT_TARGET_ROWS:,}')\n",
|
| 203 |
-
"\n",
|
| 204 |
-
"if len(raw_sft) > SFT_MAX_ROWS:\n",
|
| 205 |
-
" raw_sft = raw_sft[:SFT_MAX_ROWS]\n",
|
| 206 |
-
"\n",
|
| 207 |
-
"# Seed list for GRPO state-action curriculum (smaller than SFT pool because\n",
|
| 208 |
-
"# GRPO rollouts are slower than SFT forward passes).\n",
|
| 209 |
-
"seeds = list(range(SFT_SEED_START, SFT_SEED_START + GRPO_SEED_COUNT))\n",
|
| 210 |
-
"\n",
|
| 211 |
-
"def to_chat_text(prompt, completion):\n",
|
| 212 |
-
" return tokenizer.apply_chat_template(\n",
|
| 213 |
-
" [\n",
|
| 214 |
-
" {'role': 'user', 'content': prompt},\n",
|
| 215 |
-
" {'role': 'assistant', 'content': completion},\n",
|
| 216 |
-
" ],\n",
|
| 217 |
-
" tokenize=False,\n",
|
| 218 |
-
" add_generation_prompt=False,\n",
|
| 219 |
-
" )\n",
|
| 220 |
-
"\n",
|
| 221 |
-
"sft_rows = [{'text': to_chat_text(s['prompt'], s['completion'])} for s in raw_sft]\n",
|
| 222 |
-
"sft_dataset = Dataset.from_list(sft_rows)\n",
|
| 223 |
-
"\n",
|
| 224 |
-
"atype_counts = Counter(s['action_type'] for s in raw_sft)\n",
|
| 225 |
-
"print(f'SFT samples: {len(sft_dataset):,}, unique tasks: {len(set(s[\"task_id\"] for s in raw_sft)):,}')\n",
|
| 226 |
-
"print(f'headline tasks: {len(headline_task_ids)}, generated train tasks used: {len(generated_train_task_ids):,}')\n",
|
| 227 |
-
"print(f'excluded generated holdout seeds: {HOLDOUT_SEEDS_BY_DIFF}')\n",
|
| 228 |
-
"print(f'action_type distribution: {dict(atype_counts)}')\n",
|
| 229 |
-
"print('sample (first 500 chars):')\n",
|
| 230 |
-
"print(sft_rows[0]['text'][:500])"
|
| 231 |
-
]
|
| 232 |
},
|
| 233 |
{
|
| 234 |
"cell_type": "code",
|
|
@@ -236,60 +164,7 @@
|
|
| 236 |
"id": "sft-train-code",
|
| 237 |
"metadata": {},
|
| 238 |
"outputs": [],
|
| 239 |
-
"source":
|
| 240 |
-
"from trl import SFTConfig, SFTTrainer\n",
|
| 241 |
-
"\n",
|
| 242 |
-
"OUT_ROOT = PERSIST_ROOT\n",
|
| 243 |
-
"SFT_DIR = os.path.join(OUT_ROOT, 'sft-merchant-agent')\n",
|
| 244 |
-
"GRPO_DIR = os.path.join(OUT_ROOT, 'grpo-merchant-agent')\n",
|
| 245 |
-
"\n",
|
| 246 |
-
"SFT_FINAL_DIR = os.path.join(SFT_DIR, 'final')\n",
|
| 247 |
-
"RUN_SFT_TRAIN = os.environ.get('RUN_SFT_TRAIN', 'auto').strip().lower()\n",
|
| 248 |
-
"TRAIN_SFT = RUN_SFT_TRAIN in {'1', 'true', 'yes', 'y', 'on'} or (\n",
|
| 249 |
-
" RUN_SFT_TRAIN == 'auto' and not os.path.isdir(SFT_FINAL_DIR)\n",
|
| 250 |
-
")\n",
|
| 251 |
-
"SFT_EPOCHS = float(os.environ.get('SFT_EPOCHS', '1'))\n",
|
| 252 |
-
"SFT_LR = float(os.environ.get('SFT_LR', '1e-4'))\n",
|
| 253 |
-
"SFT_MAX_STEPS = int(os.environ.get('SFT_MAX_STEPS', '800'))\n",
|
| 254 |
-
"\n",
|
| 255 |
-
"if not TRAIN_SFT:\n",
|
| 256 |
-
" print(f'Skipping SFT train; using existing adapter at {SFT_FINAL_DIR}')\n",
|
| 257 |
-
"else:\n",
|
| 258 |
-
" sft_config = SFTConfig(\n",
|
| 259 |
-
" output_dir=SFT_DIR,\n",
|
| 260 |
-
" per_device_train_batch_size=1,\n",
|
| 261 |
-
" gradient_accumulation_steps=8,\n",
|
| 262 |
-
" num_train_epochs=SFT_EPOCHS,\n",
|
| 263 |
-
" max_steps=SFT_MAX_STEPS,\n",
|
| 264 |
-
" learning_rate=SFT_LR,\n",
|
| 265 |
-
" logging_steps=10,\n",
|
| 266 |
-
" save_steps=500,\n",
|
| 267 |
-
" save_total_limit=2,\n",
|
| 268 |
-
" bf16=False,\n",
|
| 269 |
-
" fp16=True,\n",
|
| 270 |
-
" gradient_checkpointing=True,\n",
|
| 271 |
-
" gradient_checkpointing_kwargs={'use_reentrant': False},\n",
|
| 272 |
-
" max_length=1024,\n",
|
| 273 |
-
" dataset_text_field='text',\n",
|
| 274 |
-
" report_to='none',\n",
|
| 275 |
-
" optim='adamw_torch',\n",
|
| 276 |
-
" warmup_ratio=0.03,\n",
|
| 277 |
-
" )\n",
|
| 278 |
-
" print(f'SFT config: rows={len(sft_dataset):,}, epochs={SFT_EPOCHS}, lr={SFT_LR}, max_steps={SFT_MAX_STEPS}')\n",
|
| 279 |
-
" if hasattr(model, 'config'):\n",
|
| 280 |
-
" model.config.use_cache = False\n",
|
| 281 |
-
" sft_trainer = SFTTrainer(\n",
|
| 282 |
-
" model=model,\n",
|
| 283 |
-
" args=sft_config,\n",
|
| 284 |
-
" train_dataset=sft_dataset,\n",
|
| 285 |
-
" processing_class=tokenizer,\n",
|
| 286 |
-
" )\n",
|
| 287 |
-
" sft_trainer.train()\n",
|
| 288 |
-
" sft_trainer.save_model(SFT_FINAL_DIR)\n",
|
| 289 |
-
" del sft_trainer\n",
|
| 290 |
-
" torch.cuda.empty_cache()\n",
|
| 291 |
-
" print(f'PEAK VRAM (SFT): {torch.cuda.max_memory_allocated()/1e9:.2f} GB')"
|
| 292 |
-
]
|
| 293 |
},
|
| 294 |
{
|
| 295 |
"cell_type": "markdown",
|
|
@@ -390,7 +265,7 @@
|
|
| 390 |
"id": "grpo-train-code",
|
| 391 |
"metadata": {},
|
| 392 |
"outputs": [],
|
| 393 |
-
"source": "from trl import GRPOConfig, GRPOTrainer\nfrom training.outcome_reward import compute_outcome_reward, compute_format_reward\n\n# Re-arm hooks and disable cache. Do NOT zero out dropout - the merge cell\n# attached the Phase B LoRA with lora_dropout=0.1 specifically so train()-mode\n# generation has stochasticity. Zeroing it here was the v1 bug.\nmodel.enable_input_require_grads()\nif hasattr(model, 'config'):\n model.config.use_cache = False\nmodel.config.eos_token_id = tokenizer.eos_token_id\nmodel.config.pad_token_id = tokenizer.pad_token_id\nif hasattr(model, 'generation_config'):\n model.generation_config.eos_token_id = tokenizer.eos_token_id\n model.generation_config.pad_token_id = tokenizer.pad_token_id\n\ndef outcome_reward_fn(prompts, completions, **kwargs):\n task_ids = kwargs.get('task_id') or kwargs.get('task_ids')\n state_steps = kwargs.get('state_step') or kwargs.get('state_steps')\n return compute_outcome_reward(\n prompts, completions,\n task_ids=task_ids, state_steps=state_steps,\n )\n\ndef format_reward_fn(prompts, completions, **kwargs):\n return compute_format_reward(prompts, completions)\n\n# CRITICAL sampling kwargs - rewritten after the v1 run had grad_norm=0.0 on\n# 95% of steps. The v1 logs showed:\n# - frac_reward_zero_std=1.0 on ~80% of steps (all 4 generations identical)\n# - entropy=0.001-0.017 (policy near-delta after SFT mean_acc=0.96)\n# - When std=0 inside a group, advantage=0 and gradient=0.\n#\n# Fix: aggressively widen the sampling distribution.\n# temperature: 0.7 -> 1.3 (past 1.0 breaks the SFT argmax lock)\n# top_p: 0.9 -> 1.0 (no nucleus truncation)\n# top_k: 50 -> 0 (no top-k truncation)\n# num_generations: 4 ->
|
| 394 |
},
|
| 395 |
{
|
| 396 |
"cell_type": "markdown",
|
|
|
|
| 156 |
"id": "sft-data-code",
|
| 157 |
"metadata": {},
|
| 158 |
"outputs": [],
|
| 159 |
+
"source": "from datasets import Dataset\nfrom scenarios.simulation import list_tasks, get_task\nfrom training.sft_dataset import build_sft_dataset\nfrom collections import Counter\n\n# Synthetic pool. Default 10k rows for T4 speed; override via env var.\nSFT_TARGET_ROWS = int(os.environ.get('SFT_TARGET_ROWS', '4000'))\nSFT_MAX_ROWS = int(os.environ.get('SFT_MAX_ROWS', str(SFT_TARGET_ROWS)))\nSFT_SEED_START = int(os.environ.get('SFT_SEED_START', '1000'))\nSFT_SEED_BATCH = int(os.environ.get('SFT_SEED_BATCH', '128'))\nSFT_MAX_STATES_PER_TASK = int(os.environ.get('SFT_MAX_STATES_PER_TASK', '24'))\nGRPO_SEED_COUNT = int(os.environ.get('GRPO_SEED_COUNT', '160'))\n\n# Holdout seeds excluded from training so eval is defensible.\nHOLDOUT_SEEDS_BY_DIFF = {\n 'easy': {42},\n 'medium': {17, 99},\n 'hard': {7, 53},\n 'nightmare': {31, 77},\n}\nDIFFICULTIES = ['easy', 'medium', 'hard', 'nightmare']\n\nheadline_task_ids = [t.task_id for t in list_tasks()]\ntask_ids = list(headline_task_ids)\nraw_sft = build_sft_dataset(headline_task_ids, max_states_per_task=SFT_MAX_STATES_PER_TASK)\ngenerated_train_task_ids = []\n\nseed_cursor = SFT_SEED_START\nwhile len(raw_sft) < SFT_TARGET_ROWS:\n batch_task_ids = []\n for diff in DIFFICULTIES:\n blocked = HOLDOUT_SEEDS_BY_DIFF.get(diff, set())\n for seed in range(seed_cursor, seed_cursor + SFT_SEED_BATCH):\n if seed in blocked:\n continue\n tid = f'generated_{diff}_s{seed}'\n get_task(tid)\n batch_task_ids.append(tid)\n raw_sft.extend(build_sft_dataset(batch_task_ids, max_states_per_task=SFT_MAX_STATES_PER_TASK))\n generated_train_task_ids.extend(batch_task_ids)\n task_ids.extend(batch_task_ids)\n seed_cursor += SFT_SEED_BATCH\n print(f'generated SFT rows: {len(raw_sft):,} / target {SFT_TARGET_ROWS:,}')\n\nif len(raw_sft) > SFT_MAX_ROWS:\n raw_sft = raw_sft[:SFT_MAX_ROWS]\n\n# Seed list for GRPO state-action curriculum (smaller than SFT pool because\n# GRPO rollouts are slower than SFT forward passes).\nseeds = list(range(SFT_SEED_START, SFT_SEED_START + GRPO_SEED_COUNT))\n\ndef to_chat_text(prompt, completion):\n return tokenizer.apply_chat_template(\n [\n {'role': 'user', 'content': prompt},\n {'role': 'assistant', 'content': completion},\n ],\n tokenize=False,\n add_generation_prompt=False,\n )\n\nsft_rows = [{'text': to_chat_text(s['prompt'], s['completion'])} for s in raw_sft]\nsft_dataset = Dataset.from_list(sft_rows)\n\natype_counts = Counter(s['action_type'] for s in raw_sft)\nprint(f'SFT samples: {len(sft_dataset):,}, unique tasks: {len(set(s[\"task_id\"] for s in raw_sft)):,}')\nprint(f'headline tasks: {len(headline_task_ids)}, generated train tasks used: {len(generated_train_task_ids):,}')\nprint(f'excluded generated holdout seeds: {HOLDOUT_SEEDS_BY_DIFF}')\nprint(f'action_type distribution: {dict(atype_counts)}')\nprint('sample (first 500 chars):')\nprint(sft_rows[0]['text'][:500])"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
},
|
| 161 |
{
|
| 162 |
"cell_type": "code",
|
|
|
|
| 164 |
"id": "sft-train-code",
|
| 165 |
"metadata": {},
|
| 166 |
"outputs": [],
|
| 167 |
+
"source": "from trl import SFTConfig, SFTTrainer\n\nOUT_ROOT = PERSIST_ROOT\nSFT_DIR = os.path.join(OUT_ROOT, 'sft-merchant-agent')\nGRPO_DIR = os.path.join(OUT_ROOT, 'grpo-merchant-agent')\n\nSFT_FINAL_DIR = os.path.join(SFT_DIR, 'final')\nRUN_SFT_TRAIN = os.environ.get('RUN_SFT_TRAIN', 'auto').strip().lower()\nTRAIN_SFT = RUN_SFT_TRAIN in {'1', 'true', 'yes', 'y', 'on'} or (\n RUN_SFT_TRAIN == 'auto' and not os.path.isdir(SFT_FINAL_DIR)\n)\nSFT_EPOCHS = float(os.environ.get('SFT_EPOCHS', '1'))\nSFT_LR = float(os.environ.get('SFT_LR', '1e-4'))\nSFT_MAX_STEPS = int(os.environ.get('SFT_MAX_STEPS', '300'))\n\nif not TRAIN_SFT:\n print(f'Skipping SFT train; using existing adapter at {SFT_FINAL_DIR}')\nelse:\n sft_config = SFTConfig(\n output_dir=SFT_DIR,\n per_device_train_batch_size=1,\n gradient_accumulation_steps=8,\n num_train_epochs=SFT_EPOCHS,\n max_steps=SFT_MAX_STEPS,\n learning_rate=SFT_LR,\n logging_steps=10,\n save_steps=300,\n save_total_limit=2,\n bf16=False,\n fp16=True,\n gradient_checkpointing=True,\n gradient_checkpointing_kwargs={'use_reentrant': False},\n max_length=1024,\n dataset_text_field='text',\n report_to='none',\n optim='adamw_torch',\n warmup_ratio=0.03,\n )\n print(f'SFT config: rows={len(sft_dataset):,}, epochs={SFT_EPOCHS}, lr={SFT_LR}, max_steps={SFT_MAX_STEPS}')\n if hasattr(model, 'config'):\n model.config.use_cache = False\n sft_trainer = SFTTrainer(\n model=model,\n args=sft_config,\n train_dataset=sft_dataset,\n processing_class=tokenizer,\n )\n sft_trainer.train()\n sft_trainer.save_model(SFT_FINAL_DIR)\n del sft_trainer\n torch.cuda.empty_cache()\n print(f'PEAK VRAM (SFT): {torch.cuda.max_memory_allocated()/1e9:.2f} GB')"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 168 |
},
|
| 169 |
{
|
| 170 |
"cell_type": "markdown",
|
|
|
|
| 265 |
"id": "grpo-train-code",
|
| 266 |
"metadata": {},
|
| 267 |
"outputs": [],
|
| 268 |
+
"source": "from trl import GRPOConfig, GRPOTrainer\nfrom training.outcome_reward import compute_outcome_reward, compute_format_reward\n\n# Re-arm hooks and disable cache. Do NOT zero out dropout - the merge cell\n# attached the Phase B LoRA with lora_dropout=0.1 specifically so train()-mode\n# generation has stochasticity. Zeroing it here was the v1 bug.\nmodel.enable_input_require_grads()\nif hasattr(model, 'config'):\n model.config.use_cache = False\nmodel.config.eos_token_id = tokenizer.eos_token_id\nmodel.config.pad_token_id = tokenizer.pad_token_id\nif hasattr(model, 'generation_config'):\n model.generation_config.eos_token_id = tokenizer.eos_token_id\n model.generation_config.pad_token_id = tokenizer.pad_token_id\n\ndef outcome_reward_fn(prompts, completions, **kwargs):\n task_ids = kwargs.get('task_id') or kwargs.get('task_ids')\n state_steps = kwargs.get('state_step') or kwargs.get('state_steps')\n return compute_outcome_reward(\n prompts, completions,\n task_ids=task_ids, state_steps=state_steps,\n )\n\ndef format_reward_fn(prompts, completions, **kwargs):\n return compute_format_reward(prompts, completions)\n\n# CRITICAL sampling kwargs - rewritten after the v1 run had grad_norm=0.0 on\n# 95% of steps. The v1 logs showed:\n# - frac_reward_zero_std=1.0 on ~80% of steps (all 4 generations identical)\n# - entropy=0.001-0.017 (policy near-delta after SFT mean_acc=0.96)\n# - When std=0 inside a group, advantage=0 and gradient=0.\n#\n# Fix: aggressively widen the sampling distribution.\n# temperature: 0.7 -> 1.3 (past 1.0 breaks the SFT argmax lock)\n# top_p: 0.9 -> 1.0 (no nucleus truncation)\n# top_k: 50 -> 0 (no top-k truncation)\n# num_generations: 4 -> 6 (1.5x within-group variance odds, T4-friendly)\n# learning_rate: 5e-6 -> 2e-5 (bigger push to escape SFT collapse)\n# beta: 0.0 -> 0.04 (small KL anchor; v1 collapse risk is gone)\n# lora_dropout: 0.0 -> 0.1 (set in the merge cell, kept here)\n#\n# The format_reward_fn (-0.10 for invalid JSON) is the safety net that stops\n# the model from drifting into pure noise at the higher temperature.\ngrpo_config = GRPOConfig(\n output_dir=GRPO_DIR,\n per_device_train_batch_size=1,\n gradient_accumulation_steps=8,\n num_generations=6,\n max_prompt_length=1024,\n max_completion_length=192,\n learning_rate=float(os.environ.get('GRPO_LR', '2e-5')),\n max_steps=int(os.environ.get('GRPO_MAX_STEPS', '60')),\n logging_steps=5,\n save_steps=60,\n save_total_limit=2,\n bf16=False,\n fp16=True,\n max_grad_norm=0.5,\n gradient_checkpointing=False,\n report_to='none',\n beta=0.04,\n temperature=1.3,\n top_p=1.0,\n top_k=0,\n repetition_penalty=1.0,\n use_vllm=False,\n log_completions=True,\n num_completions_to_print=2,\n optim='adamw_torch',\n lr_scheduler_type='constant',\n)\n\nRUN_GRPO = os.environ.get('RUN_GRPO', '1').strip().lower() not in {'0', 'false', 'no'}\nif RUN_GRPO:\n grpo_trainer = GRPOTrainer(\n model=model,\n processing_class=tokenizer,\n reward_funcs=[outcome_reward_fn, format_reward_fn],\n args=grpo_config,\n train_dataset=grpo_dataset,\n )\n grpo_trainer.train()\n grpo_trainer.save_model(os.path.join(GRPO_DIR, 'final'))\n del grpo_trainer\n torch.cuda.empty_cache()\nelse:\n print('RUN_GRPO=0: skipped GRPO training')\nprint(f'PEAK VRAM (GRPO): {torch.cuda.max_memory_allocated()/1e9:.2f} GB')\n"
|
| 269 |
},
|
| 270 |
{
|
| 271 |
"cell_type": "markdown",
|