shank commited on
Commit
8bd8552
Β·
1 Parent(s): cf25957

Add --hub-resume and --step-offset to allow resuming directly from HF Hub PEFT checkpoints

Browse files
Files changed (1) hide show
  1. training/train_grpo.py +23 -16
training/train_grpo.py CHANGED
@@ -33,7 +33,9 @@ parser = argparse.ArgumentParser()
33
  parser.add_argument("--test", action="store_true", help="Run 10 steps for testing (Colab/GPU)")
34
  parser.add_argument("--test-local", action="store_true", dest="test_local",
35
  help="Sanity-check reward function locally without any model or GPU")
36
- parser.add_argument("--resume", type=str, default=None, help="Path to checkpoint")
 
 
37
  parser.add_argument("--max_steps", type=int, default=500)
38
  args = parser.parse_args()
39
 
@@ -116,7 +118,7 @@ from server.models import parse_agent_output
116
  # ── Configuration ─────────────────────────────────────────────────────────────
117
  MODEL_NAME = "Qwen/Qwen2.5-Coder-3B-Instruct"
118
  HF_REPO = "shashaank0707/AgentDebugger-trained"
119
- MAX_STEPS = 10 if args.test else args.max_steps
120
  CHECKPOINT_DIR = "./checkpoints"
121
 
122
  # W&B and HF Token
@@ -353,16 +355,21 @@ model = AutoModelForCausalLM.from_pretrained(
353
  )
354
  model.config.use_cache = False
355
 
356
- lora_config = LoraConfig(
357
- r=_lora_r,
358
- lora_alpha=_lora_r * 2,
359
- target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
360
- "gate_proj", "up_proj", "down_proj"],
361
- lora_dropout=0.0,
362
- bias="none",
363
- task_type=TaskType.CAUSAL_LM,
364
- )
365
- model = get_peft_model(model, lora_config)
 
 
 
 
 
366
  model.enable_input_require_grads()
367
  model.gradient_checkpointing_enable()
368
  print(f"Trainable params: {model.num_parameters(only_trainable=True):,}")
@@ -507,8 +514,8 @@ trainer = GRPOTrainer(
507
 
508
  # ── Curriculum callback ───────────────────────────────────────────────────────
509
  class CurriculumCallback(TrainerCallback):
510
- def on_step_end(self, args, state, control, **kwargs):
511
- step = state.global_step
512
  if step in [150, 350]:
513
  trainer.train_dataset = make_dataset(step)
514
  print(f"\nCurriculum advanced at step {step}!")
@@ -527,8 +534,8 @@ HUB_PUSH_EVERY = 50 # push every 50 steps β€” ~15min on T4, ~5min on A100
527
  class CheckpointPushCallback(TrainerCallback):
528
  """Push LoRA adapter to HF Hub every HUB_PUSH_EVERY steps."""
529
 
530
- def on_step_end(self, args, state, control, **kwargs):
531
- step = state.global_step
532
  if not HF_TOKEN or step == 0 or step % HUB_PUSH_EVERY != 0:
533
  return
534
  try:
 
33
  parser.add_argument("--test", action="store_true", help="Run 10 steps for testing (Colab/GPU)")
34
  parser.add_argument("--test-local", action="store_true", dest="test_local",
35
  help="Sanity-check reward function locally without any model or GPU")
36
+ parser.add_argument("--resume", type=str, default=None, help="Local path to checkpoint")
37
+ parser.add_argument("--hub-resume", type=str, default=None, help="HF repo to resume LoRA from (e.g. shashaank0707/AgentDebugger-trained-checkpoints)")
38
+ parser.add_argument("--step-offset", type=int, default=0, help="Steps already completed if using hub-resume")
39
  parser.add_argument("--max_steps", type=int, default=500)
40
  args = parser.parse_args()
41
 
 
118
  # ── Configuration ─────────────────────────────────────────────────────────────
119
  MODEL_NAME = "Qwen/Qwen2.5-Coder-3B-Instruct"
120
  HF_REPO = "shashaank0707/AgentDebugger-trained"
121
+ MAX_STEPS = (10 if args.test else args.max_steps) - args.step_offset
122
  CHECKPOINT_DIR = "./checkpoints"
123
 
124
  # W&B and HF Token
 
355
  )
356
  model.config.use_cache = False
357
 
358
+ if args.hub_resume:
359
+ from peft import PeftModel
360
+ print(f"\nResuming LoRA adapter from Hub: {args.hub_resume}")
361
+ model = PeftModel.from_pretrained(model, args.hub_resume, is_trainable=True)
362
+ else:
363
+ lora_config = LoraConfig(
364
+ r=_lora_r,
365
+ lora_alpha=_lora_r * 2,
366
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
367
+ "gate_proj", "up_proj", "down_proj"],
368
+ lora_dropout=0.0,
369
+ bias="none",
370
+ task_type=TaskType.CAUSAL_LM,
371
+ )
372
+ model = get_peft_model(model, lora_config)
373
  model.enable_input_require_grads()
374
  model.gradient_checkpointing_enable()
375
  print(f"Trainable params: {model.num_parameters(only_trainable=True):,}")
 
514
 
515
  # ── Curriculum callback ───────────────────────────────────────────────────────
516
  class CurriculumCallback(TrainerCallback):
517
+ def on_step_end(self, callback_args, state, control, **kwargs):
518
+ step = state.global_step + args.step_offset
519
  if step in [150, 350]:
520
  trainer.train_dataset = make_dataset(step)
521
  print(f"\nCurriculum advanced at step {step}!")
 
534
  class CheckpointPushCallback(TrainerCallback):
535
  """Push LoRA adapter to HF Hub every HUB_PUSH_EVERY steps."""
536
 
537
+ def on_step_end(self, callback_args, state, control, **kwargs):
538
+ step = state.global_step + args.step_offset
539
  if not HF_TOKEN or step == 0 or step % HUB_PUSH_EVERY != 0:
540
  return
541
  try: