Dat1710's picture
Upload folder using huggingface_hub
00db46c verified
#!/usr/bin/env python3
"""
GRPO training script for arithmetic countdown problems using Hydra configuration.
After training, the model is automatically pushed to HuggingFace Hub.
"""
import logging
import os
from collections.abc import Callable
from pathlib import Path
import sys
# ---------------- FIX IMPORT WHEN USING HYDRA ----------------
FILE_DIR = os.path.dirname(os.path.abspath(__file__)) # src/training/grpo
PROJECT_ROOT = os.path.abspath(os.path.join(FILE_DIR, "../../../"))
sys.path.insert(0, PROJECT_ROOT)
import hydra
from datasets import Dataset
from omegaconf import DictConfig, OmegaConf
from transformers import AutoModelForCausalLM, PreTrainedModel, AutoTokenizer
from huggingface_hub import HfApi, login, create_repo
from peft import LoraConfig, PeftModel, get_peft_model
from trl import GRPOConfig, GRPOTrainer
from src.dataset import load_csv_dataset_grpo
from src.dataset.grpo import map_problem_description_to_conversation_grpo
from src.utils.rewards import mathematical_correctness_reward_function
# Logging
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger("grpo_training")
# -------------------------------------------------------------
# DATASET
# -------------------------------------------------------------
def load_train_dataset(cfg: DictConfig) -> Dataset:
raw_dataset: Dataset = load_csv_dataset_grpo(
cfg.file_path, cfg.split, map_problem_description_to_conversation_grpo
)
raw_dataset = raw_dataset.shuffle(seed=cfg.seed)
return raw_dataset.select(range(min(cfg.max_rows, len(raw_dataset))))
# -------------------------------------------------------------
# MODEL (LoRA + resume)
# -------------------------------------------------------------
def create_lora_model(cfg: DictConfig, resume_from_checkpoint: str | None = None) -> PreTrainedModel:
model = AutoModelForCausalLM.from_pretrained(cfg.model_id, device_map=cfg.device_map)
if resume_from_checkpoint and Path(resume_from_checkpoint).exists():
logger.info("Loading existing LoRA adapter and merging: %s", resume_from_checkpoint)
model = PeftModel.from_pretrained(model, resume_from_checkpoint)
model = model.merge_and_unload()
lora_cfg = LoraConfig(
r=cfg.lora.r,
lora_alpha=cfg.lora.lora_alpha,
target_modules=OmegaConf.to_container(cfg.lora.target_modules),
lora_dropout=cfg.lora.lora_dropout,
bias=cfg.lora.bias,
task_type=cfg.lora.task_type,
)
return get_peft_model(model, lora_cfg)
# -------------------------------------------------------------
# GRPO CONFIG
# -------------------------------------------------------------
def create_grpo_config(cfg: DictConfig, output_dir: str) -> GRPOConfig:
return GRPOConfig(
output_dir=output_dir,
learning_rate=cfg.learning_rate,
warmup_ratio=cfg.warmup_ratio,
weight_decay=cfg.weight_decay,
lr_scheduler_type=cfg.lr_scheduler_type,
optim=cfg.optim,
remove_unused_columns=cfg.remove_unused_columns,
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
num_train_epochs=cfg.num_train_epochs,
bf16=cfg.bf16,
per_device_train_batch_size=cfg.per_device_train_batch_size,
temperature=cfg.temperature,
max_completion_length=cfg.max_completion_length,
num_generations=cfg.num_generations,
max_prompt_length=cfg.max_prompt_length,
report_to=cfg.report_to,
logging_steps=cfg.logging_steps,
save_strategy=cfg.save_strategy,
save_steps=cfg.save_steps,
)
# -------------------------------------------------------------
# TRAINER
# -------------------------------------------------------------
def create_trainer(model, train_dataset, args):
reward_funcs = [mathematical_correctness_reward_function]
return GRPOTrainer(
model=model,
reward_funcs=reward_funcs,
args=args,
train_dataset=train_dataset,
)
# -------------------------------------------------------------
# TRAIN + SAVE
# -------------------------------------------------------------
def train_and_save(trainer, output_dir, resume_from_checkpoint=None, save_before_training=True):
if save_before_training:
trainer.save_model(output_dir)
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
trainer.save_model(output_dir)
logger.info("Training completed.")
logger.info("Saved final model to: %s", output_dir)
# -------------------------------------------------------------
# PUSH TO HUGGINGFACE HUB
# -------------------------------------------------------------
def push_to_huggingface(output_dir: str, repo_id: str, model_id: str):
logger.info("Pushing model to HuggingFace Hub...")
# Login must be done BEFORE training
api = HfApi()
# Create repo if not exists
try:
api.create_repo(repo_id, exist_ok=True)
except:
pass
# Load tokenizer (important!)
tokenizer = AutoTokenizer.from_pretrained(model_id)
# Push
api.upload_folder(
folder_path=output_dir,
repo_id=repo_id,
commit_message="Upload GRPO fine-tuned model",
)
tokenizer.push_to_hub(repo_id)
logger.info("Upload complete! HF repo: https://huggingface.co/%s", repo_id)
# -------------------------------------------------------------
# MAIN
# -------------------------------------------------------------
@hydra.main(version_base=None, config_path="../../config/grpo", config_name="config")
def main(cfg: DictConfig):
logger.info("Configuration:\n%s", OmegaConf.to_yaml(cfg))
if not Path(cfg.dataset.file_path).exists():
logger.error("Dataset CSV file does not exist: %s", cfg.dataset.file_path)
return
os.makedirs(cfg.output_dir, exist_ok=True)
# Load dataset
train_dataset = load_train_dataset(cfg.dataset)
# Model
resume_sft = cfg.get("resume_from_checkpoint_sft", None)
model = create_lora_model(cfg.model, resume_sft)
# Trainer
training_args = create_grpo_config(cfg.training, cfg.output_dir)
trainer = create_trainer(model, train_dataset, training_args)
# Train
train_and_save(
trainer,
cfg.output_dir,
resume_from_checkpoint=cfg.resume_from_checkpoint_grpo,
save_before_training=cfg.save_before_training,
)
# Push to HF
if cfg.get("push_to_hub", False):
push_to_huggingface(
output_dir=cfg.output_dir,
repo_id=cfg.hf_repo_id,
model_id=cfg.model.model_id,
)
if __name__ == "__main__":
main()