import argparse import os os.environ["TRANSFORMERS_NO_COMPILE"] = "1" from accelerate import Accelerator from sotopia_rl import SotopiaGRPOTrainer if __name__ == '__main__': parser = argparse.ArgumentParser(description="Train a model with GRPO using a reward model.") parser.add_argument("--model_name", type=str, default="/data/models/gemma-2-2b-it", help="Path to the model") parser.add_argument("--per_device_train_batch_size", type=int, default=1, help="Batch size per device for training") parser.add_argument("--per_device_eval_batch_size", type=int, default=1, help="Batch size per device for evaluation") parser.add_argument("--num_train_epochs", type=int, default=3, help="Number of training epochs") parser.add_argument("--num_grpo_epochs", type=int, default=4, help="Number of GRPO epochs per update") parser.add_argument("--learning_rate", type=float, default=5e-6, help="Learning rate for optimizer") parser.add_argument("--max_length", type=int, default=4096, help="Maximum length of input sequences") parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Number of steps to accumulate gradients before performing an update") parser.add_argument("--val_ratio", type=float, default=0.05, help="Ratio of validation data") parser.add_argument("--response_length", type=int, default=128, help="Maximum length of generated responses") parser.add_argument("--num_generations", type=int, default=4, help="Number of generations for GRPO") parser.add_argument("--beta", type=float, default=0.04, help="Beta parameter for GRPO") # Adapter parameters parser.add_argument("--policy_adapter_path", type=str, default=None, help="Path to policy model adapter") parser.add_argument("--reward_adapter_path", type=str, default=None, help="Path to reward model adapter") # Data and checkpoint paths parser.add_argument("--grpo_data_path", type=str, required=True, help="Path to the reward data file") parser.add_argument("--template_path", type=str, required=True, help="Path to the Jinja template file") parser.add_argument("--output_dir", type=str, default="checkpoints", help="Directory to save the best LoRA checkpoint") parser.add_argument("--save_steps", type=int, default=50, help="Number of steps between saving checkpoints") # Logging parameters parser.add_argument("--wandb_project", type=str, default="grpo-model-training", help="Wandb project name") parser.add_argument("--wandb_run_name", type=str, default=None, help="Wandb run name") parser.add_argument("--use_lora_train_grpo", action="store_true", help="Use LoRA for training GRPO") args = parser.parse_args() accelerator = Accelerator() trainer = SotopiaGRPOTrainer(args, accelerator) trainer.train()