mitudrudutta commited on
Commit
8fcb6cf
·
1 Parent(s): e26c2ac

fix(eval): preload bases, swap adapters, fix peft offload_dir crash

Browse files

Old eval cell loaded a fresh 6.3GB base on every iteration. The
'del m_ckpt; torch.cuda.empty_cache()' pattern did NOT free GPU
because the policy closure captured 'm', which captured 'base'. By
the second iteration, T4's 15GB was full. accelerate's auto
device_map decided to offload some layers to CPU/disk, then peft 0.14
crashed with 'We need an offload_dir to dispatch this model'.

Fix:
- Free all training-cell globals (model, merged_base, sft_trainer,
grpo_trainer, etc) before any eval load.
- Load eval_base ONCE (for base + SFT eval).
- Pre-merge SFT into a second base ONCE (for GRPO eval).
- Per-checkpoint: attach adapter via PeftModel.from_pretrained with
unique adapter_name, run eval, then delete_adapter to release the
LoRA tensors. Base stays in VRAM the whole time.
- Cleanup function returned alongside policy fn so the loop can
release adapter weights without nuking the base.

notebooks/train_merchant_agent.ipynb CHANGED
@@ -408,97 +408,7 @@
408
  "id": "eval-code",
409
  "metadata": {},
410
  "outputs": [],
411
- "source": [
412
- "import glob, re\n",
413
- "from peft import PeftModel\n",
414
- "from training.curve import (\n",
415
- " evaluate_checkpoint, evaluate_checkpoint_by_family,\n",
416
- " plot_training_curve, plot_training_curve_by_family,\n",
417
- ")\n",
418
- "\n",
419
- "def make_text_policy(adapter_path, adapter_kind='base'):\n",
420
- " base = AutoModelForCausalLM.from_pretrained(\n",
421
- " MODEL_ID, torch_dtype=torch.float16, device_map='auto', trust_remote_code=True,\n",
422
- " )\n",
423
- " if adapter_kind == 'grpo':\n",
424
- " sft = PeftModel.from_pretrained(base, SFT_FINAL_DIR)\n",
425
- " base = sft.merge_and_unload()\n",
426
- " m = PeftModel.from_pretrained(base, adapter_path)\n",
427
- " elif adapter_path is not None:\n",
428
- " m = PeftModel.from_pretrained(base, adapter_path)\n",
429
- " else:\n",
430
- " m = base\n",
431
- " m.eval()\n",
432
- " tok = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)\n",
433
- " if tok.pad_token is None:\n",
434
- " tok.pad_token = tok.eos_token\n",
435
- " tok.padding_side = 'left'\n",
436
- "\n",
437
- " def policy(prompt):\n",
438
- " chat = tok.apply_chat_template(\n",
439
- " [{'role': 'user', 'content': prompt}],\n",
440
- " tokenize=False, add_generation_prompt=True,\n",
441
- " )\n",
442
- " inputs = tok(chat, return_tensors='pt', truncation=True, max_length=1024).to(m.device)\n",
443
- " with torch.no_grad():\n",
444
- " out = m.generate(\n",
445
- " **inputs, max_new_tokens=192, do_sample=False,\n",
446
- " pad_token_id=tok.eos_token_id, eos_token_id=tok.eos_token_id,\n",
447
- " )\n",
448
- " return tok.decode(out[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)\n",
449
- " return policy, m\n",
450
- "\n",
451
- "# Catalog of checkpoints: untrained base, SFT final, every saved GRPO checkpoint.\n",
452
- "ckpt_specs = [('base', None, 0, 'base'),\n",
453
- " ('sft', SFT_FINAL_DIR, 1, 'sft')]\n",
454
- "grpo_dirs = sorted(\n",
455
- " glob.glob(os.path.join(GRPO_DIR, 'checkpoint-*')),\n",
456
- " key=lambda p: int(re.search(r'checkpoint-(\\d+)', p).group(1)),\n",
457
- ")\n",
458
- "for d in grpo_dirs:\n",
459
- " step = int(re.search(r'checkpoint-(\\d+)', d).group(1))\n",
460
- " ckpt_specs.append((f'grpo-{step}', d, 1 + step, 'grpo'))\n",
461
- "final_grpo = os.path.join(GRPO_DIR, 'final')\n",
462
- "if os.path.isdir(final_grpo):\n",
463
- " ckpt_specs.append(('grpo-final', final_grpo, 1 + (max(int(re.search(r'checkpoint-(\\d+)', d).group(1)) for d in grpo_dirs) if grpo_dirs else 0) + 1, 'grpo'))\n",
464
- "\n",
465
- "overall = []\n",
466
- "grouped = []\n",
467
- "for label, path, step, kind in ckpt_specs:\n",
468
- " print(f'eval {label} from {path}')\n",
469
- " pol, m_ckpt = make_text_policy(path, kind)\n",
470
- " overall.append(evaluate_checkpoint(step=step, policy=pol))\n",
471
- " grouped.append(evaluate_checkpoint_by_family(step=step, policy=pol))\n",
472
- " del m_ckpt\n",
473
- " torch.cuda.empty_cache()\n",
474
- "\n",
475
- "print('\\nOVERALL CURVE:')\n",
476
- "for c in overall:\n",
477
- " print(f' step={c.step:4d} mean={c.mean_score:.4f}')\n",
478
- "\n",
479
- "print('\\nPER-FAMILY CURVE:')\n",
480
- "for g in grouped:\n",
481
- " line = f' step={g.step:4d}'\n",
482
- " for fam in sorted(g.by_family.keys()):\n",
483
- " line += f' {fam}={g.by_family[fam].mean_score:.3f}'\n",
484
- " print(line)\n",
485
- "\n",
486
- "from runners.benchmark_runner import run_policy_sweep\n",
487
- "sweep = run_policy_sweep()\n",
488
- "heur_overall = next(s.mean_score for s in sweep.policies if s.policy == 'heuristic')\n",
489
- "\n",
490
- "FIG_DIR = os.path.join(REPO_DIR, 'docs', 'figures')\n",
491
- "os.makedirs(FIG_DIR, exist_ok=True)\n",
492
- "plot_training_curve(\n",
493
- " overall, os.path.join(FIG_DIR, 'training_curve.png'),\n",
494
- " baseline_scores={'heuristic': heur_overall, 'naive': 0.0},\n",
495
- ")\n",
496
- "plot_training_curve_by_family(\n",
497
- " grouped, os.path.join(FIG_DIR, 'training_curve_by_family.png'),\n",
498
- " family_order=['easy', 'medium', 'hard', 'nightmare'],\n",
499
- ")\n",
500
- "print(f'\\nfigures saved to {FIG_DIR}/')"
501
- ]
502
  },
503
  {
504
  "cell_type": "markdown",
 
408
  "id": "eval-code",
409
  "metadata": {},
410
  "outputs": [],
411
+ "source": "import glob, re, gc\nfrom peft import PeftModel\nfrom training.curve import (\n evaluate_checkpoint, evaluate_checkpoint_by_family,\n plot_training_curve, plot_training_curve_by_family,\n)\n\n# Free EVERYTHING from training cells before loading eval models. The GRPO\n# trainer's model/merged_base/sft_model are still pinning 6+GB of VRAM otherwise.\nfor name in ['model', 'merged_base', 'base_model', 'sft_model', 'sft_trainer',\n 'grpo_trainer', 'fresh_base', 'tmp_base', 'sft_for_merge']:\n if name in globals():\n del globals()[name]\ngc.collect()\ntorch.cuda.empty_cache()\nprint(f'before eval setup: VRAM {torch.cuda.memory_allocated()/1e9:.2f} GB')\n\n# Load BASE once. For 'base' and 'sft' eval, attach/detach SFT adapter on this.\neval_tok = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)\nif eval_tok.pad_token is None:\n eval_tok.pad_token = eval_tok.eos_token\neval_tok.padding_side = 'left'\n\neval_base = AutoModelForCausalLM.from_pretrained(\n MODEL_ID, torch_dtype=torch.float16, device_map='auto', trust_remote_code=True,\n)\neval_base.eval()\nprint(f'after eval_base load: VRAM {torch.cuda.memory_allocated()/1e9:.2f} GB')\n\n# For GRPO eval we need a SECOND base with SFT pre-merged. Load it once, hold\n# in VRAM. With 3B fp16 we need ~12-13 GB for both bases - fits T4 but tight.\nsft_merged_base = None\nif os.path.isdir(SFT_FINAL_DIR):\n tmp_base = AutoModelForCausalLM.from_pretrained(\n MODEL_ID, torch_dtype=torch.float16, device_map='auto', trust_remote_code=True,\n )\n sft_for_merge = PeftModel.from_pretrained(tmp_base, SFT_FINAL_DIR)\n sft_merged_base = sft_for_merge.merge_and_unload()\n sft_merged_base.eval()\n del sft_for_merge, tmp_base\n gc.collect(); torch.cuda.empty_cache()\n print(f'after sft_merged_base: VRAM {torch.cuda.memory_allocated()/1e9:.2f} GB')\n\n\ndef make_policy(adapter_path, adapter_kind):\n \"\"\"Return policy fn + cleanup fn. No new base loads. Adapters swap on\n pre-loaded eval_base / sft_merged_base.\"\"\"\n if adapter_kind == 'grpo':\n if sft_merged_base is None:\n raise RuntimeError('grpo eval needs SFT adapter merged first')\n target = sft_merged_base\n else:\n target = eval_base\n\n if adapter_kind == 'base' or adapter_path is None:\n m = target\n cleanup = lambda: None\n else:\n adapter_name = f'eval_{adapter_kind}_{abs(hash(adapter_path))}'\n m = PeftModel.from_pretrained(target, adapter_path, adapter_name=adapter_name)\n m.eval()\n def cleanup(_m=m, _name=adapter_name):\n try:\n _m.delete_adapter(_name)\n except Exception:\n pass\n gc.collect(); torch.cuda.empty_cache()\n\n def policy(prompt):\n chat = eval_tok.apply_chat_template(\n [{'role': 'user', 'content': prompt}],\n tokenize=False, add_generation_prompt=True,\n )\n inputs = eval_tok(chat, return_tensors='pt', truncation=True, max_length=1024).to(m.device)\n with torch.no_grad():\n out = m.generate(\n **inputs, max_new_tokens=192, do_sample=False,\n pad_token_id=eval_tok.eos_token_id, eos_token_id=eval_tok.eos_token_id,\n )\n return eval_tok.decode(out[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)\n return policy, cleanup\n\n\n# Catalog: untrained base, SFT final, every saved GRPO checkpoint + final.\nckpt_specs = [('base', None, 0, 'base'),\n ('sft', SFT_FINAL_DIR, 1, 'sft')]\ngrpo_dirs = sorted(\n glob.glob(os.path.join(GRPO_DIR, 'checkpoint-*')),\n key=lambda p: int(re.search(r'checkpoint-(\\d+)', p).group(1)),\n)\nfor d in grpo_dirs:\n step = int(re.search(r'checkpoint-(\\d+)', d).group(1))\n ckpt_specs.append((f'grpo-{step}', d, 1 + step, 'grpo'))\nfinal_grpo = os.path.join(GRPO_DIR, 'final')\nif os.path.isdir(final_grpo):\n last_step = max(int(re.search(r'checkpoint-(\\d+)', d).group(1)) for d in grpo_dirs) if grpo_dirs else 0\n ckpt_specs.append(('grpo-final', final_grpo, 1 + last_step + 1, 'grpo'))\n\noverall = []\ngrouped = []\nfor label, path, step, kind in ckpt_specs:\n print(f'eval {label} from {path}')\n pol, cleanup = make_policy(path, kind)\n overall.append(evaluate_checkpoint(step=step, policy=pol))\n grouped.append(evaluate_checkpoint_by_family(step=step, policy=pol))\n cleanup()\n\nprint('\\nOVERALL CURVE:')\nfor c in overall:\n print(f' step={c.step:4d} mean={c.mean_score:.4f}')\n\nprint('\\nPER-FAMILY CURVE:')\nfor g in grouped:\n line = f' step={g.step:4d}'\n for fam in sorted(g.by_family.keys()):\n line += f' {fam}={g.by_family[fam].mean_score:.3f}'\n print(line)\n\nfrom runners.benchmark_runner import run_policy_sweep\nsweep = run_policy_sweep()\nheur_overall = next(s.mean_score for s in sweep.policies if s.policy == 'heuristic')\n\nFIG_DIR = os.path.join(REPO_DIR, 'docs', 'figures')\nos.makedirs(FIG_DIR, exist_ok=True)\nplot_training_curve(\n overall, os.path.join(FIG_DIR, 'training_curve.png'),\n baseline_scores={'heuristic': heur_overall, 'naive': 0.0},\n)\nplot_training_curve_by_family(\n grouped, os.path.join(FIG_DIR, 'training_curve_by_family.png'),\n family_order=['easy', 'medium', 'hard', 'nightmare'],\n)\nprint(f'\\nfigures saved to {FIG_DIR}/')\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
412
  },
413
  {
414
  "cell_type": "markdown",