File size: 6,876 Bytes
00db46c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 | #!/usr/bin/env python3
"""
GRPO training script for arithmetic countdown problems using Hydra configuration.
After training, the model is automatically pushed to HuggingFace Hub.
"""
import logging
import os
from collections.abc import Callable
from pathlib import Path
import sys
# ---------------- FIX IMPORT WHEN USING HYDRA ----------------
FILE_DIR = os.path.dirname(os.path.abspath(__file__)) # src/training/grpo
PROJECT_ROOT = os.path.abspath(os.path.join(FILE_DIR, "../../../"))
sys.path.insert(0, PROJECT_ROOT)
import hydra
from datasets import Dataset
from omegaconf import DictConfig, OmegaConf
from transformers import AutoModelForCausalLM, PreTrainedModel, AutoTokenizer
from huggingface_hub import HfApi, login, create_repo
from peft import LoraConfig, PeftModel, get_peft_model
from trl import GRPOConfig, GRPOTrainer
from src.dataset import load_csv_dataset_grpo
from src.dataset.grpo import map_problem_description_to_conversation_grpo
from src.utils.rewards import mathematical_correctness_reward_function
# Logging
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger("grpo_training")
# -------------------------------------------------------------
# DATASET
# -------------------------------------------------------------
def load_train_dataset(cfg: DictConfig) -> Dataset:
raw_dataset: Dataset = load_csv_dataset_grpo(
cfg.file_path, cfg.split, map_problem_description_to_conversation_grpo
)
raw_dataset = raw_dataset.shuffle(seed=cfg.seed)
return raw_dataset.select(range(min(cfg.max_rows, len(raw_dataset))))
# -------------------------------------------------------------
# MODEL (LoRA + resume)
# -------------------------------------------------------------
def create_lora_model(cfg: DictConfig, resume_from_checkpoint: str | None = None) -> PreTrainedModel:
model = AutoModelForCausalLM.from_pretrained(cfg.model_id, device_map=cfg.device_map)
if resume_from_checkpoint and Path(resume_from_checkpoint).exists():
logger.info("Loading existing LoRA adapter and merging: %s", resume_from_checkpoint)
model = PeftModel.from_pretrained(model, resume_from_checkpoint)
model = model.merge_and_unload()
lora_cfg = LoraConfig(
r=cfg.lora.r,
lora_alpha=cfg.lora.lora_alpha,
target_modules=OmegaConf.to_container(cfg.lora.target_modules),
lora_dropout=cfg.lora.lora_dropout,
bias=cfg.lora.bias,
task_type=cfg.lora.task_type,
)
return get_peft_model(model, lora_cfg)
# -------------------------------------------------------------
# GRPO CONFIG
# -------------------------------------------------------------
def create_grpo_config(cfg: DictConfig, output_dir: str) -> GRPOConfig:
return GRPOConfig(
output_dir=output_dir,
learning_rate=cfg.learning_rate,
warmup_ratio=cfg.warmup_ratio,
weight_decay=cfg.weight_decay,
lr_scheduler_type=cfg.lr_scheduler_type,
optim=cfg.optim,
remove_unused_columns=cfg.remove_unused_columns,
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
num_train_epochs=cfg.num_train_epochs,
bf16=cfg.bf16,
per_device_train_batch_size=cfg.per_device_train_batch_size,
temperature=cfg.temperature,
max_completion_length=cfg.max_completion_length,
num_generations=cfg.num_generations,
max_prompt_length=cfg.max_prompt_length,
report_to=cfg.report_to,
logging_steps=cfg.logging_steps,
save_strategy=cfg.save_strategy,
save_steps=cfg.save_steps,
)
# -------------------------------------------------------------
# TRAINER
# -------------------------------------------------------------
def create_trainer(model, train_dataset, args):
reward_funcs = [mathematical_correctness_reward_function]
return GRPOTrainer(
model=model,
reward_funcs=reward_funcs,
args=args,
train_dataset=train_dataset,
)
# -------------------------------------------------------------
# TRAIN + SAVE
# -------------------------------------------------------------
def train_and_save(trainer, output_dir, resume_from_checkpoint=None, save_before_training=True):
if save_before_training:
trainer.save_model(output_dir)
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
trainer.save_model(output_dir)
logger.info("Training completed.")
logger.info("Saved final model to: %s", output_dir)
# -------------------------------------------------------------
# PUSH TO HUGGINGFACE HUB
# -------------------------------------------------------------
def push_to_huggingface(output_dir: str, repo_id: str, model_id: str):
logger.info("Pushing model to HuggingFace Hub...")
# Login must be done BEFORE training
api = HfApi()
# Create repo if not exists
try:
api.create_repo(repo_id, exist_ok=True)
except:
pass
# Load tokenizer (important!)
tokenizer = AutoTokenizer.from_pretrained(model_id)
# Push
api.upload_folder(
folder_path=output_dir,
repo_id=repo_id,
commit_message="Upload GRPO fine-tuned model",
)
tokenizer.push_to_hub(repo_id)
logger.info("Upload complete! HF repo: https://huggingface.co/%s", repo_id)
# -------------------------------------------------------------
# MAIN
# -------------------------------------------------------------
@hydra.main(version_base=None, config_path="../../config/grpo", config_name="config")
def main(cfg: DictConfig):
logger.info("Configuration:\n%s", OmegaConf.to_yaml(cfg))
if not Path(cfg.dataset.file_path).exists():
logger.error("Dataset CSV file does not exist: %s", cfg.dataset.file_path)
return
os.makedirs(cfg.output_dir, exist_ok=True)
# Load dataset
train_dataset = load_train_dataset(cfg.dataset)
# Model
resume_sft = cfg.get("resume_from_checkpoint_sft", None)
model = create_lora_model(cfg.model, resume_sft)
# Trainer
training_args = create_grpo_config(cfg.training, cfg.output_dir)
trainer = create_trainer(model, train_dataset, training_args)
# Train
train_and_save(
trainer,
cfg.output_dir,
resume_from_checkpoint=cfg.resume_from_checkpoint_grpo,
save_before_training=cfg.save_before_training,
)
# Push to HF
if cfg.get("push_to_hub", False):
push_to_huggingface(
output_dir=cfg.output_dir,
repo_id=cfg.hf_repo_id,
model_id=cfg.model.model_id,
)
if __name__ == "__main__":
main()
|