parth-1 commited on
Commit
6557131
·
verified ·
1 Parent(s): 08bdaa0

Update grpo_train.py

Browse files
Files changed (1) hide show
  1. grpo_train.py +1 -1
grpo_train.py CHANGED
@@ -362,7 +362,7 @@ model, tokenizer = FastLanguageModel.from_pretrained(
362
  model_name="unsloth/Llama-3.1-8B-Instruct",
363
  load_in_4bit=USE_4BIT,
364
  max_seq_length=2048,
365
- dtype=None,
366
  )
367
 
368
  model = FastLanguageModel.get_peft_model(
 
362
  model_name="unsloth/Llama-3.1-8B-Instruct",
363
  load_in_4bit=USE_4BIT,
364
  max_seq_length=2048,
365
+ dtype=torch.float16 if USE_4BIT else None,
366
  )
367
 
368
  model = FastLanguageModel.get_peft_model(