parth-1 commited on
Commit
4ae43fc
·
verified ·
1 Parent(s): 0f7168b

Update grpo_train.py

Browse files
Files changed (1) hide show
  1. grpo_train.py +1 -0
grpo_train.py CHANGED
@@ -305,6 +305,7 @@ USE_4BIT = not torch.cuda.is_available() or torch.cuda.get_device_properties(0).
305
  model, tokenizer = FastLanguageModel.from_pretrained(
306
  model_name="unsloth/Llama-3.1-8B-Instruct",
307
  load_in_4bit=USE_4BIT,
 
308
  max_seq_length=2048,
309
  dtype=None, # auto-detect bf16 on A100
310
  )
 
305
  model, tokenizer = FastLanguageModel.from_pretrained(
306
  model_name="unsloth/Llama-3.1-8B-Instruct",
307
  load_in_4bit=USE_4BIT,
308
+ dtype = torch.bfloat16,
309
  max_seq_length=2048,
310
  dtype=None, # auto-detect bf16 on A100
311
  )