|
|
|
|
|
"""
|
|
|
GRPO training script for arithmetic countdown problems using Hydra configuration.
|
|
|
After training, the model is automatically pushed to HuggingFace Hub.
|
|
|
"""
|
|
|
|
|
|
import logging
|
|
|
import os
|
|
|
from collections.abc import Callable
|
|
|
from pathlib import Path
|
|
|
import sys
|
|
|
|
|
|
|
|
|
FILE_DIR = os.path.dirname(os.path.abspath(__file__))
|
|
|
PROJECT_ROOT = os.path.abspath(os.path.join(FILE_DIR, "../../../"))
|
|
|
sys.path.insert(0, PROJECT_ROOT)
|
|
|
|
|
|
import hydra
|
|
|
from datasets import Dataset
|
|
|
from omegaconf import DictConfig, OmegaConf
|
|
|
from transformers import AutoModelForCausalLM, PreTrainedModel, AutoTokenizer
|
|
|
from huggingface_hub import HfApi, login, create_repo
|
|
|
from peft import LoraConfig, PeftModel, get_peft_model
|
|
|
from trl import GRPOConfig, GRPOTrainer
|
|
|
|
|
|
from src.dataset import load_csv_dataset_grpo
|
|
|
from src.dataset.grpo import map_problem_description_to_conversation_grpo
|
|
|
from src.utils.rewards import mathematical_correctness_reward_function
|
|
|
|
|
|
|
|
|
logging.basicConfig(
|
|
|
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
|
|
)
|
|
|
logger = logging.getLogger("grpo_training")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_train_dataset(cfg: DictConfig) -> Dataset:
|
|
|
raw_dataset: Dataset = load_csv_dataset_grpo(
|
|
|
cfg.file_path, cfg.split, map_problem_description_to_conversation_grpo
|
|
|
)
|
|
|
raw_dataset = raw_dataset.shuffle(seed=cfg.seed)
|
|
|
return raw_dataset.select(range(min(cfg.max_rows, len(raw_dataset))))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_lora_model(cfg: DictConfig, resume_from_checkpoint: str | None = None) -> PreTrainedModel:
|
|
|
model = AutoModelForCausalLM.from_pretrained(cfg.model_id, device_map=cfg.device_map)
|
|
|
|
|
|
if resume_from_checkpoint and Path(resume_from_checkpoint).exists():
|
|
|
logger.info("Loading existing LoRA adapter and merging: %s", resume_from_checkpoint)
|
|
|
model = PeftModel.from_pretrained(model, resume_from_checkpoint)
|
|
|
model = model.merge_and_unload()
|
|
|
|
|
|
lora_cfg = LoraConfig(
|
|
|
r=cfg.lora.r,
|
|
|
lora_alpha=cfg.lora.lora_alpha,
|
|
|
target_modules=OmegaConf.to_container(cfg.lora.target_modules),
|
|
|
lora_dropout=cfg.lora.lora_dropout,
|
|
|
bias=cfg.lora.bias,
|
|
|
task_type=cfg.lora.task_type,
|
|
|
)
|
|
|
return get_peft_model(model, lora_cfg)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_grpo_config(cfg: DictConfig, output_dir: str) -> GRPOConfig:
|
|
|
return GRPOConfig(
|
|
|
output_dir=output_dir,
|
|
|
learning_rate=cfg.learning_rate,
|
|
|
warmup_ratio=cfg.warmup_ratio,
|
|
|
weight_decay=cfg.weight_decay,
|
|
|
lr_scheduler_type=cfg.lr_scheduler_type,
|
|
|
optim=cfg.optim,
|
|
|
remove_unused_columns=cfg.remove_unused_columns,
|
|
|
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
|
|
|
num_train_epochs=cfg.num_train_epochs,
|
|
|
bf16=cfg.bf16,
|
|
|
per_device_train_batch_size=cfg.per_device_train_batch_size,
|
|
|
temperature=cfg.temperature,
|
|
|
max_completion_length=cfg.max_completion_length,
|
|
|
num_generations=cfg.num_generations,
|
|
|
max_prompt_length=cfg.max_prompt_length,
|
|
|
report_to=cfg.report_to,
|
|
|
logging_steps=cfg.logging_steps,
|
|
|
save_strategy=cfg.save_strategy,
|
|
|
save_steps=cfg.save_steps,
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_trainer(model, train_dataset, args):
|
|
|
reward_funcs = [mathematical_correctness_reward_function]
|
|
|
return GRPOTrainer(
|
|
|
model=model,
|
|
|
reward_funcs=reward_funcs,
|
|
|
args=args,
|
|
|
train_dataset=train_dataset,
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train_and_save(trainer, output_dir, resume_from_checkpoint=None, save_before_training=True):
|
|
|
if save_before_training:
|
|
|
trainer.save_model(output_dir)
|
|
|
|
|
|
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
|
|
trainer.save_model(output_dir)
|
|
|
|
|
|
logger.info("Training completed.")
|
|
|
logger.info("Saved final model to: %s", output_dir)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def push_to_huggingface(output_dir: str, repo_id: str, model_id: str):
|
|
|
logger.info("Pushing model to HuggingFace Hub...")
|
|
|
|
|
|
|
|
|
api = HfApi()
|
|
|
|
|
|
|
|
|
try:
|
|
|
api.create_repo(repo_id, exist_ok=True)
|
|
|
except:
|
|
|
pass
|
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
|
|
|
|
|
|
|
|
api.upload_folder(
|
|
|
folder_path=output_dir,
|
|
|
repo_id=repo_id,
|
|
|
commit_message="Upload GRPO fine-tuned model",
|
|
|
)
|
|
|
|
|
|
tokenizer.push_to_hub(repo_id)
|
|
|
|
|
|
logger.info("Upload complete! HF repo: https://huggingface.co/%s", repo_id)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@hydra.main(version_base=None, config_path="../../config/grpo", config_name="config")
|
|
|
def main(cfg: DictConfig):
|
|
|
logger.info("Configuration:\n%s", OmegaConf.to_yaml(cfg))
|
|
|
|
|
|
if not Path(cfg.dataset.file_path).exists():
|
|
|
logger.error("Dataset CSV file does not exist: %s", cfg.dataset.file_path)
|
|
|
return
|
|
|
|
|
|
os.makedirs(cfg.output_dir, exist_ok=True)
|
|
|
|
|
|
|
|
|
train_dataset = load_train_dataset(cfg.dataset)
|
|
|
|
|
|
|
|
|
resume_sft = cfg.get("resume_from_checkpoint_sft", None)
|
|
|
model = create_lora_model(cfg.model, resume_sft)
|
|
|
|
|
|
|
|
|
training_args = create_grpo_config(cfg.training, cfg.output_dir)
|
|
|
trainer = create_trainer(model, train_dataset, training_args)
|
|
|
|
|
|
|
|
|
train_and_save(
|
|
|
trainer,
|
|
|
cfg.output_dir,
|
|
|
resume_from_checkpoint=cfg.resume_from_checkpoint_grpo,
|
|
|
save_before_training=cfg.save_before_training,
|
|
|
)
|
|
|
|
|
|
|
|
|
if cfg.get("push_to_hub", False):
|
|
|
push_to_huggingface(
|
|
|
output_dir=cfg.output_dir,
|
|
|
repo_id=cfg.hf_repo_id,
|
|
|
model_id=cfg.model.model_id,
|
|
|
)
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
main()
|
|
|
|