Spaces:
Runtime error
Runtime error
Update grpo_train.py
Browse files- grpo_train.py +4 -4
grpo_train.py
CHANGED
|
@@ -302,7 +302,7 @@ model, tokenizer = FastLanguageModel.from_pretrained(
|
|
| 302 |
model_name="unsloth/Llama-3.1-8B-Instruct",
|
| 303 |
load_in_4bit=True, # Strictly True for L4 24GB
|
| 304 |
max_seq_length=2048,
|
| 305 |
-
dtype=torch.
|
| 306 |
)
|
| 307 |
|
| 308 |
model = FastLanguageModel.get_peft_model(
|
|
@@ -339,8 +339,8 @@ trainer = GRPOTrainer(
|
|
| 339 |
max_completion_length=128,
|
| 340 |
logging_steps=5,
|
| 341 |
warmup_ratio=0.1,
|
| 342 |
-
bf16=
|
| 343 |
-
fp16=
|
| 344 |
report_to="none",
|
| 345 |
),
|
| 346 |
train_dataset=dataset,
|
|
@@ -371,7 +371,7 @@ if __name__ == "__main__":
|
|
| 371 |
tokenizer.save_pretrained(LORA_DIR)
|
| 372 |
print(f"LoRA adapter saved to {LORA_DIR}")
|
| 373 |
|
| 374 |
-
print("Merging adapter into base model (
|
| 375 |
merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(
|
| 376 |
model_name=LORA_DIR,
|
| 377 |
load_in_4bit=False,
|
|
|
|
| 302 |
model_name="unsloth/Llama-3.1-8B-Instruct",
|
| 303 |
load_in_4bit=True, # Strictly True for L4 24GB
|
| 304 |
max_seq_length=2048,
|
| 305 |
+
dtype=torch.float16, # PERFECT ALIGNMENT: 4-bit uses fp16 math natively
|
| 306 |
)
|
| 307 |
|
| 308 |
model = FastLanguageModel.get_peft_model(
|
|
|
|
| 339 |
max_completion_length=128,
|
| 340 |
logging_steps=5,
|
| 341 |
warmup_ratio=0.1,
|
| 342 |
+
bf16=False, # DISABLED TO PREVENT CLASH
|
| 343 |
+
fp16=True, # ENABLED TO MATCH MODEL DTYPE
|
| 344 |
report_to="none",
|
| 345 |
),
|
| 346 |
train_dataset=dataset,
|
|
|
|
| 371 |
tokenizer.save_pretrained(LORA_DIR)
|
| 372 |
print(f"LoRA adapter saved to {LORA_DIR}")
|
| 373 |
|
| 374 |
+
print("Merging adapter into base model (fp16)...")
|
| 375 |
merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(
|
| 376 |
model_name=LORA_DIR,
|
| 377 |
load_in_4bit=False,
|