| | import os |
| | from functools import partial |
| |
|
| | import torch |
| | import wandb |
| | from jinja2 import Environment, FileSystemLoader |
| | from torch.nn.utils.rnn import pad_sequence |
| | from transformers import ( |
| | AutoConfig, |
| | AutoModelForCausalLM, |
| | AutoTokenizer, |
| | BitsAndBytesConfig, |
| | Trainer, |
| | TrainingArguments, |
| | ) |
| |
|
| | from sotopia_rl.data import SFTDataset |
| |
|
| | os.environ['NCCL_P2P_DISABLE'] = '1' |
| | os.environ["TOKENIZERS_PARALLELISM"] = "false" |
| |
|
| | def sft_collate_fn(batch, tokenizer): |
| | input_ids = pad_sequence( |
| | [x["input_ids"] for x in batch], batch_first=True, padding_value=tokenizer.pad_token_id |
| | ) |
| | attention_mask = pad_sequence( |
| | [x["attention_mask"] for x in batch], batch_first=True, padding_value=0 |
| | ) |
| | labels = pad_sequence( |
| | [x["labels"] for x in batch], batch_first=True, padding_value=-100 |
| | ) |
| | return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels} |
| |
|
| |
|
| | class SotopiaSFTTrainer(Trainer): |
| | def __init__(self, args, accelerator): |
| | |
| | self.accelerator = accelerator |
| | self.device = accelerator.device |
| |
|
| | if self.accelerator.is_main_process: |
| | wandb.init( |
| | project=args.wandb_project, |
| | name=args.wandb_run_name, |
| | config={k: v for k, v in vars(args).items() if isinstance(v, (int, float, str))}, |
| | ) |
| |
|
| | |
| | config = AutoConfig.from_pretrained(args.model_name) |
| | config.use_cache = False |
| | tokenizer = AutoTokenizer.from_pretrained(args.model_name) |
| | tokenizer.model_max_length = args.max_length |
| |
|
| | if args.use_qlora: |
| | quantization_config = BitsAndBytesConfig( |
| | load_in_4bit=True, |
| | bnb_4bit_compute_dtype=torch.bfloat16, |
| | bnb_4bit_use_double_quant=True, |
| | bnb_4bit_quant_type="nf4", |
| | ) |
| | print(f"Using QLoRA (4bit) to load model: {args.model_name}") |
| | base_model = AutoModelForCausalLM.from_pretrained( |
| | args.model_name, |
| | torch_dtype=torch.float16, |
| | quantization_config=quantization_config, |
| | ) |
| | else: |
| | base_model = AutoModelForCausalLM.from_pretrained(args.model_name).to(self.device) |
| |
|
| | |
| | base_model = AutoModelForCausalLM.from_pretrained(args.model_name) |
| | if args.use_lora: |
| | from peft import LoraConfig, get_peft_model |
| | peft_config = LoraConfig( |
| | r=args.lora_r, |
| | lora_alpha=args.lora_alpha, |
| | lora_dropout=args.lora_dropout, |
| | target_modules=args.target_modules.split(","), |
| | ) |
| | base_model = get_peft_model(base_model, peft_config) |
| | model = base_model |
| |
|
| | |
| | env = Environment(loader=FileSystemLoader(os.path.dirname(args.template_path))) |
| | template = env.get_template(os.path.basename(args.template_path)) |
| | full_ds = SFTDataset(args.sft_data_path, tokenizer, template, args.max_length) |
| | train_size = int(0.95 * len(full_ds)) |
| | val_size = len(full_ds) - train_size |
| | train_ds, eval_ds = torch.utils.data.random_split( |
| | full_ds, [train_size, val_size], generator=torch.Generator().manual_seed(42) |
| | ) |
| |
|
| | |
| | hf_args = TrainingArguments( |
| | output_dir=args.checkpoint_dir, |
| | num_train_epochs=args.num_epochs, |
| | per_device_train_batch_size=args.train_batch_size, |
| | per_device_eval_batch_size=args.val_batch_size, |
| | gradient_accumulation_steps=args.accumulation_steps, |
| | learning_rate=args.learning_rate, |
| | weight_decay=args.weight_decay, |
| | eval_steps=args.evaluation_steps, |
| | save_steps=50, |
| | logging_dir="./logs", |
| | logging_steps=1, |
| | report_to="wandb", |
| | bf16=True, |
| | optim="paged_adamw_8bit" if args.use_qlora else "adamw_torch", |
| | dataloader_num_workers=4, |
| | ddp_find_unused_parameters=False, |
| | eval_strategy="steps", |
| | label_names=["labels"] |
| | ) |
| |
|
| | |
| | super().__init__( |
| | model=model, |
| | args=hf_args, |
| | train_dataset=train_ds, |
| | eval_dataset=eval_ds, |
| | data_collator=partial(sft_collate_fn, tokenizer=tokenizer), |
| | tokenizer=tokenizer, |
| | ) |
| |
|
| | def train(self, **kwargs): |
| | |
| | super().train(**kwargs) |
| | |
| | self._save_lora() |
| | |
| | return self.evaluate() |
| |
|
| | def _save_lora(self): |
| | if getattr(self.args, "use_lora", False): |
| | ckpt = os.path.join(self.args.output_dir, "best_lora_checkpoint") |
| | os.makedirs(ckpt, exist_ok=True) |
| | |
| | self.model.save_pretrained(ckpt) |
| | print(f"LoRA checkpoint saved at {ckpt}") |
| |
|