ligaments-dev commited on
Commit
2ce6fb9
·
verified ·
1 Parent(s): 4eff6b5

Upload grpo_training.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. grpo_training.py +1 -2
grpo_training.py CHANGED
@@ -31,7 +31,7 @@ if tokenizer.pad_token is None:
31
  # Load the model explicitly
32
  model = AutoModelForCausalLM.from_pretrained(
33
  model_name,
34
- torch_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16,
35
  device_map="auto"
36
  )
37
 
@@ -87,7 +87,6 @@ def preference_reward_func(samples):
87
  # Initialize GRPO trainer
88
  trainer = GRPOTrainer(
89
  model=model,
90
- tokenizer=tokenizer,
91
  reward_funcs=[preference_reward_func],
92
  train_dataset=train_dataset,
93
  eval_dataset=eval_dataset,
 
31
  # Load the model explicitly
32
  model = AutoModelForCausalLM.from_pretrained(
33
  model_name,
34
+ dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16,
35
  device_map="auto"
36
  )
37
 
 
87
  # Initialize GRPO trainer
88
  trainer = GRPOTrainer(
89
  model=model,
 
90
  reward_funcs=[preference_reward_func],
91
  train_dataset=train_dataset,
92
  eval_dataset=eval_dataset,