Commit
·
7aa68cd
1
Parent(s):
1ee2461
Fix resume adapter training (no peft_config with PeftModel)
Browse files
app.py
CHANGED
|
@@ -517,10 +517,12 @@ def run_training(model_name, chunk_steps, max_total_steps, max_minutes, auto_con
|
|
| 517 |
_try_download_adapter(add_log)
|
| 518 |
|
| 519 |
# Resume LoRA adapter if present
|
|
|
|
| 520 |
if os.path.isdir(ADAPTER_DIR) and os.path.exists(os.path.join(ADAPTER_DIR, "adapter_config.json")):
|
| 521 |
add_log("Loading existing LoRA adapter (resume)...")
|
| 522 |
model = PeftModel.from_pretrained(base_model, ADAPTER_DIR, is_trainable=True)
|
| 523 |
add_log("✓ Adapter loaded")
|
|
|
|
| 524 |
else:
|
| 525 |
model = base_model
|
| 526 |
|
|
@@ -609,15 +611,18 @@ def run_training(model_name, chunk_steps, max_total_steps, max_minutes, auto_con
|
|
| 609 |
num_generations=4,
|
| 610 |
)
|
| 611 |
|
| 612 |
-
|
| 613 |
-
model
|
| 614 |
-
args
|
| 615 |
-
train_dataset
|
| 616 |
-
reward_funcs
|
| 617 |
-
|
| 618 |
-
|
| 619 |
-
|
| 620 |
-
|
|
|
|
|
|
|
|
|
|
| 621 |
|
| 622 |
train_result = trainer.train()
|
| 623 |
metrics = train_result.metrics
|
|
|
|
| 517 |
_try_download_adapter(add_log)
|
| 518 |
|
| 519 |
# Resume LoRA adapter if present
|
| 520 |
+
resume_adapter = False
|
| 521 |
if os.path.isdir(ADAPTER_DIR) and os.path.exists(os.path.join(ADAPTER_DIR, "adapter_config.json")):
|
| 522 |
add_log("Loading existing LoRA adapter (resume)...")
|
| 523 |
model = PeftModel.from_pretrained(base_model, ADAPTER_DIR, is_trainable=True)
|
| 524 |
add_log("✓ Adapter loaded")
|
| 525 |
+
resume_adapter = True
|
| 526 |
else:
|
| 527 |
model = base_model
|
| 528 |
|
|
|
|
| 611 |
num_generations=4,
|
| 612 |
)
|
| 613 |
|
| 614 |
+
trainer_kwargs = {
|
| 615 |
+
"model": model,
|
| 616 |
+
"args": config,
|
| 617 |
+
"train_dataset": dataset,
|
| 618 |
+
"reward_funcs": perf_takehome_reward_fn,
|
| 619 |
+
"processing_class": tokenizer,
|
| 620 |
+
"callbacks": [VLIWCallback()],
|
| 621 |
+
}
|
| 622 |
+
if not resume_adapter:
|
| 623 |
+
trainer_kwargs["peft_config"] = lora_config
|
| 624 |
+
|
| 625 |
+
trainer = GRPOTrainer(**trainer_kwargs)
|
| 626 |
|
| 627 |
train_result = trainer.train()
|
| 628 |
metrics = train_result.metrics
|