basketball_code / sotopia_rl /ppo_trainer.py
youqiwong's picture
Upload folder using huggingface_hub
0c51b93 verified
import os
import torch
import wandb
from accelerate import Accelerator, PartialState
from jinja2 import Environment, FileSystemLoader
from peft import PeftModelForCausalLM, PeftModelForSequenceClassification
from torch.utils.data import random_split
from transformers import (
AutoModelForCausalLM,
AutoModelForSequenceClassification,
AutoTokenizer,
BitsAndBytesConfig,
)
from trl import PPOConfig, PPOTrainer, get_kbit_device_map
from sotopia_rl.data import PPODataset
os.environ['NCCL_P2P_DISABLE'] = '1'
os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
class SotopiaPPOTrainer:
def __init__(self, args, accelerator: Accelerator):
self.accelerator = accelerator
self.args = args
self._init_wandb()
self._setup_tokenizer()
self._setup_dataset()
self._create_quantization_config()
self._setup_generation_models()
self._setup_classification_models()
self.policy, self.ref_policy, self.reward_model, self.value_model = self.accelerator.prepare(
self.policy, self.ref_policy, self.reward_model, self.value_model
)
self.policy = self.accelerator.unwrap_model(self.policy)
self.ref_policy = self.accelerator.unwrap_model(self.ref_policy)
self.reward_model = self.accelerator.unwrap_model(self.reward_model)
self.value_model = self.accelerator.unwrap_model(self.value_model)
self._setup_ppo_trainer()
def save_model(self, output_dir: str, _internal_call: bool = False):
if hasattr(self.model, "policy"):
model_to_save = self.model.policy
elif hasattr(self.model, "module") and hasattr(self.model.module, "policy"):
model_to_save = self.model.module.policy
else:
model_to_save = self.model
model_to_save.save_pretrained(output_dir)
self.tokenizer.save_pretrained(output_dir)
print(f"Model saved to {output_dir}")
self.ppo_trainer.save_model = save_model.__get__(self.ppo_trainer, type(self.ppo_trainer))
def _init_wandb(self):
wandb.init(
project=self.args.wandb_project,
name=self.args.wandb_run_name,
config={k: v for k, v in vars(self.args).items() if isinstance(v, (int, float, str))}
)
def _setup_tokenizer(self):
self.tokenizer = AutoTokenizer.from_pretrained(self.args.model_name, padding_side="left")
self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
def _setup_dataset(self):
with PartialState().local_main_process_first():
template_dir = "/".join(self.args.template_path.split("/")[:-1])
template_file = self.args.template_path.split("/")[-1]
env = Environment(loader=FileSystemLoader(template_dir))
template = env.get_template(template_file)
# Create and split dataset
dataset = PPODataset(
self.args.ppo_data_path,
self.tokenizer,
template,
max_length=self.args.max_length
)
print(f"dataset: {len(dataset)}")
generator = torch.Generator().manual_seed(42)
val_ratio = getattr(self.args, 'val_ratio', 0.05)
train_size = min(int(len(dataset) * (1 - val_ratio)), len(dataset) - 2)
val_size = len(dataset) - train_size
self.train_dataset, self.val_dataset = random_split(dataset, [train_size, val_size], generator=generator)
print(f"Dataset split: {len(self.train_dataset)} train, {len(self.val_dataset)} validation")
def _create_quantization_config(self):
self.quant_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4"
)
def _setup_generation_models(self):
base_gen_ref = AutoModelForCausalLM.from_pretrained(
self.args.model_name,
torch_dtype='auto',
quantization_config=self.quant_config,
device_map=get_kbit_device_map(),
)
self.ref_policy = PeftModelForCausalLM.from_pretrained(
base_gen_ref,
self.args.ref_adapter_path,
is_trainable=False,
adapter_name="ref_adapter"
)
if self.args.use_lora_train_ppo:
base_gen_policy = AutoModelForCausalLM.from_pretrained(
self.args.model_name,
torch_dtype='auto',
quantization_config=self.quant_config,
device_map=get_kbit_device_map(),
)
self.policy = PeftModelForCausalLM.from_pretrained(
base_gen_policy,
self.args.policy_adapter_path,
is_trainable=True,
adapter_name="policy_adapter"
)
else:
self.policy = AutoModelForCausalLM.from_pretrained(
self.args.model_name,
torch_dtype='auto',
)
requires_grad_num = 0
for name, param in self.policy.named_parameters():
print(name, param.requires_grad)
if param.requires_grad:
requires_grad_num += 1
print(f"Number of trainable parameters in policy: {requires_grad_num}")
requires_grad_num = 0
for name, param in self.ref_policy.named_parameters():
if param.requires_grad:
requires_grad_num += 1
print(f"Number of trainable parameters in ref policy: {requires_grad_num}")
def _setup_classification_models(self):
base_reward_model = AutoModelForSequenceClassification.from_pretrained(
self.args.model_name,
torch_dtype='auto',
num_labels=1,
quantization_config=self.quant_config,
device_map=get_kbit_device_map(),
)
self.reward_model = PeftModelForSequenceClassification.from_pretrained(
base_reward_model,
self.args.reward_adapter_path,
is_trainable=False,
adapter_name="reward_adapter"
)
for name, param in self.reward_model.named_parameters():
if self.reward_model.active_adapter in name:
param.requires_grad = False
if self.args.use_lora_train_ppo:
base_value_model = AutoModelForSequenceClassification.from_pretrained(
self.args.model_name,
torch_dtype='auto',
num_labels=1,
quantization_config=self.quant_config,
device_map=get_kbit_device_map(),
)
self.value_model = PeftModelForSequenceClassification.from_pretrained(
base_value_model,
self.args.value_adapter_path,
is_trainable=True,
adapter_name="value_adapter"
)
else:
self.value_model = AutoModelForSequenceClassification.from_pretrained(
self.args.model_name,
torch_dtype='auto',
num_labels=1,
)
# VERY VERY IMPORTANT
# specifically designed for PPO training,
# based on the get_reward function
# it fill the input_ids paddings with 0s
self.value_model.config.pad_token_id = 0
self.reward_model.config.pad_token_id = 0
requires_grad_num = 0
for name, param in self.value_model.named_parameters():
if param.requires_grad:
requires_grad_num += 1
print(f"Number of trainable parameters in value model: {requires_grad_num}")
requires_grad_num = 0
for name, param in self.reward_model.named_parameters():
if param.requires_grad:
requires_grad_num += 1
print(f"Number of trainable parameters in reward model: {requires_grad_num}")
def _setup_ppo_trainer(self):
training_args = PPOConfig(
per_device_train_batch_size=self.args.per_device_train_batch_size,
per_device_eval_batch_size=self.args.per_device_eval_batch_size,
num_mini_batches=self.args.num_mini_batches,
local_rollout_forward_batch_size=self.args.local_rollout_forward_batch_size,
gradient_accumulation_steps=self.args.gradient_accumulation_steps,
num_train_epochs=self.args.num_train_epochs,
num_ppo_epochs=self.args.num_ppo_epochs,
learning_rate=self.args.learning_rate,
output_dir=self.args.output_dir,
gamma=self.args.gamma,
lam=self.args.lam,
save_steps=self.args.save_steps,
missing_eos_penalty=self.args.missing_eos_penalty,
ddp_find_unused_parameters=True,
response_length=self.args.response_length,
stop_token='eos',
)
self.ppo_trainer = PPOTrainer(
args=training_args,
model=self.policy,
processing_class=self.tokenizer,
ref_model=self.ref_policy,
reward_model=self.reward_model,
value_model=self.value_model,
train_dataset=self.train_dataset,
eval_dataset=self.val_dataset,
)
print("PPOtrainer setup complete")
def train(self):
try:
print("Starting PPO training...")
train_stats = self.ppo_trainer.train()
return train_stats
except Exception as e:
print(f"Training error: {str(e)}")
raise