| |
| """Fuse adjacent layers via attention head alignment + Fisher-barycentric merge.""" |
|
|
| import argparse |
| import copy |
| import gc |
| import json |
| import os |
| import random |
| from dataclasses import dataclass |
| from typing import Dict, List, Optional, Set, Tuple |
|
|
| import torch |
|
|
| try: |
| import numpy as np |
| except Exception: |
| np = None |
|
|
| try: |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
| except Exception as exc: |
| raise SystemExit("transformers is required: pip install transformers") from exc |
|
|
| try: |
| import ppl_eval |
| except Exception as exc: |
| raise SystemExit("ppl_eval.py is required (missing or invalid)") from exc |
|
|
| from fuse_layers_data import ( |
| FixedSeqDataset, |
| build_token_chunks, |
| expand_dataset_configs, |
| load_instruction_records, |
| load_texts, |
| load_texts_from_datasets, |
| ) |
| from common_lm_data import SharedLMDataSpec, build_chunks, build_dataloader |
| from fuse_layers_distill import ( |
| commutator_precondition, |
| compute_fisher_gate_priors, |
| distill_reparam_merge, |
| lora_ce_finetune, |
| ) |
| from fuse_layers_model import ( |
| apply_norm_policy, |
| build_head_permutation, |
| clone_state_dict, |
| compute_fisher, |
| compute_head_means, |
| decrement_config, |
| drop_layer, |
| find_attention_module, |
| find_colon_modules, |
| find_layer_container, |
| get_dtype, |
| get_norm_pair, |
| merge_layers, |
| permute_attention_heads, |
| ) |
| from fuse_layers_select import select_layer_auto |
| from progressive_loader import load_causal_lm, load_progressive_model |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser( |
| description="Fuse layer i and i+1 using head alignment + Fisher barycenter." |
| ) |
| parser.add_argument("--model", required=True, help="HF model id or local path") |
| parser.add_argument( |
| "--model_cache_dir", |
| default=None, |
| help="Optional cache dir for model/tokenizer downloads", |
| ) |
| parser.add_argument( |
| "--layer", |
| type=str, |
| default="auto", |
| help="Layer index i (int) or 'auto' to select via auto metric", |
| ) |
| parser.add_argument( |
| "--selection_method", |
| choices=["dwce", "sequential"], |
| default="dwce", |
| help=( |
| "Pair selection policy for progressive pruning. " |
| "'dwce' uses downstream-weighted composition error; " |
| "'sequential' always takes the next available pair." |
| ), |
| ) |
| parser.add_argument( |
| "--exclude_pairs", |
| "--exclude_layers", |
| nargs="*", |
| default=None, |
| dest="exclude_pairs", |
| help=( |
| "Exclude pair indices from consideration for any fusion. Indices refer to " |
| "pair start positions in [0..N-2]. Negative indices count from the end " |
| "(-1 = last pair, -2 = second last). Accepts space- or comma-separated ints. " |
| "Alias: --exclude_layers (deprecated)." |
| ), |
| ) |
| parser.add_argument( |
| "--output_dir", required=True, help="Directory to write fused model" |
| ) |
| parser.add_argument( |
| "--dataset", |
| action="append", |
| default=[], |
| help=( |
| "HF dataset name (repeatable). Optional if using --text or --text_file." |
| ), |
| ) |
| parser.add_argument( |
| "--dataset_config", |
| action="append", |
| default=[], |
| help="Optional dataset config (repeatable or single shared config).", |
| ) |
| parser.add_argument( |
| "--dataset_split", |
| default="train", |
| help="Dataset split to use (default: train)", |
| ) |
| parser.add_argument( |
| "--dataset_text_field", |
| default=None, |
| help="Text field in dataset (default: auto-detect, applies to all datasets)", |
| ) |
| parser.add_argument( |
| "--text", |
| action="append", |
| default=[], |
| help="Inline text samples (can pass multiple)", |
| ) |
| parser.add_argument( |
| "--text_file", |
| default=None, |
| help="Path to a text file for calibration data", |
| ) |
| parser.add_argument( |
| "--num_samples", |
| type=int, |
| default=128, |
| help="Number of token sequences to use", |
| ) |
| parser.add_argument( |
| "--target_tokens", |
| type=int, |
| default=0, |
| help="Target token budget for common_lm_data-backed calibration/distillation (0 = disabled)", |
| ) |
| parser.add_argument("--seq_len", type=int, default=256, help="Sequence length") |
| parser.add_argument("--batch_size", type=int, default=2, help="Batch size") |
| parser.add_argument( |
| "--device", |
| default="cuda" if torch.cuda.is_available() else "cpu", |
| help="Device for model + compute", |
| ) |
| parser.add_argument( |
| "--dtype", |
| default="auto", |
| choices=["auto", "float32", "float16", "bfloat16"], |
| help="Model dtype", |
| ) |
| parser.add_argument( |
| "--layer_path", |
| default=None, |
| help="Override layer attribute path (e.g., model.layers)", |
| ) |
| parser.add_argument( |
| "--fisher_mode", |
| default="tensor", |
| choices=["tensor", "param"], |
| help="Fisher approximation granularity", |
| ) |
| parser.add_argument( |
| "--no_head_permute", |
| action="store_true", |
| help=( |
| "Deprecated alias for --no_head_permute_merge. " |
| "Disables merge-stage head permutation only." |
| ), |
| ) |
| parser.add_argument( |
| "--no_head_permute_merge", |
| action="store_true", |
| help="Disable attention head permutation alignment before merge", |
| ) |
| parser.add_argument( |
| "--no_head_permute_select", |
| action="store_true", |
| help="Disable attention head permutation alignment during auto selection", |
| ) |
| parser.add_argument("--eps", type=float, default=1e-8, help="Stability epsilon") |
| parser.add_argument("--seed", type=int, default=0, help="Random seed") |
| parser.add_argument( |
| "--trust_remote_code", |
| action="store_true", |
| help="Allow custom model code from hub", |
| ) |
| parser.add_argument( |
| "--save_metadata", |
| action="store_true", |
| help="Backward-compatible no-op; metadata is always written.", |
| ) |
| parser.add_argument( |
| "--skip_eval", |
| action="store_true", |
| help="Skip pre/post perplexity evaluation", |
| ) |
| parser.add_argument( |
| "--eval_dataset", |
| action="append", |
| default=[], |
| help="Evaluation dataset name (repeatable). Defaults to wikitext.", |
| ) |
| parser.add_argument( |
| "--eval_dataset_config", |
| action="append", |
| default=[], |
| help="Evaluation dataset config (repeatable or single shared config).", |
| ) |
| parser.add_argument( |
| "--eval_split", |
| default="test", |
| help="Evaluation dataset split (default: test)", |
| ) |
| parser.add_argument( |
| "--eval_text_field", |
| default=None, |
| help="Evaluation text field override (default: auto-detect)", |
| ) |
| parser.add_argument( |
| "--eval_model_family", |
| type=str, |
| choices=["auto", "llama", "qwen"], |
| default="auto", |
| help="Model family for BOS handling during eval", |
| ) |
| parser.add_argument( |
| "--eval_add_bos", |
| type=str, |
| choices=["auto", "always", "never"], |
| default="auto", |
| help="Whether to prepend BOS to each eval sample", |
| ) |
| parser.add_argument( |
| "--eval_num_samples", |
| type=int, |
| default=0, |
| help="Number of token sequences per eval dataset (0 = all)", |
| ) |
| parser.add_argument( |
| "--eval_seq_len", |
| type=int, |
| default=2048, |
| help="Sequence length for eval", |
| ) |
| parser.add_argument( |
| "--eval_batch_size", |
| type=int, |
| default=None, |
| help="Batch size for eval (defaults to --batch_size)", |
| ) |
| parser.add_argument( |
| "--eval_max_batches", |
| type=int, |
| default=None, |
| help="Optional max number of eval batches per dataset", |
| ) |
| parser.add_argument( |
| "--eval_cache_dir", |
| default=None, |
| help="Optional datasets cache dir for eval", |
| ) |
| parser.add_argument( |
| "--eval_num_workers", |
| type=int, |
| default=0, |
| help="Eval DataLoader workers", |
| ) |
| parser.add_argument( |
| "--eval_device", |
| default=None, |
| help="Device for eval (defaults to --device)", |
| ) |
| parser.add_argument( |
| "--skip_distill", |
| action="store_true", |
| help="Skip reparameterized distillation after head alignment/Fisher setup", |
| ) |
| parser.add_argument( |
| "--distill_calib_samples", |
| type=int, |
| default=256, |
| help="Number of distillation sequences from calibration datasets", |
| ) |
| parser.add_argument( |
| "--distill_inst_samples", |
| type=int, |
| default=0, |
| help="Number of distillation sequences from instruction dataset (0 = all)", |
| ) |
| parser.add_argument( |
| "--distill_seq_len", |
| type=int, |
| default=512, |
| help="Sequence length for distillation", |
| ) |
| parser.add_argument( |
| "--distill_batch_size", |
| type=int, |
| default=2, |
| help="Batch size for distillation", |
| ) |
| parser.add_argument( |
| "--distill_epochs", |
| type=float, |
| default=1.0, |
| help="Number of distillation epochs (float allowed, e.g. 0.5)", |
| ) |
| parser.add_argument( |
| "--distill_lr", |
| type=float, |
| default=1e-4, |
| help="Learning rate for distillation", |
| ) |
| parser.add_argument( |
| "--distill_method", |
| choices=["reparam"], |
| default="reparam", |
| help="Distillation strategy (reparam only).", |
| ) |
| parser.add_argument( |
| "--distill_kl_weight", |
| type=float, |
| default=1e-2, |
| help="Weight for KL loss on logits", |
| ) |
| parser.add_argument( |
| "--distill_kl_temp", |
| type=float, |
| default=4.0, |
| help="Temperature for KL distillation on logits", |
| ) |
| parser.add_argument( |
| "--distill_hidden_mse_weight", |
| type=float, |
| default=1.0, |
| help="Weight for hidden-state MSE in reparam distillation (0 disables it)", |
| ) |
| parser.add_argument( |
| "--distill_attn_mse_weight", |
| type=float, |
| default=0.0, |
| help="Weight for auxiliary attention-output MSE in reparam distillation", |
| ) |
| parser.add_argument( |
| "--distill_mlp_mse_weight", |
| type=float, |
| default=0.0, |
| help="Weight for auxiliary MLP-output MSE in reparam distillation", |
| ) |
| parser.add_argument( |
| "--reparam_eta", |
| type=float, |
| default=1e-2, |
| help="Eta: ||lambda - lambda_gate||^2 regularizer weight for --distill_method reparam", |
| ) |
| parser.add_argument( |
| "--reparam_gamma", |
| type=float, |
| default=1e-4, |
| help="Gamma: ||U - U0||^2 regularizer weight for --distill_method reparam", |
| ) |
| parser.add_argument( |
| "--reparam_attn_reg_scale", |
| type=float, |
| default=1.0, |
| help="Relative scale applied to attention-parameter reparam regularizers", |
| ) |
| parser.add_argument( |
| "--reparam_mlp_reg_scale", |
| type=float, |
| default=1.0, |
| help="Relative scale applied to MLP-parameter reparam regularizers", |
| ) |
| parser.add_argument( |
| "--reparam_param_subset", |
| type=str, |
| choices=["all", "mlp", "attn"], |
| default="all", |
| help="Restrict reparam merge/recovery capacity to only this parameter family", |
| ) |
| parser.add_argument( |
| "--norm_policy", |
| type=str, |
| choices=["hybrid", "merge_all", "copy_n1", "copy_n1_n2"], |
| default="hybrid", |
| help="Norm merge policy (default: hybrid)", |
| ) |
| parser.add_argument( |
| "--distill_weight_decay", |
| type=float, |
| default=0.0, |
| help="Weight decay for distillation", |
| ) |
| parser.add_argument( |
| "--distill_max_grad_norm", |
| type=float, |
| default=1.0, |
| help="Max grad norm for distillation", |
| ) |
| parser.add_argument( |
| "--distill_grad_accum_steps", |
| type=int, |
| default=1, |
| help="Gradient accumulation steps for distillation", |
| ) |
| parser.add_argument( |
| "--distill_log_steps", |
| type=int, |
| default=100, |
| help="Log distillation loss every N steps", |
| ) |
| parser.add_argument( |
| "--distill_eval_every", |
| type=int, |
| default=0, |
| help="Evaluate PPL every N distill steps (0 = disable)", |
| ) |
| parser.add_argument( |
| "--distill_eval_max_batches", |
| type=int, |
| default=None, |
| help="Max eval batches per dataset during distill (default: all)", |
| ) |
| parser.add_argument( |
| "--distill_teacher_device", |
| default=None, |
| help="Device for teacher model during distillation (defaults to --device)", |
| ) |
| parser.add_argument( |
| "--comm_enabled", |
| action="store_true", |
| help=( |
| "Enable commutator-style preconditioning before each progressive " |
| "cycle's fusion." |
| ), |
| ) |
| parser.add_argument( |
| "--comm_include_cycle1", |
| action="store_true", |
| help="Run commutator preconditioning for cycle 1 as well (default: skip cycle 1).", |
| ) |
| parser.add_argument( |
| "--comm_topk", |
| type=int, |
| default=1, |
| help="Top-K lowest-score pairs used as the commutator candidate set", |
| ) |
| parser.add_argument( |
| "--comm_sample_eta", |
| type=float, |
| default=0.5, |
| help="Mixture weight between uniform and score-biased candidate sampling", |
| ) |
| parser.add_argument( |
| "--comm_sample_dwce_scale", |
| type=float, |
| default=1.0, |
| help="Scale c in softmax(-c * score(i)) for commutator pair sampling", |
| ) |
| parser.add_argument( |
| "--comm_temp", |
| type=float, |
| default=2.0, |
| help="Temperature for teacher-anchor KL in commutator preconditioning", |
| ) |
| parser.add_argument( |
| "--comm_steps_ratio", |
| type=float, |
| default=0.1, |
| help="Run this fraction of distillation optimizer steps for commutator phase", |
| ) |
| parser.add_argument( |
| "--comm_lr_scale", |
| type=float, |
| default=0.1, |
| help="Commutator LR = --distill_lr * this scale", |
| ) |
| parser.add_argument( |
| "--comm_train_mode", |
| choices=["lora", "full"], |
| default="lora", |
| help=( |
| "Commutator trainable parameter mode: " |
| "'lora' updates LoRA adapters on sampled receiver layers; " |
| "'full' updates full receiver-layer weights." |
| ), |
| ) |
| parser.add_argument( |
| "--comm_interaction_mode", |
| choices=["mse", "relative"], |
| default="relative", |
| help="Interaction loss form: plain MSE or relative MSE", |
| ) |
| parser.add_argument( |
| "--comm_interaction_eps", |
| type=float, |
| default=1e-8, |
| help="Epsilon for relative commutator interaction normalization", |
| ) |
| parser.add_argument( |
| "--comm_mu", |
| type=float, |
| default=None, |
| help=( |
| "Weight for interaction loss. Defaults to 0.1 for --comm_interaction_mode=mse " |
| "and 0.5 for --comm_interaction_mode=relative." |
| ), |
| ) |
| parser.add_argument( |
| "--comm_mu_auto", |
| action="store_true", |
| help="Enable automatic mu scaling via gradient-norm balancing", |
| ) |
| parser.add_argument( |
| "--comm_mu_auto_rho", |
| type=float, |
| default=0.1, |
| help="Target anchor-to-interaction gradient ratio constant for auto-mu", |
| ) |
| parser.add_argument( |
| "--comm_mu_auto_eps", |
| type=float, |
| default=1e-8, |
| help="Numerical epsilon in auto-mu denominator", |
| ) |
| parser.add_argument( |
| "--comm_log_steps", |
| type=int, |
| default=50, |
| help="Log commutator preconditioning loss every N optimizer steps", |
| ) |
| parser.add_argument( |
| "--comm_skip_post_reselect", |
| action="store_true", |
| help=( |
| "Keep the pre-comm selected fusion pair and skip recomputing " |
| "selection after commutator preconditioning." |
| ), |
| ) |
| parser.add_argument( |
| "--redistrib_teacher_source", |
| type=str, |
| choices=["base_model", "previous_cycle"], |
| default="base_model", |
| help=( |
| "Teacher source for commutator preconditioning teacher loading. " |
| "'base_model' uses --model for all cycles; " |
| "'previous_cycle' uses cycle-1 checkpoint (cycle 1 falls back to base_model)." |
| ), |
| ) |
| parser.add_argument( |
| "--lora_epochs", |
| type=float, |
| default=1.0, |
| help="LoRA CE finetuning epochs after distill (0 = disable)", |
| ) |
| parser.add_argument( |
| "--lora_rank", |
| type=int, |
| default=8, |
| help="LoRA rank (r)", |
| ) |
| parser.add_argument( |
| "--lora_alpha", |
| type=float, |
| default=16.0, |
| help="LoRA alpha", |
| ) |
| parser.add_argument( |
| "--lora_dropout", |
| type=float, |
| default=0.0, |
| help="LoRA dropout", |
| ) |
| parser.add_argument( |
| "--lora_kl_enabled", |
| action="store_true", |
| help="Add KL regularization between pre/post LoRA logits", |
| ) |
| parser.add_argument( |
| "--lora_kl_weight", |
| type=float, |
| default=1e-1, |
| help="KL weight for LoRA regularization", |
| ) |
| parser.add_argument( |
| "--lora_kl_temp", |
| type=float, |
| default=4.0, |
| help="Temperature for LoRA KL regularization", |
| ) |
| parser.add_argument( |
| "--lora_target_modules", |
| nargs="*", |
| default=[ |
| "q_proj", |
| "k_proj", |
| "v_proj", |
| "o_proj", |
| "gate_proj", |
| "down_proj", |
| "up_proj", |
| ], |
| help="Module name suffixes to LoRA-wrap", |
| ) |
| parser.add_argument( |
| "--lora_respect_exclude_pairs", |
| action="store_true", |
| help=( |
| "When attaching LoRA adapters, skip linear modules under layers touched by " |
| "--exclude_pairs (i and i+1 for each excluded pair)." |
| ), |
| ) |
| parser.add_argument( |
| "--lora_lr", |
| type=float, |
| default=1e-4, |
| help="Learning rate for LoRA finetuning", |
| ) |
| parser.add_argument( |
| "--lora_weight_decay", |
| type=float, |
| default=0.0, |
| help="Weight decay for LoRA finetuning", |
| ) |
| parser.add_argument( |
| "--lora_max_grad_norm", |
| type=float, |
| default=1.0, |
| help="Max grad norm for LoRA finetuning", |
| ) |
| parser.add_argument( |
| "--lora_grad_accum_steps", |
| type=int, |
| default=1, |
| help="Gradient accumulation steps for LoRA finetuning", |
| ) |
| parser.add_argument( |
| "--lora_log_steps", |
| type=int, |
| default=100, |
| help="Log LoRA loss every N steps", |
| ) |
| parser.add_argument( |
| "--lora_eval_every", |
| type=int, |
| default=0, |
| help="Evaluate PPL every N LoRA steps (0 = disable)", |
| ) |
| parser.add_argument( |
| "--lora_eval_max_batches", |
| type=int, |
| default=None, |
| help="Max eval batches per dataset during LoRA (default: all)", |
| ) |
| parser.add_argument( |
| "--instruction_dataset", |
| default=None, |
| help="HF dataset name for alpaca-style instruction data", |
| ) |
| parser.add_argument( |
| "--instruction_config", |
| default=None, |
| help="Optional instruction dataset config", |
| ) |
| parser.add_argument( |
| "--instruction_split", |
| default="train", |
| help="Instruction dataset split", |
| ) |
| parser.add_argument( |
| "--instruction_field_instruction", |
| default="instruction", |
| help="Instruction field name", |
| ) |
| parser.add_argument( |
| "--instruction_field_input", |
| default="input", |
| help="Optional input field name", |
| ) |
| parser.add_argument( |
| "--instruction_field_output", |
| default="output", |
| help="Response/output field name", |
| ) |
| parser.add_argument( |
| "--auto_max_batches", |
| type=int, |
| default=0, |
| help="Max calibration batches for auto selection scoring (0 = all)", |
| ) |
| parser.add_argument( |
| "--auto_metric", |
| type=str, |
| choices=[ |
| "dwce", |
| "cosine", |
| "hybrid", |
| "hybrid_cosine", |
| "hybrid_global_rel", |
| ], |
| default="dwce", |
| help=( |
| "Auto pair scoring metric. 'dwce' uses downstream-weighted composition error; " |
| "'cosine' uses average token-level cosine distance between adjacent layer outputs; " |
| "'hybrid'/'hybrid_cosine' use DWCE to shortlist then adjacent cosine for final scoring; " |
| "'hybrid_global_rel' uses DWCE to shortlist then reranks by the change in " |
| "pair-to-final-layer cosine relation after surrogate fusion." |
| ), |
| ) |
| parser.add_argument( |
| "--auto_cosine_topk", |
| type=int, |
| default=3, |
| help="Top-K DWCE candidates to rescore with cosine in --auto_metric=hybrid", |
| ) |
| parser.add_argument( |
| "--auto_norm", |
| type=str, |
| choices=["relative", "none"], |
| default="relative", |
| help="Normalization mode for DWCE scoring (ignored for cosine)", |
| ) |
| parser.add_argument( |
| "--auto_dwce_mode", |
| type=str, |
| choices=["separate", "shared"], |
| default="separate", |
| help=( |
| "DWCE implementation for auto scoring. " |
| "'separate' runs distinct Fisher and DWCE backward passes; " |
| "'shared' reuses one backward pass and replays DWCE with cached gradients." |
| ), |
| ) |
| parser.add_argument( |
| "--num_progressive", |
| type=int, |
| default=0, |
| help="Number of progressive fusions (>0 required)", |
| ) |
| parser.add_argument( |
| "--resume_from_cycle", |
| type=int, |
| default=0, |
| help=( |
| "Resume from this completed cycle index. When > 0, --model should point " |
| "to the saved full model directory for that cycle." |
| ), |
| ) |
| parser.add_argument( |
| "--save_full_model_cycles", |
| nargs="*", |
| default=[], |
| help=( |
| "Cycle indices whose full models should be saved. Requesting cycle c " |
| "also saves cycle c-1 automatically (c=1 saves only cycle 1)." |
| ), |
| ) |
| return parser.parse_args() |
|
|
|
|
| def parse_exclude_pairs(exclude_raw: Optional[List[str]], num_pairs: int) -> List[int]: |
| """Parse --exclude_pairs into normalized pair indices for the current model. |
| |
| Indices refer to the start of an adjacent pair (i, i+1) and must be in [0..N-2]. |
| Negative indices count from the end (-1 = last pair). |
| """ |
| if not exclude_raw: |
| return [] |
| exclude: List[int] = [] |
| for item in exclude_raw: |
| if item is None: |
| continue |
| for part in str(item).split(","): |
| part = part.strip() |
| if not part: |
| continue |
| try: |
| idx = int(part) |
| except ValueError as exc: |
| raise SystemExit("--exclude_pairs must contain integers.") from exc |
| if idx < 0: |
| idx = num_pairs + idx |
| if 0 <= idx < num_pairs: |
| exclude.append(idx) |
| return sorted(set(exclude)) |
|
|
|
|
| def parse_cycle_list(raw_values: Optional[List[str]]) -> List[int]: |
| if not raw_values: |
| return [] |
| cycles: List[int] = [] |
| for item in raw_values: |
| if item is None: |
| continue |
| for part in str(item).split(","): |
| part = part.strip() |
| if not part: |
| continue |
| try: |
| cycles.append(int(part)) |
| except ValueError as exc: |
| raise SystemExit( |
| "--save_full_model_cycles must contain integers." |
| ) from exc |
| return cycles |
|
|
|
|
| def resolve_full_model_save_cycles( |
| requested_cycles: List[int], num_progressive: int |
| ) -> Set[int]: |
| resolved: Set[int] = set() |
| for cycle in requested_cycles: |
| if cycle <= 0 or cycle > num_progressive: |
| raise SystemExit( |
| "--save_full_model_cycles entries must be within [1, --num_progressive]." |
| ) |
| resolved.add(cycle) |
| if cycle > 1: |
| resolved.add(cycle - 1) |
| return resolved |
|
|
|
|
| def load_resume_metadata(model_path: str) -> Optional[Dict[str, object]]: |
| resume_meta_path = os.path.join(model_path, "resume_info.json") |
| if not os.path.exists(resume_meta_path): |
| return None |
| with open(resume_meta_path, "r", encoding="utf-8") as handle: |
| loaded = json.load(handle) |
| return loaded if isinstance(loaded, dict) else None |
|
|
|
|
| def build_generator(seed: int) -> torch.Generator: |
| generator = torch.Generator(device="cpu") |
| generator.manual_seed(int(seed)) |
| return generator |
|
|
|
|
| def capture_rng_state() -> Dict[str, object]: |
| state: Dict[str, object] = { |
| "python_random_state": random.getstate(), |
| "torch_cpu_rng_state": torch.get_rng_state(), |
| } |
| if np is not None: |
| state["numpy_random_state"] = np.random.get_state() |
| if torch.cuda.is_available(): |
| state["torch_cuda_rng_state_all"] = torch.cuda.get_rng_state_all() |
| return state |
|
|
|
|
| def restore_rng_state(state: Dict[str, object]) -> None: |
| python_state = state.get("python_random_state") |
| if python_state is not None: |
| random.setstate(python_state) |
|
|
| numpy_state = state.get("numpy_random_state") |
| if numpy_state is not None and np is not None: |
| np.random.set_state(numpy_state) |
|
|
| torch_cpu_state = state.get("torch_cpu_rng_state") |
| if torch_cpu_state is not None: |
| torch.set_rng_state(torch_cpu_state) |
|
|
| torch_cuda_state = state.get("torch_cuda_rng_state_all") |
| if torch_cuda_state is not None and torch.cuda.is_available(): |
| torch.cuda.set_rng_state_all(torch_cuda_state) |
|
|
|
|
| def save_rng_state(path: str) -> None: |
| torch.save(capture_rng_state(), path) |
|
|
|
|
| def load_rng_state(path: str) -> Optional[Dict[str, object]]: |
| if not os.path.exists(path): |
| return None |
| loaded = torch.load(path, map_location="cpu", weights_only=False) |
| return loaded if isinstance(loaded, dict) else None |
|
|
|
|
| def configure_reproducibility(seed: int) -> None: |
| random.seed(seed) |
| if np is not None: |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
| if torch.cuda.is_available(): |
| torch.cuda.manual_seed_all(seed) |
| if hasattr(torch.backends, "cudnn"): |
| torch.backends.cudnn.deterministic = True |
| torch.backends.cudnn.benchmark = False |
| if hasattr(torch, "use_deterministic_algorithms"): |
| torch.use_deterministic_algorithms(True, warn_only=True) |
|
|
|
|
| def save_loader_generator_state( |
| base_dir: str, |
| *, |
| distill_generator: Optional[torch.Generator] = None, |
| lora_generator: Optional[torch.Generator] = None, |
| ) -> None: |
| state: Dict[str, object] = {} |
| if distill_generator is not None: |
| state["distill_generator_state"] = distill_generator.get_state() |
| if lora_generator is not None: |
| state["lora_generator_state"] = lora_generator.get_state() |
| if state: |
| torch.save(state, os.path.join(base_dir, "loader_generators.pt")) |
|
|
|
|
| def load_loader_generator_state(base_dir: str) -> Optional[Dict[str, object]]: |
| path = os.path.join(base_dir, "loader_generators.pt") |
| if not os.path.exists(path): |
| return None |
| loaded = torch.load(path, map_location="cpu") |
| return loaded if isinstance(loaded, dict) else None |
|
|
|
|
| def resolve_layer_idx( |
| args: argparse.Namespace, |
| model, |
| layers: List[torch.nn.Module], |
| dataloader, |
| previous_scores, |
| start_index: int, |
| exclude_pairs: Set[int], |
| ): |
| layer_arg = str(getattr(args, "layer", "auto")).strip().lower() |
| selection_method = str(getattr(args, "selection_method", "dwce")).strip().lower() |
|
|
| if layer_arg != "auto": |
| try: |
| layer_idx = int(layer_arg) |
| except ValueError as exc: |
| raise SystemExit("--layer must be 'auto' or an integer index") from exc |
| num_pairs = max(len(layers) - 1, 0) |
| if layer_idx < 0: |
| layer_idx += num_pairs |
| if layer_idx in exclude_pairs: |
| raise SystemExit(f"--layer resolved to excluded pair index {layer_idx}") |
| return layer_idx, previous_scores, {"method": "manual", "exclude_pairs": sorted(exclude_pairs)} |
|
|
| if selection_method == "sequential": |
| num_pairs = len(layers) - 1 |
| for layer_idx in range(max(start_index, 0), num_pairs): |
| if layer_idx not in exclude_pairs: |
| return layer_idx, previous_scores, { |
| "method": "sequential", |
| "start_index": max(start_index, 0), |
| "exclude_pairs": sorted(exclude_pairs), |
| } |
| raise SystemExit("No eligible layer pairs remain after exclusions") |
|
|
| layer_idx, dwce_scores, dwce_meta = select_layer_auto( |
| model, |
| layers, |
| dataloader, |
| args, |
| previous_scores=previous_scores, |
| start_index=start_index, |
| exclude_pairs=exclude_pairs, |
| ) |
| return layer_idx, dwce_scores, dwce_meta |
|
|
|
|
| @dataclass |
| class PreparedData: |
| calib_loader: torch.utils.data.DataLoader |
| calib_num_sequences: int |
| distill_loader: Optional[torch.utils.data.DataLoader] |
| distill_generator: Optional[torch.Generator] |
| distill_meta: Dict[str, object] |
| lora_loader: Optional[torch.utils.data.DataLoader] |
| lora_generator: Optional[torch.Generator] |
| lora_meta: Dict[str, object] |
| eval_datasets: List[str] |
| eval_configs: List[Optional[str]] |
| eval_dataloaders: Optional[Dict[str, torch.utils.data.DataLoader]] |
|
|
|
|
| def resolve_eval_datasets(args: argparse.Namespace) -> Tuple[List[str], List[Optional[str]]]: |
| eval_datasets = args.eval_dataset or ["wikitext"] |
| eval_configs = args.eval_dataset_config or ["wikitext-2-raw-v1"] |
| eval_configs = ppl_eval._expand_dataset_configs(eval_datasets, eval_configs) |
| return eval_datasets, eval_configs |
|
|
|
|
| def run_ppl_eval( |
| model_id_or_path: str, |
| eval_datasets: List[str], |
| eval_configs: List[Optional[str]], |
| args: argparse.Namespace, |
| prepared_eval_dataloaders: Optional[Dict[str, torch.utils.data.DataLoader]] = None, |
| ) -> Dict[str, float]: |
| eval_device = args.eval_device or args.device |
| dtype = get_dtype(args.dtype) |
|
|
| eval_model = load_causal_lm( |
| model_id_or_path, |
| torch_dtype=dtype, |
| trust_remote_code=args.trust_remote_code, |
| ) |
|
|
| eval_model.to(eval_device) |
| if prepared_eval_dataloaders is not None: |
| results = ppl_eval.evaluate_ppl_dataloaders( |
| eval_model, |
| prepared_eval_dataloaders, |
| eval_device, |
| max_batches=args.eval_max_batches, |
| ) |
| else: |
| eval_batch_size = args.eval_batch_size or args.batch_size |
| eval_tokenizer = AutoTokenizer.from_pretrained( |
| model_id_or_path, trust_remote_code=args.trust_remote_code |
| ) |
| if eval_tokenizer.pad_token is None and eval_tokenizer.eos_token is not None: |
| eval_tokenizer.pad_token = eval_tokenizer.eos_token |
|
|
| results = ppl_eval.evaluate_ppl_datasets( |
| eval_model, |
| eval_tokenizer, |
| datasets=eval_datasets, |
| configs=eval_configs, |
| split=args.eval_split, |
| text_field=args.eval_text_field, |
| num_samples=args.eval_num_samples, |
| seq_len=args.eval_seq_len, |
| batch_size=eval_batch_size, |
| device=eval_device, |
| seed=args.seed, |
| shuffle=False, |
| model_family=args.eval_model_family, |
| add_bos=args.eval_add_bos, |
| max_batches=args.eval_max_batches, |
| cache_dir=args.eval_cache_dir, |
| num_workers=args.eval_num_workers, |
| ) |
|
|
| del eval_model |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
|
|
| return results |
|
|
|
|
| def build_calibration_dataloader( |
| args: argparse.Namespace, tokenizer |
| ) -> Tuple[List[str], List[torch.Tensor], torch.utils.data.DataLoader]: |
| if args.dataset: |
| datasets = list(args.dataset) |
| configs = expand_dataset_configs(datasets, list(args.dataset_config)) |
| chunks: List[torch.Tensor] = [] |
| for idx, (dataset_name, config) in enumerate(zip(datasets, configs)): |
| spec = SharedLMDataSpec( |
| dataset=dataset_name, |
| config=config, |
| split=args.dataset_split, |
| text_field=args.dataset_text_field, |
| seq_len=args.seq_len, |
| num_sequences=args.num_samples, |
| seed=args.seed + idx, |
| ) |
| chunks.extend(build_chunks(spec, tokenizer)) |
| if not chunks: |
| raise SystemExit("Not enough text to build token sequences.") |
| input_ids = torch.stack(chunks) |
| attention_mask = torch.ones_like(input_ids) |
| dataset = torch.utils.data.TensorDataset(input_ids, attention_mask) |
| dataloader = torch.utils.data.DataLoader( |
| dataset, batch_size=args.batch_size, shuffle=False |
| ) |
| return [], chunks, dataloader |
|
|
| texts = load_texts(args) |
| if not texts: |
| raise SystemExit( |
| "No calibration text found. Provide --dataset, --text, or --text_file." |
| ) |
|
|
| chunks = build_token_chunks(texts, tokenizer, args.seq_len, args.num_samples) |
| if not chunks: |
| raise SystemExit("Not enough text to build token sequences.") |
|
|
| input_ids = torch.stack(chunks) |
| attention_mask = torch.ones_like(input_ids) |
| dataset = torch.utils.data.TensorDataset(input_ids, attention_mask) |
| dataloader = torch.utils.data.DataLoader( |
| dataset, batch_size=args.batch_size, shuffle=False |
| ) |
| return texts, chunks, dataloader |
|
|
|
|
| def prepare_distillation_data( |
| args: argparse.Namespace, tokenizer, include_instruction: bool = True |
| ) -> Tuple[Optional[torch.utils.data.DataLoader], Optional[torch.Generator], Dict[str, object]]: |
| if ( |
| include_instruction |
| and args.distill_inst_samples != 0 |
| and not args.instruction_dataset |
| ): |
| print( |
| "Warning: --distill_inst_samples > 0 but no --instruction_dataset " |
| "provided; instruction distillation will be skipped." |
| ) |
|
|
| calib_texts: List[str] = [] |
| calib_dataset = None |
| if args.target_tokens > 0 and args.dataset: |
| datasets = list(args.dataset) |
| configs = expand_dataset_configs(datasets, list(args.dataset_config)) |
| per_dataset = args.target_tokens // len(datasets) |
| remainder = args.target_tokens % len(datasets) |
| calib_chunks: List[torch.Tensor] = [] |
| for idx, (dataset_name, config) in enumerate(zip(datasets, configs)): |
| dataset_tokens = per_dataset + (remainder if idx == 0 else 0) |
| spec = SharedLMDataSpec( |
| dataset=dataset_name, |
| config=config, |
| split=args.dataset_split, |
| text_field=args.dataset_text_field, |
| seq_len=args.distill_seq_len, |
| target_tokens=dataset_tokens, |
| seed=args.seed + 17 + idx, |
| ) |
| calib_chunks.extend(build_chunks(spec, tokenizer)) |
| if calib_chunks: |
| input_ids = torch.stack(calib_chunks) |
| attention_mask = torch.ones_like(input_ids) |
| calib_dataset = torch.utils.data.TensorDataset(input_ids, attention_mask) |
| else: |
| calib_texts = load_texts_from_datasets( |
| datasets=list(args.dataset), |
| configs=expand_dataset_configs(list(args.dataset), list(args.dataset_config)), |
| split=args.dataset_split, |
| text_field=args.dataset_text_field, |
| num_samples=args.distill_calib_samples, |
| seed=args.seed + 17, |
| ) |
| inst_records = [] |
| if include_instruction: |
| inst_records = load_instruction_records(args, args.distill_inst_samples) |
|
|
| distill_datasets = [] |
| if calib_dataset is not None: |
| distill_datasets.append(calib_dataset) |
| elif calib_texts: |
| calib_records = [{"text": text} for text in calib_texts] |
| distill_datasets.append( |
| FixedSeqDataset(calib_records, tokenizer, args.distill_seq_len) |
| ) |
| if inst_records: |
| distill_datasets.append( |
| FixedSeqDataset(inst_records, tokenizer, args.distill_seq_len) |
| ) |
|
|
| distill_meta: Dict[str, object] = { |
| "calib_texts": len(calib_texts), |
| "calib_sequences": len(calib_dataset) if calib_dataset is not None else len(calib_texts), |
| "inst_sequences": len(inst_records), |
| "total_sequences": 0, |
| } |
|
|
| if not distill_datasets: |
| return None, None, distill_meta |
| if len(distill_datasets) == 1: |
| distill_dataset = distill_datasets[0] |
| else: |
| distill_dataset = torch.utils.data.ConcatDataset(distill_datasets) |
|
|
| distill_meta["total_sequences"] = len(distill_dataset) |
| distill_generator = build_generator( |
| args.seed + 1000 + (1000000 if include_instruction else 0) |
| ) |
| distill_loader = torch.utils.data.DataLoader( |
| distill_dataset, |
| batch_size=args.distill_batch_size, |
| shuffle=True, |
| generator=distill_generator, |
| ) |
| return distill_loader, distill_generator, distill_meta |
|
|
|
|
| def prepare_eval_dataloaders( |
| args: argparse.Namespace, |
| tokenizer, |
| model: torch.nn.Module, |
| eval_datasets: List[str], |
| eval_configs: List[Optional[str]], |
| ) -> Optional[Dict[str, torch.utils.data.DataLoader]]: |
| needs_eval = (not args.skip_eval) or ( |
| (not args.skip_distill and args.distill_eval_every) |
| or (args.lora_epochs > 0 and args.lora_eval_every) |
| ) |
| if not needs_eval: |
| return None |
|
|
| eval_batch_size = args.eval_batch_size or args.batch_size |
| resolved_family = args.eval_model_family |
| if resolved_family == "auto": |
| resolved_family = ppl_eval._infer_model_family(model) |
|
|
| return ppl_eval.prepare_ppl_dataloaders( |
| tokenizer=tokenizer, |
| datasets=eval_datasets, |
| configs=eval_configs, |
| split=args.eval_split, |
| text_field=args.eval_text_field, |
| num_samples=args.eval_num_samples, |
| seq_len=args.eval_seq_len, |
| batch_size=eval_batch_size, |
| seed=args.seed, |
| shuffle=False, |
| model_family=resolved_family, |
| add_bos=args.eval_add_bos, |
| cache_dir=args.eval_cache_dir, |
| num_workers=args.eval_num_workers, |
| model=model, |
| ) |
|
|
|
|
| def prepare_all_data( |
| args: argparse.Namespace, |
| tokenizer, |
| model: torch.nn.Module, |
| eval_datasets: List[str], |
| eval_configs: List[Optional[str]], |
| loader_generator_state: Optional[Dict[str, object]] = None, |
| ) -> PreparedData: |
| texts, chunks, calib_loader = build_calibration_dataloader(args, tokenizer) |
| calib_num_sequences = len(chunks) |
| del texts |
| del chunks |
|
|
| distill_loader = None |
| distill_generator = None |
| distill_meta: Dict[str, object] = { |
| "calib_texts": 0, |
| "calib_sequences": 0, |
| "inst_sequences": 0, |
| "total_sequences": 0, |
| } |
| lora_loader = None |
| lora_generator = None |
| lora_meta: Dict[str, object] = { |
| "calib_texts": 0, |
| "calib_sequences": 0, |
| "inst_sequences": 0, |
| "total_sequences": 0, |
| } |
| if (not args.skip_distill) or bool(getattr(args, "comm_enabled", False)): |
| distill_loader, distill_generator, distill_meta = prepare_distillation_data( |
| args, tokenizer, include_instruction=False |
| ) |
| if ( |
| distill_generator is not None |
| and loader_generator_state is not None |
| and loader_generator_state.get("distill_generator_state") is not None |
| ): |
| distill_generator.set_state(loader_generator_state["distill_generator_state"]) |
| if args.lora_epochs > 0: |
| lora_loader, lora_generator, lora_meta = prepare_distillation_data( |
| args, tokenizer, include_instruction=True |
| ) |
| if ( |
| lora_generator is not None |
| and loader_generator_state is not None |
| and loader_generator_state.get("lora_generator_state") is not None |
| ): |
| lora_generator.set_state(loader_generator_state["lora_generator_state"]) |
|
|
| eval_dataloaders = prepare_eval_dataloaders( |
| args, tokenizer, model, eval_datasets, eval_configs |
| ) |
|
|
| return PreparedData( |
| calib_loader=calib_loader, |
| calib_num_sequences=calib_num_sequences, |
| distill_loader=distill_loader, |
| distill_generator=distill_generator, |
| distill_meta=distill_meta, |
| lora_loader=lora_loader, |
| lora_generator=lora_generator, |
| lora_meta=lora_meta, |
| eval_datasets=eval_datasets, |
| eval_configs=eval_configs, |
| eval_dataloaders=eval_dataloaders, |
| ) |
|
|
|
|
| def evaluate_ppl_model( |
| model: torch.nn.Module, |
| tokenizer, |
| eval_datasets: List[str], |
| eval_configs: List[Optional[str]], |
| args: argparse.Namespace, |
| max_batches: Optional[int] = None, |
| prepared_eval_dataloaders: Optional[Dict[str, torch.utils.data.DataLoader]] = None, |
| ) -> Dict[str, float]: |
| eval_device = args.eval_device or args.device |
| prev_mode = model.training |
| try: |
| prev_device = next(model.parameters()).device |
| except StopIteration: |
| prev_device = torch.device(eval_device) |
|
|
| model.eval() |
| if str(prev_device) != eval_device: |
| model.to(eval_device) |
|
|
| if prepared_eval_dataloaders is not None: |
| results = ppl_eval.evaluate_ppl_dataloaders( |
| model, |
| prepared_eval_dataloaders, |
| eval_device, |
| max_batches=max_batches if max_batches is not None else args.eval_max_batches, |
| ) |
| else: |
| eval_batch_size = args.eval_batch_size or args.batch_size |
| results = ppl_eval.evaluate_ppl_datasets( |
| model, |
| tokenizer, |
| datasets=eval_datasets, |
| configs=eval_configs, |
| split=args.eval_split, |
| text_field=args.eval_text_field, |
| num_samples=args.eval_num_samples, |
| seq_len=args.eval_seq_len, |
| batch_size=eval_batch_size, |
| device=eval_device, |
| seed=args.seed, |
| shuffle=False, |
| model_family=args.eval_model_family, |
| add_bos=args.eval_add_bos, |
| max_batches=max_batches if max_batches is not None else args.eval_max_batches, |
| cache_dir=args.eval_cache_dir, |
| num_workers=args.eval_num_workers, |
| ) |
|
|
| if prev_mode: |
| model.train() |
| if str(prev_device) != eval_device: |
| model.to(prev_device) |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
|
|
| return results |
|
|
|
|
| def has_post_fusion_data( |
| distill_loader: Optional[torch.utils.data.DataLoader], |
| distill_meta: Optional[Dict[str, object]], |
| ) -> bool: |
| if distill_loader is None or distill_meta is None: |
| return False |
| return distill_meta.get("total_sequences", 0) > 0 |
|
|
|
|
| def summarize_gate_lambdas(gates: Dict[str, torch.Tensor]) -> Dict[str, object]: |
| if not gates: |
| return {"num_tensors": 0, "num_elements": 0} |
|
|
| total_sum = 0.0 |
| total_elems = 0 |
| per_tensor_mean: Dict[str, Optional[float]] = {} |
| for name, gate in gates.items(): |
| g = gate.detach().float() |
| if g.numel() == 0: |
| per_tensor_mean[name] = None |
| continue |
| per_tensor_mean[name] = float(g.mean().item()) |
| total_sum += float(g.sum().item()) |
| total_elems += int(g.numel()) |
|
|
| global_mean = None if total_elems == 0 else total_sum / float(total_elems) |
| return { |
| "num_tensors": len(gates), |
| "num_elements": total_elems, |
| "global_mean": global_mean, |
| "per_tensor_mean": per_tensor_mean, |
| } |
|
|
|
|
| def compute_path_bytes(path: str) -> int: |
| if os.path.isfile(path): |
| return os.path.getsize(path) |
|
|
| total = 0 |
| for root, _, files in os.walk(path): |
| for name in files: |
| file_path = os.path.join(root, name) |
| if os.path.islink(file_path): |
| continue |
| try: |
| total += os.path.getsize(file_path) |
| except OSError: |
| continue |
| return total |
|
|
|
|
| def save_stage_checkpoint( |
| model: torch.nn.Module, |
| tokenizer, |
| stage_dir: str, |
| stage_name: str, |
| ppl_results: Optional[Dict[str, float]], |
| ) -> Dict[str, object]: |
| os.makedirs(stage_dir, exist_ok=True) |
| colon_modules = find_colon_modules(model) |
| if colon_modules: |
| raise RuntimeError( |
| "Unexpected module names with ':' detected before save: " |
| f"{', '.join(colon_modules)}." |
| ) |
|
|
| model.save_pretrained(stage_dir) |
| tokenizer.save_pretrained(stage_dir) |
|
|
| stage_meta = { |
| "stage": stage_name, |
| "path": stage_dir, |
| "weight_bytes": compute_path_bytes(stage_dir), |
| "post_ppl": ppl_results, |
| } |
| with open( |
| os.path.join(stage_dir, "stage_metrics.json"), |
| "w", |
| encoding="utf-8", |
| ) as handle: |
| json.dump(stage_meta, handle, indent=2) |
| return stage_meta |
|
|
|
|
| def save_cycle_full_model( |
| model: torch.nn.Module, |
| tokenizer, |
| cycle_dir: str, |
| cycle_idx: int, |
| args: argparse.Namespace, |
| ppl_results: Optional[Dict[str, float]], |
| ) -> Dict[str, object]: |
| full_model_dir = os.path.join(cycle_dir, "full_model") |
| stage_meta = save_stage_checkpoint( |
| model=model, |
| tokenizer=tokenizer, |
| stage_dir=full_model_dir, |
| stage_name=f"cycle_{cycle_idx}_full_model", |
| ppl_results=ppl_results, |
| ) |
| resume_meta = { |
| "base_model": getattr(args, "base_model_id", args.model), |
| "cycle": cycle_idx, |
| "output_dir": args.output_dir, |
| "layer_path": args.layer_path, |
| "rng_state": "rng_state.pt", |
| "loader_generators": "loader_generators.pt", |
| } |
| with open( |
| os.path.join(full_model_dir, "resume_info.json"), |
| "w", |
| encoding="utf-8", |
| ) as handle: |
| json.dump(resume_meta, handle, indent=2) |
| stage_meta["resume_info"] = "resume_info.json" |
| return stage_meta |
|
|
|
|
| def run_lora_phase( |
| model: torch.nn.Module, |
| tokenizer, |
| eval_datasets: List[str], |
| eval_configs: List[Optional[str]], |
| args: argparse.Namespace, |
| lora_loader: Optional[torch.utils.data.DataLoader] = None, |
| lora_meta: Optional[Dict[str, object]] = None, |
| eval_dataloaders: Optional[Dict[str, torch.utils.data.DataLoader]] = None, |
| cycle_idx: Optional[int] = None, |
| num_cycles: Optional[int] = None, |
| ) -> List[Dict[str, object]]: |
| lora_eval_history: List[Dict[str, object]] = [] |
| if args.lora_epochs <= 0: |
| return lora_eval_history |
| if not has_post_fusion_data(lora_loader, lora_meta): |
| print("No post-fusion sequences built; skipping LoRA finetuning.") |
| return lora_eval_history |
|
|
| lora_ce_finetune( |
| model=model, |
| dataloader=lora_loader, |
| eval_tokenizer=tokenizer, |
| eval_datasets=eval_datasets, |
| eval_configs=eval_configs, |
| eval_history=lora_eval_history, |
| args=args, |
| eval_dataloaders=eval_dataloaders, |
| progressive_cycle=cycle_idx, |
| progressive_total=num_cycles, |
| ) |
| return lora_eval_history |
|
|
|
|
| def run_progressive( |
| args: argparse.Namespace, |
| model: torch.nn.Module, |
| tokenizer, |
| prepared: PreparedData, |
| ) -> None: |
| eval_datasets = prepared.eval_datasets |
| eval_configs = prepared.eval_configs |
|
|
| dataloader = prepared.calib_loader |
| num_sequences = prepared.calib_num_sequences |
|
|
| model.to(args.device) |
|
|
| os.makedirs(args.output_dir, exist_ok=True) |
| progressive_meta_path = os.path.join(args.output_dir, "progressive_metadata.json") |
| existing_meta: Dict[str, object] = {} |
| if args.resume_from_cycle > 0 and os.path.exists(progressive_meta_path): |
| with open(progressive_meta_path, "r", encoding="utf-8") as handle: |
| loaded_meta = json.load(handle) |
| if isinstance(loaded_meta, dict): |
| existing_meta = loaded_meta |
|
|
| bootstrap_meta = { |
| "base_model": getattr(args, "base_model_id", args.model), |
| "num_progressive": args.num_progressive, |
| "layer_path": args.layer_path, |
| "resume_from_cycle": args.resume_from_cycle, |
| "save_full_model_cycles": sorted(args.full_model_save_cycles), |
| "cycles": ( |
| existing_meta.get("cycles", []) |
| if isinstance(existing_meta.get("cycles"), list) |
| else [] |
| ), |
| } |
| with open( |
| progressive_meta_path, |
| "w", |
| encoding="utf-8", |
| ) as handle: |
| json.dump(bootstrap_meta, handle, indent=2) |
|
|
| pre_eval = None |
| if not args.skip_eval: |
| pre_eval = evaluate_ppl_model( |
| model, |
| tokenizer, |
| eval_datasets, |
| eval_configs, |
| args, |
| prepared_eval_dataloaders=prepared.eval_dataloaders, |
| ) |
| print("Pre-pruning perplexity:") |
| for dataset_name, ppl in pre_eval.items(): |
| print(f"{dataset_name}: {ppl:.4f}") |
|
|
| parent, name, container = find_layer_container(model, args.layer_path) |
| layers = list(container) |
| if args.num_progressive > (len(layers) - 1 + args.resume_from_cycle): |
| raise SystemExit( |
| f"--num_progressive ({args.num_progressive}) exceeds available pairs " |
| f"after resume offset ({len(layers) - 1 + args.resume_from_cycle})" |
| ) |
|
|
| dwce_scores = None |
| dwce_meta = None |
| last_fused_idx = 0 |
| cycle_summaries: List[Dict[str, object]] = [] |
| existing_cycles = existing_meta.get("cycles", []) |
| if isinstance(existing_cycles, list): |
| for entry in existing_cycles: |
| if not isinstance(entry, dict): |
| continue |
| cycle_value = entry.get("cycle") |
| if isinstance(cycle_value, int) and cycle_value <= args.resume_from_cycle: |
| cycle_summaries.append(entry) |
|
|
| comm_enabled = bool(getattr(args, "comm_enabled", False)) |
| comm_teacher_model = None |
| comm_teacher_cycle: Optional[int] = None |
| teacher_device = args.distill_teacher_device or args.device |
| previous_cycle_teacher_model = None |
| previous_cycle_teacher_cycle: Optional[int] = None |
|
|
| def _release_comm_teacher() -> None: |
| nonlocal comm_teacher_model, comm_teacher_cycle |
| if comm_teacher_model is not None: |
| del comm_teacher_model |
| comm_teacher_model = None |
| comm_teacher_cycle = None |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
|
|
| def _release_previous_cycle_teacher() -> None: |
| nonlocal previous_cycle_teacher_model, previous_cycle_teacher_cycle |
| if previous_cycle_teacher_model is not None: |
| del previous_cycle_teacher_model |
| previous_cycle_teacher_model = None |
| previous_cycle_teacher_cycle = None |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
|
|
| def _snapshot_previous_cycle_teacher(cycle_idx: int) -> None: |
| nonlocal previous_cycle_teacher_model, previous_cycle_teacher_cycle |
| _release_previous_cycle_teacher() |
| previous_cycle_teacher_model = copy.deepcopy(model) |
| previous_cycle_teacher_model.to(teacher_device) |
| previous_cycle_teacher_model.eval() |
| previous_cycle_teacher_cycle = cycle_idx |
|
|
| def _get_previous_cycle_teacher( |
| cycle_idx: int, |
| ) -> Tuple[Optional[torch.nn.Module], str, Optional[int]]: |
| prev_cycle = cycle_idx - 1 |
| if prev_cycle <= 0: |
| return None, "base_model", 0 |
| if ( |
| previous_cycle_teacher_model is not None |
| and previous_cycle_teacher_cycle == prev_cycle |
| ): |
| return previous_cycle_teacher_model, "previous_cycle_memory", prev_cycle |
| teacher_model = load_progressive_model( |
| getattr(args, "base_model_id", args.model), |
| args.output_dir, |
| cycle=prev_cycle, |
| device=teacher_device, |
| dtype=args.dtype, |
| trust_remote_code=args.trust_remote_code, |
| layer_path=args.layer_path, |
| ) |
| teacher_model.eval() |
| return teacher_model, "previous_cycle_disk", prev_cycle |
|
|
| def _get_comm_teacher(cycle_idx: int) -> Tuple[Optional[torch.nn.Module], str, Optional[int]]: |
| nonlocal comm_teacher_model, comm_teacher_cycle |
| if not comm_enabled: |
| return None, "disabled", None |
|
|
| source = str(getattr(args, "redistrib_teacher_source", "base_model")) |
| if source == "base_model": |
| if comm_teacher_model is None: |
| print( |
| "[comm] Loading fixed base teacher for anchor loss " |
| f"(device={teacher_device})." |
| ) |
| comm_teacher_model = AutoModelForCausalLM.from_pretrained( |
| getattr(args, "base_model_id", args.model), |
| torch_dtype=get_dtype(args.dtype), |
| trust_remote_code=args.trust_remote_code, |
| ) |
| comm_teacher_model.to(teacher_device) |
| comm_teacher_model.eval() |
| comm_teacher_cycle = 0 |
| return comm_teacher_model, "base_model", 0 |
|
|
| prev_cycle = cycle_idx - 1 |
| if prev_cycle <= 0: |
| if comm_teacher_model is None or comm_teacher_cycle != 0: |
| _release_comm_teacher() |
| print( |
| "[comm] --redistrib_teacher_source=previous_cycle but cycle 1 " |
| "has no prior checkpoint; using base teacher." |
| ) |
| comm_teacher_model = AutoModelForCausalLM.from_pretrained( |
| getattr(args, "base_model_id", args.model), |
| torch_dtype=get_dtype(args.dtype), |
| trust_remote_code=args.trust_remote_code, |
| ) |
| comm_teacher_model.to(teacher_device) |
| comm_teacher_model.eval() |
| comm_teacher_cycle = 0 |
| return comm_teacher_model, "base_model", 0 |
|
|
| if ( |
| previous_cycle_teacher_model is not None |
| and previous_cycle_teacher_cycle == prev_cycle |
| ): |
| if comm_teacher_model is not previous_cycle_teacher_model: |
| _release_comm_teacher() |
| comm_teacher_model = previous_cycle_teacher_model |
| comm_teacher_cycle = prev_cycle |
| return comm_teacher_model, "previous_cycle_memory", prev_cycle |
|
|
| if comm_teacher_model is None or comm_teacher_cycle != prev_cycle: |
| _release_comm_teacher() |
| print( |
| "[comm] Loading teacher from previous cycle " |
| f"{prev_cycle} (device={teacher_device})." |
| ) |
| comm_teacher_model = load_progressive_model( |
| getattr(args, "base_model_id", args.model), |
| args.output_dir, |
| cycle=prev_cycle, |
| device=teacher_device, |
| dtype=args.dtype, |
| trust_remote_code=args.trust_remote_code, |
| layer_path=args.layer_path, |
| ) |
| comm_teacher_model.eval() |
| comm_teacher_cycle = prev_cycle |
| return comm_teacher_model, "previous_cycle_disk", prev_cycle |
|
|
| if args.resume_from_cycle > 0: |
| _snapshot_previous_cycle_teacher(args.resume_from_cycle) |
|
|
| start_cycle = args.resume_from_cycle + 1 |
| for cycle_idx in range(start_cycle, args.num_progressive + 1): |
| print(f"[progressive] Cycle {cycle_idx}/{args.num_progressive}") |
| run_comm = comm_enabled and ( |
| cycle_idx > 1 or bool(getattr(args, "comm_include_cycle1", False)) |
| ) |
| comm_stats: Dict[str, object] = {"enabled": False} |
| comm_post_eval = None |
| if run_comm: |
| |
| start_index = 0 |
| reuse_scores = None |
| else: |
| start_index = last_fused_idx if cycle_idx > 1 else 0 |
| reuse_scores = dwce_scores |
| exclude_pairs = set(parse_exclude_pairs(args.exclude_pairs, max(len(layers) - 1, 0))) |
| layer_idx, dwce_scores, dwce_meta = resolve_layer_idx( |
| args, |
| model, |
| layers, |
| dataloader, |
| reuse_scores, |
| start_index, |
| exclude_pairs, |
| ) |
|
|
| if run_comm: |
| dwce_scores_pre_comm = dwce_scores |
| if prepared.calib_loader is None: |
| print( |
| "[comm] Enabled but no calibration sequences were built; skipping." |
| ) |
| else: |
| ( |
| comm_teacher_model_loaded, |
| comm_teacher_source, |
| comm_teacher_cycle_idx, |
| ) = _get_comm_teacher(cycle_idx) |
| if comm_teacher_model_loaded is None: |
| raise RuntimeError("comm_enabled but teacher model was not loaded.") |
| comm_stats = commutator_precondition( |
| student_model=model, |
| student_layers=layers, |
| teacher_model=comm_teacher_model_loaded, |
| dataloader=prepared.calib_loader, |
| dwce_scores=dwce_scores_pre_comm, |
| exclude_pairs=exclude_pairs, |
| args=args, |
| progressive_cycle=cycle_idx, |
| progressive_total=args.num_progressive, |
| ) |
| if comm_stats.get("enabled"): |
| comm_stats["teacher_source"] = comm_teacher_source |
| comm_stats["teacher_cycle"] = comm_teacher_cycle_idx |
| comm_stats["dwce_scores_pre"] = dwce_scores_pre_comm |
| print( |
| "[comm] Done:" |
| f" opt_steps={comm_stats.get('opt_steps')}" |
| f" lr={comm_stats.get('lr')}" |
| ) |
| if not args.skip_eval: |
| comm_post_eval = evaluate_ppl_model( |
| model, |
| tokenizer, |
| eval_datasets, |
| eval_configs, |
| args, |
| prepared_eval_dataloaders=prepared.eval_dataloaders, |
| ) |
| comm_stats["post_ppl"] = comm_post_eval |
| print(f"[progressive] Cycle {cycle_idx} post-comm perplexity:") |
| for dataset_name, ppl in comm_post_eval.items(): |
| print(f"{dataset_name}: {ppl:.4f}") |
|
|
| if bool(getattr(args, "comm_skip_post_reselect", False)): |
| comm_stats["post_selection_recomputed"] = False |
| comm_stats["selected_layer_post"] = int(layer_idx) |
| print( |
| "[comm] Keeping pre-comm DWCE pair selection for fusion." |
| ) |
| else: |
| print( |
| "[comm] Recomputing DWCE after preconditioning for fusion selection." |
| ) |
| layer_idx, dwce_scores, dwce_meta = resolve_layer_idx( |
| args, |
| model, |
| layers, |
| dataloader, |
| None, |
| 0, |
| exclude_pairs, |
| ) |
| comm_stats["post_selection_recomputed"] = True |
| comm_stats["selected_layer_post"] = int(layer_idx) |
|
|
| if layer_idx < 0 or layer_idx >= len(layers) - 1: |
| raise SystemExit("--layer must be in [0, num_layers-2]") |
|
|
| num_layers_before = len(layers) |
| layer_a = layers[layer_idx] |
| layer_b = layers[layer_idx + 1] |
|
|
| norm1_state = None |
| norm2_state = None |
| norm1, norm2, norm_names = get_norm_pair(layer_a) |
| if norm1 is not None: |
| norm1_state = clone_state_dict(norm1) |
| if norm2 is not None: |
| norm2_state = clone_state_dict(norm2) |
|
|
| attn_a = find_attention_module(layer_a) |
| attn_b = find_attention_module(layer_b) |
| hidden_size = getattr(model.config, "hidden_size", None) |
| if hidden_size is None: |
| hidden_size = getattr(model.config, "n_embd", None) |
| if hidden_size is None: |
| raise SystemExit("Model config missing hidden_size/n_embd") |
|
|
| no_head_permute_merge = bool( |
| getattr(args, "no_head_permute_merge", False) |
| or getattr(args, "no_head_permute", False) |
| ) |
| if no_head_permute_merge: |
| print("[fuse] Head permutation disabled; merging with original head order.") |
| else: |
| mean_a, mean_b, num_heads, num_kv_heads, head_dim = compute_head_means( |
| model, |
| attn_a, |
| attn_b, |
| dataloader, |
| args.device, |
| hidden_size, |
| ) |
|
|
| perm = build_head_permutation( |
| mean_a, |
| mean_b, |
| num_heads=num_heads, |
| num_kv_heads=num_kv_heads, |
| eps=args.eps, |
| ) |
| permute_attention_heads( |
| attn_b, perm, num_heads, num_kv_heads, head_dim=head_dim |
| ) |
|
|
| fisher_sums, num_batches, param_numels = compute_fisher( |
| model, |
| layer_a, |
| layer_b, |
| dataloader, |
| fisher_mode=args.fisher_mode, |
| device=args.device, |
| ) |
|
|
| distill_ready = has_post_fusion_data( |
| prepared.distill_loader, prepared.distill_meta |
| ) |
| teacher_cycle = cycle_idx - 1 |
| teacher_source = "previous_cycle" if teacher_cycle > 0 else "base_model" |
|
|
| merge_method = "fisher" |
| distill_method = str(getattr(args, "distill_method", "reparam")) |
| reparam_stats: Optional[Dict[str, object]] = None |
| reparam_gate_summary: Optional[Dict[str, object]] = None |
|
|
| needs_teacher_for_reparam = ( |
| (not args.skip_distill) |
| and distill_ready |
| and float(args.distill_epochs) > 0.0 |
| ) |
| teacher_model = None |
| teacher_parent = None |
| teacher_layer_attr = None |
| teacher_layers: Optional[List[torch.nn.Module]] = None |
| teacher_from_cache = False |
| if needs_teacher_for_reparam: |
| teacher_model, teacher_source, teacher_cycle = _get_previous_cycle_teacher( |
| cycle_idx |
| ) |
| teacher_from_cache = ( |
| teacher_source == "previous_cycle_memory" |
| and teacher_model is previous_cycle_teacher_model |
| ) |
| if teacher_model is None: |
| teacher_model = load_causal_lm( |
| getattr(args, "base_model_id", args.model), |
| torch_dtype=get_dtype(args.dtype), |
| trust_remote_code=args.trust_remote_code, |
| cache_dir=args.model_cache_dir, |
| ) |
| teacher_model.to(teacher_device) |
| teacher_model.eval() |
| teacher_source = "base_model" |
| teacher_cycle = 0 |
| teacher_parent, teacher_layer_attr, teacher_container = find_layer_container( |
| teacher_model, args.layer_path |
| ) |
| teacher_layers = list(teacher_container) |
|
|
| do_reparam = ( |
| (not args.skip_distill) |
| and distill_ready |
| and prepared.distill_loader is not None |
| ) |
| if (not args.skip_distill) and not do_reparam: |
| print("[reparam] No distillation sequences built; skipping reparam distill.") |
| distill_post_eval = None |
| if do_reparam: |
| lambda_source = "fisher_prior" |
| reparam_gate_targets: Dict[str, object] = compute_fisher_gate_priors( |
| layer_a=layer_a, |
| layer_b=layer_b, |
| fisher_a=fisher_sums[0], |
| fisher_b=fisher_sums[1], |
| num_batches=num_batches, |
| numels_a=param_numels[0], |
| numels_b=param_numels[1], |
| fisher_mode=args.fisher_mode, |
| eps=float(args.eps), |
| ) |
| if not reparam_gate_targets: |
| raise SystemExit("[reparam] No mergeable parameters found; cannot continue.") |
| if float(args.distill_epochs) > 0.0 and ( |
| teacher_model is None or teacher_layers is None |
| ): |
| raise SystemExit("--distill_method reparam requires a teacher model.") |
| print( |
| f"[reparam] Cycle {cycle_idx}: training U + gates for pair " |
| f"{layer_idx}-{layer_idx + 1} (epochs={args.distill_epochs}, " |
| f"hidden_mse_w={args.distill_hidden_mse_weight}, " |
| f"attn_mse_w={args.distill_attn_mse_weight}, " |
| f"mlp_mse_w={args.distill_mlp_mse_weight}, " |
| f"eta={args.reparam_eta}, gamma={args.reparam_gamma}, " |
| f"attn_reg_scale={args.reparam_attn_reg_scale}, " |
| f"mlp_reg_scale={args.reparam_mlp_reg_scale}, " |
| f"param_subset={args.reparam_param_subset}, " |
| f"lambda_init={lambda_source})." |
| ) |
| merged, final_gates, reparam_stats = distill_reparam_merge( |
| student_model=model, |
| student_parent=parent, |
| student_layer_attr=name, |
| student_layers=layers, |
| teacher_model=teacher_model, |
| teacher_parent=teacher_parent, |
| teacher_layer_attr=teacher_layer_attr, |
| teacher_layers=teacher_layers, |
| layer_idx=layer_idx, |
| gate_lambdas=reparam_gate_targets, |
| dataloader=prepared.distill_loader, |
| args=args, |
| progressive_cycle=cycle_idx, |
| progressive_total=args.num_progressive, |
| ) |
| reparam_gate_summary = summarize_gate_lambdas(final_gates) |
| merge_method = "reparam" |
| if reparam_stats is not None: |
| reparam_stats["lambda_init"] = lambda_source |
| else: |
| merged = merge_layers( |
| layer_a, |
| layer_b, |
| fisher_sums[0], |
| fisher_sums[1], |
| num_batches, |
| param_numels[0], |
| param_numels[1], |
| fisher_mode=args.fisher_mode, |
| eps=args.eps, |
| ) |
|
|
| apply_norm_policy( |
| layer_a, |
| args.norm_policy, |
| norm1_state, |
| norm2_state, |
| norm_names, |
| ) |
| if teacher_model is not None and not teacher_from_cache: |
| del teacher_model |
| teacher_model = None |
| teacher_parent = None |
| teacher_layer_attr = None |
| teacher_layers = None |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
|
|
| new_container = drop_layer(container, layer_idx + 1) |
| setattr(parent, name, new_container) |
| decrement_config(model.config) |
| layers = list(new_container) |
|
|
| lora_post_eval = None |
| if (not args.skip_eval) and (not args.skip_distill) and do_reparam: |
| distill_post_eval = evaluate_ppl_model( |
| model, |
| tokenizer, |
| eval_datasets, |
| eval_configs, |
| args, |
| prepared_eval_dataloaders=prepared.eval_dataloaders, |
| ) |
| print(f"[progressive] Cycle {cycle_idx} post-distill perplexity:") |
| for dataset_name, ppl in distill_post_eval.items(): |
| print(f"{dataset_name}: {ppl:.4f}") |
|
|
| post_eval = None |
| if not args.skip_eval: |
| if distill_post_eval is not None: |
| post_eval = distill_post_eval |
| else: |
| post_eval = evaluate_ppl_model( |
| model, |
| tokenizer, |
| eval_datasets, |
| eval_configs, |
| args, |
| prepared_eval_dataloaders=prepared.eval_dataloaders, |
| ) |
| print(f"[progressive] Cycle {cycle_idx} perplexity:") |
| for dataset_name, ppl in post_eval.items(): |
| print(f"{dataset_name}: {ppl:.4f}") |
|
|
| cycle_dir = os.path.join(args.output_dir, f"cycle_{cycle_idx}") |
| os.makedirs(cycle_dir, exist_ok=True) |
| fused_layer_file = "fused_layer.pt" |
| fused_layer_path = os.path.join(cycle_dir, fused_layer_file) |
| torch.save(layers[layer_idx].state_dict(), fused_layer_path) |
|
|
| cycle_meta: Dict[str, object] = { |
| "cycle": cycle_idx, |
| "layer_merged": layer_idx, |
| "num_layers_before": num_layers_before, |
| "num_layers_after": num_layers_before - 1, |
| "fused_layer_state": fused_layer_file, |
| "dwce_score": dwce_scores[layer_idx] if dwce_scores else None, |
| "dwce_scores": dwce_scores, |
| "dwce_meta": dwce_meta, |
| "fisher_num_batches": num_batches, |
| "merge_method": merge_method, |
| "merged_params": merged, |
| "num_sequences": num_sequences, |
| "teacher_source": teacher_source, |
| "teacher_cycle": teacher_cycle, |
| "eval": { |
| "datasets": eval_datasets, |
| "configs": eval_configs, |
| "split": args.eval_split, |
| "num_samples": args.eval_num_samples, |
| "seq_len": args.eval_seq_len, |
| "post_ppl": post_eval, |
| }, |
| "comm": comm_stats, |
| "distill": { |
| "enabled": not args.skip_distill, |
| "method": distill_method, |
| "calib_samples": args.distill_calib_samples, |
| "inst_samples": args.distill_inst_samples, |
| "seq_len": args.distill_seq_len, |
| "batch_size": args.distill_batch_size, |
| "epochs": args.distill_epochs, |
| "lr": args.distill_lr, |
| "kl_weight": args.distill_kl_weight, |
| "kl_temp": args.distill_kl_temp, |
| "hidden_mse_weight": args.distill_hidden_mse_weight, |
| "attn_mse_weight": args.distill_attn_mse_weight, |
| "mlp_mse_weight": args.distill_mlp_mse_weight, |
| "reparam_eta": args.reparam_eta, |
| "reparam_gamma": args.reparam_gamma, |
| "reparam_attn_reg_scale": args.reparam_attn_reg_scale, |
| "reparam_mlp_reg_scale": args.reparam_mlp_reg_scale, |
| "reparam_param_subset": args.reparam_param_subset, |
| "reparam_stats": reparam_stats, |
| "reparam_gate_summary": reparam_gate_summary, |
| "post_ppl": distill_post_eval, |
| "weight_decay": args.distill_weight_decay, |
| "max_grad_norm": args.distill_max_grad_norm, |
| "grad_accum_steps": args.distill_grad_accum_steps, |
| "instruction_dataset": args.instruction_dataset, |
| "instruction_config": args.instruction_config, |
| "instruction_split": args.instruction_split, |
| }, |
| "lora": { |
| "enabled": args.lora_epochs > 0, |
| "seq_len": args.distill_seq_len, |
| "batch_size": args.distill_batch_size, |
| "epochs": args.lora_epochs, |
| "rank": args.lora_rank, |
| "alpha": args.lora_alpha, |
| "dropout": args.lora_dropout, |
| "target_modules": args.lora_target_modules, |
| "respect_exclude_pairs": args.lora_respect_exclude_pairs, |
| "kl_enabled": args.lora_kl_enabled, |
| "kl_weight": args.lora_kl_weight, |
| "kl_temp": args.lora_kl_temp, |
| "post_ppl": lora_post_eval, |
| "lr": args.lora_lr, |
| "weight_decay": args.lora_weight_decay, |
| "max_grad_norm": args.lora_max_grad_norm, |
| "grad_accum_steps": args.lora_grad_accum_steps, |
| "log_steps": args.lora_log_steps, |
| "eval_every": args.lora_eval_every, |
| "eval_max_batches": args.lora_eval_max_batches, |
| }, |
| "norm_policy": args.norm_policy, |
| } |
|
|
| saved_full_model_dir = None |
| if cycle_idx in args.full_model_save_cycles: |
| cycle_meta["full_model_saved"] = True |
| cycle_meta["full_model"] = save_cycle_full_model( |
| model=model, |
| tokenizer=tokenizer, |
| cycle_dir=cycle_dir, |
| cycle_idx=cycle_idx, |
| args=args, |
| ppl_results=post_eval, |
| ) |
| saved_full_model_dir = os.path.join(cycle_dir, "full_model") |
| else: |
| cycle_meta["full_model_saved"] = False |
|
|
| with open( |
| os.path.join(cycle_dir, "cycle_metadata.json"), |
| "w", |
| encoding="utf-8", |
| ) as handle: |
| json.dump(cycle_meta, handle, indent=2) |
|
|
| cycle_summaries.append( |
| { |
| "cycle": cycle_idx, |
| "layer_merged": layer_idx, |
| "dwce_score": dwce_scores[layer_idx] if dwce_scores else None, |
| "comm_post_ppl": comm_post_eval, |
| "distill_post_ppl": distill_post_eval, |
| "lora_post_ppl": lora_post_eval, |
| "post_ppl": post_eval, |
| "cycle_dir": f"cycle_{cycle_idx}", |
| } |
| ) |
|
|
| last_fused_idx = layer_idx |
| _snapshot_previous_cycle_teacher(cycle_idx) |
|
|
| parent, name, container = find_layer_container(model, args.layer_path) |
| layers = list(container) |
| if dwce_scores: |
| dwce_scores = dwce_scores[: max(len(layers) - 1, 0)] |
|
|
| |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| torch.cuda.ipc_collect() |
| gc.collect() |
|
|
| if saved_full_model_dir is not None: |
| save_rng_state(os.path.join(saved_full_model_dir, "rng_state.pt")) |
| save_loader_generator_state( |
| saved_full_model_dir, |
| distill_generator=prepared.distill_generator, |
| lora_generator=prepared.lora_generator, |
| ) |
|
|
| _release_comm_teacher() |
| _release_previous_cycle_teacher() |
|
|
| final_pre_lora_eval = cycle_summaries[-1]["post_ppl"] if cycle_summaries else None |
| final_pre_lora_dir = f"{os.path.abspath(args.output_dir.rstrip(os.sep))}_final_pre_lora_hf" |
| final_pre_lora_meta = save_stage_checkpoint( |
| model=model, |
| tokenizer=tokenizer, |
| stage_dir=final_pre_lora_dir, |
| stage_name="final_pre_lora", |
| ppl_results=final_pre_lora_eval, |
| ) |
|
|
| |
| lora_eval_history: List[Dict[str, object]] = [] |
| lora_post_eval = None |
| lora_ready = has_post_fusion_data(prepared.lora_loader, prepared.lora_meta) |
| if args.lora_epochs > 0: |
| if not lora_ready: |
| print("No post-fusion sequences built; skipping LoRA finetuning.") |
| else: |
| print( |
| f"[progressive] Running final LoRA finetuning (epochs={args.lora_epochs})." |
| ) |
| lora_eval_history = run_lora_phase( |
| model=model, |
| tokenizer=tokenizer, |
| eval_datasets=eval_datasets, |
| eval_configs=eval_configs, |
| args=args, |
| lora_loader=prepared.lora_loader, |
| lora_meta=prepared.lora_meta, |
| eval_dataloaders=prepared.eval_dataloaders, |
| cycle_idx=args.num_progressive, |
| num_cycles=args.num_progressive, |
| ) |
|
|
| if not args.skip_eval: |
| lora_post_eval = evaluate_ppl_model( |
| model, |
| tokenizer, |
| eval_datasets, |
| eval_configs, |
| args, |
| prepared_eval_dataloaders=prepared.eval_dataloaders, |
| ) |
| print("[progressive] Post-LoRA perplexity:") |
| for dataset_name, ppl in lora_post_eval.items(): |
| print(f"{dataset_name}: {ppl:.4f}") |
|
|
| |
| if cycle_summaries: |
| cycle_summaries[-1]["lora_post_ppl"] = lora_post_eval |
| if lora_post_eval is not None: |
| cycle_summaries[-1]["post_ppl"] = lora_post_eval |
|
|
| final_cycle_dir = os.path.join( |
| args.output_dir, f"cycle_{args.num_progressive}" |
| ) |
| final_cycle_meta_path = os.path.join(final_cycle_dir, "cycle_metadata.json") |
| if os.path.exists(final_cycle_meta_path): |
| with open(final_cycle_meta_path, "r", encoding="utf-8") as handle: |
| final_cycle_meta = json.load(handle) |
|
|
| lora_meta_entry = final_cycle_meta.get("lora") |
| if not isinstance(lora_meta_entry, dict): |
| lora_meta_entry = {} |
| final_cycle_meta["lora"] = lora_meta_entry |
| lora_meta_entry["ran"] = True |
| lora_meta_entry["post_ppl"] = lora_post_eval |
| if lora_post_eval is not None and isinstance( |
| final_cycle_meta.get("eval"), dict |
| ): |
| final_cycle_meta["eval"]["post_ppl"] = lora_post_eval |
|
|
| if lora_eval_history: |
| lora_path = os.path.join(final_cycle_dir, "ppl_over_lora.json") |
| with open(lora_path, "w", encoding="utf-8") as handle: |
| json.dump(lora_eval_history, handle, indent=2) |
| lora_meta_entry["ppl_over_lora"] = "ppl_over_lora.json" |
|
|
| with open(final_cycle_meta_path, "w", encoding="utf-8") as handle: |
| json.dump(final_cycle_meta, handle, indent=2) |
|
|
| os.makedirs(args.output_dir, exist_ok=True) |
| final_post_lora_meta = save_stage_checkpoint( |
| model=model, |
| tokenizer=tokenizer, |
| stage_dir=args.output_dir, |
| stage_name="final_post_lora" if lora_post_eval is not None else "final_model", |
| ppl_results=lora_post_eval, |
| ) |
|
|
| progressive_meta = { |
| "base_model": getattr(args, "base_model_id", args.model), |
| "num_progressive": args.num_progressive, |
| "layer_path": args.layer_path, |
| "resume_from_cycle": args.resume_from_cycle, |
| "save_full_model_cycles": sorted(args.full_model_save_cycles), |
| "num_sequences": num_sequences, |
| "seq_len": args.seq_len, |
| "lora": { |
| "enabled": args.lora_epochs > 0, |
| "ran": args.lora_epochs > 0 and lora_ready, |
| "seq_len": args.distill_seq_len, |
| "batch_size": args.distill_batch_size, |
| "epochs": args.lora_epochs, |
| "rank": args.lora_rank, |
| "alpha": args.lora_alpha, |
| "dropout": args.lora_dropout, |
| "target_modules": args.lora_target_modules, |
| "respect_exclude_pairs": args.lora_respect_exclude_pairs, |
| "kl_enabled": args.lora_kl_enabled, |
| "kl_weight": args.lora_kl_weight, |
| "kl_temp": args.lora_kl_temp, |
| "post_ppl": lora_post_eval, |
| "ppl_over_lora": ( |
| f"cycle_{args.num_progressive}/ppl_over_lora.json" |
| if lora_eval_history |
| else None |
| ), |
| "lr": args.lora_lr, |
| "weight_decay": args.lora_weight_decay, |
| "max_grad_norm": args.lora_max_grad_norm, |
| "grad_accum_steps": args.lora_grad_accum_steps, |
| "log_steps": args.lora_log_steps, |
| "eval_every": args.lora_eval_every, |
| "eval_max_batches": args.lora_eval_max_batches, |
| }, |
| "artifacts": { |
| "final_pre_lora": final_pre_lora_meta, |
| "final_post_lora": final_post_lora_meta, |
| }, |
| "eval": { |
| "datasets": eval_datasets, |
| "configs": eval_configs, |
| "split": args.eval_split, |
| "num_samples": args.eval_num_samples, |
| "seq_len": args.eval_seq_len, |
| "pre_ppl": pre_eval, |
| "post_ppl": cycle_summaries[-1]["post_ppl"] if cycle_summaries else None, |
| }, |
| "cycles": cycle_summaries, |
| "final_num_layers": len(layers), |
| } |
|
|
| with open( |
| os.path.join(args.output_dir, "progressive_metadata.json"), |
| "w", |
| encoding="utf-8", |
| ) as handle: |
| json.dump(progressive_meta, handle, indent=2) |
|
|
| print( |
| f"[progressive] Completed {args.num_progressive} cycles. " |
| f"Final model saved to {args.output_dir}." |
| ) |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
| if args.num_progressive <= 0: |
| raise SystemExit( |
| "Single-cycle mode has been removed. Pass --num_progressive > 0." |
| ) |
| if args.resume_from_cycle < 0: |
| raise SystemExit("--resume_from_cycle must be >= 0.") |
| if args.resume_from_cycle >= args.num_progressive: |
| raise SystemExit("--resume_from_cycle must be smaller than --num_progressive.") |
| args.full_model_save_cycles = resolve_full_model_save_cycles( |
| parse_cycle_list(args.save_full_model_cycles), |
| args.num_progressive, |
| ) |
| args.base_model_id = args.model |
| if args.resume_from_cycle > 0: |
| resume_meta = load_resume_metadata(args.model) |
| if resume_meta is None: |
| raise SystemExit( |
| "--resume_from_cycle requires --model to point to a saved cycle full model " |
| "directory containing resume_info.json." |
| ) |
| resume_cycle = resume_meta.get("cycle") |
| if resume_cycle is not None and int(resume_cycle) != args.resume_from_cycle: |
| raise SystemExit( |
| "resume_info.json cycle does not match --resume_from_cycle." |
| ) |
| base_model = resume_meta.get("base_model") |
| if isinstance(base_model, str) and base_model: |
| args.base_model_id = base_model |
| configure_reproducibility(args.seed) |
|
|
| eval_datasets, eval_configs = resolve_eval_datasets(args) |
| dtype = get_dtype(args.dtype) |
| model = load_causal_lm( |
| args.model, |
| torch_dtype=dtype, |
| trust_remote_code=args.trust_remote_code, |
| cache_dir=args.model_cache_dir, |
| ) |
| loader_generator_state = None |
| if args.resume_from_cycle > 0: |
| rng_state_path = os.path.join(args.model, "rng_state.pt") |
| rng_state = load_rng_state(rng_state_path) |
| if rng_state is not None: |
| restore_rng_state(rng_state) |
| loader_generator_state = load_loader_generator_state(args.model) |
| tokenizer = AutoTokenizer.from_pretrained( |
| args.model, |
| trust_remote_code=args.trust_remote_code, |
| cache_dir=args.model_cache_dir, |
| ) |
| print(model) |
| if tokenizer.pad_token is None and tokenizer.eos_token is not None: |
| tokenizer.pad_token = tokenizer.eos_token |
| |
| |
| model.config.use_cache = False |
|
|
| prepared = prepare_all_data( |
| args, |
| tokenizer, |
| model, |
| eval_datasets, |
| eval_configs, |
| loader_generator_state=loader_generator_state, |
| ) |
| run_progressive(args, model, tokenizer, prepared) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|