shank commited on
Commit Β·
8bd8552
1
Parent(s): cf25957
Add --hub-resume and --step-offset to allow resuming directly from HF Hub PEFT checkpoints
Browse files- training/train_grpo.py +23 -16
training/train_grpo.py
CHANGED
|
@@ -33,7 +33,9 @@ parser = argparse.ArgumentParser()
|
|
| 33 |
parser.add_argument("--test", action="store_true", help="Run 10 steps for testing (Colab/GPU)")
|
| 34 |
parser.add_argument("--test-local", action="store_true", dest="test_local",
|
| 35 |
help="Sanity-check reward function locally without any model or GPU")
|
| 36 |
-
parser.add_argument("--resume", type=str, default=None, help="
|
|
|
|
|
|
|
| 37 |
parser.add_argument("--max_steps", type=int, default=500)
|
| 38 |
args = parser.parse_args()
|
| 39 |
|
|
@@ -116,7 +118,7 @@ from server.models import parse_agent_output
|
|
| 116 |
# ββ Configuration βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 117 |
MODEL_NAME = "Qwen/Qwen2.5-Coder-3B-Instruct"
|
| 118 |
HF_REPO = "shashaank0707/AgentDebugger-trained"
|
| 119 |
-
MAX_STEPS = 10 if args.test else args.max_steps
|
| 120 |
CHECKPOINT_DIR = "./checkpoints"
|
| 121 |
|
| 122 |
# W&B and HF Token
|
|
@@ -353,16 +355,21 @@ model = AutoModelForCausalLM.from_pretrained(
|
|
| 353 |
)
|
| 354 |
model.config.use_cache = False
|
| 355 |
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 366 |
model.enable_input_require_grads()
|
| 367 |
model.gradient_checkpointing_enable()
|
| 368 |
print(f"Trainable params: {model.num_parameters(only_trainable=True):,}")
|
|
@@ -507,8 +514,8 @@ trainer = GRPOTrainer(
|
|
| 507 |
|
| 508 |
# ββ Curriculum callback βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 509 |
class CurriculumCallback(TrainerCallback):
|
| 510 |
-
def on_step_end(self,
|
| 511 |
-
step = state.global_step
|
| 512 |
if step in [150, 350]:
|
| 513 |
trainer.train_dataset = make_dataset(step)
|
| 514 |
print(f"\nCurriculum advanced at step {step}!")
|
|
@@ -527,8 +534,8 @@ HUB_PUSH_EVERY = 50 # push every 50 steps β ~15min on T4, ~5min on A100
|
|
| 527 |
class CheckpointPushCallback(TrainerCallback):
|
| 528 |
"""Push LoRA adapter to HF Hub every HUB_PUSH_EVERY steps."""
|
| 529 |
|
| 530 |
-
def on_step_end(self,
|
| 531 |
-
step = state.global_step
|
| 532 |
if not HF_TOKEN or step == 0 or step % HUB_PUSH_EVERY != 0:
|
| 533 |
return
|
| 534 |
try:
|
|
|
|
| 33 |
parser.add_argument("--test", action="store_true", help="Run 10 steps for testing (Colab/GPU)")
|
| 34 |
parser.add_argument("--test-local", action="store_true", dest="test_local",
|
| 35 |
help="Sanity-check reward function locally without any model or GPU")
|
| 36 |
+
parser.add_argument("--resume", type=str, default=None, help="Local path to checkpoint")
|
| 37 |
+
parser.add_argument("--hub-resume", type=str, default=None, help="HF repo to resume LoRA from (e.g. shashaank0707/AgentDebugger-trained-checkpoints)")
|
| 38 |
+
parser.add_argument("--step-offset", type=int, default=0, help="Steps already completed if using hub-resume")
|
| 39 |
parser.add_argument("--max_steps", type=int, default=500)
|
| 40 |
args = parser.parse_args()
|
| 41 |
|
|
|
|
| 118 |
# ββ Configuration βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 119 |
MODEL_NAME = "Qwen/Qwen2.5-Coder-3B-Instruct"
|
| 120 |
HF_REPO = "shashaank0707/AgentDebugger-trained"
|
| 121 |
+
MAX_STEPS = (10 if args.test else args.max_steps) - args.step_offset
|
| 122 |
CHECKPOINT_DIR = "./checkpoints"
|
| 123 |
|
| 124 |
# W&B and HF Token
|
|
|
|
| 355 |
)
|
| 356 |
model.config.use_cache = False
|
| 357 |
|
| 358 |
+
if args.hub_resume:
|
| 359 |
+
from peft import PeftModel
|
| 360 |
+
print(f"\nResuming LoRA adapter from Hub: {args.hub_resume}")
|
| 361 |
+
model = PeftModel.from_pretrained(model, args.hub_resume, is_trainable=True)
|
| 362 |
+
else:
|
| 363 |
+
lora_config = LoraConfig(
|
| 364 |
+
r=_lora_r,
|
| 365 |
+
lora_alpha=_lora_r * 2,
|
| 366 |
+
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
|
| 367 |
+
"gate_proj", "up_proj", "down_proj"],
|
| 368 |
+
lora_dropout=0.0,
|
| 369 |
+
bias="none",
|
| 370 |
+
task_type=TaskType.CAUSAL_LM,
|
| 371 |
+
)
|
| 372 |
+
model = get_peft_model(model, lora_config)
|
| 373 |
model.enable_input_require_grads()
|
| 374 |
model.gradient_checkpointing_enable()
|
| 375 |
print(f"Trainable params: {model.num_parameters(only_trainable=True):,}")
|
|
|
|
| 514 |
|
| 515 |
# ββ Curriculum callback βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 516 |
class CurriculumCallback(TrainerCallback):
|
| 517 |
+
def on_step_end(self, callback_args, state, control, **kwargs):
|
| 518 |
+
step = state.global_step + args.step_offset
|
| 519 |
if step in [150, 350]:
|
| 520 |
trainer.train_dataset = make_dataset(step)
|
| 521 |
print(f"\nCurriculum advanced at step {step}!")
|
|
|
|
| 534 |
class CheckpointPushCallback(TrainerCallback):
|
| 535 |
"""Push LoRA adapter to HF Hub every HUB_PUSH_EVERY steps."""
|
| 536 |
|
| 537 |
+
def on_step_end(self, callback_args, state, control, **kwargs):
|
| 538 |
+
step = state.global_step + args.step_offset
|
| 539 |
if not HF_TOKEN or step == 0 or step % HUB_PUSH_EVERY != 0:
|
| 540 |
return
|
| 541 |
try:
|