| %%writefile qat_sft.py |
|
|
| import os |
| import argparse |
| import torch |
| import numpy as np |
| from torch import nn |
| from torch.utils.data import DataLoader, Dataset |
| from transformers import AutoModelForMaskedLM, get_scheduler |
| from datasets import load_from_disk |
| from accelerate import Accelerator |
| from tqdm.auto import tqdm |
| from tokenizer import get_tokenizer |
| from safetensors.torch import load_file |
|
|
| from torch.quantization import get_default_qat_qconfig, prepare_qat, convert |
|
|
| from data_utils import SFTCollator |
|
|
| def parse_args(): |
| parser = argparse.ArgumentParser(description="Pretrain a language model.") |
| |
| |
| parser.add_argument("--experiment_name", |
| type=str, |
| required=True, |
| ) |
| |
| parser.add_argument("--working_dir", |
| type=str, |
| required=True, |
| ) |
|
|
| parser.add_argument("--path_to_pretrained_checkpoint", |
| type=str, |
| default=None, |
| help="Path to a pretrained model checkpoint to continue training from.", |
| ) |
|
|
| parser.add_argument("--seed", |
| type=int, |
| default=42, |
| help="Random seed for reproducibility.", |
| ) |
|
|
| |
| parser.add_argument("--hf_model_name", |
| type=str, |
| required=True, |
| ) |
|
|
| |
| parser.add_argument("--path_to_prepped_data", |
| type=str, |
| required=True, |
| help="Path to the preprocessed dataset stored on disk\ |
| in prepare_pretrain_data.py.", |
| ) |
|
|
| parser.add_argument("--num_workers", |
| type=int, |
| default=24, |
| help="Number of workers for data loading.", |
| ) |
|
|
| |
| parser.add_argument("--mixed_precision", |
| type=str, |
| default="bf16", |
| choices=["fp32", "fp16", "bf16", "no"], |
| help="Whether to use mixed precision. Choose between fp16 and bf16.", |
| ) |
|
|
| parser.add_argument("--batch_size", |
| type=int, |
| default=16, |
| help="Batch size per GPU/TPU core/CPU for training.", |
| ) |
|
|
| parser.add_argument("--gradient_accumulation_steps", |
| type=int, |
| default=1, |
| help="Number of updates steps to accumulate before\ |
| performing a backward/update pass.", |
| ) |
|
|
| parser.add_argument("--num_training_steps", |
| type=int, |
| default=100000, |
| help="Total number of training steps to perform.", |
| ) |
|
|
| parser.add_argument("--max_grad_norm", |
| type=float, |
| default=1.0, |
| help="Maximum gradient norm for gradient clipping.", |
| ) |
|
|
| parser.add_argument("--lr_scheduler_type", |
| type=str, |
| default="cosine", |
| help="Type of learning rate scheduler to use.", |
| choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"], |
| ) |
|
|
| parser.add_argument("--num_warmup_steps", |
| type=int, |
| default=1000, |
| help="Number of steps for the warmup in the lr scheduler.", |
| ) |
|
|
| parser.add_argument("--evaluation_interval", |
| type=int, |
| default=2500, |
| help="Number of steps between evaluations.", |
| ) |
|
|
| parser.add_argument("--checkpoint_interval", |
| type=int, |
| default=2500, |
| help="Number of steps between model checkpoints.", |
| ) |
|
|
| parser.add_argument("--learning_rate", |
| type=float, |
| default=5e-5, |
| help="Max learning rate.", |
| ) |
|
|
| parser.add_argument("--weight_decay", |
| type=float, |
| default=0.05, |
| help="Weight decay to use.", |
| ) |
|
|
| |
|
|
| parser.add_argument("--log_wandb", |
| default=False, |
| help="Whether to log metrics and model checkpoints to Weights & Biases.", |
| action=argparse.BooleanOptionalAction, |
| ) |
|
|
| args = parser.parse_args() |
| return args |
|
|
| """ |
| python finetune_sft.py \ |
| --experiment_name my_sft_experiment \ |
| --working_dir ./experiments \ |
| --path_to_pretrained_checkpoint ./pretrained_models/modernbert_pretrained \ |
| --seed 42 \ |
| --hf_model_name answerdotai/ModernBERT-base \ |
| --path_to_prepped_data ./data/tokenized_sft_dataset \ |
| --num_workers 24 \ |
| --mixed_precision bf16 \ |
| --batch_size 16 \ |
| --gradient_accumulation_steps 1 \ |
| --num_training_steps 100000 \ |
| --max_grad_norm 1.0 \ |
| --lr_scheduler_type cosine \ |
| --num_warmup_steps 1000 \ |
| --evaluation_interval 2500 \ |
| --checkpoint_interval 2500 \ |
| --learning_rate 5e-5 \ |
| --weight_decay 0.05 \ |
| --log_wandb \ |
| """ |
|
|
| def seed_everything(seed: int): |
| import random, os |
| import numpy as np |
| import torch |
| |
| random.seed(seed) |
| os.environ['PYTHONHASHSEED'] = str(seed) |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
| torch.cuda.manual_seed(seed) |
| torch.backends.cudnn.deterministic = True |
| torch.backends.cudnn.benchmark = True |
| |
| seed_everything(42) |
|
|
| args = parse_args() |
|
|
| |
| path_to_experiment = os.path.join(args.working_dir, args.experiment_name) |
| os.makedirs(path_to_experiment, exist_ok=True) |
| accelerator = Accelerator( |
| mixed_precision=args.mixed_precision, |
| gradient_accumulation_steps=args.gradient_accumulation_steps, |
| log_with="wandb" if args.log_wandb else None, |
| project_dir=path_to_experiment, |
| ) |
|
|
| if args.log_wandb: |
| accelerator.init_trackers(args.experiment_name, config=vars(args)) |
|
|
| |
| tokenizer = get_tokenizer(args.hf_model_name) |
|
|
| |
| model = AutoModelForMaskedLM.from_pretrained(args.hf_model_name) |
| model.resize_token_embeddings(len(tokenizer)) |
| state_dict = torch.load(args.path_to_pretrained_checkpoint) |
| model.load_state_dict(state_dict, strict=True) |
|
|
| compile_model = False |
| if compile_model: |
| model = torch.compile(model) |
|
|
| model.train() |
| model.qconfig = get_default_qat_qconfig("x86") |
|
|
| for name, module in model.named_modules(): |
| if isinstance(module, torch.nn.Embedding): |
| module.qconfig = None |
|
|
| model = prepare_qat(model, inplace=False) |
|
|
| model_parameters = filter(lambda p: p.requires_grad, model.parameters()) |
| params = sum([np.prod(p.size()) for p in model_parameters]) |
| accelerator.print(f"Number of trainable parameters: {params}") |
|
|
| |
| batch_size = args.batch_size |
|
|
|
|
| tokenized_data = load_from_disk(args.path_to_prepped_data) |
| train_dataloader = DataLoader(tokenized_data["train"], |
| batch_size=batch_size, |
| shuffle=True, |
| collate_fn=SFTCollator(args.hf_model_name), |
| drop_last=True) |
|
|
| eval_dataloader = DataLoader(tokenized_data["test"], |
| batch_size=batch_size, |
| shuffle=False, |
| collate_fn=SFTCollator(args.hf_model_name), |
| drop_last=True) |
|
|
| |
| optimizer = torch.optim.AdamW( |
| model.parameters(), |
| lr=args.learning_rate, |
| weight_decay=args.weight_decay, |
| ) |
|
|
| |
| scheduler = get_scheduler( |
| name=args.lr_scheduler_type, |
| optimizer=optimizer, |
| num_warmup_steps=args.num_warmup_steps, |
| num_training_steps=args.num_training_steps, |
| ) |
|
|
| |
| loss_func = nn.CrossEntropyLoss(reduction="none") |
|
|
| |
| model, optimizer, train_dataloader, eval_dataloader, scheduler = accelerator.prepare( |
| model, optimizer, train_dataloader, eval_dataloader, scheduler |
| ) |
|
|
| |
| train = True |
| global_step = 0 |
| progress_bar = tqdm(range(args.num_training_steps), disable=not accelerator.is_local_main_process) |
|
|
| while train: |
| model.train() |
| for batch in train_dataloader: |
| optimizer.zero_grad() |
|
|
| input_ids = batch["input_ids"] |
| query_mask = batch["query_mask"] |
|
|
| |
| batch_size, seq_len = input_ids.size() |
| attention_mask = torch.ones((batch_size, seq_len), dtype=torch.long, device=accelerator.device) |
|
|
| |
| t = torch.rand(batch_size, 1, device=accelerator.device) |
| t = t.expand(batch_size, seq_len).clamp_min(1e-5) |
| mask = torch.bernoulli(t).bool() |
|
|
| mask = mask * query_mask |
| mask = mask.bool() |
|
|
| |
| masked_input_ids = input_ids.masked_fill(mask, tokenizer.mask_token_id) |
| labels = input_ids.masked_fill(~mask, -100) |
|
|
| |
| with accelerator.accumulate(model): |
| logits = model(input_ids=masked_input_ids, |
| attention_mask=attention_mask)["logits"] |
| |
| |
| num_classes = logits.size(-1) |
| loss = loss_func(logits.view(batch_size * seq_len, num_classes), |
| labels.flatten()) |
|
|
| |
| loss = loss.reshape(batch_size, seq_len) / t |
|
|
| answer_lengths = query_mask.sum(dim=1, keepdim=True) |
| answer_lengths = torch.clamp(answer_lengths, min=1) |
| loss = loss / answer_lengths |
|
|
| loss = loss.sum(dim=1).mean() |
|
|
| accelerator.backward(loss) |
| if accelerator.sync_gradients: |
| accelerator.clip_grad_norm_(model.parameters(), args.max_grad_norm) |
| optimizer.step() |
| scheduler.step() |
|
|
| |
| if accelerator.is_local_main_process: |
| progress_bar.update(1) |
| global_step += 1 |
| accelerator.log({"train_loss": loss.item(), |
| "lr": scheduler.get_last_lr()[0], |
| }, step=global_step) |
| |
| |
| if global_step % args.evaluation_interval == 0: |
| model.eval() |
|
|
| log = {"eval_loss": 0.0} |
| eval_steps = 0 |
|
|
| for batch in tqdm(eval_dataloader, desc="Evaluating", disable=not accelerator.is_local_main_process): |
| with torch.no_grad(): |
|
|
| input_ids = batch["input_ids"] |
| query_mask = batch["query_mask"] |
|
|
| |
| batch_size, seq_len = input_ids.size() |
| attention_mask = torch.ones((batch_size, seq_len), dtype=torch.long, device=accelerator.device) |
|
|
| |
| t = torch.rand(batch_size, 1, device=accelerator.device) |
| t = t.expand(batch_size, seq_len).clamp_min(1e-5) |
| mask = torch.bernoulli(t).bool() |
|
|
| mask = mask * query_mask |
| mask = mask.bool() |
|
|
| |
| masked_input_ids = input_ids.masked_fill(mask, tokenizer.mask_token_id) |
| labels = input_ids.masked_fill(~mask, -100) |
|
|
| |
| with accelerator.accumulate(model): |
| logits = model(input_ids=masked_input_ids, |
| attention_mask=attention_mask)["logits"] |
| |
| |
| num_classes = logits.size(-1) |
| loss = loss_func(logits.view(batch_size * seq_len, num_classes), |
| labels.flatten()) |
|
|
| |
| loss = loss.reshape(batch_size, seq_len) / t |
|
|
| answer_lengths = query_mask.sum(dim=1, keepdim=True) |
| answer_lengths = torch.clamp(answer_lengths, min=1) |
| loss = loss / answer_lengths |
|
|
| loss = loss.sum(dim=1).mean() |
|
|
| log["eval_loss"] += loss.item() |
| eval_steps += 1 |
| |
| log["eval_loss"] /= eval_steps |
| accelerator.log(log, step=global_step) |
| model.train() |
|
|
| if global_step % args.checkpoint_interval == 0: |
| if accelerator.is_local_main_process: |
| unwrapped_model = accelerator.unwrap_model(model) |
| final_dir = os.path.join(path_to_experiment, "checkpoint_latest") |
| os.makedirs(final_dir, exist_ok=True) |
| |
| |
| |
| save_path = os.path.join(final_dir, "qat_model_unconverted.pt") |
| accelerator.save(unwrapped_model.state_dict(), save_path) |
| |
| accelerator.print(f"QAT Checkpoint saved to {save_path} (Convert this offline!)") |
| tokenizer.save_pretrained(final_dir) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| if global_step >= 500: |
| train = False |
| break |
|
|
| |
|
|
| if accelerator.is_local_main_process: |
| unwrapped_model = accelerator.unwrap_model(model) |
| final_dir = os.path.join(path_to_experiment, "final_model") |
| os.makedirs(final_dir, exist_ok=True) |
| |
| unwrapped_model.to("cpu") |
| unwrapped_model.eval() |
| |
| accelerator.print("Converting QAT model to INT8...") |
| quantized_model = convert(unwrapped_model, inplace=False) |
| |
| save_path = os.path.join(final_dir, "quantized_model.pt") |
| torch.save(quantized_model.state_dict(), save_path) |
| |
| accelerator.print(f"Quantized model saved to {save_path}") |
|
|
| tokenizer.save_pretrained(final_dir) |