temp_ss / src /fuse_layers_select.py
LJYAI's picture
upload src
2c44909 verified
#!/usr/bin/env python3
"""Automatic adjacent-pair selection via configurable scoring metrics."""
import copy
import math
from contextlib import contextmanager
from typing import Dict, List, Optional, Set, Tuple
import torch
import torch.nn.functional as F
from fuse_layers_model import (
build_head_permutation,
compute_fisher,
compute_head_means,
find_attention_module,
find_layer_container,
merge_layers,
permute_attention_heads,
)
_DWCE_GRAD_CACHE_MAX_BYTES = 1 << 30
class _DwceGradCacheOverflow(RuntimeError):
"""Raised when shared-backward DWCE caching exceeds the configured budget."""
def _get_hidden_size(model) -> int:
hidden_size = getattr(model.config, "hidden_size", None)
if hidden_size is None:
hidden_size = getattr(model.config, "n_embd", None)
if hidden_size is None:
raise SystemExit("Model config missing hidden_size/n_embd")
return int(hidden_size)
def _detach_arg(arg):
if torch.is_tensor(arg):
return arg.detach()
if isinstance(arg, (list, tuple)):
return type(arg)(_detach_arg(x) for x in arg)
if isinstance(arg, dict):
return {k: _detach_arg(v) for k, v in arg.items()}
return arg
def _register_forward_hook(layer, hook):
try:
def wrapper(module, inputs, kwargs, output):
return hook(module, inputs, output, kwargs)
handle = layer.register_forward_hook(wrapper, with_kwargs=True)
return handle, True
except TypeError:
def wrapper(module, inputs, output):
return hook(module, inputs, output, None)
handle = layer.register_forward_hook(wrapper)
return handle, False
@contextmanager
def _temporary_layers(parent: object, name: str, new_layers: object):
original = getattr(parent, name)
setattr(parent, name, new_layers)
try:
yield
finally:
setattr(parent, name, original)
def _extract_hidden(output):
if torch.is_tensor(output):
return output
if isinstance(output, (tuple, list)):
if output and all(torch.is_tensor(item) for item in output):
return output[0]
for item in output:
hidden = _extract_hidden(item)
if hidden is not None:
return hidden
return None
if isinstance(output, dict):
for key in ("hidden_states", "last_hidden_state", "hidden_state"):
if key in output:
value = output[key]
if isinstance(value, (tuple, list)) and value and all(
torch.is_tensor(item) for item in value
):
return value[-1]
hidden = _extract_hidden(value)
if hidden is not None:
return hidden
for value in output.values():
hidden = _extract_hidden(value)
if hidden is not None:
return hidden
return None
for attr in ("hidden_states", "last_hidden_state"):
if hasattr(output, attr):
value = getattr(output, attr)
if isinstance(value, (tuple, list)) and value and all(
torch.is_tensor(item) for item in value
):
return value[-1]
hidden = _extract_hidden(value)
if hidden is not None:
return hidden
return None
def _build_fused_layer_for_pair(
model,
layer_a: torch.nn.Module,
layer_b: torch.nn.Module,
dataloader,
device: str,
fisher_mode: str,
eps: float,
hidden_size: int,
enable_head_permute: bool = True,
) -> Tuple[torch.nn.Module, Dict[str, float]]:
attn_a = find_attention_module(layer_a)
attn_b = find_attention_module(layer_b)
perm = None
inv_perm = None
num_heads = None
num_kv_heads = None
head_dim = None
if enable_head_permute:
mean_a, mean_b, num_heads, num_kv_heads, head_dim = compute_head_means(
model,
attn_a,
attn_b,
dataloader,
device,
hidden_size,
)
perm = build_head_permutation(
mean_a,
mean_b,
num_heads=num_heads,
num_kv_heads=num_kv_heads,
eps=eps,
)
layer_a_copy = copy.deepcopy(layer_a)
layer_b_copy = copy.deepcopy(layer_b)
attn_b_copy = find_attention_module(layer_b_copy)
if perm is not None:
permute_attention_heads(
attn_b_copy, perm, num_heads, num_kv_heads, head_dim=head_dim
)
inv_perm = [0] * len(perm)
for idx, mapped in enumerate(perm):
inv_perm[mapped] = idx
permute_attention_heads(attn_b, perm, num_heads, num_kv_heads, head_dim=head_dim)
try:
fisher_sums, num_batches, param_numels = compute_fisher(
model,
layer_a,
layer_b,
dataloader,
fisher_mode=fisher_mode,
device=device,
)
finally:
if inv_perm is not None:
permute_attention_heads(
attn_b, inv_perm, num_heads, num_kv_heads, head_dim=head_dim
)
merge_layers(
layer_a_copy,
layer_b_copy,
fisher_sums[0],
fisher_sums[1],
num_batches,
param_numels[0],
param_numels[1],
fisher_mode=fisher_mode,
eps=eps,
)
# Scalar mixing coefficients per parameter tensor; used by pressure redistribution
# to simulate future fusions without running another Fisher pass.
fuse_priors: Dict[str, float] = {}
params_b = {name: param for name, param in layer_b.named_parameters()}
clamp_eps = 1e-4
for name, param_a in layer_a.named_parameters():
param_b = params_b.get(name)
if param_b is None or param_b.shape != param_a.shape:
continue
if fisher_mode == "param":
fa = fisher_sums[0][name] / max(num_batches, 1)
fb = fisher_sums[1][name] / max(num_batches, 1)
if isinstance(fa, torch.Tensor):
fa_val = float(fa.mean().item())
else:
fa_val = float(fa)
if isinstance(fb, torch.Tensor):
fb_val = float(fb.mean().item())
else:
fb_val = float(fb)
else:
fa_val = float(
fisher_sums[0][name]
/ (max(num_batches, 1) * max(param_numels[0].get(name, 1), 1))
)
fb_val = float(
fisher_sums[1][name]
/ (max(num_batches, 1) * max(param_numels[1].get(name, 1), 1))
)
denom = fa_val + fb_val
if denom <= eps:
lam = 0.5
else:
lam = fa_val / (denom + eps)
lam = min(max(lam, clamp_eps), 1.0 - clamp_eps)
fuse_priors[name] = lam
layer_a_copy.eval()
return layer_a_copy, fuse_priors
def _init_fisher_accumulators(
layer_a: torch.nn.Module,
layer_b: torch.nn.Module,
fisher_mode: str,
device: str,
) -> Tuple[List[Dict[str, object]], List[Dict[str, int]]]:
fisher_sums: List[Dict[str, object]] = []
param_numels: List[Dict[str, int]] = []
for layer in (layer_a, layer_b):
layer_sums: Dict[str, object] = {}
layer_numels: Dict[str, int] = {}
for name, param in layer.named_parameters():
if not param.requires_grad:
continue
if fisher_mode == "param":
layer_sums[name] = torch.zeros_like(
param, dtype=torch.float32, device="cpu"
)
else:
layer_sums[name] = torch.zeros((), dtype=torch.float32, device=device)
layer_numels[name] = param.numel()
fisher_sums.append(layer_sums)
param_numels.append(layer_numels)
return fisher_sums, param_numels
def _accumulate_fisher_from_grads(
layer: torch.nn.Module,
layer_sums: Dict[str, object],
fisher_mode: str,
) -> None:
for name, param in layer.named_parameters():
if not param.requires_grad or param.grad is None:
continue
grad_sq = param.grad.detach().float().pow(2)
if fisher_mode == "param":
layer_sums[name] += grad_sq.cpu()
else:
layer_sums[name] += grad_sq.sum()
def _finalize_fisher_sums(
fisher_sums: List[Dict[str, object]],
fisher_mode: str,
) -> List[Dict[str, object]]:
if fisher_mode == "param":
return fisher_sums
finalized: List[Dict[str, object]] = []
for layer_sums in fisher_sums:
finalized_layer: Dict[str, object] = {}
for name, value in layer_sums.items():
if isinstance(value, torch.Tensor):
finalized_layer[name] = float(value.detach().cpu().item())
else:
finalized_layer[name] = float(value)
finalized.append(finalized_layer)
return finalized
def _compute_fuse_priors(
layer_a: torch.nn.Module,
layer_b: torch.nn.Module,
fisher_sums: List[Dict[str, object]],
num_batches: int,
param_numels: List[Dict[str, int]],
fisher_mode: str,
eps: float,
) -> Dict[str, float]:
fuse_priors: Dict[str, float] = {}
params_b = {name: param for name, param in layer_b.named_parameters()}
clamp_eps = 1e-4
for name, param_a in layer_a.named_parameters():
param_b = params_b.get(name)
if param_b is None or param_b.shape != param_a.shape:
continue
if fisher_mode == "param":
fa = fisher_sums[0][name] / max(num_batches, 1)
fb = fisher_sums[1][name] / max(num_batches, 1)
fa_val = float(fa.mean().item()) if isinstance(fa, torch.Tensor) else float(fa)
fb_val = float(fb.mean().item()) if isinstance(fb, torch.Tensor) else float(fb)
else:
fa_val = float(
fisher_sums[0][name]
/ (max(num_batches, 1) * max(param_numels[0].get(name, 1), 1))
)
fb_val = float(
fisher_sums[1][name]
/ (max(num_batches, 1) * max(param_numels[1].get(name, 1), 1))
)
denom = fa_val + fb_val
lam = 0.5 if denom <= eps else fa_val / (denom + eps)
fuse_priors[name] = min(max(lam, clamp_eps), 1.0 - clamp_eps)
return fuse_priors
def _score_dwce_with_shared_backward(
model,
layer_a: torch.nn.Module,
layer_b: torch.nn.Module,
dataloader,
device: str,
fisher_mode: str,
max_batches: int,
eps: float,
norm: str,
hidden_size: int,
enable_head_permute: bool = True,
) -> Tuple[float, Dict[str, object]]:
attn_a = find_attention_module(layer_a)
attn_b = find_attention_module(layer_b)
perm = None
inv_perm = None
num_heads = None
num_kv_heads = None
head_dim = None
if enable_head_permute:
mean_a, mean_b, num_heads, num_kv_heads, head_dim = compute_head_means(
model,
attn_a,
attn_b,
dataloader,
device,
hidden_size,
)
perm = build_head_permutation(
mean_a,
mean_b,
num_heads=num_heads,
num_kv_heads=num_kv_heads,
eps=eps,
)
layer_a_copy = copy.deepcopy(layer_a)
layer_b_copy = copy.deepcopy(layer_b)
attn_b_copy = find_attention_module(layer_b_copy)
if perm is not None:
permute_attention_heads(
attn_b_copy, perm, num_heads, num_kv_heads, head_dim=head_dim
)
inv_perm = [0] * len(perm)
for idx, mapped in enumerate(perm):
inv_perm[mapped] = idx
cache: Dict[str, Optional[torch.Tensor]] = {"teacher": None}
grad_sq_cache: List[torch.Tensor] = []
cached_bytes = 0
def hook_b(_module, _inputs, output, _kwargs=None):
teacher_hidden = _extract_hidden(output)
if teacher_hidden is None:
raise RuntimeError("Failed to extract teacher hidden state output.")
cache["teacher"] = teacher_hidden
if teacher_hidden.requires_grad:
teacher_hidden.retain_grad()
return output
handle_b, _ = _register_forward_hook(layer_b, hook_b)
for param in model.parameters():
param.requires_grad_(False)
for layer in (layer_a, layer_b):
for param in layer.parameters():
param.requires_grad_(True)
fisher_sums, param_numels = _init_fisher_accumulators(
layer_a, layer_b, fisher_mode, device
)
num_batches = 0
if perm is not None:
permute_attention_heads(attn_b, perm, num_heads, num_kv_heads, head_dim=head_dim)
try:
model.eval()
for batch_idx, batch in enumerate(dataloader):
if max_batches and batch_idx >= max_batches:
break
cache["teacher"] = None
input_ids = batch[0].to(device)
attention_mask = batch[1].to(device) if len(batch) > 1 else None
model.zero_grad(set_to_none=True)
outputs = model(
input_ids=input_ids,
attention_mask=attention_mask,
labels=input_ids,
)
outputs.loss.backward()
teacher = cache["teacher"]
grad = None if teacher is None else teacher.grad
if teacher is None or grad is None:
raise RuntimeError(
"Auto selection hooks failed to capture outputs/gradients. "
"Try updating PyTorch or run with --layer <index>."
)
grad_sq = grad.detach().pow(2).to(device=device, dtype=torch.float16)
cached_bytes += grad_sq.numel() * grad_sq.element_size()
if cached_bytes > _DWCE_GRAD_CACHE_MAX_BYTES:
raise _DwceGradCacheOverflow(
"DWCE grad cache exceeded device-memory budget during shared-backward scoring."
)
grad_sq_cache.append(grad_sq)
_accumulate_fisher_from_grads(layer_a, fisher_sums[0], fisher_mode)
_accumulate_fisher_from_grads(layer_b, fisher_sums[1], fisher_mode)
model.zero_grad(set_to_none=True)
num_batches += 1
finally:
handle_b.remove()
if inv_perm is not None:
permute_attention_heads(
attn_b, inv_perm, num_heads, num_kv_heads, head_dim=head_dim
)
for param in model.parameters():
param.requires_grad_(True)
if num_batches == 0:
raise RuntimeError("No batches processed; check dataset or text inputs.")
fisher_sums = _finalize_fisher_sums(fisher_sums, fisher_mode)
merge_layers(
layer_a_copy,
layer_b_copy,
fisher_sums[0],
fisher_sums[1],
num_batches,
param_numels[0],
param_numels[1],
fisher_mode=fisher_mode,
eps=eps,
)
fuse_priors = _compute_fuse_priors(
layer_a,
layer_b,
fisher_sums,
num_batches,
param_numels,
fisher_mode,
eps,
)
fused_layer = layer_a_copy
fused_layer.eval()
phase2_cache = {"teacher": None, "fused": None}
def hook_a(_module, inputs, output, kwargs=None):
with torch.no_grad():
detached_inputs = tuple(_detach_arg(arg) for arg in inputs)
if kwargs:
detached_kwargs = {k: _detach_arg(v) for k, v in kwargs.items()}
fused_out = fused_layer(*detached_inputs, **detached_kwargs)
else:
fused_out = fused_layer(*detached_inputs)
fused_hidden = _extract_hidden(fused_out)
if fused_hidden is None:
raise RuntimeError("Failed to extract fused hidden state output.")
phase2_cache["fused"] = fused_hidden
return output
def hook_b_eval(_module, _inputs, output, _kwargs=None):
teacher_hidden = _extract_hidden(output)
if teacher_hidden is None:
raise RuntimeError("Failed to extract teacher hidden state output.")
phase2_cache["teacher"] = teacher_hidden
return output
handle_a, has_kwargs_a = _register_forward_hook(layer_a, hook_a)
handle_b_eval, has_kwargs_b = _register_forward_hook(layer_b, hook_b_eval)
supports_kwargs = has_kwargs_a and has_kwargs_b
score_num = 0.0
score_den = 0.0
token_count = 0.0
try:
model.eval()
for batch_idx, batch in enumerate(dataloader):
if batch_idx >= num_batches:
break
phase2_cache["teacher"] = None
phase2_cache["fused"] = None
input_ids = batch[0].to(device)
attention_mask = batch[1].to(device) if len(batch) > 1 else None
with torch.no_grad():
model(
input_ids=input_ids,
attention_mask=attention_mask,
use_cache=False,
)
teacher = phase2_cache["teacher"]
fused = phase2_cache["fused"]
if teacher is None or fused is None:
raise RuntimeError(
"Auto selection hooks failed to capture outputs during DWCE replay."
)
grad_sq = grad_sq_cache[batch_idx].to(dtype=torch.float32)
if attention_mask is not None:
mask = attention_mask.to(dtype=torch.float32).unsqueeze(-1)
batch_tokens = float(mask.sum().item())
grad_sq = grad_sq * mask
else:
mask = None
batch_tokens = float(input_ids.numel())
token_count += batch_tokens
delta = fused - teacher
if mask is not None:
delta = delta * mask
score_num += (delta.float().pow(2) * grad_sq).sum().item()
score_den += (teacher.float().pow(2) * grad_sq).sum().item()
finally:
handle_a.remove()
handle_b_eval.remove()
score = (
score_num / (score_den + eps)
if norm == "relative"
else score_num / max(token_count, 1.0)
)
meta = {
"num_batches": num_batches,
"token_count": token_count,
"norm": norm,
"supports_kwargs": supports_kwargs,
"fuse_priors": fuse_priors,
"metric": "dwce",
"dwce_mode": "shared",
}
return score, meta
def _compute_dwce_for_pair(
model,
layer_a: torch.nn.Module,
layer_b: torch.nn.Module,
fused_layer: torch.nn.Module,
dataloader,
device: str,
max_batches: int,
eps: float,
norm: str,
) -> Tuple[float, Dict[str, object]]:
cache = {"teacher": None, "fused": None}
supports_kwargs = True
def hook_a(_module, inputs, output, kwargs=None):
with torch.no_grad():
detached_inputs = tuple(_detach_arg(arg) for arg in inputs)
if kwargs is not None and len(kwargs) > 0:
detached_kwargs = {k: _detach_arg(v) for k, v in kwargs.items()}
fused_out = fused_layer(*detached_inputs, **detached_kwargs)
else:
fused_out = fused_layer(*detached_inputs)
fused_hidden = _extract_hidden(fused_out)
if fused_hidden is None:
raise RuntimeError("Failed to extract fused hidden state output.")
cache["fused"] = fused_hidden
return output
def hook_b(_module, _inputs, output, _kwargs=None):
teacher_hidden = _extract_hidden(output)
if teacher_hidden is None:
raise RuntimeError("Failed to extract teacher hidden state output.")
cache["teacher"] = teacher_hidden
if teacher_hidden.requires_grad:
teacher_hidden.retain_grad()
return output
handle_a, has_kwargs_a = _register_forward_hook(layer_a, hook_a)
handle_b, has_kwargs_b = _register_forward_hook(layer_b, hook_b)
supports_kwargs = has_kwargs_a and has_kwargs_b
score_num = 0.0
score_den = 0.0
token_count = 0.0
num_batches = 0
model.eval()
for batch_idx, batch in enumerate(dataloader):
if max_batches and batch_idx >= max_batches:
break
cache["teacher"] = None
cache["fused"] = None
input_ids = batch[0].to(device)
attention_mask = batch[1].to(device) if len(batch) > 1 else None
model.zero_grad(set_to_none=True)
outputs = model(
input_ids=input_ids,
attention_mask=attention_mask,
labels=input_ids,
)
loss = outputs.loss
loss.backward()
teacher = cache["teacher"]
fused = cache["fused"]
grad = None if teacher is None else teacher.grad
if teacher is None or fused is None or grad is None:
raise RuntimeError(
"Auto selection hooks failed to capture outputs/gradients. "
"Try updating PyTorch or run with --layer <index>."
)
if not teacher.requires_grad:
raise RuntimeError(
"Teacher hidden state does not require grad. "
"Ensure model parameters require grad for DWCE."
)
with torch.no_grad():
if attention_mask is not None:
mask = attention_mask.to(dtype=torch.float32).unsqueeze(-1)
batch_tokens = float(mask.sum().item())
else:
mask = None
batch_tokens = float(input_ids.numel())
token_count += batch_tokens
delta = fused - teacher
grad_sq = grad.pow(2)
if mask is not None:
delta = delta * mask
grad_sq = grad_sq * mask
score_num += (delta.pow(2) * grad_sq).sum().item()
score_den += (teacher.pow(2) * grad_sq).sum().item()
num_batches += 1
handle_a.remove()
handle_b.remove()
if norm == "relative":
score = score_num / (score_den + eps)
else:
denom = token_count if token_count > 0 else 1.0
score = score_num / denom
meta = {
"num_batches": num_batches,
"token_count": token_count,
"norm": norm,
"supports_kwargs": supports_kwargs,
}
return score, meta
def _compute_cosine_for_pair(
model,
layer_a: torch.nn.Module,
layer_b: torch.nn.Module,
dataloader,
device: str,
max_batches: int,
eps: float,
) -> Tuple[float, Dict[str, object]]:
cache = {"a": None, "b": None}
supports_kwargs = True
def hook_a(_module, _inputs, output, _kwargs=None):
hidden = _extract_hidden(output)
if hidden is None:
raise RuntimeError("Failed to extract layer_a hidden state output.")
cache["a"] = hidden
return output
def hook_b(_module, _inputs, output, _kwargs=None):
hidden = _extract_hidden(output)
if hidden is None:
raise RuntimeError("Failed to extract layer_b hidden state output.")
cache["b"] = hidden
return output
handle_a, has_kwargs_a = _register_forward_hook(layer_a, hook_a)
handle_b, has_kwargs_b = _register_forward_hook(layer_b, hook_b)
supports_kwargs = has_kwargs_a and has_kwargs_b
score_sum = 0.0
token_count = 0.0
num_batches = 0
model.eval()
for batch_idx, batch in enumerate(dataloader):
if max_batches and batch_idx >= max_batches:
break
cache["a"] = None
cache["b"] = None
input_ids = batch[0].to(device)
attention_mask = batch[1].to(device) if len(batch) > 1 else None
with torch.no_grad():
model(
input_ids=input_ids,
attention_mask=attention_mask,
use_cache=False,
)
hidden_a = cache["a"]
hidden_b = cache["b"]
if hidden_a is None or hidden_b is None:
raise RuntimeError(
"Auto selection hooks failed to capture outputs for cosine scoring."
)
with torch.no_grad():
a = hidden_a.float()
b = hidden_b.float()
cos = F.cosine_similarity(a, b, dim=-1, eps=eps)
distance = 1.0 - cos
if attention_mask is not None:
mask = attention_mask.to(dtype=torch.float32)
batch_tokens = float(mask.sum().item())
distance = distance * mask
else:
batch_tokens = float(distance.numel())
token_count += batch_tokens
score_sum += float(distance.sum().item())
num_batches += 1
handle_a.remove()
handle_b.remove()
denom = token_count if token_count > 0 else 1.0
score = score_sum / denom
meta = {
"num_batches": num_batches,
"token_count": token_count,
"metric": "cosine",
"supports_kwargs": supports_kwargs,
}
return score, meta
def _compute_global_rel_change_for_pair(
model,
layers: List[torch.nn.Module],
pair_idx: int,
dataloader,
args,
max_batches: int,
eps: float,
) -> Tuple[float, Dict[str, object]]:
hidden_size = _get_hidden_size(model)
head_permute_select = not bool(getattr(args, "no_head_permute_select", False))
layer_a = layers[pair_idx]
layer_b = layers[pair_idx + 1]
fused_layer, fuse_priors = _build_fused_layer_for_pair(
model,
layer_a,
layer_b,
dataloader,
device=args.device,
fisher_mode=args.fisher_mode,
eps=eps,
hidden_size=hidden_size,
enable_head_permute=head_permute_select,
)
fused_layer.to(args.device)
fused_layer.eval()
parent, name, container = find_layer_container(model, getattr(args, "layer_path", None))
if len(list(container)) != len(layers):
raise RuntimeError("Layer container changed during auto-selection; aborting rerank.")
virtual_layers = list(layers)
virtual_layers[pair_idx] = fused_layer
del virtual_layers[pair_idx + 1]
if isinstance(container, torch.nn.ModuleList):
virtual_container = torch.nn.ModuleList(virtual_layers)
elif isinstance(container, list):
virtual_container = virtual_layers
else:
raise TypeError("Layer container must be ModuleList or list")
teacher_cache = {"pair": None, "final": None}
supports_kwargs = True
def hook_pair(_module, _inputs, output, _kwargs=None):
hidden = _extract_hidden(output)
if hidden is None:
raise RuntimeError("Failed to extract pair output for global relation rerank.")
teacher_cache["pair"] = hidden
return output
handle_pair, has_kwargs_pair = _register_forward_hook(layer_b, hook_pair)
supports_kwargs = supports_kwargs and has_kwargs_pair
score_sum = 0.0
token_count = 0.0
num_batches = 0
model.eval()
for batch_idx, batch in enumerate(dataloader):
if max_batches and batch_idx >= max_batches:
break
teacher_cache["pair"] = None
input_ids = batch[0].to(args.device)
attention_mask = batch[1].to(args.device) if len(batch) > 1 else None
with torch.no_grad():
teacher_outputs = model(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
use_cache=False,
)
teacher_hidden_states = getattr(teacher_outputs, "hidden_states", None)
if not teacher_hidden_states:
raise RuntimeError("Teacher forward did not return hidden_states.")
teacher_final = teacher_hidden_states[-1]
teacher_pair = teacher_cache["pair"]
if teacher_pair is None or teacher_final is None:
raise RuntimeError(
"Failed to capture teacher pair/final hidden states for global rerank."
)
with torch.no_grad(), _temporary_layers(parent, name, virtual_container):
fused_outputs = model(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
use_cache=False,
)
fused_hidden_states = getattr(fused_outputs, "hidden_states", None)
if not fused_hidden_states:
raise RuntimeError("Fused forward did not return hidden_states.")
fused_final = fused_hidden_states[-1]
if fused_final is None:
raise RuntimeError("Failed to capture fused final hidden state for global rerank.")
with torch.no_grad():
teacher_pair_f = teacher_pair.float()
teacher_final_f = teacher_final.float()
fused_final_f = fused_final.float()
teacher_rel = F.cosine_similarity(
teacher_pair_f, teacher_final_f, dim=-1, eps=eps
)
fused_rel = F.cosine_similarity(
teacher_pair_f, fused_final_f, dim=-1, eps=eps
)
rel_change = (teacher_rel - fused_rel).abs()
if attention_mask is not None:
mask = attention_mask.to(dtype=torch.float32)
batch_tokens = float(mask.sum().item())
rel_change = rel_change * mask
else:
batch_tokens = float(rel_change.numel())
token_count += batch_tokens
score_sum += float(rel_change.sum().item())
num_batches += 1
handle_pair.remove()
del fused_layer
if torch.cuda.is_available():
torch.cuda.empty_cache()
denom = token_count if token_count > 0 else 1.0
score = score_sum / denom
meta = {
"num_batches": num_batches,
"token_count": token_count,
"metric": "global_rel_change",
"supports_kwargs": supports_kwargs,
"fuse_priors": fuse_priors,
}
return score, meta
def select_layer_auto(
model,
layers: List[torch.nn.Module],
dataloader,
args,
previous_scores: Optional[List[float]] = None,
start_index: int = 0,
exclude_pairs: Optional[Set[int]] = None,
) -> Tuple[int, List[float], Dict[str, object]]:
num_layers = len(layers)
if num_layers < 2:
raise SystemExit("Model must have at least 2 layers for auto selection.")
hidden_size = _get_hidden_size(model)
num_pairs = num_layers - 1
scores: List[float] = [float("inf")] * num_pairs
meta_per_pair: List[Optional[Dict[str, object]]] = [None] * num_pairs
supports_kwargs_all = True
head_permute_select = not bool(getattr(args, "no_head_permute_select", False))
exclude_set: Set[int] = {
int(idx)
for idx in (exclude_pairs or set())
if isinstance(idx, int) and 0 <= int(idx) < num_pairs
}
max_batches = args.auto_max_batches
start_index = max(0, min(start_index, num_pairs))
auto_metric = str(getattr(args, "auto_metric", "dwce")).strip().lower()
if auto_metric == "hybrid":
auto_metric = "hybrid_cosine"
if auto_metric not in {
"dwce",
"cosine",
"hybrid_cosine",
"hybrid_global_rel",
}:
raise SystemExit(
"--auto_metric must be one of: dwce, cosine, hybrid, "
"hybrid_cosine, hybrid_global_rel"
)
auto_cosine_topk = int(getattr(args, "auto_cosine_topk", 3))
if auto_cosine_topk <= 0:
raise SystemExit("--auto_cosine_topk must be >= 1")
print(
f"[auto] metric={auto_metric}; using "
f"{('all' if max_batches == 0 else max_batches)} batches "
"from calibration samples."
)
reuse_upto = 0
allow_reuse = auto_metric == "dwce"
if previous_scores:
reuse_upto = min(start_index, len(previous_scores), num_pairs) if allow_reuse else 0
for idx in range(reuse_upto):
if idx in exclude_set:
scores[idx] = float("inf")
meta_per_pair[idx] = {"excluded": True}
print(f"[auto] skipped excluded pair {idx}-{idx+1}.")
continue
scores[idx] = previous_scores[idx]
meta_per_pair[idx] = (
{
"num_batches": 0,
"token_count": 0.0,
"norm": args.auto_norm,
"metric": auto_metric,
"supports_kwargs": True,
"reused": True,
}
)
print(f"[auto] reused pair {idx}-{idx+1}: {scores[idx]:.6e}")
compute_start = start_index if reuse_upto == start_index else reuse_upto
pairs_to_score: List[int] = []
for idx in range(compute_start, num_pairs):
if idx in exclude_set:
scores[idx] = float("inf")
meta_per_pair[idx] = {"excluded": True}
print(f"[auto] skipped excluded pair {idx}-{idx+1}.")
continue
pairs_to_score.append(idx)
def _score_dwce_for_pair(idx: int) -> Tuple[float, Dict[str, object]]:
print(f"[auto] building fused pair {idx}-{idx+1} for DWCE...")
layer_a = layers[idx]
layer_b = layers[idx + 1]
dwce_mode = str(getattr(args, "auto_dwce_mode", "separate")).strip().lower()
if dwce_mode == "shared":
try:
return _score_dwce_with_shared_backward(
model,
layer_a,
layer_b,
dataloader,
device=args.device,
fisher_mode=args.fisher_mode,
max_batches=max_batches,
eps=args.eps,
norm=args.auto_norm,
hidden_size=hidden_size,
enable_head_permute=head_permute_select,
)
except _DwceGradCacheOverflow:
print(
"[auto] shared-backward DWCE cache exceeded budget; "
"falling back to separate mode."
)
fused, fuse_priors = _build_fused_layer_for_pair(
model,
layer_a,
layer_b,
dataloader,
device=args.device,
fisher_mode=args.fisher_mode,
eps=args.eps,
hidden_size=hidden_size,
enable_head_permute=head_permute_select,
)
fused.to(args.device)
fused.eval()
for param in model.parameters():
param.requires_grad_(True)
score, meta = _compute_dwce_for_pair(
model,
layer_a,
layer_b,
fused,
dataloader,
device=args.device,
max_batches=max_batches,
eps=args.eps,
norm=args.auto_norm,
)
meta["fuse_priors"] = fuse_priors
meta["metric"] = "dwce"
del fused
if torch.cuda.is_available():
torch.cuda.empty_cache()
return score, meta
def _score_cosine_for_pair(idx: int) -> Tuple[float, Dict[str, object]]:
print(f"[auto] scoring cosine for pair {idx}-{idx+1}...")
layer_a = layers[idx]
layer_b = layers[idx + 1]
return _compute_cosine_for_pair(
model,
layer_a,
layer_b,
dataloader,
device=args.device,
max_batches=max_batches,
eps=args.eps,
)
def _score_global_rel_for_pair(idx: int) -> Tuple[float, Dict[str, object]]:
print(f"[auto] scoring global relation change for pair {idx}-{idx+1}...")
return _compute_global_rel_change_for_pair(
model,
layers,
idx,
dataloader,
args=args,
max_batches=max_batches,
eps=args.eps,
)
if auto_metric in {"dwce", "cosine"}:
for idx in pairs_to_score:
if auto_metric == "dwce":
score, meta = _score_dwce_for_pair(idx)
else:
score, meta = _score_cosine_for_pair(idx)
supports_kwargs_all = supports_kwargs_all and meta.get("supports_kwargs", True)
scores[idx] = score
meta_per_pair[idx] = meta
print(f"[auto] {auto_metric} pair {idx}-{idx+1}: {score:.6e}")
else:
dwce_prefilter: Dict[int, float] = {}
for idx in pairs_to_score:
score, meta = _score_dwce_for_pair(idx)
dwce_prefilter[idx] = score
supports_kwargs_all = supports_kwargs_all and meta.get("supports_kwargs", True)
meta_per_pair[idx] = {
"prefilter_dwce": score,
"dwce_meta": meta,
"metric": "hybrid",
}
print(f"[auto] hybrid prefilter DWCE pair {idx}-{idx+1}: {score:.6e}")
ranked = sorted(pairs_to_score, key=lambda i: float(dwce_prefilter[i]))
shortlist = ranked[: min(auto_cosine_topk, len(ranked))]
print(f"[auto] hybrid shortlist (dwce top-{len(shortlist)}): {shortlist}")
for idx in shortlist:
if auto_metric == "hybrid_global_rel":
score, rerank_meta = _score_global_rel_for_pair(idx)
score_metric = "global_rel_change"
else:
score, rerank_meta = _score_cosine_for_pair(idx)
score_metric = "cosine"
supports_kwargs_all = supports_kwargs_all and rerank_meta.get(
"supports_kwargs", True
)
scores[idx] = score
pair_meta = meta_per_pair[idx] or {}
pair_meta["rerank_meta"] = rerank_meta
pair_meta["score_metric"] = score_metric
meta_per_pair[idx] = pair_meta
print(f"[auto] hybrid {score_metric} pair {idx}-{idx+1}: {score:.6e}")
if not supports_kwargs_all:
print(
"[auto] Warning: forward hooks did not capture kwargs; "
"fused-layer calls may be approximate."
)
print(f"[auto] score summary (metric={auto_metric}, norm={args.auto_norm}):")
for idx, score in enumerate(scores):
if idx in exclude_set:
print(f"[auto] pair {idx}-{idx+1}: excluded")
elif math.isfinite(float(score)):
print(f"[auto] pair {idx}-{idx+1}: {score:.6e}")
else:
print(f"[auto] pair {idx}-{idx+1}: {score}")
candidates = [i for i in range(num_pairs) if i not in exclude_set]
if not candidates:
raise SystemExit("All pairs are excluded; cannot auto-select a fusion layer.")
best_idx = min(candidates, key=lambda i: scores[i])
best_score = float(scores[best_idx])
if not math.isfinite(best_score):
raise SystemExit(
"Auto selection failed: all candidate pairs have non-finite scores "
"(check --exclude_pairs and data)."
)
print(f"[auto] Selected layer {best_idx} (score={best_score:.6e})")
meta = {
"per_pair": meta_per_pair,
"supports_kwargs": supports_kwargs_all,
"max_batches": max_batches,
"norm": args.auto_norm,
"metric": auto_metric,
"cosine_topk": auto_cosine_topk,
"start_index": start_index,
"excluded_pairs": sorted(exclude_set),
}
return best_idx, scores, meta