Spaces:
Sleeping
perf(training): v5 - shorter SFT, longer GRPO, hard/nightmare oversample
Browse filesv4 results showed:
- SFT collapsed to mean_token_acc=0.96 even at 300 steps -> entropy 0.001
- GRPO improved nightmare 0.55 -> 0.69 but regressed easy 0.92 -> 0.61
- Trained overall 0.73 vs SFT 0.75 vs heuristic 0.81
v5 changes (compounding):
- SFT_MAX_STEPS 300 -> 150: stops at loss ~0.20, mean_acc ~0.88, entropy
stays higher so post-SFT policy has room to explore. Fixes the SFT
collapse problem at the source instead of working around it.
- GRPO max_steps 60 -> 200: 3.3x more training steps. v4 showed gradient
signal active on ~40% of steps; more steps = more opportunities for
policy refinement.
- GRPO learning_rate 2e-5 -> 3e-5: 50% bigger updates per step,
capitalizes on extra steps before warmup decay.
- GRPO data: hard + nightmare difficulties appear 2x in curriculum.
Concentrates training where exploration helps (per v4 per-family
results) and where the easy-case argmax-lock matters less.
- GRPO save_steps 60 -> 80: 2 intermediate checkpoints + final, richer
eval curve.
Expected runtime increase: SFT ~33min -> ~17min, GRPO ~12min -> ~40min.
Net: ~75min -> ~75min (SFT savings offset GRPO extension).
|
@@ -164,7 +164,7 @@
|
|
| 164 |
"id": "sft-train-code",
|
| 165 |
"metadata": {},
|
| 166 |
"outputs": [],
|
| 167 |
-
"source": "from trl import SFTConfig, SFTTrainer\n\nOUT_ROOT = PERSIST_ROOT\nSFT_DIR = os.path.join(OUT_ROOT, 'sft-merchant-agent')\nGRPO_DIR = os.path.join(OUT_ROOT, 'grpo-merchant-agent')\n\nSFT_FINAL_DIR = os.path.join(SFT_DIR, 'final')\nRUN_SFT_TRAIN = os.environ.get('RUN_SFT_TRAIN', 'auto').strip().lower()\nTRAIN_SFT = RUN_SFT_TRAIN in {'1', 'true', 'yes', 'y', 'on'} or (\n RUN_SFT_TRAIN == 'auto' and not os.path.isdir(SFT_FINAL_DIR)\n)\nSFT_EPOCHS = float(os.environ.get('SFT_EPOCHS', '1'))\nSFT_LR = float(os.environ.get('SFT_LR', '1e-4'))\nSFT_MAX_STEPS = int(os.environ.get('SFT_MAX_STEPS', '
|
| 168 |
},
|
| 169 |
{
|
| 170 |
"cell_type": "markdown",
|
|
@@ -202,62 +202,7 @@
|
|
| 202 |
"id": "grpo-data-code",
|
| 203 |
"metadata": {},
|
| 204 |
"outputs": [],
|
| 205 |
-
"source": [
|
| 206 |
-
"from training.reward_adapter import build_state_action_dataset\n",
|
| 207 |
-
"\n",
|
| 208 |
-
"# State-action samples: (task_id, state_step, prompt) tuples captured by\n",
|
| 209 |
-
"# rolling the heuristic forward on each task. The model will be asked to\n",
|
| 210 |
-
"# pick the next action at each captured state.\n",
|
| 211 |
-
"PHASE_B_MAX_STATES_PER_TASK = int(os.environ.get('PHASE_B_MAX_STATES_PER_TASK', '10'))\n",
|
| 212 |
-
"GRPO_DIFFICULTIES = tuple(\n",
|
| 213 |
-
" d.strip()\n",
|
| 214 |
-
" for d in os.environ.get('GRPO_DIFFICULTIES', 'easy,medium,hard,nightmare').split(',')\n",
|
| 215 |
-
" if d.strip()\n",
|
| 216 |
-
")\n",
|
| 217 |
-
"curriculum_task_ids = [\n",
|
| 218 |
-
" t.task_id for t in list_tasks()\n",
|
| 219 |
-
" if t.difficulty in GRPO_DIFFICULTIES\n",
|
| 220 |
-
"]\n",
|
| 221 |
-
"for diff in GRPO_DIFFICULTIES:\n",
|
| 222 |
-
" blocked = HOLDOUT_SEEDS_BY_DIFF.get(diff, set())\n",
|
| 223 |
-
" for s in seeds:\n",
|
| 224 |
-
" if s in blocked:\n",
|
| 225 |
-
" continue\n",
|
| 226 |
-
" tid = f'generated_{diff}_s{s}'\n",
|
| 227 |
-
" try:\n",
|
| 228 |
-
" get_task(tid)\n",
|
| 229 |
-
" if tid not in curriculum_task_ids:\n",
|
| 230 |
-
" curriculum_task_ids.append(tid)\n",
|
| 231 |
-
" except Exception:\n",
|
| 232 |
-
" pass\n",
|
| 233 |
-
"\n",
|
| 234 |
-
"raw_grpo = build_state_action_dataset(\n",
|
| 235 |
-
" curriculum_task_ids, max_states_per_task=PHASE_B_MAX_STATES_PER_TASK,\n",
|
| 236 |
-
")\n",
|
| 237 |
-
"\n",
|
| 238 |
-
"def to_chat_prompt(prompt):\n",
|
| 239 |
-
" return tokenizer.apply_chat_template(\n",
|
| 240 |
-
" [{'role': 'user', 'content': prompt}],\n",
|
| 241 |
-
" tokenize=False, add_generation_prompt=True,\n",
|
| 242 |
-
" )\n",
|
| 243 |
-
"\n",
|
| 244 |
-
"grpo_rows = []\n",
|
| 245 |
-
"for sample in raw_grpo:\n",
|
| 246 |
-
" chat_prompt = to_chat_prompt(sample['prompt'])\n",
|
| 247 |
-
" n_tokens = len(tokenizer(chat_prompt, add_special_tokens=False)['input_ids'])\n",
|
| 248 |
-
" if n_tokens <= 1000:\n",
|
| 249 |
-
" grpo_rows.append({\n",
|
| 250 |
-
" 'prompt': chat_prompt,\n",
|
| 251 |
-
" 'task_id': sample['task_id'],\n",
|
| 252 |
-
" 'state_step': int(sample['state_step']),\n",
|
| 253 |
-
" })\n",
|
| 254 |
-
"\n",
|
| 255 |
-
"grpo_dataset = Dataset.from_list(grpo_rows)\n",
|
| 256 |
-
"unique_tasks = len({row['task_id'] for row in grpo_rows})\n",
|
| 257 |
-
"print(f'GRPO state-action samples: {len(grpo_dataset)}, '\n",
|
| 258 |
-
" f'unique tasks: {unique_tasks}, difficulties={GRPO_DIFFICULTIES}, '\n",
|
| 259 |
-
" f'max_states_per_task={PHASE_B_MAX_STATES_PER_TASK}')"
|
| 260 |
-
]
|
| 261 |
},
|
| 262 |
{
|
| 263 |
"cell_type": "code",
|
|
@@ -265,7 +210,7 @@
|
|
| 265 |
"id": "grpo-train-code",
|
| 266 |
"metadata": {},
|
| 267 |
"outputs": [],
|
| 268 |
-
"source": "from trl import GRPOConfig, GRPOTrainer\nfrom training.outcome_reward import compute_outcome_reward, compute_format_reward\n\n# Re-arm hooks and disable cache. Do NOT zero out dropout - the merge cell\n# attached the Phase B LoRA with lora_dropout=0.1 specifically so train()-mode\n# generation has stochasticity. Zeroing it here was the v1 bug.\nmodel.enable_input_require_grads()\nif hasattr(model, 'config'):\n model.config.use_cache = False\nmodel.config.eos_token_id = tokenizer.eos_token_id\nmodel.config.pad_token_id = tokenizer.pad_token_id\nif hasattr(model, 'generation_config'):\n model.generation_config.eos_token_id = tokenizer.eos_token_id\n model.generation_config.pad_token_id = tokenizer.pad_token_id\n\ndef outcome_reward_fn(prompts, completions, **kwargs):\n task_ids = kwargs.get('task_id') or kwargs.get('task_ids')\n state_steps = kwargs.get('state_step') or kwargs.get('state_steps')\n return compute_outcome_reward(\n prompts, completions,\n task_ids=task_ids, state_steps=state_steps,\n )\n\ndef format_reward_fn(prompts, completions, **kwargs):\n return compute_format_reward(prompts, completions)\n\n# CRITICAL sampling kwargs - rewritten after the v1 run had grad_norm=0.0 on\n# 95% of steps. The v1 logs showed:\n# - frac_reward_zero_std=1.0 on ~80% of steps (all 4 generations identical)\n# - entropy=0.001-0.017 (policy near-delta after SFT mean_acc=0.96)\n# - When std=0 inside a group, advantage=0 and gradient=0.\n#\n# Fix: aggressively widen the sampling distribution.\n# temperature: 0.7 -> 1.3 (past 1.0 breaks the SFT argmax lock)\n# top_p: 0.9 -> 1.0 (no nucleus truncation)\n# top_k: 50 -> 0 (no top-k truncation)\n# num_generations: 4 -> 8 (2x within-group variance odds; gen_batch=8 must be divisible)\n# learning_rate: 5e-6 -> 2e-5 (bigger push to escape SFT collapse)\n# beta: 0.0 -> 0.04 (small KL anchor; v1 collapse risk is gone)\n# lora_dropout: 0.0 -> 0.1 (set in the merge cell, kept here)\n#\n# The format_reward_fn (-0.10 for invalid JSON) is the safety net that stops\n# the model from drifting into pure noise at the higher temperature.\ngrpo_config = GRPOConfig(\n output_dir=GRPO_DIR,\n per_device_train_batch_size=1,\n gradient_accumulation_steps=8,\n num_generations=8,\n max_prompt_length=1024,\n max_completion_length=192,\n learning_rate=float(os.environ.get('GRPO_LR', '
|
| 269 |
},
|
| 270 |
{
|
| 271 |
"cell_type": "markdown",
|
|
|
|
| 164 |
"id": "sft-train-code",
|
| 165 |
"metadata": {},
|
| 166 |
"outputs": [],
|
| 167 |
+
"source": "from trl import SFTConfig, SFTTrainer\n\nOUT_ROOT = PERSIST_ROOT\nSFT_DIR = os.path.join(OUT_ROOT, 'sft-merchant-agent')\nGRPO_DIR = os.path.join(OUT_ROOT, 'grpo-merchant-agent')\n\nSFT_FINAL_DIR = os.path.join(SFT_DIR, 'final')\nRUN_SFT_TRAIN = os.environ.get('RUN_SFT_TRAIN', 'auto').strip().lower()\nTRAIN_SFT = RUN_SFT_TRAIN in {'1', 'true', 'yes', 'y', 'on'} or (\n RUN_SFT_TRAIN == 'auto' and not os.path.isdir(SFT_FINAL_DIR)\n)\nSFT_EPOCHS = float(os.environ.get('SFT_EPOCHS', '1'))\nSFT_LR = float(os.environ.get('SFT_LR', '1e-4'))\nSFT_MAX_STEPS = int(os.environ.get('SFT_MAX_STEPS', '150'))\n\nif not TRAIN_SFT:\n print(f'Skipping SFT train; using existing adapter at {SFT_FINAL_DIR}')\nelse:\n sft_config = SFTConfig(\n output_dir=SFT_DIR,\n per_device_train_batch_size=1,\n gradient_accumulation_steps=8,\n num_train_epochs=SFT_EPOCHS,\n max_steps=SFT_MAX_STEPS,\n learning_rate=SFT_LR,\n logging_steps=10,\n save_steps=150,\n save_total_limit=2,\n bf16=False,\n fp16=True,\n gradient_checkpointing=True,\n gradient_checkpointing_kwargs={'use_reentrant': False},\n max_length=1024,\n dataset_text_field='text',\n report_to='none',\n optim='adamw_torch',\n warmup_ratio=0.03,\n )\n print(f'SFT config: rows={len(sft_dataset):,}, epochs={SFT_EPOCHS}, lr={SFT_LR}, max_steps={SFT_MAX_STEPS}')\n if hasattr(model, 'config'):\n model.config.use_cache = False\n sft_trainer = SFTTrainer(\n model=model,\n args=sft_config,\n train_dataset=sft_dataset,\n processing_class=tokenizer,\n )\n sft_trainer.train()\n sft_trainer.save_model(SFT_FINAL_DIR)\n del sft_trainer\n torch.cuda.empty_cache()\n print(f'PEAK VRAM (SFT): {torch.cuda.max_memory_allocated()/1e9:.2f} GB')"
|
| 168 |
},
|
| 169 |
{
|
| 170 |
"cell_type": "markdown",
|
|
|
|
| 202 |
"id": "grpo-data-code",
|
| 203 |
"metadata": {},
|
| 204 |
"outputs": [],
|
| 205 |
+
"source": "from training.reward_adapter import build_state_action_dataset\n\n# State-action samples: (task_id, state_step, prompt) tuples captured by\n# rolling the heuristic forward on each task. The model will be asked to\n# pick the next action at each captured state.\nPHASE_B_MAX_STATES_PER_TASK = int(os.environ.get('PHASE_B_MAX_STATES_PER_TASK', '10'))\nGRPO_DIFFICULTIES = tuple(\n d.strip()\n for d in os.environ.get('GRPO_DIFFICULTIES', 'easy,medium,hard,nightmare').split(',')\n if d.strip()\n)\ncurriculum_task_ids = [\n t.task_id for t in list_tasks()\n if t.difficulty in GRPO_DIFFICULTIES\n]\nfor diff in GRPO_DIFFICULTIES:\n blocked = HOLDOUT_SEEDS_BY_DIFF.get(diff, set())\n for s in seeds:\n if s in blocked:\n continue\n tid = f'generated_{diff}_s{s}'\n try:\n get_task(tid)\n if tid not in curriculum_task_ids:\n curriculum_task_ids.append(tid)\n except Exception:\n pass\n\n# Curriculum bias: hard + nightmare appear 2x to give GRPO more chances\n# to learn the cases where exploration beats SFT-locked argmax. v4 results\n# showed GRPO improved nightmare 0.55 -> 0.69 but regressed easy 0.92 -> 0.61\n# because easy was already SFT-saturated. Oversample where there is room to learn.\nhard_difficulties = ('hard', 'nightmare')\nhard_task_ids = [t.task_id for t in list_tasks() if t.difficulty in hard_difficulties]\nfor diff in hard_difficulties:\n blocked = HOLDOUT_SEEDS_BY_DIFF.get(diff, set())\n for s in seeds:\n if s in blocked:\n continue\n tid = f'generated_{diff}_s{s}'\n try:\n get_task(tid)\n if tid not in hard_task_ids:\n hard_task_ids.append(tid)\n except Exception:\n pass\n\ncurriculum_with_oversample = curriculum_task_ids + hard_task_ids # hard/nightmare appear twice\nraw_grpo = build_state_action_dataset(\n curriculum_with_oversample, max_states_per_task=PHASE_B_MAX_STATES_PER_TASK,\n)\n\ndef to_chat_prompt(prompt):\n return tokenizer.apply_chat_template(\n [{'role': 'user', 'content': prompt}],\n tokenize=False, add_generation_prompt=True,\n )\n\ngrpo_rows = []\nfor sample in raw_grpo:\n chat_prompt = to_chat_prompt(sample['prompt'])\n n_tokens = len(tokenizer(chat_prompt, add_special_tokens=False)['input_ids'])\n if n_tokens <= 1000:\n grpo_rows.append({\n 'prompt': chat_prompt,\n 'task_id': sample['task_id'],\n 'state_step': int(sample['state_step']),\n })\n\ngrpo_dataset = Dataset.from_list(grpo_rows)\nunique_tasks = len({row['task_id'] for row in grpo_rows})\nprint(f'GRPO state-action samples: {len(grpo_dataset)}, '\n f'unique tasks: {unique_tasks}, difficulties={GRPO_DIFFICULTIES}, '\n f'max_states_per_task={PHASE_B_MAX_STATES_PER_TASK}')"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 206 |
},
|
| 207 |
{
|
| 208 |
"cell_type": "code",
|
|
|
|
| 210 |
"id": "grpo-train-code",
|
| 211 |
"metadata": {},
|
| 212 |
"outputs": [],
|
| 213 |
+
"source": "from trl import GRPOConfig, GRPOTrainer\nfrom training.outcome_reward import compute_outcome_reward, compute_format_reward\n\n# Re-arm hooks and disable cache. Do NOT zero out dropout - the merge cell\n# attached the Phase B LoRA with lora_dropout=0.1 specifically so train()-mode\n# generation has stochasticity. Zeroing it here was the v1 bug.\nmodel.enable_input_require_grads()\nif hasattr(model, 'config'):\n model.config.use_cache = False\nmodel.config.eos_token_id = tokenizer.eos_token_id\nmodel.config.pad_token_id = tokenizer.pad_token_id\nif hasattr(model, 'generation_config'):\n model.generation_config.eos_token_id = tokenizer.eos_token_id\n model.generation_config.pad_token_id = tokenizer.pad_token_id\n\ndef outcome_reward_fn(prompts, completions, **kwargs):\n task_ids = kwargs.get('task_id') or kwargs.get('task_ids')\n state_steps = kwargs.get('state_step') or kwargs.get('state_steps')\n return compute_outcome_reward(\n prompts, completions,\n task_ids=task_ids, state_steps=state_steps,\n )\n\ndef format_reward_fn(prompts, completions, **kwargs):\n return compute_format_reward(prompts, completions)\n\n# CRITICAL sampling kwargs - rewritten after the v1 run had grad_norm=0.0 on\n# 95% of steps. The v1 logs showed:\n# - frac_reward_zero_std=1.0 on ~80% of steps (all 4 generations identical)\n# - entropy=0.001-0.017 (policy near-delta after SFT mean_acc=0.96)\n# - When std=0 inside a group, advantage=0 and gradient=0.\n#\n# Fix: aggressively widen the sampling distribution.\n# temperature: 0.7 -> 1.3 (past 1.0 breaks the SFT argmax lock)\n# top_p: 0.9 -> 1.0 (no nucleus truncation)\n# top_k: 50 -> 0 (no top-k truncation)\n# num_generations: 4 -> 8 (2x within-group variance odds; gen_batch=8 must be divisible)\n# learning_rate: 5e-6 -> 2e-5 (bigger push to escape SFT collapse)\n# beta: 0.0 -> 0.04 (small KL anchor; v1 collapse risk is gone)\n# lora_dropout: 0.0 -> 0.1 (set in the merge cell, kept here)\n#\n# The format_reward_fn (-0.10 for invalid JSON) is the safety net that stops\n# the model from drifting into pure noise at the higher temperature.\ngrpo_config = GRPOConfig(\n output_dir=GRPO_DIR,\n per_device_train_batch_size=1,\n gradient_accumulation_steps=8,\n num_generations=8,\n max_prompt_length=1024,\n max_completion_length=192,\n learning_rate=float(os.environ.get('GRPO_LR', '3e-5')),\n max_steps=int(os.environ.get('GRPO_MAX_STEPS', '200')),\n logging_steps=5,\n save_steps=80,\n save_total_limit=3,\n bf16=False,\n fp16=True,\n max_grad_norm=0.5,\n gradient_checkpointing=False,\n report_to='none',\n beta=0.04,\n temperature=1.3,\n top_p=1.0,\n top_k=0,\n repetition_penalty=1.0,\n use_vllm=False,\n log_completions=True,\n num_completions_to_print=2,\n optim='adamw_torch',\n lr_scheduler_type='constant',\n)\n\nRUN_GRPO = os.environ.get('RUN_GRPO', '1').strip().lower() not in {'0', 'false', 'no'}\nif RUN_GRPO:\n grpo_trainer = GRPOTrainer(\n model=model,\n processing_class=tokenizer,\n reward_funcs=[outcome_reward_fn, format_reward_fn],\n args=grpo_config,\n train_dataset=grpo_dataset,\n )\n grpo_trainer.train()\n grpo_trainer.save_model(os.path.join(GRPO_DIR, 'final'))\n del grpo_trainer\n torch.cuda.empty_cache()\nelse:\n print('RUN_GRPO=0: skipped GRPO training')\nprint(f'PEAK VRAM (GRPO): {torch.cuda.max_memory_allocated()/1e9:.2f} GB')\n"
|
| 214 |
},
|
| 215 |
{
|
| 216 |
"cell_type": "markdown",
|