CreativeEngineer commited on
Commit
7aa68cd
·
1 Parent(s): 1ee2461

Fix resume adapter training (no peft_config with PeftModel)

Browse files
Files changed (1) hide show
  1. app.py +14 -9
app.py CHANGED
@@ -517,10 +517,12 @@ def run_training(model_name, chunk_steps, max_total_steps, max_minutes, auto_con
517
  _try_download_adapter(add_log)
518
 
519
  # Resume LoRA adapter if present
 
520
  if os.path.isdir(ADAPTER_DIR) and os.path.exists(os.path.join(ADAPTER_DIR, "adapter_config.json")):
521
  add_log("Loading existing LoRA adapter (resume)...")
522
  model = PeftModel.from_pretrained(base_model, ADAPTER_DIR, is_trainable=True)
523
  add_log("✓ Adapter loaded")
 
524
  else:
525
  model = base_model
526
 
@@ -609,15 +611,18 @@ def run_training(model_name, chunk_steps, max_total_steps, max_minutes, auto_con
609
  num_generations=4,
610
  )
611
 
612
- trainer = GRPOTrainer(
613
- model=model,
614
- args=config,
615
- train_dataset=dataset,
616
- reward_funcs=perf_takehome_reward_fn,
617
- peft_config=lora_config,
618
- processing_class=tokenizer,
619
- callbacks=[VLIWCallback()],
620
- )
 
 
 
621
 
622
  train_result = trainer.train()
623
  metrics = train_result.metrics
 
517
  _try_download_adapter(add_log)
518
 
519
  # Resume LoRA adapter if present
520
+ resume_adapter = False
521
  if os.path.isdir(ADAPTER_DIR) and os.path.exists(os.path.join(ADAPTER_DIR, "adapter_config.json")):
522
  add_log("Loading existing LoRA adapter (resume)...")
523
  model = PeftModel.from_pretrained(base_model, ADAPTER_DIR, is_trainable=True)
524
  add_log("✓ Adapter loaded")
525
+ resume_adapter = True
526
  else:
527
  model = base_model
528
 
 
611
  num_generations=4,
612
  )
613
 
614
+ trainer_kwargs = {
615
+ "model": model,
616
+ "args": config,
617
+ "train_dataset": dataset,
618
+ "reward_funcs": perf_takehome_reward_fn,
619
+ "processing_class": tokenizer,
620
+ "callbacks": [VLIWCallback()],
621
+ }
622
+ if not resume_adapter:
623
+ trainer_kwargs["peft_config"] = lora_config
624
+
625
+ trainer = GRPOTrainer(**trainer_kwargs)
626
 
627
  train_result = trainer.train()
628
  metrics = train_result.metrics