| |
| """Automatic adjacent-pair selection via configurable scoring metrics.""" |
|
|
| import copy |
| import math |
| from contextlib import contextmanager |
| from typing import Dict, List, Optional, Set, Tuple |
|
|
| import torch |
| import torch.nn.functional as F |
|
|
| from fuse_layers_model import ( |
| build_head_permutation, |
| compute_fisher, |
| compute_head_means, |
| find_attention_module, |
| find_layer_container, |
| merge_layers, |
| permute_attention_heads, |
| ) |
|
|
| _DWCE_GRAD_CACHE_MAX_BYTES = 1 << 30 |
|
|
|
|
| class _DwceGradCacheOverflow(RuntimeError): |
| """Raised when shared-backward DWCE caching exceeds the configured budget.""" |
|
|
|
|
| def _get_hidden_size(model) -> int: |
| hidden_size = getattr(model.config, "hidden_size", None) |
| if hidden_size is None: |
| hidden_size = getattr(model.config, "n_embd", None) |
| if hidden_size is None: |
| raise SystemExit("Model config missing hidden_size/n_embd") |
| return int(hidden_size) |
|
|
|
|
| def _detach_arg(arg): |
| if torch.is_tensor(arg): |
| return arg.detach() |
| if isinstance(arg, (list, tuple)): |
| return type(arg)(_detach_arg(x) for x in arg) |
| if isinstance(arg, dict): |
| return {k: _detach_arg(v) for k, v in arg.items()} |
| return arg |
|
|
|
|
| def _register_forward_hook(layer, hook): |
| try: |
| def wrapper(module, inputs, kwargs, output): |
| return hook(module, inputs, output, kwargs) |
|
|
| handle = layer.register_forward_hook(wrapper, with_kwargs=True) |
| return handle, True |
| except TypeError: |
| def wrapper(module, inputs, output): |
| return hook(module, inputs, output, None) |
| handle = layer.register_forward_hook(wrapper) |
| return handle, False |
|
|
|
|
| @contextmanager |
| def _temporary_layers(parent: object, name: str, new_layers: object): |
| original = getattr(parent, name) |
| setattr(parent, name, new_layers) |
| try: |
| yield |
| finally: |
| setattr(parent, name, original) |
|
|
|
|
| def _extract_hidden(output): |
| if torch.is_tensor(output): |
| return output |
| if isinstance(output, (tuple, list)): |
| if output and all(torch.is_tensor(item) for item in output): |
| return output[0] |
| for item in output: |
| hidden = _extract_hidden(item) |
| if hidden is not None: |
| return hidden |
| return None |
| if isinstance(output, dict): |
| for key in ("hidden_states", "last_hidden_state", "hidden_state"): |
| if key in output: |
| value = output[key] |
| if isinstance(value, (tuple, list)) and value and all( |
| torch.is_tensor(item) for item in value |
| ): |
| return value[-1] |
| hidden = _extract_hidden(value) |
| if hidden is not None: |
| return hidden |
| for value in output.values(): |
| hidden = _extract_hidden(value) |
| if hidden is not None: |
| return hidden |
| return None |
| for attr in ("hidden_states", "last_hidden_state"): |
| if hasattr(output, attr): |
| value = getattr(output, attr) |
| if isinstance(value, (tuple, list)) and value and all( |
| torch.is_tensor(item) for item in value |
| ): |
| return value[-1] |
| hidden = _extract_hidden(value) |
| if hidden is not None: |
| return hidden |
| return None |
|
|
|
|
| def _build_fused_layer_for_pair( |
| model, |
| layer_a: torch.nn.Module, |
| layer_b: torch.nn.Module, |
| dataloader, |
| device: str, |
| fisher_mode: str, |
| eps: float, |
| hidden_size: int, |
| enable_head_permute: bool = True, |
| ) -> Tuple[torch.nn.Module, Dict[str, float]]: |
| attn_a = find_attention_module(layer_a) |
| attn_b = find_attention_module(layer_b) |
| perm = None |
| inv_perm = None |
| num_heads = None |
| num_kv_heads = None |
| head_dim = None |
| if enable_head_permute: |
| mean_a, mean_b, num_heads, num_kv_heads, head_dim = compute_head_means( |
| model, |
| attn_a, |
| attn_b, |
| dataloader, |
| device, |
| hidden_size, |
| ) |
|
|
| perm = build_head_permutation( |
| mean_a, |
| mean_b, |
| num_heads=num_heads, |
| num_kv_heads=num_kv_heads, |
| eps=eps, |
| ) |
|
|
| layer_a_copy = copy.deepcopy(layer_a) |
| layer_b_copy = copy.deepcopy(layer_b) |
| attn_b_copy = find_attention_module(layer_b_copy) |
| if perm is not None: |
| permute_attention_heads( |
| attn_b_copy, perm, num_heads, num_kv_heads, head_dim=head_dim |
| ) |
|
|
| inv_perm = [0] * len(perm) |
| for idx, mapped in enumerate(perm): |
| inv_perm[mapped] = idx |
|
|
| permute_attention_heads(attn_b, perm, num_heads, num_kv_heads, head_dim=head_dim) |
| try: |
| fisher_sums, num_batches, param_numels = compute_fisher( |
| model, |
| layer_a, |
| layer_b, |
| dataloader, |
| fisher_mode=fisher_mode, |
| device=device, |
| ) |
| finally: |
| if inv_perm is not None: |
| permute_attention_heads( |
| attn_b, inv_perm, num_heads, num_kv_heads, head_dim=head_dim |
| ) |
|
|
| merge_layers( |
| layer_a_copy, |
| layer_b_copy, |
| fisher_sums[0], |
| fisher_sums[1], |
| num_batches, |
| param_numels[0], |
| param_numels[1], |
| fisher_mode=fisher_mode, |
| eps=eps, |
| ) |
|
|
| |
| |
| fuse_priors: Dict[str, float] = {} |
| params_b = {name: param for name, param in layer_b.named_parameters()} |
| clamp_eps = 1e-4 |
| for name, param_a in layer_a.named_parameters(): |
| param_b = params_b.get(name) |
| if param_b is None or param_b.shape != param_a.shape: |
| continue |
| if fisher_mode == "param": |
| fa = fisher_sums[0][name] / max(num_batches, 1) |
| fb = fisher_sums[1][name] / max(num_batches, 1) |
| if isinstance(fa, torch.Tensor): |
| fa_val = float(fa.mean().item()) |
| else: |
| fa_val = float(fa) |
| if isinstance(fb, torch.Tensor): |
| fb_val = float(fb.mean().item()) |
| else: |
| fb_val = float(fb) |
| else: |
| fa_val = float( |
| fisher_sums[0][name] |
| / (max(num_batches, 1) * max(param_numels[0].get(name, 1), 1)) |
| ) |
| fb_val = float( |
| fisher_sums[1][name] |
| / (max(num_batches, 1) * max(param_numels[1].get(name, 1), 1)) |
| ) |
| denom = fa_val + fb_val |
| if denom <= eps: |
| lam = 0.5 |
| else: |
| lam = fa_val / (denom + eps) |
| lam = min(max(lam, clamp_eps), 1.0 - clamp_eps) |
| fuse_priors[name] = lam |
|
|
| layer_a_copy.eval() |
| return layer_a_copy, fuse_priors |
|
|
|
|
| def _init_fisher_accumulators( |
| layer_a: torch.nn.Module, |
| layer_b: torch.nn.Module, |
| fisher_mode: str, |
| device: str, |
| ) -> Tuple[List[Dict[str, object]], List[Dict[str, int]]]: |
| fisher_sums: List[Dict[str, object]] = [] |
| param_numels: List[Dict[str, int]] = [] |
| for layer in (layer_a, layer_b): |
| layer_sums: Dict[str, object] = {} |
| layer_numels: Dict[str, int] = {} |
| for name, param in layer.named_parameters(): |
| if not param.requires_grad: |
| continue |
| if fisher_mode == "param": |
| layer_sums[name] = torch.zeros_like( |
| param, dtype=torch.float32, device="cpu" |
| ) |
| else: |
| layer_sums[name] = torch.zeros((), dtype=torch.float32, device=device) |
| layer_numels[name] = param.numel() |
| fisher_sums.append(layer_sums) |
| param_numels.append(layer_numels) |
| return fisher_sums, param_numels |
|
|
|
|
| def _accumulate_fisher_from_grads( |
| layer: torch.nn.Module, |
| layer_sums: Dict[str, object], |
| fisher_mode: str, |
| ) -> None: |
| for name, param in layer.named_parameters(): |
| if not param.requires_grad or param.grad is None: |
| continue |
| grad_sq = param.grad.detach().float().pow(2) |
| if fisher_mode == "param": |
| layer_sums[name] += grad_sq.cpu() |
| else: |
| layer_sums[name] += grad_sq.sum() |
|
|
|
|
| def _finalize_fisher_sums( |
| fisher_sums: List[Dict[str, object]], |
| fisher_mode: str, |
| ) -> List[Dict[str, object]]: |
| if fisher_mode == "param": |
| return fisher_sums |
|
|
| finalized: List[Dict[str, object]] = [] |
| for layer_sums in fisher_sums: |
| finalized_layer: Dict[str, object] = {} |
| for name, value in layer_sums.items(): |
| if isinstance(value, torch.Tensor): |
| finalized_layer[name] = float(value.detach().cpu().item()) |
| else: |
| finalized_layer[name] = float(value) |
| finalized.append(finalized_layer) |
| return finalized |
|
|
|
|
| def _compute_fuse_priors( |
| layer_a: torch.nn.Module, |
| layer_b: torch.nn.Module, |
| fisher_sums: List[Dict[str, object]], |
| num_batches: int, |
| param_numels: List[Dict[str, int]], |
| fisher_mode: str, |
| eps: float, |
| ) -> Dict[str, float]: |
| fuse_priors: Dict[str, float] = {} |
| params_b = {name: param for name, param in layer_b.named_parameters()} |
| clamp_eps = 1e-4 |
| for name, param_a in layer_a.named_parameters(): |
| param_b = params_b.get(name) |
| if param_b is None or param_b.shape != param_a.shape: |
| continue |
| if fisher_mode == "param": |
| fa = fisher_sums[0][name] / max(num_batches, 1) |
| fb = fisher_sums[1][name] / max(num_batches, 1) |
| fa_val = float(fa.mean().item()) if isinstance(fa, torch.Tensor) else float(fa) |
| fb_val = float(fb.mean().item()) if isinstance(fb, torch.Tensor) else float(fb) |
| else: |
| fa_val = float( |
| fisher_sums[0][name] |
| / (max(num_batches, 1) * max(param_numels[0].get(name, 1), 1)) |
| ) |
| fb_val = float( |
| fisher_sums[1][name] |
| / (max(num_batches, 1) * max(param_numels[1].get(name, 1), 1)) |
| ) |
| denom = fa_val + fb_val |
| lam = 0.5 if denom <= eps else fa_val / (denom + eps) |
| fuse_priors[name] = min(max(lam, clamp_eps), 1.0 - clamp_eps) |
| return fuse_priors |
|
|
|
|
| def _score_dwce_with_shared_backward( |
| model, |
| layer_a: torch.nn.Module, |
| layer_b: torch.nn.Module, |
| dataloader, |
| device: str, |
| fisher_mode: str, |
| max_batches: int, |
| eps: float, |
| norm: str, |
| hidden_size: int, |
| enable_head_permute: bool = True, |
| ) -> Tuple[float, Dict[str, object]]: |
| attn_a = find_attention_module(layer_a) |
| attn_b = find_attention_module(layer_b) |
| perm = None |
| inv_perm = None |
| num_heads = None |
| num_kv_heads = None |
| head_dim = None |
| if enable_head_permute: |
| mean_a, mean_b, num_heads, num_kv_heads, head_dim = compute_head_means( |
| model, |
| attn_a, |
| attn_b, |
| dataloader, |
| device, |
| hidden_size, |
| ) |
| perm = build_head_permutation( |
| mean_a, |
| mean_b, |
| num_heads=num_heads, |
| num_kv_heads=num_kv_heads, |
| eps=eps, |
| ) |
|
|
| layer_a_copy = copy.deepcopy(layer_a) |
| layer_b_copy = copy.deepcopy(layer_b) |
| attn_b_copy = find_attention_module(layer_b_copy) |
| if perm is not None: |
| permute_attention_heads( |
| attn_b_copy, perm, num_heads, num_kv_heads, head_dim=head_dim |
| ) |
|
|
| inv_perm = [0] * len(perm) |
| for idx, mapped in enumerate(perm): |
| inv_perm[mapped] = idx |
|
|
| cache: Dict[str, Optional[torch.Tensor]] = {"teacher": None} |
| grad_sq_cache: List[torch.Tensor] = [] |
| cached_bytes = 0 |
|
|
| def hook_b(_module, _inputs, output, _kwargs=None): |
| teacher_hidden = _extract_hidden(output) |
| if teacher_hidden is None: |
| raise RuntimeError("Failed to extract teacher hidden state output.") |
| cache["teacher"] = teacher_hidden |
| if teacher_hidden.requires_grad: |
| teacher_hidden.retain_grad() |
| return output |
|
|
| handle_b, _ = _register_forward_hook(layer_b, hook_b) |
| for param in model.parameters(): |
| param.requires_grad_(False) |
| for layer in (layer_a, layer_b): |
| for param in layer.parameters(): |
| param.requires_grad_(True) |
| fisher_sums, param_numels = _init_fisher_accumulators( |
| layer_a, layer_b, fisher_mode, device |
| ) |
| num_batches = 0 |
|
|
| if perm is not None: |
| permute_attention_heads(attn_b, perm, num_heads, num_kv_heads, head_dim=head_dim) |
| try: |
| model.eval() |
| for batch_idx, batch in enumerate(dataloader): |
| if max_batches and batch_idx >= max_batches: |
| break |
| cache["teacher"] = None |
| input_ids = batch[0].to(device) |
| attention_mask = batch[1].to(device) if len(batch) > 1 else None |
|
|
| model.zero_grad(set_to_none=True) |
| outputs = model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| labels=input_ids, |
| ) |
| outputs.loss.backward() |
|
|
| teacher = cache["teacher"] |
| grad = None if teacher is None else teacher.grad |
| if teacher is None or grad is None: |
| raise RuntimeError( |
| "Auto selection hooks failed to capture outputs/gradients. " |
| "Try updating PyTorch or run with --layer <index>." |
| ) |
| grad_sq = grad.detach().pow(2).to(device=device, dtype=torch.float16) |
| cached_bytes += grad_sq.numel() * grad_sq.element_size() |
| if cached_bytes > _DWCE_GRAD_CACHE_MAX_BYTES: |
| raise _DwceGradCacheOverflow( |
| "DWCE grad cache exceeded device-memory budget during shared-backward scoring." |
| ) |
| grad_sq_cache.append(grad_sq) |
| _accumulate_fisher_from_grads(layer_a, fisher_sums[0], fisher_mode) |
| _accumulate_fisher_from_grads(layer_b, fisher_sums[1], fisher_mode) |
| model.zero_grad(set_to_none=True) |
| num_batches += 1 |
| finally: |
| handle_b.remove() |
| if inv_perm is not None: |
| permute_attention_heads( |
| attn_b, inv_perm, num_heads, num_kv_heads, head_dim=head_dim |
| ) |
| for param in model.parameters(): |
| param.requires_grad_(True) |
|
|
| if num_batches == 0: |
| raise RuntimeError("No batches processed; check dataset or text inputs.") |
|
|
| fisher_sums = _finalize_fisher_sums(fisher_sums, fisher_mode) |
| merge_layers( |
| layer_a_copy, |
| layer_b_copy, |
| fisher_sums[0], |
| fisher_sums[1], |
| num_batches, |
| param_numels[0], |
| param_numels[1], |
| fisher_mode=fisher_mode, |
| eps=eps, |
| ) |
| fuse_priors = _compute_fuse_priors( |
| layer_a, |
| layer_b, |
| fisher_sums, |
| num_batches, |
| param_numels, |
| fisher_mode, |
| eps, |
| ) |
|
|
| fused_layer = layer_a_copy |
| fused_layer.eval() |
| phase2_cache = {"teacher": None, "fused": None} |
|
|
| def hook_a(_module, inputs, output, kwargs=None): |
| with torch.no_grad(): |
| detached_inputs = tuple(_detach_arg(arg) for arg in inputs) |
| if kwargs: |
| detached_kwargs = {k: _detach_arg(v) for k, v in kwargs.items()} |
| fused_out = fused_layer(*detached_inputs, **detached_kwargs) |
| else: |
| fused_out = fused_layer(*detached_inputs) |
| fused_hidden = _extract_hidden(fused_out) |
| if fused_hidden is None: |
| raise RuntimeError("Failed to extract fused hidden state output.") |
| phase2_cache["fused"] = fused_hidden |
| return output |
|
|
| def hook_b_eval(_module, _inputs, output, _kwargs=None): |
| teacher_hidden = _extract_hidden(output) |
| if teacher_hidden is None: |
| raise RuntimeError("Failed to extract teacher hidden state output.") |
| phase2_cache["teacher"] = teacher_hidden |
| return output |
|
|
| handle_a, has_kwargs_a = _register_forward_hook(layer_a, hook_a) |
| handle_b_eval, has_kwargs_b = _register_forward_hook(layer_b, hook_b_eval) |
| supports_kwargs = has_kwargs_a and has_kwargs_b |
|
|
| score_num = 0.0 |
| score_den = 0.0 |
| token_count = 0.0 |
| try: |
| model.eval() |
| for batch_idx, batch in enumerate(dataloader): |
| if batch_idx >= num_batches: |
| break |
| phase2_cache["teacher"] = None |
| phase2_cache["fused"] = None |
| input_ids = batch[0].to(device) |
| attention_mask = batch[1].to(device) if len(batch) > 1 else None |
|
|
| with torch.no_grad(): |
| model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| use_cache=False, |
| ) |
|
|
| teacher = phase2_cache["teacher"] |
| fused = phase2_cache["fused"] |
| if teacher is None or fused is None: |
| raise RuntimeError( |
| "Auto selection hooks failed to capture outputs during DWCE replay." |
| ) |
| grad_sq = grad_sq_cache[batch_idx].to(dtype=torch.float32) |
| if attention_mask is not None: |
| mask = attention_mask.to(dtype=torch.float32).unsqueeze(-1) |
| batch_tokens = float(mask.sum().item()) |
| grad_sq = grad_sq * mask |
| else: |
| mask = None |
| batch_tokens = float(input_ids.numel()) |
| token_count += batch_tokens |
|
|
| delta = fused - teacher |
| if mask is not None: |
| delta = delta * mask |
| score_num += (delta.float().pow(2) * grad_sq).sum().item() |
| score_den += (teacher.float().pow(2) * grad_sq).sum().item() |
| finally: |
| handle_a.remove() |
| handle_b_eval.remove() |
|
|
| score = ( |
| score_num / (score_den + eps) |
| if norm == "relative" |
| else score_num / max(token_count, 1.0) |
| ) |
| meta = { |
| "num_batches": num_batches, |
| "token_count": token_count, |
| "norm": norm, |
| "supports_kwargs": supports_kwargs, |
| "fuse_priors": fuse_priors, |
| "metric": "dwce", |
| "dwce_mode": "shared", |
| } |
| return score, meta |
|
|
|
|
| def _compute_dwce_for_pair( |
| model, |
| layer_a: torch.nn.Module, |
| layer_b: torch.nn.Module, |
| fused_layer: torch.nn.Module, |
| dataloader, |
| device: str, |
| max_batches: int, |
| eps: float, |
| norm: str, |
| ) -> Tuple[float, Dict[str, object]]: |
| cache = {"teacher": None, "fused": None} |
| supports_kwargs = True |
|
|
| def hook_a(_module, inputs, output, kwargs=None): |
| with torch.no_grad(): |
| detached_inputs = tuple(_detach_arg(arg) for arg in inputs) |
| if kwargs is not None and len(kwargs) > 0: |
| detached_kwargs = {k: _detach_arg(v) for k, v in kwargs.items()} |
| fused_out = fused_layer(*detached_inputs, **detached_kwargs) |
| else: |
| fused_out = fused_layer(*detached_inputs) |
| fused_hidden = _extract_hidden(fused_out) |
| if fused_hidden is None: |
| raise RuntimeError("Failed to extract fused hidden state output.") |
| cache["fused"] = fused_hidden |
| return output |
|
|
| def hook_b(_module, _inputs, output, _kwargs=None): |
| teacher_hidden = _extract_hidden(output) |
| if teacher_hidden is None: |
| raise RuntimeError("Failed to extract teacher hidden state output.") |
| cache["teacher"] = teacher_hidden |
| if teacher_hidden.requires_grad: |
| teacher_hidden.retain_grad() |
| return output |
|
|
| handle_a, has_kwargs_a = _register_forward_hook(layer_a, hook_a) |
| handle_b, has_kwargs_b = _register_forward_hook(layer_b, hook_b) |
| supports_kwargs = has_kwargs_a and has_kwargs_b |
|
|
| score_num = 0.0 |
| score_den = 0.0 |
| token_count = 0.0 |
| num_batches = 0 |
|
|
| model.eval() |
| for batch_idx, batch in enumerate(dataloader): |
| if max_batches and batch_idx >= max_batches: |
| break |
| cache["teacher"] = None |
| cache["fused"] = None |
|
|
| input_ids = batch[0].to(device) |
| attention_mask = batch[1].to(device) if len(batch) > 1 else None |
|
|
| model.zero_grad(set_to_none=True) |
| outputs = model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| labels=input_ids, |
| ) |
| loss = outputs.loss |
| loss.backward() |
|
|
| teacher = cache["teacher"] |
| fused = cache["fused"] |
| grad = None if teacher is None else teacher.grad |
| if teacher is None or fused is None or grad is None: |
| raise RuntimeError( |
| "Auto selection hooks failed to capture outputs/gradients. " |
| "Try updating PyTorch or run with --layer <index>." |
| ) |
| if not teacher.requires_grad: |
| raise RuntimeError( |
| "Teacher hidden state does not require grad. " |
| "Ensure model parameters require grad for DWCE." |
| ) |
|
|
| with torch.no_grad(): |
| if attention_mask is not None: |
| mask = attention_mask.to(dtype=torch.float32).unsqueeze(-1) |
| batch_tokens = float(mask.sum().item()) |
| else: |
| mask = None |
| batch_tokens = float(input_ids.numel()) |
| token_count += batch_tokens |
|
|
| delta = fused - teacher |
| grad_sq = grad.pow(2) |
| if mask is not None: |
| delta = delta * mask |
| grad_sq = grad_sq * mask |
|
|
| score_num += (delta.pow(2) * grad_sq).sum().item() |
| score_den += (teacher.pow(2) * grad_sq).sum().item() |
| num_batches += 1 |
|
|
| handle_a.remove() |
| handle_b.remove() |
|
|
| if norm == "relative": |
| score = score_num / (score_den + eps) |
| else: |
| denom = token_count if token_count > 0 else 1.0 |
| score = score_num / denom |
|
|
| meta = { |
| "num_batches": num_batches, |
| "token_count": token_count, |
| "norm": norm, |
| "supports_kwargs": supports_kwargs, |
| } |
| return score, meta |
|
|
|
|
| def _compute_cosine_for_pair( |
| model, |
| layer_a: torch.nn.Module, |
| layer_b: torch.nn.Module, |
| dataloader, |
| device: str, |
| max_batches: int, |
| eps: float, |
| ) -> Tuple[float, Dict[str, object]]: |
| cache = {"a": None, "b": None} |
| supports_kwargs = True |
|
|
| def hook_a(_module, _inputs, output, _kwargs=None): |
| hidden = _extract_hidden(output) |
| if hidden is None: |
| raise RuntimeError("Failed to extract layer_a hidden state output.") |
| cache["a"] = hidden |
| return output |
|
|
| def hook_b(_module, _inputs, output, _kwargs=None): |
| hidden = _extract_hidden(output) |
| if hidden is None: |
| raise RuntimeError("Failed to extract layer_b hidden state output.") |
| cache["b"] = hidden |
| return output |
|
|
| handle_a, has_kwargs_a = _register_forward_hook(layer_a, hook_a) |
| handle_b, has_kwargs_b = _register_forward_hook(layer_b, hook_b) |
| supports_kwargs = has_kwargs_a and has_kwargs_b |
|
|
| score_sum = 0.0 |
| token_count = 0.0 |
| num_batches = 0 |
|
|
| model.eval() |
| for batch_idx, batch in enumerate(dataloader): |
| if max_batches and batch_idx >= max_batches: |
| break |
| cache["a"] = None |
| cache["b"] = None |
|
|
| input_ids = batch[0].to(device) |
| attention_mask = batch[1].to(device) if len(batch) > 1 else None |
|
|
| with torch.no_grad(): |
| model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| use_cache=False, |
| ) |
|
|
| hidden_a = cache["a"] |
| hidden_b = cache["b"] |
| if hidden_a is None or hidden_b is None: |
| raise RuntimeError( |
| "Auto selection hooks failed to capture outputs for cosine scoring." |
| ) |
|
|
| with torch.no_grad(): |
| a = hidden_a.float() |
| b = hidden_b.float() |
| cos = F.cosine_similarity(a, b, dim=-1, eps=eps) |
| distance = 1.0 - cos |
|
|
| if attention_mask is not None: |
| mask = attention_mask.to(dtype=torch.float32) |
| batch_tokens = float(mask.sum().item()) |
| distance = distance * mask |
| else: |
| batch_tokens = float(distance.numel()) |
|
|
| token_count += batch_tokens |
| score_sum += float(distance.sum().item()) |
| num_batches += 1 |
|
|
| handle_a.remove() |
| handle_b.remove() |
|
|
| denom = token_count if token_count > 0 else 1.0 |
| score = score_sum / denom |
| meta = { |
| "num_batches": num_batches, |
| "token_count": token_count, |
| "metric": "cosine", |
| "supports_kwargs": supports_kwargs, |
| } |
| return score, meta |
|
|
|
|
| def _compute_global_rel_change_for_pair( |
| model, |
| layers: List[torch.nn.Module], |
| pair_idx: int, |
| dataloader, |
| args, |
| max_batches: int, |
| eps: float, |
| ) -> Tuple[float, Dict[str, object]]: |
| hidden_size = _get_hidden_size(model) |
| head_permute_select = not bool(getattr(args, "no_head_permute_select", False)) |
| layer_a = layers[pair_idx] |
| layer_b = layers[pair_idx + 1] |
| fused_layer, fuse_priors = _build_fused_layer_for_pair( |
| model, |
| layer_a, |
| layer_b, |
| dataloader, |
| device=args.device, |
| fisher_mode=args.fisher_mode, |
| eps=eps, |
| hidden_size=hidden_size, |
| enable_head_permute=head_permute_select, |
| ) |
| fused_layer.to(args.device) |
| fused_layer.eval() |
|
|
| parent, name, container = find_layer_container(model, getattr(args, "layer_path", None)) |
| if len(list(container)) != len(layers): |
| raise RuntimeError("Layer container changed during auto-selection; aborting rerank.") |
|
|
| virtual_layers = list(layers) |
| virtual_layers[pair_idx] = fused_layer |
| del virtual_layers[pair_idx + 1] |
| if isinstance(container, torch.nn.ModuleList): |
| virtual_container = torch.nn.ModuleList(virtual_layers) |
| elif isinstance(container, list): |
| virtual_container = virtual_layers |
| else: |
| raise TypeError("Layer container must be ModuleList or list") |
|
|
| teacher_cache = {"pair": None, "final": None} |
| supports_kwargs = True |
|
|
| def hook_pair(_module, _inputs, output, _kwargs=None): |
| hidden = _extract_hidden(output) |
| if hidden is None: |
| raise RuntimeError("Failed to extract pair output for global relation rerank.") |
| teacher_cache["pair"] = hidden |
| return output |
|
|
| handle_pair, has_kwargs_pair = _register_forward_hook(layer_b, hook_pair) |
| supports_kwargs = supports_kwargs and has_kwargs_pair |
|
|
| score_sum = 0.0 |
| token_count = 0.0 |
| num_batches = 0 |
|
|
| model.eval() |
| for batch_idx, batch in enumerate(dataloader): |
| if max_batches and batch_idx >= max_batches: |
| break |
|
|
| teacher_cache["pair"] = None |
|
|
| input_ids = batch[0].to(args.device) |
| attention_mask = batch[1].to(args.device) if len(batch) > 1 else None |
|
|
| with torch.no_grad(): |
| teacher_outputs = model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| output_hidden_states=True, |
| use_cache=False, |
| ) |
| teacher_hidden_states = getattr(teacher_outputs, "hidden_states", None) |
| if not teacher_hidden_states: |
| raise RuntimeError("Teacher forward did not return hidden_states.") |
| teacher_final = teacher_hidden_states[-1] |
| teacher_pair = teacher_cache["pair"] |
|
|
| if teacher_pair is None or teacher_final is None: |
| raise RuntimeError( |
| "Failed to capture teacher pair/final hidden states for global rerank." |
| ) |
|
|
| with torch.no_grad(), _temporary_layers(parent, name, virtual_container): |
| fused_outputs = model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| output_hidden_states=True, |
| use_cache=False, |
| ) |
| fused_hidden_states = getattr(fused_outputs, "hidden_states", None) |
| if not fused_hidden_states: |
| raise RuntimeError("Fused forward did not return hidden_states.") |
| fused_final = fused_hidden_states[-1] |
|
|
| if fused_final is None: |
| raise RuntimeError("Failed to capture fused final hidden state for global rerank.") |
|
|
| with torch.no_grad(): |
| teacher_pair_f = teacher_pair.float() |
| teacher_final_f = teacher_final.float() |
| fused_final_f = fused_final.float() |
|
|
| teacher_rel = F.cosine_similarity( |
| teacher_pair_f, teacher_final_f, dim=-1, eps=eps |
| ) |
| fused_rel = F.cosine_similarity( |
| teacher_pair_f, fused_final_f, dim=-1, eps=eps |
| ) |
| rel_change = (teacher_rel - fused_rel).abs() |
|
|
| if attention_mask is not None: |
| mask = attention_mask.to(dtype=torch.float32) |
| batch_tokens = float(mask.sum().item()) |
| rel_change = rel_change * mask |
| else: |
| batch_tokens = float(rel_change.numel()) |
|
|
| token_count += batch_tokens |
| score_sum += float(rel_change.sum().item()) |
| num_batches += 1 |
|
|
| handle_pair.remove() |
| del fused_layer |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
|
|
| denom = token_count if token_count > 0 else 1.0 |
| score = score_sum / denom |
| meta = { |
| "num_batches": num_batches, |
| "token_count": token_count, |
| "metric": "global_rel_change", |
| "supports_kwargs": supports_kwargs, |
| "fuse_priors": fuse_priors, |
| } |
| return score, meta |
|
|
|
|
| def select_layer_auto( |
| model, |
| layers: List[torch.nn.Module], |
| dataloader, |
| args, |
| previous_scores: Optional[List[float]] = None, |
| start_index: int = 0, |
| exclude_pairs: Optional[Set[int]] = None, |
| ) -> Tuple[int, List[float], Dict[str, object]]: |
| num_layers = len(layers) |
| if num_layers < 2: |
| raise SystemExit("Model must have at least 2 layers for auto selection.") |
|
|
| hidden_size = _get_hidden_size(model) |
| num_pairs = num_layers - 1 |
| scores: List[float] = [float("inf")] * num_pairs |
| meta_per_pair: List[Optional[Dict[str, object]]] = [None] * num_pairs |
| supports_kwargs_all = True |
| head_permute_select = not bool(getattr(args, "no_head_permute_select", False)) |
| exclude_set: Set[int] = { |
| int(idx) |
| for idx in (exclude_pairs or set()) |
| if isinstance(idx, int) and 0 <= int(idx) < num_pairs |
| } |
|
|
| max_batches = args.auto_max_batches |
| start_index = max(0, min(start_index, num_pairs)) |
| auto_metric = str(getattr(args, "auto_metric", "dwce")).strip().lower() |
| if auto_metric == "hybrid": |
| auto_metric = "hybrid_cosine" |
| if auto_metric not in { |
| "dwce", |
| "cosine", |
| "hybrid_cosine", |
| "hybrid_global_rel", |
| }: |
| raise SystemExit( |
| "--auto_metric must be one of: dwce, cosine, hybrid, " |
| "hybrid_cosine, hybrid_global_rel" |
| ) |
| auto_cosine_topk = int(getattr(args, "auto_cosine_topk", 3)) |
| if auto_cosine_topk <= 0: |
| raise SystemExit("--auto_cosine_topk must be >= 1") |
| print( |
| f"[auto] metric={auto_metric}; using " |
| f"{('all' if max_batches == 0 else max_batches)} batches " |
| "from calibration samples." |
| ) |
|
|
| reuse_upto = 0 |
| allow_reuse = auto_metric == "dwce" |
| if previous_scores: |
| reuse_upto = min(start_index, len(previous_scores), num_pairs) if allow_reuse else 0 |
| for idx in range(reuse_upto): |
| if idx in exclude_set: |
| scores[idx] = float("inf") |
| meta_per_pair[idx] = {"excluded": True} |
| print(f"[auto] skipped excluded pair {idx}-{idx+1}.") |
| continue |
| scores[idx] = previous_scores[idx] |
| meta_per_pair[idx] = ( |
| { |
| "num_batches": 0, |
| "token_count": 0.0, |
| "norm": args.auto_norm, |
| "metric": auto_metric, |
| "supports_kwargs": True, |
| "reused": True, |
| } |
| ) |
| print(f"[auto] reused pair {idx}-{idx+1}: {scores[idx]:.6e}") |
|
|
| compute_start = start_index if reuse_upto == start_index else reuse_upto |
| pairs_to_score: List[int] = [] |
| for idx in range(compute_start, num_pairs): |
| if idx in exclude_set: |
| scores[idx] = float("inf") |
| meta_per_pair[idx] = {"excluded": True} |
| print(f"[auto] skipped excluded pair {idx}-{idx+1}.") |
| continue |
| pairs_to_score.append(idx) |
|
|
| def _score_dwce_for_pair(idx: int) -> Tuple[float, Dict[str, object]]: |
| print(f"[auto] building fused pair {idx}-{idx+1} for DWCE...") |
| layer_a = layers[idx] |
| layer_b = layers[idx + 1] |
| dwce_mode = str(getattr(args, "auto_dwce_mode", "separate")).strip().lower() |
| if dwce_mode == "shared": |
| try: |
| return _score_dwce_with_shared_backward( |
| model, |
| layer_a, |
| layer_b, |
| dataloader, |
| device=args.device, |
| fisher_mode=args.fisher_mode, |
| max_batches=max_batches, |
| eps=args.eps, |
| norm=args.auto_norm, |
| hidden_size=hidden_size, |
| enable_head_permute=head_permute_select, |
| ) |
| except _DwceGradCacheOverflow: |
| print( |
| "[auto] shared-backward DWCE cache exceeded budget; " |
| "falling back to separate mode." |
| ) |
| fused, fuse_priors = _build_fused_layer_for_pair( |
| model, |
| layer_a, |
| layer_b, |
| dataloader, |
| device=args.device, |
| fisher_mode=args.fisher_mode, |
| eps=args.eps, |
| hidden_size=hidden_size, |
| enable_head_permute=head_permute_select, |
| ) |
| fused.to(args.device) |
| fused.eval() |
| for param in model.parameters(): |
| param.requires_grad_(True) |
| score, meta = _compute_dwce_for_pair( |
| model, |
| layer_a, |
| layer_b, |
| fused, |
| dataloader, |
| device=args.device, |
| max_batches=max_batches, |
| eps=args.eps, |
| norm=args.auto_norm, |
| ) |
| meta["fuse_priors"] = fuse_priors |
| meta["metric"] = "dwce" |
| del fused |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| return score, meta |
|
|
| def _score_cosine_for_pair(idx: int) -> Tuple[float, Dict[str, object]]: |
| print(f"[auto] scoring cosine for pair {idx}-{idx+1}...") |
| layer_a = layers[idx] |
| layer_b = layers[idx + 1] |
| return _compute_cosine_for_pair( |
| model, |
| layer_a, |
| layer_b, |
| dataloader, |
| device=args.device, |
| max_batches=max_batches, |
| eps=args.eps, |
| ) |
|
|
| def _score_global_rel_for_pair(idx: int) -> Tuple[float, Dict[str, object]]: |
| print(f"[auto] scoring global relation change for pair {idx}-{idx+1}...") |
| return _compute_global_rel_change_for_pair( |
| model, |
| layers, |
| idx, |
| dataloader, |
| args=args, |
| max_batches=max_batches, |
| eps=args.eps, |
| ) |
|
|
| if auto_metric in {"dwce", "cosine"}: |
| for idx in pairs_to_score: |
| if auto_metric == "dwce": |
| score, meta = _score_dwce_for_pair(idx) |
| else: |
| score, meta = _score_cosine_for_pair(idx) |
| supports_kwargs_all = supports_kwargs_all and meta.get("supports_kwargs", True) |
| scores[idx] = score |
| meta_per_pair[idx] = meta |
| print(f"[auto] {auto_metric} pair {idx}-{idx+1}: {score:.6e}") |
| else: |
| dwce_prefilter: Dict[int, float] = {} |
| for idx in pairs_to_score: |
| score, meta = _score_dwce_for_pair(idx) |
| dwce_prefilter[idx] = score |
| supports_kwargs_all = supports_kwargs_all and meta.get("supports_kwargs", True) |
| meta_per_pair[idx] = { |
| "prefilter_dwce": score, |
| "dwce_meta": meta, |
| "metric": "hybrid", |
| } |
| print(f"[auto] hybrid prefilter DWCE pair {idx}-{idx+1}: {score:.6e}") |
| ranked = sorted(pairs_to_score, key=lambda i: float(dwce_prefilter[i])) |
| shortlist = ranked[: min(auto_cosine_topk, len(ranked))] |
| print(f"[auto] hybrid shortlist (dwce top-{len(shortlist)}): {shortlist}") |
| for idx in shortlist: |
| if auto_metric == "hybrid_global_rel": |
| score, rerank_meta = _score_global_rel_for_pair(idx) |
| score_metric = "global_rel_change" |
| else: |
| score, rerank_meta = _score_cosine_for_pair(idx) |
| score_metric = "cosine" |
| supports_kwargs_all = supports_kwargs_all and rerank_meta.get( |
| "supports_kwargs", True |
| ) |
| scores[idx] = score |
| pair_meta = meta_per_pair[idx] or {} |
| pair_meta["rerank_meta"] = rerank_meta |
| pair_meta["score_metric"] = score_metric |
| meta_per_pair[idx] = pair_meta |
| print(f"[auto] hybrid {score_metric} pair {idx}-{idx+1}: {score:.6e}") |
|
|
| if not supports_kwargs_all: |
| print( |
| "[auto] Warning: forward hooks did not capture kwargs; " |
| "fused-layer calls may be approximate." |
| ) |
|
|
| print(f"[auto] score summary (metric={auto_metric}, norm={args.auto_norm}):") |
| for idx, score in enumerate(scores): |
| if idx in exclude_set: |
| print(f"[auto] pair {idx}-{idx+1}: excluded") |
| elif math.isfinite(float(score)): |
| print(f"[auto] pair {idx}-{idx+1}: {score:.6e}") |
| else: |
| print(f"[auto] pair {idx}-{idx+1}: {score}") |
|
|
| candidates = [i for i in range(num_pairs) if i not in exclude_set] |
| if not candidates: |
| raise SystemExit("All pairs are excluded; cannot auto-select a fusion layer.") |
| best_idx = min(candidates, key=lambda i: scores[i]) |
| best_score = float(scores[best_idx]) |
| if not math.isfinite(best_score): |
| raise SystemExit( |
| "Auto selection failed: all candidate pairs have non-finite scores " |
| "(check --exclude_pairs and data)." |
| ) |
| print(f"[auto] Selected layer {best_idx} (score={best_score:.6e})") |
|
|
| meta = { |
| "per_pair": meta_per_pair, |
| "supports_kwargs": supports_kwargs_all, |
| "max_batches": max_batches, |
| "norm": args.auto_norm, |
| "metric": auto_metric, |
| "cosine_topk": auto_cosine_topk, |
| "start_index": start_index, |
| "excluded_pairs": sorted(exclude_set), |
| } |
| return best_idx, scores, meta |
|
|