grpo-training-scripts / grpo_training.py
ligaments-dev's picture
Upload grpo_training.py with huggingface_hub
a22a2fe verified
# /// script
# dependencies = ["trl>=0.12.0", "peft>=0.7.0", "trackio", "torch", "transformers"]
# ///
from datasets import load_dataset
from peft import LoraConfig
from trl.trainer.grpo_trainer import GRPOTrainer, GRPOConfig
from transformers import AutoTokenizer, AutoModelForCausalLM
import trackio
import torch
# Load your fine-tuned model and preference dataset
model_name = "ligaments-enterprise/llama3.2-1b-instruct-sec-finetuned"
dataset_name = "ligaments-enterprise/sec-data-preferences"
output_model = "ligaments-enterprise/llama3.2-1b-sec-grpo"
# Load dataset
dataset = load_dataset(dataset_name, split="train")
print(f"Loaded {len(dataset)} preference pairs from {dataset_name}")
# Create train/eval split for monitoring
dataset_split = dataset.train_test_split(test_size=0.1, seed=42)
train_dataset = dataset_split["train"]
eval_dataset = dataset_split["test"]
# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Load the model explicitly
model = AutoModelForCausalLM.from_pretrained(
model_name,
dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16,
device_map="auto"
)
# Configure GRPO training
config = GRPOConfig(
output_dir=output_model,
num_train_epochs=3,
per_device_train_batch_size=1,
per_device_eval_batch_size=8, # Must be divisible by num_generations (default 8)
gradient_accumulation_steps=8, # Effective batch size = 8
learning_rate=1e-6,
# Evaluation and logging
eval_strategy="steps",
eval_steps=50,
logging_steps=10,
save_strategy="steps",
save_steps=100,
# Hub integration
push_to_hub=True,
hub_model_id=output_model,
hub_strategy="every_save",
# Optimization
gradient_checkpointing=True,
bf16=True if torch.cuda.is_bf16_supported() else False,
fp16=False if torch.cuda.is_bf16_supported() else True,
# Trackio monitoring
report_to="trackio",
run_name="llama3.2-1b-sec-grpo-training",
project="ligaments-sec-alignment",
)
# Define reward function for GRPO
def preference_reward_func(**kwargs):
"""Simple reward function based on response length preference"""
# Extract completions from kwargs
completions = kwargs.get('completions', [])
rewards = []
for completion in completions:
# Prefer shorter, more concise responses (addressing verbosity issue)
response_length = len(completion.split())
# Reward shorter responses (up to a reasonable length)
if response_length < 50:
reward = 1.0
elif response_length < 100:
reward = 0.5
else:
reward = 0.0 # Penalize overly verbose responses
rewards.append(reward)
return rewards
# Initialize GRPO trainer
trainer = GRPOTrainer(
model=model,
reward_funcs=[preference_reward_func],
train_dataset=train_dataset,
eval_dataset=eval_dataset,
peft_config=LoraConfig(
r=16,
lora_alpha=32,
target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM"
),
args=config,
)
print("Starting GRPO training...")
print(f"Training on {len(train_dataset)} preference pairs")
print(f"Evaluating on {len(eval_dataset)} preference pairs")
print(f"Output model will be saved to: {output_model}")
# Train the model
trainer.train()
# Push final model to Hub
trainer.push_to_hub()
print("GRPO training completed successfully!")
print(f"Final model available at: https://huggingface.co/{output_model}")