|
|
import argparse |
|
|
import os |
|
|
|
|
|
os.environ["TRANSFORMERS_NO_COMPILE"] = "1" |
|
|
from accelerate import Accelerator |
|
|
|
|
|
from sotopia_rl import SotopiaGRPOTrainer |
|
|
|
|
|
if __name__ == '__main__': |
|
|
parser = argparse.ArgumentParser(description="Train a model with GRPO using a reward model.") |
|
|
|
|
|
parser.add_argument("--model_name", type=str, default="/data/models/gemma-2-2b-it", |
|
|
help="Path to the model") |
|
|
|
|
|
parser.add_argument("--per_device_train_batch_size", type=int, default=1, |
|
|
help="Batch size per device for training") |
|
|
parser.add_argument("--per_device_eval_batch_size", type=int, default=1, |
|
|
help="Batch size per device for evaluation") |
|
|
parser.add_argument("--num_train_epochs", type=int, default=3, |
|
|
help="Number of training epochs") |
|
|
parser.add_argument("--num_grpo_epochs", type=int, default=4, |
|
|
help="Number of GRPO epochs per update") |
|
|
parser.add_argument("--learning_rate", type=float, default=5e-6, |
|
|
help="Learning rate for optimizer") |
|
|
parser.add_argument("--max_length", type=int, default=4096, |
|
|
help="Maximum length of input sequences") |
|
|
parser.add_argument("--gradient_accumulation_steps", type=int, default=1, |
|
|
help="Number of steps to accumulate gradients before performing an update") |
|
|
parser.add_argument("--val_ratio", type=float, default=0.05, |
|
|
help="Ratio of validation data") |
|
|
parser.add_argument("--response_length", type=int, default=128, |
|
|
help="Maximum length of generated responses") |
|
|
parser.add_argument("--num_generations", type=int, default=4, |
|
|
help="Number of generations for GRPO") |
|
|
parser.add_argument("--beta", type=float, default=0.04, |
|
|
help="Beta parameter for GRPO") |
|
|
|
|
|
|
|
|
parser.add_argument("--policy_adapter_path", type=str, default=None, |
|
|
help="Path to policy model adapter") |
|
|
parser.add_argument("--reward_adapter_path", type=str, default=None, |
|
|
help="Path to reward model adapter") |
|
|
|
|
|
|
|
|
parser.add_argument("--grpo_data_path", type=str, required=True, |
|
|
help="Path to the reward data file") |
|
|
parser.add_argument("--template_path", type=str, required=True, |
|
|
help="Path to the Jinja template file") |
|
|
parser.add_argument("--output_dir", type=str, default="checkpoints", |
|
|
help="Directory to save the best LoRA checkpoint") |
|
|
parser.add_argument("--save_steps", type=int, default=50, |
|
|
help="Number of steps between saving checkpoints") |
|
|
|
|
|
|
|
|
parser.add_argument("--wandb_project", type=str, default="grpo-model-training", |
|
|
help="Wandb project name") |
|
|
parser.add_argument("--wandb_run_name", type=str, default=None, |
|
|
help="Wandb run name") |
|
|
parser.add_argument("--use_lora_train_grpo", action="store_true", |
|
|
help="Use LoRA for training GRPO") |
|
|
|
|
|
args = parser.parse_args() |
|
|
accelerator = Accelerator() |
|
|
trainer = SotopiaGRPOTrainer(args, accelerator) |
|
|
trainer.train() |
|
|
|