ligaments-dev commited on
Commit
5f0d2a1
·
verified ·
1 Parent(s): 302b22a

Upload grpo_training.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. grpo_training.py +91 -194
grpo_training.py CHANGED
@@ -1,200 +1,97 @@
1
  # /// script
2
- # dependencies = ["trl>=0.12.0", "peft>=0.18.0", "transformers>=4.45.0", "torch>=2.0.0", "trackio", "wandb", "accelerate>=0.21.0", "bitsandbytes"]
3
  # ///
4
 
5
- import os
6
- import torch
7
  from datasets import load_dataset
8
- from peft import LoraConfig, get_peft_model, TaskType, prepare_model_for_kbit_training
9
- from transformers import (
10
- AutoTokenizer,
11
- AutoModelForCausalLM,
12
- BitsAndBytesConfig,
13
- TrainingArguments
14
- )
15
- from trl import GRPOTrainer, GRPOConfig
16
  import trackio
17
- import wandb
18
- from huggingface_hub import HfApi
19
-
20
- def main():
21
- # Initialize tracking
22
- trackio.init(project="sec_grpo_training", run_name="llama32_1b_sec_grpo")
23
-
24
- print("🚀 Starting GRPO training for SEC model...")
25
-
26
- # Configuration
27
- model_name = "ligaments-enterprise/llama3.2-1b-instruct-sec-finetuned"
28
- dataset_name = "ligaments-enterprise/sec-data-preferences"
29
- output_model = "ligaments-enterprise/llama3.2-1b-sec-grpo"
30
-
31
- # BitsAndBytesConfig for QLoRA (4-bit quantization)
32
- bnb_config = BitsAndBytesConfig(
33
- load_in_4bit=True,
34
- bnb_4bit_quant_type="nf4",
35
- bnb_4bit_compute_dtype=torch.float16,
36
- bnb_4bit_use_double_quant=True
37
- )
38
-
39
- # Load tokenizer and model
40
- print(f"📦 Loading model: {model_name}")
41
- tokenizer = AutoTokenizer.from_pretrained(model_name)
42
- if tokenizer.pad_token is None:
43
- tokenizer.pad_token = tokenizer.eos_token
44
-
45
- # Load model with QLoRA
46
- model = AutoModelForCausalLM.from_pretrained(
47
- model_name,
48
- quantization_config=bnb_config,
49
- device_map="auto",
50
- trust_remote_code=True,
51
- attn_implementation="flash_attention_2" if torch.cuda.get_device_capability()[0] >= 8 else None
52
- )
53
-
54
- # Prepare model for k-bit training
55
- model = prepare_model_for_kbit_training(model)
56
-
57
- # LoRA configuration for GRPO
58
- lora_config = LoraConfig(
59
- task_type=TaskType.CAUSAL_LM,
60
- r=16, # rank
61
- lora_alpha=32, # alpha scaling
62
- lora_dropout=0.1,
63
- target_modules=["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  bias="none",
65
- )
66
-
67
- # Apply LoRA to model
68
- model = get_peft_model(model, lora_config)
69
- model.print_trainable_parameters()
70
-
71
- # Load preference dataset
72
- print(f"📊 Loading preference dataset: {dataset_name}")
73
- dataset = load_dataset(dataset_name, split="train")
74
- print(f"✅ Loaded {len(dataset)} preference pairs")
75
-
76
- # Create train/eval split
77
- train_test_split = dataset.train_test_split(test_size=0.1, seed=42)
78
- train_dataset = train_test_split["train"]
79
- eval_dataset = train_test_split["test"]
80
-
81
- print(f"📈 Training samples: {len(train_dataset)}")
82
- print(f"📉 Evaluation samples: {len(eval_dataset)}")
83
-
84
- # GRPO Configuration
85
- training_args = GRPOConfig(
86
- output_dir="./grpo_sec_model",
87
-
88
- # Basic training settings
89
- num_train_epochs=2,
90
- per_device_train_batch_size=1,
91
- per_device_eval_batch_size=1,
92
- gradient_accumulation_steps=8, # Effective batch size = 8
93
-
94
- # Learning rate and optimization
95
- learning_rate=5e-6, # Lower LR for RL fine-tuning
96
- lr_scheduler_type="cosine",
97
- warmup_ratio=0.03,
98
-
99
- # Memory and efficiency
100
- gradient_checkpointing=True,
101
- dataloader_pin_memory=True,
102
- bf16=True,
103
- remove_unused_columns=False,
104
-
105
- # GRPO specific parameters
106
- beta=0.1, # KL penalty coefficient
107
- grpo_score_clip=5.0, # Clip scores to prevent instability
108
-
109
- # Evaluation and logging
110
- eval_strategy="steps",
111
- eval_steps=50,
112
- logging_steps=10,
113
- save_strategy="steps",
114
- save_steps=100,
115
- save_total_limit=3,
116
-
117
- # Tracking
118
- report_to=["trackio"],
119
- run_name="sec_grpo_training",
120
-
121
- # Hub integration
122
- push_to_hub=True,
123
- hub_model_id=output_model,
124
- hub_strategy="every_save",
125
-
126
- # Length settings
127
- max_length=512,
128
- max_prompt_length=256,
129
- )
130
-
131
- # Initialize GRPO Trainer
132
- print("🎯 Initializing GRPO Trainer...")
133
- trainer = GRPOTrainer(
134
- model=model,
135
- args=training_args,
136
- train_dataset=train_dataset,
137
- eval_dataset=eval_dataset,
138
- tokenizer=tokenizer,
139
- peft_config=lora_config,
140
- )
141
-
142
- # Log initial metrics
143
- trackio.log({
144
- "model_name": model_name,
145
- "dataset_name": dataset_name,
146
- "output_model": output_model,
147
- "train_samples": len(train_dataset),
148
- "eval_samples": len(eval_dataset),
149
- "lora_rank": lora_config.r,
150
- "lora_alpha": lora_config.lora_alpha,
151
- "beta": training_args.beta,
152
- "learning_rate": training_args.learning_rate,
153
- })
154
-
155
- # Start training
156
- print("🚀 Starting GRPO training...")
157
- try:
158
- trainer.train()
159
-
160
- # Log final metrics
161
- trainer_state = trainer.state
162
- trackio.log({
163
- "final_train_loss": trainer_state.log_history[-1].get("train_loss", 0),
164
- "final_eval_loss": trainer_state.log_history[-1].get("eval_loss", 0),
165
- "training_completed": True
166
- })
167
-
168
- # Save final model
169
- print("💾 Saving final model...")
170
- trainer.save_model()
171
-
172
- # Push to hub
173
- print("📤 Pushing to Hub...")
174
- trainer.push_to_hub(commit_message="GRPO training completed")
175
-
176
- print(f"✅ GRPO training completed successfully!")
177
- print(f"📦 Model saved to: {output_model}")
178
-
179
- # Create evaluation summary
180
- eval_summary = {
181
- "total_steps": trainer_state.global_step,
182
- "total_epochs": trainer_state.epoch,
183
- "final_train_loss": trainer_state.log_history[-1].get("train_loss", "N/A"),
184
- "final_eval_loss": trainer_state.log_history[-1].get("eval_loss", "N/A"),
185
- "model_parameters": sum(p.numel() for p in model.parameters() if p.requires_grad),
186
- }
187
-
188
- print("📊 Training Summary:")
189
- for key, value in eval_summary.items():
190
- print(f" {key}: {value}")
191
-
192
- trackio.log(eval_summary)
193
-
194
- except Exception as e:
195
- print(f"❌ Training failed: {e}")
196
- trackio.log({"error": str(e), "training_completed": False})
197
- raise e
198
-
199
- if __name__ == "__main__":
200
- main()
 
1
  # /// script
2
+ # dependencies = ["trl>=0.12.0", "peft>=0.7.0", "trackio", "torch", "transformers"]
3
  # ///
4
 
 
 
5
  from datasets import load_dataset
6
+ from peft import LoraConfig
7
+ from trl.trainer.grpo_trainer import GRPOTrainer, GRPOConfig
8
+ from transformers import AutoTokenizer
 
 
 
 
 
9
  import trackio
10
+ import torch
11
+
12
+ # Load your fine-tuned model and preference dataset
13
+ model_name = "ligaments-enterprise/llama3.2-1b-instruct-sec-finetuned"
14
+ dataset_name = "ligaments-enterprise/sec-data-preferences"
15
+ output_model = "ligaments-enterprise/llama3.2-1b-sec-grpo"
16
+
17
+ # Load dataset
18
+ dataset = load_dataset(dataset_name, split="train")
19
+ print(f"Loaded {len(dataset)} preference pairs from {dataset_name}")
20
+
21
+ # Create train/eval split for monitoring
22
+ dataset_split = dataset.train_test_split(test_size=0.1, seed=42)
23
+ train_dataset = dataset_split["train"]
24
+ eval_dataset = dataset_split["test"]
25
+
26
+ # Load tokenizer
27
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
28
+ if tokenizer.pad_token is None:
29
+ tokenizer.pad_token = tokenizer.eos_token
30
+
31
+ # Configure GRPO training
32
+ config = GRPOConfig(
33
+ output_dir=output_model,
34
+ num_train_epochs=3,
35
+ per_device_train_batch_size=1,
36
+ per_device_eval_batch_size=1,
37
+ gradient_accumulation_steps=8, # Effective batch size = 8
38
+ learning_rate=1e-6,
39
+ max_length=1024,
40
+
41
+ # Evaluation and logging
42
+ eval_strategy="steps",
43
+ eval_steps=50,
44
+ logging_steps=10,
45
+ save_strategy="steps",
46
+ save_steps=100,
47
+
48
+ # Hub integration
49
+ push_to_hub=True,
50
+ hub_model_id=output_model,
51
+ hub_strategy="every_save",
52
+
53
+ # Optimization
54
+ gradient_checkpointing=True,
55
+ bf16=True if torch.cuda.is_bf16_supported() else False,
56
+ fp16=False if torch.cuda.is_bf16_supported() else True,
57
+
58
+ # Trackio monitoring
59
+ report_to="trackio",
60
+ run_name="llama3.2-1b-sec-grpo-training",
61
+ project="ligaments-sec-alignment",
62
+
63
+ # GRPO specific parameters
64
+ kl_penalty="kl", # KL penalty for policy regularization
65
+ temperature=0.7,
66
+ )
67
+
68
+ # Initialize GRPO trainer
69
+ trainer = GRPOTrainer(
70
+ model=model_name,
71
+ tokenizer=tokenizer,
72
+ train_dataset=train_dataset,
73
+ eval_dataset=eval_dataset,
74
+ peft_config=LoraConfig(
75
+ r=16,
76
+ lora_alpha=32,
77
+ target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
78
+ lora_dropout=0.05,
79
  bias="none",
80
+ task_type="CAUSAL_LM"
81
+ ),
82
+ args=config,
83
+ )
84
+
85
+ print("Starting GRPO training...")
86
+ print(f"Training on {len(train_dataset)} preference pairs")
87
+ print(f"Evaluating on {len(eval_dataset)} preference pairs")
88
+ print(f"Output model will be saved to: {output_model}")
89
+
90
+ # Train the model
91
+ trainer.train()
92
+
93
+ # Push final model to Hub
94
+ trainer.push_to_hub()
95
+
96
+ print("GRPO training completed successfully!")
97
+ print(f"Final model available at: https://huggingface.co/{output_model}")