ligaments-dev commited on
Commit
4eff6b5
·
verified ·
1 Parent(s): 28a17b6

Upload grpo_training.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. grpo_training.py +11 -3
grpo_training.py CHANGED
@@ -5,7 +5,7 @@
5
  from datasets import load_dataset
6
  from peft import LoraConfig
7
  from trl.trainer.grpo_trainer import GRPOTrainer, GRPOConfig
8
- from transformers import AutoTokenizer
9
  import trackio
10
  import torch
11
 
@@ -23,11 +23,18 @@ dataset_split = dataset.train_test_split(test_size=0.1, seed=42)
23
  train_dataset = dataset_split["train"]
24
  eval_dataset = dataset_split["test"]
25
 
26
- # Load tokenizer
27
  tokenizer = AutoTokenizer.from_pretrained(model_name)
28
  if tokenizer.pad_token is None:
29
  tokenizer.pad_token = tokenizer.eos_token
30
 
 
 
 
 
 
 
 
31
  # Configure GRPO training
32
  config = GRPOConfig(
33
  output_dir=output_model,
@@ -79,7 +86,8 @@ def preference_reward_func(samples):
79
 
80
  # Initialize GRPO trainer
81
  trainer = GRPOTrainer(
82
- model=model_name,
 
83
  reward_funcs=[preference_reward_func],
84
  train_dataset=train_dataset,
85
  eval_dataset=eval_dataset,
 
5
  from datasets import load_dataset
6
  from peft import LoraConfig
7
  from trl.trainer.grpo_trainer import GRPOTrainer, GRPOConfig
8
+ from transformers import AutoTokenizer, AutoModelForCausalLM
9
  import trackio
10
  import torch
11
 
 
23
  train_dataset = dataset_split["train"]
24
  eval_dataset = dataset_split["test"]
25
 
26
+ # Load tokenizer and model
27
  tokenizer = AutoTokenizer.from_pretrained(model_name)
28
  if tokenizer.pad_token is None:
29
  tokenizer.pad_token = tokenizer.eos_token
30
 
31
+ # Load the model explicitly
32
+ model = AutoModelForCausalLM.from_pretrained(
33
+ model_name,
34
+ torch_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16,
35
+ device_map="auto"
36
+ )
37
+
38
  # Configure GRPO training
39
  config = GRPOConfig(
40
  output_dir=output_model,
 
86
 
87
  # Initialize GRPO trainer
88
  trainer = GRPOTrainer(
89
+ model=model,
90
+ tokenizer=tokenizer,
91
  reward_funcs=[preference_reward_func],
92
  train_dataset=train_dataset,
93
  eval_dataset=eval_dataset,