File size: 5,174 Bytes
0c51b93 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 | import os
from functools import partial
import torch
import wandb
from jinja2 import Environment, FileSystemLoader
from torch.nn.utils.rnn import pad_sequence
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
Trainer,
TrainingArguments,
)
from sotopia_rl.data import SFTDataset
os.environ['NCCL_P2P_DISABLE'] = '1'
os.environ["TOKENIZERS_PARALLELISM"] = "false"
def sft_collate_fn(batch, tokenizer):
input_ids = pad_sequence(
[x["input_ids"] for x in batch], batch_first=True, padding_value=tokenizer.pad_token_id
)
attention_mask = pad_sequence(
[x["attention_mask"] for x in batch], batch_first=True, padding_value=0
)
labels = pad_sequence(
[x["labels"] for x in batch], batch_first=True, padding_value=-100
)
return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}
class SotopiaSFTTrainer(Trainer):
def __init__(self, args, accelerator):
# 1️⃣ Initialize wandb on main process
self.accelerator = accelerator
self.device = accelerator.device
if self.accelerator.is_main_process:
wandb.init(
project=args.wandb_project,
name=args.wandb_run_name,
config={k: v for k, v in vars(args).items() if isinstance(v, (int, float, str))},
)
# 2️⃣ Load config + tokenizer
config = AutoConfig.from_pretrained(args.model_name)
config.use_cache = False
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
tokenizer.model_max_length = args.max_length
if args.use_qlora:
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
)
print(f"Using QLoRA (4bit) to load model: {args.model_name}")
base_model = AutoModelForCausalLM.from_pretrained(
args.model_name,
torch_dtype=torch.float16,
quantization_config=quantization_config,
)
else:
base_model = AutoModelForCausalLM.from_pretrained(args.model_name).to(self.device)
# 3️⃣ Load & (optional) LoRA-wrap model
base_model = AutoModelForCausalLM.from_pretrained(args.model_name)
if args.use_lora:
from peft import LoraConfig, get_peft_model
peft_config = LoraConfig(
r=args.lora_r,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
target_modules=args.target_modules.split(","),
)
base_model = get_peft_model(base_model, peft_config)
model = base_model
# 4️⃣ Prepare dataset + split
env = Environment(loader=FileSystemLoader(os.path.dirname(args.template_path)))
template = env.get_template(os.path.basename(args.template_path))
full_ds = SFTDataset(args.sft_data_path, tokenizer, template, args.max_length)
train_size = int(0.95 * len(full_ds))
val_size = len(full_ds) - train_size
train_ds, eval_ds = torch.utils.data.random_split(
full_ds, [train_size, val_size], generator=torch.Generator().manual_seed(42)
)
# 5️⃣ Build HF TrainingArguments
hf_args = TrainingArguments(
output_dir=args.checkpoint_dir,
num_train_epochs=args.num_epochs,
per_device_train_batch_size=args.train_batch_size,
per_device_eval_batch_size=args.val_batch_size,
gradient_accumulation_steps=args.accumulation_steps,
learning_rate=args.learning_rate,
weight_decay=args.weight_decay,
eval_steps=args.evaluation_steps,
save_steps=50,
logging_dir="./logs",
logging_steps=1,
report_to="wandb",
bf16=True,
optim="paged_adamw_8bit" if args.use_qlora else "adamw_torch",
dataloader_num_workers=4,
ddp_find_unused_parameters=False,
eval_strategy="steps",
label_names=["labels"]
)
# 6️⃣ Call the Trainer constructor
super().__init__(
model=model,
args=hf_args,
train_dataset=train_ds,
eval_dataset=eval_ds,
data_collator=partial(sft_collate_fn, tokenizer=tokenizer),
tokenizer=tokenizer,
)
def train(self, **kwargs):
# run the usual HF train loop
super().train(**kwargs)
# then save your LoRA adapter if needed
self._save_lora()
# optionally run final evaluation
return self.evaluate()
def _save_lora(self):
if getattr(self.args, "use_lora", False):
ckpt = os.path.join(self.args.output_dir, "best_lora_checkpoint")
os.makedirs(ckpt, exist_ok=True)
# HF/PEFT save
self.model.save_pretrained(ckpt)
print(f"LoRA checkpoint saved at {ckpt}")
|