Spaces:
Sleeping
fix(eval): preload bases, swap adapters, fix peft offload_dir crash
Browse filesOld eval cell loaded a fresh 6.3GB base on every iteration. The
'del m_ckpt; torch.cuda.empty_cache()' pattern did NOT free GPU
because the policy closure captured 'm', which captured 'base'. By
the second iteration, T4's 15GB was full. accelerate's auto
device_map decided to offload some layers to CPU/disk, then peft 0.14
crashed with 'We need an offload_dir to dispatch this model'.
Fix:
- Free all training-cell globals (model, merged_base, sft_trainer,
grpo_trainer, etc) before any eval load.
- Load eval_base ONCE (for base + SFT eval).
- Pre-merge SFT into a second base ONCE (for GRPO eval).
- Per-checkpoint: attach adapter via PeftModel.from_pretrained with
unique adapter_name, run eval, then delete_adapter to release the
LoRA tensors. Base stays in VRAM the whole time.
- Cleanup function returned alongside policy fn so the loop can
release adapter weights without nuking the base.
|
@@ -408,97 +408,7 @@
|
|
| 408 |
"id": "eval-code",
|
| 409 |
"metadata": {},
|
| 410 |
"outputs": [],
|
| 411 |
-
"source": [
|
| 412 |
-
"import glob, re\n",
|
| 413 |
-
"from peft import PeftModel\n",
|
| 414 |
-
"from training.curve import (\n",
|
| 415 |
-
" evaluate_checkpoint, evaluate_checkpoint_by_family,\n",
|
| 416 |
-
" plot_training_curve, plot_training_curve_by_family,\n",
|
| 417 |
-
")\n",
|
| 418 |
-
"\n",
|
| 419 |
-
"def make_text_policy(adapter_path, adapter_kind='base'):\n",
|
| 420 |
-
" base = AutoModelForCausalLM.from_pretrained(\n",
|
| 421 |
-
" MODEL_ID, torch_dtype=torch.float16, device_map='auto', trust_remote_code=True,\n",
|
| 422 |
-
" )\n",
|
| 423 |
-
" if adapter_kind == 'grpo':\n",
|
| 424 |
-
" sft = PeftModel.from_pretrained(base, SFT_FINAL_DIR)\n",
|
| 425 |
-
" base = sft.merge_and_unload()\n",
|
| 426 |
-
" m = PeftModel.from_pretrained(base, adapter_path)\n",
|
| 427 |
-
" elif adapter_path is not None:\n",
|
| 428 |
-
" m = PeftModel.from_pretrained(base, adapter_path)\n",
|
| 429 |
-
" else:\n",
|
| 430 |
-
" m = base\n",
|
| 431 |
-
" m.eval()\n",
|
| 432 |
-
" tok = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)\n",
|
| 433 |
-
" if tok.pad_token is None:\n",
|
| 434 |
-
" tok.pad_token = tok.eos_token\n",
|
| 435 |
-
" tok.padding_side = 'left'\n",
|
| 436 |
-
"\n",
|
| 437 |
-
" def policy(prompt):\n",
|
| 438 |
-
" chat = tok.apply_chat_template(\n",
|
| 439 |
-
" [{'role': 'user', 'content': prompt}],\n",
|
| 440 |
-
" tokenize=False, add_generation_prompt=True,\n",
|
| 441 |
-
" )\n",
|
| 442 |
-
" inputs = tok(chat, return_tensors='pt', truncation=True, max_length=1024).to(m.device)\n",
|
| 443 |
-
" with torch.no_grad():\n",
|
| 444 |
-
" out = m.generate(\n",
|
| 445 |
-
" **inputs, max_new_tokens=192, do_sample=False,\n",
|
| 446 |
-
" pad_token_id=tok.eos_token_id, eos_token_id=tok.eos_token_id,\n",
|
| 447 |
-
" )\n",
|
| 448 |
-
" return tok.decode(out[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)\n",
|
| 449 |
-
" return policy, m\n",
|
| 450 |
-
"\n",
|
| 451 |
-
"# Catalog of checkpoints: untrained base, SFT final, every saved GRPO checkpoint.\n",
|
| 452 |
-
"ckpt_specs = [('base', None, 0, 'base'),\n",
|
| 453 |
-
" ('sft', SFT_FINAL_DIR, 1, 'sft')]\n",
|
| 454 |
-
"grpo_dirs = sorted(\n",
|
| 455 |
-
" glob.glob(os.path.join(GRPO_DIR, 'checkpoint-*')),\n",
|
| 456 |
-
" key=lambda p: int(re.search(r'checkpoint-(\\d+)', p).group(1)),\n",
|
| 457 |
-
")\n",
|
| 458 |
-
"for d in grpo_dirs:\n",
|
| 459 |
-
" step = int(re.search(r'checkpoint-(\\d+)', d).group(1))\n",
|
| 460 |
-
" ckpt_specs.append((f'grpo-{step}', d, 1 + step, 'grpo'))\n",
|
| 461 |
-
"final_grpo = os.path.join(GRPO_DIR, 'final')\n",
|
| 462 |
-
"if os.path.isdir(final_grpo):\n",
|
| 463 |
-
" ckpt_specs.append(('grpo-final', final_grpo, 1 + (max(int(re.search(r'checkpoint-(\\d+)', d).group(1)) for d in grpo_dirs) if grpo_dirs else 0) + 1, 'grpo'))\n",
|
| 464 |
-
"\n",
|
| 465 |
-
"overall = []\n",
|
| 466 |
-
"grouped = []\n",
|
| 467 |
-
"for label, path, step, kind in ckpt_specs:\n",
|
| 468 |
-
" print(f'eval {label} from {path}')\n",
|
| 469 |
-
" pol, m_ckpt = make_text_policy(path, kind)\n",
|
| 470 |
-
" overall.append(evaluate_checkpoint(step=step, policy=pol))\n",
|
| 471 |
-
" grouped.append(evaluate_checkpoint_by_family(step=step, policy=pol))\n",
|
| 472 |
-
" del m_ckpt\n",
|
| 473 |
-
" torch.cuda.empty_cache()\n",
|
| 474 |
-
"\n",
|
| 475 |
-
"print('\\nOVERALL CURVE:')\n",
|
| 476 |
-
"for c in overall:\n",
|
| 477 |
-
" print(f' step={c.step:4d} mean={c.mean_score:.4f}')\n",
|
| 478 |
-
"\n",
|
| 479 |
-
"print('\\nPER-FAMILY CURVE:')\n",
|
| 480 |
-
"for g in grouped:\n",
|
| 481 |
-
" line = f' step={g.step:4d}'\n",
|
| 482 |
-
" for fam in sorted(g.by_family.keys()):\n",
|
| 483 |
-
" line += f' {fam}={g.by_family[fam].mean_score:.3f}'\n",
|
| 484 |
-
" print(line)\n",
|
| 485 |
-
"\n",
|
| 486 |
-
"from runners.benchmark_runner import run_policy_sweep\n",
|
| 487 |
-
"sweep = run_policy_sweep()\n",
|
| 488 |
-
"heur_overall = next(s.mean_score for s in sweep.policies if s.policy == 'heuristic')\n",
|
| 489 |
-
"\n",
|
| 490 |
-
"FIG_DIR = os.path.join(REPO_DIR, 'docs', 'figures')\n",
|
| 491 |
-
"os.makedirs(FIG_DIR, exist_ok=True)\n",
|
| 492 |
-
"plot_training_curve(\n",
|
| 493 |
-
" overall, os.path.join(FIG_DIR, 'training_curve.png'),\n",
|
| 494 |
-
" baseline_scores={'heuristic': heur_overall, 'naive': 0.0},\n",
|
| 495 |
-
")\n",
|
| 496 |
-
"plot_training_curve_by_family(\n",
|
| 497 |
-
" grouped, os.path.join(FIG_DIR, 'training_curve_by_family.png'),\n",
|
| 498 |
-
" family_order=['easy', 'medium', 'hard', 'nightmare'],\n",
|
| 499 |
-
")\n",
|
| 500 |
-
"print(f'\\nfigures saved to {FIG_DIR}/')"
|
| 501 |
-
]
|
| 502 |
},
|
| 503 |
{
|
| 504 |
"cell_type": "markdown",
|
|
|
|
| 408 |
"id": "eval-code",
|
| 409 |
"metadata": {},
|
| 410 |
"outputs": [],
|
| 411 |
+
"source": "import glob, re, gc\nfrom peft import PeftModel\nfrom training.curve import (\n evaluate_checkpoint, evaluate_checkpoint_by_family,\n plot_training_curve, plot_training_curve_by_family,\n)\n\n# Free EVERYTHING from training cells before loading eval models. The GRPO\n# trainer's model/merged_base/sft_model are still pinning 6+GB of VRAM otherwise.\nfor name in ['model', 'merged_base', 'base_model', 'sft_model', 'sft_trainer',\n 'grpo_trainer', 'fresh_base', 'tmp_base', 'sft_for_merge']:\n if name in globals():\n del globals()[name]\ngc.collect()\ntorch.cuda.empty_cache()\nprint(f'before eval setup: VRAM {torch.cuda.memory_allocated()/1e9:.2f} GB')\n\n# Load BASE once. For 'base' and 'sft' eval, attach/detach SFT adapter on this.\neval_tok = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)\nif eval_tok.pad_token is None:\n eval_tok.pad_token = eval_tok.eos_token\neval_tok.padding_side = 'left'\n\neval_base = AutoModelForCausalLM.from_pretrained(\n MODEL_ID, torch_dtype=torch.float16, device_map='auto', trust_remote_code=True,\n)\neval_base.eval()\nprint(f'after eval_base load: VRAM {torch.cuda.memory_allocated()/1e9:.2f} GB')\n\n# For GRPO eval we need a SECOND base with SFT pre-merged. Load it once, hold\n# in VRAM. With 3B fp16 we need ~12-13 GB for both bases - fits T4 but tight.\nsft_merged_base = None\nif os.path.isdir(SFT_FINAL_DIR):\n tmp_base = AutoModelForCausalLM.from_pretrained(\n MODEL_ID, torch_dtype=torch.float16, device_map='auto', trust_remote_code=True,\n )\n sft_for_merge = PeftModel.from_pretrained(tmp_base, SFT_FINAL_DIR)\n sft_merged_base = sft_for_merge.merge_and_unload()\n sft_merged_base.eval()\n del sft_for_merge, tmp_base\n gc.collect(); torch.cuda.empty_cache()\n print(f'after sft_merged_base: VRAM {torch.cuda.memory_allocated()/1e9:.2f} GB')\n\n\ndef make_policy(adapter_path, adapter_kind):\n \"\"\"Return policy fn + cleanup fn. No new base loads. Adapters swap on\n pre-loaded eval_base / sft_merged_base.\"\"\"\n if adapter_kind == 'grpo':\n if sft_merged_base is None:\n raise RuntimeError('grpo eval needs SFT adapter merged first')\n target = sft_merged_base\n else:\n target = eval_base\n\n if adapter_kind == 'base' or adapter_path is None:\n m = target\n cleanup = lambda: None\n else:\n adapter_name = f'eval_{adapter_kind}_{abs(hash(adapter_path))}'\n m = PeftModel.from_pretrained(target, adapter_path, adapter_name=adapter_name)\n m.eval()\n def cleanup(_m=m, _name=adapter_name):\n try:\n _m.delete_adapter(_name)\n except Exception:\n pass\n gc.collect(); torch.cuda.empty_cache()\n\n def policy(prompt):\n chat = eval_tok.apply_chat_template(\n [{'role': 'user', 'content': prompt}],\n tokenize=False, add_generation_prompt=True,\n )\n inputs = eval_tok(chat, return_tensors='pt', truncation=True, max_length=1024).to(m.device)\n with torch.no_grad():\n out = m.generate(\n **inputs, max_new_tokens=192, do_sample=False,\n pad_token_id=eval_tok.eos_token_id, eos_token_id=eval_tok.eos_token_id,\n )\n return eval_tok.decode(out[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)\n return policy, cleanup\n\n\n# Catalog: untrained base, SFT final, every saved GRPO checkpoint + final.\nckpt_specs = [('base', None, 0, 'base'),\n ('sft', SFT_FINAL_DIR, 1, 'sft')]\ngrpo_dirs = sorted(\n glob.glob(os.path.join(GRPO_DIR, 'checkpoint-*')),\n key=lambda p: int(re.search(r'checkpoint-(\\d+)', p).group(1)),\n)\nfor d in grpo_dirs:\n step = int(re.search(r'checkpoint-(\\d+)', d).group(1))\n ckpt_specs.append((f'grpo-{step}', d, 1 + step, 'grpo'))\nfinal_grpo = os.path.join(GRPO_DIR, 'final')\nif os.path.isdir(final_grpo):\n last_step = max(int(re.search(r'checkpoint-(\\d+)', d).group(1)) for d in grpo_dirs) if grpo_dirs else 0\n ckpt_specs.append(('grpo-final', final_grpo, 1 + last_step + 1, 'grpo'))\n\noverall = []\ngrouped = []\nfor label, path, step, kind in ckpt_specs:\n print(f'eval {label} from {path}')\n pol, cleanup = make_policy(path, kind)\n overall.append(evaluate_checkpoint(step=step, policy=pol))\n grouped.append(evaluate_checkpoint_by_family(step=step, policy=pol))\n cleanup()\n\nprint('\\nOVERALL CURVE:')\nfor c in overall:\n print(f' step={c.step:4d} mean={c.mean_score:.4f}')\n\nprint('\\nPER-FAMILY CURVE:')\nfor g in grouped:\n line = f' step={g.step:4d}'\n for fam in sorted(g.by_family.keys()):\n line += f' {fam}={g.by_family[fam].mean_score:.3f}'\n print(line)\n\nfrom runners.benchmark_runner import run_policy_sweep\nsweep = run_policy_sweep()\nheur_overall = next(s.mean_score for s in sweep.policies if s.policy == 'heuristic')\n\nFIG_DIR = os.path.join(REPO_DIR, 'docs', 'figures')\nos.makedirs(FIG_DIR, exist_ok=True)\nplot_training_curve(\n overall, os.path.join(FIG_DIR, 'training_curve.png'),\n baseline_scores={'heuristic': heur_overall, 'naive': 0.0},\n)\nplot_training_curve_by_family(\n grouped, os.path.join(FIG_DIR, 'training_curve_by_family.png'),\n family_order=['easy', 'medium', 'hard', 'nightmare'],\n)\nprint(f'\\nfigures saved to {FIG_DIR}/')\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 412 |
},
|
| 413 |
{
|
| 414 |
"cell_type": "markdown",
|