temp_ss / src /fuse_layers.py
LJYAI's picture
upload src
2c44909 verified
#!/usr/bin/env python3
"""Fuse adjacent layers via attention head alignment + Fisher-barycentric merge."""
import argparse
import copy
import gc
import json
import os
import random
from dataclasses import dataclass
from typing import Dict, List, Optional, Set, Tuple
import torch
try:
import numpy as np
except Exception: # pragma: no cover - optional dependency
np = None
try:
from transformers import AutoModelForCausalLM, AutoTokenizer
except Exception as exc: # pragma: no cover - fail early with clear error
raise SystemExit("transformers is required: pip install transformers") from exc
try:
import ppl_eval
except Exception as exc: # pragma: no cover - optional dependency
raise SystemExit("ppl_eval.py is required (missing or invalid)") from exc
from fuse_layers_data import (
FixedSeqDataset,
build_token_chunks,
expand_dataset_configs,
load_instruction_records,
load_texts,
load_texts_from_datasets,
)
from common_lm_data import SharedLMDataSpec, build_chunks, build_dataloader
from fuse_layers_distill import (
commutator_precondition,
compute_fisher_gate_priors,
distill_reparam_merge,
lora_ce_finetune,
)
from fuse_layers_model import (
apply_norm_policy,
build_head_permutation,
clone_state_dict,
compute_fisher,
compute_head_means,
decrement_config,
drop_layer,
find_attention_module,
find_colon_modules,
find_layer_container,
get_dtype,
get_norm_pair,
merge_layers,
permute_attention_heads,
)
from fuse_layers_select import select_layer_auto
from progressive_loader import load_causal_lm, load_progressive_model
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Fuse layer i and i+1 using head alignment + Fisher barycenter."
)
parser.add_argument("--model", required=True, help="HF model id or local path")
parser.add_argument(
"--model_cache_dir",
default=None,
help="Optional cache dir for model/tokenizer downloads",
)
parser.add_argument(
"--layer",
type=str,
default="auto",
help="Layer index i (int) or 'auto' to select via auto metric",
)
parser.add_argument(
"--selection_method",
choices=["dwce", "sequential"],
default="dwce",
help=(
"Pair selection policy for progressive pruning. "
"'dwce' uses downstream-weighted composition error; "
"'sequential' always takes the next available pair."
),
)
parser.add_argument(
"--exclude_pairs",
"--exclude_layers",
nargs="*",
default=None,
dest="exclude_pairs",
help=(
"Exclude pair indices from consideration for any fusion. Indices refer to "
"pair start positions in [0..N-2]. Negative indices count from the end "
"(-1 = last pair, -2 = second last). Accepts space- or comma-separated ints. "
"Alias: --exclude_layers (deprecated)."
),
)
parser.add_argument(
"--output_dir", required=True, help="Directory to write fused model"
)
parser.add_argument(
"--dataset",
action="append",
default=[],
help=(
"HF dataset name (repeatable). Optional if using --text or --text_file."
),
)
parser.add_argument(
"--dataset_config",
action="append",
default=[],
help="Optional dataset config (repeatable or single shared config).",
)
parser.add_argument(
"--dataset_split",
default="train",
help="Dataset split to use (default: train)",
)
parser.add_argument(
"--dataset_text_field",
default=None,
help="Text field in dataset (default: auto-detect, applies to all datasets)",
)
parser.add_argument(
"--text",
action="append",
default=[],
help="Inline text samples (can pass multiple)",
)
parser.add_argument(
"--text_file",
default=None,
help="Path to a text file for calibration data",
)
parser.add_argument(
"--num_samples",
type=int,
default=128,
help="Number of token sequences to use",
)
parser.add_argument(
"--target_tokens",
type=int,
default=0,
help="Target token budget for common_lm_data-backed calibration/distillation (0 = disabled)",
)
parser.add_argument("--seq_len", type=int, default=256, help="Sequence length")
parser.add_argument("--batch_size", type=int, default=2, help="Batch size")
parser.add_argument(
"--device",
default="cuda" if torch.cuda.is_available() else "cpu",
help="Device for model + compute",
)
parser.add_argument(
"--dtype",
default="auto",
choices=["auto", "float32", "float16", "bfloat16"],
help="Model dtype",
)
parser.add_argument(
"--layer_path",
default=None,
help="Override layer attribute path (e.g., model.layers)",
)
parser.add_argument(
"--fisher_mode",
default="tensor",
choices=["tensor", "param"],
help="Fisher approximation granularity",
)
parser.add_argument(
"--no_head_permute",
action="store_true",
help=(
"Deprecated alias for --no_head_permute_merge. "
"Disables merge-stage head permutation only."
),
)
parser.add_argument(
"--no_head_permute_merge",
action="store_true",
help="Disable attention head permutation alignment before merge",
)
parser.add_argument(
"--no_head_permute_select",
action="store_true",
help="Disable attention head permutation alignment during auto selection",
)
parser.add_argument("--eps", type=float, default=1e-8, help="Stability epsilon")
parser.add_argument("--seed", type=int, default=0, help="Random seed")
parser.add_argument(
"--trust_remote_code",
action="store_true",
help="Allow custom model code from hub",
)
parser.add_argument(
"--save_metadata",
action="store_true",
help="Backward-compatible no-op; metadata is always written.",
)
parser.add_argument(
"--skip_eval",
action="store_true",
help="Skip pre/post perplexity evaluation",
)
parser.add_argument(
"--eval_dataset",
action="append",
default=[],
help="Evaluation dataset name (repeatable). Defaults to wikitext.",
)
parser.add_argument(
"--eval_dataset_config",
action="append",
default=[],
help="Evaluation dataset config (repeatable or single shared config).",
)
parser.add_argument(
"--eval_split",
default="test",
help="Evaluation dataset split (default: test)",
)
parser.add_argument(
"--eval_text_field",
default=None,
help="Evaluation text field override (default: auto-detect)",
)
parser.add_argument(
"--eval_model_family",
type=str,
choices=["auto", "llama", "qwen"],
default="auto",
help="Model family for BOS handling during eval",
)
parser.add_argument(
"--eval_add_bos",
type=str,
choices=["auto", "always", "never"],
default="auto",
help="Whether to prepend BOS to each eval sample",
)
parser.add_argument(
"--eval_num_samples",
type=int,
default=0,
help="Number of token sequences per eval dataset (0 = all)",
)
parser.add_argument(
"--eval_seq_len",
type=int,
default=2048,
help="Sequence length for eval",
)
parser.add_argument(
"--eval_batch_size",
type=int,
default=None,
help="Batch size for eval (defaults to --batch_size)",
)
parser.add_argument(
"--eval_max_batches",
type=int,
default=None,
help="Optional max number of eval batches per dataset",
)
parser.add_argument(
"--eval_cache_dir",
default=None,
help="Optional datasets cache dir for eval",
)
parser.add_argument(
"--eval_num_workers",
type=int,
default=0,
help="Eval DataLoader workers",
)
parser.add_argument(
"--eval_device",
default=None,
help="Device for eval (defaults to --device)",
)
parser.add_argument(
"--skip_distill",
action="store_true",
help="Skip reparameterized distillation after head alignment/Fisher setup",
)
parser.add_argument(
"--distill_calib_samples",
type=int,
default=256,
help="Number of distillation sequences from calibration datasets",
)
parser.add_argument(
"--distill_inst_samples",
type=int,
default=0,
help="Number of distillation sequences from instruction dataset (0 = all)",
)
parser.add_argument(
"--distill_seq_len",
type=int,
default=512,
help="Sequence length for distillation",
)
parser.add_argument(
"--distill_batch_size",
type=int,
default=2,
help="Batch size for distillation",
)
parser.add_argument(
"--distill_epochs",
type=float,
default=1.0,
help="Number of distillation epochs (float allowed, e.g. 0.5)",
)
parser.add_argument(
"--distill_lr",
type=float,
default=1e-4,
help="Learning rate for distillation",
)
parser.add_argument(
"--distill_method",
choices=["reparam"],
default="reparam",
help="Distillation strategy (reparam only).",
)
parser.add_argument(
"--distill_kl_weight",
type=float,
default=1e-2,
help="Weight for KL loss on logits",
)
parser.add_argument(
"--distill_kl_temp",
type=float,
default=4.0,
help="Temperature for KL distillation on logits",
)
parser.add_argument(
"--distill_hidden_mse_weight",
type=float,
default=1.0,
help="Weight for hidden-state MSE in reparam distillation (0 disables it)",
)
parser.add_argument(
"--distill_attn_mse_weight",
type=float,
default=0.0,
help="Weight for auxiliary attention-output MSE in reparam distillation",
)
parser.add_argument(
"--distill_mlp_mse_weight",
type=float,
default=0.0,
help="Weight for auxiliary MLP-output MSE in reparam distillation",
)
parser.add_argument(
"--reparam_eta",
type=float,
default=1e-2,
help="Eta: ||lambda - lambda_gate||^2 regularizer weight for --distill_method reparam",
)
parser.add_argument(
"--reparam_gamma",
type=float,
default=1e-4,
help="Gamma: ||U - U0||^2 regularizer weight for --distill_method reparam",
)
parser.add_argument(
"--reparam_attn_reg_scale",
type=float,
default=1.0,
help="Relative scale applied to attention-parameter reparam regularizers",
)
parser.add_argument(
"--reparam_mlp_reg_scale",
type=float,
default=1.0,
help="Relative scale applied to MLP-parameter reparam regularizers",
)
parser.add_argument(
"--reparam_param_subset",
type=str,
choices=["all", "mlp", "attn"],
default="all",
help="Restrict reparam merge/recovery capacity to only this parameter family",
)
parser.add_argument(
"--norm_policy",
type=str,
choices=["hybrid", "merge_all", "copy_n1", "copy_n1_n2"],
default="hybrid",
help="Norm merge policy (default: hybrid)",
)
parser.add_argument(
"--distill_weight_decay",
type=float,
default=0.0,
help="Weight decay for distillation",
)
parser.add_argument(
"--distill_max_grad_norm",
type=float,
default=1.0,
help="Max grad norm for distillation",
)
parser.add_argument(
"--distill_grad_accum_steps",
type=int,
default=1,
help="Gradient accumulation steps for distillation",
)
parser.add_argument(
"--distill_log_steps",
type=int,
default=100,
help="Log distillation loss every N steps",
)
parser.add_argument(
"--distill_eval_every",
type=int,
default=0,
help="Evaluate PPL every N distill steps (0 = disable)",
)
parser.add_argument(
"--distill_eval_max_batches",
type=int,
default=None,
help="Max eval batches per dataset during distill (default: all)",
)
parser.add_argument(
"--distill_teacher_device",
default=None,
help="Device for teacher model during distillation (defaults to --device)",
)
parser.add_argument(
"--comm_enabled",
action="store_true",
help=(
"Enable commutator-style preconditioning before each progressive "
"cycle's fusion."
),
)
parser.add_argument(
"--comm_include_cycle1",
action="store_true",
help="Run commutator preconditioning for cycle 1 as well (default: skip cycle 1).",
)
parser.add_argument(
"--comm_topk",
type=int,
default=1,
help="Top-K lowest-score pairs used as the commutator candidate set",
)
parser.add_argument(
"--comm_sample_eta",
type=float,
default=0.5,
help="Mixture weight between uniform and score-biased candidate sampling",
)
parser.add_argument(
"--comm_sample_dwce_scale",
type=float,
default=1.0,
help="Scale c in softmax(-c * score(i)) for commutator pair sampling",
)
parser.add_argument(
"--comm_temp",
type=float,
default=2.0,
help="Temperature for teacher-anchor KL in commutator preconditioning",
)
parser.add_argument(
"--comm_steps_ratio",
type=float,
default=0.1,
help="Run this fraction of distillation optimizer steps for commutator phase",
)
parser.add_argument(
"--comm_lr_scale",
type=float,
default=0.1,
help="Commutator LR = --distill_lr * this scale",
)
parser.add_argument(
"--comm_train_mode",
choices=["lora", "full"],
default="lora",
help=(
"Commutator trainable parameter mode: "
"'lora' updates LoRA adapters on sampled receiver layers; "
"'full' updates full receiver-layer weights."
),
)
parser.add_argument(
"--comm_interaction_mode",
choices=["mse", "relative"],
default="relative",
help="Interaction loss form: plain MSE or relative MSE",
)
parser.add_argument(
"--comm_interaction_eps",
type=float,
default=1e-8,
help="Epsilon for relative commutator interaction normalization",
)
parser.add_argument(
"--comm_mu",
type=float,
default=None,
help=(
"Weight for interaction loss. Defaults to 0.1 for --comm_interaction_mode=mse "
"and 0.5 for --comm_interaction_mode=relative."
),
)
parser.add_argument(
"--comm_mu_auto",
action="store_true",
help="Enable automatic mu scaling via gradient-norm balancing",
)
parser.add_argument(
"--comm_mu_auto_rho",
type=float,
default=0.1,
help="Target anchor-to-interaction gradient ratio constant for auto-mu",
)
parser.add_argument(
"--comm_mu_auto_eps",
type=float,
default=1e-8,
help="Numerical epsilon in auto-mu denominator",
)
parser.add_argument(
"--comm_log_steps",
type=int,
default=50,
help="Log commutator preconditioning loss every N optimizer steps",
)
parser.add_argument(
"--comm_skip_post_reselect",
action="store_true",
help=(
"Keep the pre-comm selected fusion pair and skip recomputing "
"selection after commutator preconditioning."
),
)
parser.add_argument(
"--redistrib_teacher_source",
type=str,
choices=["base_model", "previous_cycle"],
default="base_model",
help=(
"Teacher source for commutator preconditioning teacher loading. "
"'base_model' uses --model for all cycles; "
"'previous_cycle' uses cycle-1 checkpoint (cycle 1 falls back to base_model)."
),
)
parser.add_argument(
"--lora_epochs",
type=float,
default=1.0,
help="LoRA CE finetuning epochs after distill (0 = disable)",
)
parser.add_argument(
"--lora_rank",
type=int,
default=8,
help="LoRA rank (r)",
)
parser.add_argument(
"--lora_alpha",
type=float,
default=16.0,
help="LoRA alpha",
)
parser.add_argument(
"--lora_dropout",
type=float,
default=0.0,
help="LoRA dropout",
)
parser.add_argument(
"--lora_kl_enabled",
action="store_true",
help="Add KL regularization between pre/post LoRA logits",
)
parser.add_argument(
"--lora_kl_weight",
type=float,
default=1e-1,
help="KL weight for LoRA regularization",
)
parser.add_argument(
"--lora_kl_temp",
type=float,
default=4.0,
help="Temperature for LoRA KL regularization",
)
parser.add_argument(
"--lora_target_modules",
nargs="*",
default=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"down_proj",
"up_proj",
],
help="Module name suffixes to LoRA-wrap",
)
parser.add_argument(
"--lora_respect_exclude_pairs",
action="store_true",
help=(
"When attaching LoRA adapters, skip linear modules under layers touched by "
"--exclude_pairs (i and i+1 for each excluded pair)."
),
)
parser.add_argument(
"--lora_lr",
type=float,
default=1e-4,
help="Learning rate for LoRA finetuning",
)
parser.add_argument(
"--lora_weight_decay",
type=float,
default=0.0,
help="Weight decay for LoRA finetuning",
)
parser.add_argument(
"--lora_max_grad_norm",
type=float,
default=1.0,
help="Max grad norm for LoRA finetuning",
)
parser.add_argument(
"--lora_grad_accum_steps",
type=int,
default=1,
help="Gradient accumulation steps for LoRA finetuning",
)
parser.add_argument(
"--lora_log_steps",
type=int,
default=100,
help="Log LoRA loss every N steps",
)
parser.add_argument(
"--lora_eval_every",
type=int,
default=0,
help="Evaluate PPL every N LoRA steps (0 = disable)",
)
parser.add_argument(
"--lora_eval_max_batches",
type=int,
default=None,
help="Max eval batches per dataset during LoRA (default: all)",
)
parser.add_argument(
"--instruction_dataset",
default=None,
help="HF dataset name for alpaca-style instruction data",
)
parser.add_argument(
"--instruction_config",
default=None,
help="Optional instruction dataset config",
)
parser.add_argument(
"--instruction_split",
default="train",
help="Instruction dataset split",
)
parser.add_argument(
"--instruction_field_instruction",
default="instruction",
help="Instruction field name",
)
parser.add_argument(
"--instruction_field_input",
default="input",
help="Optional input field name",
)
parser.add_argument(
"--instruction_field_output",
default="output",
help="Response/output field name",
)
parser.add_argument(
"--auto_max_batches",
type=int,
default=0,
help="Max calibration batches for auto selection scoring (0 = all)",
)
parser.add_argument(
"--auto_metric",
type=str,
choices=[
"dwce",
"cosine",
"hybrid",
"hybrid_cosine",
"hybrid_global_rel",
],
default="dwce",
help=(
"Auto pair scoring metric. 'dwce' uses downstream-weighted composition error; "
"'cosine' uses average token-level cosine distance between adjacent layer outputs; "
"'hybrid'/'hybrid_cosine' use DWCE to shortlist then adjacent cosine for final scoring; "
"'hybrid_global_rel' uses DWCE to shortlist then reranks by the change in "
"pair-to-final-layer cosine relation after surrogate fusion."
),
)
parser.add_argument(
"--auto_cosine_topk",
type=int,
default=3,
help="Top-K DWCE candidates to rescore with cosine in --auto_metric=hybrid",
)
parser.add_argument(
"--auto_norm",
type=str,
choices=["relative", "none"],
default="relative",
help="Normalization mode for DWCE scoring (ignored for cosine)",
)
parser.add_argument(
"--auto_dwce_mode",
type=str,
choices=["separate", "shared"],
default="separate",
help=(
"DWCE implementation for auto scoring. "
"'separate' runs distinct Fisher and DWCE backward passes; "
"'shared' reuses one backward pass and replays DWCE with cached gradients."
),
)
parser.add_argument(
"--num_progressive",
type=int,
default=0,
help="Number of progressive fusions (>0 required)",
)
parser.add_argument(
"--resume_from_cycle",
type=int,
default=0,
help=(
"Resume from this completed cycle index. When > 0, --model should point "
"to the saved full model directory for that cycle."
),
)
parser.add_argument(
"--save_full_model_cycles",
nargs="*",
default=[],
help=(
"Cycle indices whose full models should be saved. Requesting cycle c "
"also saves cycle c-1 automatically (c=1 saves only cycle 1)."
),
)
return parser.parse_args()
def parse_exclude_pairs(exclude_raw: Optional[List[str]], num_pairs: int) -> List[int]:
"""Parse --exclude_pairs into normalized pair indices for the current model.
Indices refer to the start of an adjacent pair (i, i+1) and must be in [0..N-2].
Negative indices count from the end (-1 = last pair).
"""
if not exclude_raw:
return []
exclude: List[int] = []
for item in exclude_raw:
if item is None:
continue
for part in str(item).split(","):
part = part.strip()
if not part:
continue
try:
idx = int(part)
except ValueError as exc:
raise SystemExit("--exclude_pairs must contain integers.") from exc
if idx < 0:
idx = num_pairs + idx
if 0 <= idx < num_pairs:
exclude.append(idx)
return sorted(set(exclude))
def parse_cycle_list(raw_values: Optional[List[str]]) -> List[int]:
if not raw_values:
return []
cycles: List[int] = []
for item in raw_values:
if item is None:
continue
for part in str(item).split(","):
part = part.strip()
if not part:
continue
try:
cycles.append(int(part))
except ValueError as exc:
raise SystemExit(
"--save_full_model_cycles must contain integers."
) from exc
return cycles
def resolve_full_model_save_cycles(
requested_cycles: List[int], num_progressive: int
) -> Set[int]:
resolved: Set[int] = set()
for cycle in requested_cycles:
if cycle <= 0 or cycle > num_progressive:
raise SystemExit(
"--save_full_model_cycles entries must be within [1, --num_progressive]."
)
resolved.add(cycle)
if cycle > 1:
resolved.add(cycle - 1)
return resolved
def load_resume_metadata(model_path: str) -> Optional[Dict[str, object]]:
resume_meta_path = os.path.join(model_path, "resume_info.json")
if not os.path.exists(resume_meta_path):
return None
with open(resume_meta_path, "r", encoding="utf-8") as handle:
loaded = json.load(handle)
return loaded if isinstance(loaded, dict) else None
def build_generator(seed: int) -> torch.Generator:
generator = torch.Generator(device="cpu")
generator.manual_seed(int(seed))
return generator
def capture_rng_state() -> Dict[str, object]:
state: Dict[str, object] = {
"python_random_state": random.getstate(),
"torch_cpu_rng_state": torch.get_rng_state(),
}
if np is not None:
state["numpy_random_state"] = np.random.get_state()
if torch.cuda.is_available():
state["torch_cuda_rng_state_all"] = torch.cuda.get_rng_state_all()
return state
def restore_rng_state(state: Dict[str, object]) -> None:
python_state = state.get("python_random_state")
if python_state is not None:
random.setstate(python_state)
numpy_state = state.get("numpy_random_state")
if numpy_state is not None and np is not None:
np.random.set_state(numpy_state)
torch_cpu_state = state.get("torch_cpu_rng_state")
if torch_cpu_state is not None:
torch.set_rng_state(torch_cpu_state)
torch_cuda_state = state.get("torch_cuda_rng_state_all")
if torch_cuda_state is not None and torch.cuda.is_available():
torch.cuda.set_rng_state_all(torch_cuda_state)
def save_rng_state(path: str) -> None:
torch.save(capture_rng_state(), path)
def load_rng_state(path: str) -> Optional[Dict[str, object]]:
if not os.path.exists(path):
return None
loaded = torch.load(path, map_location="cpu", weights_only=False)
return loaded if isinstance(loaded, dict) else None
def configure_reproducibility(seed: int) -> None:
random.seed(seed)
if np is not None:
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
if hasattr(torch.backends, "cudnn"):
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
if hasattr(torch, "use_deterministic_algorithms"):
torch.use_deterministic_algorithms(True, warn_only=True)
def save_loader_generator_state(
base_dir: str,
*,
distill_generator: Optional[torch.Generator] = None,
lora_generator: Optional[torch.Generator] = None,
) -> None:
state: Dict[str, object] = {}
if distill_generator is not None:
state["distill_generator_state"] = distill_generator.get_state()
if lora_generator is not None:
state["lora_generator_state"] = lora_generator.get_state()
if state:
torch.save(state, os.path.join(base_dir, "loader_generators.pt"))
def load_loader_generator_state(base_dir: str) -> Optional[Dict[str, object]]:
path = os.path.join(base_dir, "loader_generators.pt")
if not os.path.exists(path):
return None
loaded = torch.load(path, map_location="cpu")
return loaded if isinstance(loaded, dict) else None
def resolve_layer_idx(
args: argparse.Namespace,
model,
layers: List[torch.nn.Module],
dataloader,
previous_scores,
start_index: int,
exclude_pairs: Set[int],
):
layer_arg = str(getattr(args, "layer", "auto")).strip().lower()
selection_method = str(getattr(args, "selection_method", "dwce")).strip().lower()
if layer_arg != "auto":
try:
layer_idx = int(layer_arg)
except ValueError as exc:
raise SystemExit("--layer must be 'auto' or an integer index") from exc
num_pairs = max(len(layers) - 1, 0)
if layer_idx < 0:
layer_idx += num_pairs
if layer_idx in exclude_pairs:
raise SystemExit(f"--layer resolved to excluded pair index {layer_idx}")
return layer_idx, previous_scores, {"method": "manual", "exclude_pairs": sorted(exclude_pairs)}
if selection_method == "sequential":
num_pairs = len(layers) - 1
for layer_idx in range(max(start_index, 0), num_pairs):
if layer_idx not in exclude_pairs:
return layer_idx, previous_scores, {
"method": "sequential",
"start_index": max(start_index, 0),
"exclude_pairs": sorted(exclude_pairs),
}
raise SystemExit("No eligible layer pairs remain after exclusions")
layer_idx, dwce_scores, dwce_meta = select_layer_auto(
model,
layers,
dataloader,
args,
previous_scores=previous_scores,
start_index=start_index,
exclude_pairs=exclude_pairs,
)
return layer_idx, dwce_scores, dwce_meta
@dataclass
class PreparedData:
calib_loader: torch.utils.data.DataLoader
calib_num_sequences: int
distill_loader: Optional[torch.utils.data.DataLoader]
distill_generator: Optional[torch.Generator]
distill_meta: Dict[str, object]
lora_loader: Optional[torch.utils.data.DataLoader]
lora_generator: Optional[torch.Generator]
lora_meta: Dict[str, object]
eval_datasets: List[str]
eval_configs: List[Optional[str]]
eval_dataloaders: Optional[Dict[str, torch.utils.data.DataLoader]]
def resolve_eval_datasets(args: argparse.Namespace) -> Tuple[List[str], List[Optional[str]]]:
eval_datasets = args.eval_dataset or ["wikitext"]
eval_configs = args.eval_dataset_config or ["wikitext-2-raw-v1"]
eval_configs = ppl_eval._expand_dataset_configs(eval_datasets, eval_configs)
return eval_datasets, eval_configs
def run_ppl_eval(
model_id_or_path: str,
eval_datasets: List[str],
eval_configs: List[Optional[str]],
args: argparse.Namespace,
prepared_eval_dataloaders: Optional[Dict[str, torch.utils.data.DataLoader]] = None,
) -> Dict[str, float]:
eval_device = args.eval_device or args.device
dtype = get_dtype(args.dtype)
eval_model = load_causal_lm(
model_id_or_path,
torch_dtype=dtype,
trust_remote_code=args.trust_remote_code,
)
eval_model.to(eval_device)
if prepared_eval_dataloaders is not None:
results = ppl_eval.evaluate_ppl_dataloaders(
eval_model,
prepared_eval_dataloaders,
eval_device,
max_batches=args.eval_max_batches,
)
else:
eval_batch_size = args.eval_batch_size or args.batch_size
eval_tokenizer = AutoTokenizer.from_pretrained(
model_id_or_path, trust_remote_code=args.trust_remote_code
)
if eval_tokenizer.pad_token is None and eval_tokenizer.eos_token is not None:
eval_tokenizer.pad_token = eval_tokenizer.eos_token
results = ppl_eval.evaluate_ppl_datasets(
eval_model,
eval_tokenizer,
datasets=eval_datasets,
configs=eval_configs,
split=args.eval_split,
text_field=args.eval_text_field,
num_samples=args.eval_num_samples,
seq_len=args.eval_seq_len,
batch_size=eval_batch_size,
device=eval_device,
seed=args.seed,
shuffle=False,
model_family=args.eval_model_family,
add_bos=args.eval_add_bos,
max_batches=args.eval_max_batches,
cache_dir=args.eval_cache_dir,
num_workers=args.eval_num_workers,
)
del eval_model
if torch.cuda.is_available():
torch.cuda.empty_cache()
return results
def build_calibration_dataloader(
args: argparse.Namespace, tokenizer
) -> Tuple[List[str], List[torch.Tensor], torch.utils.data.DataLoader]:
if args.dataset:
datasets = list(args.dataset)
configs = expand_dataset_configs(datasets, list(args.dataset_config))
chunks: List[torch.Tensor] = []
for idx, (dataset_name, config) in enumerate(zip(datasets, configs)):
spec = SharedLMDataSpec(
dataset=dataset_name,
config=config,
split=args.dataset_split,
text_field=args.dataset_text_field,
seq_len=args.seq_len,
num_sequences=args.num_samples,
seed=args.seed + idx,
)
chunks.extend(build_chunks(spec, tokenizer))
if not chunks:
raise SystemExit("Not enough text to build token sequences.")
input_ids = torch.stack(chunks)
attention_mask = torch.ones_like(input_ids)
dataset = torch.utils.data.TensorDataset(input_ids, attention_mask)
dataloader = torch.utils.data.DataLoader(
dataset, batch_size=args.batch_size, shuffle=False
)
return [], chunks, dataloader
texts = load_texts(args)
if not texts:
raise SystemExit(
"No calibration text found. Provide --dataset, --text, or --text_file."
)
chunks = build_token_chunks(texts, tokenizer, args.seq_len, args.num_samples)
if not chunks:
raise SystemExit("Not enough text to build token sequences.")
input_ids = torch.stack(chunks)
attention_mask = torch.ones_like(input_ids)
dataset = torch.utils.data.TensorDataset(input_ids, attention_mask)
dataloader = torch.utils.data.DataLoader(
dataset, batch_size=args.batch_size, shuffle=False
)
return texts, chunks, dataloader
def prepare_distillation_data(
args: argparse.Namespace, tokenizer, include_instruction: bool = True
) -> Tuple[Optional[torch.utils.data.DataLoader], Optional[torch.Generator], Dict[str, object]]:
if (
include_instruction
and args.distill_inst_samples != 0
and not args.instruction_dataset
):
print(
"Warning: --distill_inst_samples > 0 but no --instruction_dataset "
"provided; instruction distillation will be skipped."
)
calib_texts: List[str] = []
calib_dataset = None
if args.target_tokens > 0 and args.dataset:
datasets = list(args.dataset)
configs = expand_dataset_configs(datasets, list(args.dataset_config))
per_dataset = args.target_tokens // len(datasets)
remainder = args.target_tokens % len(datasets)
calib_chunks: List[torch.Tensor] = []
for idx, (dataset_name, config) in enumerate(zip(datasets, configs)):
dataset_tokens = per_dataset + (remainder if idx == 0 else 0)
spec = SharedLMDataSpec(
dataset=dataset_name,
config=config,
split=args.dataset_split,
text_field=args.dataset_text_field,
seq_len=args.distill_seq_len,
target_tokens=dataset_tokens,
seed=args.seed + 17 + idx,
)
calib_chunks.extend(build_chunks(spec, tokenizer))
if calib_chunks:
input_ids = torch.stack(calib_chunks)
attention_mask = torch.ones_like(input_ids)
calib_dataset = torch.utils.data.TensorDataset(input_ids, attention_mask)
else:
calib_texts = load_texts_from_datasets(
datasets=list(args.dataset),
configs=expand_dataset_configs(list(args.dataset), list(args.dataset_config)),
split=args.dataset_split,
text_field=args.dataset_text_field,
num_samples=args.distill_calib_samples,
seed=args.seed + 17,
)
inst_records = []
if include_instruction:
inst_records = load_instruction_records(args, args.distill_inst_samples)
distill_datasets = []
if calib_dataset is not None:
distill_datasets.append(calib_dataset)
elif calib_texts:
calib_records = [{"text": text} for text in calib_texts]
distill_datasets.append(
FixedSeqDataset(calib_records, tokenizer, args.distill_seq_len)
)
if inst_records:
distill_datasets.append(
FixedSeqDataset(inst_records, tokenizer, args.distill_seq_len)
)
distill_meta: Dict[str, object] = {
"calib_texts": len(calib_texts),
"calib_sequences": len(calib_dataset) if calib_dataset is not None else len(calib_texts),
"inst_sequences": len(inst_records),
"total_sequences": 0,
}
if not distill_datasets:
return None, None, distill_meta
if len(distill_datasets) == 1:
distill_dataset = distill_datasets[0]
else:
distill_dataset = torch.utils.data.ConcatDataset(distill_datasets)
distill_meta["total_sequences"] = len(distill_dataset)
distill_generator = build_generator(
args.seed + 1000 + (1000000 if include_instruction else 0)
)
distill_loader = torch.utils.data.DataLoader(
distill_dataset,
batch_size=args.distill_batch_size,
shuffle=True,
generator=distill_generator,
)
return distill_loader, distill_generator, distill_meta
def prepare_eval_dataloaders(
args: argparse.Namespace,
tokenizer,
model: torch.nn.Module,
eval_datasets: List[str],
eval_configs: List[Optional[str]],
) -> Optional[Dict[str, torch.utils.data.DataLoader]]:
needs_eval = (not args.skip_eval) or (
(not args.skip_distill and args.distill_eval_every)
or (args.lora_epochs > 0 and args.lora_eval_every)
)
if not needs_eval:
return None
eval_batch_size = args.eval_batch_size or args.batch_size
resolved_family = args.eval_model_family
if resolved_family == "auto":
resolved_family = ppl_eval._infer_model_family(model)
return ppl_eval.prepare_ppl_dataloaders(
tokenizer=tokenizer,
datasets=eval_datasets,
configs=eval_configs,
split=args.eval_split,
text_field=args.eval_text_field,
num_samples=args.eval_num_samples,
seq_len=args.eval_seq_len,
batch_size=eval_batch_size,
seed=args.seed,
shuffle=False,
model_family=resolved_family,
add_bos=args.eval_add_bos,
cache_dir=args.eval_cache_dir,
num_workers=args.eval_num_workers,
model=model,
)
def prepare_all_data(
args: argparse.Namespace,
tokenizer,
model: torch.nn.Module,
eval_datasets: List[str],
eval_configs: List[Optional[str]],
loader_generator_state: Optional[Dict[str, object]] = None,
) -> PreparedData:
texts, chunks, calib_loader = build_calibration_dataloader(args, tokenizer)
calib_num_sequences = len(chunks)
del texts
del chunks
distill_loader = None
distill_generator = None
distill_meta: Dict[str, object] = {
"calib_texts": 0,
"calib_sequences": 0,
"inst_sequences": 0,
"total_sequences": 0,
}
lora_loader = None
lora_generator = None
lora_meta: Dict[str, object] = {
"calib_texts": 0,
"calib_sequences": 0,
"inst_sequences": 0,
"total_sequences": 0,
}
if (not args.skip_distill) or bool(getattr(args, "comm_enabled", False)):
distill_loader, distill_generator, distill_meta = prepare_distillation_data(
args, tokenizer, include_instruction=False
)
if (
distill_generator is not None
and loader_generator_state is not None
and loader_generator_state.get("distill_generator_state") is not None
):
distill_generator.set_state(loader_generator_state["distill_generator_state"])
if args.lora_epochs > 0:
lora_loader, lora_generator, lora_meta = prepare_distillation_data(
args, tokenizer, include_instruction=True
)
if (
lora_generator is not None
and loader_generator_state is not None
and loader_generator_state.get("lora_generator_state") is not None
):
lora_generator.set_state(loader_generator_state["lora_generator_state"])
eval_dataloaders = prepare_eval_dataloaders(
args, tokenizer, model, eval_datasets, eval_configs
)
return PreparedData(
calib_loader=calib_loader,
calib_num_sequences=calib_num_sequences,
distill_loader=distill_loader,
distill_generator=distill_generator,
distill_meta=distill_meta,
lora_loader=lora_loader,
lora_generator=lora_generator,
lora_meta=lora_meta,
eval_datasets=eval_datasets,
eval_configs=eval_configs,
eval_dataloaders=eval_dataloaders,
)
def evaluate_ppl_model(
model: torch.nn.Module,
tokenizer,
eval_datasets: List[str],
eval_configs: List[Optional[str]],
args: argparse.Namespace,
max_batches: Optional[int] = None,
prepared_eval_dataloaders: Optional[Dict[str, torch.utils.data.DataLoader]] = None,
) -> Dict[str, float]:
eval_device = args.eval_device or args.device
prev_mode = model.training
try:
prev_device = next(model.parameters()).device
except StopIteration:
prev_device = torch.device(eval_device)
model.eval()
if str(prev_device) != eval_device:
model.to(eval_device)
if prepared_eval_dataloaders is not None:
results = ppl_eval.evaluate_ppl_dataloaders(
model,
prepared_eval_dataloaders,
eval_device,
max_batches=max_batches if max_batches is not None else args.eval_max_batches,
)
else:
eval_batch_size = args.eval_batch_size or args.batch_size
results = ppl_eval.evaluate_ppl_datasets(
model,
tokenizer,
datasets=eval_datasets,
configs=eval_configs,
split=args.eval_split,
text_field=args.eval_text_field,
num_samples=args.eval_num_samples,
seq_len=args.eval_seq_len,
batch_size=eval_batch_size,
device=eval_device,
seed=args.seed,
shuffle=False,
model_family=args.eval_model_family,
add_bos=args.eval_add_bos,
max_batches=max_batches if max_batches is not None else args.eval_max_batches,
cache_dir=args.eval_cache_dir,
num_workers=args.eval_num_workers,
)
if prev_mode:
model.train()
if str(prev_device) != eval_device:
model.to(prev_device)
if torch.cuda.is_available():
torch.cuda.empty_cache()
return results
def has_post_fusion_data(
distill_loader: Optional[torch.utils.data.DataLoader],
distill_meta: Optional[Dict[str, object]],
) -> bool:
if distill_loader is None or distill_meta is None:
return False
return distill_meta.get("total_sequences", 0) > 0
def summarize_gate_lambdas(gates: Dict[str, torch.Tensor]) -> Dict[str, object]:
if not gates:
return {"num_tensors": 0, "num_elements": 0}
total_sum = 0.0
total_elems = 0
per_tensor_mean: Dict[str, Optional[float]] = {}
for name, gate in gates.items():
g = gate.detach().float()
if g.numel() == 0:
per_tensor_mean[name] = None
continue
per_tensor_mean[name] = float(g.mean().item())
total_sum += float(g.sum().item())
total_elems += int(g.numel())
global_mean = None if total_elems == 0 else total_sum / float(total_elems)
return {
"num_tensors": len(gates),
"num_elements": total_elems,
"global_mean": global_mean,
"per_tensor_mean": per_tensor_mean,
}
def compute_path_bytes(path: str) -> int:
if os.path.isfile(path):
return os.path.getsize(path)
total = 0
for root, _, files in os.walk(path):
for name in files:
file_path = os.path.join(root, name)
if os.path.islink(file_path):
continue
try:
total += os.path.getsize(file_path)
except OSError:
continue
return total
def save_stage_checkpoint(
model: torch.nn.Module,
tokenizer,
stage_dir: str,
stage_name: str,
ppl_results: Optional[Dict[str, float]],
) -> Dict[str, object]:
os.makedirs(stage_dir, exist_ok=True)
colon_modules = find_colon_modules(model)
if colon_modules:
raise RuntimeError(
"Unexpected module names with ':' detected before save: "
f"{', '.join(colon_modules)}."
)
model.save_pretrained(stage_dir)
tokenizer.save_pretrained(stage_dir)
stage_meta = {
"stage": stage_name,
"path": stage_dir,
"weight_bytes": compute_path_bytes(stage_dir),
"post_ppl": ppl_results,
}
with open(
os.path.join(stage_dir, "stage_metrics.json"),
"w",
encoding="utf-8",
) as handle:
json.dump(stage_meta, handle, indent=2)
return stage_meta
def save_cycle_full_model(
model: torch.nn.Module,
tokenizer,
cycle_dir: str,
cycle_idx: int,
args: argparse.Namespace,
ppl_results: Optional[Dict[str, float]],
) -> Dict[str, object]:
full_model_dir = os.path.join(cycle_dir, "full_model")
stage_meta = save_stage_checkpoint(
model=model,
tokenizer=tokenizer,
stage_dir=full_model_dir,
stage_name=f"cycle_{cycle_idx}_full_model",
ppl_results=ppl_results,
)
resume_meta = {
"base_model": getattr(args, "base_model_id", args.model),
"cycle": cycle_idx,
"output_dir": args.output_dir,
"layer_path": args.layer_path,
"rng_state": "rng_state.pt",
"loader_generators": "loader_generators.pt",
}
with open(
os.path.join(full_model_dir, "resume_info.json"),
"w",
encoding="utf-8",
) as handle:
json.dump(resume_meta, handle, indent=2)
stage_meta["resume_info"] = "resume_info.json"
return stage_meta
def run_lora_phase(
model: torch.nn.Module,
tokenizer,
eval_datasets: List[str],
eval_configs: List[Optional[str]],
args: argparse.Namespace,
lora_loader: Optional[torch.utils.data.DataLoader] = None,
lora_meta: Optional[Dict[str, object]] = None,
eval_dataloaders: Optional[Dict[str, torch.utils.data.DataLoader]] = None,
cycle_idx: Optional[int] = None,
num_cycles: Optional[int] = None,
) -> List[Dict[str, object]]:
lora_eval_history: List[Dict[str, object]] = []
if args.lora_epochs <= 0:
return lora_eval_history
if not has_post_fusion_data(lora_loader, lora_meta):
print("No post-fusion sequences built; skipping LoRA finetuning.")
return lora_eval_history
lora_ce_finetune(
model=model,
dataloader=lora_loader,
eval_tokenizer=tokenizer,
eval_datasets=eval_datasets,
eval_configs=eval_configs,
eval_history=lora_eval_history,
args=args,
eval_dataloaders=eval_dataloaders,
progressive_cycle=cycle_idx,
progressive_total=num_cycles,
)
return lora_eval_history
def run_progressive(
args: argparse.Namespace,
model: torch.nn.Module,
tokenizer,
prepared: PreparedData,
) -> None:
eval_datasets = prepared.eval_datasets
eval_configs = prepared.eval_configs
dataloader = prepared.calib_loader
num_sequences = prepared.calib_num_sequences
model.to(args.device)
os.makedirs(args.output_dir, exist_ok=True)
progressive_meta_path = os.path.join(args.output_dir, "progressive_metadata.json")
existing_meta: Dict[str, object] = {}
if args.resume_from_cycle > 0 and os.path.exists(progressive_meta_path):
with open(progressive_meta_path, "r", encoding="utf-8") as handle:
loaded_meta = json.load(handle)
if isinstance(loaded_meta, dict):
existing_meta = loaded_meta
bootstrap_meta = {
"base_model": getattr(args, "base_model_id", args.model),
"num_progressive": args.num_progressive,
"layer_path": args.layer_path,
"resume_from_cycle": args.resume_from_cycle,
"save_full_model_cycles": sorted(args.full_model_save_cycles),
"cycles": (
existing_meta.get("cycles", [])
if isinstance(existing_meta.get("cycles"), list)
else []
),
}
with open(
progressive_meta_path,
"w",
encoding="utf-8",
) as handle:
json.dump(bootstrap_meta, handle, indent=2)
pre_eval = None
if not args.skip_eval:
pre_eval = evaluate_ppl_model(
model,
tokenizer,
eval_datasets,
eval_configs,
args,
prepared_eval_dataloaders=prepared.eval_dataloaders,
)
print("Pre-pruning perplexity:")
for dataset_name, ppl in pre_eval.items():
print(f"{dataset_name}: {ppl:.4f}")
parent, name, container = find_layer_container(model, args.layer_path)
layers = list(container)
if args.num_progressive > (len(layers) - 1 + args.resume_from_cycle):
raise SystemExit(
f"--num_progressive ({args.num_progressive}) exceeds available pairs "
f"after resume offset ({len(layers) - 1 + args.resume_from_cycle})"
)
dwce_scores = None
dwce_meta = None
last_fused_idx = 0
cycle_summaries: List[Dict[str, object]] = []
existing_cycles = existing_meta.get("cycles", [])
if isinstance(existing_cycles, list):
for entry in existing_cycles:
if not isinstance(entry, dict):
continue
cycle_value = entry.get("cycle")
if isinstance(cycle_value, int) and cycle_value <= args.resume_from_cycle:
cycle_summaries.append(entry)
comm_enabled = bool(getattr(args, "comm_enabled", False))
comm_teacher_model = None
comm_teacher_cycle: Optional[int] = None
teacher_device = args.distill_teacher_device or args.device
previous_cycle_teacher_model = None
previous_cycle_teacher_cycle: Optional[int] = None
def _release_comm_teacher() -> None:
nonlocal comm_teacher_model, comm_teacher_cycle
if comm_teacher_model is not None:
del comm_teacher_model
comm_teacher_model = None
comm_teacher_cycle = None
if torch.cuda.is_available():
torch.cuda.empty_cache()
def _release_previous_cycle_teacher() -> None:
nonlocal previous_cycle_teacher_model, previous_cycle_teacher_cycle
if previous_cycle_teacher_model is not None:
del previous_cycle_teacher_model
previous_cycle_teacher_model = None
previous_cycle_teacher_cycle = None
if torch.cuda.is_available():
torch.cuda.empty_cache()
def _snapshot_previous_cycle_teacher(cycle_idx: int) -> None:
nonlocal previous_cycle_teacher_model, previous_cycle_teacher_cycle
_release_previous_cycle_teacher()
previous_cycle_teacher_model = copy.deepcopy(model)
previous_cycle_teacher_model.to(teacher_device)
previous_cycle_teacher_model.eval()
previous_cycle_teacher_cycle = cycle_idx
def _get_previous_cycle_teacher(
cycle_idx: int,
) -> Tuple[Optional[torch.nn.Module], str, Optional[int]]:
prev_cycle = cycle_idx - 1
if prev_cycle <= 0:
return None, "base_model", 0
if (
previous_cycle_teacher_model is not None
and previous_cycle_teacher_cycle == prev_cycle
):
return previous_cycle_teacher_model, "previous_cycle_memory", prev_cycle
teacher_model = load_progressive_model(
getattr(args, "base_model_id", args.model),
args.output_dir,
cycle=prev_cycle,
device=teacher_device,
dtype=args.dtype,
trust_remote_code=args.trust_remote_code,
layer_path=args.layer_path,
)
teacher_model.eval()
return teacher_model, "previous_cycle_disk", prev_cycle
def _get_comm_teacher(cycle_idx: int) -> Tuple[Optional[torch.nn.Module], str, Optional[int]]:
nonlocal comm_teacher_model, comm_teacher_cycle
if not comm_enabled:
return None, "disabled", None
source = str(getattr(args, "redistrib_teacher_source", "base_model"))
if source == "base_model":
if comm_teacher_model is None:
print(
"[comm] Loading fixed base teacher for anchor loss "
f"(device={teacher_device})."
)
comm_teacher_model = AutoModelForCausalLM.from_pretrained(
getattr(args, "base_model_id", args.model),
torch_dtype=get_dtype(args.dtype),
trust_remote_code=args.trust_remote_code,
)
comm_teacher_model.to(teacher_device)
comm_teacher_model.eval()
comm_teacher_cycle = 0
return comm_teacher_model, "base_model", 0
prev_cycle = cycle_idx - 1
if prev_cycle <= 0:
if comm_teacher_model is None or comm_teacher_cycle != 0:
_release_comm_teacher()
print(
"[comm] --redistrib_teacher_source=previous_cycle but cycle 1 "
"has no prior checkpoint; using base teacher."
)
comm_teacher_model = AutoModelForCausalLM.from_pretrained(
getattr(args, "base_model_id", args.model),
torch_dtype=get_dtype(args.dtype),
trust_remote_code=args.trust_remote_code,
)
comm_teacher_model.to(teacher_device)
comm_teacher_model.eval()
comm_teacher_cycle = 0
return comm_teacher_model, "base_model", 0
if (
previous_cycle_teacher_model is not None
and previous_cycle_teacher_cycle == prev_cycle
):
if comm_teacher_model is not previous_cycle_teacher_model:
_release_comm_teacher()
comm_teacher_model = previous_cycle_teacher_model
comm_teacher_cycle = prev_cycle
return comm_teacher_model, "previous_cycle_memory", prev_cycle
if comm_teacher_model is None or comm_teacher_cycle != prev_cycle:
_release_comm_teacher()
print(
"[comm] Loading teacher from previous cycle "
f"{prev_cycle} (device={teacher_device})."
)
comm_teacher_model = load_progressive_model(
getattr(args, "base_model_id", args.model),
args.output_dir,
cycle=prev_cycle,
device=teacher_device,
dtype=args.dtype,
trust_remote_code=args.trust_remote_code,
layer_path=args.layer_path,
)
comm_teacher_model.eval()
comm_teacher_cycle = prev_cycle
return comm_teacher_model, "previous_cycle_disk", prev_cycle
if args.resume_from_cycle > 0:
_snapshot_previous_cycle_teacher(args.resume_from_cycle)
start_cycle = args.resume_from_cycle + 1
for cycle_idx in range(start_cycle, args.num_progressive + 1):
print(f"[progressive] Cycle {cycle_idx}/{args.num_progressive}")
run_comm = comm_enabled and (
cycle_idx > 1 or bool(getattr(args, "comm_include_cycle1", False))
)
comm_stats: Dict[str, object] = {"enabled": False}
comm_post_eval = None
if run_comm:
# Preconditioning updates model weights, so DWCE reuse is unreliable.
start_index = 0
reuse_scores = None
else:
start_index = last_fused_idx if cycle_idx > 1 else 0
reuse_scores = dwce_scores
exclude_pairs = set(parse_exclude_pairs(args.exclude_pairs, max(len(layers) - 1, 0)))
layer_idx, dwce_scores, dwce_meta = resolve_layer_idx(
args,
model,
layers,
dataloader,
reuse_scores,
start_index,
exclude_pairs,
)
if run_comm:
dwce_scores_pre_comm = dwce_scores
if prepared.calib_loader is None:
print(
"[comm] Enabled but no calibration sequences were built; skipping."
)
else:
(
comm_teacher_model_loaded,
comm_teacher_source,
comm_teacher_cycle_idx,
) = _get_comm_teacher(cycle_idx)
if comm_teacher_model_loaded is None:
raise RuntimeError("comm_enabled but teacher model was not loaded.")
comm_stats = commutator_precondition(
student_model=model,
student_layers=layers,
teacher_model=comm_teacher_model_loaded,
dataloader=prepared.calib_loader,
dwce_scores=dwce_scores_pre_comm,
exclude_pairs=exclude_pairs,
args=args,
progressive_cycle=cycle_idx,
progressive_total=args.num_progressive,
)
if comm_stats.get("enabled"):
comm_stats["teacher_source"] = comm_teacher_source
comm_stats["teacher_cycle"] = comm_teacher_cycle_idx
comm_stats["dwce_scores_pre"] = dwce_scores_pre_comm
print(
"[comm] Done:"
f" opt_steps={comm_stats.get('opt_steps')}"
f" lr={comm_stats.get('lr')}"
)
if not args.skip_eval:
comm_post_eval = evaluate_ppl_model(
model,
tokenizer,
eval_datasets,
eval_configs,
args,
prepared_eval_dataloaders=prepared.eval_dataloaders,
)
comm_stats["post_ppl"] = comm_post_eval
print(f"[progressive] Cycle {cycle_idx} post-comm perplexity:")
for dataset_name, ppl in comm_post_eval.items():
print(f"{dataset_name}: {ppl:.4f}")
if bool(getattr(args, "comm_skip_post_reselect", False)):
comm_stats["post_selection_recomputed"] = False
comm_stats["selected_layer_post"] = int(layer_idx)
print(
"[comm] Keeping pre-comm DWCE pair selection for fusion."
)
else:
print(
"[comm] Recomputing DWCE after preconditioning for fusion selection."
)
layer_idx, dwce_scores, dwce_meta = resolve_layer_idx(
args,
model,
layers,
dataloader,
None,
0,
exclude_pairs,
)
comm_stats["post_selection_recomputed"] = True
comm_stats["selected_layer_post"] = int(layer_idx)
if layer_idx < 0 or layer_idx >= len(layers) - 1:
raise SystemExit("--layer must be in [0, num_layers-2]")
num_layers_before = len(layers)
layer_a = layers[layer_idx]
layer_b = layers[layer_idx + 1]
norm1_state = None
norm2_state = None
norm1, norm2, norm_names = get_norm_pair(layer_a)
if norm1 is not None:
norm1_state = clone_state_dict(norm1)
if norm2 is not None:
norm2_state = clone_state_dict(norm2)
attn_a = find_attention_module(layer_a)
attn_b = find_attention_module(layer_b)
hidden_size = getattr(model.config, "hidden_size", None)
if hidden_size is None:
hidden_size = getattr(model.config, "n_embd", None)
if hidden_size is None:
raise SystemExit("Model config missing hidden_size/n_embd")
no_head_permute_merge = bool(
getattr(args, "no_head_permute_merge", False)
or getattr(args, "no_head_permute", False)
)
if no_head_permute_merge:
print("[fuse] Head permutation disabled; merging with original head order.")
else:
mean_a, mean_b, num_heads, num_kv_heads, head_dim = compute_head_means(
model,
attn_a,
attn_b,
dataloader,
args.device,
hidden_size,
)
perm = build_head_permutation(
mean_a,
mean_b,
num_heads=num_heads,
num_kv_heads=num_kv_heads,
eps=args.eps,
)
permute_attention_heads(
attn_b, perm, num_heads, num_kv_heads, head_dim=head_dim
)
fisher_sums, num_batches, param_numels = compute_fisher(
model,
layer_a,
layer_b,
dataloader,
fisher_mode=args.fisher_mode,
device=args.device,
)
distill_ready = has_post_fusion_data(
prepared.distill_loader, prepared.distill_meta
)
teacher_cycle = cycle_idx - 1
teacher_source = "previous_cycle" if teacher_cycle > 0 else "base_model"
merge_method = "fisher"
distill_method = str(getattr(args, "distill_method", "reparam"))
reparam_stats: Optional[Dict[str, object]] = None
reparam_gate_summary: Optional[Dict[str, object]] = None
needs_teacher_for_reparam = (
(not args.skip_distill)
and distill_ready
and float(args.distill_epochs) > 0.0
)
teacher_model = None
teacher_parent = None
teacher_layer_attr = None
teacher_layers: Optional[List[torch.nn.Module]] = None
teacher_from_cache = False
if needs_teacher_for_reparam:
teacher_model, teacher_source, teacher_cycle = _get_previous_cycle_teacher(
cycle_idx
)
teacher_from_cache = (
teacher_source == "previous_cycle_memory"
and teacher_model is previous_cycle_teacher_model
)
if teacher_model is None:
teacher_model = load_causal_lm(
getattr(args, "base_model_id", args.model),
torch_dtype=get_dtype(args.dtype),
trust_remote_code=args.trust_remote_code,
cache_dir=args.model_cache_dir,
)
teacher_model.to(teacher_device)
teacher_model.eval()
teacher_source = "base_model"
teacher_cycle = 0
teacher_parent, teacher_layer_attr, teacher_container = find_layer_container(
teacher_model, args.layer_path
)
teacher_layers = list(teacher_container)
do_reparam = (
(not args.skip_distill)
and distill_ready
and prepared.distill_loader is not None
)
if (not args.skip_distill) and not do_reparam:
print("[reparam] No distillation sequences built; skipping reparam distill.")
distill_post_eval = None
if do_reparam:
lambda_source = "fisher_prior"
reparam_gate_targets: Dict[str, object] = compute_fisher_gate_priors(
layer_a=layer_a,
layer_b=layer_b,
fisher_a=fisher_sums[0],
fisher_b=fisher_sums[1],
num_batches=num_batches,
numels_a=param_numels[0],
numels_b=param_numels[1],
fisher_mode=args.fisher_mode,
eps=float(args.eps),
)
if not reparam_gate_targets:
raise SystemExit("[reparam] No mergeable parameters found; cannot continue.")
if float(args.distill_epochs) > 0.0 and (
teacher_model is None or teacher_layers is None
):
raise SystemExit("--distill_method reparam requires a teacher model.")
print(
f"[reparam] Cycle {cycle_idx}: training U + gates for pair "
f"{layer_idx}-{layer_idx + 1} (epochs={args.distill_epochs}, "
f"hidden_mse_w={args.distill_hidden_mse_weight}, "
f"attn_mse_w={args.distill_attn_mse_weight}, "
f"mlp_mse_w={args.distill_mlp_mse_weight}, "
f"eta={args.reparam_eta}, gamma={args.reparam_gamma}, "
f"attn_reg_scale={args.reparam_attn_reg_scale}, "
f"mlp_reg_scale={args.reparam_mlp_reg_scale}, "
f"param_subset={args.reparam_param_subset}, "
f"lambda_init={lambda_source})."
)
merged, final_gates, reparam_stats = distill_reparam_merge(
student_model=model,
student_parent=parent,
student_layer_attr=name,
student_layers=layers,
teacher_model=teacher_model,
teacher_parent=teacher_parent,
teacher_layer_attr=teacher_layer_attr,
teacher_layers=teacher_layers,
layer_idx=layer_idx,
gate_lambdas=reparam_gate_targets,
dataloader=prepared.distill_loader,
args=args,
progressive_cycle=cycle_idx,
progressive_total=args.num_progressive,
)
reparam_gate_summary = summarize_gate_lambdas(final_gates)
merge_method = "reparam"
if reparam_stats is not None:
reparam_stats["lambda_init"] = lambda_source
else:
merged = merge_layers(
layer_a,
layer_b,
fisher_sums[0],
fisher_sums[1],
num_batches,
param_numels[0],
param_numels[1],
fisher_mode=args.fisher_mode,
eps=args.eps,
)
apply_norm_policy(
layer_a,
args.norm_policy,
norm1_state,
norm2_state,
norm_names,
)
if teacher_model is not None and not teacher_from_cache:
del teacher_model
teacher_model = None
teacher_parent = None
teacher_layer_attr = None
teacher_layers = None
if torch.cuda.is_available():
torch.cuda.empty_cache()
new_container = drop_layer(container, layer_idx + 1)
setattr(parent, name, new_container)
decrement_config(model.config)
layers = list(new_container)
lora_post_eval = None
if (not args.skip_eval) and (not args.skip_distill) and do_reparam:
distill_post_eval = evaluate_ppl_model(
model,
tokenizer,
eval_datasets,
eval_configs,
args,
prepared_eval_dataloaders=prepared.eval_dataloaders,
)
print(f"[progressive] Cycle {cycle_idx} post-distill perplexity:")
for dataset_name, ppl in distill_post_eval.items():
print(f"{dataset_name}: {ppl:.4f}")
post_eval = None
if not args.skip_eval:
if distill_post_eval is not None:
post_eval = distill_post_eval
else:
post_eval = evaluate_ppl_model(
model,
tokenizer,
eval_datasets,
eval_configs,
args,
prepared_eval_dataloaders=prepared.eval_dataloaders,
)
print(f"[progressive] Cycle {cycle_idx} perplexity:")
for dataset_name, ppl in post_eval.items():
print(f"{dataset_name}: {ppl:.4f}")
cycle_dir = os.path.join(args.output_dir, f"cycle_{cycle_idx}")
os.makedirs(cycle_dir, exist_ok=True)
fused_layer_file = "fused_layer.pt"
fused_layer_path = os.path.join(cycle_dir, fused_layer_file)
torch.save(layers[layer_idx].state_dict(), fused_layer_path)
cycle_meta: Dict[str, object] = {
"cycle": cycle_idx,
"layer_merged": layer_idx,
"num_layers_before": num_layers_before,
"num_layers_after": num_layers_before - 1,
"fused_layer_state": fused_layer_file,
"dwce_score": dwce_scores[layer_idx] if dwce_scores else None,
"dwce_scores": dwce_scores,
"dwce_meta": dwce_meta,
"fisher_num_batches": num_batches,
"merge_method": merge_method,
"merged_params": merged,
"num_sequences": num_sequences,
"teacher_source": teacher_source,
"teacher_cycle": teacher_cycle,
"eval": {
"datasets": eval_datasets,
"configs": eval_configs,
"split": args.eval_split,
"num_samples": args.eval_num_samples,
"seq_len": args.eval_seq_len,
"post_ppl": post_eval,
},
"comm": comm_stats,
"distill": {
"enabled": not args.skip_distill,
"method": distill_method,
"calib_samples": args.distill_calib_samples,
"inst_samples": args.distill_inst_samples,
"seq_len": args.distill_seq_len,
"batch_size": args.distill_batch_size,
"epochs": args.distill_epochs,
"lr": args.distill_lr,
"kl_weight": args.distill_kl_weight,
"kl_temp": args.distill_kl_temp,
"hidden_mse_weight": args.distill_hidden_mse_weight,
"attn_mse_weight": args.distill_attn_mse_weight,
"mlp_mse_weight": args.distill_mlp_mse_weight,
"reparam_eta": args.reparam_eta,
"reparam_gamma": args.reparam_gamma,
"reparam_attn_reg_scale": args.reparam_attn_reg_scale,
"reparam_mlp_reg_scale": args.reparam_mlp_reg_scale,
"reparam_param_subset": args.reparam_param_subset,
"reparam_stats": reparam_stats,
"reparam_gate_summary": reparam_gate_summary,
"post_ppl": distill_post_eval,
"weight_decay": args.distill_weight_decay,
"max_grad_norm": args.distill_max_grad_norm,
"grad_accum_steps": args.distill_grad_accum_steps,
"instruction_dataset": args.instruction_dataset,
"instruction_config": args.instruction_config,
"instruction_split": args.instruction_split,
},
"lora": {
"enabled": args.lora_epochs > 0,
"seq_len": args.distill_seq_len,
"batch_size": args.distill_batch_size,
"epochs": args.lora_epochs,
"rank": args.lora_rank,
"alpha": args.lora_alpha,
"dropout": args.lora_dropout,
"target_modules": args.lora_target_modules,
"respect_exclude_pairs": args.lora_respect_exclude_pairs,
"kl_enabled": args.lora_kl_enabled,
"kl_weight": args.lora_kl_weight,
"kl_temp": args.lora_kl_temp,
"post_ppl": lora_post_eval,
"lr": args.lora_lr,
"weight_decay": args.lora_weight_decay,
"max_grad_norm": args.lora_max_grad_norm,
"grad_accum_steps": args.lora_grad_accum_steps,
"log_steps": args.lora_log_steps,
"eval_every": args.lora_eval_every,
"eval_max_batches": args.lora_eval_max_batches,
},
"norm_policy": args.norm_policy,
}
saved_full_model_dir = None
if cycle_idx in args.full_model_save_cycles:
cycle_meta["full_model_saved"] = True
cycle_meta["full_model"] = save_cycle_full_model(
model=model,
tokenizer=tokenizer,
cycle_dir=cycle_dir,
cycle_idx=cycle_idx,
args=args,
ppl_results=post_eval,
)
saved_full_model_dir = os.path.join(cycle_dir, "full_model")
else:
cycle_meta["full_model_saved"] = False
with open(
os.path.join(cycle_dir, "cycle_metadata.json"),
"w",
encoding="utf-8",
) as handle:
json.dump(cycle_meta, handle, indent=2)
cycle_summaries.append(
{
"cycle": cycle_idx,
"layer_merged": layer_idx,
"dwce_score": dwce_scores[layer_idx] if dwce_scores else None,
"comm_post_ppl": comm_post_eval,
"distill_post_ppl": distill_post_eval,
"lora_post_ppl": lora_post_eval,
"post_ppl": post_eval,
"cycle_dir": f"cycle_{cycle_idx}",
}
)
last_fused_idx = layer_idx
_snapshot_previous_cycle_teacher(cycle_idx)
parent, name, container = find_layer_container(model, args.layer_path)
layers = list(container)
if dwce_scores:
dwce_scores = dwce_scores[: max(len(layers) - 1, 0)]
# Encourage allocator to release cached blocks between cycles.
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
gc.collect()
if saved_full_model_dir is not None:
save_rng_state(os.path.join(saved_full_model_dir, "rng_state.pt"))
save_loader_generator_state(
saved_full_model_dir,
distill_generator=prepared.distill_generator,
lora_generator=prepared.lora_generator,
)
_release_comm_teacher()
_release_previous_cycle_teacher()
final_pre_lora_eval = cycle_summaries[-1]["post_ppl"] if cycle_summaries else None
final_pre_lora_dir = f"{os.path.abspath(args.output_dir.rstrip(os.sep))}_final_pre_lora_hf"
final_pre_lora_meta = save_stage_checkpoint(
model=model,
tokenizer=tokenizer,
stage_dir=final_pre_lora_dir,
stage_name="final_pre_lora",
ppl_results=final_pre_lora_eval,
)
# Optional final LoRA finetune after all pruning cycles.
lora_eval_history: List[Dict[str, object]] = []
lora_post_eval = None
lora_ready = has_post_fusion_data(prepared.lora_loader, prepared.lora_meta)
if args.lora_epochs > 0:
if not lora_ready:
print("No post-fusion sequences built; skipping LoRA finetuning.")
else:
print(
f"[progressive] Running final LoRA finetuning (epochs={args.lora_epochs})."
)
lora_eval_history = run_lora_phase(
model=model,
tokenizer=tokenizer,
eval_datasets=eval_datasets,
eval_configs=eval_configs,
args=args,
lora_loader=prepared.lora_loader,
lora_meta=prepared.lora_meta,
eval_dataloaders=prepared.eval_dataloaders,
cycle_idx=args.num_progressive,
num_cycles=args.num_progressive,
)
if not args.skip_eval:
lora_post_eval = evaluate_ppl_model(
model,
tokenizer,
eval_datasets,
eval_configs,
args,
prepared_eval_dataloaders=prepared.eval_dataloaders,
)
print("[progressive] Post-LoRA perplexity:")
for dataset_name, ppl in lora_post_eval.items():
print(f"{dataset_name}: {ppl:.4f}")
# Update final cycle metadata and summary with the post-LoRA PPL.
if cycle_summaries:
cycle_summaries[-1]["lora_post_ppl"] = lora_post_eval
if lora_post_eval is not None:
cycle_summaries[-1]["post_ppl"] = lora_post_eval
final_cycle_dir = os.path.join(
args.output_dir, f"cycle_{args.num_progressive}"
)
final_cycle_meta_path = os.path.join(final_cycle_dir, "cycle_metadata.json")
if os.path.exists(final_cycle_meta_path):
with open(final_cycle_meta_path, "r", encoding="utf-8") as handle:
final_cycle_meta = json.load(handle)
lora_meta_entry = final_cycle_meta.get("lora")
if not isinstance(lora_meta_entry, dict):
lora_meta_entry = {}
final_cycle_meta["lora"] = lora_meta_entry
lora_meta_entry["ran"] = True
lora_meta_entry["post_ppl"] = lora_post_eval
if lora_post_eval is not None and isinstance(
final_cycle_meta.get("eval"), dict
):
final_cycle_meta["eval"]["post_ppl"] = lora_post_eval
if lora_eval_history:
lora_path = os.path.join(final_cycle_dir, "ppl_over_lora.json")
with open(lora_path, "w", encoding="utf-8") as handle:
json.dump(lora_eval_history, handle, indent=2)
lora_meta_entry["ppl_over_lora"] = "ppl_over_lora.json"
with open(final_cycle_meta_path, "w", encoding="utf-8") as handle:
json.dump(final_cycle_meta, handle, indent=2)
os.makedirs(args.output_dir, exist_ok=True)
final_post_lora_meta = save_stage_checkpoint(
model=model,
tokenizer=tokenizer,
stage_dir=args.output_dir,
stage_name="final_post_lora" if lora_post_eval is not None else "final_model",
ppl_results=lora_post_eval,
)
progressive_meta = {
"base_model": getattr(args, "base_model_id", args.model),
"num_progressive": args.num_progressive,
"layer_path": args.layer_path,
"resume_from_cycle": args.resume_from_cycle,
"save_full_model_cycles": sorted(args.full_model_save_cycles),
"num_sequences": num_sequences,
"seq_len": args.seq_len,
"lora": {
"enabled": args.lora_epochs > 0,
"ran": args.lora_epochs > 0 and lora_ready,
"seq_len": args.distill_seq_len,
"batch_size": args.distill_batch_size,
"epochs": args.lora_epochs,
"rank": args.lora_rank,
"alpha": args.lora_alpha,
"dropout": args.lora_dropout,
"target_modules": args.lora_target_modules,
"respect_exclude_pairs": args.lora_respect_exclude_pairs,
"kl_enabled": args.lora_kl_enabled,
"kl_weight": args.lora_kl_weight,
"kl_temp": args.lora_kl_temp,
"post_ppl": lora_post_eval,
"ppl_over_lora": (
f"cycle_{args.num_progressive}/ppl_over_lora.json"
if lora_eval_history
else None
),
"lr": args.lora_lr,
"weight_decay": args.lora_weight_decay,
"max_grad_norm": args.lora_max_grad_norm,
"grad_accum_steps": args.lora_grad_accum_steps,
"log_steps": args.lora_log_steps,
"eval_every": args.lora_eval_every,
"eval_max_batches": args.lora_eval_max_batches,
},
"artifacts": {
"final_pre_lora": final_pre_lora_meta,
"final_post_lora": final_post_lora_meta,
},
"eval": {
"datasets": eval_datasets,
"configs": eval_configs,
"split": args.eval_split,
"num_samples": args.eval_num_samples,
"seq_len": args.eval_seq_len,
"pre_ppl": pre_eval,
"post_ppl": cycle_summaries[-1]["post_ppl"] if cycle_summaries else None,
},
"cycles": cycle_summaries,
"final_num_layers": len(layers),
}
with open(
os.path.join(args.output_dir, "progressive_metadata.json"),
"w",
encoding="utf-8",
) as handle:
json.dump(progressive_meta, handle, indent=2)
print(
f"[progressive] Completed {args.num_progressive} cycles. "
f"Final model saved to {args.output_dir}."
)
def main() -> None:
args = parse_args()
if args.num_progressive <= 0:
raise SystemExit(
"Single-cycle mode has been removed. Pass --num_progressive > 0."
)
if args.resume_from_cycle < 0:
raise SystemExit("--resume_from_cycle must be >= 0.")
if args.resume_from_cycle >= args.num_progressive:
raise SystemExit("--resume_from_cycle must be smaller than --num_progressive.")
args.full_model_save_cycles = resolve_full_model_save_cycles(
parse_cycle_list(args.save_full_model_cycles),
args.num_progressive,
)
args.base_model_id = args.model
if args.resume_from_cycle > 0:
resume_meta = load_resume_metadata(args.model)
if resume_meta is None:
raise SystemExit(
"--resume_from_cycle requires --model to point to a saved cycle full model "
"directory containing resume_info.json."
)
resume_cycle = resume_meta.get("cycle")
if resume_cycle is not None and int(resume_cycle) != args.resume_from_cycle:
raise SystemExit(
"resume_info.json cycle does not match --resume_from_cycle."
)
base_model = resume_meta.get("base_model")
if isinstance(base_model, str) and base_model:
args.base_model_id = base_model
configure_reproducibility(args.seed)
eval_datasets, eval_configs = resolve_eval_datasets(args)
dtype = get_dtype(args.dtype)
model = load_causal_lm(
args.model,
torch_dtype=dtype,
trust_remote_code=args.trust_remote_code,
cache_dir=args.model_cache_dir,
)
loader_generator_state = None
if args.resume_from_cycle > 0:
rng_state_path = os.path.join(args.model, "rng_state.pt")
rng_state = load_rng_state(rng_state_path)
if rng_state is not None:
restore_rng_state(rng_state)
loader_generator_state = load_loader_generator_state(args.model)
tokenizer = AutoTokenizer.from_pretrained(
args.model,
trust_remote_code=args.trust_remote_code,
cache_dir=args.model_cache_dir,
)
print(model)
if tokenizer.pad_token is None and tokenizer.eos_token is not None:
tokenizer.pad_token = tokenizer.eos_token
# for llama?
model.config.use_cache = False
prepared = prepare_all_data(
args,
tokenizer,
model,
eval_datasets,
eval_configs,
loader_generator_state=loader_generator_state,
)
run_progressive(args, model, tokenizer, prepared)
if __name__ == "__main__":
main()