#!/usr/bin/env python3 """Automatic adjacent-pair selection via configurable scoring metrics.""" import copy import math from contextlib import contextmanager from typing import Dict, List, Optional, Set, Tuple import torch import torch.nn.functional as F from fuse_layers_model import ( build_head_permutation, compute_fisher, compute_head_means, find_attention_module, find_layer_container, merge_layers, permute_attention_heads, ) _DWCE_GRAD_CACHE_MAX_BYTES = 1 << 30 class _DwceGradCacheOverflow(RuntimeError): """Raised when shared-backward DWCE caching exceeds the configured budget.""" def _get_hidden_size(model) -> int: hidden_size = getattr(model.config, "hidden_size", None) if hidden_size is None: hidden_size = getattr(model.config, "n_embd", None) if hidden_size is None: raise SystemExit("Model config missing hidden_size/n_embd") return int(hidden_size) def _detach_arg(arg): if torch.is_tensor(arg): return arg.detach() if isinstance(arg, (list, tuple)): return type(arg)(_detach_arg(x) for x in arg) if isinstance(arg, dict): return {k: _detach_arg(v) for k, v in arg.items()} return arg def _register_forward_hook(layer, hook): try: def wrapper(module, inputs, kwargs, output): return hook(module, inputs, output, kwargs) handle = layer.register_forward_hook(wrapper, with_kwargs=True) return handle, True except TypeError: def wrapper(module, inputs, output): return hook(module, inputs, output, None) handle = layer.register_forward_hook(wrapper) return handle, False @contextmanager def _temporary_layers(parent: object, name: str, new_layers: object): original = getattr(parent, name) setattr(parent, name, new_layers) try: yield finally: setattr(parent, name, original) def _extract_hidden(output): if torch.is_tensor(output): return output if isinstance(output, (tuple, list)): if output and all(torch.is_tensor(item) for item in output): return output[0] for item in output: hidden = _extract_hidden(item) if hidden is not None: return hidden return None if isinstance(output, dict): for key in ("hidden_states", "last_hidden_state", "hidden_state"): if key in output: value = output[key] if isinstance(value, (tuple, list)) and value and all( torch.is_tensor(item) for item in value ): return value[-1] hidden = _extract_hidden(value) if hidden is not None: return hidden for value in output.values(): hidden = _extract_hidden(value) if hidden is not None: return hidden return None for attr in ("hidden_states", "last_hidden_state"): if hasattr(output, attr): value = getattr(output, attr) if isinstance(value, (tuple, list)) and value and all( torch.is_tensor(item) for item in value ): return value[-1] hidden = _extract_hidden(value) if hidden is not None: return hidden return None def _build_fused_layer_for_pair( model, layer_a: torch.nn.Module, layer_b: torch.nn.Module, dataloader, device: str, fisher_mode: str, eps: float, hidden_size: int, enable_head_permute: bool = True, ) -> Tuple[torch.nn.Module, Dict[str, float]]: attn_a = find_attention_module(layer_a) attn_b = find_attention_module(layer_b) perm = None inv_perm = None num_heads = None num_kv_heads = None head_dim = None if enable_head_permute: mean_a, mean_b, num_heads, num_kv_heads, head_dim = compute_head_means( model, attn_a, attn_b, dataloader, device, hidden_size, ) perm = build_head_permutation( mean_a, mean_b, num_heads=num_heads, num_kv_heads=num_kv_heads, eps=eps, ) layer_a_copy = copy.deepcopy(layer_a) layer_b_copy = copy.deepcopy(layer_b) attn_b_copy = find_attention_module(layer_b_copy) if perm is not None: permute_attention_heads( attn_b_copy, perm, num_heads, num_kv_heads, head_dim=head_dim ) inv_perm = [0] * len(perm) for idx, mapped in enumerate(perm): inv_perm[mapped] = idx permute_attention_heads(attn_b, perm, num_heads, num_kv_heads, head_dim=head_dim) try: fisher_sums, num_batches, param_numels = compute_fisher( model, layer_a, layer_b, dataloader, fisher_mode=fisher_mode, device=device, ) finally: if inv_perm is not None: permute_attention_heads( attn_b, inv_perm, num_heads, num_kv_heads, head_dim=head_dim ) merge_layers( layer_a_copy, layer_b_copy, fisher_sums[0], fisher_sums[1], num_batches, param_numels[0], param_numels[1], fisher_mode=fisher_mode, eps=eps, ) # Scalar mixing coefficients per parameter tensor; used by pressure redistribution # to simulate future fusions without running another Fisher pass. fuse_priors: Dict[str, float] = {} params_b = {name: param for name, param in layer_b.named_parameters()} clamp_eps = 1e-4 for name, param_a in layer_a.named_parameters(): param_b = params_b.get(name) if param_b is None or param_b.shape != param_a.shape: continue if fisher_mode == "param": fa = fisher_sums[0][name] / max(num_batches, 1) fb = fisher_sums[1][name] / max(num_batches, 1) if isinstance(fa, torch.Tensor): fa_val = float(fa.mean().item()) else: fa_val = float(fa) if isinstance(fb, torch.Tensor): fb_val = float(fb.mean().item()) else: fb_val = float(fb) else: fa_val = float( fisher_sums[0][name] / (max(num_batches, 1) * max(param_numels[0].get(name, 1), 1)) ) fb_val = float( fisher_sums[1][name] / (max(num_batches, 1) * max(param_numels[1].get(name, 1), 1)) ) denom = fa_val + fb_val if denom <= eps: lam = 0.5 else: lam = fa_val / (denom + eps) lam = min(max(lam, clamp_eps), 1.0 - clamp_eps) fuse_priors[name] = lam layer_a_copy.eval() return layer_a_copy, fuse_priors def _init_fisher_accumulators( layer_a: torch.nn.Module, layer_b: torch.nn.Module, fisher_mode: str, device: str, ) -> Tuple[List[Dict[str, object]], List[Dict[str, int]]]: fisher_sums: List[Dict[str, object]] = [] param_numels: List[Dict[str, int]] = [] for layer in (layer_a, layer_b): layer_sums: Dict[str, object] = {} layer_numels: Dict[str, int] = {} for name, param in layer.named_parameters(): if not param.requires_grad: continue if fisher_mode == "param": layer_sums[name] = torch.zeros_like( param, dtype=torch.float32, device="cpu" ) else: layer_sums[name] = torch.zeros((), dtype=torch.float32, device=device) layer_numels[name] = param.numel() fisher_sums.append(layer_sums) param_numels.append(layer_numels) return fisher_sums, param_numels def _accumulate_fisher_from_grads( layer: torch.nn.Module, layer_sums: Dict[str, object], fisher_mode: str, ) -> None: for name, param in layer.named_parameters(): if not param.requires_grad or param.grad is None: continue grad_sq = param.grad.detach().float().pow(2) if fisher_mode == "param": layer_sums[name] += grad_sq.cpu() else: layer_sums[name] += grad_sq.sum() def _finalize_fisher_sums( fisher_sums: List[Dict[str, object]], fisher_mode: str, ) -> List[Dict[str, object]]: if fisher_mode == "param": return fisher_sums finalized: List[Dict[str, object]] = [] for layer_sums in fisher_sums: finalized_layer: Dict[str, object] = {} for name, value in layer_sums.items(): if isinstance(value, torch.Tensor): finalized_layer[name] = float(value.detach().cpu().item()) else: finalized_layer[name] = float(value) finalized.append(finalized_layer) return finalized def _compute_fuse_priors( layer_a: torch.nn.Module, layer_b: torch.nn.Module, fisher_sums: List[Dict[str, object]], num_batches: int, param_numels: List[Dict[str, int]], fisher_mode: str, eps: float, ) -> Dict[str, float]: fuse_priors: Dict[str, float] = {} params_b = {name: param for name, param in layer_b.named_parameters()} clamp_eps = 1e-4 for name, param_a in layer_a.named_parameters(): param_b = params_b.get(name) if param_b is None or param_b.shape != param_a.shape: continue if fisher_mode == "param": fa = fisher_sums[0][name] / max(num_batches, 1) fb = fisher_sums[1][name] / max(num_batches, 1) fa_val = float(fa.mean().item()) if isinstance(fa, torch.Tensor) else float(fa) fb_val = float(fb.mean().item()) if isinstance(fb, torch.Tensor) else float(fb) else: fa_val = float( fisher_sums[0][name] / (max(num_batches, 1) * max(param_numels[0].get(name, 1), 1)) ) fb_val = float( fisher_sums[1][name] / (max(num_batches, 1) * max(param_numels[1].get(name, 1), 1)) ) denom = fa_val + fb_val lam = 0.5 if denom <= eps else fa_val / (denom + eps) fuse_priors[name] = min(max(lam, clamp_eps), 1.0 - clamp_eps) return fuse_priors def _score_dwce_with_shared_backward( model, layer_a: torch.nn.Module, layer_b: torch.nn.Module, dataloader, device: str, fisher_mode: str, max_batches: int, eps: float, norm: str, hidden_size: int, enable_head_permute: bool = True, ) -> Tuple[float, Dict[str, object]]: attn_a = find_attention_module(layer_a) attn_b = find_attention_module(layer_b) perm = None inv_perm = None num_heads = None num_kv_heads = None head_dim = None if enable_head_permute: mean_a, mean_b, num_heads, num_kv_heads, head_dim = compute_head_means( model, attn_a, attn_b, dataloader, device, hidden_size, ) perm = build_head_permutation( mean_a, mean_b, num_heads=num_heads, num_kv_heads=num_kv_heads, eps=eps, ) layer_a_copy = copy.deepcopy(layer_a) layer_b_copy = copy.deepcopy(layer_b) attn_b_copy = find_attention_module(layer_b_copy) if perm is not None: permute_attention_heads( attn_b_copy, perm, num_heads, num_kv_heads, head_dim=head_dim ) inv_perm = [0] * len(perm) for idx, mapped in enumerate(perm): inv_perm[mapped] = idx cache: Dict[str, Optional[torch.Tensor]] = {"teacher": None} grad_sq_cache: List[torch.Tensor] = [] cached_bytes = 0 def hook_b(_module, _inputs, output, _kwargs=None): teacher_hidden = _extract_hidden(output) if teacher_hidden is None: raise RuntimeError("Failed to extract teacher hidden state output.") cache["teacher"] = teacher_hidden if teacher_hidden.requires_grad: teacher_hidden.retain_grad() return output handle_b, _ = _register_forward_hook(layer_b, hook_b) for param in model.parameters(): param.requires_grad_(False) for layer in (layer_a, layer_b): for param in layer.parameters(): param.requires_grad_(True) fisher_sums, param_numels = _init_fisher_accumulators( layer_a, layer_b, fisher_mode, device ) num_batches = 0 if perm is not None: permute_attention_heads(attn_b, perm, num_heads, num_kv_heads, head_dim=head_dim) try: model.eval() for batch_idx, batch in enumerate(dataloader): if max_batches and batch_idx >= max_batches: break cache["teacher"] = None input_ids = batch[0].to(device) attention_mask = batch[1].to(device) if len(batch) > 1 else None model.zero_grad(set_to_none=True) outputs = model( input_ids=input_ids, attention_mask=attention_mask, labels=input_ids, ) outputs.loss.backward() teacher = cache["teacher"] grad = None if teacher is None else teacher.grad if teacher is None or grad is None: raise RuntimeError( "Auto selection hooks failed to capture outputs/gradients. " "Try updating PyTorch or run with --layer ." ) grad_sq = grad.detach().pow(2).to(device=device, dtype=torch.float16) cached_bytes += grad_sq.numel() * grad_sq.element_size() if cached_bytes > _DWCE_GRAD_CACHE_MAX_BYTES: raise _DwceGradCacheOverflow( "DWCE grad cache exceeded device-memory budget during shared-backward scoring." ) grad_sq_cache.append(grad_sq) _accumulate_fisher_from_grads(layer_a, fisher_sums[0], fisher_mode) _accumulate_fisher_from_grads(layer_b, fisher_sums[1], fisher_mode) model.zero_grad(set_to_none=True) num_batches += 1 finally: handle_b.remove() if inv_perm is not None: permute_attention_heads( attn_b, inv_perm, num_heads, num_kv_heads, head_dim=head_dim ) for param in model.parameters(): param.requires_grad_(True) if num_batches == 0: raise RuntimeError("No batches processed; check dataset or text inputs.") fisher_sums = _finalize_fisher_sums(fisher_sums, fisher_mode) merge_layers( layer_a_copy, layer_b_copy, fisher_sums[0], fisher_sums[1], num_batches, param_numels[0], param_numels[1], fisher_mode=fisher_mode, eps=eps, ) fuse_priors = _compute_fuse_priors( layer_a, layer_b, fisher_sums, num_batches, param_numels, fisher_mode, eps, ) fused_layer = layer_a_copy fused_layer.eval() phase2_cache = {"teacher": None, "fused": None} def hook_a(_module, inputs, output, kwargs=None): with torch.no_grad(): detached_inputs = tuple(_detach_arg(arg) for arg in inputs) if kwargs: detached_kwargs = {k: _detach_arg(v) for k, v in kwargs.items()} fused_out = fused_layer(*detached_inputs, **detached_kwargs) else: fused_out = fused_layer(*detached_inputs) fused_hidden = _extract_hidden(fused_out) if fused_hidden is None: raise RuntimeError("Failed to extract fused hidden state output.") phase2_cache["fused"] = fused_hidden return output def hook_b_eval(_module, _inputs, output, _kwargs=None): teacher_hidden = _extract_hidden(output) if teacher_hidden is None: raise RuntimeError("Failed to extract teacher hidden state output.") phase2_cache["teacher"] = teacher_hidden return output handle_a, has_kwargs_a = _register_forward_hook(layer_a, hook_a) handle_b_eval, has_kwargs_b = _register_forward_hook(layer_b, hook_b_eval) supports_kwargs = has_kwargs_a and has_kwargs_b score_num = 0.0 score_den = 0.0 token_count = 0.0 try: model.eval() for batch_idx, batch in enumerate(dataloader): if batch_idx >= num_batches: break phase2_cache["teacher"] = None phase2_cache["fused"] = None input_ids = batch[0].to(device) attention_mask = batch[1].to(device) if len(batch) > 1 else None with torch.no_grad(): model( input_ids=input_ids, attention_mask=attention_mask, use_cache=False, ) teacher = phase2_cache["teacher"] fused = phase2_cache["fused"] if teacher is None or fused is None: raise RuntimeError( "Auto selection hooks failed to capture outputs during DWCE replay." ) grad_sq = grad_sq_cache[batch_idx].to(dtype=torch.float32) if attention_mask is not None: mask = attention_mask.to(dtype=torch.float32).unsqueeze(-1) batch_tokens = float(mask.sum().item()) grad_sq = grad_sq * mask else: mask = None batch_tokens = float(input_ids.numel()) token_count += batch_tokens delta = fused - teacher if mask is not None: delta = delta * mask score_num += (delta.float().pow(2) * grad_sq).sum().item() score_den += (teacher.float().pow(2) * grad_sq).sum().item() finally: handle_a.remove() handle_b_eval.remove() score = ( score_num / (score_den + eps) if norm == "relative" else score_num / max(token_count, 1.0) ) meta = { "num_batches": num_batches, "token_count": token_count, "norm": norm, "supports_kwargs": supports_kwargs, "fuse_priors": fuse_priors, "metric": "dwce", "dwce_mode": "shared", } return score, meta def _compute_dwce_for_pair( model, layer_a: torch.nn.Module, layer_b: torch.nn.Module, fused_layer: torch.nn.Module, dataloader, device: str, max_batches: int, eps: float, norm: str, ) -> Tuple[float, Dict[str, object]]: cache = {"teacher": None, "fused": None} supports_kwargs = True def hook_a(_module, inputs, output, kwargs=None): with torch.no_grad(): detached_inputs = tuple(_detach_arg(arg) for arg in inputs) if kwargs is not None and len(kwargs) > 0: detached_kwargs = {k: _detach_arg(v) for k, v in kwargs.items()} fused_out = fused_layer(*detached_inputs, **detached_kwargs) else: fused_out = fused_layer(*detached_inputs) fused_hidden = _extract_hidden(fused_out) if fused_hidden is None: raise RuntimeError("Failed to extract fused hidden state output.") cache["fused"] = fused_hidden return output def hook_b(_module, _inputs, output, _kwargs=None): teacher_hidden = _extract_hidden(output) if teacher_hidden is None: raise RuntimeError("Failed to extract teacher hidden state output.") cache["teacher"] = teacher_hidden if teacher_hidden.requires_grad: teacher_hidden.retain_grad() return output handle_a, has_kwargs_a = _register_forward_hook(layer_a, hook_a) handle_b, has_kwargs_b = _register_forward_hook(layer_b, hook_b) supports_kwargs = has_kwargs_a and has_kwargs_b score_num = 0.0 score_den = 0.0 token_count = 0.0 num_batches = 0 model.eval() for batch_idx, batch in enumerate(dataloader): if max_batches and batch_idx >= max_batches: break cache["teacher"] = None cache["fused"] = None input_ids = batch[0].to(device) attention_mask = batch[1].to(device) if len(batch) > 1 else None model.zero_grad(set_to_none=True) outputs = model( input_ids=input_ids, attention_mask=attention_mask, labels=input_ids, ) loss = outputs.loss loss.backward() teacher = cache["teacher"] fused = cache["fused"] grad = None if teacher is None else teacher.grad if teacher is None or fused is None or grad is None: raise RuntimeError( "Auto selection hooks failed to capture outputs/gradients. " "Try updating PyTorch or run with --layer ." ) if not teacher.requires_grad: raise RuntimeError( "Teacher hidden state does not require grad. " "Ensure model parameters require grad for DWCE." ) with torch.no_grad(): if attention_mask is not None: mask = attention_mask.to(dtype=torch.float32).unsqueeze(-1) batch_tokens = float(mask.sum().item()) else: mask = None batch_tokens = float(input_ids.numel()) token_count += batch_tokens delta = fused - teacher grad_sq = grad.pow(2) if mask is not None: delta = delta * mask grad_sq = grad_sq * mask score_num += (delta.pow(2) * grad_sq).sum().item() score_den += (teacher.pow(2) * grad_sq).sum().item() num_batches += 1 handle_a.remove() handle_b.remove() if norm == "relative": score = score_num / (score_den + eps) else: denom = token_count if token_count > 0 else 1.0 score = score_num / denom meta = { "num_batches": num_batches, "token_count": token_count, "norm": norm, "supports_kwargs": supports_kwargs, } return score, meta def _compute_cosine_for_pair( model, layer_a: torch.nn.Module, layer_b: torch.nn.Module, dataloader, device: str, max_batches: int, eps: float, ) -> Tuple[float, Dict[str, object]]: cache = {"a": None, "b": None} supports_kwargs = True def hook_a(_module, _inputs, output, _kwargs=None): hidden = _extract_hidden(output) if hidden is None: raise RuntimeError("Failed to extract layer_a hidden state output.") cache["a"] = hidden return output def hook_b(_module, _inputs, output, _kwargs=None): hidden = _extract_hidden(output) if hidden is None: raise RuntimeError("Failed to extract layer_b hidden state output.") cache["b"] = hidden return output handle_a, has_kwargs_a = _register_forward_hook(layer_a, hook_a) handle_b, has_kwargs_b = _register_forward_hook(layer_b, hook_b) supports_kwargs = has_kwargs_a and has_kwargs_b score_sum = 0.0 token_count = 0.0 num_batches = 0 model.eval() for batch_idx, batch in enumerate(dataloader): if max_batches and batch_idx >= max_batches: break cache["a"] = None cache["b"] = None input_ids = batch[0].to(device) attention_mask = batch[1].to(device) if len(batch) > 1 else None with torch.no_grad(): model( input_ids=input_ids, attention_mask=attention_mask, use_cache=False, ) hidden_a = cache["a"] hidden_b = cache["b"] if hidden_a is None or hidden_b is None: raise RuntimeError( "Auto selection hooks failed to capture outputs for cosine scoring." ) with torch.no_grad(): a = hidden_a.float() b = hidden_b.float() cos = F.cosine_similarity(a, b, dim=-1, eps=eps) distance = 1.0 - cos if attention_mask is not None: mask = attention_mask.to(dtype=torch.float32) batch_tokens = float(mask.sum().item()) distance = distance * mask else: batch_tokens = float(distance.numel()) token_count += batch_tokens score_sum += float(distance.sum().item()) num_batches += 1 handle_a.remove() handle_b.remove() denom = token_count if token_count > 0 else 1.0 score = score_sum / denom meta = { "num_batches": num_batches, "token_count": token_count, "metric": "cosine", "supports_kwargs": supports_kwargs, } return score, meta def _compute_global_rel_change_for_pair( model, layers: List[torch.nn.Module], pair_idx: int, dataloader, args, max_batches: int, eps: float, ) -> Tuple[float, Dict[str, object]]: hidden_size = _get_hidden_size(model) head_permute_select = not bool(getattr(args, "no_head_permute_select", False)) layer_a = layers[pair_idx] layer_b = layers[pair_idx + 1] fused_layer, fuse_priors = _build_fused_layer_for_pair( model, layer_a, layer_b, dataloader, device=args.device, fisher_mode=args.fisher_mode, eps=eps, hidden_size=hidden_size, enable_head_permute=head_permute_select, ) fused_layer.to(args.device) fused_layer.eval() parent, name, container = find_layer_container(model, getattr(args, "layer_path", None)) if len(list(container)) != len(layers): raise RuntimeError("Layer container changed during auto-selection; aborting rerank.") virtual_layers = list(layers) virtual_layers[pair_idx] = fused_layer del virtual_layers[pair_idx + 1] if isinstance(container, torch.nn.ModuleList): virtual_container = torch.nn.ModuleList(virtual_layers) elif isinstance(container, list): virtual_container = virtual_layers else: raise TypeError("Layer container must be ModuleList or list") teacher_cache = {"pair": None, "final": None} supports_kwargs = True def hook_pair(_module, _inputs, output, _kwargs=None): hidden = _extract_hidden(output) if hidden is None: raise RuntimeError("Failed to extract pair output for global relation rerank.") teacher_cache["pair"] = hidden return output handle_pair, has_kwargs_pair = _register_forward_hook(layer_b, hook_pair) supports_kwargs = supports_kwargs and has_kwargs_pair score_sum = 0.0 token_count = 0.0 num_batches = 0 model.eval() for batch_idx, batch in enumerate(dataloader): if max_batches and batch_idx >= max_batches: break teacher_cache["pair"] = None input_ids = batch[0].to(args.device) attention_mask = batch[1].to(args.device) if len(batch) > 1 else None with torch.no_grad(): teacher_outputs = model( input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True, use_cache=False, ) teacher_hidden_states = getattr(teacher_outputs, "hidden_states", None) if not teacher_hidden_states: raise RuntimeError("Teacher forward did not return hidden_states.") teacher_final = teacher_hidden_states[-1] teacher_pair = teacher_cache["pair"] if teacher_pair is None or teacher_final is None: raise RuntimeError( "Failed to capture teacher pair/final hidden states for global rerank." ) with torch.no_grad(), _temporary_layers(parent, name, virtual_container): fused_outputs = model( input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True, use_cache=False, ) fused_hidden_states = getattr(fused_outputs, "hidden_states", None) if not fused_hidden_states: raise RuntimeError("Fused forward did not return hidden_states.") fused_final = fused_hidden_states[-1] if fused_final is None: raise RuntimeError("Failed to capture fused final hidden state for global rerank.") with torch.no_grad(): teacher_pair_f = teacher_pair.float() teacher_final_f = teacher_final.float() fused_final_f = fused_final.float() teacher_rel = F.cosine_similarity( teacher_pair_f, teacher_final_f, dim=-1, eps=eps ) fused_rel = F.cosine_similarity( teacher_pair_f, fused_final_f, dim=-1, eps=eps ) rel_change = (teacher_rel - fused_rel).abs() if attention_mask is not None: mask = attention_mask.to(dtype=torch.float32) batch_tokens = float(mask.sum().item()) rel_change = rel_change * mask else: batch_tokens = float(rel_change.numel()) token_count += batch_tokens score_sum += float(rel_change.sum().item()) num_batches += 1 handle_pair.remove() del fused_layer if torch.cuda.is_available(): torch.cuda.empty_cache() denom = token_count if token_count > 0 else 1.0 score = score_sum / denom meta = { "num_batches": num_batches, "token_count": token_count, "metric": "global_rel_change", "supports_kwargs": supports_kwargs, "fuse_priors": fuse_priors, } return score, meta def select_layer_auto( model, layers: List[torch.nn.Module], dataloader, args, previous_scores: Optional[List[float]] = None, start_index: int = 0, exclude_pairs: Optional[Set[int]] = None, ) -> Tuple[int, List[float], Dict[str, object]]: num_layers = len(layers) if num_layers < 2: raise SystemExit("Model must have at least 2 layers for auto selection.") hidden_size = _get_hidden_size(model) num_pairs = num_layers - 1 scores: List[float] = [float("inf")] * num_pairs meta_per_pair: List[Optional[Dict[str, object]]] = [None] * num_pairs supports_kwargs_all = True head_permute_select = not bool(getattr(args, "no_head_permute_select", False)) exclude_set: Set[int] = { int(idx) for idx in (exclude_pairs or set()) if isinstance(idx, int) and 0 <= int(idx) < num_pairs } max_batches = args.auto_max_batches start_index = max(0, min(start_index, num_pairs)) auto_metric = str(getattr(args, "auto_metric", "dwce")).strip().lower() if auto_metric == "hybrid": auto_metric = "hybrid_cosine" if auto_metric not in { "dwce", "cosine", "hybrid_cosine", "hybrid_global_rel", }: raise SystemExit( "--auto_metric must be one of: dwce, cosine, hybrid, " "hybrid_cosine, hybrid_global_rel" ) auto_cosine_topk = int(getattr(args, "auto_cosine_topk", 3)) if auto_cosine_topk <= 0: raise SystemExit("--auto_cosine_topk must be >= 1") print( f"[auto] metric={auto_metric}; using " f"{('all' if max_batches == 0 else max_batches)} batches " "from calibration samples." ) reuse_upto = 0 allow_reuse = auto_metric == "dwce" if previous_scores: reuse_upto = min(start_index, len(previous_scores), num_pairs) if allow_reuse else 0 for idx in range(reuse_upto): if idx in exclude_set: scores[idx] = float("inf") meta_per_pair[idx] = {"excluded": True} print(f"[auto] skipped excluded pair {idx}-{idx+1}.") continue scores[idx] = previous_scores[idx] meta_per_pair[idx] = ( { "num_batches": 0, "token_count": 0.0, "norm": args.auto_norm, "metric": auto_metric, "supports_kwargs": True, "reused": True, } ) print(f"[auto] reused pair {idx}-{idx+1}: {scores[idx]:.6e}") compute_start = start_index if reuse_upto == start_index else reuse_upto pairs_to_score: List[int] = [] for idx in range(compute_start, num_pairs): if idx in exclude_set: scores[idx] = float("inf") meta_per_pair[idx] = {"excluded": True} print(f"[auto] skipped excluded pair {idx}-{idx+1}.") continue pairs_to_score.append(idx) def _score_dwce_for_pair(idx: int) -> Tuple[float, Dict[str, object]]: print(f"[auto] building fused pair {idx}-{idx+1} for DWCE...") layer_a = layers[idx] layer_b = layers[idx + 1] dwce_mode = str(getattr(args, "auto_dwce_mode", "separate")).strip().lower() if dwce_mode == "shared": try: return _score_dwce_with_shared_backward( model, layer_a, layer_b, dataloader, device=args.device, fisher_mode=args.fisher_mode, max_batches=max_batches, eps=args.eps, norm=args.auto_norm, hidden_size=hidden_size, enable_head_permute=head_permute_select, ) except _DwceGradCacheOverflow: print( "[auto] shared-backward DWCE cache exceeded budget; " "falling back to separate mode." ) fused, fuse_priors = _build_fused_layer_for_pair( model, layer_a, layer_b, dataloader, device=args.device, fisher_mode=args.fisher_mode, eps=args.eps, hidden_size=hidden_size, enable_head_permute=head_permute_select, ) fused.to(args.device) fused.eval() for param in model.parameters(): param.requires_grad_(True) score, meta = _compute_dwce_for_pair( model, layer_a, layer_b, fused, dataloader, device=args.device, max_batches=max_batches, eps=args.eps, norm=args.auto_norm, ) meta["fuse_priors"] = fuse_priors meta["metric"] = "dwce" del fused if torch.cuda.is_available(): torch.cuda.empty_cache() return score, meta def _score_cosine_for_pair(idx: int) -> Tuple[float, Dict[str, object]]: print(f"[auto] scoring cosine for pair {idx}-{idx+1}...") layer_a = layers[idx] layer_b = layers[idx + 1] return _compute_cosine_for_pair( model, layer_a, layer_b, dataloader, device=args.device, max_batches=max_batches, eps=args.eps, ) def _score_global_rel_for_pair(idx: int) -> Tuple[float, Dict[str, object]]: print(f"[auto] scoring global relation change for pair {idx}-{idx+1}...") return _compute_global_rel_change_for_pair( model, layers, idx, dataloader, args=args, max_batches=max_batches, eps=args.eps, ) if auto_metric in {"dwce", "cosine"}: for idx in pairs_to_score: if auto_metric == "dwce": score, meta = _score_dwce_for_pair(idx) else: score, meta = _score_cosine_for_pair(idx) supports_kwargs_all = supports_kwargs_all and meta.get("supports_kwargs", True) scores[idx] = score meta_per_pair[idx] = meta print(f"[auto] {auto_metric} pair {idx}-{idx+1}: {score:.6e}") else: dwce_prefilter: Dict[int, float] = {} for idx in pairs_to_score: score, meta = _score_dwce_for_pair(idx) dwce_prefilter[idx] = score supports_kwargs_all = supports_kwargs_all and meta.get("supports_kwargs", True) meta_per_pair[idx] = { "prefilter_dwce": score, "dwce_meta": meta, "metric": "hybrid", } print(f"[auto] hybrid prefilter DWCE pair {idx}-{idx+1}: {score:.6e}") ranked = sorted(pairs_to_score, key=lambda i: float(dwce_prefilter[i])) shortlist = ranked[: min(auto_cosine_topk, len(ranked))] print(f"[auto] hybrid shortlist (dwce top-{len(shortlist)}): {shortlist}") for idx in shortlist: if auto_metric == "hybrid_global_rel": score, rerank_meta = _score_global_rel_for_pair(idx) score_metric = "global_rel_change" else: score, rerank_meta = _score_cosine_for_pair(idx) score_metric = "cosine" supports_kwargs_all = supports_kwargs_all and rerank_meta.get( "supports_kwargs", True ) scores[idx] = score pair_meta = meta_per_pair[idx] or {} pair_meta["rerank_meta"] = rerank_meta pair_meta["score_metric"] = score_metric meta_per_pair[idx] = pair_meta print(f"[auto] hybrid {score_metric} pair {idx}-{idx+1}: {score:.6e}") if not supports_kwargs_all: print( "[auto] Warning: forward hooks did not capture kwargs; " "fused-layer calls may be approximate." ) print(f"[auto] score summary (metric={auto_metric}, norm={args.auto_norm}):") for idx, score in enumerate(scores): if idx in exclude_set: print(f"[auto] pair {idx}-{idx+1}: excluded") elif math.isfinite(float(score)): print(f"[auto] pair {idx}-{idx+1}: {score:.6e}") else: print(f"[auto] pair {idx}-{idx+1}: {score}") candidates = [i for i in range(num_pairs) if i not in exclude_set] if not candidates: raise SystemExit("All pairs are excluded; cannot auto-select a fusion layer.") best_idx = min(candidates, key=lambda i: scores[i]) best_score = float(scores[best_idx]) if not math.isfinite(best_score): raise SystemExit( "Auto selection failed: all candidate pairs have non-finite scores " "(check --exclude_pairs and data)." ) print(f"[auto] Selected layer {best_idx} (score={best_score:.6e})") meta = { "per_pair": meta_per_pair, "supports_kwargs": supports_kwargs_all, "max_batches": max_batches, "norm": args.auto_norm, "metric": auto_metric, "cosine_topk": auto_cosine_topk, "start_index": start_index, "excluded_pairs": sorted(exclude_set), } return best_idx, scores, meta