Adaptive-Block-Forcing / FlexMDM /scaling_flexmdm /flexmdm_transfer_openwebtext.py
Bailan-Alex's picture
Upload folder using huggingface_hub
4f2b2f4 verified
import torch
import torch.nn as nn
import argparse
from transformers import AutoTokenizer, AutoModel, TrainingArguments
from transformers.trainer_callback import TrainerCallback
from transformers.trainer_utils import is_main_process
from datasets import load_dataset, load_from_disk, Features, Sequence, Value, concatenate_datasets
from datasets.distributed import split_dataset_by_node
import os, multiprocessing, random, pathlib
from torch.utils.data import DataLoader
from peft import LoraConfig, get_peft_model, TaskType
from flexmdm_trainer import *
from collections import Counter
from llada_dit import LLaDA_DIT
from pathlib import Path
import torch.distributed as dist
import random
import tqdm
import numpy as np
import wandb
import glob
def init_seed(seed):
random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# torch.backends.cudnn.deterministic = True # for the training speed, we comment this out
# ------------------------------------------------------------
# Util function for logging
# ------------------------------------------------------------
def count_parameters(named_params, key: str | None = None):
return sum(p.numel()
for n, p in named_params
if p.requires_grad and (key is None or key in n)
)
class LogLrCallback(TrainerCallback):
def on_step_end(self, args, state, control, **kwargs):
if not is_main_process(args):
return
opt = kwargs["optimizer"]
wandb.log(
{
"lr/lora": opt.param_groups[0]["lr"],
"lr/token_head": opt.param_groups[1]["lr"],
"lr/from_scratch": opt.param_groups[2]["lr"],
"step": state.global_step,
}
)
# Initialize argument parser
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_name", type=str, default="GSAI-ML/LLaDA-8B-Base", help="Name of the pretrained model"
)
# Training hyperparameters
parser.add_argument("--batch_size", type=int, default=4, help="batch size per device")
parser.add_argument("--lora_lr", type=float, default=1e-4, help="Learning rate for the LoRA")
parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate for other parameters")
parser.add_argument("--grad_accum_steps", type=int, default=2, help="Gradient accumulation steps")
parser.add_argument("--max_steps", type=int, default=500000, help="Maximum number of training steps")
parser.add_argument("--resume_from_checkpoint", type=str, default=None, help="Path to the checkpoint to resume from")
parser.add_argument("--low_discrepancy", type=bool, default=False, help="whether to use low discrepancy sampling")
# Output directory and job name
parser.add_argument(
"--output_dir",
type=str,
default="/n/netscratch/albergo_lab/Lab/transdim-flow/sft-datamix-checkpoints",
help="Directory to save model checkpoints and logs",
)
parser.add_argument("--job_name", type=str, default="llada-sft-openwebtext", help="Job Name")
parser.add_argument("--train_data", type=str, default="openwebtext", help="Path to training data")
parser.add_argument("--wandb", action="store_true", help="whether to use wandb")
parser.add_argument("--variable_length", action="store_true", help="whether to use variable length training")
parser.add_argument("--sanity_run", action="store_true", help="whether to run the sanity run (overfitting the model)")
# CLI flags for openwebtext dataset preprocessing
parser.add_argument("--sft_max_length", type=int, default=1024, help="Maximum sequence length for tokenization")
parser.add_argument("--cache_path", type=str, default="/n/netscratch/albergo_lab/Everyone/jay_brian/datamix", help="Path of the tokenized openwebtext dataset")
return parser.parse_args()
# Model loading with LoRA integration
def load_model_and_tokenizer(args):
# Load the backbone LLaDA model
backbone = AutoModel.from_pretrained(
args.model_name,
trust_remote_code=True,
torch_dtype=torch.bfloat16,
return_dict=True,
)
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(args.model_name, padding_side="right", trust_remote_code=True, use_fast=True)
print("Tokenizer and backbone loaded!")
backbone.config.output_hidden_states = True
backbone.config.return_dict = True
# lora adapter for the backbone LLaDA
lora_config = LoraConfig(
r=128,
lora_alpha=128,
target_modules=["q_proj", "k_proj", "v_proj", "transformer.ff_out"],
lora_dropout=0.1,
bias="none",
task_type=TaskType.CAUSAL_LM,
)
backbone = get_peft_model(backbone, lora_config)
backbone = backbone.to(torch.bfloat16)
if args.variable_length:
model = LLaDA_DIT(backbone, pad_token_id = tokenizer.pad_token_id, d_model = 4096)
else:
model = backbone
if args.resume_from_checkpoint:
ckpt_dir = Path(args.resume_from_checkpoint)
state = torch.load(ckpt_dir/ "pytorch_model.bin", map_location="cpu")
model.load_state_dict(state, strict=False)
print(f"Resumed from checkpoint {args.resume_from_checkpoint}")
print("Final trainer model loaded!")
return tokenizer, model
# Dataset loading
def load_data(args, tokenizer):
# load the pre-processed tokenzied dataset (already int64)
cache_dir = pathlib.Path(args.cache_path)
if not cache_dir.exists():
raise FileNotFoundError(f"Cache directory {cache_dir} does not exist")
ds = load_from_disk(cache_dir)
ds = ds.shuffle(seed=42)
data = ds.train_test_split(test_size=0.001, seed=42)
print("Training and evaluation datasets successfully loaded!")
if args.sanity_run:
data = data["train"].select(range(128))
print("Sanity run dataset loaded!")
data.save_to_disk("sanity_run_dataset")
return data, data
return data["train"], data["test"]
# Training setup
def train_model(args, tokenizer, model):
# Load dataset
train_dataset, eval_dataset = load_data(args, tokenizer)
# Training arguments setup
training_args = TrainingArguments(
output_dir=os.path.join(args.output_dir, args.job_name),
max_steps = args.max_steps,
per_device_train_batch_size=args.batch_size,
gradient_accumulation_steps=args.grad_accum_steps,
eval_strategy= 'steps',
eval_steps = 1000,
prediction_loss_only = True,
logging_steps = 10,
save_steps = 10000,
save_total_limit=20,
save_safetensors=False,
max_grad_norm=1.0,
bf16=True,
lr_scheduler_type="cosine",
lr_scheduler_kwargs={"num_cycles": 5},
warmup_ratio=0.05,
remove_unused_columns=False,
report_to="wandb" if args.wandb else None,
)
# setup the trainable parameters
lora_params = [p for n, p in model.named_parameters() if "lora" in n and p.requires_grad]
head_params = [p for n, p in model.named_parameters() if "lora" not in n and "ff_out" in n and p.requires_grad]
from_scratch_params = [p for n, p in model.named_parameters() if "lora" not in n and "ff_out" not in n and p.requires_grad]
trainable = [p for _, p in model.named_parameters() if p.requires_grad]
assert set(trainable) == set(lora_params) | set(head_params) | set(from_scratch_params), "Trainable parameters are not correctly set"
# parameter count check
print(f"Total trainable parameters: {count_parameters(model.named_parameters(), key = None)}")
print(f" └─ LoRA adapter params : {count_parameters(model.named_parameters() , key = 'lora')}")
print(f" └─ Token Head params : {count_parameters(model.named_parameters(), key = 'ff_out')}")
print(f" └─ Scalar Length Head params : {count_parameters(model.named_parameters(), key = 'scalar_length_head')}")
print(f" └─ Time embedding params : {count_parameters(model.named_parameters(), key = '.temb_mod')}")
# Initialize Trainer with custom dLLMTrainer
if args.variable_length:
optimizer = torch.optim.AdamW(
[
{"params": lora_params, "lr": args.lora_lr, "weight_decay": 0.0},
{"params": head_params, "lr": args.lora_lr / 4, "weight_decay": 0.01},
{"params": from_scratch_params, "lr": args.lr, "weight_decay": 0.01}
],
)
trainer = dLLMVariableLengthTrainer(
model=model,
args=training_args,
data_collator=dLLMVariableDataCollator(tokenizer=tokenizer, mask_token_id=126336,
max_length=args.sft_max_length, compute_metrics = None,
low_discrepancy = args.low_discrepancy),
train_dataset=train_dataset,
eval_dataset=eval_dataset,
optimizers=(optimizer, None),
)
else:
raise NotImplementedError("Currently we don't support fixed length training")
local_rank = int(os.environ.get("LOCAL_RANK", 0))
if args.wandb and local_rank == 0:
wandb.init(project="SFT-llada", name=args.job_name, entity="jaeyeon_kim-harvard-university")
# double-check the optimizer
for i, g in enumerate(trainer.optimizer.param_groups):
print(f"group {i} init-lr={g['lr']} wd={g['weight_decay']}")
# add the callback
trainer.add_callback(LogLrCallback())
# Start training
trainer.train()
if __name__ == "__main__":
init_seed(42)
# Parse command-line arguments
args = parse_args()
# Load model and tokenizer
tokenizer, model = load_model_and_tokenizer(args)
# Train the model
train_model(args, tokenizer, model)