temp_ss / src /fuse_layers_distill.py
LJYAI's picture
upload src
2c44909 verified
#!/usr/bin/env python3
"""Distillation helpers for fuse_layers."""
import argparse
import itertools
import math
import os
from contextlib import contextmanager, nullcontext
from typing import Dict, List, Optional, Set, Tuple
import torch
import torch.nn.functional as F
try:
import ppl_eval
except Exception as exc: # pragma: no cover - optional dependency
raise SystemExit("ppl_eval.py is required (missing or invalid)") from exc
try:
from tqdm import tqdm
except Exception: # pragma: no cover - optional dependency
tqdm = None
try:
from torch.func import functional_call as _functional_call
except Exception: # pragma: no cover - depends on torch version
try:
from torch.nn.utils.stateless import functional_call as _functional_call
except Exception: # pragma: no cover - depends on torch version
_functional_call = None
from fuse_layers_model import find_attention_module, find_mlp_module
def _tqdm_enabled() -> bool:
value = os.environ.get("DISABLE_TQDM", os.environ.get("TQDM_DISABLE", "0"))
return value.strip().lower() not in {"1", "true", "yes", "on"}
@contextmanager
def temporary_layers(parent: object, name: str, new_layers: torch.nn.Module):
original = getattr(parent, name)
setattr(parent, name, new_layers)
try:
yield
finally:
setattr(parent, name, original)
@contextmanager
def temporary_norm(parent: object):
if hasattr(parent, "norm"):
original = getattr(parent, "norm")
setattr(parent, "norm", torch.nn.Identity())
try:
yield
finally:
setattr(parent, "norm", original)
else:
yield
def forward_truncated(
parent: torch.nn.Module,
layer_attr: str,
layers: List[torch.nn.Module],
upto: int,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
truncated = torch.nn.ModuleList(layers[:upto])
with temporary_layers(parent, layer_attr, truncated), temporary_norm(parent):
outputs = parent(
input_ids=input_ids,
attention_mask=attention_mask,
use_cache=False,
)
if hasattr(outputs, "last_hidden_state"):
return outputs.last_hidden_state
return outputs[0]
def _masked_hidden_mse(diff: torch.Tensor, attention_mask: torch.Tensor) -> Optional[torch.Tensor]:
diff_f = diff.float()
mask = attention_mask.to(device=diff.device, dtype=torch.float32)
denom = mask.sum() * diff_f.size(-1)
if denom.item() == 0:
return None
return (diff_f.pow(2) * mask.unsqueeze(-1)).sum() / denom
def _extract_hidden_like(output) -> Optional[torch.Tensor]:
if torch.is_tensor(output):
return output
if isinstance(output, (tuple, list)) and output:
first = output[0]
if torch.is_tensor(first):
return first
if hasattr(output, "last_hidden_state"):
hidden = getattr(output, "last_hidden_state")
if torch.is_tensor(hidden):
return hidden
return None
@contextmanager
def capture_module_output(module: torch.nn.Module):
cache: Dict[str, Optional[torch.Tensor]] = {"output": None}
def hook(_module, _inputs, output):
cache["output"] = _extract_hidden_like(output)
handle = module.register_forward_hook(hook)
try:
yield cache
finally:
handle.remove()
_ATTN_NAME_FRAGMENTS = (
"self_attn.",
"attn.",
"attention.",
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"q_norm",
"k_norm",
)
_MLP_NAME_FRAGMENTS = (
"mlp.",
"ffn.",
"feed_forward",
"feedforward",
"gate_proj",
"up_proj",
"down_proj",
"fc1",
"fc2",
"dense_h_to_4h",
"dense_4h_to_h",
"w1",
"w2",
"w3",
)
def _classify_param_family(name: str) -> str:
lowered = name.lower()
if any(fragment in lowered for fragment in _MLP_NAME_FRAGMENTS):
return "mlp"
if any(fragment in lowered for fragment in _ATTN_NAME_FRAGMENTS):
return "attn"
return "other"
def _family_reg_scale(family: str, attn_scale: float, mlp_scale: float) -> float:
if family == "attn":
return attn_scale
if family == "mlp":
return mlp_scale
return 1.0
def _subset_allows_param(name: str, subset: str) -> bool:
if subset == "all":
return True
return _classify_param_family(name) == subset
def _gate_logit_from_prior(prior: torch.Tensor) -> torch.Tensor:
# Stable logit: log(p) - log(1 - p).
return torch.log(prior) - torch.log1p(-prior)
def _build_gate_priors(
layer_a: torch.nn.Module,
layer_b: torch.nn.Module,
fisher_a: Dict[str, object],
fisher_b: Dict[str, object],
num_batches: int,
numels_a: Dict[str, int],
numels_b: Dict[str, int],
fisher_mode: str,
eps: float,
clamp_eps: float,
) -> Dict[str, torch.Tensor]:
"""Return lambda priors for parameters that can be merged."""
priors: Dict[str, torch.Tensor] = {}
params_b = {name: param for name, param in layer_b.named_parameters()}
for name, param_a in layer_a.named_parameters():
param_b = params_b.get(name)
if param_b is None or param_b.shape != param_a.shape:
continue
if fisher_mode == "param":
fa = fisher_a[name] / max(num_batches, 1)
fb = fisher_b[name] / max(num_batches, 1)
denom = fa + fb
if not isinstance(denom, torch.Tensor):
denom = torch.tensor(float(denom))
# If Fisher is uninformative, default to symmetric init.
prior = torch.where(
denom > eps,
fa / (denom + eps),
torch.full_like(denom, 0.5),
)
prior = prior.clamp(clamp_eps, 1.0 - clamp_eps)
priors[name] = prior
else:
fa = fisher_a[name] / (max(num_batches, 1) * numels_a[name])
fb = fisher_b[name] / (max(num_batches, 1) * numels_b[name])
denom = fa + fb
if denom <= eps:
prior_val = 0.5
else:
prior_val = float(fa / (denom + eps))
prior_val = min(max(prior_val, clamp_eps), 1.0 - clamp_eps)
priors[name] = torch.tensor(prior_val, dtype=torch.float32)
return priors
def compute_fisher_gate_priors(
layer_a: torch.nn.Module,
layer_b: torch.nn.Module,
fisher_a: Dict[str, object],
fisher_b: Dict[str, object],
num_batches: int,
numels_a: Dict[str, int],
numels_b: Dict[str, int],
fisher_mode: str,
eps: float,
clamp_eps: float = 1e-4,
) -> Dict[str, torch.Tensor]:
"""Compute Fisher prior gate lambdas (lambda_prior) for mergeable parameters."""
return _build_gate_priors(
layer_a=layer_a,
layer_b=layer_b,
fisher_a=fisher_a,
fisher_b=fisher_b,
num_batches=num_batches,
numels_a=numels_a,
numels_b=numels_b,
fisher_mode=fisher_mode,
eps=eps,
clamp_eps=clamp_eps,
)
class ReparamMergedLayer(torch.nn.Module):
"""Virtual layer that merges parameters via W0/U reparameterization.
Parameters of layer_a/layer_b are treated as frozen (detached). We train:
- gate logits s (lambda = sigmoid(s))
- U (initialized as U0 = (W_a - W_b) / 2)
Forward uses:
W_merge = W0 + (2 * lambda - 1) * U
where W0 = (W_a + W_b) / 2
"""
def __init__(
self,
layer_a: torch.nn.Module,
layer_b: torch.nn.Module,
gate_targets: Dict[str, object],
param_subset: str = "all",
clamp_eps: float = 1e-4,
) -> None:
super().__init__()
self.layer_a = layer_a
self.layer_b = layer_b
self.param_subset = param_subset
self._name_map: Dict[str, str] = {}
self.gates = torch.nn.ParameterDict()
self.u = torch.nn.ParameterDict()
params_b = {name: param for name, param in layer_b.named_parameters()}
try:
device = next(layer_a.parameters()).device
except StopIteration:
device = torch.device("cpu")
for name, param_a in layer_a.named_parameters():
param_b = params_b.get(name)
if param_b is None or param_b.shape != param_a.shape:
continue
if not _subset_allows_param(name, self.param_subset):
continue
target = gate_targets.get(name)
if target is None:
target_t = torch.tensor(0.5, device=device, dtype=torch.float32)
elif isinstance(target, torch.Tensor):
target_t = target.detach().to(device=device, dtype=torch.float32)
else:
target_t = torch.tensor(float(target), device=device, dtype=torch.float32)
target_t = target_t.clamp(clamp_eps, 1.0 - clamp_eps)
s0 = _gate_logit_from_prior(target_t)
u0 = 0.5 * (param_a.detach().float() - param_b.detach().float())
safe = name.replace(".", "__")
if safe in self.gates:
safe = f"{safe}_{len(self.gates)}"
self._name_map[name] = safe
self.gates[safe] = torch.nn.Parameter(s0)
self.u[safe] = torch.nn.Parameter(u0)
def __getattr__(self, name: str):
# Delegate model-specific attributes (e.g. Qwen's `attention_type`) to
# the underlying layer so the parent forward doesn't break.
try:
return super().__getattr__(name)
except AttributeError as exc:
try:
layer_a = super().__getattr__("layer_a")
if hasattr(layer_a, name):
return getattr(layer_a, name)
except AttributeError:
pass
try:
layer_b = super().__getattr__("layer_b")
if hasattr(layer_b, name):
return getattr(layer_b, name)
except AttributeError:
pass
raise exc
def _safe_for(self, orig: str) -> Optional[str]:
return self._name_map.get(orig)
def gate_lambdas(self) -> Dict[str, torch.Tensor]:
out: Dict[str, torch.Tensor] = {}
for orig, safe in self._name_map.items():
out[orig] = torch.sigmoid(self.gates[safe]).detach()
return out
def _merged_params(self) -> Dict[str, torch.Tensor]:
params_a = {name: p for name, p in self.layer_a.named_parameters()}
params_b = {name: p for name, p in self.layer_b.named_parameters()}
merged_params: Dict[str, torch.Tensor] = {}
for name, param_a in params_a.items():
param_b = params_b.get(name)
safe = self._safe_for(name)
if safe is None or param_b is None or param_b.shape != param_a.shape:
merged_params[name] = param_a.detach()
continue
lam = torch.sigmoid(self.gates[safe]).to(dtype=torch.float32)
u = self.u[safe].to(dtype=torch.float32)
w0 = 0.5 * (param_a.detach().float() + param_b.detach().float())
merged = w0 + (2.0 * lam - 1.0) * u
merged_params[name] = merged.to(dtype=param_a.dtype)
return merged_params
def forward(self, *args, **kwargs):
if _functional_call is None:
raise RuntimeError(
"Reparam distillation requires torch.func.functional_call"
)
merged_params = self._merged_params()
return _functional_call(self.layer_a, merged_params, args, kwargs)
def materialize_into_layer_a(self) -> int:
merged = 0
params_a = {name: p for name, p in self.layer_a.named_parameters()}
params_b = {name: p for name, p in self.layer_b.named_parameters()}
with torch.no_grad():
for orig, safe in self._name_map.items():
param_a = params_a.get(orig)
param_b = params_b.get(orig)
if param_a is None or param_b is None or param_b.shape != param_a.shape:
continue
lam = torch.sigmoid(self.gates[safe]).to(device=param_a.device, dtype=torch.float32)
u = self.u[safe].to(device=param_a.device, dtype=torch.float32)
w0 = 0.5 * (param_a.detach().float() + param_b.detach().float())
merged_param = w0 + (2.0 * lam - 1.0) * u
param_a.copy_(merged_param.to(dtype=param_a.dtype))
merged += 1
return merged
def distill_reparam_merge(
student_model: torch.nn.Module,
student_parent: object,
student_layer_attr: str,
student_layers: List[torch.nn.Module],
teacher_model: torch.nn.Module,
teacher_parent: object,
teacher_layer_attr: str,
teacher_layers: List[torch.nn.Module],
layer_idx: int,
gate_lambdas: Dict[str, object],
dataloader,
args: argparse.Namespace,
progressive_cycle: Optional[int] = None,
progressive_total: Optional[int] = None,
) -> Tuple[int, Dict[str, torch.Tensor], Dict[str, object]]:
"""Reparameterized distillation that materializes a fused layer into layer_a.
Trains U and gate logits s (lambda = sigmoid(s)) using:
- composition MSE + distill-KL
- eta * ||lambda - lambda_gate||^2 + gamma * ||U - U0||^2
"""
total_epochs = float(args.distill_epochs)
hidden_mse_weight = float(getattr(args, "distill_hidden_mse_weight", 1.0))
if hidden_mse_weight < 0.0:
raise SystemExit("--distill_hidden_mse_weight must be >= 0")
attn_mse_weight = float(getattr(args, "distill_attn_mse_weight", 0.0))
if attn_mse_weight < 0.0:
raise SystemExit("--distill_attn_mse_weight must be >= 0")
mlp_mse_weight = float(getattr(args, "distill_mlp_mse_weight", 0.0))
if mlp_mse_weight < 0.0:
raise SystemExit("--distill_mlp_mse_weight must be >= 0")
param_subset = str(getattr(args, "reparam_param_subset", "all"))
if param_subset not in {"all", "mlp", "attn"}:
raise SystemExit("--reparam_param_subset must be one of: all, mlp, attn")
kl_weight = float(args.distill_kl_weight)
kl_temp = float(args.distill_kl_temp)
if kl_weight < 0.0:
raise SystemExit("--distill_kl_weight must be >= 0")
if kl_temp <= 0.0:
raise SystemExit("--distill_kl_temp must be > 0")
eta = float(getattr(args, "reparam_eta", 0.0))
gamma = float(getattr(args, "reparam_gamma", 0.0))
if eta < 0.0:
raise SystemExit("--reparam_eta must be >= 0")
if gamma < 0.0:
raise SystemExit("--reparam_gamma must be >= 0")
attn_reg_scale = float(getattr(args, "reparam_attn_reg_scale", 1.0))
mlp_reg_scale = float(getattr(args, "reparam_mlp_reg_scale", 1.0))
if attn_reg_scale < 0.0:
raise SystemExit("--reparam_attn_reg_scale must be >= 0")
if mlp_reg_scale < 0.0:
raise SystemExit("--reparam_mlp_reg_scale must be >= 0")
if (
total_epochs > 0.0
and hidden_mse_weight == 0.0
and attn_mse_weight == 0.0
and mlp_mse_weight == 0.0
and kl_weight == 0.0
and eta == 0.0
and gamma == 0.0
):
raise SystemExit(
"Reparam distillation has no active loss terms. "
"Enable hidden/attention/MLP MSE, KL, or at least one reparam regularizer."
)
if not gate_lambdas:
raise SystemExit("Reparam distillation requires non-empty gate lambdas.")
layer_a = student_layers[layer_idx]
layer_b = student_layers[layer_idx + 1]
reparam_layer = ReparamMergedLayer(
layer_a,
layer_b,
gate_lambdas,
param_subset=param_subset,
clamp_eps=1e-4,
)
if not reparam_layer._name_map:
raise RuntimeError(
"No mergeable parameters found for reparam distillation under "
f"--reparam_param_subset={param_subset!r}."
)
teacher_attn = None
student_attn = None
if attn_mse_weight > 0.0:
try:
teacher_attn = find_attention_module(teacher_layers[layer_idx + 1])
student_attn = find_attention_module(reparam_layer.layer_a)
except ValueError as exc:
raise SystemExit(
"Attention-output preservation was requested but an attention module "
f"could not be resolved: {exc}"
) from exc
teacher_mlp = None
student_mlp = None
if mlp_mse_weight > 0.0:
try:
teacher_mlp = find_mlp_module(teacher_layers[layer_idx + 1])
student_mlp = find_mlp_module(reparam_layer.layer_a)
except ValueError as exc:
raise SystemExit(
"MLP-output preservation was requested but an MLP module could not be "
f"resolved: {exc}"
) from exc
# Virtual layer list: replace layer_a with reparam layer and remove layer_b.
virtual_layers = list(student_layers)
virtual_layers[layer_idx] = reparam_layer
del virtual_layers[layer_idx + 1]
# Only (U, s) are trainable.
for param in student_model.parameters():
param.requires_grad_(False)
for param in reparam_layer.gates.parameters():
param.requires_grad_(True)
for param in reparam_layer.u.parameters():
param.requires_grad_(True)
do_train = total_epochs > 0.0
if do_train:
teacher_model.eval()
student_model.train()
# Rough memory heads-up (esp. when --fisher_mode param makes per-element gates).
total_gate_elems = sum(int(p.numel()) for p in reparam_layer.gates.parameters())
total_u_elems = sum(int(p.numel()) for p in reparam_layer.u.parameters())
gate_mib = total_gate_elems * 4.0 / (1024.0 * 1024.0)
u_mib = total_u_elems * 4.0 / (1024.0 * 1024.0)
family_counts: Dict[str, int] = {"attn": 0, "mlp": 0, "other": 0}
for orig in reparam_layer._name_map:
family_counts[_classify_param_family(orig)] += 1
print(
f"[reparam] subset={param_subset} gates={len(reparam_layer.gates)} "
f"(attn={family_counts['attn']}, mlp={family_counts['mlp']}, other={family_counts['other']}) "
f"elems={total_gate_elems} (~{gate_mib:.1f} MiB), "
f"U_elems={total_u_elems} (~{u_mib:.1f} MiB; +optimizer state)"
)
optimizer = None
if do_train:
optimizer = torch.optim.AdamW(
[*reparam_layer.gates.parameters(), *reparam_layer.u.parameters()],
lr=float(args.distill_lr),
weight_decay=float(args.distill_weight_decay),
)
device_type = torch.device(args.device).type
amp_dtype = None
if args.dtype == "float16":
amp_dtype = torch.float16
elif args.dtype == "bfloat16":
amp_dtype = torch.bfloat16
use_amp = do_train and amp_dtype is not None and device_type == "cuda"
use_scaler = use_amp and amp_dtype == torch.float16
scaler = torch.cuda.amp.GradScaler() if use_scaler else None
full_epochs = int(total_epochs) if do_train else 0
fractional = (total_epochs - full_epochs) if do_train else 0.0
if fractional < 1e-8:
fractional = 0.0
epoch_plan = [(epoch_idx, None) for epoch_idx in range(full_epochs)]
if fractional > 0:
try:
batches_per_epoch = len(dataloader)
except TypeError as exc:
raise SystemExit(
"Fractional distill epochs require a dataloader with finite length."
) from exc
if batches_per_epoch > 0:
frac_batches = int(round(fractional * batches_per_epoch))
if frac_batches <= 0:
frac_batches = 1
epoch_plan.append((full_epochs, frac_batches))
grad_accum = int(getattr(args, "distill_grad_accum_steps", 1))
if grad_accum <= 0:
raise SystemExit("--distill_grad_accum_steps must be >= 1")
log_steps = int(getattr(args, "distill_log_steps", 100))
max_grad_norm = getattr(args, "distill_max_grad_norm", 1.0)
params_a = {name: p for name, p in layer_a.named_parameters()}
params_b = {name: p for name, p in layer_b.named_parameters()}
step = 0
for epoch_idx, max_batches in epoch_plan:
if max_batches is None:
epoch_iter = dataloader
else:
epoch_iter = itertools.islice(dataloader, max_batches)
iterator = epoch_iter
if tqdm is not None and _tqdm_enabled():
if progressive_cycle is not None:
if progressive_total is not None:
desc = (
f"Reparam (cycle {progressive_cycle}/{progressive_total}, "
f"epoch {epoch_idx+1})"
)
else:
desc = f"Reparam (cycle {progressive_cycle}, epoch {epoch_idx+1})"
else:
desc = f"Reparam (epoch {epoch_idx+1})"
iterator = tqdm(epoch_iter, desc=desc, unit="batch", total=max_batches)
for batch in iterator:
input_ids = batch[0].to(args.device)
attention_mask = batch[1].to(args.device)
teacher_ids = input_ids.to(args.distill_teacher_device or args.device)
teacher_mask = attention_mask.to(args.distill_teacher_device or args.device)
teacher_depth = layer_idx + 2
student_depth = layer_idx + 1
autocast_ctx = (
torch.autocast(device_type=device_type, dtype=amp_dtype)
if use_amp
else nullcontext()
)
with autocast_ctx:
teacher_attn_ctx = (
capture_module_output(teacher_attn)
if teacher_attn is not None
else nullcontext({"output": None})
)
teacher_mlp_ctx = (
capture_module_output(teacher_mlp)
if teacher_mlp is not None
else nullcontext({"output": None})
)
with torch.no_grad():
with teacher_attn_ctx as teacher_attn_cache, teacher_mlp_ctx as teacher_mlp_cache:
teacher_hidden = forward_truncated(
teacher_parent,
teacher_layer_attr,
teacher_layers,
teacher_depth,
teacher_ids,
attention_mask=teacher_mask,
)
student_attn_ctx = (
capture_module_output(student_attn)
if student_attn is not None
else nullcontext({"output": None})
)
student_mlp_ctx = (
capture_module_output(student_mlp)
if student_mlp is not None
else nullcontext({"output": None})
)
with student_attn_ctx as student_attn_cache, student_mlp_ctx as student_mlp_cache:
student_hidden = forward_truncated(
student_parent,
student_layer_attr,
virtual_layers,
student_depth,
input_ids,
attention_mask=attention_mask,
)
if teacher_hidden.device != student_hidden.device:
teacher_hidden = teacher_hidden.to(student_hidden.device)
mse_loss = None
if hidden_mse_weight > 0.0:
diff = student_hidden - teacher_hidden
mse_loss = _masked_hidden_mse(diff, attention_mask)
if mse_loss is None:
continue
attn_aux_loss = None
if attn_mse_weight > 0.0:
teacher_attn_hidden = teacher_attn_cache.get("output")
student_attn_hidden = student_attn_cache.get("output")
if teacher_attn_hidden is None or student_attn_hidden is None:
raise RuntimeError(
"Attention-output preservation is enabled, but the forward "
"hook did not capture attention outputs."
)
if teacher_attn_hidden.device != student_attn_hidden.device:
teacher_attn_hidden = teacher_attn_hidden.to(student_attn_hidden.device)
attn_aux_loss = _masked_hidden_mse(
student_attn_hidden - teacher_attn_hidden,
attention_mask,
)
if attn_aux_loss is None:
continue
mlp_aux_loss = None
if mlp_mse_weight > 0.0:
teacher_mlp_hidden = teacher_mlp_cache.get("output")
student_mlp_hidden = student_mlp_cache.get("output")
if teacher_mlp_hidden is None or student_mlp_hidden is None:
raise RuntimeError(
"MLP-output preservation is enabled, but the forward hook "
"did not capture MLP outputs."
)
if teacher_mlp_hidden.device != student_mlp_hidden.device:
teacher_mlp_hidden = teacher_mlp_hidden.to(student_mlp_hidden.device)
mlp_aux_loss = _masked_hidden_mse(
student_mlp_hidden - teacher_mlp_hidden,
attention_mask,
)
if mlp_aux_loss is None:
continue
kl_loss = None
if kl_weight > 0.0:
with torch.no_grad():
teacher_outputs = teacher_model(
input_ids=teacher_ids,
attention_mask=teacher_mask,
use_cache=False,
)
teacher_logits = teacher_outputs.logits
virtual_container = torch.nn.ModuleList(virtual_layers)
with temporary_layers(
student_parent, student_layer_attr, virtual_container
):
student_outputs = student_model(
input_ids=input_ids,
attention_mask=attention_mask,
use_cache=False,
)
student_logits = student_outputs.logits
if teacher_logits.device != student_logits.device:
teacher_logits = teacher_logits.to(student_logits.device)
shift_teacher_logits = teacher_logits[:, :-1, :].contiguous()
shift_student_logits = student_logits[:, :-1, :].contiguous()
shift_mask = attention_mask[:, 1:].contiguous()
log_p_t = F.log_softmax(shift_teacher_logits / kl_temp, dim=-1)
log_p_s = F.log_softmax(shift_student_logits / kl_temp, dim=-1)
p_t = log_p_t.exp()
kl_flat = (p_t * (log_p_t - log_p_s)).sum(dim=-1)
kl_denom = shift_mask.sum()
if kl_denom.item() == 0:
continue
kl_loss = (
kl_flat * shift_mask.to(kl_flat.dtype)
).sum() / kl_denom
lambda_reg = None
if eta > 0.0:
reg_sum: Optional[torch.Tensor] = None
reg_elems = 0
for orig, safe in reparam_layer._name_map.items():
lam = torch.sigmoid(reparam_layer.gates[safe]).float()
target = gate_lambdas.get(orig)
if target is None:
target_t = 0.5
elif isinstance(target, torch.Tensor):
target_t = target.to(device=lam.device, dtype=lam.dtype)
else:
target_t = float(target)
diff_lam = lam - target_t
family = _classify_param_family(orig)
scale = _family_reg_scale(
family,
attn_scale=attn_reg_scale,
mlp_scale=mlp_reg_scale,
)
if scale <= 0.0:
continue
part = diff_lam.pow(2).sum() * scale
reg_sum = part if reg_sum is None else reg_sum + part
reg_elems += int(float(diff_lam.numel()) * scale)
if reg_elems > 0 and reg_sum is not None:
lambda_reg = reg_sum / float(reg_elems)
u_reg = None
if gamma > 0.0:
reg_sum: Optional[torch.Tensor] = None
reg_elems = 0
for orig, safe in reparam_layer._name_map.items():
u = reparam_layer.u[safe].float()
param_a = params_a.get(orig)
param_b = params_b.get(orig)
if param_a is None or param_b is None or param_b.shape != param_a.shape:
continue
u0 = 0.5 * (param_a.detach().float() - param_b.detach().float())
diff_u = u - u0
family = _classify_param_family(orig)
scale = _family_reg_scale(
family,
attn_scale=attn_reg_scale,
mlp_scale=mlp_reg_scale,
)
if scale <= 0.0:
continue
part = diff_u.pow(2).sum() * scale
reg_sum = part if reg_sum is None else reg_sum + part
reg_elems += int(float(diff_u.numel()) * scale)
if reg_elems > 0 and reg_sum is not None:
u_reg = reg_sum / float(reg_elems)
total_loss = None
if mse_loss is not None:
total_loss = hidden_mse_weight * mse_loss
if attn_aux_loss is not None:
total_loss = attn_mse_weight * attn_aux_loss if total_loss is None else total_loss + (attn_mse_weight * attn_aux_loss)
if mlp_aux_loss is not None:
total_loss = mlp_mse_weight * mlp_aux_loss if total_loss is None else total_loss + (mlp_mse_weight * mlp_aux_loss)
if kl_loss is not None:
total_loss = kl_weight * (kl_temp ** 2) * kl_loss if total_loss is None else total_loss + (kl_weight * (kl_temp ** 2) * kl_loss)
if lambda_reg is not None:
total_loss = eta * lambda_reg if total_loss is None else total_loss + (eta * lambda_reg)
if u_reg is not None:
total_loss = gamma * u_reg if total_loss is None else total_loss + (gamma * u_reg)
if total_loss is None:
continue
if grad_accum > 1:
total_loss = total_loss / grad_accum
if use_scaler:
scaler.scale(total_loss).backward()
else:
total_loss.backward()
if (step + 1) % grad_accum == 0:
if max_grad_norm is not None:
if use_scaler:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(
[*reparam_layer.gates.parameters(), *reparam_layer.u.parameters()],
float(max_grad_norm),
)
if use_scaler:
scaler.step(optimizer)
scaler.update()
else:
optimizer.step()
optimizer.zero_grad(set_to_none=True)
if log_steps and (step == 0 or (step + 1) % log_steps == 0):
log_parts = [f"loss={total_loss.item():.6e}"]
if mse_loss is not None:
log_parts.append(f"mse={mse_loss.item():.6e}")
else:
log_parts.append("mse=disabled")
if attn_aux_loss is not None:
log_parts.append(f"attn_mse={attn_aux_loss.item():.6e}")
elif attn_mse_weight > 0.0:
log_parts.append("attn_mse=nan")
if mlp_aux_loss is not None:
log_parts.append(f"mlp_mse={mlp_aux_loss.item():.6e}")
elif mlp_mse_weight > 0.0:
log_parts.append("mlp_mse=nan")
if kl_loss is not None:
log_parts.append(f"kl={kl_loss.item():.6e}")
if lambda_reg is not None:
log_parts.append(f"lam_reg={lambda_reg.item():.6e}")
if u_reg is not None:
log_parts.append(f"u_reg={u_reg.item():.6e}")
print(
f"[reparam] epoch={epoch_idx+1} step={step+1} " + " ".join(log_parts)
)
step += 1
merged = reparam_layer.materialize_into_layer_a()
final_lambdas = reparam_layer.gate_lambdas()
stats: Dict[str, object] = {
"enabled": True,
"epochs": total_epochs,
"lr": float(args.distill_lr),
"hidden_mse_weight": hidden_mse_weight,
"attn_mse_weight": attn_mse_weight,
"mlp_mse_weight": mlp_mse_weight,
"eta": eta,
"gamma": gamma,
"attn_reg_scale": attn_reg_scale,
"mlp_reg_scale": mlp_reg_scale,
"param_subset": param_subset,
"num_gates": len(final_lambdas),
"num_attn_gates": family_counts["attn"],
"num_mlp_gates": family_counts["mlp"],
"num_other_gates": family_counts["other"],
}
return merged, final_lambdas, stats
class LoRALinear(torch.nn.Module):
def __init__(
self,
base: torch.nn.Linear,
rank: int,
alpha: float,
dropout: float,
) -> None:
super().__init__()
if rank <= 0:
raise ValueError("LoRA rank must be positive")
self.base = base
self.rank = int(rank)
self.alpha = float(alpha)
self.scaling = self.alpha / float(self.rank)
self.enabled = True
if dropout > 0:
self.dropout = torch.nn.Dropout(dropout)
else:
self.dropout = torch.nn.Identity()
self.lora_A = torch.nn.Linear(base.in_features, self.rank, bias=False)
self.lora_B = torch.nn.Linear(self.rank, base.out_features, bias=False)
torch.nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5))
torch.nn.init.zeros_(self.lora_B.weight)
self.lora_A.to(device=base.weight.device, dtype=base.weight.dtype)
self.lora_B.to(device=base.weight.device, dtype=base.weight.dtype)
self.merged = False
def lora_parameters(self) -> List[torch.nn.Parameter]:
return [*self.lora_A.parameters(), *self.lora_B.parameters()]
def forward(self, x: torch.Tensor) -> torch.Tensor:
result = self.base(x)
if self.merged or not self.enabled:
return result
lora_out = self.lora_B(self.lora_A(self.dropout(x)))
return result + lora_out * self.scaling
def merge(self) -> None:
if self.merged:
return
delta = torch.matmul(self.lora_B.weight, self.lora_A.weight)
delta = delta.to(dtype=self.base.weight.dtype) * self.scaling
self.base.weight.data.add_(delta)
self.merged = True
def _get_child_module(parent: torch.nn.Module, part: str) -> torch.nn.Module:
if isinstance(parent, (torch.nn.ModuleList, torch.nn.Sequential)) and part.isdigit():
return parent[int(part)]
if isinstance(parent, torch.nn.ModuleDict):
return parent[part]
return getattr(parent, part)
def _set_child_module(parent: torch.nn.Module, part: str, module: torch.nn.Module) -> None:
if isinstance(parent, (torch.nn.ModuleList, torch.nn.Sequential)) and part.isdigit():
parent[int(part)] = module
return
if isinstance(parent, torch.nn.ModuleDict):
parent[part] = module
return
setattr(parent, part, module)
def _resolve_parent_module(
root: torch.nn.Module, module_name: str
) -> Optional[tuple]:
if not module_name:
return None
parts = module_name.split(".")
parent = root
for part in parts[:-1]:
parent = _get_child_module(parent, part)
return parent, parts[-1]
def _resolve_module_by_path(root: torch.nn.Module, module_path: str) -> Optional[torch.nn.Module]:
if not module_path:
return None
parts = [part for part in module_path.split(".") if part]
node = root
for part in parts:
try:
node = _get_child_module(node, part)
except Exception:
return None
return node
def _resolve_layer_container_for_lora(
model: torch.nn.Module, layer_path: Optional[str]
) -> Tuple[Optional[str], Optional[object]]:
"""Resolve transformer layer container with optional auto-detection.
Mirrors the candidate path strategy used elsewhere, so LoRA filtering can work
even when --layer_path is not provided.
"""
if isinstance(layer_path, str) and layer_path and layer_path.lower() != "none":
container = _resolve_module_by_path(model, layer_path)
if container is not None:
try:
list(container)
return layer_path, container
except TypeError:
pass
candidate_paths = [
"model.layers", # LLaMA, Mistral, Qwen2, Gemma
"model.decoder.layers", # OPT
"transformer.h", # GPT-2, GPT-J, Bloom, Falcon
"transformer.blocks", # MPT
"gpt_neox.layers", # GPT-NeoX
"layers", # fallback
]
for path in candidate_paths:
container = _resolve_module_by_path(model, path)
if container is None:
continue
try:
list(container)
except TypeError:
continue
return path, container
return None, None
def _parse_exclude_pairs_local(raw_values, num_pairs: int) -> Set[int]:
if not raw_values or num_pairs <= 0:
return set()
exclude: Set[int] = set()
for item in raw_values:
if item is None:
continue
for part in str(item).split(","):
part = part.strip()
if not part:
continue
try:
idx = int(part)
except ValueError as exc:
raise SystemExit("--exclude_pairs must contain integers.") from exc
if idx < 0:
idx = num_pairs + idx
if 0 <= idx < num_pairs:
exclude.add(idx)
return exclude
def _extract_layer_index_from_module_name(
module_name: str, layer_path: str
) -> Optional[int]:
if not layer_path:
return None
prefix = f"{layer_path}."
if not module_name.startswith(prefix):
return None
rest = module_name[len(prefix) :]
if not rest:
return None
idx_text = rest.split(".", 1)[0]
if not idx_text.isdigit():
return None
return int(idx_text)
def _select_linear_modules_for_lora_targets(
model: torch.nn.Module,
args: argparse.Namespace,
*,
log_tag: str,
) -> Tuple[List[Tuple[str, torch.nn.Linear]], Optional[Set[str]], Set[int], Optional[str]]:
raw_targets = getattr(args, "lora_target_modules", None)
target_modules: Optional[Set[str]] = None
if raw_targets:
target_modules = {str(item) for item in raw_targets if str(item)}
exclude_layer_indices: Set[int] = set()
resolved_layer_path: Optional[str] = None
if bool(getattr(args, "lora_respect_exclude_pairs", False)):
requested_layer_path = getattr(args, "layer_path", None)
resolved_layer_path, layer_container = _resolve_layer_container_for_lora(
model, requested_layer_path
)
if isinstance(layer_container, (torch.nn.ModuleList, list, tuple)):
num_pairs = max(len(layer_container) - 1, 0)
exclude_pairs = _parse_exclude_pairs_local(
getattr(args, "exclude_pairs", None), num_pairs
)
for pair_idx in exclude_pairs:
exclude_layer_indices.add(pair_idx)
exclude_layer_indices.add(pair_idx + 1)
else:
print(
f"[{log_tag}] Warning: --lora_respect_exclude_pairs enabled, but "
f"could not resolve layer path '{requested_layer_path}'."
)
linear_modules = [
(name, module)
for name, module in model.named_modules()
if isinstance(module, torch.nn.Linear)
and (target_modules is None or name.split(".")[-1] in target_modules)
and (
not exclude_layer_indices
or _extract_layer_index_from_module_name(name, resolved_layer_path or "")
not in exclude_layer_indices
)
]
return linear_modules, target_modules, exclude_layer_indices, resolved_layer_path
def apply_lora_adapters(
model: torch.nn.Module, args: argparse.Namespace
) -> List[LoRALinear]:
if args.lora_rank <= 0:
raise SystemExit("--lora_rank must be > 0 when --lora_epochs > 0")
linear_modules, target_modules, exclude_layer_indices, _ = (
_select_linear_modules_for_lora_targets(model, args, log_tag="lora")
)
if not linear_modules:
raise SystemExit(
"No Linear modules found for LoRA adapters "
"(check --lora_target_modules / --exclude_pairs / --lora_respect_exclude_pairs)."
)
lora_modules: List[LoRALinear] = []
for name, module in linear_modules:
resolved = _resolve_parent_module(model, name)
if resolved is None:
continue
parent, attr = resolved
wrapped = LoRALinear(
base=module,
rank=args.lora_rank,
alpha=args.lora_alpha,
dropout=args.lora_dropout,
)
_set_child_module(parent, attr, wrapped)
lora_modules.append(wrapped)
for param in model.parameters():
param.requires_grad_(False)
for lora_module in lora_modules:
for param in lora_module.lora_parameters():
param.requires_grad_(True)
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
percent = 100.0 * trainable_params / max(total_params, 1)
target_note = ""
if target_modules is not None:
target_note = f" target={sorted(target_modules)}"
exclude_note = ""
if exclude_layer_indices:
exclude_note = f" excluded_layers={sorted(exclude_layer_indices)}"
print(
"[lora] Applied adapters to "
f"{len(lora_modules)} linear modules "
f"({trainable_params}/{total_params} trainable, {percent:.4f}%)."
f"{target_note}{exclude_note}"
)
return lora_modules
def merge_lora_adapters(model: torch.nn.Module) -> None:
lora_entries = [
(name, module)
for name, module in model.named_modules()
if isinstance(module, LoRALinear)
]
for name, module in lora_entries:
module.merge()
resolved = _resolve_parent_module(model, name)
if resolved is None:
continue
parent, attr = resolved
_set_child_module(parent, attr, module.base)
def set_lora_enabled(lora_modules: List[LoRALinear], enabled: bool) -> None:
for module in lora_modules:
module.enabled = enabled
def lora_ce_finetune(
model: torch.nn.Module,
dataloader,
eval_tokenizer,
eval_datasets: List[str],
eval_configs: List[Optional[str]],
eval_history: List[Dict[str, object]],
args: argparse.Namespace,
eval_dataloaders: Optional[Dict[str, object]] = None,
progressive_cycle: Optional[int] = None,
progressive_total: Optional[int] = None,
) -> None:
total_epochs = float(args.lora_epochs)
if total_epochs <= 0:
return
use_kl = bool(getattr(args, "lora_kl_enabled", False))
kl_weight = float(getattr(args, "lora_kl_weight", 0.0))
kl_temp = float(getattr(args, "lora_kl_temp", 1.0))
if use_kl:
if kl_weight < 0.0:
raise SystemExit("--lora_kl_weight must be >= 0")
if kl_temp <= 0.0:
raise SystemExit("--lora_kl_temp must be > 0")
if kl_weight == 0.0:
use_kl = False
lora_modules = apply_lora_adapters(model, args)
if not lora_modules:
return
model.train()
lora_params = []
for module in lora_modules:
lora_params.extend(module.lora_parameters())
optimizer = torch.optim.AdamW(
lora_params,
lr=args.lora_lr,
weight_decay=args.lora_weight_decay,
)
device_type = torch.device(args.device).type
amp_dtype = None
if args.dtype == "float16":
amp_dtype = torch.float16
elif args.dtype == "bfloat16":
amp_dtype = torch.bfloat16
use_amp = amp_dtype is not None and device_type == "cuda"
use_scaler = use_amp and amp_dtype == torch.float16
scaler = torch.cuda.amp.GradScaler() if use_scaler else None
full_epochs = int(total_epochs)
fractional = total_epochs - full_epochs
if fractional < 1e-8:
fractional = 0.0
epoch_plan = [(epoch_idx, None) for epoch_idx in range(full_epochs)]
if fractional > 0:
try:
batches_per_epoch = len(dataloader)
except TypeError as exc:
raise SystemExit(
"Fractional lora epochs require a dataloader with finite length."
) from exc
if batches_per_epoch > 0:
frac_batches = int(round(fractional * batches_per_epoch))
if frac_batches <= 0:
frac_batches = 1
epoch_plan.append((full_epochs, frac_batches))
step = 0
for epoch_idx, max_batches in epoch_plan:
if max_batches is None:
epoch_iter = dataloader
else:
epoch_iter = itertools.islice(dataloader, max_batches)
iterator = epoch_iter
if tqdm is not None and _tqdm_enabled():
if progressive_cycle is not None:
if progressive_total is not None:
desc = (
f"LoRA (cycle {progressive_cycle}/{progressive_total}, "
f"epoch {epoch_idx+1})"
)
else:
desc = f"LoRA (cycle {progressive_cycle}, epoch {epoch_idx+1})"
else:
desc = f"LoRA (epoch {epoch_idx+1})"
iterator = tqdm(
epoch_iter,
desc=desc,
unit="batch",
total=max_batches,
)
for batch in iterator:
input_ids = batch[0].to(args.device)
attention_mask = batch[1].to(args.device)
autocast_ctx = (
torch.autocast(device_type=device_type, dtype=amp_dtype)
if use_amp
else nullcontext()
)
with autocast_ctx:
outputs = model(
input_ids=input_ids,
attention_mask=attention_mask,
use_cache=False,
)
logits = outputs.logits
shift_logits = logits[:, :-1, :].contiguous()
shift_labels = input_ids[:, 1:].contiguous()
shift_mask = attention_mask[:, 1:].contiguous()
ce_flat = F.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
reduction="none",
)
ce_denom = shift_mask.sum()
if ce_denom.item() == 0:
continue
ce_loss = (
ce_flat * shift_mask.view(-1).to(ce_flat.dtype)
).sum() / ce_denom
kl_loss = None
if use_kl:
set_lora_enabled(lora_modules, False)
with torch.no_grad():
base_outputs = model(
input_ids=input_ids,
attention_mask=attention_mask,
use_cache=False,
)
base_logits = base_outputs.logits
set_lora_enabled(lora_modules, True)
if base_logits.device != shift_logits.device:
base_logits = base_logits.to(shift_logits.device)
shift_base_logits = base_logits[:, :-1, :].contiguous()
log_p_pre = F.log_softmax(shift_base_logits / kl_temp, dim=-1)
log_p_post = F.log_softmax(shift_logits / kl_temp, dim=-1)
p_pre = log_p_pre.exp()
kl_flat = (p_pre * (log_p_pre - log_p_post)).sum(dim=-1)
kl_loss = (
kl_flat * shift_mask.to(kl_flat.dtype)
).sum() / ce_denom
total_loss = ce_loss
if kl_loss is not None:
total_loss = total_loss + (kl_weight * (kl_temp ** 2) * kl_loss)
if args.lora_grad_accum_steps > 1:
total_loss = total_loss / args.lora_grad_accum_steps
if use_scaler:
scaler.scale(total_loss).backward()
else:
total_loss.backward()
if (step + 1) % args.lora_grad_accum_steps == 0:
if args.lora_max_grad_norm is not None:
if use_scaler:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(
lora_params,
args.lora_max_grad_norm,
)
if use_scaler:
scaler.step(optimizer)
scaler.update()
else:
optimizer.step()
optimizer.zero_grad(set_to_none=True)
if args.lora_eval_every and (step + 1) % args.lora_eval_every == 0:
prev_mode = model.training
model.eval()
eval_device = args.eval_device or args.device
if eval_dataloaders is not None:
results = ppl_eval.evaluate_ppl_dataloaders(
model,
eval_dataloaders,
eval_device,
max_batches=args.lora_eval_max_batches,
)
else:
results = ppl_eval.evaluate_ppl_datasets(
model,
eval_tokenizer,
datasets=eval_datasets,
configs=eval_configs,
split=args.eval_split,
text_field=args.eval_text_field,
num_samples=args.eval_num_samples,
seq_len=args.eval_seq_len,
batch_size=args.eval_batch_size or args.batch_size,
device=eval_device,
seed=args.seed,
shuffle=False,
model_family=args.eval_model_family,
add_bos=args.eval_add_bos,
max_batches=args.lora_eval_max_batches,
cache_dir=args.eval_cache_dir,
num_workers=args.eval_num_workers,
)
eval_history.append({"step": step + 1, "ppl": results})
print(f"[lora] eval step={step+1}: {results}")
if prev_mode:
model.train()
if args.lora_log_steps and (
step == 0 or (step + 1) % args.lora_log_steps == 0
):
log_parts = [f"loss={total_loss.item():.6f}"]
if kl_loss is not None:
log_parts.append(f"kl={kl_loss.item():.6f}")
print(
f"[lora] epoch={epoch_idx+1} step={step+1} "
+ " ".join(log_parts)
)
step += 1
merge_lora_adapters(model)
def _masked_kl(
logits_p: torch.Tensor,
logits_q: torch.Tensor,
attention_mask: torch.Tensor,
temp: float,
detach_p: bool = True,
) -> Optional[torch.Tensor]:
shift_mask = attention_mask[:, 1:].contiguous()
denom = shift_mask.sum()
if denom.item() == 0:
return None
p = logits_p[:, :-1, :].contiguous()
q = logits_q[:, :-1, :].contiguous()
if p.device != q.device:
p = p.to(q.device)
# Keep dtype to avoid blowing up memory on large vocab models.
log_p = F.log_softmax(p / temp, dim=-1)
log_q = F.log_softmax(q / temp, dim=-1)
if detach_p:
log_p = log_p.detach()
p_probs = log_p.exp()
kl_flat = (p_probs * (log_p - log_q)).sum(dim=-1)
return (kl_flat * shift_mask.to(kl_flat.dtype)).sum() / denom
def _extract_hidden_tensor(output: object) -> Optional[torch.Tensor]:
if isinstance(output, torch.Tensor):
return output
if isinstance(output, (tuple, list)) and output:
first = output[0]
if isinstance(first, torch.Tensor):
return first
return None
def _grad_l2_norm(grads: List[Optional[torch.Tensor]]) -> float:
total = 0.0
for grad in grads:
if grad is None:
continue
total += float(grad.detach().float().pow(2).sum().item())
if total <= 0.0:
return 0.0
return float(math.sqrt(total))
def _register_forward_pre_hook_with_optional_kwargs(layer, hook):
try:
handle = layer.register_forward_pre_hook(hook, with_kwargs=True)
return handle
except TypeError:
def wrapper(module, inputs):
return hook(module, inputs, None)
return layer.register_forward_pre_hook(wrapper)
def commutator_precondition(
student_model: torch.nn.Module,
student_layers: List[torch.nn.Module],
teacher_model: torch.nn.Module,
dataloader,
dwce_scores: Optional[List[float]],
args: argparse.Namespace,
exclude_pairs: Optional[Set[int]] = None,
progressive_cycle: Optional[int] = None,
progressive_total: Optional[int] = None,
) -> Dict[str, object]:
"""Run commutator-style preconditioning before pair fusion.
Objective on each sampled pair i:
L = T^2 * KL(p_teacher || p_student) + mu * L_interaction(i)
Interaction loss is computed locally on block (i+1):
r1 = B_{i+1}(h_{i+1}) - h_{i+1}
r0 = B_{i+1}(h_i) - h_i
L_interaction = ||r1-r0||^2 (or relative form).
"""
if not bool(getattr(args, "comm_enabled", False)):
return {"enabled": False}
if not student_layers or len(student_layers) < 2:
return {"enabled": False, "reason": "need_at_least_2_layers"}
temp = float(getattr(args, "comm_temp", 2.0))
steps_ratio = float(getattr(args, "comm_steps_ratio", 0.1))
lr_scale = float(getattr(args, "comm_lr_scale", 0.1))
sample_eta = float(getattr(args, "comm_sample_eta", 0.5))
sample_dwce_scale = float(getattr(args, "comm_sample_dwce_scale", 1.0))
top_k = int(getattr(args, "comm_topk", 1))
interaction_mode = str(getattr(args, "comm_interaction_mode", "relative")).strip().lower()
interaction_eps = float(getattr(args, "comm_interaction_eps", 1e-8))
mu_cfg = getattr(args, "comm_mu", None)
mu_auto = bool(getattr(args, "comm_mu_auto", False))
mu_auto_rho = float(getattr(args, "comm_mu_auto_rho", 0.1))
mu_auto_eps = float(getattr(args, "comm_mu_auto_eps", 1e-8))
comm_train_mode = str(getattr(args, "comm_train_mode", "lora")).strip().lower()
log_steps = int(getattr(args, "comm_log_steps", 50))
if temp <= 0.0:
raise SystemExit("--comm_temp must be > 0")
if steps_ratio < 0.0:
raise SystemExit("--comm_steps_ratio must be >= 0")
if lr_scale <= 0.0:
raise SystemExit("--comm_lr_scale must be > 0")
if not (0.0 <= sample_eta <= 1.0):
raise SystemExit("--comm_sample_eta must be in [0, 1]")
if top_k <= 0:
raise SystemExit("--comm_topk must be >= 1")
if interaction_mode not in {"mse", "relative"}:
raise SystemExit("--comm_interaction_mode must be one of: mse, relative")
if comm_train_mode not in {"lora", "full"}:
raise SystemExit("--comm_train_mode must be one of: lora, full")
if interaction_eps <= 0.0:
raise SystemExit("--comm_interaction_eps must be > 0")
if mu_auto_rho < 0.0:
raise SystemExit("--comm_mu_auto_rho must be >= 0")
if mu_auto_eps <= 0.0:
raise SystemExit("--comm_mu_auto_eps must be > 0")
if mu_cfg is None:
base_mu = 0.5 if interaction_mode == "relative" else 0.1
else:
base_mu = float(mu_cfg)
if base_mu < 0.0:
raise SystemExit("--comm_mu must be >= 0")
distill_epochs = float(getattr(args, "distill_epochs", 1.0))
if distill_epochs <= 0.0:
distill_epochs = 1.0
grad_accum = int(getattr(args, "distill_grad_accum_steps", 1))
if grad_accum <= 0:
grad_accum = 1
try:
batches_per_epoch = len(dataloader)
except TypeError as exc:
raise SystemExit(
"Commutator preconditioning requires a finite-length distillation dataloader."
) from exc
if batches_per_epoch <= 0:
return {"enabled": False, "reason": "empty_dataloader"}
full_epochs = int(distill_epochs)
fractional = distill_epochs - full_epochs
if fractional < 1e-8:
fractional = 0.0
total_batches = full_epochs * batches_per_epoch
if fractional > 0.0:
frac_batches = int(round(fractional * batches_per_epoch))
if frac_batches <= 0:
frac_batches = 1
total_batches += frac_batches
distill_opt_steps = int(math.ceil(total_batches / float(grad_accum)))
target_opt_steps = int(round(steps_ratio * distill_opt_steps))
if target_opt_steps <= 0:
target_opt_steps = 1
num_pairs = max(len(student_layers) - 1, 0)
exclude_set = {
int(idx)
for idx in (exclude_pairs or set())
if isinstance(idx, int) and 0 <= int(idx) < num_pairs
}
allowed_pairs = [i for i in range(num_pairs) if i not in exclude_set]
if not allowed_pairs:
return {"enabled": False, "reason": "all_pairs_excluded"}
ranked_pairs = list(allowed_pairs)
if dwce_scores is not None and len(dwce_scores) >= num_pairs:
finite_pairs = []
for idx in allowed_pairs:
value = float(dwce_scores[idx])
if math.isfinite(value):
finite_pairs.append(idx)
if finite_pairs:
ranked_pairs = sorted(finite_pairs, key=lambda i: float(dwce_scores[i]))
else:
ranked_pairs = list(allowed_pairs)
candidate_pairs = ranked_pairs[: min(top_k, len(ranked_pairs))]
if not candidate_pairs:
return {"enabled": False, "reason": "no_candidate_pairs"}
layer_trainable_params: List[List[torch.nn.Parameter]] = []
trainable_params: List[torch.nn.Parameter] = []
if comm_train_mode == "lora":
# LoRA comm preconditioning: update LoRA adapters on receiver layer (i+1).
lora_modules = apply_lora_adapters(student_model, args)
if not lora_modules:
return {"enabled": False, "reason": "no_lora_modules"}
trainable_seen: Set[int] = set()
for module in lora_modules:
for param in module.lora_parameters():
pid = id(param)
if pid in trainable_seen:
continue
trainable_seen.add(pid)
trainable_params.append(param)
for layer in student_layers:
seen: Set[int] = set()
params: List[torch.nn.Parameter] = []
for module in layer.modules():
if not isinstance(module, LoRALinear):
continue
for param in module.lora_parameters():
pid = id(param)
if pid in seen:
continue
seen.add(pid)
params.append(param)
layer_trainable_params.append(params)
else:
# Full-weight comm preconditioning: update full receiver-layer weights.
for layer in student_layers:
seen: Set[int] = set()
params: List[torch.nn.Parameter] = []
for param in layer.parameters():
if not isinstance(param, torch.nn.Parameter):
continue
pid = id(param)
if pid in seen:
continue
seen.add(pid)
params.append(param)
layer_trainable_params.append(params)
candidate_pairs = [
i
for i in candidate_pairs
if (i + 1) < len(layer_trainable_params) and layer_trainable_params[i + 1]
]
if not candidate_pairs:
if comm_train_mode == "lora":
merge_lora_adapters(student_model)
return {"enabled": False, "reason": "no_trainable_receiver_layers"}
if comm_train_mode == "full":
trainable_seen: Set[int] = set()
for pair_idx in candidate_pairs:
for param in layer_trainable_params[pair_idx + 1]:
pid = id(param)
if pid in trainable_seen:
continue
trainable_seen.add(pid)
trainable_params.append(param)
if not trainable_params:
return {"enabled": False, "reason": "no_trainable_receiver_layers"}
# Freeze non-comm params to reduce grad memory.
for param in student_model.parameters():
param.requires_grad_(False)
for param in trainable_params:
param.requires_grad_(True)
if not trainable_params:
if comm_train_mode == "lora":
merge_lora_adapters(student_model)
return {"enabled": False, "reason": "no_trainable_params"}
candidate_probs = torch.full(
(len(candidate_pairs),),
1.0 / float(len(candidate_pairs)),
dtype=torch.float32,
)
if dwce_scores is not None and len(dwce_scores) >= num_pairs and sample_eta > 0.0:
score_vec = torch.tensor(
[float(dwce_scores[i]) for i in candidate_pairs], dtype=torch.float32
)
score_vec = torch.nan_to_num(score_vec, nan=1e9, posinf=1e9, neginf=-1e9)
biased = torch.softmax(-float(sample_dwce_scale) * score_vec, dim=0)
candidate_probs = (1.0 - sample_eta) * candidate_probs + sample_eta * biased
candidate_probs = candidate_probs / candidate_probs.sum()
probs_by_pair = [0.0 for _ in range(num_pairs)]
for pos, pair_idx in enumerate(candidate_pairs):
probs_by_pair[pair_idx] = float(candidate_probs[pos].item())
lr = float(getattr(args, "distill_lr", 1e-4)) * lr_scale
optimizer = torch.optim.AdamW(
trainable_params,
lr=lr,
weight_decay=float(getattr(args, "distill_weight_decay", 0.0)),
)
device_type = torch.device(args.device).type
amp_dtype = None
if args.dtype == "float16":
amp_dtype = torch.float16
elif args.dtype == "bfloat16":
amp_dtype = torch.bfloat16
use_amp = amp_dtype is not None and device_type == "cuda"
use_scaler = use_amp and amp_dtype == torch.float16
scaler = torch.cuda.amp.GradScaler() if use_scaler else None
teacher_device = next(teacher_model.parameters()).device
teacher_model.eval()
student_model.train()
gen = torch.Generator(device="cpu")
seed = int(getattr(args, "seed", 0))
if progressive_cycle is not None:
seed += int(progressive_cycle) * 100003
gen.manual_seed(seed)
opt_step = 0
total_loss_sum = 0.0
anchor_sum = 0.0
interaction_sum = 0.0
mu_sum = 0.0
counted = 0
pair_counts = [0 for _ in range(num_pairs)]
desc = "Comm"
if progressive_cycle is not None:
if progressive_total is not None:
desc = f"Comm (cycle {progressive_cycle}/{progressive_total})"
else:
desc = f"Comm (cycle {progressive_cycle})"
iterator = range(target_opt_steps)
if tqdm is not None and _tqdm_enabled():
iterator = tqdm(iterator, desc=desc, unit="step")
data_iter = iter(dataloader)
autocast_ctx = (
torch.autocast(device_type=device_type, dtype=amp_dtype)
if use_amp
else nullcontext()
)
for _ in iterator:
optimizer.zero_grad(set_to_none=True)
accum_done = 0
while accum_done < grad_accum:
try:
batch = next(data_iter)
except StopIteration:
data_iter = iter(dataloader)
batch = next(data_iter)
input_ids = batch[0].to(args.device)
attention_mask = batch[1].to(args.device)
sampled_pos = int(torch.multinomial(candidate_probs, 1, generator=gen).item())
pair_idx = int(candidate_pairs[sampled_pos])
pair_counts[pair_idx] += 1
receiver_params = layer_trainable_params[pair_idx + 1]
receiver_param_ids = {id(param) for param in receiver_params}
teacher_ids = input_ids.to(teacher_device)
teacher_mask = attention_mask.to(teacher_device)
with torch.no_grad(), autocast_ctx:
teacher_outputs = teacher_model(
input_ids=teacher_ids,
attention_mask=teacher_mask,
use_cache=False,
)
teacher_logits = teacher_outputs.logits
capture: Dict[str, object] = {
"h_l": None,
"h_lp1": None,
"y1": None,
"recv_args": None,
"recv_kwargs": None,
}
def _hook_l(_module, inputs, _output):
if inputs and isinstance(inputs[0], torch.Tensor):
capture["h_l"] = inputs[0]
def _hook_recv_pre(_module, inputs, kwargs):
capture["recv_args"] = inputs
capture["recv_kwargs"] = kwargs
def _hook_recv(_module, inputs, output):
if inputs and isinstance(inputs[0], torch.Tensor):
capture["h_lp1"] = inputs[0]
capture["y1"] = _extract_hidden_tensor(output)
handles: List[object] = [
student_layers[pair_idx].register_forward_hook(_hook_l),
_register_forward_pre_hook_with_optional_kwargs(
student_layers[pair_idx + 1], _hook_recv_pre
),
student_layers[pair_idx + 1].register_forward_hook(_hook_recv),
]
try:
with autocast_ctx:
student_outputs = student_model(
input_ids=input_ids,
attention_mask=attention_mask,
use_cache=False,
)
student_logits = student_outputs.logits
finally:
for handle in handles:
try:
handle.remove()
except Exception:
pass
with autocast_ctx:
anchor_kl = _masked_kl(
teacher_logits,
student_logits,
attention_mask,
temp=temp,
detach_p=True,
)
if anchor_kl is None:
continue
anchor_loss = (temp ** 2) * anchor_kl
interaction_loss = None
h_l = capture.get("h_l")
h_lp1 = capture.get("h_lp1")
y1 = capture.get("y1")
recv_args = capture.get("recv_args")
recv_kwargs = capture.get("recv_kwargs")
if (
isinstance(h_l, torch.Tensor)
and isinstance(h_lp1, torch.Tensor)
and isinstance(y1, torch.Tensor)
and isinstance(recv_args, tuple)
and len(recv_args) > 0
and isinstance(recv_args[0], torch.Tensor)
):
call_args = list(recv_args)
first_hidden = call_args[0]
h_l_detached = h_l.detach().to(
device=first_hidden.device,
dtype=first_hidden.dtype,
)
call_args[0] = h_l_detached
call_kwargs = dict(recv_kwargs) if isinstance(recv_kwargs, dict) else {}
y0_raw = student_layers[pair_idx + 1](*tuple(call_args), **call_kwargs)
y0 = _extract_hidden_tensor(y0_raw)
if isinstance(y0, torch.Tensor):
if y0.device != y1.device:
y0 = y0.to(y1.device)
h_lp1_detached = h_lp1.detach().to(device=y1.device, dtype=y1.dtype)
h_l_for_res = h_l.detach().to(device=y0.device, dtype=y0.dtype)
r1 = y1 - h_lp1_detached
r0 = y0 - h_l_for_res
mask = attention_mask.to(dtype=r1.dtype)
mask_sum = mask.sum()
if mask_sum.item() > 0:
if interaction_mode == "relative":
num = (r1 - r0).float().pow(2).sum(dim=-1)
den = r1.float().pow(2).sum(dim=-1) + float(interaction_eps)
ratio = (num / den) * mask.to(num.dtype)
interaction_loss = ratio.sum() / (mask_sum + 1e-8)
else:
denom = mask_sum * r1.size(-1)
if denom.item() > 0:
interaction_loss = (
(r1 - r0).pow(2) * mask.unsqueeze(-1)
).sum() / denom
mu_effective = float(base_mu)
if (
mu_auto
and interaction_loss is not None
and receiver_params
and mu_auto_rho > 0.0
):
anchor_grads = torch.autograd.grad(
anchor_loss,
receiver_params,
retain_graph=True,
allow_unused=True,
)
interaction_grads = torch.autograd.grad(
interaction_loss,
receiver_params,
retain_graph=True,
allow_unused=True,
)
anchor_norm = _grad_l2_norm(list(anchor_grads))
interaction_norm = _grad_l2_norm(list(interaction_grads))
if interaction_norm > 0.0:
mu_effective = float(
mu_auto_rho
* (anchor_norm / (interaction_norm + float(mu_auto_eps)))
)
else:
mu_effective = float(base_mu)
if not math.isfinite(mu_effective):
mu_effective = float(base_mu)
total_loss = anchor_loss
if interaction_loss is not None:
total_loss = total_loss + (float(mu_effective) * interaction_loss)
if grad_accum > 1:
total_loss = total_loss / float(grad_accum)
if use_scaler:
scaler.scale(total_loss).backward()
else:
total_loss.backward()
# Only the sampled receiver layer updates on this micro-batch.
for param in trainable_params:
if id(param) in receiver_param_ids:
continue
if param.grad is not None:
if comm_train_mode == "lora":
param.grad.zero_()
else:
param.grad = None
total_loss_sum += float(total_loss.detach().float().item())
anchor_sum += float(anchor_loss.detach().float().item())
if interaction_loss is not None:
interaction_sum += float(interaction_loss.detach().float().item())
mu_sum += float(mu_effective)
counted += 1
accum_done += 1
if args.distill_max_grad_norm is not None:
if use_scaler:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(
trainable_params,
float(args.distill_max_grad_norm),
)
if use_scaler:
scaler.step(optimizer)
scaler.update()
else:
optimizer.step()
opt_step += 1
if log_steps and (opt_step == 1 or opt_step % log_steps == 0):
denom = max(counted, 1)
print(
f"[comm] step={opt_step}/{target_opt_steps} "
f"loss={total_loss_sum/denom:.6f} "
f"anchor={anchor_sum/denom:.6f} "
f"int={interaction_sum/denom:.6f} "
f"mu={mu_sum/denom:.6f}"
)
if comm_train_mode == "lora":
merge_lora_adapters(student_model)
stats: Dict[str, object] = {
"enabled": True,
"train_mode": comm_train_mode,
"opt_steps": int(target_opt_steps),
"grad_accum_steps": int(grad_accum),
"lr": float(lr),
"temp": float(temp),
"steps_ratio": float(steps_ratio),
"lr_scale": float(lr_scale),
"interaction_mode": interaction_mode,
"interaction_eps": float(interaction_eps),
"mu": float(base_mu),
"mu_auto": bool(mu_auto),
"mu_auto_rho": float(mu_auto_rho),
"mu_auto_eps": float(mu_auto_eps),
"sample_eta": float(sample_eta),
"sample_dwce_scale": float(sample_dwce_scale),
"topk": int(top_k),
"candidate_pairs": [int(i) for i in candidate_pairs],
"trainable_params": int(sum(int(param.numel()) for param in trainable_params)),
}
total_samples = int(sum(pair_counts))
probs_list = [float(x) for x in probs_by_pair]
freqs = (
[float(c) / float(total_samples) for c in pair_counts]
if total_samples > 0
else [0.0 for _ in pair_counts]
)
top_show = min(10, num_pairs)
top_indices = sorted(range(num_pairs), key=lambda i: pair_counts[i], reverse=True)[:top_show]
top_pairs = [
{
"pair": int(i),
"count": int(pair_counts[i]),
"freq": float(freqs[i]),
"prob": float(probs_list[i]) if i < len(probs_list) else None,
}
for i in top_indices
if pair_counts[i] > 0
]
stats["pair_selection"] = {
"num_pairs": int(num_pairs),
"excluded_pairs": sorted(exclude_set),
"candidate_pairs": [int(i) for i in candidate_pairs],
"total_samples": total_samples,
"unique_pairs": int(sum(1 for c in pair_counts if c > 0)),
"counts": [int(c) for c in pair_counts],
"freqs": freqs,
"probs": probs_list,
"top_pairs": top_pairs,
}
if total_samples > 0 and top_pairs:
top_str = ", ".join(
f"{entry['pair']}-{entry['pair'] + 1}: {entry['count']} "
f"(obs={entry['freq']:.3f}, exp={entry['prob']:.3f})"
for entry in top_pairs
if entry.get("prob") is not None
)
if not top_str:
top_str = ", ".join(
f"{entry['pair']}-{entry['pair'] + 1}: {entry['count']} "
f"(obs={entry['freq']:.3f})"
for entry in top_pairs
)
print(
f"[comm] Pair sampling stats: total={total_samples} "
f"unique={stats['pair_selection']['unique_pairs']}/{num_pairs} "
f"top={top_str}"
)
if counted > 0:
stats["avg_loss"] = float(total_loss_sum / float(counted))
stats["avg_anchor"] = float(anchor_sum / float(counted))
stats["avg_interaction"] = float(interaction_sum / float(counted))
stats["avg_mu"] = float(mu_sum / float(counted))
return stats