parth-1 commited on
Commit
2961503
·
verified ·
1 Parent(s): a5a9c5a

Update grpo_train.py

Browse files
Files changed (1) hide show
  1. grpo_train.py +4 -4
grpo_train.py CHANGED
@@ -302,7 +302,7 @@ model, tokenizer = FastLanguageModel.from_pretrained(
302
  model_name="unsloth/Llama-3.1-8B-Instruct",
303
  load_in_4bit=True, # Strictly True for L4 24GB
304
  max_seq_length=2048,
305
- dtype=torch.bfloat16, # L4 Native Support
306
  )
307
 
308
  model = FastLanguageModel.get_peft_model(
@@ -339,8 +339,8 @@ trainer = GRPOTrainer(
339
  max_completion_length=128,
340
  logging_steps=5,
341
  warmup_ratio=0.1,
342
- bf16=True,
343
- fp16=False,
344
  report_to="none",
345
  ),
346
  train_dataset=dataset,
@@ -371,7 +371,7 @@ if __name__ == "__main__":
371
  tokenizer.save_pretrained(LORA_DIR)
372
  print(f"LoRA adapter saved to {LORA_DIR}")
373
 
374
- print("Merging adapter into base model (bf16)...")
375
  merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(
376
  model_name=LORA_DIR,
377
  load_in_4bit=False,
 
302
  model_name="unsloth/Llama-3.1-8B-Instruct",
303
  load_in_4bit=True, # Strictly True for L4 24GB
304
  max_seq_length=2048,
305
+ dtype=torch.float16, # PERFECT ALIGNMENT: 4-bit uses fp16 math natively
306
  )
307
 
308
  model = FastLanguageModel.get_peft_model(
 
339
  max_completion_length=128,
340
  logging_steps=5,
341
  warmup_ratio=0.1,
342
+ bf16=False, # DISABLED TO PREVENT CLASH
343
+ fp16=True, # ENABLED TO MATCH MODEL DTYPE
344
  report_to="none",
345
  ),
346
  train_dataset=dataset,
 
371
  tokenizer.save_pretrained(LORA_DIR)
372
  print(f"LoRA adapter saved to {LORA_DIR}")
373
 
374
+ print("Merging adapter into base model (fp16)...")
375
  merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(
376
  model_name=LORA_DIR,
377
  load_in_4bit=False,