|
|
import os |
|
|
|
|
|
import torch |
|
|
import torch._dynamo |
|
|
import wandb |
|
|
from jinja2 import Environment, FileSystemLoader |
|
|
from peft import LoraConfig, PeftModelForSequenceClassification |
|
|
from torch.nn import MSELoss |
|
|
from torch.utils.data import random_split |
|
|
from transformers import ( |
|
|
AutoModelForSequenceClassification, |
|
|
AutoTokenizer, |
|
|
Trainer, |
|
|
TrainingArguments, |
|
|
) |
|
|
|
|
|
from sotopia_rl.data import RMDataset |
|
|
|
|
|
torch._dynamo.config.suppress_errors = True |
|
|
|
|
|
os.environ["NCCL_P2P_DISABLE"] = "1" |
|
|
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" |
|
|
|
|
|
|
|
|
class SotopiaRMTrainer(Trainer): |
|
|
def __init__(self, args, accelerator, **kwargs): |
|
|
self.args = args |
|
|
self.accelerator = accelerator |
|
|
self.device = accelerator.device |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(args.model_name) |
|
|
train_dataset, eval_dataset = self.setup_dataset(tokenizer) |
|
|
|
|
|
|
|
|
if self.accelerator.is_main_process: |
|
|
wandb.init( |
|
|
project=args.wandb_project, |
|
|
name=args.wandb_run_name, |
|
|
config={ |
|
|
k: v |
|
|
for k, v in vars(args).items() |
|
|
if isinstance(v, (int, float, str)) |
|
|
}, |
|
|
) |
|
|
|
|
|
peft_config = LoraConfig( |
|
|
r=args.lora_r, |
|
|
lora_alpha=args.lora_alpha, |
|
|
lora_dropout=args.lora_dropout, |
|
|
target_modules=args.target_modules.split(","), |
|
|
) |
|
|
|
|
|
base_model = AutoModelForSequenceClassification.from_pretrained( |
|
|
args.model_name, |
|
|
num_labels=1, |
|
|
torch_dtype="auto", |
|
|
) |
|
|
|
|
|
|
|
|
self.model = PeftModelForSequenceClassification(base_model, peft_config) |
|
|
self.model.config.pad_token_id = tokenizer.pad_token_id |
|
|
|
|
|
|
|
|
training_args = TrainingArguments( |
|
|
output_dir=args.checkpoint_dir, |
|
|
per_device_train_batch_size=args.train_batch_size, |
|
|
per_device_eval_batch_size=args.val_batch_size, |
|
|
num_train_epochs=args.num_epochs, |
|
|
logging_steps=1, |
|
|
save_steps=args.evaluation_steps, |
|
|
save_strategy="steps", |
|
|
logging_dir="./logs", |
|
|
gradient_accumulation_steps=args.accumulation_steps, |
|
|
learning_rate=args.learning_rate, |
|
|
warmup_steps=int(len(train_dataset) * args.warmup_epochs), |
|
|
optim="adamw_torch", |
|
|
remove_unused_columns=False, |
|
|
dataloader_num_workers=4, |
|
|
report_to="wandb", |
|
|
ddp_find_unused_parameters=False, |
|
|
eval_strategy="steps", |
|
|
label_names=["labels"], |
|
|
) |
|
|
|
|
|
collate_fn = ( |
|
|
train_dataset.dataset.collate_fn |
|
|
if hasattr(train_dataset, "dataset") |
|
|
else None |
|
|
) |
|
|
|
|
|
super().__init__( |
|
|
model=self.model, |
|
|
args=training_args, |
|
|
processing_class=tokenizer, |
|
|
train_dataset=train_dataset, |
|
|
eval_dataset=eval_dataset, |
|
|
data_collator=collate_fn, |
|
|
**kwargs, |
|
|
) |
|
|
self.loss_fn = MSELoss() |
|
|
|
|
|
def setup_dataset(self, tokenizer): |
|
|
env = Environment( |
|
|
loader=FileSystemLoader("/".join(self.args.template_path.split("/")[:-1])) |
|
|
) |
|
|
template = env.get_template(self.args.template_path.split("/")[-1]) |
|
|
dataset = RMDataset( |
|
|
self.args.reward_data_path, tokenizer, template, self.args.max_length |
|
|
) |
|
|
|
|
|
if self.accelerator.is_main_process: |
|
|
print(f"dataset: {len(dataset)}") |
|
|
|
|
|
train_size = int(len(dataset) * 0.95) |
|
|
val_size = len(dataset) - train_size |
|
|
|
|
|
|
|
|
generator = torch.Generator().manual_seed(42) |
|
|
train_dataset, val_dataset = random_split( |
|
|
dataset, [train_size, val_size], generator=generator |
|
|
) |
|
|
return train_dataset, val_dataset |
|
|
|
|
|
def compute_loss(self, model, inputs, return_outputs=False, **kwargs): |
|
|
input_ids = inputs["input_ids"] |
|
|
attention_masks = inputs["attention_mask"] |
|
|
true_rewards = inputs["labels"] |
|
|
|
|
|
outputs = model(input_ids, attention_mask=attention_masks) |
|
|
predicted_rewards = outputs.logits.squeeze(-1) |
|
|
loss = self.loss_fn(predicted_rewards, true_rewards) |
|
|
|
|
|
return (loss, outputs) if return_outputs else loss |
|
|
|