#!/usr/bin/env python3 """Distillation helpers for fuse_layers.""" import argparse import itertools import math import os from contextlib import contextmanager, nullcontext from typing import Dict, List, Optional, Set, Tuple import torch import torch.nn.functional as F try: import ppl_eval except Exception as exc: # pragma: no cover - optional dependency raise SystemExit("ppl_eval.py is required (missing or invalid)") from exc try: from tqdm import tqdm except Exception: # pragma: no cover - optional dependency tqdm = None try: from torch.func import functional_call as _functional_call except Exception: # pragma: no cover - depends on torch version try: from torch.nn.utils.stateless import functional_call as _functional_call except Exception: # pragma: no cover - depends on torch version _functional_call = None from fuse_layers_model import find_attention_module, find_mlp_module def _tqdm_enabled() -> bool: value = os.environ.get("DISABLE_TQDM", os.environ.get("TQDM_DISABLE", "0")) return value.strip().lower() not in {"1", "true", "yes", "on"} @contextmanager def temporary_layers(parent: object, name: str, new_layers: torch.nn.Module): original = getattr(parent, name) setattr(parent, name, new_layers) try: yield finally: setattr(parent, name, original) @contextmanager def temporary_norm(parent: object): if hasattr(parent, "norm"): original = getattr(parent, "norm") setattr(parent, "norm", torch.nn.Identity()) try: yield finally: setattr(parent, "norm", original) else: yield def forward_truncated( parent: torch.nn.Module, layer_attr: str, layers: List[torch.nn.Module], upto: int, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: truncated = torch.nn.ModuleList(layers[:upto]) with temporary_layers(parent, layer_attr, truncated), temporary_norm(parent): outputs = parent( input_ids=input_ids, attention_mask=attention_mask, use_cache=False, ) if hasattr(outputs, "last_hidden_state"): return outputs.last_hidden_state return outputs[0] def _masked_hidden_mse(diff: torch.Tensor, attention_mask: torch.Tensor) -> Optional[torch.Tensor]: diff_f = diff.float() mask = attention_mask.to(device=diff.device, dtype=torch.float32) denom = mask.sum() * diff_f.size(-1) if denom.item() == 0: return None return (diff_f.pow(2) * mask.unsqueeze(-1)).sum() / denom def _extract_hidden_like(output) -> Optional[torch.Tensor]: if torch.is_tensor(output): return output if isinstance(output, (tuple, list)) and output: first = output[0] if torch.is_tensor(first): return first if hasattr(output, "last_hidden_state"): hidden = getattr(output, "last_hidden_state") if torch.is_tensor(hidden): return hidden return None @contextmanager def capture_module_output(module: torch.nn.Module): cache: Dict[str, Optional[torch.Tensor]] = {"output": None} def hook(_module, _inputs, output): cache["output"] = _extract_hidden_like(output) handle = module.register_forward_hook(hook) try: yield cache finally: handle.remove() _ATTN_NAME_FRAGMENTS = ( "self_attn.", "attn.", "attention.", "q_proj", "k_proj", "v_proj", "o_proj", "q_norm", "k_norm", ) _MLP_NAME_FRAGMENTS = ( "mlp.", "ffn.", "feed_forward", "feedforward", "gate_proj", "up_proj", "down_proj", "fc1", "fc2", "dense_h_to_4h", "dense_4h_to_h", "w1", "w2", "w3", ) def _classify_param_family(name: str) -> str: lowered = name.lower() if any(fragment in lowered for fragment in _MLP_NAME_FRAGMENTS): return "mlp" if any(fragment in lowered for fragment in _ATTN_NAME_FRAGMENTS): return "attn" return "other" def _family_reg_scale(family: str, attn_scale: float, mlp_scale: float) -> float: if family == "attn": return attn_scale if family == "mlp": return mlp_scale return 1.0 def _subset_allows_param(name: str, subset: str) -> bool: if subset == "all": return True return _classify_param_family(name) == subset def _gate_logit_from_prior(prior: torch.Tensor) -> torch.Tensor: # Stable logit: log(p) - log(1 - p). return torch.log(prior) - torch.log1p(-prior) def _build_gate_priors( layer_a: torch.nn.Module, layer_b: torch.nn.Module, fisher_a: Dict[str, object], fisher_b: Dict[str, object], num_batches: int, numels_a: Dict[str, int], numels_b: Dict[str, int], fisher_mode: str, eps: float, clamp_eps: float, ) -> Dict[str, torch.Tensor]: """Return lambda priors for parameters that can be merged.""" priors: Dict[str, torch.Tensor] = {} params_b = {name: param for name, param in layer_b.named_parameters()} for name, param_a in layer_a.named_parameters(): param_b = params_b.get(name) if param_b is None or param_b.shape != param_a.shape: continue if fisher_mode == "param": fa = fisher_a[name] / max(num_batches, 1) fb = fisher_b[name] / max(num_batches, 1) denom = fa + fb if not isinstance(denom, torch.Tensor): denom = torch.tensor(float(denom)) # If Fisher is uninformative, default to symmetric init. prior = torch.where( denom > eps, fa / (denom + eps), torch.full_like(denom, 0.5), ) prior = prior.clamp(clamp_eps, 1.0 - clamp_eps) priors[name] = prior else: fa = fisher_a[name] / (max(num_batches, 1) * numels_a[name]) fb = fisher_b[name] / (max(num_batches, 1) * numels_b[name]) denom = fa + fb if denom <= eps: prior_val = 0.5 else: prior_val = float(fa / (denom + eps)) prior_val = min(max(prior_val, clamp_eps), 1.0 - clamp_eps) priors[name] = torch.tensor(prior_val, dtype=torch.float32) return priors def compute_fisher_gate_priors( layer_a: torch.nn.Module, layer_b: torch.nn.Module, fisher_a: Dict[str, object], fisher_b: Dict[str, object], num_batches: int, numels_a: Dict[str, int], numels_b: Dict[str, int], fisher_mode: str, eps: float, clamp_eps: float = 1e-4, ) -> Dict[str, torch.Tensor]: """Compute Fisher prior gate lambdas (lambda_prior) for mergeable parameters.""" return _build_gate_priors( layer_a=layer_a, layer_b=layer_b, fisher_a=fisher_a, fisher_b=fisher_b, num_batches=num_batches, numels_a=numels_a, numels_b=numels_b, fisher_mode=fisher_mode, eps=eps, clamp_eps=clamp_eps, ) class ReparamMergedLayer(torch.nn.Module): """Virtual layer that merges parameters via W0/U reparameterization. Parameters of layer_a/layer_b are treated as frozen (detached). We train: - gate logits s (lambda = sigmoid(s)) - U (initialized as U0 = (W_a - W_b) / 2) Forward uses: W_merge = W0 + (2 * lambda - 1) * U where W0 = (W_a + W_b) / 2 """ def __init__( self, layer_a: torch.nn.Module, layer_b: torch.nn.Module, gate_targets: Dict[str, object], param_subset: str = "all", clamp_eps: float = 1e-4, ) -> None: super().__init__() self.layer_a = layer_a self.layer_b = layer_b self.param_subset = param_subset self._name_map: Dict[str, str] = {} self.gates = torch.nn.ParameterDict() self.u = torch.nn.ParameterDict() params_b = {name: param for name, param in layer_b.named_parameters()} try: device = next(layer_a.parameters()).device except StopIteration: device = torch.device("cpu") for name, param_a in layer_a.named_parameters(): param_b = params_b.get(name) if param_b is None or param_b.shape != param_a.shape: continue if not _subset_allows_param(name, self.param_subset): continue target = gate_targets.get(name) if target is None: target_t = torch.tensor(0.5, device=device, dtype=torch.float32) elif isinstance(target, torch.Tensor): target_t = target.detach().to(device=device, dtype=torch.float32) else: target_t = torch.tensor(float(target), device=device, dtype=torch.float32) target_t = target_t.clamp(clamp_eps, 1.0 - clamp_eps) s0 = _gate_logit_from_prior(target_t) u0 = 0.5 * (param_a.detach().float() - param_b.detach().float()) safe = name.replace(".", "__") if safe in self.gates: safe = f"{safe}_{len(self.gates)}" self._name_map[name] = safe self.gates[safe] = torch.nn.Parameter(s0) self.u[safe] = torch.nn.Parameter(u0) def __getattr__(self, name: str): # Delegate model-specific attributes (e.g. Qwen's `attention_type`) to # the underlying layer so the parent forward doesn't break. try: return super().__getattr__(name) except AttributeError as exc: try: layer_a = super().__getattr__("layer_a") if hasattr(layer_a, name): return getattr(layer_a, name) except AttributeError: pass try: layer_b = super().__getattr__("layer_b") if hasattr(layer_b, name): return getattr(layer_b, name) except AttributeError: pass raise exc def _safe_for(self, orig: str) -> Optional[str]: return self._name_map.get(orig) def gate_lambdas(self) -> Dict[str, torch.Tensor]: out: Dict[str, torch.Tensor] = {} for orig, safe in self._name_map.items(): out[orig] = torch.sigmoid(self.gates[safe]).detach() return out def _merged_params(self) -> Dict[str, torch.Tensor]: params_a = {name: p for name, p in self.layer_a.named_parameters()} params_b = {name: p for name, p in self.layer_b.named_parameters()} merged_params: Dict[str, torch.Tensor] = {} for name, param_a in params_a.items(): param_b = params_b.get(name) safe = self._safe_for(name) if safe is None or param_b is None or param_b.shape != param_a.shape: merged_params[name] = param_a.detach() continue lam = torch.sigmoid(self.gates[safe]).to(dtype=torch.float32) u = self.u[safe].to(dtype=torch.float32) w0 = 0.5 * (param_a.detach().float() + param_b.detach().float()) merged = w0 + (2.0 * lam - 1.0) * u merged_params[name] = merged.to(dtype=param_a.dtype) return merged_params def forward(self, *args, **kwargs): if _functional_call is None: raise RuntimeError( "Reparam distillation requires torch.func.functional_call" ) merged_params = self._merged_params() return _functional_call(self.layer_a, merged_params, args, kwargs) def materialize_into_layer_a(self) -> int: merged = 0 params_a = {name: p for name, p in self.layer_a.named_parameters()} params_b = {name: p for name, p in self.layer_b.named_parameters()} with torch.no_grad(): for orig, safe in self._name_map.items(): param_a = params_a.get(orig) param_b = params_b.get(orig) if param_a is None or param_b is None or param_b.shape != param_a.shape: continue lam = torch.sigmoid(self.gates[safe]).to(device=param_a.device, dtype=torch.float32) u = self.u[safe].to(device=param_a.device, dtype=torch.float32) w0 = 0.5 * (param_a.detach().float() + param_b.detach().float()) merged_param = w0 + (2.0 * lam - 1.0) * u param_a.copy_(merged_param.to(dtype=param_a.dtype)) merged += 1 return merged def distill_reparam_merge( student_model: torch.nn.Module, student_parent: object, student_layer_attr: str, student_layers: List[torch.nn.Module], teacher_model: torch.nn.Module, teacher_parent: object, teacher_layer_attr: str, teacher_layers: List[torch.nn.Module], layer_idx: int, gate_lambdas: Dict[str, object], dataloader, args: argparse.Namespace, progressive_cycle: Optional[int] = None, progressive_total: Optional[int] = None, ) -> Tuple[int, Dict[str, torch.Tensor], Dict[str, object]]: """Reparameterized distillation that materializes a fused layer into layer_a. Trains U and gate logits s (lambda = sigmoid(s)) using: - composition MSE + distill-KL - eta * ||lambda - lambda_gate||^2 + gamma * ||U - U0||^2 """ total_epochs = float(args.distill_epochs) hidden_mse_weight = float(getattr(args, "distill_hidden_mse_weight", 1.0)) if hidden_mse_weight < 0.0: raise SystemExit("--distill_hidden_mse_weight must be >= 0") attn_mse_weight = float(getattr(args, "distill_attn_mse_weight", 0.0)) if attn_mse_weight < 0.0: raise SystemExit("--distill_attn_mse_weight must be >= 0") mlp_mse_weight = float(getattr(args, "distill_mlp_mse_weight", 0.0)) if mlp_mse_weight < 0.0: raise SystemExit("--distill_mlp_mse_weight must be >= 0") param_subset = str(getattr(args, "reparam_param_subset", "all")) if param_subset not in {"all", "mlp", "attn"}: raise SystemExit("--reparam_param_subset must be one of: all, mlp, attn") kl_weight = float(args.distill_kl_weight) kl_temp = float(args.distill_kl_temp) if kl_weight < 0.0: raise SystemExit("--distill_kl_weight must be >= 0") if kl_temp <= 0.0: raise SystemExit("--distill_kl_temp must be > 0") eta = float(getattr(args, "reparam_eta", 0.0)) gamma = float(getattr(args, "reparam_gamma", 0.0)) if eta < 0.0: raise SystemExit("--reparam_eta must be >= 0") if gamma < 0.0: raise SystemExit("--reparam_gamma must be >= 0") attn_reg_scale = float(getattr(args, "reparam_attn_reg_scale", 1.0)) mlp_reg_scale = float(getattr(args, "reparam_mlp_reg_scale", 1.0)) if attn_reg_scale < 0.0: raise SystemExit("--reparam_attn_reg_scale must be >= 0") if mlp_reg_scale < 0.0: raise SystemExit("--reparam_mlp_reg_scale must be >= 0") if ( total_epochs > 0.0 and hidden_mse_weight == 0.0 and attn_mse_weight == 0.0 and mlp_mse_weight == 0.0 and kl_weight == 0.0 and eta == 0.0 and gamma == 0.0 ): raise SystemExit( "Reparam distillation has no active loss terms. " "Enable hidden/attention/MLP MSE, KL, or at least one reparam regularizer." ) if not gate_lambdas: raise SystemExit("Reparam distillation requires non-empty gate lambdas.") layer_a = student_layers[layer_idx] layer_b = student_layers[layer_idx + 1] reparam_layer = ReparamMergedLayer( layer_a, layer_b, gate_lambdas, param_subset=param_subset, clamp_eps=1e-4, ) if not reparam_layer._name_map: raise RuntimeError( "No mergeable parameters found for reparam distillation under " f"--reparam_param_subset={param_subset!r}." ) teacher_attn = None student_attn = None if attn_mse_weight > 0.0: try: teacher_attn = find_attention_module(teacher_layers[layer_idx + 1]) student_attn = find_attention_module(reparam_layer.layer_a) except ValueError as exc: raise SystemExit( "Attention-output preservation was requested but an attention module " f"could not be resolved: {exc}" ) from exc teacher_mlp = None student_mlp = None if mlp_mse_weight > 0.0: try: teacher_mlp = find_mlp_module(teacher_layers[layer_idx + 1]) student_mlp = find_mlp_module(reparam_layer.layer_a) except ValueError as exc: raise SystemExit( "MLP-output preservation was requested but an MLP module could not be " f"resolved: {exc}" ) from exc # Virtual layer list: replace layer_a with reparam layer and remove layer_b. virtual_layers = list(student_layers) virtual_layers[layer_idx] = reparam_layer del virtual_layers[layer_idx + 1] # Only (U, s) are trainable. for param in student_model.parameters(): param.requires_grad_(False) for param in reparam_layer.gates.parameters(): param.requires_grad_(True) for param in reparam_layer.u.parameters(): param.requires_grad_(True) do_train = total_epochs > 0.0 if do_train: teacher_model.eval() student_model.train() # Rough memory heads-up (esp. when --fisher_mode param makes per-element gates). total_gate_elems = sum(int(p.numel()) for p in reparam_layer.gates.parameters()) total_u_elems = sum(int(p.numel()) for p in reparam_layer.u.parameters()) gate_mib = total_gate_elems * 4.0 / (1024.0 * 1024.0) u_mib = total_u_elems * 4.0 / (1024.0 * 1024.0) family_counts: Dict[str, int] = {"attn": 0, "mlp": 0, "other": 0} for orig in reparam_layer._name_map: family_counts[_classify_param_family(orig)] += 1 print( f"[reparam] subset={param_subset} gates={len(reparam_layer.gates)} " f"(attn={family_counts['attn']}, mlp={family_counts['mlp']}, other={family_counts['other']}) " f"elems={total_gate_elems} (~{gate_mib:.1f} MiB), " f"U_elems={total_u_elems} (~{u_mib:.1f} MiB; +optimizer state)" ) optimizer = None if do_train: optimizer = torch.optim.AdamW( [*reparam_layer.gates.parameters(), *reparam_layer.u.parameters()], lr=float(args.distill_lr), weight_decay=float(args.distill_weight_decay), ) device_type = torch.device(args.device).type amp_dtype = None if args.dtype == "float16": amp_dtype = torch.float16 elif args.dtype == "bfloat16": amp_dtype = torch.bfloat16 use_amp = do_train and amp_dtype is not None and device_type == "cuda" use_scaler = use_amp and amp_dtype == torch.float16 scaler = torch.cuda.amp.GradScaler() if use_scaler else None full_epochs = int(total_epochs) if do_train else 0 fractional = (total_epochs - full_epochs) if do_train else 0.0 if fractional < 1e-8: fractional = 0.0 epoch_plan = [(epoch_idx, None) for epoch_idx in range(full_epochs)] if fractional > 0: try: batches_per_epoch = len(dataloader) except TypeError as exc: raise SystemExit( "Fractional distill epochs require a dataloader with finite length." ) from exc if batches_per_epoch > 0: frac_batches = int(round(fractional * batches_per_epoch)) if frac_batches <= 0: frac_batches = 1 epoch_plan.append((full_epochs, frac_batches)) grad_accum = int(getattr(args, "distill_grad_accum_steps", 1)) if grad_accum <= 0: raise SystemExit("--distill_grad_accum_steps must be >= 1") log_steps = int(getattr(args, "distill_log_steps", 100)) max_grad_norm = getattr(args, "distill_max_grad_norm", 1.0) params_a = {name: p for name, p in layer_a.named_parameters()} params_b = {name: p for name, p in layer_b.named_parameters()} step = 0 for epoch_idx, max_batches in epoch_plan: if max_batches is None: epoch_iter = dataloader else: epoch_iter = itertools.islice(dataloader, max_batches) iterator = epoch_iter if tqdm is not None and _tqdm_enabled(): if progressive_cycle is not None: if progressive_total is not None: desc = ( f"Reparam (cycle {progressive_cycle}/{progressive_total}, " f"epoch {epoch_idx+1})" ) else: desc = f"Reparam (cycle {progressive_cycle}, epoch {epoch_idx+1})" else: desc = f"Reparam (epoch {epoch_idx+1})" iterator = tqdm(epoch_iter, desc=desc, unit="batch", total=max_batches) for batch in iterator: input_ids = batch[0].to(args.device) attention_mask = batch[1].to(args.device) teacher_ids = input_ids.to(args.distill_teacher_device or args.device) teacher_mask = attention_mask.to(args.distill_teacher_device or args.device) teacher_depth = layer_idx + 2 student_depth = layer_idx + 1 autocast_ctx = ( torch.autocast(device_type=device_type, dtype=amp_dtype) if use_amp else nullcontext() ) with autocast_ctx: teacher_attn_ctx = ( capture_module_output(teacher_attn) if teacher_attn is not None else nullcontext({"output": None}) ) teacher_mlp_ctx = ( capture_module_output(teacher_mlp) if teacher_mlp is not None else nullcontext({"output": None}) ) with torch.no_grad(): with teacher_attn_ctx as teacher_attn_cache, teacher_mlp_ctx as teacher_mlp_cache: teacher_hidden = forward_truncated( teacher_parent, teacher_layer_attr, teacher_layers, teacher_depth, teacher_ids, attention_mask=teacher_mask, ) student_attn_ctx = ( capture_module_output(student_attn) if student_attn is not None else nullcontext({"output": None}) ) student_mlp_ctx = ( capture_module_output(student_mlp) if student_mlp is not None else nullcontext({"output": None}) ) with student_attn_ctx as student_attn_cache, student_mlp_ctx as student_mlp_cache: student_hidden = forward_truncated( student_parent, student_layer_attr, virtual_layers, student_depth, input_ids, attention_mask=attention_mask, ) if teacher_hidden.device != student_hidden.device: teacher_hidden = teacher_hidden.to(student_hidden.device) mse_loss = None if hidden_mse_weight > 0.0: diff = student_hidden - teacher_hidden mse_loss = _masked_hidden_mse(diff, attention_mask) if mse_loss is None: continue attn_aux_loss = None if attn_mse_weight > 0.0: teacher_attn_hidden = teacher_attn_cache.get("output") student_attn_hidden = student_attn_cache.get("output") if teacher_attn_hidden is None or student_attn_hidden is None: raise RuntimeError( "Attention-output preservation is enabled, but the forward " "hook did not capture attention outputs." ) if teacher_attn_hidden.device != student_attn_hidden.device: teacher_attn_hidden = teacher_attn_hidden.to(student_attn_hidden.device) attn_aux_loss = _masked_hidden_mse( student_attn_hidden - teacher_attn_hidden, attention_mask, ) if attn_aux_loss is None: continue mlp_aux_loss = None if mlp_mse_weight > 0.0: teacher_mlp_hidden = teacher_mlp_cache.get("output") student_mlp_hidden = student_mlp_cache.get("output") if teacher_mlp_hidden is None or student_mlp_hidden is None: raise RuntimeError( "MLP-output preservation is enabled, but the forward hook " "did not capture MLP outputs." ) if teacher_mlp_hidden.device != student_mlp_hidden.device: teacher_mlp_hidden = teacher_mlp_hidden.to(student_mlp_hidden.device) mlp_aux_loss = _masked_hidden_mse( student_mlp_hidden - teacher_mlp_hidden, attention_mask, ) if mlp_aux_loss is None: continue kl_loss = None if kl_weight > 0.0: with torch.no_grad(): teacher_outputs = teacher_model( input_ids=teacher_ids, attention_mask=teacher_mask, use_cache=False, ) teacher_logits = teacher_outputs.logits virtual_container = torch.nn.ModuleList(virtual_layers) with temporary_layers( student_parent, student_layer_attr, virtual_container ): student_outputs = student_model( input_ids=input_ids, attention_mask=attention_mask, use_cache=False, ) student_logits = student_outputs.logits if teacher_logits.device != student_logits.device: teacher_logits = teacher_logits.to(student_logits.device) shift_teacher_logits = teacher_logits[:, :-1, :].contiguous() shift_student_logits = student_logits[:, :-1, :].contiguous() shift_mask = attention_mask[:, 1:].contiguous() log_p_t = F.log_softmax(shift_teacher_logits / kl_temp, dim=-1) log_p_s = F.log_softmax(shift_student_logits / kl_temp, dim=-1) p_t = log_p_t.exp() kl_flat = (p_t * (log_p_t - log_p_s)).sum(dim=-1) kl_denom = shift_mask.sum() if kl_denom.item() == 0: continue kl_loss = ( kl_flat * shift_mask.to(kl_flat.dtype) ).sum() / kl_denom lambda_reg = None if eta > 0.0: reg_sum: Optional[torch.Tensor] = None reg_elems = 0 for orig, safe in reparam_layer._name_map.items(): lam = torch.sigmoid(reparam_layer.gates[safe]).float() target = gate_lambdas.get(orig) if target is None: target_t = 0.5 elif isinstance(target, torch.Tensor): target_t = target.to(device=lam.device, dtype=lam.dtype) else: target_t = float(target) diff_lam = lam - target_t family = _classify_param_family(orig) scale = _family_reg_scale( family, attn_scale=attn_reg_scale, mlp_scale=mlp_reg_scale, ) if scale <= 0.0: continue part = diff_lam.pow(2).sum() * scale reg_sum = part if reg_sum is None else reg_sum + part reg_elems += int(float(diff_lam.numel()) * scale) if reg_elems > 0 and reg_sum is not None: lambda_reg = reg_sum / float(reg_elems) u_reg = None if gamma > 0.0: reg_sum: Optional[torch.Tensor] = None reg_elems = 0 for orig, safe in reparam_layer._name_map.items(): u = reparam_layer.u[safe].float() param_a = params_a.get(orig) param_b = params_b.get(orig) if param_a is None or param_b is None or param_b.shape != param_a.shape: continue u0 = 0.5 * (param_a.detach().float() - param_b.detach().float()) diff_u = u - u0 family = _classify_param_family(orig) scale = _family_reg_scale( family, attn_scale=attn_reg_scale, mlp_scale=mlp_reg_scale, ) if scale <= 0.0: continue part = diff_u.pow(2).sum() * scale reg_sum = part if reg_sum is None else reg_sum + part reg_elems += int(float(diff_u.numel()) * scale) if reg_elems > 0 and reg_sum is not None: u_reg = reg_sum / float(reg_elems) total_loss = None if mse_loss is not None: total_loss = hidden_mse_weight * mse_loss if attn_aux_loss is not None: total_loss = attn_mse_weight * attn_aux_loss if total_loss is None else total_loss + (attn_mse_weight * attn_aux_loss) if mlp_aux_loss is not None: total_loss = mlp_mse_weight * mlp_aux_loss if total_loss is None else total_loss + (mlp_mse_weight * mlp_aux_loss) if kl_loss is not None: total_loss = kl_weight * (kl_temp ** 2) * kl_loss if total_loss is None else total_loss + (kl_weight * (kl_temp ** 2) * kl_loss) if lambda_reg is not None: total_loss = eta * lambda_reg if total_loss is None else total_loss + (eta * lambda_reg) if u_reg is not None: total_loss = gamma * u_reg if total_loss is None else total_loss + (gamma * u_reg) if total_loss is None: continue if grad_accum > 1: total_loss = total_loss / grad_accum if use_scaler: scaler.scale(total_loss).backward() else: total_loss.backward() if (step + 1) % grad_accum == 0: if max_grad_norm is not None: if use_scaler: scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_( [*reparam_layer.gates.parameters(), *reparam_layer.u.parameters()], float(max_grad_norm), ) if use_scaler: scaler.step(optimizer) scaler.update() else: optimizer.step() optimizer.zero_grad(set_to_none=True) if log_steps and (step == 0 or (step + 1) % log_steps == 0): log_parts = [f"loss={total_loss.item():.6e}"] if mse_loss is not None: log_parts.append(f"mse={mse_loss.item():.6e}") else: log_parts.append("mse=disabled") if attn_aux_loss is not None: log_parts.append(f"attn_mse={attn_aux_loss.item():.6e}") elif attn_mse_weight > 0.0: log_parts.append("attn_mse=nan") if mlp_aux_loss is not None: log_parts.append(f"mlp_mse={mlp_aux_loss.item():.6e}") elif mlp_mse_weight > 0.0: log_parts.append("mlp_mse=nan") if kl_loss is not None: log_parts.append(f"kl={kl_loss.item():.6e}") if lambda_reg is not None: log_parts.append(f"lam_reg={lambda_reg.item():.6e}") if u_reg is not None: log_parts.append(f"u_reg={u_reg.item():.6e}") print( f"[reparam] epoch={epoch_idx+1} step={step+1} " + " ".join(log_parts) ) step += 1 merged = reparam_layer.materialize_into_layer_a() final_lambdas = reparam_layer.gate_lambdas() stats: Dict[str, object] = { "enabled": True, "epochs": total_epochs, "lr": float(args.distill_lr), "hidden_mse_weight": hidden_mse_weight, "attn_mse_weight": attn_mse_weight, "mlp_mse_weight": mlp_mse_weight, "eta": eta, "gamma": gamma, "attn_reg_scale": attn_reg_scale, "mlp_reg_scale": mlp_reg_scale, "param_subset": param_subset, "num_gates": len(final_lambdas), "num_attn_gates": family_counts["attn"], "num_mlp_gates": family_counts["mlp"], "num_other_gates": family_counts["other"], } return merged, final_lambdas, stats class LoRALinear(torch.nn.Module): def __init__( self, base: torch.nn.Linear, rank: int, alpha: float, dropout: float, ) -> None: super().__init__() if rank <= 0: raise ValueError("LoRA rank must be positive") self.base = base self.rank = int(rank) self.alpha = float(alpha) self.scaling = self.alpha / float(self.rank) self.enabled = True if dropout > 0: self.dropout = torch.nn.Dropout(dropout) else: self.dropout = torch.nn.Identity() self.lora_A = torch.nn.Linear(base.in_features, self.rank, bias=False) self.lora_B = torch.nn.Linear(self.rank, base.out_features, bias=False) torch.nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5)) torch.nn.init.zeros_(self.lora_B.weight) self.lora_A.to(device=base.weight.device, dtype=base.weight.dtype) self.lora_B.to(device=base.weight.device, dtype=base.weight.dtype) self.merged = False def lora_parameters(self) -> List[torch.nn.Parameter]: return [*self.lora_A.parameters(), *self.lora_B.parameters()] def forward(self, x: torch.Tensor) -> torch.Tensor: result = self.base(x) if self.merged or not self.enabled: return result lora_out = self.lora_B(self.lora_A(self.dropout(x))) return result + lora_out * self.scaling def merge(self) -> None: if self.merged: return delta = torch.matmul(self.lora_B.weight, self.lora_A.weight) delta = delta.to(dtype=self.base.weight.dtype) * self.scaling self.base.weight.data.add_(delta) self.merged = True def _get_child_module(parent: torch.nn.Module, part: str) -> torch.nn.Module: if isinstance(parent, (torch.nn.ModuleList, torch.nn.Sequential)) and part.isdigit(): return parent[int(part)] if isinstance(parent, torch.nn.ModuleDict): return parent[part] return getattr(parent, part) def _set_child_module(parent: torch.nn.Module, part: str, module: torch.nn.Module) -> None: if isinstance(parent, (torch.nn.ModuleList, torch.nn.Sequential)) and part.isdigit(): parent[int(part)] = module return if isinstance(parent, torch.nn.ModuleDict): parent[part] = module return setattr(parent, part, module) def _resolve_parent_module( root: torch.nn.Module, module_name: str ) -> Optional[tuple]: if not module_name: return None parts = module_name.split(".") parent = root for part in parts[:-1]: parent = _get_child_module(parent, part) return parent, parts[-1] def _resolve_module_by_path(root: torch.nn.Module, module_path: str) -> Optional[torch.nn.Module]: if not module_path: return None parts = [part for part in module_path.split(".") if part] node = root for part in parts: try: node = _get_child_module(node, part) except Exception: return None return node def _resolve_layer_container_for_lora( model: torch.nn.Module, layer_path: Optional[str] ) -> Tuple[Optional[str], Optional[object]]: """Resolve transformer layer container with optional auto-detection. Mirrors the candidate path strategy used elsewhere, so LoRA filtering can work even when --layer_path is not provided. """ if isinstance(layer_path, str) and layer_path and layer_path.lower() != "none": container = _resolve_module_by_path(model, layer_path) if container is not None: try: list(container) return layer_path, container except TypeError: pass candidate_paths = [ "model.layers", # LLaMA, Mistral, Qwen2, Gemma "model.decoder.layers", # OPT "transformer.h", # GPT-2, GPT-J, Bloom, Falcon "transformer.blocks", # MPT "gpt_neox.layers", # GPT-NeoX "layers", # fallback ] for path in candidate_paths: container = _resolve_module_by_path(model, path) if container is None: continue try: list(container) except TypeError: continue return path, container return None, None def _parse_exclude_pairs_local(raw_values, num_pairs: int) -> Set[int]: if not raw_values or num_pairs <= 0: return set() exclude: Set[int] = set() for item in raw_values: if item is None: continue for part in str(item).split(","): part = part.strip() if not part: continue try: idx = int(part) except ValueError as exc: raise SystemExit("--exclude_pairs must contain integers.") from exc if idx < 0: idx = num_pairs + idx if 0 <= idx < num_pairs: exclude.add(idx) return exclude def _extract_layer_index_from_module_name( module_name: str, layer_path: str ) -> Optional[int]: if not layer_path: return None prefix = f"{layer_path}." if not module_name.startswith(prefix): return None rest = module_name[len(prefix) :] if not rest: return None idx_text = rest.split(".", 1)[0] if not idx_text.isdigit(): return None return int(idx_text) def _select_linear_modules_for_lora_targets( model: torch.nn.Module, args: argparse.Namespace, *, log_tag: str, ) -> Tuple[List[Tuple[str, torch.nn.Linear]], Optional[Set[str]], Set[int], Optional[str]]: raw_targets = getattr(args, "lora_target_modules", None) target_modules: Optional[Set[str]] = None if raw_targets: target_modules = {str(item) for item in raw_targets if str(item)} exclude_layer_indices: Set[int] = set() resolved_layer_path: Optional[str] = None if bool(getattr(args, "lora_respect_exclude_pairs", False)): requested_layer_path = getattr(args, "layer_path", None) resolved_layer_path, layer_container = _resolve_layer_container_for_lora( model, requested_layer_path ) if isinstance(layer_container, (torch.nn.ModuleList, list, tuple)): num_pairs = max(len(layer_container) - 1, 0) exclude_pairs = _parse_exclude_pairs_local( getattr(args, "exclude_pairs", None), num_pairs ) for pair_idx in exclude_pairs: exclude_layer_indices.add(pair_idx) exclude_layer_indices.add(pair_idx + 1) else: print( f"[{log_tag}] Warning: --lora_respect_exclude_pairs enabled, but " f"could not resolve layer path '{requested_layer_path}'." ) linear_modules = [ (name, module) for name, module in model.named_modules() if isinstance(module, torch.nn.Linear) and (target_modules is None or name.split(".")[-1] in target_modules) and ( not exclude_layer_indices or _extract_layer_index_from_module_name(name, resolved_layer_path or "") not in exclude_layer_indices ) ] return linear_modules, target_modules, exclude_layer_indices, resolved_layer_path def apply_lora_adapters( model: torch.nn.Module, args: argparse.Namespace ) -> List[LoRALinear]: if args.lora_rank <= 0: raise SystemExit("--lora_rank must be > 0 when --lora_epochs > 0") linear_modules, target_modules, exclude_layer_indices, _ = ( _select_linear_modules_for_lora_targets(model, args, log_tag="lora") ) if not linear_modules: raise SystemExit( "No Linear modules found for LoRA adapters " "(check --lora_target_modules / --exclude_pairs / --lora_respect_exclude_pairs)." ) lora_modules: List[LoRALinear] = [] for name, module in linear_modules: resolved = _resolve_parent_module(model, name) if resolved is None: continue parent, attr = resolved wrapped = LoRALinear( base=module, rank=args.lora_rank, alpha=args.lora_alpha, dropout=args.lora_dropout, ) _set_child_module(parent, attr, wrapped) lora_modules.append(wrapped) for param in model.parameters(): param.requires_grad_(False) for lora_module in lora_modules: for param in lora_module.lora_parameters(): param.requires_grad_(True) total_params = sum(p.numel() for p in model.parameters()) trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) percent = 100.0 * trainable_params / max(total_params, 1) target_note = "" if target_modules is not None: target_note = f" target={sorted(target_modules)}" exclude_note = "" if exclude_layer_indices: exclude_note = f" excluded_layers={sorted(exclude_layer_indices)}" print( "[lora] Applied adapters to " f"{len(lora_modules)} linear modules " f"({trainable_params}/{total_params} trainable, {percent:.4f}%)." f"{target_note}{exclude_note}" ) return lora_modules def merge_lora_adapters(model: torch.nn.Module) -> None: lora_entries = [ (name, module) for name, module in model.named_modules() if isinstance(module, LoRALinear) ] for name, module in lora_entries: module.merge() resolved = _resolve_parent_module(model, name) if resolved is None: continue parent, attr = resolved _set_child_module(parent, attr, module.base) def set_lora_enabled(lora_modules: List[LoRALinear], enabled: bool) -> None: for module in lora_modules: module.enabled = enabled def lora_ce_finetune( model: torch.nn.Module, dataloader, eval_tokenizer, eval_datasets: List[str], eval_configs: List[Optional[str]], eval_history: List[Dict[str, object]], args: argparse.Namespace, eval_dataloaders: Optional[Dict[str, object]] = None, progressive_cycle: Optional[int] = None, progressive_total: Optional[int] = None, ) -> None: total_epochs = float(args.lora_epochs) if total_epochs <= 0: return use_kl = bool(getattr(args, "lora_kl_enabled", False)) kl_weight = float(getattr(args, "lora_kl_weight", 0.0)) kl_temp = float(getattr(args, "lora_kl_temp", 1.0)) if use_kl: if kl_weight < 0.0: raise SystemExit("--lora_kl_weight must be >= 0") if kl_temp <= 0.0: raise SystemExit("--lora_kl_temp must be > 0") if kl_weight == 0.0: use_kl = False lora_modules = apply_lora_adapters(model, args) if not lora_modules: return model.train() lora_params = [] for module in lora_modules: lora_params.extend(module.lora_parameters()) optimizer = torch.optim.AdamW( lora_params, lr=args.lora_lr, weight_decay=args.lora_weight_decay, ) device_type = torch.device(args.device).type amp_dtype = None if args.dtype == "float16": amp_dtype = torch.float16 elif args.dtype == "bfloat16": amp_dtype = torch.bfloat16 use_amp = amp_dtype is not None and device_type == "cuda" use_scaler = use_amp and amp_dtype == torch.float16 scaler = torch.cuda.amp.GradScaler() if use_scaler else None full_epochs = int(total_epochs) fractional = total_epochs - full_epochs if fractional < 1e-8: fractional = 0.0 epoch_plan = [(epoch_idx, None) for epoch_idx in range(full_epochs)] if fractional > 0: try: batches_per_epoch = len(dataloader) except TypeError as exc: raise SystemExit( "Fractional lora epochs require a dataloader with finite length." ) from exc if batches_per_epoch > 0: frac_batches = int(round(fractional * batches_per_epoch)) if frac_batches <= 0: frac_batches = 1 epoch_plan.append((full_epochs, frac_batches)) step = 0 for epoch_idx, max_batches in epoch_plan: if max_batches is None: epoch_iter = dataloader else: epoch_iter = itertools.islice(dataloader, max_batches) iterator = epoch_iter if tqdm is not None and _tqdm_enabled(): if progressive_cycle is not None: if progressive_total is not None: desc = ( f"LoRA (cycle {progressive_cycle}/{progressive_total}, " f"epoch {epoch_idx+1})" ) else: desc = f"LoRA (cycle {progressive_cycle}, epoch {epoch_idx+1})" else: desc = f"LoRA (epoch {epoch_idx+1})" iterator = tqdm( epoch_iter, desc=desc, unit="batch", total=max_batches, ) for batch in iterator: input_ids = batch[0].to(args.device) attention_mask = batch[1].to(args.device) autocast_ctx = ( torch.autocast(device_type=device_type, dtype=amp_dtype) if use_amp else nullcontext() ) with autocast_ctx: outputs = model( input_ids=input_ids, attention_mask=attention_mask, use_cache=False, ) logits = outputs.logits shift_logits = logits[:, :-1, :].contiguous() shift_labels = input_ids[:, 1:].contiguous() shift_mask = attention_mask[:, 1:].contiguous() ce_flat = F.cross_entropy( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1), reduction="none", ) ce_denom = shift_mask.sum() if ce_denom.item() == 0: continue ce_loss = ( ce_flat * shift_mask.view(-1).to(ce_flat.dtype) ).sum() / ce_denom kl_loss = None if use_kl: set_lora_enabled(lora_modules, False) with torch.no_grad(): base_outputs = model( input_ids=input_ids, attention_mask=attention_mask, use_cache=False, ) base_logits = base_outputs.logits set_lora_enabled(lora_modules, True) if base_logits.device != shift_logits.device: base_logits = base_logits.to(shift_logits.device) shift_base_logits = base_logits[:, :-1, :].contiguous() log_p_pre = F.log_softmax(shift_base_logits / kl_temp, dim=-1) log_p_post = F.log_softmax(shift_logits / kl_temp, dim=-1) p_pre = log_p_pre.exp() kl_flat = (p_pre * (log_p_pre - log_p_post)).sum(dim=-1) kl_loss = ( kl_flat * shift_mask.to(kl_flat.dtype) ).sum() / ce_denom total_loss = ce_loss if kl_loss is not None: total_loss = total_loss + (kl_weight * (kl_temp ** 2) * kl_loss) if args.lora_grad_accum_steps > 1: total_loss = total_loss / args.lora_grad_accum_steps if use_scaler: scaler.scale(total_loss).backward() else: total_loss.backward() if (step + 1) % args.lora_grad_accum_steps == 0: if args.lora_max_grad_norm is not None: if use_scaler: scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_( lora_params, args.lora_max_grad_norm, ) if use_scaler: scaler.step(optimizer) scaler.update() else: optimizer.step() optimizer.zero_grad(set_to_none=True) if args.lora_eval_every and (step + 1) % args.lora_eval_every == 0: prev_mode = model.training model.eval() eval_device = args.eval_device or args.device if eval_dataloaders is not None: results = ppl_eval.evaluate_ppl_dataloaders( model, eval_dataloaders, eval_device, max_batches=args.lora_eval_max_batches, ) else: results = ppl_eval.evaluate_ppl_datasets( model, eval_tokenizer, datasets=eval_datasets, configs=eval_configs, split=args.eval_split, text_field=args.eval_text_field, num_samples=args.eval_num_samples, seq_len=args.eval_seq_len, batch_size=args.eval_batch_size or args.batch_size, device=eval_device, seed=args.seed, shuffle=False, model_family=args.eval_model_family, add_bos=args.eval_add_bos, max_batches=args.lora_eval_max_batches, cache_dir=args.eval_cache_dir, num_workers=args.eval_num_workers, ) eval_history.append({"step": step + 1, "ppl": results}) print(f"[lora] eval step={step+1}: {results}") if prev_mode: model.train() if args.lora_log_steps and ( step == 0 or (step + 1) % args.lora_log_steps == 0 ): log_parts = [f"loss={total_loss.item():.6f}"] if kl_loss is not None: log_parts.append(f"kl={kl_loss.item():.6f}") print( f"[lora] epoch={epoch_idx+1} step={step+1} " + " ".join(log_parts) ) step += 1 merge_lora_adapters(model) def _masked_kl( logits_p: torch.Tensor, logits_q: torch.Tensor, attention_mask: torch.Tensor, temp: float, detach_p: bool = True, ) -> Optional[torch.Tensor]: shift_mask = attention_mask[:, 1:].contiguous() denom = shift_mask.sum() if denom.item() == 0: return None p = logits_p[:, :-1, :].contiguous() q = logits_q[:, :-1, :].contiguous() if p.device != q.device: p = p.to(q.device) # Keep dtype to avoid blowing up memory on large vocab models. log_p = F.log_softmax(p / temp, dim=-1) log_q = F.log_softmax(q / temp, dim=-1) if detach_p: log_p = log_p.detach() p_probs = log_p.exp() kl_flat = (p_probs * (log_p - log_q)).sum(dim=-1) return (kl_flat * shift_mask.to(kl_flat.dtype)).sum() / denom def _extract_hidden_tensor(output: object) -> Optional[torch.Tensor]: if isinstance(output, torch.Tensor): return output if isinstance(output, (tuple, list)) and output: first = output[0] if isinstance(first, torch.Tensor): return first return None def _grad_l2_norm(grads: List[Optional[torch.Tensor]]) -> float: total = 0.0 for grad in grads: if grad is None: continue total += float(grad.detach().float().pow(2).sum().item()) if total <= 0.0: return 0.0 return float(math.sqrt(total)) def _register_forward_pre_hook_with_optional_kwargs(layer, hook): try: handle = layer.register_forward_pre_hook(hook, with_kwargs=True) return handle except TypeError: def wrapper(module, inputs): return hook(module, inputs, None) return layer.register_forward_pre_hook(wrapper) def commutator_precondition( student_model: torch.nn.Module, student_layers: List[torch.nn.Module], teacher_model: torch.nn.Module, dataloader, dwce_scores: Optional[List[float]], args: argparse.Namespace, exclude_pairs: Optional[Set[int]] = None, progressive_cycle: Optional[int] = None, progressive_total: Optional[int] = None, ) -> Dict[str, object]: """Run commutator-style preconditioning before pair fusion. Objective on each sampled pair i: L = T^2 * KL(p_teacher || p_student) + mu * L_interaction(i) Interaction loss is computed locally on block (i+1): r1 = B_{i+1}(h_{i+1}) - h_{i+1} r0 = B_{i+1}(h_i) - h_i L_interaction = ||r1-r0||^2 (or relative form). """ if not bool(getattr(args, "comm_enabled", False)): return {"enabled": False} if not student_layers or len(student_layers) < 2: return {"enabled": False, "reason": "need_at_least_2_layers"} temp = float(getattr(args, "comm_temp", 2.0)) steps_ratio = float(getattr(args, "comm_steps_ratio", 0.1)) lr_scale = float(getattr(args, "comm_lr_scale", 0.1)) sample_eta = float(getattr(args, "comm_sample_eta", 0.5)) sample_dwce_scale = float(getattr(args, "comm_sample_dwce_scale", 1.0)) top_k = int(getattr(args, "comm_topk", 1)) interaction_mode = str(getattr(args, "comm_interaction_mode", "relative")).strip().lower() interaction_eps = float(getattr(args, "comm_interaction_eps", 1e-8)) mu_cfg = getattr(args, "comm_mu", None) mu_auto = bool(getattr(args, "comm_mu_auto", False)) mu_auto_rho = float(getattr(args, "comm_mu_auto_rho", 0.1)) mu_auto_eps = float(getattr(args, "comm_mu_auto_eps", 1e-8)) comm_train_mode = str(getattr(args, "comm_train_mode", "lora")).strip().lower() log_steps = int(getattr(args, "comm_log_steps", 50)) if temp <= 0.0: raise SystemExit("--comm_temp must be > 0") if steps_ratio < 0.0: raise SystemExit("--comm_steps_ratio must be >= 0") if lr_scale <= 0.0: raise SystemExit("--comm_lr_scale must be > 0") if not (0.0 <= sample_eta <= 1.0): raise SystemExit("--comm_sample_eta must be in [0, 1]") if top_k <= 0: raise SystemExit("--comm_topk must be >= 1") if interaction_mode not in {"mse", "relative"}: raise SystemExit("--comm_interaction_mode must be one of: mse, relative") if comm_train_mode not in {"lora", "full"}: raise SystemExit("--comm_train_mode must be one of: lora, full") if interaction_eps <= 0.0: raise SystemExit("--comm_interaction_eps must be > 0") if mu_auto_rho < 0.0: raise SystemExit("--comm_mu_auto_rho must be >= 0") if mu_auto_eps <= 0.0: raise SystemExit("--comm_mu_auto_eps must be > 0") if mu_cfg is None: base_mu = 0.5 if interaction_mode == "relative" else 0.1 else: base_mu = float(mu_cfg) if base_mu < 0.0: raise SystemExit("--comm_mu must be >= 0") distill_epochs = float(getattr(args, "distill_epochs", 1.0)) if distill_epochs <= 0.0: distill_epochs = 1.0 grad_accum = int(getattr(args, "distill_grad_accum_steps", 1)) if grad_accum <= 0: grad_accum = 1 try: batches_per_epoch = len(dataloader) except TypeError as exc: raise SystemExit( "Commutator preconditioning requires a finite-length distillation dataloader." ) from exc if batches_per_epoch <= 0: return {"enabled": False, "reason": "empty_dataloader"} full_epochs = int(distill_epochs) fractional = distill_epochs - full_epochs if fractional < 1e-8: fractional = 0.0 total_batches = full_epochs * batches_per_epoch if fractional > 0.0: frac_batches = int(round(fractional * batches_per_epoch)) if frac_batches <= 0: frac_batches = 1 total_batches += frac_batches distill_opt_steps = int(math.ceil(total_batches / float(grad_accum))) target_opt_steps = int(round(steps_ratio * distill_opt_steps)) if target_opt_steps <= 0: target_opt_steps = 1 num_pairs = max(len(student_layers) - 1, 0) exclude_set = { int(idx) for idx in (exclude_pairs or set()) if isinstance(idx, int) and 0 <= int(idx) < num_pairs } allowed_pairs = [i for i in range(num_pairs) if i not in exclude_set] if not allowed_pairs: return {"enabled": False, "reason": "all_pairs_excluded"} ranked_pairs = list(allowed_pairs) if dwce_scores is not None and len(dwce_scores) >= num_pairs: finite_pairs = [] for idx in allowed_pairs: value = float(dwce_scores[idx]) if math.isfinite(value): finite_pairs.append(idx) if finite_pairs: ranked_pairs = sorted(finite_pairs, key=lambda i: float(dwce_scores[i])) else: ranked_pairs = list(allowed_pairs) candidate_pairs = ranked_pairs[: min(top_k, len(ranked_pairs))] if not candidate_pairs: return {"enabled": False, "reason": "no_candidate_pairs"} layer_trainable_params: List[List[torch.nn.Parameter]] = [] trainable_params: List[torch.nn.Parameter] = [] if comm_train_mode == "lora": # LoRA comm preconditioning: update LoRA adapters on receiver layer (i+1). lora_modules = apply_lora_adapters(student_model, args) if not lora_modules: return {"enabled": False, "reason": "no_lora_modules"} trainable_seen: Set[int] = set() for module in lora_modules: for param in module.lora_parameters(): pid = id(param) if pid in trainable_seen: continue trainable_seen.add(pid) trainable_params.append(param) for layer in student_layers: seen: Set[int] = set() params: List[torch.nn.Parameter] = [] for module in layer.modules(): if not isinstance(module, LoRALinear): continue for param in module.lora_parameters(): pid = id(param) if pid in seen: continue seen.add(pid) params.append(param) layer_trainable_params.append(params) else: # Full-weight comm preconditioning: update full receiver-layer weights. for layer in student_layers: seen: Set[int] = set() params: List[torch.nn.Parameter] = [] for param in layer.parameters(): if not isinstance(param, torch.nn.Parameter): continue pid = id(param) if pid in seen: continue seen.add(pid) params.append(param) layer_trainable_params.append(params) candidate_pairs = [ i for i in candidate_pairs if (i + 1) < len(layer_trainable_params) and layer_trainable_params[i + 1] ] if not candidate_pairs: if comm_train_mode == "lora": merge_lora_adapters(student_model) return {"enabled": False, "reason": "no_trainable_receiver_layers"} if comm_train_mode == "full": trainable_seen: Set[int] = set() for pair_idx in candidate_pairs: for param in layer_trainable_params[pair_idx + 1]: pid = id(param) if pid in trainable_seen: continue trainable_seen.add(pid) trainable_params.append(param) if not trainable_params: return {"enabled": False, "reason": "no_trainable_receiver_layers"} # Freeze non-comm params to reduce grad memory. for param in student_model.parameters(): param.requires_grad_(False) for param in trainable_params: param.requires_grad_(True) if not trainable_params: if comm_train_mode == "lora": merge_lora_adapters(student_model) return {"enabled": False, "reason": "no_trainable_params"} candidate_probs = torch.full( (len(candidate_pairs),), 1.0 / float(len(candidate_pairs)), dtype=torch.float32, ) if dwce_scores is not None and len(dwce_scores) >= num_pairs and sample_eta > 0.0: score_vec = torch.tensor( [float(dwce_scores[i]) for i in candidate_pairs], dtype=torch.float32 ) score_vec = torch.nan_to_num(score_vec, nan=1e9, posinf=1e9, neginf=-1e9) biased = torch.softmax(-float(sample_dwce_scale) * score_vec, dim=0) candidate_probs = (1.0 - sample_eta) * candidate_probs + sample_eta * biased candidate_probs = candidate_probs / candidate_probs.sum() probs_by_pair = [0.0 for _ in range(num_pairs)] for pos, pair_idx in enumerate(candidate_pairs): probs_by_pair[pair_idx] = float(candidate_probs[pos].item()) lr = float(getattr(args, "distill_lr", 1e-4)) * lr_scale optimizer = torch.optim.AdamW( trainable_params, lr=lr, weight_decay=float(getattr(args, "distill_weight_decay", 0.0)), ) device_type = torch.device(args.device).type amp_dtype = None if args.dtype == "float16": amp_dtype = torch.float16 elif args.dtype == "bfloat16": amp_dtype = torch.bfloat16 use_amp = amp_dtype is not None and device_type == "cuda" use_scaler = use_amp and amp_dtype == torch.float16 scaler = torch.cuda.amp.GradScaler() if use_scaler else None teacher_device = next(teacher_model.parameters()).device teacher_model.eval() student_model.train() gen = torch.Generator(device="cpu") seed = int(getattr(args, "seed", 0)) if progressive_cycle is not None: seed += int(progressive_cycle) * 100003 gen.manual_seed(seed) opt_step = 0 total_loss_sum = 0.0 anchor_sum = 0.0 interaction_sum = 0.0 mu_sum = 0.0 counted = 0 pair_counts = [0 for _ in range(num_pairs)] desc = "Comm" if progressive_cycle is not None: if progressive_total is not None: desc = f"Comm (cycle {progressive_cycle}/{progressive_total})" else: desc = f"Comm (cycle {progressive_cycle})" iterator = range(target_opt_steps) if tqdm is not None and _tqdm_enabled(): iterator = tqdm(iterator, desc=desc, unit="step") data_iter = iter(dataloader) autocast_ctx = ( torch.autocast(device_type=device_type, dtype=amp_dtype) if use_amp else nullcontext() ) for _ in iterator: optimizer.zero_grad(set_to_none=True) accum_done = 0 while accum_done < grad_accum: try: batch = next(data_iter) except StopIteration: data_iter = iter(dataloader) batch = next(data_iter) input_ids = batch[0].to(args.device) attention_mask = batch[1].to(args.device) sampled_pos = int(torch.multinomial(candidate_probs, 1, generator=gen).item()) pair_idx = int(candidate_pairs[sampled_pos]) pair_counts[pair_idx] += 1 receiver_params = layer_trainable_params[pair_idx + 1] receiver_param_ids = {id(param) for param in receiver_params} teacher_ids = input_ids.to(teacher_device) teacher_mask = attention_mask.to(teacher_device) with torch.no_grad(), autocast_ctx: teacher_outputs = teacher_model( input_ids=teacher_ids, attention_mask=teacher_mask, use_cache=False, ) teacher_logits = teacher_outputs.logits capture: Dict[str, object] = { "h_l": None, "h_lp1": None, "y1": None, "recv_args": None, "recv_kwargs": None, } def _hook_l(_module, inputs, _output): if inputs and isinstance(inputs[0], torch.Tensor): capture["h_l"] = inputs[0] def _hook_recv_pre(_module, inputs, kwargs): capture["recv_args"] = inputs capture["recv_kwargs"] = kwargs def _hook_recv(_module, inputs, output): if inputs and isinstance(inputs[0], torch.Tensor): capture["h_lp1"] = inputs[0] capture["y1"] = _extract_hidden_tensor(output) handles: List[object] = [ student_layers[pair_idx].register_forward_hook(_hook_l), _register_forward_pre_hook_with_optional_kwargs( student_layers[pair_idx + 1], _hook_recv_pre ), student_layers[pair_idx + 1].register_forward_hook(_hook_recv), ] try: with autocast_ctx: student_outputs = student_model( input_ids=input_ids, attention_mask=attention_mask, use_cache=False, ) student_logits = student_outputs.logits finally: for handle in handles: try: handle.remove() except Exception: pass with autocast_ctx: anchor_kl = _masked_kl( teacher_logits, student_logits, attention_mask, temp=temp, detach_p=True, ) if anchor_kl is None: continue anchor_loss = (temp ** 2) * anchor_kl interaction_loss = None h_l = capture.get("h_l") h_lp1 = capture.get("h_lp1") y1 = capture.get("y1") recv_args = capture.get("recv_args") recv_kwargs = capture.get("recv_kwargs") if ( isinstance(h_l, torch.Tensor) and isinstance(h_lp1, torch.Tensor) and isinstance(y1, torch.Tensor) and isinstance(recv_args, tuple) and len(recv_args) > 0 and isinstance(recv_args[0], torch.Tensor) ): call_args = list(recv_args) first_hidden = call_args[0] h_l_detached = h_l.detach().to( device=first_hidden.device, dtype=first_hidden.dtype, ) call_args[0] = h_l_detached call_kwargs = dict(recv_kwargs) if isinstance(recv_kwargs, dict) else {} y0_raw = student_layers[pair_idx + 1](*tuple(call_args), **call_kwargs) y0 = _extract_hidden_tensor(y0_raw) if isinstance(y0, torch.Tensor): if y0.device != y1.device: y0 = y0.to(y1.device) h_lp1_detached = h_lp1.detach().to(device=y1.device, dtype=y1.dtype) h_l_for_res = h_l.detach().to(device=y0.device, dtype=y0.dtype) r1 = y1 - h_lp1_detached r0 = y0 - h_l_for_res mask = attention_mask.to(dtype=r1.dtype) mask_sum = mask.sum() if mask_sum.item() > 0: if interaction_mode == "relative": num = (r1 - r0).float().pow(2).sum(dim=-1) den = r1.float().pow(2).sum(dim=-1) + float(interaction_eps) ratio = (num / den) * mask.to(num.dtype) interaction_loss = ratio.sum() / (mask_sum + 1e-8) else: denom = mask_sum * r1.size(-1) if denom.item() > 0: interaction_loss = ( (r1 - r0).pow(2) * mask.unsqueeze(-1) ).sum() / denom mu_effective = float(base_mu) if ( mu_auto and interaction_loss is not None and receiver_params and mu_auto_rho > 0.0 ): anchor_grads = torch.autograd.grad( anchor_loss, receiver_params, retain_graph=True, allow_unused=True, ) interaction_grads = torch.autograd.grad( interaction_loss, receiver_params, retain_graph=True, allow_unused=True, ) anchor_norm = _grad_l2_norm(list(anchor_grads)) interaction_norm = _grad_l2_norm(list(interaction_grads)) if interaction_norm > 0.0: mu_effective = float( mu_auto_rho * (anchor_norm / (interaction_norm + float(mu_auto_eps))) ) else: mu_effective = float(base_mu) if not math.isfinite(mu_effective): mu_effective = float(base_mu) total_loss = anchor_loss if interaction_loss is not None: total_loss = total_loss + (float(mu_effective) * interaction_loss) if grad_accum > 1: total_loss = total_loss / float(grad_accum) if use_scaler: scaler.scale(total_loss).backward() else: total_loss.backward() # Only the sampled receiver layer updates on this micro-batch. for param in trainable_params: if id(param) in receiver_param_ids: continue if param.grad is not None: if comm_train_mode == "lora": param.grad.zero_() else: param.grad = None total_loss_sum += float(total_loss.detach().float().item()) anchor_sum += float(anchor_loss.detach().float().item()) if interaction_loss is not None: interaction_sum += float(interaction_loss.detach().float().item()) mu_sum += float(mu_effective) counted += 1 accum_done += 1 if args.distill_max_grad_norm is not None: if use_scaler: scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_( trainable_params, float(args.distill_max_grad_norm), ) if use_scaler: scaler.step(optimizer) scaler.update() else: optimizer.step() opt_step += 1 if log_steps and (opt_step == 1 or opt_step % log_steps == 0): denom = max(counted, 1) print( f"[comm] step={opt_step}/{target_opt_steps} " f"loss={total_loss_sum/denom:.6f} " f"anchor={anchor_sum/denom:.6f} " f"int={interaction_sum/denom:.6f} " f"mu={mu_sum/denom:.6f}" ) if comm_train_mode == "lora": merge_lora_adapters(student_model) stats: Dict[str, object] = { "enabled": True, "train_mode": comm_train_mode, "opt_steps": int(target_opt_steps), "grad_accum_steps": int(grad_accum), "lr": float(lr), "temp": float(temp), "steps_ratio": float(steps_ratio), "lr_scale": float(lr_scale), "interaction_mode": interaction_mode, "interaction_eps": float(interaction_eps), "mu": float(base_mu), "mu_auto": bool(mu_auto), "mu_auto_rho": float(mu_auto_rho), "mu_auto_eps": float(mu_auto_eps), "sample_eta": float(sample_eta), "sample_dwce_scale": float(sample_dwce_scale), "topk": int(top_k), "candidate_pairs": [int(i) for i in candidate_pairs], "trainable_params": int(sum(int(param.numel()) for param in trainable_params)), } total_samples = int(sum(pair_counts)) probs_list = [float(x) for x in probs_by_pair] freqs = ( [float(c) / float(total_samples) for c in pair_counts] if total_samples > 0 else [0.0 for _ in pair_counts] ) top_show = min(10, num_pairs) top_indices = sorted(range(num_pairs), key=lambda i: pair_counts[i], reverse=True)[:top_show] top_pairs = [ { "pair": int(i), "count": int(pair_counts[i]), "freq": float(freqs[i]), "prob": float(probs_list[i]) if i < len(probs_list) else None, } for i in top_indices if pair_counts[i] > 0 ] stats["pair_selection"] = { "num_pairs": int(num_pairs), "excluded_pairs": sorted(exclude_set), "candidate_pairs": [int(i) for i in candidate_pairs], "total_samples": total_samples, "unique_pairs": int(sum(1 for c in pair_counts if c > 0)), "counts": [int(c) for c in pair_counts], "freqs": freqs, "probs": probs_list, "top_pairs": top_pairs, } if total_samples > 0 and top_pairs: top_str = ", ".join( f"{entry['pair']}-{entry['pair'] + 1}: {entry['count']} " f"(obs={entry['freq']:.3f}, exp={entry['prob']:.3f})" for entry in top_pairs if entry.get("prob") is not None ) if not top_str: top_str = ", ".join( f"{entry['pair']}-{entry['pair'] + 1}: {entry['count']} " f"(obs={entry['freq']:.3f})" for entry in top_pairs ) print( f"[comm] Pair sampling stats: total={total_samples} " f"unique={stats['pair_selection']['unique_pairs']}/{num_pairs} " f"top={top_str}" ) if counted > 0: stats["avg_loss"] = float(total_loss_sum / float(counted)) stats["avg_anchor"] = float(anchor_sum / float(counted)) stats["avg_interaction"] = float(interaction_sum / float(counted)) stats["avg_mu"] = float(mu_sum / float(counted)) return stats