| |
| """Centralized Alpaca LoRA finetuning for post-pruned models.""" |
|
|
| import argparse |
| import itertools |
| import json |
| import os |
| from types import SimpleNamespace |
| from pathlib import Path |
|
|
| import torch |
| from contextlib import nullcontext |
|
|
| from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, PretrainedConfig |
| from transformers.models.auto.configuration_auto import CONFIG_MAPPING |
| import ppl_eval |
|
|
| from fuse_layers_data import FixedSeqDataset, load_instruction_records |
| from fuse_layers_distill import LoRALinear, apply_lora_adapters, merge_lora_adapters |
|
|
| try: |
| from tqdm import tqdm |
| except Exception: |
| tqdm = None |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser(description="Run centralized Alpaca LoRA finetuning.") |
| parser.add_argument("--base_model", required=True, help="Path or HF model id to finetune") |
| parser.add_argument("--output_dir", required=True, help="Directory to save merged model") |
| parser.add_argument("--device", default="cuda", help="Training device") |
| parser.add_argument( |
| "--dtype", |
| default="bfloat16", |
| choices=["float32", "float16", "bfloat16"], |
| help="Model load/training dtype", |
| ) |
| parser.add_argument("--trust_remote_code", action="store_true", help="Enable trust_remote_code") |
| parser.add_argument("--seed", type=int, default=42, help="Random seed") |
|
|
| parser.add_argument( |
| "--instruction_dataset", |
| default="yahma/alpaca-cleaned", |
| help="HF dataset name for Alpaca-style instruction data", |
| ) |
| parser.add_argument("--instruction_config", default=None, help="Optional dataset config") |
| parser.add_argument("--instruction_split", default="train", help="Dataset split") |
| parser.add_argument("--instruction_field_instruction", default="instruction") |
| parser.add_argument("--instruction_field_input", default="input") |
| parser.add_argument("--instruction_field_output", default="output") |
| parser.add_argument("--max_samples", type=int, default=0, help="Limit instruction samples (0 = all)") |
| parser.add_argument("--seq_len", type=int, default=1024, help="Training sequence length") |
| parser.add_argument("--batch_size", type=int, default=64, help="Global batch size") |
| parser.add_argument("--micro_batch_size", type=int, default=4, help="Per-step micro-batch size") |
| parser.add_argument("--epochs", type=float, default=1.0, help="Training epochs") |
| parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate") |
| parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay") |
| parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Gradient clipping norm") |
| parser.add_argument("--log_steps", type=int, default=100, help="Log every N optimizer steps") |
| parser.add_argument( |
| "--save_steps", |
| type=int, |
| default=200, |
| help="Save LoRA adapter checkpoints every N optimizer steps (0 = disable)", |
| ) |
| parser.add_argument( |
| "--no_wikitext2_ppl_on_log", |
| dest="wikitext2_ppl_on_log", |
| action="store_false", |
| help="Disable Wikitext-2 perplexity evaluation at loss log steps", |
| ) |
| parser.set_defaults(wikitext2_ppl_on_log=True) |
| parser.add_argument("--wikitext2_ppl_seq_len", type=int, default=128) |
| parser.add_argument("--wikitext2_ppl_batch_size", type=int, default=8) |
| parser.add_argument("--wikitext2_ppl_max_batches", type=int, default=None) |
|
|
| parser.add_argument("--lora_rank", type=int, default=8, help="LoRA rank") |
| parser.add_argument("--lora_alpha", type=float, default=16.0, help="LoRA alpha") |
| parser.add_argument("--lora_dropout", type=float, default=0.0, help="LoRA dropout") |
| parser.add_argument( |
| "--lora_target_modules", |
| nargs="*", |
| default=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "down_proj", "up_proj"], |
| help="Linear module suffixes to LoRA-wrap", |
| ) |
|
|
| return parser.parse_args() |
|
|
|
|
| def get_dtype(name: str) -> torch.dtype: |
| return { |
| "float32": torch.float32, |
| "float16": torch.float16, |
| "bfloat16": torch.bfloat16, |
| }[name] |
|
|
|
|
| def seed_all(seed: int) -> None: |
| torch.manual_seed(seed) |
| if torch.cuda.is_available(): |
| torch.cuda.manual_seed_all(seed) |
|
|
|
|
| def normalize_config(config): |
| layer_types = getattr(config, "layer_types", None) |
| num_hidden_layers = getattr(config, "num_hidden_layers", None) |
| if layer_types is not None and num_hidden_layers is not None and len(layer_types) != num_hidden_layers: |
| config.layer_types = list(layer_types[:num_hidden_layers]) |
| if getattr(config, "_attn_implementation", None) is None: |
| config._attn_implementation = "eager" |
| return config |
|
|
|
|
| def load_normalized_config(base_model: str, trust_remote_code: bool): |
| config_dict, unused_kwargs = PretrainedConfig.get_config_dict(base_model, trust_remote_code=trust_remote_code) |
| layer_types = config_dict.get("layer_types") |
| num_hidden_layers = config_dict.get("num_hidden_layers") |
| if layer_types is not None and num_hidden_layers is not None and len(layer_types) != num_hidden_layers: |
| config_dict["layer_types"] = list(layer_types[:num_hidden_layers]) |
| if config_dict.get("_attn_implementation") is None: |
| config_dict["_attn_implementation"] = "eager" |
| model_type = config_dict["model_type"] |
| config_class = CONFIG_MAPPING[model_type] |
| config = config_class.from_dict(config_dict, **unused_kwargs) |
| return normalize_config(config) |
|
|
|
|
| def validate_local_model_dir(base_path: Path) -> None: |
| if not base_path.exists() or not base_path.is_dir(): |
| return |
|
|
| has_config = (base_path / "config.json").is_file() |
| has_weights = any( |
| (base_path / name).is_file() |
| for name in ( |
| "model.safetensors", |
| "model.safetensors.index.json", |
| "pytorch_model.bin", |
| "pytorch_model.bin.index.json", |
| ) |
| ) |
| if has_config and has_weights: |
| return |
|
|
| raise SystemExit( |
| "Local --base_model points to an incomplete HF model directory: " |
| f"{base_path}. Expected at least config.json and model weights. " |
| "Set --base_model/BASE_MODEL to a saved HF model directory." |
| ) |
|
|
|
|
| def load_base_artifacts(args: argparse.Namespace): |
| base_path = Path(args.base_model) |
| if base_path.is_file() and base_path.suffix == ".bin": |
| checkpoint = torch.load(str(base_path), map_location="cpu", weights_only=False) |
| if not isinstance(checkpoint, dict) or "model" not in checkpoint or "tokenizer" not in checkpoint: |
| raise SystemExit("Expected a .bin checkpoint dict with `model` and `tokenizer` entries.") |
| model = checkpoint["model"] |
| tokenizer = checkpoint["tokenizer"] |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token or tokenizer.unk_token |
| return model, tokenizer |
|
|
| validate_local_model_dir(base_path) |
| tokenizer = AutoTokenizer.from_pretrained(args.base_model, trust_remote_code=args.trust_remote_code) |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token or tokenizer.unk_token |
| config = load_normalized_config(args.base_model, trust_remote_code=args.trust_remote_code) |
| model = AutoModelForCausalLM.from_pretrained( |
| args.base_model, |
| config=config, |
| torch_dtype=get_dtype(args.dtype), |
| trust_remote_code=args.trust_remote_code, |
| ) |
| return model, tokenizer |
|
|
|
|
| def build_training_loader(args: argparse.Namespace, tokenizer) -> torch.utils.data.DataLoader: |
| num_samples = args.max_samples if args.max_samples > 0 else 0 |
| records = load_instruction_records(args, num_samples) |
| if not records: |
| raise SystemExit("No instruction records were loaded.") |
| dataset = FixedSeqDataset(records, tokenizer, args.seq_len) |
| return torch.utils.data.DataLoader(dataset, batch_size=args.micro_batch_size, shuffle=True) |
|
|
|
|
| def save_lora_adapters( |
| model: torch.nn.Module, args: argparse.Namespace, subdir: str = "lora_adapter" |
| ) -> str: |
| adapter_dir = os.path.join(args.output_dir, subdir) |
| os.makedirs(adapter_dir, exist_ok=True) |
|
|
| adapter_state = {} |
| adapter_modules = {} |
| for module_name, module in model.named_modules(): |
| if not isinstance(module, LoRALinear): |
| continue |
| adapter_modules[module_name] = { |
| "rank": module.rank, |
| "alpha": module.alpha, |
| "scaling": module.scaling, |
| "dropout": getattr(module.dropout, "p", 0.0), |
| "base_layer_class": type(module.base).__name__, |
| "in_features": module.base.in_features, |
| "out_features": module.base.out_features, |
| } |
| adapter_state[f"{module_name}.lora_A.weight"] = module.lora_A.weight.detach().cpu() |
| adapter_state[f"{module_name}.lora_B.weight"] = module.lora_B.weight.detach().cpu() |
|
|
| torch.save(adapter_state, os.path.join(adapter_dir, "adapter_model.bin")) |
| with open(os.path.join(adapter_dir, "adapter_config.json"), "w", encoding="utf-8") as handle: |
| json.dump( |
| { |
| "base_model": args.base_model, |
| "lora_rank": args.lora_rank, |
| "lora_alpha": args.lora_alpha, |
| "lora_dropout": args.lora_dropout, |
| "lora_target_modules": list(args.lora_target_modules), |
| "batch_size": args.batch_size, |
| "micro_batch_size": args.micro_batch_size, |
| "grad_accum_steps": args.grad_accum_steps, |
| "modules": adapter_modules, |
| }, |
| handle, |
| indent=2, |
| ) |
| return adapter_dir |
|
|
|
|
| def prepare_wikitext2_eval(args: argparse.Namespace, model, tokenizer): |
| if not args.wikitext2_ppl_on_log: |
| return None |
| return ppl_eval.prepare_ppl_dataloaders( |
| tokenizer=tokenizer, |
| datasets=["wikitext"], |
| configs=["wikitext-2-raw-v1"], |
| split="test", |
| text_field=None, |
| num_samples=0, |
| seq_len=args.wikitext2_ppl_seq_len, |
| batch_size=args.wikitext2_ppl_batch_size, |
| seed=args.seed, |
| shuffle=False, |
| model_family="auto", |
| add_bos="auto", |
| cache_dir=None, |
| num_workers=0, |
| model=model, |
| ) |
|
|
|
|
| def train(model: torch.nn.Module, dataloader, args: argparse.Namespace, wikitext2_eval_dataloaders=None) -> dict: |
| lora_args = SimpleNamespace( |
| lora_rank=args.lora_rank, |
| lora_alpha=args.lora_alpha, |
| lora_dropout=args.lora_dropout, |
| lora_target_modules=args.lora_target_modules, |
| lora_respect_exclude_pairs=False, |
| layer_path=None, |
| exclude_pairs=None, |
| ) |
| lora_modules = apply_lora_adapters(model, lora_args) |
| lora_params = [param for module in lora_modules for param in module.lora_parameters()] |
|
|
| optimizer = torch.optim.AdamW( |
| lora_params, |
| lr=args.learning_rate, |
| weight_decay=args.weight_decay, |
| ) |
| model.train() |
|
|
| device = torch.device(args.device) |
| device_type = device.type |
| amp_dtype = None |
| if args.dtype == "float16": |
| amp_dtype = torch.float16 |
| elif args.dtype == "bfloat16": |
| amp_dtype = torch.bfloat16 |
| use_amp = amp_dtype is not None and device_type == "cuda" |
| use_scaler = use_amp and amp_dtype == torch.float16 |
| scaler = torch.cuda.amp.GradScaler() if use_scaler else None |
|
|
| full_epochs = int(args.epochs) |
| fractional = args.epochs - full_epochs |
| epoch_plan = [None] * full_epochs |
| if fractional > 1e-8: |
| frac_batches = max(1, int(round(fractional * len(dataloader)))) |
| epoch_plan.append(frac_batches) |
|
|
| optimizer.zero_grad(set_to_none=True) |
| optimizer_step = 0 |
| seen_batches = 0 |
| last_loss = None |
| ppl_history = [] |
|
|
| for epoch_idx, max_batches in enumerate(epoch_plan, start=1): |
| iterator = dataloader if max_batches is None else itertools.islice(dataloader, max_batches) |
| if tqdm is not None: |
| iterator = tqdm(iterator, desc=f"LoRA epoch {epoch_idx}", unit="batch", total=max_batches) |
| for batch in iterator: |
| input_ids = batch[0].to(args.device) |
| attention_mask = batch[1].to(args.device) |
|
|
| autocast_ctx = ( |
| torch.autocast(device_type=device_type, dtype=amp_dtype) |
| if use_amp |
| else nullcontext() |
| ) |
| with autocast_ctx: |
| outputs = model(input_ids=input_ids, attention_mask=attention_mask, use_cache=False) |
| logits = outputs.logits[:, :-1, :].contiguous() |
| labels = input_ids[:, 1:].contiguous() |
| mask = attention_mask[:, 1:].contiguous() |
| ce_flat = torch.nn.functional.cross_entropy( |
| logits.view(-1, logits.size(-1)), |
| labels.view(-1), |
| reduction="none", |
| ) |
| denom = mask.sum() |
| if denom.item() == 0: |
| continue |
| loss = (ce_flat * mask.reshape(-1).to(ce_flat.dtype)).sum() / denom |
|
|
| last_loss = float(loss.detach().item()) |
| scaled_loss = loss / max(args.grad_accum_steps, 1) |
| if use_scaler: |
| scaler.scale(scaled_loss).backward() |
| else: |
| scaled_loss.backward() |
|
|
| seen_batches += 1 |
| if seen_batches % max(args.grad_accum_steps, 1) != 0: |
| continue |
|
|
| if args.max_grad_norm is not None: |
| if use_scaler: |
| scaler.unscale_(optimizer) |
| torch.nn.utils.clip_grad_norm_(lora_params, args.max_grad_norm) |
| if use_scaler: |
| scaler.step(optimizer) |
| scaler.update() |
| else: |
| optimizer.step() |
| optimizer.zero_grad(set_to_none=True) |
| optimizer_step += 1 |
|
|
| if args.log_steps and optimizer_step % args.log_steps == 0: |
| print(f"[loratune] step={optimizer_step} loss={last_loss:.6f}") |
| if wikitext2_eval_dataloaders is not None: |
| prev_mode = model.training |
| model.eval() |
| ppl_results = ppl_eval.evaluate_ppl_dataloaders( |
| model, |
| wikitext2_eval_dataloaders, |
| args.device, |
| max_batches=args.wikitext2_ppl_max_batches, |
| ) |
| ppl_history.append({"step": optimizer_step, "ppl": ppl_results}) |
| print(f"[loratune] ppl step={optimizer_step} {ppl_results}") |
| if prev_mode: |
| model.train() |
|
|
| if args.save_steps and optimizer_step % args.save_steps == 0: |
| checkpoint_dir = save_lora_adapters( |
| model, |
| args, |
| subdir=os.path.join("checkpoints", f"step_{optimizer_step}"), |
| ) |
| print(f"[loratune] saved adapter checkpoint to {checkpoint_dir}") |
|
|
| adapter_dir = save_lora_adapters(model, args) |
| merge_lora_adapters(model) |
| return { |
| "adapter_dir": adapter_dir, |
| "optimizer_steps": optimizer_step, |
| "seen_batches": seen_batches, |
| "last_loss": last_loss, |
| "wikitext2_ppl_history": ppl_history, |
| } |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
| if args.batch_size < 1: |
| raise SystemExit("--batch_size must be >= 1") |
| if args.micro_batch_size < 1: |
| raise SystemExit("--micro_batch_size must be >= 1") |
| args.grad_accum_steps = args.batch_size // args.micro_batch_size |
| if args.grad_accum_steps < 1: |
| raise SystemExit("--batch_size must be >= --micro_batch_size") |
|
|
| seed_all(args.seed) |
| os.makedirs(args.output_dir, exist_ok=True) |
|
|
| model, tokenizer = load_base_artifacts(args) |
| if args.dtype != "float32": |
| model = model.to(get_dtype(args.dtype)) |
| model.to(args.device) |
|
|
| dataloader = build_training_loader(args, tokenizer) |
| wikitext2_eval_dataloaders = prepare_wikitext2_eval(args, model, tokenizer) |
| metrics = train(model, dataloader, args, wikitext2_eval_dataloaders=wikitext2_eval_dataloaders) |
|
|
| model.save_pretrained(args.output_dir) |
| tokenizer.save_pretrained(args.output_dir) |
|
|
| with open(os.path.join(args.output_dir, "loratune_metrics.json"), "w", encoding="utf-8") as handle: |
| json.dump( |
| { |
| "base_model": args.base_model, |
| "instruction_dataset": args.instruction_dataset, |
| "seq_len": args.seq_len, |
| "batch_size": args.batch_size, |
| "micro_batch_size": args.micro_batch_size, |
| "grad_accum_steps": args.grad_accum_steps, |
| "epochs": args.epochs, |
| "learning_rate": args.learning_rate, |
| "save_steps": args.save_steps, |
| "lora_rank": args.lora_rank, |
| "lora_alpha": args.lora_alpha, |
| "lora_dropout": args.lora_dropout, |
| **metrics, |
| }, |
| handle, |
| indent=2, |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|