File size: 5,174 Bytes
0c51b93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import os
from functools import partial

import torch
import wandb
from jinja2 import Environment, FileSystemLoader
from torch.nn.utils.rnn import pad_sequence
from transformers import (
    AutoConfig,
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    Trainer,
    TrainingArguments,
)

from sotopia_rl.data import SFTDataset

os.environ['NCCL_P2P_DISABLE'] = '1'
os.environ["TOKENIZERS_PARALLELISM"] = "false"

def sft_collate_fn(batch, tokenizer):
    input_ids = pad_sequence(
        [x["input_ids"] for x in batch], batch_first=True, padding_value=tokenizer.pad_token_id
    )
    attention_mask = pad_sequence(
        [x["attention_mask"] for x in batch], batch_first=True, padding_value=0
    )
    labels = pad_sequence(
        [x["labels"] for x in batch], batch_first=True, padding_value=-100
    )
    return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}


class SotopiaSFTTrainer(Trainer):
    def __init__(self, args, accelerator):
        # 1️⃣ Initialize wandb on main process
        self.accelerator = accelerator
        self.device = accelerator.device

        if self.accelerator.is_main_process:
            wandb.init(
                project=args.wandb_project,
                name=args.wandb_run_name,
                config={k: v for k, v in vars(args).items() if isinstance(v, (int, float, str))},
            )

        # 2️⃣ Load config + tokenizer
        config = AutoConfig.from_pretrained(args.model_name)
        config.use_cache = False
        tokenizer = AutoTokenizer.from_pretrained(args.model_name)
        tokenizer.model_max_length = args.max_length

        if args.use_qlora:
            quantization_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_compute_dtype=torch.bfloat16,
                bnb_4bit_use_double_quant=True,
                bnb_4bit_quant_type="nf4",
            )
            print(f"Using QLoRA (4bit) to load model: {args.model_name}")
            base_model = AutoModelForCausalLM.from_pretrained(
                args.model_name,
                torch_dtype=torch.float16,
                quantization_config=quantization_config,
            )
        else:
            base_model = AutoModelForCausalLM.from_pretrained(args.model_name).to(self.device)

        # 3️⃣ Load & (optional) LoRA-wrap model
        base_model = AutoModelForCausalLM.from_pretrained(args.model_name)
        if args.use_lora:
            from peft import LoraConfig, get_peft_model
            peft_config = LoraConfig(
                r=args.lora_r,
                lora_alpha=args.lora_alpha,
                lora_dropout=args.lora_dropout,
                target_modules=args.target_modules.split(","),
            )
            base_model = get_peft_model(base_model, peft_config)
        model = base_model

        # 4️⃣ Prepare dataset + split
        env = Environment(loader=FileSystemLoader(os.path.dirname(args.template_path)))
        template = env.get_template(os.path.basename(args.template_path))
        full_ds = SFTDataset(args.sft_data_path, tokenizer, template, args.max_length)
        train_size = int(0.95 * len(full_ds))
        val_size = len(full_ds) - train_size
        train_ds, eval_ds = torch.utils.data.random_split(
            full_ds, [train_size, val_size], generator=torch.Generator().manual_seed(42)
        )

        # 5️⃣ Build HF TrainingArguments
        hf_args = TrainingArguments(
            output_dir=args.checkpoint_dir,
            num_train_epochs=args.num_epochs,
            per_device_train_batch_size=args.train_batch_size,
            per_device_eval_batch_size=args.val_batch_size,
            gradient_accumulation_steps=args.accumulation_steps,
            learning_rate=args.learning_rate,
            weight_decay=args.weight_decay,
            eval_steps=args.evaluation_steps,
            save_steps=50,
            logging_dir="./logs",
            logging_steps=1,
            report_to="wandb",
            bf16=True,
            optim="paged_adamw_8bit" if args.use_qlora else "adamw_torch",
            dataloader_num_workers=4,
            ddp_find_unused_parameters=False,
            eval_strategy="steps",
            label_names=["labels"]
        )

        # 6️⃣ Call the Trainer constructor
        super().__init__(
            model=model,
            args=hf_args,
            train_dataset=train_ds,
            eval_dataset=eval_ds,
            data_collator=partial(sft_collate_fn, tokenizer=tokenizer),
            tokenizer=tokenizer,
        )

    def train(self, **kwargs):
        # run the usual HF train loop
        super().train(**kwargs)
        # then save your LoRA adapter if needed
        self._save_lora()
        # optionally run final evaluation
        return self.evaluate()

    def _save_lora(self):
        if getattr(self.args, "use_lora", False):
            ckpt = os.path.join(self.args.output_dir, "best_lora_checkpoint")
            os.makedirs(ckpt, exist_ok=True)
            # HF/PEFT save
            self.model.save_pretrained(ckpt)
            print(f"LoRA checkpoint saved at {ckpt}")