#!/usr/bin/env python3 """ GRPO training script for arithmetic countdown problems using Hydra configuration. After training, the model is automatically pushed to HuggingFace Hub. """ import logging import os from collections.abc import Callable from pathlib import Path import sys # ---------------- FIX IMPORT WHEN USING HYDRA ---------------- FILE_DIR = os.path.dirname(os.path.abspath(__file__)) # src/training/grpo PROJECT_ROOT = os.path.abspath(os.path.join(FILE_DIR, "../../../")) sys.path.insert(0, PROJECT_ROOT) import hydra from datasets import Dataset from omegaconf import DictConfig, OmegaConf from transformers import AutoModelForCausalLM, PreTrainedModel, AutoTokenizer from huggingface_hub import HfApi, login, create_repo from peft import LoraConfig, PeftModel, get_peft_model from trl import GRPOConfig, GRPOTrainer from src.dataset import load_csv_dataset_grpo from src.dataset.grpo import map_problem_description_to_conversation_grpo from src.utils.rewards import mathematical_correctness_reward_function # Logging logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" ) logger = logging.getLogger("grpo_training") # ------------------------------------------------------------- # DATASET # ------------------------------------------------------------- def load_train_dataset(cfg: DictConfig) -> Dataset: raw_dataset: Dataset = load_csv_dataset_grpo( cfg.file_path, cfg.split, map_problem_description_to_conversation_grpo ) raw_dataset = raw_dataset.shuffle(seed=cfg.seed) return raw_dataset.select(range(min(cfg.max_rows, len(raw_dataset)))) # ------------------------------------------------------------- # MODEL (LoRA + resume) # ------------------------------------------------------------- def create_lora_model(cfg: DictConfig, resume_from_checkpoint: str | None = None) -> PreTrainedModel: model = AutoModelForCausalLM.from_pretrained(cfg.model_id, device_map=cfg.device_map) if resume_from_checkpoint and Path(resume_from_checkpoint).exists(): logger.info("Loading existing LoRA adapter and merging: %s", resume_from_checkpoint) model = PeftModel.from_pretrained(model, resume_from_checkpoint) model = model.merge_and_unload() lora_cfg = LoraConfig( r=cfg.lora.r, lora_alpha=cfg.lora.lora_alpha, target_modules=OmegaConf.to_container(cfg.lora.target_modules), lora_dropout=cfg.lora.lora_dropout, bias=cfg.lora.bias, task_type=cfg.lora.task_type, ) return get_peft_model(model, lora_cfg) # ------------------------------------------------------------- # GRPO CONFIG # ------------------------------------------------------------- def create_grpo_config(cfg: DictConfig, output_dir: str) -> GRPOConfig: return GRPOConfig( output_dir=output_dir, learning_rate=cfg.learning_rate, warmup_ratio=cfg.warmup_ratio, weight_decay=cfg.weight_decay, lr_scheduler_type=cfg.lr_scheduler_type, optim=cfg.optim, remove_unused_columns=cfg.remove_unused_columns, gradient_accumulation_steps=cfg.gradient_accumulation_steps, num_train_epochs=cfg.num_train_epochs, bf16=cfg.bf16, per_device_train_batch_size=cfg.per_device_train_batch_size, temperature=cfg.temperature, max_completion_length=cfg.max_completion_length, num_generations=cfg.num_generations, max_prompt_length=cfg.max_prompt_length, report_to=cfg.report_to, logging_steps=cfg.logging_steps, save_strategy=cfg.save_strategy, save_steps=cfg.save_steps, ) # ------------------------------------------------------------- # TRAINER # ------------------------------------------------------------- def create_trainer(model, train_dataset, args): reward_funcs = [mathematical_correctness_reward_function] return GRPOTrainer( model=model, reward_funcs=reward_funcs, args=args, train_dataset=train_dataset, ) # ------------------------------------------------------------- # TRAIN + SAVE # ------------------------------------------------------------- def train_and_save(trainer, output_dir, resume_from_checkpoint=None, save_before_training=True): if save_before_training: trainer.save_model(output_dir) trainer.train(resume_from_checkpoint=resume_from_checkpoint) trainer.save_model(output_dir) logger.info("Training completed.") logger.info("Saved final model to: %s", output_dir) # ------------------------------------------------------------- # PUSH TO HUGGINGFACE HUB # ------------------------------------------------------------- def push_to_huggingface(output_dir: str, repo_id: str, model_id: str): logger.info("Pushing model to HuggingFace Hub...") # Login must be done BEFORE training api = HfApi() # Create repo if not exists try: api.create_repo(repo_id, exist_ok=True) except: pass # Load tokenizer (important!) tokenizer = AutoTokenizer.from_pretrained(model_id) # Push api.upload_folder( folder_path=output_dir, repo_id=repo_id, commit_message="Upload GRPO fine-tuned model", ) tokenizer.push_to_hub(repo_id) logger.info("Upload complete! HF repo: https://huggingface.co/%s", repo_id) # ------------------------------------------------------------- # MAIN # ------------------------------------------------------------- @hydra.main(version_base=None, config_path="../../config/grpo", config_name="config") def main(cfg: DictConfig): logger.info("Configuration:\n%s", OmegaConf.to_yaml(cfg)) if not Path(cfg.dataset.file_path).exists(): logger.error("Dataset CSV file does not exist: %s", cfg.dataset.file_path) return os.makedirs(cfg.output_dir, exist_ok=True) # Load dataset train_dataset = load_train_dataset(cfg.dataset) # Model resume_sft = cfg.get("resume_from_checkpoint_sft", None) model = create_lora_model(cfg.model, resume_sft) # Trainer training_args = create_grpo_config(cfg.training, cfg.output_dir) trainer = create_trainer(model, train_dataset, training_args) # Train train_and_save( trainer, cfg.output_dir, resume_from_checkpoint=cfg.resume_from_checkpoint_grpo, save_before_training=cfg.save_before_training, ) # Push to HF if cfg.get("push_to_hub", False): push_to_huggingface( output_dir=cfg.output_dir, repo_id=cfg.hf_repo_id, model_id=cfg.model.model_id, ) if __name__ == "__main__": main()