ligaments-dev commited on
Commit
302b22a
Β·
verified Β·
1 Parent(s): e2bfd5b

GRPO training script for SEC model

Browse files
Files changed (1) hide show
  1. grpo_training.py +200 -0
grpo_training.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # dependencies = ["trl>=0.12.0", "peft>=0.18.0", "transformers>=4.45.0", "torch>=2.0.0", "trackio", "wandb", "accelerate>=0.21.0", "bitsandbytes"]
3
+ # ///
4
+
5
+ import os
6
+ import torch
7
+ from datasets import load_dataset
8
+ from peft import LoraConfig, get_peft_model, TaskType, prepare_model_for_kbit_training
9
+ from transformers import (
10
+ AutoTokenizer,
11
+ AutoModelForCausalLM,
12
+ BitsAndBytesConfig,
13
+ TrainingArguments
14
+ )
15
+ from trl import GRPOTrainer, GRPOConfig
16
+ import trackio
17
+ import wandb
18
+ from huggingface_hub import HfApi
19
+
20
+ def main():
21
+ # Initialize tracking
22
+ trackio.init(project="sec_grpo_training", run_name="llama32_1b_sec_grpo")
23
+
24
+ print("πŸš€ Starting GRPO training for SEC model...")
25
+
26
+ # Configuration
27
+ model_name = "ligaments-enterprise/llama3.2-1b-instruct-sec-finetuned"
28
+ dataset_name = "ligaments-enterprise/sec-data-preferences"
29
+ output_model = "ligaments-enterprise/llama3.2-1b-sec-grpo"
30
+
31
+ # BitsAndBytesConfig for QLoRA (4-bit quantization)
32
+ bnb_config = BitsAndBytesConfig(
33
+ load_in_4bit=True,
34
+ bnb_4bit_quant_type="nf4",
35
+ bnb_4bit_compute_dtype=torch.float16,
36
+ bnb_4bit_use_double_quant=True
37
+ )
38
+
39
+ # Load tokenizer and model
40
+ print(f"πŸ“¦ Loading model: {model_name}")
41
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
42
+ if tokenizer.pad_token is None:
43
+ tokenizer.pad_token = tokenizer.eos_token
44
+
45
+ # Load model with QLoRA
46
+ model = AutoModelForCausalLM.from_pretrained(
47
+ model_name,
48
+ quantization_config=bnb_config,
49
+ device_map="auto",
50
+ trust_remote_code=True,
51
+ attn_implementation="flash_attention_2" if torch.cuda.get_device_capability()[0] >= 8 else None
52
+ )
53
+
54
+ # Prepare model for k-bit training
55
+ model = prepare_model_for_kbit_training(model)
56
+
57
+ # LoRA configuration for GRPO
58
+ lora_config = LoraConfig(
59
+ task_type=TaskType.CAUSAL_LM,
60
+ r=16, # rank
61
+ lora_alpha=32, # alpha scaling
62
+ lora_dropout=0.1,
63
+ target_modules=["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
64
+ bias="none",
65
+ )
66
+
67
+ # Apply LoRA to model
68
+ model = get_peft_model(model, lora_config)
69
+ model.print_trainable_parameters()
70
+
71
+ # Load preference dataset
72
+ print(f"πŸ“Š Loading preference dataset: {dataset_name}")
73
+ dataset = load_dataset(dataset_name, split="train")
74
+ print(f"βœ… Loaded {len(dataset)} preference pairs")
75
+
76
+ # Create train/eval split
77
+ train_test_split = dataset.train_test_split(test_size=0.1, seed=42)
78
+ train_dataset = train_test_split["train"]
79
+ eval_dataset = train_test_split["test"]
80
+
81
+ print(f"πŸ“ˆ Training samples: {len(train_dataset)}")
82
+ print(f"πŸ“‰ Evaluation samples: {len(eval_dataset)}")
83
+
84
+ # GRPO Configuration
85
+ training_args = GRPOConfig(
86
+ output_dir="./grpo_sec_model",
87
+
88
+ # Basic training settings
89
+ num_train_epochs=2,
90
+ per_device_train_batch_size=1,
91
+ per_device_eval_batch_size=1,
92
+ gradient_accumulation_steps=8, # Effective batch size = 8
93
+
94
+ # Learning rate and optimization
95
+ learning_rate=5e-6, # Lower LR for RL fine-tuning
96
+ lr_scheduler_type="cosine",
97
+ warmup_ratio=0.03,
98
+
99
+ # Memory and efficiency
100
+ gradient_checkpointing=True,
101
+ dataloader_pin_memory=True,
102
+ bf16=True,
103
+ remove_unused_columns=False,
104
+
105
+ # GRPO specific parameters
106
+ beta=0.1, # KL penalty coefficient
107
+ grpo_score_clip=5.0, # Clip scores to prevent instability
108
+
109
+ # Evaluation and logging
110
+ eval_strategy="steps",
111
+ eval_steps=50,
112
+ logging_steps=10,
113
+ save_strategy="steps",
114
+ save_steps=100,
115
+ save_total_limit=3,
116
+
117
+ # Tracking
118
+ report_to=["trackio"],
119
+ run_name="sec_grpo_training",
120
+
121
+ # Hub integration
122
+ push_to_hub=True,
123
+ hub_model_id=output_model,
124
+ hub_strategy="every_save",
125
+
126
+ # Length settings
127
+ max_length=512,
128
+ max_prompt_length=256,
129
+ )
130
+
131
+ # Initialize GRPO Trainer
132
+ print("🎯 Initializing GRPO Trainer...")
133
+ trainer = GRPOTrainer(
134
+ model=model,
135
+ args=training_args,
136
+ train_dataset=train_dataset,
137
+ eval_dataset=eval_dataset,
138
+ tokenizer=tokenizer,
139
+ peft_config=lora_config,
140
+ )
141
+
142
+ # Log initial metrics
143
+ trackio.log({
144
+ "model_name": model_name,
145
+ "dataset_name": dataset_name,
146
+ "output_model": output_model,
147
+ "train_samples": len(train_dataset),
148
+ "eval_samples": len(eval_dataset),
149
+ "lora_rank": lora_config.r,
150
+ "lora_alpha": lora_config.lora_alpha,
151
+ "beta": training_args.beta,
152
+ "learning_rate": training_args.learning_rate,
153
+ })
154
+
155
+ # Start training
156
+ print("πŸš€ Starting GRPO training...")
157
+ try:
158
+ trainer.train()
159
+
160
+ # Log final metrics
161
+ trainer_state = trainer.state
162
+ trackio.log({
163
+ "final_train_loss": trainer_state.log_history[-1].get("train_loss", 0),
164
+ "final_eval_loss": trainer_state.log_history[-1].get("eval_loss", 0),
165
+ "training_completed": True
166
+ })
167
+
168
+ # Save final model
169
+ print("πŸ’Ύ Saving final model...")
170
+ trainer.save_model()
171
+
172
+ # Push to hub
173
+ print("πŸ“€ Pushing to Hub...")
174
+ trainer.push_to_hub(commit_message="GRPO training completed")
175
+
176
+ print(f"βœ… GRPO training completed successfully!")
177
+ print(f"πŸ“¦ Model saved to: {output_model}")
178
+
179
+ # Create evaluation summary
180
+ eval_summary = {
181
+ "total_steps": trainer_state.global_step,
182
+ "total_epochs": trainer_state.epoch,
183
+ "final_train_loss": trainer_state.log_history[-1].get("train_loss", "N/A"),
184
+ "final_eval_loss": trainer_state.log_history[-1].get("eval_loss", "N/A"),
185
+ "model_parameters": sum(p.numel() for p in model.parameters() if p.requires_grad),
186
+ }
187
+
188
+ print("πŸ“Š Training Summary:")
189
+ for key, value in eval_summary.items():
190
+ print(f" {key}: {value}")
191
+
192
+ trackio.log(eval_summary)
193
+
194
+ except Exception as e:
195
+ print(f"❌ Training failed: {e}")
196
+ trackio.log({"error": str(e), "training_completed": False})
197
+ raise e
198
+
199
+ if __name__ == "__main__":
200
+ main()