File size: 9,891 Bytes
4f2b2f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
import torch
import torch.nn as nn
import argparse
from transformers import AutoTokenizer, AutoModel, TrainingArguments
from transformers.trainer_callback import TrainerCallback
from transformers.trainer_utils import is_main_process
from datasets import load_dataset, load_from_disk, Features, Sequence, Value, concatenate_datasets
from datasets.distributed import split_dataset_by_node
import os, multiprocessing, random, pathlib
from torch.utils.data import DataLoader
from peft import LoraConfig, get_peft_model, TaskType
from flexmdm_trainer import *
from collections import Counter
from llada_dit import LLaDA_DIT
from pathlib import Path
import torch.distributed as dist
import random
import tqdm
import numpy as np
import wandb
import glob



def init_seed(seed):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # torch.backends.cudnn.deterministic = True # for the training speed, we comment this out


# ------------------------------------------------------------
# Util function for logging
# ------------------------------------------------------------
def count_parameters(named_params, key: str | None = None):
    return sum(p.numel()
        for n, p in named_params
        if p.requires_grad and (key is None or key in n)
    )

class LogLrCallback(TrainerCallback):
    def on_step_end(self, args, state, control, **kwargs):
        if not is_main_process(args):
            return
        opt = kwargs["optimizer"]
        wandb.log(
            {
                "lr/lora": opt.param_groups[0]["lr"],
                "lr/token_head": opt.param_groups[1]["lr"],
                "lr/from_scratch": opt.param_groups[2]["lr"],
                "step": state.global_step,
            }
        )


# Initialize argument parser
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model_name", type=str, default="GSAI-ML/LLaDA-8B-Base", help="Name of the pretrained model"
    )

    # Training hyperparameters
    parser.add_argument("--batch_size", type=int, default=4, help="batch size per device")
    parser.add_argument("--lora_lr", type=float, default=1e-4, help="Learning rate for the LoRA")
    parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate for other parameters")
    parser.add_argument("--grad_accum_steps", type=int, default=2, help="Gradient accumulation steps")
    parser.add_argument("--max_steps", type=int, default=500000, help="Maximum number of training steps")
    parser.add_argument("--resume_from_checkpoint", type=str, default=None, help="Path to the checkpoint to resume from")
    parser.add_argument("--low_discrepancy", type=bool, default=False, help="whether to use low discrepancy sampling")

    # Output directory and job name
    parser.add_argument(
        "--output_dir",
        type=str,
        default="/n/netscratch/albergo_lab/Lab/transdim-flow/sft-datamix-checkpoints",
        help="Directory to save model checkpoints and logs",
    )
    parser.add_argument("--job_name", type=str, default="llada-sft-openwebtext", help="Job Name")
    parser.add_argument("--train_data", type=str, default="openwebtext", help="Path to training data")
    parser.add_argument("--wandb", action="store_true", help="whether to use wandb")
    parser.add_argument("--variable_length", action="store_true", help="whether to use variable length training")
    parser.add_argument("--sanity_run", action="store_true", help="whether to run the sanity run (overfitting the model)")

    # CLI flags for openwebtext dataset preprocessing
    parser.add_argument("--sft_max_length", type=int, default=1024, help="Maximum sequence length for tokenization")
    parser.add_argument("--cache_path", type=str, default="/n/netscratch/albergo_lab/Everyone/jay_brian/datamix", help="Path of the tokenized openwebtext dataset")

    return parser.parse_args()



# Model loading with LoRA integration
def load_model_and_tokenizer(args):
    # Load the backbone LLaDA model
    backbone = AutoModel.from_pretrained(
        args.model_name,
        trust_remote_code=True,
        torch_dtype=torch.bfloat16,
        return_dict=True,
    )

    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(args.model_name, padding_side="right", trust_remote_code=True, use_fast=True)

    print("Tokenizer and backbone loaded!")

    backbone.config.output_hidden_states = True
    backbone.config.return_dict = True

    # lora adapter for the backbone LLaDA
    lora_config = LoraConfig(
        r=128,
        lora_alpha=128,
        target_modules=["q_proj", "k_proj", "v_proj", "transformer.ff_out"],
        lora_dropout=0.1,
        bias="none",
        task_type=TaskType.CAUSAL_LM,
    )
    backbone = get_peft_model(backbone, lora_config)
    backbone = backbone.to(torch.bfloat16)

    if args.variable_length:
        model = LLaDA_DIT(backbone, pad_token_id = tokenizer.pad_token_id, d_model = 4096)
    else:
        model = backbone

    if args.resume_from_checkpoint:
        ckpt_dir = Path(args.resume_from_checkpoint)
        state = torch.load(ckpt_dir/ "pytorch_model.bin", map_location="cpu")
        model.load_state_dict(state, strict=False)
        print(f"Resumed from checkpoint {args.resume_from_checkpoint}")

    print("Final trainer model loaded!")
    
    return tokenizer, model


# Dataset loading
def load_data(args, tokenizer):
    # load the pre-processed tokenzied dataset (already int64)
    cache_dir = pathlib.Path(args.cache_path)
    if not cache_dir.exists():
        raise FileNotFoundError(f"Cache directory {cache_dir} does not exist")
    ds = load_from_disk(cache_dir)
    ds = ds.shuffle(seed=42)
    data = ds.train_test_split(test_size=0.001, seed=42)
    print("Training and evaluation datasets successfully loaded!")

    if args.sanity_run:
        data = data["train"].select(range(128))
        print("Sanity run dataset loaded!")
        data.save_to_disk("sanity_run_dataset")
        return data, data

    return data["train"], data["test"]


# Training setup
def train_model(args, tokenizer, model):
    # Load dataset
    train_dataset, eval_dataset = load_data(args, tokenizer)

    # Training arguments setup
    training_args = TrainingArguments(
        output_dir=os.path.join(args.output_dir, args.job_name),
        max_steps = args.max_steps,
        per_device_train_batch_size=args.batch_size,
        gradient_accumulation_steps=args.grad_accum_steps,
        eval_strategy= 'steps',
        eval_steps = 1000,
        prediction_loss_only = True,
        logging_steps = 10,
        save_steps = 10000,
        save_total_limit=20,
        save_safetensors=False,
        max_grad_norm=1.0,
        bf16=True,
        lr_scheduler_type="cosine",
        lr_scheduler_kwargs={"num_cycles": 5},
        warmup_ratio=0.05,
        remove_unused_columns=False,
        report_to="wandb" if args.wandb else None,
    )

    # setup the trainable parameters
    lora_params  = [p for n, p in model.named_parameters() if "lora" in n and p.requires_grad]
    head_params  = [p for n, p in model.named_parameters() if "lora" not in n and "ff_out" in n and p.requires_grad]
    from_scratch_params = [p for n, p in model.named_parameters() if "lora" not in n and "ff_out" not in n and p.requires_grad]

    trainable = [p for _, p in model.named_parameters() if p.requires_grad]
    assert set(trainable) == set(lora_params) | set(head_params) | set(from_scratch_params), "Trainable parameters are not correctly set"

    # parameter count check
    print(f"Total trainable parameters: {count_parameters(model.named_parameters(), key = None)}")
    print(f"  └─ LoRA adapter params          : {count_parameters(model.named_parameters() , key = 'lora')}")
    print(f"  └─ Token Head params                  : {count_parameters(model.named_parameters(), key = 'ff_out')}")
    print(f"  └─ Scalar Length Head params        : {count_parameters(model.named_parameters(), key = 'scalar_length_head')}")
    print(f"  └─ Time embedding params            : {count_parameters(model.named_parameters(), key = '.temb_mod')}")


    # Initialize Trainer with custom dLLMTrainer
    if args.variable_length:
        optimizer = torch.optim.AdamW(
            [
                {"params": lora_params, "lr": args.lora_lr, "weight_decay": 0.0},
                {"params": head_params, "lr": args.lora_lr / 4, "weight_decay": 0.01},
                {"params": from_scratch_params, "lr": args.lr, "weight_decay": 0.01}
            ],
        )
        trainer = dLLMVariableLengthTrainer(
            model=model,
            args=training_args,
            data_collator=dLLMVariableDataCollator(tokenizer=tokenizer, mask_token_id=126336, 
            max_length=args.sft_max_length, compute_metrics = None, 
                low_discrepancy = args.low_discrepancy),
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            optimizers=(optimizer, None),
        )
    else:
        raise NotImplementedError("Currently we don't support fixed length training")

    local_rank = int(os.environ.get("LOCAL_RANK", 0))

    if args.wandb and local_rank == 0:
        wandb.init(project="SFT-llada", name=args.job_name, entity="jaeyeon_kim-harvard-university")

    # double-check the optimizer
    for i, g in enumerate(trainer.optimizer.param_groups):
        print(f"group {i}  init-lr={g['lr']}  wd={g['weight_decay']}")

    # add the callback
    trainer.add_callback(LogLrCallback())

    # Start training
    trainer.train()


if __name__ == "__main__":
    init_seed(42)
    # Parse command-line arguments
    args = parse_args()

    # Load model and tokenizer
    tokenizer, model = load_model_and_tokenizer(args)

    # Train the model
    train_model(args, tokenizer, model)