Spaces:
No application file
No application file
| import torch | |
| import torch.nn as nn | |
| import argparse | |
| from transformers import AutoTokenizer, AutoModel, TrainingArguments | |
| from transformers.trainer_callback import TrainerCallback | |
| from transformers.trainer_utils import is_main_process | |
| from datasets import load_dataset, load_from_disk, Features, Sequence, Value, concatenate_datasets | |
| from datasets.distributed import split_dataset_by_node | |
| import os, multiprocessing, random, pathlib | |
| from torch.utils.data import DataLoader | |
| from peft import LoraConfig, get_peft_model, TaskType | |
| from flexmdm_trainer import * | |
| from collections import Counter | |
| from llada_dit import LLaDA_DIT | |
| from pathlib import Path | |
| import torch.distributed as dist | |
| import random | |
| import tqdm | |
| import numpy as np | |
| import wandb | |
| import glob | |
| def init_seed(seed): | |
| random.seed(seed) | |
| os.environ["PYTHONHASHSEED"] = str(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed(seed) | |
| # torch.backends.cudnn.deterministic = True # for the training speed, we comment this out | |
| # ------------------------------------------------------------ | |
| # Util function for logging | |
| # ------------------------------------------------------------ | |
| def count_parameters(named_params, key: str | None = None): | |
| return sum(p.numel() | |
| for n, p in named_params | |
| if p.requires_grad and (key is None or key in n) | |
| ) | |
| class LogLrCallback(TrainerCallback): | |
| def on_step_end(self, args, state, control, **kwargs): | |
| if not is_main_process(args): | |
| return | |
| opt = kwargs["optimizer"] | |
| wandb.log( | |
| { | |
| "lr/lora": opt.param_groups[0]["lr"], | |
| "lr/token_head": opt.param_groups[1]["lr"], | |
| "lr/from_scratch": opt.param_groups[2]["lr"], | |
| "step": state.global_step, | |
| } | |
| ) | |
| # Initialize argument parser | |
| def parse_args(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "--model_name", type=str, default="GSAI-ML/LLaDA-8B-Base", help="Name of the pretrained model" | |
| ) | |
| # Training hyperparameters | |
| parser.add_argument("--batch_size", type=int, default=4, help="batch size per device") | |
| parser.add_argument("--lora_lr", type=float, default=1e-4, help="Learning rate for the LoRA") | |
| parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate for other parameters") | |
| parser.add_argument("--grad_accum_steps", type=int, default=2, help="Gradient accumulation steps") | |
| parser.add_argument("--max_steps", type=int, default=500000, help="Maximum number of training steps") | |
| parser.add_argument("--resume_from_checkpoint", type=str, default=None, help="Path to the checkpoint to resume from") | |
| parser.add_argument("--low_discrepancy", type=bool, default=False, help="whether to use low discrepancy sampling") | |
| # Output directory and job name | |
| parser.add_argument( | |
| "--output_dir", | |
| type=str, | |
| default="/n/netscratch/albergo_lab/Lab/transdim-flow/sft-datamix-checkpoints", | |
| help="Directory to save model checkpoints and logs", | |
| ) | |
| parser.add_argument("--job_name", type=str, default="llada-sft-openwebtext", help="Job Name") | |
| parser.add_argument("--train_data", type=str, default="openwebtext", help="Path to training data") | |
| parser.add_argument("--wandb", action="store_true", help="whether to use wandb") | |
| parser.add_argument("--variable_length", action="store_true", help="whether to use variable length training") | |
| parser.add_argument("--sanity_run", action="store_true", help="whether to run the sanity run (overfitting the model)") | |
| # CLI flags for openwebtext dataset preprocessing | |
| parser.add_argument("--sft_max_length", type=int, default=1024, help="Maximum sequence length for tokenization") | |
| parser.add_argument("--cache_path", type=str, default="/n/netscratch/albergo_lab/Everyone/jay_brian/datamix", help="Path of the tokenized openwebtext dataset") | |
| return parser.parse_args() | |
| # Model loading with LoRA integration | |
| def load_model_and_tokenizer(args): | |
| # Load the backbone LLaDA model | |
| backbone = AutoModel.from_pretrained( | |
| args.model_name, | |
| trust_remote_code=True, | |
| torch_dtype=torch.bfloat16, | |
| return_dict=True, | |
| ) | |
| # Load tokenizer | |
| tokenizer = AutoTokenizer.from_pretrained(args.model_name, padding_side="right", trust_remote_code=True, use_fast=True) | |
| print("Tokenizer and backbone loaded!") | |
| backbone.config.output_hidden_states = True | |
| backbone.config.return_dict = True | |
| # lora adapter for the backbone LLaDA | |
| lora_config = LoraConfig( | |
| r=128, | |
| lora_alpha=128, | |
| target_modules=["q_proj", "k_proj", "v_proj", "transformer.ff_out"], | |
| lora_dropout=0.1, | |
| bias="none", | |
| task_type=TaskType.CAUSAL_LM, | |
| ) | |
| backbone = get_peft_model(backbone, lora_config) | |
| backbone = backbone.to(torch.bfloat16) | |
| if args.variable_length: | |
| model = LLaDA_DIT(backbone, pad_token_id = tokenizer.pad_token_id, d_model = 4096) | |
| else: | |
| model = backbone | |
| if args.resume_from_checkpoint: | |
| ckpt_dir = Path(args.resume_from_checkpoint) | |
| state = torch.load(ckpt_dir/ "pytorch_model.bin", map_location="cpu") | |
| model.load_state_dict(state, strict=False) | |
| print(f"Resumed from checkpoint {args.resume_from_checkpoint}") | |
| print("Final trainer model loaded!") | |
| return tokenizer, model | |
| # Dataset loading | |
| def load_data(args, tokenizer): | |
| # load the pre-processed tokenzied dataset (already int64) | |
| cache_dir = pathlib.Path(args.cache_path) | |
| if not cache_dir.exists(): | |
| raise FileNotFoundError(f"Cache directory {cache_dir} does not exist") | |
| ds = load_from_disk(cache_dir) | |
| ds = ds.shuffle(seed=42) | |
| data = ds.train_test_split(test_size=0.001, seed=42) | |
| print("Training and evaluation datasets successfully loaded!") | |
| if args.sanity_run: | |
| data = data["train"].select(range(128)) | |
| print("Sanity run dataset loaded!") | |
| data.save_to_disk("sanity_run_dataset") | |
| return data, data | |
| return data["train"], data["test"] | |
| # Training setup | |
| def train_model(args, tokenizer, model): | |
| # Load dataset | |
| train_dataset, eval_dataset = load_data(args, tokenizer) | |
| # Training arguments setup | |
| training_args = TrainingArguments( | |
| output_dir=os.path.join(args.output_dir, args.job_name), | |
| max_steps = args.max_steps, | |
| per_device_train_batch_size=args.batch_size, | |
| gradient_accumulation_steps=args.grad_accum_steps, | |
| eval_strategy= 'steps', | |
| eval_steps = 1000, | |
| prediction_loss_only = True, | |
| logging_steps = 10, | |
| save_steps = 10000, | |
| save_total_limit=20, | |
| save_safetensors=False, | |
| max_grad_norm=1.0, | |
| bf16=True, | |
| lr_scheduler_type="cosine", | |
| lr_scheduler_kwargs={"num_cycles": 5}, | |
| warmup_ratio=0.05, | |
| remove_unused_columns=False, | |
| report_to="wandb" if args.wandb else None, | |
| ) | |
| # setup the trainable parameters | |
| lora_params = [p for n, p in model.named_parameters() if "lora" in n and p.requires_grad] | |
| head_params = [p for n, p in model.named_parameters() if "lora" not in n and "ff_out" in n and p.requires_grad] | |
| from_scratch_params = [p for n, p in model.named_parameters() if "lora" not in n and "ff_out" not in n and p.requires_grad] | |
| trainable = [p for _, p in model.named_parameters() if p.requires_grad] | |
| assert set(trainable) == set(lora_params) | set(head_params) | set(from_scratch_params), "Trainable parameters are not correctly set" | |
| # parameter count check | |
| print(f"Total trainable parameters: {count_parameters(model.named_parameters(), key = None)}") | |
| print(f" ββ LoRA adapter params : {count_parameters(model.named_parameters() , key = 'lora')}") | |
| print(f" ββ Token Head params : {count_parameters(model.named_parameters(), key = 'ff_out')}") | |
| print(f" ββ Scalar Length Head params : {count_parameters(model.named_parameters(), key = 'scalar_length_head')}") | |
| print(f" ββ Time embedding params : {count_parameters(model.named_parameters(), key = '.temb_mod')}") | |
| # Initialize Trainer with custom dLLMTrainer | |
| if args.variable_length: | |
| optimizer = torch.optim.AdamW( | |
| [ | |
| {"params": lora_params, "lr": args.lora_lr, "weight_decay": 0.0}, | |
| {"params": head_params, "lr": args.lora_lr / 4, "weight_decay": 0.01}, | |
| {"params": from_scratch_params, "lr": args.lr, "weight_decay": 0.01} | |
| ], | |
| ) | |
| trainer = dLLMVariableLengthTrainer( | |
| model=model, | |
| args=training_args, | |
| data_collator=dLLMVariableDataCollator(tokenizer=tokenizer, mask_token_id=126336, | |
| max_length=args.sft_max_length, compute_metrics = None, | |
| low_discrepancy = args.low_discrepancy), | |
| train_dataset=train_dataset, | |
| eval_dataset=eval_dataset, | |
| optimizers=(optimizer, None), | |
| ) | |
| else: | |
| raise NotImplementedError("Currently we don't support fixed length training") | |
| local_rank = int(os.environ.get("LOCAL_RANK", 0)) | |
| if args.wandb and local_rank == 0: | |
| wandb.init(project="SFT-llada", name=args.job_name, entity="jaeyeon_kim-harvard-university") | |
| # double-check the optimizer | |
| for i, g in enumerate(trainer.optimizer.param_groups): | |
| print(f"group {i} init-lr={g['lr']} wd={g['weight_decay']}") | |
| # add the callback | |
| trainer.add_callback(LogLrCallback()) | |
| # Start training | |
| trainer.train() | |
| if __name__ == "__main__": | |
| init_seed(42) | |
| # Parse command-line arguments | |
| args = parse_args() | |
| # Load model and tokenizer | |
| tokenizer, model = load_model_and_tokenizer(args) | |
| # Train the model | |
| train_model(args, tokenizer, model) |