temp_ss / src /fuse_layers_model.py
LJYAI's picture
upload src
2c44909 verified
#!/usr/bin/env python3
"""Model and layer helpers for fuse_layers."""
import os
from typing import Dict, List, Optional, Tuple
import torch
try:
from tqdm import tqdm
except Exception: # pragma: no cover - optional dependency
tqdm = None
def _tqdm_enabled() -> bool:
value = os.environ.get("DISABLE_TQDM", os.environ.get("TQDM_DISABLE", "0"))
return value.strip().lower() not in {"1", "true", "yes", "on"}
def get_dtype(dtype: str):
if dtype == "auto":
return None
if dtype == "float16":
return torch.float16
if dtype == "bfloat16":
return torch.bfloat16
return torch.float32
def resolve_attr(root: object, path: str) -> Optional[object]:
cur = root
for part in path.split("."):
if not hasattr(cur, part):
return None
cur = getattr(cur, part)
return cur
def resolve_attr_with_parent(root: object, path: str) -> Tuple[object, str, object]:
parts = path.split(".")
cur = root
for part in parts[:-1]:
if not hasattr(cur, part):
raise ValueError(f"'{path}' not found on model")
cur = getattr(cur, part)
name = parts[-1]
if not hasattr(cur, name):
raise ValueError(f"'{path}' not found on model")
return cur, name, getattr(cur, name)
def find_layer_container(model, layer_path: Optional[str]) -> Tuple[object, str, object]:
if layer_path:
parent, name, container = resolve_attr_with_parent(model, layer_path)
return parent, name, container
candidate_paths = [
"model.layers", # LLaMA, Mistral, Qwen2, Gemma
"model.decoder.layers", # OPT
"transformer.h", # GPT-2, GPT-J, Bloom, Falcon
"transformer.blocks", # MPT
"gpt_neox.layers", # GPT-NeoX
"layers", # fallback
]
for path in candidate_paths:
candidate = resolve_attr(model, path)
if candidate is None:
continue
try:
list(candidate)
except TypeError:
continue
parent, name, container = resolve_attr_with_parent(model, path)
return parent, name, container
raise ValueError(
"Could not locate transformer layers. Pass --layer_path explicitly."
)
def find_attention_module(layer: torch.nn.Module) -> torch.nn.Module:
if hasattr(layer, "self_attn"):
return getattr(layer, "self_attn")
if hasattr(layer, "attn"):
return getattr(layer, "attn")
if hasattr(layer, "attention"):
return getattr(layer, "attention")
for _, module in layer.named_modules():
if all(
hasattr(module, attr) for attr in ("q_proj", "k_proj", "v_proj", "o_proj")
):
return module
raise ValueError("Could not find attention module with q_proj/k_proj/v_proj/o_proj")
def find_mlp_module(layer: torch.nn.Module) -> torch.nn.Module:
if hasattr(layer, "mlp"):
return getattr(layer, "mlp")
for attr in ("feed_forward", "feedforward", "ffn", "ff"):
if hasattr(layer, attr):
return getattr(layer, attr)
for _, module in layer.named_modules():
if all(hasattr(module, attr) for attr in ("gate_proj", "up_proj", "down_proj")):
return module
if all(hasattr(module, attr) for attr in ("fc1", "fc2")):
return module
if all(
hasattr(module, attr)
for attr in ("dense_h_to_4h", "dense_4h_to_h")
):
return module
if all(hasattr(module, attr) for attr in ("w1", "w2")):
return module
raise ValueError("Could not find MLP/FFN module on layer")
def get_head_info(
attn: torch.nn.Module, hidden_size: int, config
) -> Tuple[int, int, int]:
num_heads = getattr(attn, "num_heads", None)
if num_heads is None:
num_heads = getattr(attn, "num_attention_heads", None)
if num_heads is None and config is not None:
num_heads = getattr(
config,
"num_attention_heads",
getattr(config, "num_heads", getattr(config, "n_head", None)),
)
num_key_value_heads = getattr(attn, "num_key_value_heads", None)
if num_key_value_heads is None:
num_key_value_heads = getattr(attn, "num_kv_heads", None)
if num_key_value_heads is None and config is not None:
num_key_value_heads = getattr(
config,
"num_key_value_heads",
getattr(config, "num_kv_heads", getattr(config, "n_head_kv", None)),
)
head_dim = getattr(attn, "head_dim", None)
if head_dim is None and config is not None:
head_dim = getattr(config, "head_dim", None)
if num_heads is None:
if hasattr(attn, "q_proj"):
q_out = attn.q_proj.weight.shape[0]
if head_dim is not None:
num_heads = q_out // head_dim
elif num_key_value_heads is not None and hasattr(attn, "k_proj"):
k_out = attn.k_proj.weight.shape[0]
head_dim = k_out // max(int(num_key_value_heads), 1)
num_heads = q_out // head_dim
if num_heads is None:
raise ValueError(
"Attention module missing num_heads/num_attention_heads; "
"pass --layer_path or add config overrides."
)
if num_key_value_heads is None:
num_key_value_heads = num_heads
if head_dim is None:
head_dim = hidden_size // int(num_heads)
if num_key_value_heads is None and hasattr(attn, "k_proj"):
k_out = attn.k_proj.weight.shape[0]
num_key_value_heads = k_out // int(head_dim)
return int(num_heads), int(num_key_value_heads), int(head_dim)
def cosine_cost_matrix(
a: torch.Tensor, b: torch.Tensor, eps: float = 1e-8
) -> torch.Tensor:
a_norm = a / (a.norm(dim=1, keepdim=True) + eps)
b_norm = b / (b.norm(dim=1, keepdim=True) + eps)
sim = a_norm @ b_norm.t()
return 1.0 - sim
def hungarian(cost: torch.Tensor) -> List[int]:
# Kuhn-Munkres for square cost matrix (minimization).
n = cost.size(0)
u = [0.0] * (n + 1)
v = [0.0] * (n + 1)
p = [0] * (n + 1)
way = [0] * (n + 1)
for i in range(1, n + 1):
p[0] = i
j0 = 0
minv = [float("inf")] * (n + 1)
used = [False] * (n + 1)
while True:
used[j0] = True
i0 = p[j0]
delta = float("inf")
j1 = 0
for j in range(1, n + 1):
if used[j]:
continue
cur = cost[i0 - 1, j - 1].item() - u[i0] - v[j]
if cur < minv[j]:
minv[j] = cur
way[j] = j0
if minv[j] < delta:
delta = minv[j]
j1 = j
for j in range(0, n + 1):
if used[j]:
u[p[j]] += delta
v[j] -= delta
else:
minv[j] -= delta
j0 = j1
if p[j0] == 0:
break
while True:
j1 = way[j0]
p[j0] = p[j1]
j0 = j1
if j0 == 0:
break
assignment = [-1] * n
for j in range(1, n + 1):
if p[j] > 0:
assignment[p[j] - 1] = j - 1
return assignment
def compute_head_means(
model,
attn_i: torch.nn.Module,
attn_j: torch.nn.Module,
dataloader,
device: str,
hidden_size: int,
) -> Tuple[torch.Tensor, torch.Tensor, int, int, int]:
num_heads_i, num_kv_i, head_dim_i = get_head_info(attn_i, hidden_size, model.config)
num_heads_j, num_kv_j, head_dim_j = get_head_info(attn_j, hidden_size, model.config)
if num_heads_i != num_heads_j or head_dim_i != head_dim_j:
raise ValueError("Head counts or head_dim differ between layers; cannot align")
sums_i = torch.zeros(num_heads_i, head_dim_i, device="cpu")
sums_j = torch.zeros(num_heads_j, head_dim_j, device="cpu")
count_i = [0]
count_j = [0]
def make_hook(
sums: torch.Tensor, count_ref: List[int], num_heads: int, head_dim: int
):
def hook(_module, inputs, _output):
hidden = inputs[0].detach()
if hidden.dim() != 3:
return
batch, seq, width = hidden.shape
if width != num_heads * head_dim:
return
reshaped = hidden.view(batch, seq, num_heads, head_dim)
sums.add_(reshaped.sum(dim=(0, 1)).float().cpu())
count_ref[0] += batch * seq
return hook
hook_i = attn_i.o_proj.register_forward_hook(
make_hook(sums_i, count_i, num_heads_i, head_dim_i)
)
hook_j = attn_j.o_proj.register_forward_hook(
make_hook(sums_j, count_j, num_heads_j, head_dim_j)
)
model.eval()
iterator = dataloader
if tqdm is not None and _tqdm_enabled():
iterator = tqdm(dataloader, desc="Head stats", unit="batch")
with torch.no_grad():
for batch in iterator:
input_ids = batch[0].to(device)
_ = model(input_ids=input_ids)
hook_i.remove()
hook_j.remove()
if count_i[0] == 0 or count_j[0] == 0:
raise RuntimeError("Failed to capture head outputs; check attention modules.")
mean_i = sums_i / count_i[0]
mean_j = sums_j / count_j[0]
return mean_i, mean_j, num_heads_i, num_kv_i, head_dim_i
def build_head_permutation(
mean_i: torch.Tensor,
mean_j: torch.Tensor,
num_heads: int,
num_kv_heads: int,
eps: float,
) -> List[int]:
group_size = num_heads // num_kv_heads
if group_size * num_kv_heads != num_heads:
raise ValueError("num_heads must be divisible by num_key_value_heads")
perm = list(range(num_heads))
for g in range(num_kv_heads):
start = g * group_size
end = start + group_size
cost = cosine_cost_matrix(mean_i[start:end], mean_j[start:end], eps=eps)
assignment = hungarian(cost)
for local_idx, match in enumerate(assignment):
perm[start + local_idx] = start + match
return perm
def permute_attention_heads(
attn: torch.nn.Module,
perm: List[int],
num_heads: int,
num_kv_heads: int,
head_dim: int,
) -> None:
hidden_size = num_heads * head_dim
def permute_out_proj_weight(weight: torch.Tensor) -> torch.Tensor:
out_features, in_features = weight.shape
if in_features != hidden_size:
raise ValueError(
"o_proj in_features ({} ) != num_heads*head_dim ({})".format(
in_features, hidden_size
)
)
reshaped = weight.view(out_features, num_heads, head_dim)
reshaped = reshaped[:, perm, :]
return reshaped.reshape(out_features, in_features)
def permute_proj_weight(weight: torch.Tensor) -> torch.Tensor:
out_features, in_features = weight.shape
if out_features != hidden_size:
raise ValueError(
"proj out_features ({}) != num_heads*head_dim ({})".format(
out_features, hidden_size
)
)
reshaped = weight.view(num_heads, head_dim, in_features)
reshaped = reshaped[perm, :, :]
return reshaped.reshape(out_features, in_features)
def permute_proj_bias(bias: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
if bias is None:
return None
reshaped = bias.view(num_heads, head_dim)
reshaped = reshaped[perm, :]
return reshaped.reshape(num_heads * head_dim)
with torch.no_grad():
attn.q_proj.weight.copy_(permute_proj_weight(attn.q_proj.weight))
if attn.q_proj.bias is not None:
attn.q_proj.bias.copy_(permute_proj_bias(attn.q_proj.bias))
if num_kv_heads == num_heads:
attn.k_proj.weight.copy_(permute_proj_weight(attn.k_proj.weight))
if attn.k_proj.bias is not None:
attn.k_proj.bias.copy_(permute_proj_bias(attn.k_proj.bias))
attn.v_proj.weight.copy_(permute_proj_weight(attn.v_proj.weight))
if attn.v_proj.bias is not None:
attn.v_proj.bias.copy_(permute_proj_bias(attn.v_proj.bias))
attn.o_proj.weight.copy_(permute_out_proj_weight(attn.o_proj.weight))
def compute_fisher(
model,
layer_a: torch.nn.Module,
layer_b: torch.nn.Module,
dataloader,
fisher_mode: str,
device: str,
) -> Tuple[List[Dict[str, object]], int, List[Dict[str, int]]]:
for param in model.parameters():
param.requires_grad_(False)
for layer in (layer_a, layer_b):
for param in layer.parameters():
param.requires_grad_(True)
fisher_sums: List[Dict[str, object]] = []
param_numels: List[Dict[str, int]] = []
for layer in (layer_a, layer_b):
layer_sums: Dict[str, object] = {}
layer_numels: Dict[str, int] = {}
for name, param in layer.named_parameters():
if not param.requires_grad:
continue
if fisher_mode == "param":
layer_sums[name] = torch.zeros_like(
param, dtype=torch.float32, device="cpu"
)
else:
layer_sums[name] = 0.0
layer_numels[name] = param.numel()
fisher_sums.append(layer_sums)
param_numels.append(layer_numels)
num_batches = 0
model.eval()
iterator = dataloader
if tqdm is not None and _tqdm_enabled():
iterator = tqdm(dataloader, desc="Fisher", unit="batch")
for batch in iterator:
input_ids = batch[0].to(device)
outputs = model(input_ids=input_ids, labels=input_ids)
loss = outputs.loss
loss.backward()
for layer_idx, layer in enumerate((layer_a, layer_b)):
layer_sums = fisher_sums[layer_idx]
for name, param in layer.named_parameters():
if not param.requires_grad or param.grad is None:
continue
grad_sq = param.grad.detach().float().pow(2)
if fisher_mode == "param":
layer_sums[name] += grad_sq.cpu()
else:
layer_sums[name] += float(grad_sq.sum().item())
model.zero_grad(set_to_none=True)
num_batches += 1
if num_batches == 0:
raise RuntimeError("No batches processed; check dataset or text inputs.")
return fisher_sums, num_batches, param_numels
def merge_layers(
layer_a: torch.nn.Module,
layer_b: torch.nn.Module,
fisher_a: Dict[str, object],
fisher_b: Dict[str, object],
num_batches: int,
numels_a: Dict[str, int],
numels_b: Dict[str, int],
fisher_mode: str,
eps: float,
) -> int:
merged = 0
params_b = {name: param for name, param in layer_b.named_parameters()}
with torch.no_grad():
for name, param_a in layer_a.named_parameters():
param_b = params_b.get(name)
if param_b is None or param_b.shape != param_a.shape:
continue
if fisher_mode == "param":
fa = fisher_a[name] / num_batches
fb = fisher_b[name] / num_batches
# Fisher tensors are accumulated on CPU to save VRAM; move to the
# parameter device for the actual merge.
if isinstance(fa, torch.Tensor) and fa.device != param_a.device:
fa = fa.to(param_a.device)
if isinstance(fb, torch.Tensor) and fb.device != param_a.device:
fb = fb.to(param_a.device)
denom = fa + fb
denom_mean = float(denom.mean().item())
if denom_mean <= eps:
merged_param = 0.5 * (param_a.float() + param_b.float())
else:
merged_param = (fa * param_a.float() + fb * param_b.float()) / (
denom + eps
)
else:
fa = fisher_a[name] / (num_batches * numels_a[name])
fb = fisher_b[name] / (num_batches * numels_b[name])
denom = fa + fb
if denom <= eps:
merged_param = 0.5 * (param_a.float() + param_b.float())
else:
merged_param = (
fa * param_a.float() + fb * param_b.float()
) / (denom + eps)
param_a.copy_(merged_param.to(dtype=param_a.dtype))
merged += 1
return merged
def merge_layers_with_gates(
layer_a: torch.nn.Module,
layer_b: torch.nn.Module,
gates: Dict[str, torch.Tensor],
) -> int:
"""Merge layer_b into layer_a using precomputed gates.
Each gate is a lambda in [0, 1] that mixes parameters as:
W = lambda * W_a + (1 - lambda) * W_b
Gate tensors may be scalars (per-tensor gating) or full tensors matching the
parameter shape (per-parameter gating).
"""
merged = 0
params_b = {name: param for name, param in layer_b.named_parameters()}
with torch.no_grad():
for name, param_a in layer_a.named_parameters():
gate = gates.get(name)
if gate is None:
continue
param_b = params_b.get(name)
if param_b is None or param_b.shape != param_a.shape:
continue
lam = gate
if not isinstance(lam, torch.Tensor):
lam = torch.tensor(lam)
if lam.device != param_a.device:
lam = lam.to(param_a.device)
merged_param = lam * param_a.float() + (1.0 - lam) * param_b.float()
param_a.copy_(merged_param.to(dtype=param_a.dtype))
merged += 1
return merged
def drop_layer(container: object, index: int) -> object:
if isinstance(container, torch.nn.ModuleList):
return torch.nn.ModuleList(
[layer for idx, layer in enumerate(container) if idx != index]
)
if isinstance(container, list):
del container[index]
return container
raise TypeError("Layer container must be ModuleList or list")
def decrement_config(config) -> None:
for attr in ("num_hidden_layers", "n_layer", "num_layers"):
if hasattr(config, attr):
value = getattr(config, attr)
if isinstance(value, int) and value > 0:
setattr(config, attr, value - 1)
normalize_config(config)
def normalize_config(config) -> None:
num_hidden_layers = getattr(config, "num_hidden_layers", None)
layer_types = getattr(config, "layer_types", None)
if (
isinstance(num_hidden_layers, int)
and num_hidden_layers >= 0
and isinstance(layer_types, (list, tuple))
and len(layer_types) != num_hidden_layers
):
config.layer_types = list(layer_types[:num_hidden_layers])
def find_colon_modules(module: torch.nn.Module) -> List[str]:
found: List[str] = []
for name, child in module._modules.items():
if ":" in name:
found.append(name)
if isinstance(child, torch.nn.Module):
for sub in find_colon_modules(child):
found.append(f"{name}.{sub}")
return found
def get_norm_pair(
layer: torch.nn.Module,
) -> Tuple[
Optional[torch.nn.Module],
Optional[torch.nn.Module],
Tuple[Optional[str], Optional[str]],
]:
candidates = [
("input_layernorm", "post_attention_layernorm"),
("ln_1", "ln_2"),
("norm1", "norm2"),
("norm_1", "norm_2"),
("layer_norm_1", "layer_norm_2"),
("self_attn_layer_norm", "final_layer_norm"),
]
for n1, n2 in candidates:
if hasattr(layer, n1) and hasattr(layer, n2):
return getattr(layer, n1), getattr(layer, n2), (n1, n2)
return None, None, (None, None)
def clone_state_dict(module: torch.nn.Module) -> Dict[str, torch.Tensor]:
return {k: v.detach().clone() for k, v in module.state_dict().items()}
def apply_norm_policy(
layer: torch.nn.Module,
norm_policy: str,
norm1_state: Optional[Dict[str, torch.Tensor]],
norm2_state: Optional[Dict[str, torch.Tensor]],
norm_names: Tuple[Optional[str], Optional[str]],
) -> None:
norm1, norm2, _ = get_norm_pair(layer)
if norm_policy in {"copy_n1", "hybrid"} and norm1_state is not None and norm1 is not None:
norm1.load_state_dict(norm1_state)
if norm_policy == "copy_n1_n2" and norm2_state is not None and norm2 is not None:
norm2.load_state_dict(norm2_state)