#!/usr/bin/env python3 """ GRPO training script for arithmetic countdown problems. This script trains a language model using GRPO (Group Relative Policy Optimization) to solve arithmetic problems with proper reasoning and formatting. """ import argparse import logging import os from collections.abc import Callable from pathlib import Path from datasets import Dataset from peft import LoraConfig, get_peft_model from src.utils.dataset import load_csv_dataset from transformers import AutoModelForCausalLM, PreTrainedModel from trl import GRPOConfig, GRPOTrainer from src.utils.rewards import ( mathematical_correctness_reward_function, ) # Set up logging logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" ) logger = logging.getLogger("grpo_training") def load_train_dataset( dataset_csv: str, max_rows: int = 2000, seed: int = 42 ) -> Dataset: """ Load, shuffle, and subsample the training dataset. Args: dataset_csv: Absolute path to the dataset CSV file max_rows: Maximum number of rows to select for training seed: Seed for dataset shuffling Returns: Dataset: A datasets.Dataset ready for GRPO training """ raw_dataset: Dataset = load_csv_dataset(dataset_csv) raw_dataset = raw_dataset.shuffle(seed=seed) train_dataset = raw_dataset.select(range(min(max_rows, len(raw_dataset)))) logger.info("Train rows: %d", len(train_dataset)) return train_dataset def create_lora_model(model_id: str, device_map: str = "cuda") -> PreTrainedModel: """ Create a base causal LM and wrap it with LoRA adapters. Args: model_id: Hugging Face model identifier to load as the base model device_map: Device mapping strategy for model loading Returns: PreTrainedModel: A transformers.PreTrainedModel with LoRA adapters applied """ model = AutoModelForCausalLM.from_pretrained( model_id, device_map=device_map, ) lora_cfg = LoraConfig( r=16, lora_alpha=32, target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], lora_dropout=0.05, bias="none", task_type="CAUSAL_LM", ) model = get_peft_model(model, lora_cfg) logger.info("Model with LoRA ready") return model def create_grpo_config( output_dir: str, learning_rate: float = 5e-6, num_train_epochs: int = 1, per_device_train_batch_size: int = 1, gradient_accumulation_steps: int = 16, max_completion_length: int = 512, num_generations: int = 16, temperature: float = 1.0, save_steps: int = 50, logging_steps: int = 1, max_prompt_length: int = 4096, ) -> GRPOConfig: """ Create GRPO training configuration. Args: output_dir: Directory where checkpoints and logs will be written learning_rate: Learning rate for training num_train_epochs: Number of training epochs per_device_train_batch_size: Batch size per device gradient_accumulation_steps: Steps to accumulate gradients max_completion_length: Maximum length for completions num_generations: Number of generations per prompt temperature: Sampling temperature save_steps: Steps between model saves logging_steps: Steps between log outputs max_prompt_length: Maximum length for input prompts Returns: GRPOConfig: A configured trl.GRPOConfig instance """ return GRPOConfig( output_dir=output_dir, learning_rate=learning_rate, weight_decay=0.01, warmup_ratio=0.1, lr_scheduler_type="linear", optim="adamw_8bit", remove_unused_columns=False, gradient_accumulation_steps=gradient_accumulation_steps, num_train_epochs=num_train_epochs, bf16=True, per_device_train_batch_size=per_device_train_batch_size, temperature=temperature, # Preprocessing controls max_completion_length=max_completion_length, num_generations=num_generations, max_prompt_length=max_prompt_length, # Logging and saving report_to=["tensorboard"], logging_steps=logging_steps, save_strategy="steps", save_steps=save_steps, ) def create_trainer( model: PreTrainedModel, train_dataset: Dataset, args: GRPOConfig, ) -> GRPOTrainer: """ Construct a GRPOTrainer with arithmetic-specific reward functions. Args: model: The LoRA-wrapped pretrained model to train train_dataset: The dataset to use for training args: The GRPO configuration Returns: GRPOTrainer: An initialized trl.GRPOTrainer instance """ reward_funcs: list[Callable[..., list[float]]] = [ mathematical_correctness_reward_function, ] trainer = GRPOTrainer( model=model, reward_funcs=reward_funcs, args=args, train_dataset=train_dataset, ) return trainer def train_and_save(trainer: GRPOTrainer, output_dir: str) -> None: """ Run training and save the final model to disk. Args: trainer: The configured GRPO trainer instance output_dir: Output directory to save the trained model Returns: None """ train_result = trainer.train() logger.info("Training complete: %s", str(train_result)) trainer.save_model(output_dir) logger.info("Saved to %s", output_dir) def main() -> None: """ Run the full GRPO training workflow with command-line arguments. Returns: None """ parser = argparse.ArgumentParser( description="Train a language model using GRPO for arithmetic countdown problems" ) # Dataset arguments parser.add_argument( "--dataset_csv", type=str, required=True, help="Path to the training dataset CSV file", ) parser.add_argument( "--max_rows", type=int, default=2000, help="Maximum number of training samples" ) parser.add_argument( "--seed", type=int, default=42, help="Random seed for dataset shuffling" ) # Model arguments parser.add_argument( "--model_id", type=str, default="Qwen/Qwen2.5-3B-Instruct", help="Hugging Face model identifier", ) parser.add_argument( "--device_map", type=str, default="x", help="Device mapping strategy" ) # Training arguments parser.add_argument( "--output_dir", type=str, required=True, help="Directory to save model checkpoints and logs", ) parser.add_argument( "--learning_rate", type=float, default=5e-6, help="Learning rate" ) parser.add_argument( "--num_train_epochs", type=int, default=1, help="Number of training epochs" ) parser.add_argument( "--per_device_train_batch_size", type=int, default=1, help="Batch size per device", ) parser.add_argument( "--gradient_accumulation_steps", type=int, default=16, help="Gradient accumulation steps", ) parser.add_argument( "--max_completion_length", type=int, default=512, help="Maximum completion length", ) parser.add_argument( "--num_generations", type=int, default=16, help="Number of generations per prompt", ) parser.add_argument( "--temperature", type=float, default=1.0, help="Sampling temperature" ) parser.add_argument( "--save_steps", type=int, default=50, help="Steps between model saves" ) parser.add_argument( "--logging_steps", type=int, default=1, help="Steps between log outputs" ) parser.add_argument( "--max_prompt_length", type=int, default=4096, help="Maximum length for input prompts", ) args = parser.parse_args() # Validate arguments if not Path(args.dataset_csv).exists(): logger.error("Dataset CSV file does not exist: %s", args.dataset_csv) return if args.max_rows <= 0: logger.error("max_rows must be positive") return # Create output directory os.makedirs(args.output_dir, exist_ok=True) logger.info("Output dir: %s", args.output_dir) # Load dataset train_dataset = load_train_dataset(args.dataset_csv, args.max_rows, args.seed) # Create model model = create_lora_model(args.model_id, args.device_map) # Create training configuration training_args = create_grpo_config( output_dir=args.output_dir, learning_rate=args.learning_rate, num_train_epochs=args.num_train_epochs, per_device_train_batch_size=args.per_device_train_batch_size, gradient_accumulation_steps=args.gradient_accumulation_steps, max_completion_length=args.max_completion_length, num_generations=args.num_generations, temperature=args.temperature, save_steps=args.save_steps, logging_steps=args.logging_steps, max_prompt_length=args.max_prompt_length, ) # Create trainer trainer = create_trainer( model=model, train_dataset=train_dataset, args=training_args ) # Train and save train_and_save(trainer=trainer, output_dir=args.output_dir) if __name__ == "__main__": main()