File size: 9,629 Bytes
0c51b93 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 | import os
import torch
import wandb
from accelerate import Accelerator, PartialState
from jinja2 import Environment, FileSystemLoader
from peft import PeftModelForCausalLM, PeftModelForSequenceClassification
from torch.utils.data import random_split
from transformers import (
AutoModelForCausalLM,
AutoModelForSequenceClassification,
AutoTokenizer,
BitsAndBytesConfig,
)
from trl import PPOConfig, PPOTrainer, get_kbit_device_map
from sotopia_rl.data import PPODataset
os.environ['NCCL_P2P_DISABLE'] = '1'
os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
class SotopiaPPOTrainer:
def __init__(self, args, accelerator: Accelerator):
self.accelerator = accelerator
self.args = args
self._init_wandb()
self._setup_tokenizer()
self._setup_dataset()
self._create_quantization_config()
self._setup_generation_models()
self._setup_classification_models()
self.policy, self.ref_policy, self.reward_model, self.value_model = self.accelerator.prepare(
self.policy, self.ref_policy, self.reward_model, self.value_model
)
self.policy = self.accelerator.unwrap_model(self.policy)
self.ref_policy = self.accelerator.unwrap_model(self.ref_policy)
self.reward_model = self.accelerator.unwrap_model(self.reward_model)
self.value_model = self.accelerator.unwrap_model(self.value_model)
self._setup_ppo_trainer()
def save_model(self, output_dir: str, _internal_call: bool = False):
if hasattr(self.model, "policy"):
model_to_save = self.model.policy
elif hasattr(self.model, "module") and hasattr(self.model.module, "policy"):
model_to_save = self.model.module.policy
else:
model_to_save = self.model
model_to_save.save_pretrained(output_dir)
self.tokenizer.save_pretrained(output_dir)
print(f"Model saved to {output_dir}")
self.ppo_trainer.save_model = save_model.__get__(self.ppo_trainer, type(self.ppo_trainer))
def _init_wandb(self):
wandb.init(
project=self.args.wandb_project,
name=self.args.wandb_run_name,
config={k: v for k, v in vars(self.args).items() if isinstance(v, (int, float, str))}
)
def _setup_tokenizer(self):
self.tokenizer = AutoTokenizer.from_pretrained(self.args.model_name, padding_side="left")
self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
def _setup_dataset(self):
with PartialState().local_main_process_first():
template_dir = "/".join(self.args.template_path.split("/")[:-1])
template_file = self.args.template_path.split("/")[-1]
env = Environment(loader=FileSystemLoader(template_dir))
template = env.get_template(template_file)
# Create and split dataset
dataset = PPODataset(
self.args.ppo_data_path,
self.tokenizer,
template,
max_length=self.args.max_length
)
print(f"dataset: {len(dataset)}")
generator = torch.Generator().manual_seed(42)
val_ratio = getattr(self.args, 'val_ratio', 0.05)
train_size = min(int(len(dataset) * (1 - val_ratio)), len(dataset) - 2)
val_size = len(dataset) - train_size
self.train_dataset, self.val_dataset = random_split(dataset, [train_size, val_size], generator=generator)
print(f"Dataset split: {len(self.train_dataset)} train, {len(self.val_dataset)} validation")
def _create_quantization_config(self):
self.quant_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4"
)
def _setup_generation_models(self):
base_gen_ref = AutoModelForCausalLM.from_pretrained(
self.args.model_name,
torch_dtype='auto',
quantization_config=self.quant_config,
device_map=get_kbit_device_map(),
)
self.ref_policy = PeftModelForCausalLM.from_pretrained(
base_gen_ref,
self.args.ref_adapter_path,
is_trainable=False,
adapter_name="ref_adapter"
)
if self.args.use_lora_train_ppo:
base_gen_policy = AutoModelForCausalLM.from_pretrained(
self.args.model_name,
torch_dtype='auto',
quantization_config=self.quant_config,
device_map=get_kbit_device_map(),
)
self.policy = PeftModelForCausalLM.from_pretrained(
base_gen_policy,
self.args.policy_adapter_path,
is_trainable=True,
adapter_name="policy_adapter"
)
else:
self.policy = AutoModelForCausalLM.from_pretrained(
self.args.model_name,
torch_dtype='auto',
)
requires_grad_num = 0
for name, param in self.policy.named_parameters():
print(name, param.requires_grad)
if param.requires_grad:
requires_grad_num += 1
print(f"Number of trainable parameters in policy: {requires_grad_num}")
requires_grad_num = 0
for name, param in self.ref_policy.named_parameters():
if param.requires_grad:
requires_grad_num += 1
print(f"Number of trainable parameters in ref policy: {requires_grad_num}")
def _setup_classification_models(self):
base_reward_model = AutoModelForSequenceClassification.from_pretrained(
self.args.model_name,
torch_dtype='auto',
num_labels=1,
quantization_config=self.quant_config,
device_map=get_kbit_device_map(),
)
self.reward_model = PeftModelForSequenceClassification.from_pretrained(
base_reward_model,
self.args.reward_adapter_path,
is_trainable=False,
adapter_name="reward_adapter"
)
for name, param in self.reward_model.named_parameters():
if self.reward_model.active_adapter in name:
param.requires_grad = False
if self.args.use_lora_train_ppo:
base_value_model = AutoModelForSequenceClassification.from_pretrained(
self.args.model_name,
torch_dtype='auto',
num_labels=1,
quantization_config=self.quant_config,
device_map=get_kbit_device_map(),
)
self.value_model = PeftModelForSequenceClassification.from_pretrained(
base_value_model,
self.args.value_adapter_path,
is_trainable=True,
adapter_name="value_adapter"
)
else:
self.value_model = AutoModelForSequenceClassification.from_pretrained(
self.args.model_name,
torch_dtype='auto',
num_labels=1,
)
# VERY VERY IMPORTANT
# specifically designed for PPO training,
# based on the get_reward function
# it fill the input_ids paddings with 0s
self.value_model.config.pad_token_id = 0
self.reward_model.config.pad_token_id = 0
requires_grad_num = 0
for name, param in self.value_model.named_parameters():
if param.requires_grad:
requires_grad_num += 1
print(f"Number of trainable parameters in value model: {requires_grad_num}")
requires_grad_num = 0
for name, param in self.reward_model.named_parameters():
if param.requires_grad:
requires_grad_num += 1
print(f"Number of trainable parameters in reward model: {requires_grad_num}")
def _setup_ppo_trainer(self):
training_args = PPOConfig(
per_device_train_batch_size=self.args.per_device_train_batch_size,
per_device_eval_batch_size=self.args.per_device_eval_batch_size,
num_mini_batches=self.args.num_mini_batches,
local_rollout_forward_batch_size=self.args.local_rollout_forward_batch_size,
gradient_accumulation_steps=self.args.gradient_accumulation_steps,
num_train_epochs=self.args.num_train_epochs,
num_ppo_epochs=self.args.num_ppo_epochs,
learning_rate=self.args.learning_rate,
output_dir=self.args.output_dir,
gamma=self.args.gamma,
lam=self.args.lam,
save_steps=self.args.save_steps,
missing_eos_penalty=self.args.missing_eos_penalty,
ddp_find_unused_parameters=True,
response_length=self.args.response_length,
stop_token='eos',
)
self.ppo_trainer = PPOTrainer(
args=training_args,
model=self.policy,
processing_class=self.tokenizer,
ref_model=self.ref_policy,
reward_model=self.reward_model,
value_model=self.value_model,
train_dataset=self.train_dataset,
eval_dataset=self.val_dataset,
)
print("PPOtrainer setup complete")
def train(self):
try:
print("Starting PPO training...")
train_stats = self.ppo_trainer.train()
return train_stats
except Exception as e:
print(f"Training error: {str(e)}")
raise
|