#!/usr/bin/env python3 """Fuse adjacent layers via attention head alignment + Fisher-barycentric merge.""" import argparse import copy import gc import json import os import random from dataclasses import dataclass from typing import Dict, List, Optional, Set, Tuple import torch try: import numpy as np except Exception: # pragma: no cover - optional dependency np = None try: from transformers import AutoModelForCausalLM, AutoTokenizer except Exception as exc: # pragma: no cover - fail early with clear error raise SystemExit("transformers is required: pip install transformers") from exc try: import ppl_eval except Exception as exc: # pragma: no cover - optional dependency raise SystemExit("ppl_eval.py is required (missing or invalid)") from exc from fuse_layers_data import ( FixedSeqDataset, build_token_chunks, expand_dataset_configs, load_instruction_records, load_texts, load_texts_from_datasets, ) from common_lm_data import SharedLMDataSpec, build_chunks, build_dataloader from fuse_layers_distill import ( commutator_precondition, compute_fisher_gate_priors, distill_reparam_merge, lora_ce_finetune, ) from fuse_layers_model import ( apply_norm_policy, build_head_permutation, clone_state_dict, compute_fisher, compute_head_means, decrement_config, drop_layer, find_attention_module, find_colon_modules, find_layer_container, get_dtype, get_norm_pair, merge_layers, permute_attention_heads, ) from fuse_layers_select import select_layer_auto from progressive_loader import load_causal_lm, load_progressive_model def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( description="Fuse layer i and i+1 using head alignment + Fisher barycenter." ) parser.add_argument("--model", required=True, help="HF model id or local path") parser.add_argument( "--model_cache_dir", default=None, help="Optional cache dir for model/tokenizer downloads", ) parser.add_argument( "--layer", type=str, default="auto", help="Layer index i (int) or 'auto' to select via auto metric", ) parser.add_argument( "--selection_method", choices=["dwce", "sequential"], default="dwce", help=( "Pair selection policy for progressive pruning. " "'dwce' uses downstream-weighted composition error; " "'sequential' always takes the next available pair." ), ) parser.add_argument( "--exclude_pairs", "--exclude_layers", nargs="*", default=None, dest="exclude_pairs", help=( "Exclude pair indices from consideration for any fusion. Indices refer to " "pair start positions in [0..N-2]. Negative indices count from the end " "(-1 = last pair, -2 = second last). Accepts space- or comma-separated ints. " "Alias: --exclude_layers (deprecated)." ), ) parser.add_argument( "--output_dir", required=True, help="Directory to write fused model" ) parser.add_argument( "--dataset", action="append", default=[], help=( "HF dataset name (repeatable). Optional if using --text or --text_file." ), ) parser.add_argument( "--dataset_config", action="append", default=[], help="Optional dataset config (repeatable or single shared config).", ) parser.add_argument( "--dataset_split", default="train", help="Dataset split to use (default: train)", ) parser.add_argument( "--dataset_text_field", default=None, help="Text field in dataset (default: auto-detect, applies to all datasets)", ) parser.add_argument( "--text", action="append", default=[], help="Inline text samples (can pass multiple)", ) parser.add_argument( "--text_file", default=None, help="Path to a text file for calibration data", ) parser.add_argument( "--num_samples", type=int, default=128, help="Number of token sequences to use", ) parser.add_argument( "--target_tokens", type=int, default=0, help="Target token budget for common_lm_data-backed calibration/distillation (0 = disabled)", ) parser.add_argument("--seq_len", type=int, default=256, help="Sequence length") parser.add_argument("--batch_size", type=int, default=2, help="Batch size") parser.add_argument( "--device", default="cuda" if torch.cuda.is_available() else "cpu", help="Device for model + compute", ) parser.add_argument( "--dtype", default="auto", choices=["auto", "float32", "float16", "bfloat16"], help="Model dtype", ) parser.add_argument( "--layer_path", default=None, help="Override layer attribute path (e.g., model.layers)", ) parser.add_argument( "--fisher_mode", default="tensor", choices=["tensor", "param"], help="Fisher approximation granularity", ) parser.add_argument( "--no_head_permute", action="store_true", help=( "Deprecated alias for --no_head_permute_merge. " "Disables merge-stage head permutation only." ), ) parser.add_argument( "--no_head_permute_merge", action="store_true", help="Disable attention head permutation alignment before merge", ) parser.add_argument( "--no_head_permute_select", action="store_true", help="Disable attention head permutation alignment during auto selection", ) parser.add_argument("--eps", type=float, default=1e-8, help="Stability epsilon") parser.add_argument("--seed", type=int, default=0, help="Random seed") parser.add_argument( "--trust_remote_code", action="store_true", help="Allow custom model code from hub", ) parser.add_argument( "--save_metadata", action="store_true", help="Backward-compatible no-op; metadata is always written.", ) parser.add_argument( "--skip_eval", action="store_true", help="Skip pre/post perplexity evaluation", ) parser.add_argument( "--eval_dataset", action="append", default=[], help="Evaluation dataset name (repeatable). Defaults to wikitext.", ) parser.add_argument( "--eval_dataset_config", action="append", default=[], help="Evaluation dataset config (repeatable or single shared config).", ) parser.add_argument( "--eval_split", default="test", help="Evaluation dataset split (default: test)", ) parser.add_argument( "--eval_text_field", default=None, help="Evaluation text field override (default: auto-detect)", ) parser.add_argument( "--eval_model_family", type=str, choices=["auto", "llama", "qwen"], default="auto", help="Model family for BOS handling during eval", ) parser.add_argument( "--eval_add_bos", type=str, choices=["auto", "always", "never"], default="auto", help="Whether to prepend BOS to each eval sample", ) parser.add_argument( "--eval_num_samples", type=int, default=0, help="Number of token sequences per eval dataset (0 = all)", ) parser.add_argument( "--eval_seq_len", type=int, default=2048, help="Sequence length for eval", ) parser.add_argument( "--eval_batch_size", type=int, default=None, help="Batch size for eval (defaults to --batch_size)", ) parser.add_argument( "--eval_max_batches", type=int, default=None, help="Optional max number of eval batches per dataset", ) parser.add_argument( "--eval_cache_dir", default=None, help="Optional datasets cache dir for eval", ) parser.add_argument( "--eval_num_workers", type=int, default=0, help="Eval DataLoader workers", ) parser.add_argument( "--eval_device", default=None, help="Device for eval (defaults to --device)", ) parser.add_argument( "--skip_distill", action="store_true", help="Skip reparameterized distillation after head alignment/Fisher setup", ) parser.add_argument( "--distill_calib_samples", type=int, default=256, help="Number of distillation sequences from calibration datasets", ) parser.add_argument( "--distill_inst_samples", type=int, default=0, help="Number of distillation sequences from instruction dataset (0 = all)", ) parser.add_argument( "--distill_seq_len", type=int, default=512, help="Sequence length for distillation", ) parser.add_argument( "--distill_batch_size", type=int, default=2, help="Batch size for distillation", ) parser.add_argument( "--distill_epochs", type=float, default=1.0, help="Number of distillation epochs (float allowed, e.g. 0.5)", ) parser.add_argument( "--distill_lr", type=float, default=1e-4, help="Learning rate for distillation", ) parser.add_argument( "--distill_method", choices=["reparam"], default="reparam", help="Distillation strategy (reparam only).", ) parser.add_argument( "--distill_kl_weight", type=float, default=1e-2, help="Weight for KL loss on logits", ) parser.add_argument( "--distill_kl_temp", type=float, default=4.0, help="Temperature for KL distillation on logits", ) parser.add_argument( "--distill_hidden_mse_weight", type=float, default=1.0, help="Weight for hidden-state MSE in reparam distillation (0 disables it)", ) parser.add_argument( "--distill_attn_mse_weight", type=float, default=0.0, help="Weight for auxiliary attention-output MSE in reparam distillation", ) parser.add_argument( "--distill_mlp_mse_weight", type=float, default=0.0, help="Weight for auxiliary MLP-output MSE in reparam distillation", ) parser.add_argument( "--reparam_eta", type=float, default=1e-2, help="Eta: ||lambda - lambda_gate||^2 regularizer weight for --distill_method reparam", ) parser.add_argument( "--reparam_gamma", type=float, default=1e-4, help="Gamma: ||U - U0||^2 regularizer weight for --distill_method reparam", ) parser.add_argument( "--reparam_attn_reg_scale", type=float, default=1.0, help="Relative scale applied to attention-parameter reparam regularizers", ) parser.add_argument( "--reparam_mlp_reg_scale", type=float, default=1.0, help="Relative scale applied to MLP-parameter reparam regularizers", ) parser.add_argument( "--reparam_param_subset", type=str, choices=["all", "mlp", "attn"], default="all", help="Restrict reparam merge/recovery capacity to only this parameter family", ) parser.add_argument( "--norm_policy", type=str, choices=["hybrid", "merge_all", "copy_n1", "copy_n1_n2"], default="hybrid", help="Norm merge policy (default: hybrid)", ) parser.add_argument( "--distill_weight_decay", type=float, default=0.0, help="Weight decay for distillation", ) parser.add_argument( "--distill_max_grad_norm", type=float, default=1.0, help="Max grad norm for distillation", ) parser.add_argument( "--distill_grad_accum_steps", type=int, default=1, help="Gradient accumulation steps for distillation", ) parser.add_argument( "--distill_log_steps", type=int, default=100, help="Log distillation loss every N steps", ) parser.add_argument( "--distill_eval_every", type=int, default=0, help="Evaluate PPL every N distill steps (0 = disable)", ) parser.add_argument( "--distill_eval_max_batches", type=int, default=None, help="Max eval batches per dataset during distill (default: all)", ) parser.add_argument( "--distill_teacher_device", default=None, help="Device for teacher model during distillation (defaults to --device)", ) parser.add_argument( "--comm_enabled", action="store_true", help=( "Enable commutator-style preconditioning before each progressive " "cycle's fusion." ), ) parser.add_argument( "--comm_include_cycle1", action="store_true", help="Run commutator preconditioning for cycle 1 as well (default: skip cycle 1).", ) parser.add_argument( "--comm_topk", type=int, default=1, help="Top-K lowest-score pairs used as the commutator candidate set", ) parser.add_argument( "--comm_sample_eta", type=float, default=0.5, help="Mixture weight between uniform and score-biased candidate sampling", ) parser.add_argument( "--comm_sample_dwce_scale", type=float, default=1.0, help="Scale c in softmax(-c * score(i)) for commutator pair sampling", ) parser.add_argument( "--comm_temp", type=float, default=2.0, help="Temperature for teacher-anchor KL in commutator preconditioning", ) parser.add_argument( "--comm_steps_ratio", type=float, default=0.1, help="Run this fraction of distillation optimizer steps for commutator phase", ) parser.add_argument( "--comm_lr_scale", type=float, default=0.1, help="Commutator LR = --distill_lr * this scale", ) parser.add_argument( "--comm_train_mode", choices=["lora", "full"], default="lora", help=( "Commutator trainable parameter mode: " "'lora' updates LoRA adapters on sampled receiver layers; " "'full' updates full receiver-layer weights." ), ) parser.add_argument( "--comm_interaction_mode", choices=["mse", "relative"], default="relative", help="Interaction loss form: plain MSE or relative MSE", ) parser.add_argument( "--comm_interaction_eps", type=float, default=1e-8, help="Epsilon for relative commutator interaction normalization", ) parser.add_argument( "--comm_mu", type=float, default=None, help=( "Weight for interaction loss. Defaults to 0.1 for --comm_interaction_mode=mse " "and 0.5 for --comm_interaction_mode=relative." ), ) parser.add_argument( "--comm_mu_auto", action="store_true", help="Enable automatic mu scaling via gradient-norm balancing", ) parser.add_argument( "--comm_mu_auto_rho", type=float, default=0.1, help="Target anchor-to-interaction gradient ratio constant for auto-mu", ) parser.add_argument( "--comm_mu_auto_eps", type=float, default=1e-8, help="Numerical epsilon in auto-mu denominator", ) parser.add_argument( "--comm_log_steps", type=int, default=50, help="Log commutator preconditioning loss every N optimizer steps", ) parser.add_argument( "--comm_skip_post_reselect", action="store_true", help=( "Keep the pre-comm selected fusion pair and skip recomputing " "selection after commutator preconditioning." ), ) parser.add_argument( "--redistrib_teacher_source", type=str, choices=["base_model", "previous_cycle"], default="base_model", help=( "Teacher source for commutator preconditioning teacher loading. " "'base_model' uses --model for all cycles; " "'previous_cycle' uses cycle-1 checkpoint (cycle 1 falls back to base_model)." ), ) parser.add_argument( "--lora_epochs", type=float, default=1.0, help="LoRA CE finetuning epochs after distill (0 = disable)", ) parser.add_argument( "--lora_rank", type=int, default=8, help="LoRA rank (r)", ) parser.add_argument( "--lora_alpha", type=float, default=16.0, help="LoRA alpha", ) parser.add_argument( "--lora_dropout", type=float, default=0.0, help="LoRA dropout", ) parser.add_argument( "--lora_kl_enabled", action="store_true", help="Add KL regularization between pre/post LoRA logits", ) parser.add_argument( "--lora_kl_weight", type=float, default=1e-1, help="KL weight for LoRA regularization", ) parser.add_argument( "--lora_kl_temp", type=float, default=4.0, help="Temperature for LoRA KL regularization", ) parser.add_argument( "--lora_target_modules", nargs="*", default=[ "q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "down_proj", "up_proj", ], help="Module name suffixes to LoRA-wrap", ) parser.add_argument( "--lora_respect_exclude_pairs", action="store_true", help=( "When attaching LoRA adapters, skip linear modules under layers touched by " "--exclude_pairs (i and i+1 for each excluded pair)." ), ) parser.add_argument( "--lora_lr", type=float, default=1e-4, help="Learning rate for LoRA finetuning", ) parser.add_argument( "--lora_weight_decay", type=float, default=0.0, help="Weight decay for LoRA finetuning", ) parser.add_argument( "--lora_max_grad_norm", type=float, default=1.0, help="Max grad norm for LoRA finetuning", ) parser.add_argument( "--lora_grad_accum_steps", type=int, default=1, help="Gradient accumulation steps for LoRA finetuning", ) parser.add_argument( "--lora_log_steps", type=int, default=100, help="Log LoRA loss every N steps", ) parser.add_argument( "--lora_eval_every", type=int, default=0, help="Evaluate PPL every N LoRA steps (0 = disable)", ) parser.add_argument( "--lora_eval_max_batches", type=int, default=None, help="Max eval batches per dataset during LoRA (default: all)", ) parser.add_argument( "--instruction_dataset", default=None, help="HF dataset name for alpaca-style instruction data", ) parser.add_argument( "--instruction_config", default=None, help="Optional instruction dataset config", ) parser.add_argument( "--instruction_split", default="train", help="Instruction dataset split", ) parser.add_argument( "--instruction_field_instruction", default="instruction", help="Instruction field name", ) parser.add_argument( "--instruction_field_input", default="input", help="Optional input field name", ) parser.add_argument( "--instruction_field_output", default="output", help="Response/output field name", ) parser.add_argument( "--auto_max_batches", type=int, default=0, help="Max calibration batches for auto selection scoring (0 = all)", ) parser.add_argument( "--auto_metric", type=str, choices=[ "dwce", "cosine", "hybrid", "hybrid_cosine", "hybrid_global_rel", ], default="dwce", help=( "Auto pair scoring metric. 'dwce' uses downstream-weighted composition error; " "'cosine' uses average token-level cosine distance between adjacent layer outputs; " "'hybrid'/'hybrid_cosine' use DWCE to shortlist then adjacent cosine for final scoring; " "'hybrid_global_rel' uses DWCE to shortlist then reranks by the change in " "pair-to-final-layer cosine relation after surrogate fusion." ), ) parser.add_argument( "--auto_cosine_topk", type=int, default=3, help="Top-K DWCE candidates to rescore with cosine in --auto_metric=hybrid", ) parser.add_argument( "--auto_norm", type=str, choices=["relative", "none"], default="relative", help="Normalization mode for DWCE scoring (ignored for cosine)", ) parser.add_argument( "--auto_dwce_mode", type=str, choices=["separate", "shared"], default="separate", help=( "DWCE implementation for auto scoring. " "'separate' runs distinct Fisher and DWCE backward passes; " "'shared' reuses one backward pass and replays DWCE with cached gradients." ), ) parser.add_argument( "--num_progressive", type=int, default=0, help="Number of progressive fusions (>0 required)", ) parser.add_argument( "--resume_from_cycle", type=int, default=0, help=( "Resume from this completed cycle index. When > 0, --model should point " "to the saved full model directory for that cycle." ), ) parser.add_argument( "--save_full_model_cycles", nargs="*", default=[], help=( "Cycle indices whose full models should be saved. Requesting cycle c " "also saves cycle c-1 automatically (c=1 saves only cycle 1)." ), ) return parser.parse_args() def parse_exclude_pairs(exclude_raw: Optional[List[str]], num_pairs: int) -> List[int]: """Parse --exclude_pairs into normalized pair indices for the current model. Indices refer to the start of an adjacent pair (i, i+1) and must be in [0..N-2]. Negative indices count from the end (-1 = last pair). """ if not exclude_raw: return [] exclude: List[int] = [] for item in exclude_raw: if item is None: continue for part in str(item).split(","): part = part.strip() if not part: continue try: idx = int(part) except ValueError as exc: raise SystemExit("--exclude_pairs must contain integers.") from exc if idx < 0: idx = num_pairs + idx if 0 <= idx < num_pairs: exclude.append(idx) return sorted(set(exclude)) def parse_cycle_list(raw_values: Optional[List[str]]) -> List[int]: if not raw_values: return [] cycles: List[int] = [] for item in raw_values: if item is None: continue for part in str(item).split(","): part = part.strip() if not part: continue try: cycles.append(int(part)) except ValueError as exc: raise SystemExit( "--save_full_model_cycles must contain integers." ) from exc return cycles def resolve_full_model_save_cycles( requested_cycles: List[int], num_progressive: int ) -> Set[int]: resolved: Set[int] = set() for cycle in requested_cycles: if cycle <= 0 or cycle > num_progressive: raise SystemExit( "--save_full_model_cycles entries must be within [1, --num_progressive]." ) resolved.add(cycle) if cycle > 1: resolved.add(cycle - 1) return resolved def load_resume_metadata(model_path: str) -> Optional[Dict[str, object]]: resume_meta_path = os.path.join(model_path, "resume_info.json") if not os.path.exists(resume_meta_path): return None with open(resume_meta_path, "r", encoding="utf-8") as handle: loaded = json.load(handle) return loaded if isinstance(loaded, dict) else None def build_generator(seed: int) -> torch.Generator: generator = torch.Generator(device="cpu") generator.manual_seed(int(seed)) return generator def capture_rng_state() -> Dict[str, object]: state: Dict[str, object] = { "python_random_state": random.getstate(), "torch_cpu_rng_state": torch.get_rng_state(), } if np is not None: state["numpy_random_state"] = np.random.get_state() if torch.cuda.is_available(): state["torch_cuda_rng_state_all"] = torch.cuda.get_rng_state_all() return state def restore_rng_state(state: Dict[str, object]) -> None: python_state = state.get("python_random_state") if python_state is not None: random.setstate(python_state) numpy_state = state.get("numpy_random_state") if numpy_state is not None and np is not None: np.random.set_state(numpy_state) torch_cpu_state = state.get("torch_cpu_rng_state") if torch_cpu_state is not None: torch.set_rng_state(torch_cpu_state) torch_cuda_state = state.get("torch_cuda_rng_state_all") if torch_cuda_state is not None and torch.cuda.is_available(): torch.cuda.set_rng_state_all(torch_cuda_state) def save_rng_state(path: str) -> None: torch.save(capture_rng_state(), path) def load_rng_state(path: str) -> Optional[Dict[str, object]]: if not os.path.exists(path): return None loaded = torch.load(path, map_location="cpu", weights_only=False) return loaded if isinstance(loaded, dict) else None def configure_reproducibility(seed: int) -> None: random.seed(seed) if np is not None: np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) if hasattr(torch.backends, "cudnn"): torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False if hasattr(torch, "use_deterministic_algorithms"): torch.use_deterministic_algorithms(True, warn_only=True) def save_loader_generator_state( base_dir: str, *, distill_generator: Optional[torch.Generator] = None, lora_generator: Optional[torch.Generator] = None, ) -> None: state: Dict[str, object] = {} if distill_generator is not None: state["distill_generator_state"] = distill_generator.get_state() if lora_generator is not None: state["lora_generator_state"] = lora_generator.get_state() if state: torch.save(state, os.path.join(base_dir, "loader_generators.pt")) def load_loader_generator_state(base_dir: str) -> Optional[Dict[str, object]]: path = os.path.join(base_dir, "loader_generators.pt") if not os.path.exists(path): return None loaded = torch.load(path, map_location="cpu") return loaded if isinstance(loaded, dict) else None def resolve_layer_idx( args: argparse.Namespace, model, layers: List[torch.nn.Module], dataloader, previous_scores, start_index: int, exclude_pairs: Set[int], ): layer_arg = str(getattr(args, "layer", "auto")).strip().lower() selection_method = str(getattr(args, "selection_method", "dwce")).strip().lower() if layer_arg != "auto": try: layer_idx = int(layer_arg) except ValueError as exc: raise SystemExit("--layer must be 'auto' or an integer index") from exc num_pairs = max(len(layers) - 1, 0) if layer_idx < 0: layer_idx += num_pairs if layer_idx in exclude_pairs: raise SystemExit(f"--layer resolved to excluded pair index {layer_idx}") return layer_idx, previous_scores, {"method": "manual", "exclude_pairs": sorted(exclude_pairs)} if selection_method == "sequential": num_pairs = len(layers) - 1 for layer_idx in range(max(start_index, 0), num_pairs): if layer_idx not in exclude_pairs: return layer_idx, previous_scores, { "method": "sequential", "start_index": max(start_index, 0), "exclude_pairs": sorted(exclude_pairs), } raise SystemExit("No eligible layer pairs remain after exclusions") layer_idx, dwce_scores, dwce_meta = select_layer_auto( model, layers, dataloader, args, previous_scores=previous_scores, start_index=start_index, exclude_pairs=exclude_pairs, ) return layer_idx, dwce_scores, dwce_meta @dataclass class PreparedData: calib_loader: torch.utils.data.DataLoader calib_num_sequences: int distill_loader: Optional[torch.utils.data.DataLoader] distill_generator: Optional[torch.Generator] distill_meta: Dict[str, object] lora_loader: Optional[torch.utils.data.DataLoader] lora_generator: Optional[torch.Generator] lora_meta: Dict[str, object] eval_datasets: List[str] eval_configs: List[Optional[str]] eval_dataloaders: Optional[Dict[str, torch.utils.data.DataLoader]] def resolve_eval_datasets(args: argparse.Namespace) -> Tuple[List[str], List[Optional[str]]]: eval_datasets = args.eval_dataset or ["wikitext"] eval_configs = args.eval_dataset_config or ["wikitext-2-raw-v1"] eval_configs = ppl_eval._expand_dataset_configs(eval_datasets, eval_configs) return eval_datasets, eval_configs def run_ppl_eval( model_id_or_path: str, eval_datasets: List[str], eval_configs: List[Optional[str]], args: argparse.Namespace, prepared_eval_dataloaders: Optional[Dict[str, torch.utils.data.DataLoader]] = None, ) -> Dict[str, float]: eval_device = args.eval_device or args.device dtype = get_dtype(args.dtype) eval_model = load_causal_lm( model_id_or_path, torch_dtype=dtype, trust_remote_code=args.trust_remote_code, ) eval_model.to(eval_device) if prepared_eval_dataloaders is not None: results = ppl_eval.evaluate_ppl_dataloaders( eval_model, prepared_eval_dataloaders, eval_device, max_batches=args.eval_max_batches, ) else: eval_batch_size = args.eval_batch_size or args.batch_size eval_tokenizer = AutoTokenizer.from_pretrained( model_id_or_path, trust_remote_code=args.trust_remote_code ) if eval_tokenizer.pad_token is None and eval_tokenizer.eos_token is not None: eval_tokenizer.pad_token = eval_tokenizer.eos_token results = ppl_eval.evaluate_ppl_datasets( eval_model, eval_tokenizer, datasets=eval_datasets, configs=eval_configs, split=args.eval_split, text_field=args.eval_text_field, num_samples=args.eval_num_samples, seq_len=args.eval_seq_len, batch_size=eval_batch_size, device=eval_device, seed=args.seed, shuffle=False, model_family=args.eval_model_family, add_bos=args.eval_add_bos, max_batches=args.eval_max_batches, cache_dir=args.eval_cache_dir, num_workers=args.eval_num_workers, ) del eval_model if torch.cuda.is_available(): torch.cuda.empty_cache() return results def build_calibration_dataloader( args: argparse.Namespace, tokenizer ) -> Tuple[List[str], List[torch.Tensor], torch.utils.data.DataLoader]: if args.dataset: datasets = list(args.dataset) configs = expand_dataset_configs(datasets, list(args.dataset_config)) chunks: List[torch.Tensor] = [] for idx, (dataset_name, config) in enumerate(zip(datasets, configs)): spec = SharedLMDataSpec( dataset=dataset_name, config=config, split=args.dataset_split, text_field=args.dataset_text_field, seq_len=args.seq_len, num_sequences=args.num_samples, seed=args.seed + idx, ) chunks.extend(build_chunks(spec, tokenizer)) if not chunks: raise SystemExit("Not enough text to build token sequences.") input_ids = torch.stack(chunks) attention_mask = torch.ones_like(input_ids) dataset = torch.utils.data.TensorDataset(input_ids, attention_mask) dataloader = torch.utils.data.DataLoader( dataset, batch_size=args.batch_size, shuffle=False ) return [], chunks, dataloader texts = load_texts(args) if not texts: raise SystemExit( "No calibration text found. Provide --dataset, --text, or --text_file." ) chunks = build_token_chunks(texts, tokenizer, args.seq_len, args.num_samples) if not chunks: raise SystemExit("Not enough text to build token sequences.") input_ids = torch.stack(chunks) attention_mask = torch.ones_like(input_ids) dataset = torch.utils.data.TensorDataset(input_ids, attention_mask) dataloader = torch.utils.data.DataLoader( dataset, batch_size=args.batch_size, shuffle=False ) return texts, chunks, dataloader def prepare_distillation_data( args: argparse.Namespace, tokenizer, include_instruction: bool = True ) -> Tuple[Optional[torch.utils.data.DataLoader], Optional[torch.Generator], Dict[str, object]]: if ( include_instruction and args.distill_inst_samples != 0 and not args.instruction_dataset ): print( "Warning: --distill_inst_samples > 0 but no --instruction_dataset " "provided; instruction distillation will be skipped." ) calib_texts: List[str] = [] calib_dataset = None if args.target_tokens > 0 and args.dataset: datasets = list(args.dataset) configs = expand_dataset_configs(datasets, list(args.dataset_config)) per_dataset = args.target_tokens // len(datasets) remainder = args.target_tokens % len(datasets) calib_chunks: List[torch.Tensor] = [] for idx, (dataset_name, config) in enumerate(zip(datasets, configs)): dataset_tokens = per_dataset + (remainder if idx == 0 else 0) spec = SharedLMDataSpec( dataset=dataset_name, config=config, split=args.dataset_split, text_field=args.dataset_text_field, seq_len=args.distill_seq_len, target_tokens=dataset_tokens, seed=args.seed + 17 + idx, ) calib_chunks.extend(build_chunks(spec, tokenizer)) if calib_chunks: input_ids = torch.stack(calib_chunks) attention_mask = torch.ones_like(input_ids) calib_dataset = torch.utils.data.TensorDataset(input_ids, attention_mask) else: calib_texts = load_texts_from_datasets( datasets=list(args.dataset), configs=expand_dataset_configs(list(args.dataset), list(args.dataset_config)), split=args.dataset_split, text_field=args.dataset_text_field, num_samples=args.distill_calib_samples, seed=args.seed + 17, ) inst_records = [] if include_instruction: inst_records = load_instruction_records(args, args.distill_inst_samples) distill_datasets = [] if calib_dataset is not None: distill_datasets.append(calib_dataset) elif calib_texts: calib_records = [{"text": text} for text in calib_texts] distill_datasets.append( FixedSeqDataset(calib_records, tokenizer, args.distill_seq_len) ) if inst_records: distill_datasets.append( FixedSeqDataset(inst_records, tokenizer, args.distill_seq_len) ) distill_meta: Dict[str, object] = { "calib_texts": len(calib_texts), "calib_sequences": len(calib_dataset) if calib_dataset is not None else len(calib_texts), "inst_sequences": len(inst_records), "total_sequences": 0, } if not distill_datasets: return None, None, distill_meta if len(distill_datasets) == 1: distill_dataset = distill_datasets[0] else: distill_dataset = torch.utils.data.ConcatDataset(distill_datasets) distill_meta["total_sequences"] = len(distill_dataset) distill_generator = build_generator( args.seed + 1000 + (1000000 if include_instruction else 0) ) distill_loader = torch.utils.data.DataLoader( distill_dataset, batch_size=args.distill_batch_size, shuffle=True, generator=distill_generator, ) return distill_loader, distill_generator, distill_meta def prepare_eval_dataloaders( args: argparse.Namespace, tokenizer, model: torch.nn.Module, eval_datasets: List[str], eval_configs: List[Optional[str]], ) -> Optional[Dict[str, torch.utils.data.DataLoader]]: needs_eval = (not args.skip_eval) or ( (not args.skip_distill and args.distill_eval_every) or (args.lora_epochs > 0 and args.lora_eval_every) ) if not needs_eval: return None eval_batch_size = args.eval_batch_size or args.batch_size resolved_family = args.eval_model_family if resolved_family == "auto": resolved_family = ppl_eval._infer_model_family(model) return ppl_eval.prepare_ppl_dataloaders( tokenizer=tokenizer, datasets=eval_datasets, configs=eval_configs, split=args.eval_split, text_field=args.eval_text_field, num_samples=args.eval_num_samples, seq_len=args.eval_seq_len, batch_size=eval_batch_size, seed=args.seed, shuffle=False, model_family=resolved_family, add_bos=args.eval_add_bos, cache_dir=args.eval_cache_dir, num_workers=args.eval_num_workers, model=model, ) def prepare_all_data( args: argparse.Namespace, tokenizer, model: torch.nn.Module, eval_datasets: List[str], eval_configs: List[Optional[str]], loader_generator_state: Optional[Dict[str, object]] = None, ) -> PreparedData: texts, chunks, calib_loader = build_calibration_dataloader(args, tokenizer) calib_num_sequences = len(chunks) del texts del chunks distill_loader = None distill_generator = None distill_meta: Dict[str, object] = { "calib_texts": 0, "calib_sequences": 0, "inst_sequences": 0, "total_sequences": 0, } lora_loader = None lora_generator = None lora_meta: Dict[str, object] = { "calib_texts": 0, "calib_sequences": 0, "inst_sequences": 0, "total_sequences": 0, } if (not args.skip_distill) or bool(getattr(args, "comm_enabled", False)): distill_loader, distill_generator, distill_meta = prepare_distillation_data( args, tokenizer, include_instruction=False ) if ( distill_generator is not None and loader_generator_state is not None and loader_generator_state.get("distill_generator_state") is not None ): distill_generator.set_state(loader_generator_state["distill_generator_state"]) if args.lora_epochs > 0: lora_loader, lora_generator, lora_meta = prepare_distillation_data( args, tokenizer, include_instruction=True ) if ( lora_generator is not None and loader_generator_state is not None and loader_generator_state.get("lora_generator_state") is not None ): lora_generator.set_state(loader_generator_state["lora_generator_state"]) eval_dataloaders = prepare_eval_dataloaders( args, tokenizer, model, eval_datasets, eval_configs ) return PreparedData( calib_loader=calib_loader, calib_num_sequences=calib_num_sequences, distill_loader=distill_loader, distill_generator=distill_generator, distill_meta=distill_meta, lora_loader=lora_loader, lora_generator=lora_generator, lora_meta=lora_meta, eval_datasets=eval_datasets, eval_configs=eval_configs, eval_dataloaders=eval_dataloaders, ) def evaluate_ppl_model( model: torch.nn.Module, tokenizer, eval_datasets: List[str], eval_configs: List[Optional[str]], args: argparse.Namespace, max_batches: Optional[int] = None, prepared_eval_dataloaders: Optional[Dict[str, torch.utils.data.DataLoader]] = None, ) -> Dict[str, float]: eval_device = args.eval_device or args.device prev_mode = model.training try: prev_device = next(model.parameters()).device except StopIteration: prev_device = torch.device(eval_device) model.eval() if str(prev_device) != eval_device: model.to(eval_device) if prepared_eval_dataloaders is not None: results = ppl_eval.evaluate_ppl_dataloaders( model, prepared_eval_dataloaders, eval_device, max_batches=max_batches if max_batches is not None else args.eval_max_batches, ) else: eval_batch_size = args.eval_batch_size or args.batch_size results = ppl_eval.evaluate_ppl_datasets( model, tokenizer, datasets=eval_datasets, configs=eval_configs, split=args.eval_split, text_field=args.eval_text_field, num_samples=args.eval_num_samples, seq_len=args.eval_seq_len, batch_size=eval_batch_size, device=eval_device, seed=args.seed, shuffle=False, model_family=args.eval_model_family, add_bos=args.eval_add_bos, max_batches=max_batches if max_batches is not None else args.eval_max_batches, cache_dir=args.eval_cache_dir, num_workers=args.eval_num_workers, ) if prev_mode: model.train() if str(prev_device) != eval_device: model.to(prev_device) if torch.cuda.is_available(): torch.cuda.empty_cache() return results def has_post_fusion_data( distill_loader: Optional[torch.utils.data.DataLoader], distill_meta: Optional[Dict[str, object]], ) -> bool: if distill_loader is None or distill_meta is None: return False return distill_meta.get("total_sequences", 0) > 0 def summarize_gate_lambdas(gates: Dict[str, torch.Tensor]) -> Dict[str, object]: if not gates: return {"num_tensors": 0, "num_elements": 0} total_sum = 0.0 total_elems = 0 per_tensor_mean: Dict[str, Optional[float]] = {} for name, gate in gates.items(): g = gate.detach().float() if g.numel() == 0: per_tensor_mean[name] = None continue per_tensor_mean[name] = float(g.mean().item()) total_sum += float(g.sum().item()) total_elems += int(g.numel()) global_mean = None if total_elems == 0 else total_sum / float(total_elems) return { "num_tensors": len(gates), "num_elements": total_elems, "global_mean": global_mean, "per_tensor_mean": per_tensor_mean, } def compute_path_bytes(path: str) -> int: if os.path.isfile(path): return os.path.getsize(path) total = 0 for root, _, files in os.walk(path): for name in files: file_path = os.path.join(root, name) if os.path.islink(file_path): continue try: total += os.path.getsize(file_path) except OSError: continue return total def save_stage_checkpoint( model: torch.nn.Module, tokenizer, stage_dir: str, stage_name: str, ppl_results: Optional[Dict[str, float]], ) -> Dict[str, object]: os.makedirs(stage_dir, exist_ok=True) colon_modules = find_colon_modules(model) if colon_modules: raise RuntimeError( "Unexpected module names with ':' detected before save: " f"{', '.join(colon_modules)}." ) model.save_pretrained(stage_dir) tokenizer.save_pretrained(stage_dir) stage_meta = { "stage": stage_name, "path": stage_dir, "weight_bytes": compute_path_bytes(stage_dir), "post_ppl": ppl_results, } with open( os.path.join(stage_dir, "stage_metrics.json"), "w", encoding="utf-8", ) as handle: json.dump(stage_meta, handle, indent=2) return stage_meta def save_cycle_full_model( model: torch.nn.Module, tokenizer, cycle_dir: str, cycle_idx: int, args: argparse.Namespace, ppl_results: Optional[Dict[str, float]], ) -> Dict[str, object]: full_model_dir = os.path.join(cycle_dir, "full_model") stage_meta = save_stage_checkpoint( model=model, tokenizer=tokenizer, stage_dir=full_model_dir, stage_name=f"cycle_{cycle_idx}_full_model", ppl_results=ppl_results, ) resume_meta = { "base_model": getattr(args, "base_model_id", args.model), "cycle": cycle_idx, "output_dir": args.output_dir, "layer_path": args.layer_path, "rng_state": "rng_state.pt", "loader_generators": "loader_generators.pt", } with open( os.path.join(full_model_dir, "resume_info.json"), "w", encoding="utf-8", ) as handle: json.dump(resume_meta, handle, indent=2) stage_meta["resume_info"] = "resume_info.json" return stage_meta def run_lora_phase( model: torch.nn.Module, tokenizer, eval_datasets: List[str], eval_configs: List[Optional[str]], args: argparse.Namespace, lora_loader: Optional[torch.utils.data.DataLoader] = None, lora_meta: Optional[Dict[str, object]] = None, eval_dataloaders: Optional[Dict[str, torch.utils.data.DataLoader]] = None, cycle_idx: Optional[int] = None, num_cycles: Optional[int] = None, ) -> List[Dict[str, object]]: lora_eval_history: List[Dict[str, object]] = [] if args.lora_epochs <= 0: return lora_eval_history if not has_post_fusion_data(lora_loader, lora_meta): print("No post-fusion sequences built; skipping LoRA finetuning.") return lora_eval_history lora_ce_finetune( model=model, dataloader=lora_loader, eval_tokenizer=tokenizer, eval_datasets=eval_datasets, eval_configs=eval_configs, eval_history=lora_eval_history, args=args, eval_dataloaders=eval_dataloaders, progressive_cycle=cycle_idx, progressive_total=num_cycles, ) return lora_eval_history def run_progressive( args: argparse.Namespace, model: torch.nn.Module, tokenizer, prepared: PreparedData, ) -> None: eval_datasets = prepared.eval_datasets eval_configs = prepared.eval_configs dataloader = prepared.calib_loader num_sequences = prepared.calib_num_sequences model.to(args.device) os.makedirs(args.output_dir, exist_ok=True) progressive_meta_path = os.path.join(args.output_dir, "progressive_metadata.json") existing_meta: Dict[str, object] = {} if args.resume_from_cycle > 0 and os.path.exists(progressive_meta_path): with open(progressive_meta_path, "r", encoding="utf-8") as handle: loaded_meta = json.load(handle) if isinstance(loaded_meta, dict): existing_meta = loaded_meta bootstrap_meta = { "base_model": getattr(args, "base_model_id", args.model), "num_progressive": args.num_progressive, "layer_path": args.layer_path, "resume_from_cycle": args.resume_from_cycle, "save_full_model_cycles": sorted(args.full_model_save_cycles), "cycles": ( existing_meta.get("cycles", []) if isinstance(existing_meta.get("cycles"), list) else [] ), } with open( progressive_meta_path, "w", encoding="utf-8", ) as handle: json.dump(bootstrap_meta, handle, indent=2) pre_eval = None if not args.skip_eval: pre_eval = evaluate_ppl_model( model, tokenizer, eval_datasets, eval_configs, args, prepared_eval_dataloaders=prepared.eval_dataloaders, ) print("Pre-pruning perplexity:") for dataset_name, ppl in pre_eval.items(): print(f"{dataset_name}: {ppl:.4f}") parent, name, container = find_layer_container(model, args.layer_path) layers = list(container) if args.num_progressive > (len(layers) - 1 + args.resume_from_cycle): raise SystemExit( f"--num_progressive ({args.num_progressive}) exceeds available pairs " f"after resume offset ({len(layers) - 1 + args.resume_from_cycle})" ) dwce_scores = None dwce_meta = None last_fused_idx = 0 cycle_summaries: List[Dict[str, object]] = [] existing_cycles = existing_meta.get("cycles", []) if isinstance(existing_cycles, list): for entry in existing_cycles: if not isinstance(entry, dict): continue cycle_value = entry.get("cycle") if isinstance(cycle_value, int) and cycle_value <= args.resume_from_cycle: cycle_summaries.append(entry) comm_enabled = bool(getattr(args, "comm_enabled", False)) comm_teacher_model = None comm_teacher_cycle: Optional[int] = None teacher_device = args.distill_teacher_device or args.device previous_cycle_teacher_model = None previous_cycle_teacher_cycle: Optional[int] = None def _release_comm_teacher() -> None: nonlocal comm_teacher_model, comm_teacher_cycle if comm_teacher_model is not None: del comm_teacher_model comm_teacher_model = None comm_teacher_cycle = None if torch.cuda.is_available(): torch.cuda.empty_cache() def _release_previous_cycle_teacher() -> None: nonlocal previous_cycle_teacher_model, previous_cycle_teacher_cycle if previous_cycle_teacher_model is not None: del previous_cycle_teacher_model previous_cycle_teacher_model = None previous_cycle_teacher_cycle = None if torch.cuda.is_available(): torch.cuda.empty_cache() def _snapshot_previous_cycle_teacher(cycle_idx: int) -> None: nonlocal previous_cycle_teacher_model, previous_cycle_teacher_cycle _release_previous_cycle_teacher() previous_cycle_teacher_model = copy.deepcopy(model) previous_cycle_teacher_model.to(teacher_device) previous_cycle_teacher_model.eval() previous_cycle_teacher_cycle = cycle_idx def _get_previous_cycle_teacher( cycle_idx: int, ) -> Tuple[Optional[torch.nn.Module], str, Optional[int]]: prev_cycle = cycle_idx - 1 if prev_cycle <= 0: return None, "base_model", 0 if ( previous_cycle_teacher_model is not None and previous_cycle_teacher_cycle == prev_cycle ): return previous_cycle_teacher_model, "previous_cycle_memory", prev_cycle teacher_model = load_progressive_model( getattr(args, "base_model_id", args.model), args.output_dir, cycle=prev_cycle, device=teacher_device, dtype=args.dtype, trust_remote_code=args.trust_remote_code, layer_path=args.layer_path, ) teacher_model.eval() return teacher_model, "previous_cycle_disk", prev_cycle def _get_comm_teacher(cycle_idx: int) -> Tuple[Optional[torch.nn.Module], str, Optional[int]]: nonlocal comm_teacher_model, comm_teacher_cycle if not comm_enabled: return None, "disabled", None source = str(getattr(args, "redistrib_teacher_source", "base_model")) if source == "base_model": if comm_teacher_model is None: print( "[comm] Loading fixed base teacher for anchor loss " f"(device={teacher_device})." ) comm_teacher_model = AutoModelForCausalLM.from_pretrained( getattr(args, "base_model_id", args.model), torch_dtype=get_dtype(args.dtype), trust_remote_code=args.trust_remote_code, ) comm_teacher_model.to(teacher_device) comm_teacher_model.eval() comm_teacher_cycle = 0 return comm_teacher_model, "base_model", 0 prev_cycle = cycle_idx - 1 if prev_cycle <= 0: if comm_teacher_model is None or comm_teacher_cycle != 0: _release_comm_teacher() print( "[comm] --redistrib_teacher_source=previous_cycle but cycle 1 " "has no prior checkpoint; using base teacher." ) comm_teacher_model = AutoModelForCausalLM.from_pretrained( getattr(args, "base_model_id", args.model), torch_dtype=get_dtype(args.dtype), trust_remote_code=args.trust_remote_code, ) comm_teacher_model.to(teacher_device) comm_teacher_model.eval() comm_teacher_cycle = 0 return comm_teacher_model, "base_model", 0 if ( previous_cycle_teacher_model is not None and previous_cycle_teacher_cycle == prev_cycle ): if comm_teacher_model is not previous_cycle_teacher_model: _release_comm_teacher() comm_teacher_model = previous_cycle_teacher_model comm_teacher_cycle = prev_cycle return comm_teacher_model, "previous_cycle_memory", prev_cycle if comm_teacher_model is None or comm_teacher_cycle != prev_cycle: _release_comm_teacher() print( "[comm] Loading teacher from previous cycle " f"{prev_cycle} (device={teacher_device})." ) comm_teacher_model = load_progressive_model( getattr(args, "base_model_id", args.model), args.output_dir, cycle=prev_cycle, device=teacher_device, dtype=args.dtype, trust_remote_code=args.trust_remote_code, layer_path=args.layer_path, ) comm_teacher_model.eval() comm_teacher_cycle = prev_cycle return comm_teacher_model, "previous_cycle_disk", prev_cycle if args.resume_from_cycle > 0: _snapshot_previous_cycle_teacher(args.resume_from_cycle) start_cycle = args.resume_from_cycle + 1 for cycle_idx in range(start_cycle, args.num_progressive + 1): print(f"[progressive] Cycle {cycle_idx}/{args.num_progressive}") run_comm = comm_enabled and ( cycle_idx > 1 or bool(getattr(args, "comm_include_cycle1", False)) ) comm_stats: Dict[str, object] = {"enabled": False} comm_post_eval = None if run_comm: # Preconditioning updates model weights, so DWCE reuse is unreliable. start_index = 0 reuse_scores = None else: start_index = last_fused_idx if cycle_idx > 1 else 0 reuse_scores = dwce_scores exclude_pairs = set(parse_exclude_pairs(args.exclude_pairs, max(len(layers) - 1, 0))) layer_idx, dwce_scores, dwce_meta = resolve_layer_idx( args, model, layers, dataloader, reuse_scores, start_index, exclude_pairs, ) if run_comm: dwce_scores_pre_comm = dwce_scores if prepared.calib_loader is None: print( "[comm] Enabled but no calibration sequences were built; skipping." ) else: ( comm_teacher_model_loaded, comm_teacher_source, comm_teacher_cycle_idx, ) = _get_comm_teacher(cycle_idx) if comm_teacher_model_loaded is None: raise RuntimeError("comm_enabled but teacher model was not loaded.") comm_stats = commutator_precondition( student_model=model, student_layers=layers, teacher_model=comm_teacher_model_loaded, dataloader=prepared.calib_loader, dwce_scores=dwce_scores_pre_comm, exclude_pairs=exclude_pairs, args=args, progressive_cycle=cycle_idx, progressive_total=args.num_progressive, ) if comm_stats.get("enabled"): comm_stats["teacher_source"] = comm_teacher_source comm_stats["teacher_cycle"] = comm_teacher_cycle_idx comm_stats["dwce_scores_pre"] = dwce_scores_pre_comm print( "[comm] Done:" f" opt_steps={comm_stats.get('opt_steps')}" f" lr={comm_stats.get('lr')}" ) if not args.skip_eval: comm_post_eval = evaluate_ppl_model( model, tokenizer, eval_datasets, eval_configs, args, prepared_eval_dataloaders=prepared.eval_dataloaders, ) comm_stats["post_ppl"] = comm_post_eval print(f"[progressive] Cycle {cycle_idx} post-comm perplexity:") for dataset_name, ppl in comm_post_eval.items(): print(f"{dataset_name}: {ppl:.4f}") if bool(getattr(args, "comm_skip_post_reselect", False)): comm_stats["post_selection_recomputed"] = False comm_stats["selected_layer_post"] = int(layer_idx) print( "[comm] Keeping pre-comm DWCE pair selection for fusion." ) else: print( "[comm] Recomputing DWCE after preconditioning for fusion selection." ) layer_idx, dwce_scores, dwce_meta = resolve_layer_idx( args, model, layers, dataloader, None, 0, exclude_pairs, ) comm_stats["post_selection_recomputed"] = True comm_stats["selected_layer_post"] = int(layer_idx) if layer_idx < 0 or layer_idx >= len(layers) - 1: raise SystemExit("--layer must be in [0, num_layers-2]") num_layers_before = len(layers) layer_a = layers[layer_idx] layer_b = layers[layer_idx + 1] norm1_state = None norm2_state = None norm1, norm2, norm_names = get_norm_pair(layer_a) if norm1 is not None: norm1_state = clone_state_dict(norm1) if norm2 is not None: norm2_state = clone_state_dict(norm2) attn_a = find_attention_module(layer_a) attn_b = find_attention_module(layer_b) hidden_size = getattr(model.config, "hidden_size", None) if hidden_size is None: hidden_size = getattr(model.config, "n_embd", None) if hidden_size is None: raise SystemExit("Model config missing hidden_size/n_embd") no_head_permute_merge = bool( getattr(args, "no_head_permute_merge", False) or getattr(args, "no_head_permute", False) ) if no_head_permute_merge: print("[fuse] Head permutation disabled; merging with original head order.") else: mean_a, mean_b, num_heads, num_kv_heads, head_dim = compute_head_means( model, attn_a, attn_b, dataloader, args.device, hidden_size, ) perm = build_head_permutation( mean_a, mean_b, num_heads=num_heads, num_kv_heads=num_kv_heads, eps=args.eps, ) permute_attention_heads( attn_b, perm, num_heads, num_kv_heads, head_dim=head_dim ) fisher_sums, num_batches, param_numels = compute_fisher( model, layer_a, layer_b, dataloader, fisher_mode=args.fisher_mode, device=args.device, ) distill_ready = has_post_fusion_data( prepared.distill_loader, prepared.distill_meta ) teacher_cycle = cycle_idx - 1 teacher_source = "previous_cycle" if teacher_cycle > 0 else "base_model" merge_method = "fisher" distill_method = str(getattr(args, "distill_method", "reparam")) reparam_stats: Optional[Dict[str, object]] = None reparam_gate_summary: Optional[Dict[str, object]] = None needs_teacher_for_reparam = ( (not args.skip_distill) and distill_ready and float(args.distill_epochs) > 0.0 ) teacher_model = None teacher_parent = None teacher_layer_attr = None teacher_layers: Optional[List[torch.nn.Module]] = None teacher_from_cache = False if needs_teacher_for_reparam: teacher_model, teacher_source, teacher_cycle = _get_previous_cycle_teacher( cycle_idx ) teacher_from_cache = ( teacher_source == "previous_cycle_memory" and teacher_model is previous_cycle_teacher_model ) if teacher_model is None: teacher_model = load_causal_lm( getattr(args, "base_model_id", args.model), torch_dtype=get_dtype(args.dtype), trust_remote_code=args.trust_remote_code, cache_dir=args.model_cache_dir, ) teacher_model.to(teacher_device) teacher_model.eval() teacher_source = "base_model" teacher_cycle = 0 teacher_parent, teacher_layer_attr, teacher_container = find_layer_container( teacher_model, args.layer_path ) teacher_layers = list(teacher_container) do_reparam = ( (not args.skip_distill) and distill_ready and prepared.distill_loader is not None ) if (not args.skip_distill) and not do_reparam: print("[reparam] No distillation sequences built; skipping reparam distill.") distill_post_eval = None if do_reparam: lambda_source = "fisher_prior" reparam_gate_targets: Dict[str, object] = compute_fisher_gate_priors( layer_a=layer_a, layer_b=layer_b, fisher_a=fisher_sums[0], fisher_b=fisher_sums[1], num_batches=num_batches, numels_a=param_numels[0], numels_b=param_numels[1], fisher_mode=args.fisher_mode, eps=float(args.eps), ) if not reparam_gate_targets: raise SystemExit("[reparam] No mergeable parameters found; cannot continue.") if float(args.distill_epochs) > 0.0 and ( teacher_model is None or teacher_layers is None ): raise SystemExit("--distill_method reparam requires a teacher model.") print( f"[reparam] Cycle {cycle_idx}: training U + gates for pair " f"{layer_idx}-{layer_idx + 1} (epochs={args.distill_epochs}, " f"hidden_mse_w={args.distill_hidden_mse_weight}, " f"attn_mse_w={args.distill_attn_mse_weight}, " f"mlp_mse_w={args.distill_mlp_mse_weight}, " f"eta={args.reparam_eta}, gamma={args.reparam_gamma}, " f"attn_reg_scale={args.reparam_attn_reg_scale}, " f"mlp_reg_scale={args.reparam_mlp_reg_scale}, " f"param_subset={args.reparam_param_subset}, " f"lambda_init={lambda_source})." ) merged, final_gates, reparam_stats = distill_reparam_merge( student_model=model, student_parent=parent, student_layer_attr=name, student_layers=layers, teacher_model=teacher_model, teacher_parent=teacher_parent, teacher_layer_attr=teacher_layer_attr, teacher_layers=teacher_layers, layer_idx=layer_idx, gate_lambdas=reparam_gate_targets, dataloader=prepared.distill_loader, args=args, progressive_cycle=cycle_idx, progressive_total=args.num_progressive, ) reparam_gate_summary = summarize_gate_lambdas(final_gates) merge_method = "reparam" if reparam_stats is not None: reparam_stats["lambda_init"] = lambda_source else: merged = merge_layers( layer_a, layer_b, fisher_sums[0], fisher_sums[1], num_batches, param_numels[0], param_numels[1], fisher_mode=args.fisher_mode, eps=args.eps, ) apply_norm_policy( layer_a, args.norm_policy, norm1_state, norm2_state, norm_names, ) if teacher_model is not None and not teacher_from_cache: del teacher_model teacher_model = None teacher_parent = None teacher_layer_attr = None teacher_layers = None if torch.cuda.is_available(): torch.cuda.empty_cache() new_container = drop_layer(container, layer_idx + 1) setattr(parent, name, new_container) decrement_config(model.config) layers = list(new_container) lora_post_eval = None if (not args.skip_eval) and (not args.skip_distill) and do_reparam: distill_post_eval = evaluate_ppl_model( model, tokenizer, eval_datasets, eval_configs, args, prepared_eval_dataloaders=prepared.eval_dataloaders, ) print(f"[progressive] Cycle {cycle_idx} post-distill perplexity:") for dataset_name, ppl in distill_post_eval.items(): print(f"{dataset_name}: {ppl:.4f}") post_eval = None if not args.skip_eval: if distill_post_eval is not None: post_eval = distill_post_eval else: post_eval = evaluate_ppl_model( model, tokenizer, eval_datasets, eval_configs, args, prepared_eval_dataloaders=prepared.eval_dataloaders, ) print(f"[progressive] Cycle {cycle_idx} perplexity:") for dataset_name, ppl in post_eval.items(): print(f"{dataset_name}: {ppl:.4f}") cycle_dir = os.path.join(args.output_dir, f"cycle_{cycle_idx}") os.makedirs(cycle_dir, exist_ok=True) fused_layer_file = "fused_layer.pt" fused_layer_path = os.path.join(cycle_dir, fused_layer_file) torch.save(layers[layer_idx].state_dict(), fused_layer_path) cycle_meta: Dict[str, object] = { "cycle": cycle_idx, "layer_merged": layer_idx, "num_layers_before": num_layers_before, "num_layers_after": num_layers_before - 1, "fused_layer_state": fused_layer_file, "dwce_score": dwce_scores[layer_idx] if dwce_scores else None, "dwce_scores": dwce_scores, "dwce_meta": dwce_meta, "fisher_num_batches": num_batches, "merge_method": merge_method, "merged_params": merged, "num_sequences": num_sequences, "teacher_source": teacher_source, "teacher_cycle": teacher_cycle, "eval": { "datasets": eval_datasets, "configs": eval_configs, "split": args.eval_split, "num_samples": args.eval_num_samples, "seq_len": args.eval_seq_len, "post_ppl": post_eval, }, "comm": comm_stats, "distill": { "enabled": not args.skip_distill, "method": distill_method, "calib_samples": args.distill_calib_samples, "inst_samples": args.distill_inst_samples, "seq_len": args.distill_seq_len, "batch_size": args.distill_batch_size, "epochs": args.distill_epochs, "lr": args.distill_lr, "kl_weight": args.distill_kl_weight, "kl_temp": args.distill_kl_temp, "hidden_mse_weight": args.distill_hidden_mse_weight, "attn_mse_weight": args.distill_attn_mse_weight, "mlp_mse_weight": args.distill_mlp_mse_weight, "reparam_eta": args.reparam_eta, "reparam_gamma": args.reparam_gamma, "reparam_attn_reg_scale": args.reparam_attn_reg_scale, "reparam_mlp_reg_scale": args.reparam_mlp_reg_scale, "reparam_param_subset": args.reparam_param_subset, "reparam_stats": reparam_stats, "reparam_gate_summary": reparam_gate_summary, "post_ppl": distill_post_eval, "weight_decay": args.distill_weight_decay, "max_grad_norm": args.distill_max_grad_norm, "grad_accum_steps": args.distill_grad_accum_steps, "instruction_dataset": args.instruction_dataset, "instruction_config": args.instruction_config, "instruction_split": args.instruction_split, }, "lora": { "enabled": args.lora_epochs > 0, "seq_len": args.distill_seq_len, "batch_size": args.distill_batch_size, "epochs": args.lora_epochs, "rank": args.lora_rank, "alpha": args.lora_alpha, "dropout": args.lora_dropout, "target_modules": args.lora_target_modules, "respect_exclude_pairs": args.lora_respect_exclude_pairs, "kl_enabled": args.lora_kl_enabled, "kl_weight": args.lora_kl_weight, "kl_temp": args.lora_kl_temp, "post_ppl": lora_post_eval, "lr": args.lora_lr, "weight_decay": args.lora_weight_decay, "max_grad_norm": args.lora_max_grad_norm, "grad_accum_steps": args.lora_grad_accum_steps, "log_steps": args.lora_log_steps, "eval_every": args.lora_eval_every, "eval_max_batches": args.lora_eval_max_batches, }, "norm_policy": args.norm_policy, } saved_full_model_dir = None if cycle_idx in args.full_model_save_cycles: cycle_meta["full_model_saved"] = True cycle_meta["full_model"] = save_cycle_full_model( model=model, tokenizer=tokenizer, cycle_dir=cycle_dir, cycle_idx=cycle_idx, args=args, ppl_results=post_eval, ) saved_full_model_dir = os.path.join(cycle_dir, "full_model") else: cycle_meta["full_model_saved"] = False with open( os.path.join(cycle_dir, "cycle_metadata.json"), "w", encoding="utf-8", ) as handle: json.dump(cycle_meta, handle, indent=2) cycle_summaries.append( { "cycle": cycle_idx, "layer_merged": layer_idx, "dwce_score": dwce_scores[layer_idx] if dwce_scores else None, "comm_post_ppl": comm_post_eval, "distill_post_ppl": distill_post_eval, "lora_post_ppl": lora_post_eval, "post_ppl": post_eval, "cycle_dir": f"cycle_{cycle_idx}", } ) last_fused_idx = layer_idx _snapshot_previous_cycle_teacher(cycle_idx) parent, name, container = find_layer_container(model, args.layer_path) layers = list(container) if dwce_scores: dwce_scores = dwce_scores[: max(len(layers) - 1, 0)] # Encourage allocator to release cached blocks between cycles. if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.ipc_collect() gc.collect() if saved_full_model_dir is not None: save_rng_state(os.path.join(saved_full_model_dir, "rng_state.pt")) save_loader_generator_state( saved_full_model_dir, distill_generator=prepared.distill_generator, lora_generator=prepared.lora_generator, ) _release_comm_teacher() _release_previous_cycle_teacher() final_pre_lora_eval = cycle_summaries[-1]["post_ppl"] if cycle_summaries else None final_pre_lora_dir = f"{os.path.abspath(args.output_dir.rstrip(os.sep))}_final_pre_lora_hf" final_pre_lora_meta = save_stage_checkpoint( model=model, tokenizer=tokenizer, stage_dir=final_pre_lora_dir, stage_name="final_pre_lora", ppl_results=final_pre_lora_eval, ) # Optional final LoRA finetune after all pruning cycles. lora_eval_history: List[Dict[str, object]] = [] lora_post_eval = None lora_ready = has_post_fusion_data(prepared.lora_loader, prepared.lora_meta) if args.lora_epochs > 0: if not lora_ready: print("No post-fusion sequences built; skipping LoRA finetuning.") else: print( f"[progressive] Running final LoRA finetuning (epochs={args.lora_epochs})." ) lora_eval_history = run_lora_phase( model=model, tokenizer=tokenizer, eval_datasets=eval_datasets, eval_configs=eval_configs, args=args, lora_loader=prepared.lora_loader, lora_meta=prepared.lora_meta, eval_dataloaders=prepared.eval_dataloaders, cycle_idx=args.num_progressive, num_cycles=args.num_progressive, ) if not args.skip_eval: lora_post_eval = evaluate_ppl_model( model, tokenizer, eval_datasets, eval_configs, args, prepared_eval_dataloaders=prepared.eval_dataloaders, ) print("[progressive] Post-LoRA perplexity:") for dataset_name, ppl in lora_post_eval.items(): print(f"{dataset_name}: {ppl:.4f}") # Update final cycle metadata and summary with the post-LoRA PPL. if cycle_summaries: cycle_summaries[-1]["lora_post_ppl"] = lora_post_eval if lora_post_eval is not None: cycle_summaries[-1]["post_ppl"] = lora_post_eval final_cycle_dir = os.path.join( args.output_dir, f"cycle_{args.num_progressive}" ) final_cycle_meta_path = os.path.join(final_cycle_dir, "cycle_metadata.json") if os.path.exists(final_cycle_meta_path): with open(final_cycle_meta_path, "r", encoding="utf-8") as handle: final_cycle_meta = json.load(handle) lora_meta_entry = final_cycle_meta.get("lora") if not isinstance(lora_meta_entry, dict): lora_meta_entry = {} final_cycle_meta["lora"] = lora_meta_entry lora_meta_entry["ran"] = True lora_meta_entry["post_ppl"] = lora_post_eval if lora_post_eval is not None and isinstance( final_cycle_meta.get("eval"), dict ): final_cycle_meta["eval"]["post_ppl"] = lora_post_eval if lora_eval_history: lora_path = os.path.join(final_cycle_dir, "ppl_over_lora.json") with open(lora_path, "w", encoding="utf-8") as handle: json.dump(lora_eval_history, handle, indent=2) lora_meta_entry["ppl_over_lora"] = "ppl_over_lora.json" with open(final_cycle_meta_path, "w", encoding="utf-8") as handle: json.dump(final_cycle_meta, handle, indent=2) os.makedirs(args.output_dir, exist_ok=True) final_post_lora_meta = save_stage_checkpoint( model=model, tokenizer=tokenizer, stage_dir=args.output_dir, stage_name="final_post_lora" if lora_post_eval is not None else "final_model", ppl_results=lora_post_eval, ) progressive_meta = { "base_model": getattr(args, "base_model_id", args.model), "num_progressive": args.num_progressive, "layer_path": args.layer_path, "resume_from_cycle": args.resume_from_cycle, "save_full_model_cycles": sorted(args.full_model_save_cycles), "num_sequences": num_sequences, "seq_len": args.seq_len, "lora": { "enabled": args.lora_epochs > 0, "ran": args.lora_epochs > 0 and lora_ready, "seq_len": args.distill_seq_len, "batch_size": args.distill_batch_size, "epochs": args.lora_epochs, "rank": args.lora_rank, "alpha": args.lora_alpha, "dropout": args.lora_dropout, "target_modules": args.lora_target_modules, "respect_exclude_pairs": args.lora_respect_exclude_pairs, "kl_enabled": args.lora_kl_enabled, "kl_weight": args.lora_kl_weight, "kl_temp": args.lora_kl_temp, "post_ppl": lora_post_eval, "ppl_over_lora": ( f"cycle_{args.num_progressive}/ppl_over_lora.json" if lora_eval_history else None ), "lr": args.lora_lr, "weight_decay": args.lora_weight_decay, "max_grad_norm": args.lora_max_grad_norm, "grad_accum_steps": args.lora_grad_accum_steps, "log_steps": args.lora_log_steps, "eval_every": args.lora_eval_every, "eval_max_batches": args.lora_eval_max_batches, }, "artifacts": { "final_pre_lora": final_pre_lora_meta, "final_post_lora": final_post_lora_meta, }, "eval": { "datasets": eval_datasets, "configs": eval_configs, "split": args.eval_split, "num_samples": args.eval_num_samples, "seq_len": args.eval_seq_len, "pre_ppl": pre_eval, "post_ppl": cycle_summaries[-1]["post_ppl"] if cycle_summaries else None, }, "cycles": cycle_summaries, "final_num_layers": len(layers), } with open( os.path.join(args.output_dir, "progressive_metadata.json"), "w", encoding="utf-8", ) as handle: json.dump(progressive_meta, handle, indent=2) print( f"[progressive] Completed {args.num_progressive} cycles. " f"Final model saved to {args.output_dir}." ) def main() -> None: args = parse_args() if args.num_progressive <= 0: raise SystemExit( "Single-cycle mode has been removed. Pass --num_progressive > 0." ) if args.resume_from_cycle < 0: raise SystemExit("--resume_from_cycle must be >= 0.") if args.resume_from_cycle >= args.num_progressive: raise SystemExit("--resume_from_cycle must be smaller than --num_progressive.") args.full_model_save_cycles = resolve_full_model_save_cycles( parse_cycle_list(args.save_full_model_cycles), args.num_progressive, ) args.base_model_id = args.model if args.resume_from_cycle > 0: resume_meta = load_resume_metadata(args.model) if resume_meta is None: raise SystemExit( "--resume_from_cycle requires --model to point to a saved cycle full model " "directory containing resume_info.json." ) resume_cycle = resume_meta.get("cycle") if resume_cycle is not None and int(resume_cycle) != args.resume_from_cycle: raise SystemExit( "resume_info.json cycle does not match --resume_from_cycle." ) base_model = resume_meta.get("base_model") if isinstance(base_model, str) and base_model: args.base_model_id = base_model configure_reproducibility(args.seed) eval_datasets, eval_configs = resolve_eval_datasets(args) dtype = get_dtype(args.dtype) model = load_causal_lm( args.model, torch_dtype=dtype, trust_remote_code=args.trust_remote_code, cache_dir=args.model_cache_dir, ) loader_generator_state = None if args.resume_from_cycle > 0: rng_state_path = os.path.join(args.model, "rng_state.pt") rng_state = load_rng_state(rng_state_path) if rng_state is not None: restore_rng_state(rng_state) loader_generator_state = load_loader_generator_state(args.model) tokenizer = AutoTokenizer.from_pretrained( args.model, trust_remote_code=args.trust_remote_code, cache_dir=args.model_cache_dir, ) print(model) if tokenizer.pad_token is None and tokenizer.eos_token is not None: tokenizer.pad_token = tokenizer.eos_token # for llama? model.config.use_cache = False prepared = prepare_all_data( args, tokenizer, model, eval_datasets, eval_configs, loader_generator_state=loader_generator_state, ) run_progressive(args, model, tokenizer, prepared) if __name__ == "__main__": main()