| |
| """Distillation helpers for fuse_layers.""" |
|
|
| import argparse |
| import itertools |
| import math |
| import os |
| from contextlib import contextmanager, nullcontext |
| from typing import Dict, List, Optional, Set, Tuple |
|
|
| import torch |
| import torch.nn.functional as F |
|
|
| try: |
| import ppl_eval |
| except Exception as exc: |
| raise SystemExit("ppl_eval.py is required (missing or invalid)") from exc |
| try: |
| from tqdm import tqdm |
| except Exception: |
| tqdm = None |
|
|
| try: |
| from torch.func import functional_call as _functional_call |
| except Exception: |
| try: |
| from torch.nn.utils.stateless import functional_call as _functional_call |
| except Exception: |
| _functional_call = None |
|
|
| from fuse_layers_model import find_attention_module, find_mlp_module |
|
|
|
|
| def _tqdm_enabled() -> bool: |
| value = os.environ.get("DISABLE_TQDM", os.environ.get("TQDM_DISABLE", "0")) |
| return value.strip().lower() not in {"1", "true", "yes", "on"} |
|
|
|
|
| @contextmanager |
| def temporary_layers(parent: object, name: str, new_layers: torch.nn.Module): |
| original = getattr(parent, name) |
| setattr(parent, name, new_layers) |
| try: |
| yield |
| finally: |
| setattr(parent, name, original) |
|
|
|
|
| @contextmanager |
| def temporary_norm(parent: object): |
| if hasattr(parent, "norm"): |
| original = getattr(parent, "norm") |
| setattr(parent, "norm", torch.nn.Identity()) |
| try: |
| yield |
| finally: |
| setattr(parent, "norm", original) |
| else: |
| yield |
|
|
|
|
| def forward_truncated( |
| parent: torch.nn.Module, |
| layer_attr: str, |
| layers: List[torch.nn.Module], |
| upto: int, |
| input_ids: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| truncated = torch.nn.ModuleList(layers[:upto]) |
| with temporary_layers(parent, layer_attr, truncated), temporary_norm(parent): |
| outputs = parent( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| use_cache=False, |
| ) |
| if hasattr(outputs, "last_hidden_state"): |
| return outputs.last_hidden_state |
| return outputs[0] |
|
|
|
|
| def _masked_hidden_mse(diff: torch.Tensor, attention_mask: torch.Tensor) -> Optional[torch.Tensor]: |
| diff_f = diff.float() |
| mask = attention_mask.to(device=diff.device, dtype=torch.float32) |
| denom = mask.sum() * diff_f.size(-1) |
| if denom.item() == 0: |
| return None |
| return (diff_f.pow(2) * mask.unsqueeze(-1)).sum() / denom |
|
|
|
|
| def _extract_hidden_like(output) -> Optional[torch.Tensor]: |
| if torch.is_tensor(output): |
| return output |
| if isinstance(output, (tuple, list)) and output: |
| first = output[0] |
| if torch.is_tensor(first): |
| return first |
| if hasattr(output, "last_hidden_state"): |
| hidden = getattr(output, "last_hidden_state") |
| if torch.is_tensor(hidden): |
| return hidden |
| return None |
|
|
|
|
| @contextmanager |
| def capture_module_output(module: torch.nn.Module): |
| cache: Dict[str, Optional[torch.Tensor]] = {"output": None} |
|
|
| def hook(_module, _inputs, output): |
| cache["output"] = _extract_hidden_like(output) |
|
|
| handle = module.register_forward_hook(hook) |
| try: |
| yield cache |
| finally: |
| handle.remove() |
|
|
|
|
| _ATTN_NAME_FRAGMENTS = ( |
| "self_attn.", |
| "attn.", |
| "attention.", |
| "q_proj", |
| "k_proj", |
| "v_proj", |
| "o_proj", |
| "q_norm", |
| "k_norm", |
| ) |
| _MLP_NAME_FRAGMENTS = ( |
| "mlp.", |
| "ffn.", |
| "feed_forward", |
| "feedforward", |
| "gate_proj", |
| "up_proj", |
| "down_proj", |
| "fc1", |
| "fc2", |
| "dense_h_to_4h", |
| "dense_4h_to_h", |
| "w1", |
| "w2", |
| "w3", |
| ) |
|
|
|
|
| def _classify_param_family(name: str) -> str: |
| lowered = name.lower() |
| if any(fragment in lowered for fragment in _MLP_NAME_FRAGMENTS): |
| return "mlp" |
| if any(fragment in lowered for fragment in _ATTN_NAME_FRAGMENTS): |
| return "attn" |
| return "other" |
|
|
|
|
| def _family_reg_scale(family: str, attn_scale: float, mlp_scale: float) -> float: |
| if family == "attn": |
| return attn_scale |
| if family == "mlp": |
| return mlp_scale |
| return 1.0 |
|
|
|
|
| def _subset_allows_param(name: str, subset: str) -> bool: |
| if subset == "all": |
| return True |
| return _classify_param_family(name) == subset |
|
|
|
|
| def _gate_logit_from_prior(prior: torch.Tensor) -> torch.Tensor: |
| |
| return torch.log(prior) - torch.log1p(-prior) |
|
|
|
|
| def _build_gate_priors( |
| layer_a: torch.nn.Module, |
| layer_b: torch.nn.Module, |
| fisher_a: Dict[str, object], |
| fisher_b: Dict[str, object], |
| num_batches: int, |
| numels_a: Dict[str, int], |
| numels_b: Dict[str, int], |
| fisher_mode: str, |
| eps: float, |
| clamp_eps: float, |
| ) -> Dict[str, torch.Tensor]: |
| """Return lambda priors for parameters that can be merged.""" |
| priors: Dict[str, torch.Tensor] = {} |
| params_b = {name: param for name, param in layer_b.named_parameters()} |
| for name, param_a in layer_a.named_parameters(): |
| param_b = params_b.get(name) |
| if param_b is None or param_b.shape != param_a.shape: |
| continue |
| if fisher_mode == "param": |
| fa = fisher_a[name] / max(num_batches, 1) |
| fb = fisher_b[name] / max(num_batches, 1) |
| denom = fa + fb |
| if not isinstance(denom, torch.Tensor): |
| denom = torch.tensor(float(denom)) |
| |
| prior = torch.where( |
| denom > eps, |
| fa / (denom + eps), |
| torch.full_like(denom, 0.5), |
| ) |
| prior = prior.clamp(clamp_eps, 1.0 - clamp_eps) |
| priors[name] = prior |
| else: |
| fa = fisher_a[name] / (max(num_batches, 1) * numels_a[name]) |
| fb = fisher_b[name] / (max(num_batches, 1) * numels_b[name]) |
| denom = fa + fb |
| if denom <= eps: |
| prior_val = 0.5 |
| else: |
| prior_val = float(fa / (denom + eps)) |
| prior_val = min(max(prior_val, clamp_eps), 1.0 - clamp_eps) |
| priors[name] = torch.tensor(prior_val, dtype=torch.float32) |
| return priors |
|
|
|
|
| def compute_fisher_gate_priors( |
| layer_a: torch.nn.Module, |
| layer_b: torch.nn.Module, |
| fisher_a: Dict[str, object], |
| fisher_b: Dict[str, object], |
| num_batches: int, |
| numels_a: Dict[str, int], |
| numels_b: Dict[str, int], |
| fisher_mode: str, |
| eps: float, |
| clamp_eps: float = 1e-4, |
| ) -> Dict[str, torch.Tensor]: |
| """Compute Fisher prior gate lambdas (lambda_prior) for mergeable parameters.""" |
| return _build_gate_priors( |
| layer_a=layer_a, |
| layer_b=layer_b, |
| fisher_a=fisher_a, |
| fisher_b=fisher_b, |
| num_batches=num_batches, |
| numels_a=numels_a, |
| numels_b=numels_b, |
| fisher_mode=fisher_mode, |
| eps=eps, |
| clamp_eps=clamp_eps, |
| ) |
|
|
|
|
| class ReparamMergedLayer(torch.nn.Module): |
| """Virtual layer that merges parameters via W0/U reparameterization. |
| |
| Parameters of layer_a/layer_b are treated as frozen (detached). We train: |
| - gate logits s (lambda = sigmoid(s)) |
| - U (initialized as U0 = (W_a - W_b) / 2) |
| |
| Forward uses: |
| W_merge = W0 + (2 * lambda - 1) * U |
| where W0 = (W_a + W_b) / 2 |
| """ |
|
|
| def __init__( |
| self, |
| layer_a: torch.nn.Module, |
| layer_b: torch.nn.Module, |
| gate_targets: Dict[str, object], |
| param_subset: str = "all", |
| clamp_eps: float = 1e-4, |
| ) -> None: |
| super().__init__() |
| self.layer_a = layer_a |
| self.layer_b = layer_b |
| self.param_subset = param_subset |
| self._name_map: Dict[str, str] = {} |
|
|
| self.gates = torch.nn.ParameterDict() |
| self.u = torch.nn.ParameterDict() |
|
|
| params_b = {name: param for name, param in layer_b.named_parameters()} |
| try: |
| device = next(layer_a.parameters()).device |
| except StopIteration: |
| device = torch.device("cpu") |
|
|
| for name, param_a in layer_a.named_parameters(): |
| param_b = params_b.get(name) |
| if param_b is None or param_b.shape != param_a.shape: |
| continue |
| if not _subset_allows_param(name, self.param_subset): |
| continue |
|
|
| target = gate_targets.get(name) |
| if target is None: |
| target_t = torch.tensor(0.5, device=device, dtype=torch.float32) |
| elif isinstance(target, torch.Tensor): |
| target_t = target.detach().to(device=device, dtype=torch.float32) |
| else: |
| target_t = torch.tensor(float(target), device=device, dtype=torch.float32) |
|
|
| target_t = target_t.clamp(clamp_eps, 1.0 - clamp_eps) |
| s0 = _gate_logit_from_prior(target_t) |
| u0 = 0.5 * (param_a.detach().float() - param_b.detach().float()) |
|
|
| safe = name.replace(".", "__") |
| if safe in self.gates: |
| safe = f"{safe}_{len(self.gates)}" |
| self._name_map[name] = safe |
| self.gates[safe] = torch.nn.Parameter(s0) |
| self.u[safe] = torch.nn.Parameter(u0) |
|
|
| def __getattr__(self, name: str): |
| |
| |
| try: |
| return super().__getattr__(name) |
| except AttributeError as exc: |
| try: |
| layer_a = super().__getattr__("layer_a") |
| if hasattr(layer_a, name): |
| return getattr(layer_a, name) |
| except AttributeError: |
| pass |
| try: |
| layer_b = super().__getattr__("layer_b") |
| if hasattr(layer_b, name): |
| return getattr(layer_b, name) |
| except AttributeError: |
| pass |
| raise exc |
|
|
| def _safe_for(self, orig: str) -> Optional[str]: |
| return self._name_map.get(orig) |
|
|
| def gate_lambdas(self) -> Dict[str, torch.Tensor]: |
| out: Dict[str, torch.Tensor] = {} |
| for orig, safe in self._name_map.items(): |
| out[orig] = torch.sigmoid(self.gates[safe]).detach() |
| return out |
|
|
| def _merged_params(self) -> Dict[str, torch.Tensor]: |
| params_a = {name: p for name, p in self.layer_a.named_parameters()} |
| params_b = {name: p for name, p in self.layer_b.named_parameters()} |
| merged_params: Dict[str, torch.Tensor] = {} |
|
|
| for name, param_a in params_a.items(): |
| param_b = params_b.get(name) |
| safe = self._safe_for(name) |
| if safe is None or param_b is None or param_b.shape != param_a.shape: |
| merged_params[name] = param_a.detach() |
| continue |
|
|
| lam = torch.sigmoid(self.gates[safe]).to(dtype=torch.float32) |
| u = self.u[safe].to(dtype=torch.float32) |
| w0 = 0.5 * (param_a.detach().float() + param_b.detach().float()) |
| merged = w0 + (2.0 * lam - 1.0) * u |
| merged_params[name] = merged.to(dtype=param_a.dtype) |
| return merged_params |
|
|
| def forward(self, *args, **kwargs): |
| if _functional_call is None: |
| raise RuntimeError( |
| "Reparam distillation requires torch.func.functional_call" |
| ) |
|
|
| merged_params = self._merged_params() |
| return _functional_call(self.layer_a, merged_params, args, kwargs) |
|
|
| def materialize_into_layer_a(self) -> int: |
| merged = 0 |
| params_a = {name: p for name, p in self.layer_a.named_parameters()} |
| params_b = {name: p for name, p in self.layer_b.named_parameters()} |
| with torch.no_grad(): |
| for orig, safe in self._name_map.items(): |
| param_a = params_a.get(orig) |
| param_b = params_b.get(orig) |
| if param_a is None or param_b is None or param_b.shape != param_a.shape: |
| continue |
| lam = torch.sigmoid(self.gates[safe]).to(device=param_a.device, dtype=torch.float32) |
| u = self.u[safe].to(device=param_a.device, dtype=torch.float32) |
| w0 = 0.5 * (param_a.detach().float() + param_b.detach().float()) |
| merged_param = w0 + (2.0 * lam - 1.0) * u |
| param_a.copy_(merged_param.to(dtype=param_a.dtype)) |
| merged += 1 |
| return merged |
|
|
|
|
| def distill_reparam_merge( |
| student_model: torch.nn.Module, |
| student_parent: object, |
| student_layer_attr: str, |
| student_layers: List[torch.nn.Module], |
| teacher_model: torch.nn.Module, |
| teacher_parent: object, |
| teacher_layer_attr: str, |
| teacher_layers: List[torch.nn.Module], |
| layer_idx: int, |
| gate_lambdas: Dict[str, object], |
| dataloader, |
| args: argparse.Namespace, |
| progressive_cycle: Optional[int] = None, |
| progressive_total: Optional[int] = None, |
| ) -> Tuple[int, Dict[str, torch.Tensor], Dict[str, object]]: |
| """Reparameterized distillation that materializes a fused layer into layer_a. |
| |
| Trains U and gate logits s (lambda = sigmoid(s)) using: |
| - composition MSE + distill-KL |
| - eta * ||lambda - lambda_gate||^2 + gamma * ||U - U0||^2 |
| """ |
| total_epochs = float(args.distill_epochs) |
|
|
| hidden_mse_weight = float(getattr(args, "distill_hidden_mse_weight", 1.0)) |
| if hidden_mse_weight < 0.0: |
| raise SystemExit("--distill_hidden_mse_weight must be >= 0") |
| attn_mse_weight = float(getattr(args, "distill_attn_mse_weight", 0.0)) |
| if attn_mse_weight < 0.0: |
| raise SystemExit("--distill_attn_mse_weight must be >= 0") |
| mlp_mse_weight = float(getattr(args, "distill_mlp_mse_weight", 0.0)) |
| if mlp_mse_weight < 0.0: |
| raise SystemExit("--distill_mlp_mse_weight must be >= 0") |
| param_subset = str(getattr(args, "reparam_param_subset", "all")) |
| if param_subset not in {"all", "mlp", "attn"}: |
| raise SystemExit("--reparam_param_subset must be one of: all, mlp, attn") |
|
|
| kl_weight = float(args.distill_kl_weight) |
| kl_temp = float(args.distill_kl_temp) |
| if kl_weight < 0.0: |
| raise SystemExit("--distill_kl_weight must be >= 0") |
| if kl_temp <= 0.0: |
| raise SystemExit("--distill_kl_temp must be > 0") |
|
|
| eta = float(getattr(args, "reparam_eta", 0.0)) |
| gamma = float(getattr(args, "reparam_gamma", 0.0)) |
| if eta < 0.0: |
| raise SystemExit("--reparam_eta must be >= 0") |
| if gamma < 0.0: |
| raise SystemExit("--reparam_gamma must be >= 0") |
| attn_reg_scale = float(getattr(args, "reparam_attn_reg_scale", 1.0)) |
| mlp_reg_scale = float(getattr(args, "reparam_mlp_reg_scale", 1.0)) |
| if attn_reg_scale < 0.0: |
| raise SystemExit("--reparam_attn_reg_scale must be >= 0") |
| if mlp_reg_scale < 0.0: |
| raise SystemExit("--reparam_mlp_reg_scale must be >= 0") |
| if ( |
| total_epochs > 0.0 |
| and hidden_mse_weight == 0.0 |
| and attn_mse_weight == 0.0 |
| and mlp_mse_weight == 0.0 |
| and kl_weight == 0.0 |
| and eta == 0.0 |
| and gamma == 0.0 |
| ): |
| raise SystemExit( |
| "Reparam distillation has no active loss terms. " |
| "Enable hidden/attention/MLP MSE, KL, or at least one reparam regularizer." |
| ) |
|
|
| if not gate_lambdas: |
| raise SystemExit("Reparam distillation requires non-empty gate lambdas.") |
|
|
| layer_a = student_layers[layer_idx] |
| layer_b = student_layers[layer_idx + 1] |
|
|
| reparam_layer = ReparamMergedLayer( |
| layer_a, |
| layer_b, |
| gate_lambdas, |
| param_subset=param_subset, |
| clamp_eps=1e-4, |
| ) |
| if not reparam_layer._name_map: |
| raise RuntimeError( |
| "No mergeable parameters found for reparam distillation under " |
| f"--reparam_param_subset={param_subset!r}." |
| ) |
|
|
| teacher_attn = None |
| student_attn = None |
| if attn_mse_weight > 0.0: |
| try: |
| teacher_attn = find_attention_module(teacher_layers[layer_idx + 1]) |
| student_attn = find_attention_module(reparam_layer.layer_a) |
| except ValueError as exc: |
| raise SystemExit( |
| "Attention-output preservation was requested but an attention module " |
| f"could not be resolved: {exc}" |
| ) from exc |
|
|
| teacher_mlp = None |
| student_mlp = None |
| if mlp_mse_weight > 0.0: |
| try: |
| teacher_mlp = find_mlp_module(teacher_layers[layer_idx + 1]) |
| student_mlp = find_mlp_module(reparam_layer.layer_a) |
| except ValueError as exc: |
| raise SystemExit( |
| "MLP-output preservation was requested but an MLP module could not be " |
| f"resolved: {exc}" |
| ) from exc |
|
|
| |
| virtual_layers = list(student_layers) |
| virtual_layers[layer_idx] = reparam_layer |
| del virtual_layers[layer_idx + 1] |
|
|
| |
| for param in student_model.parameters(): |
| param.requires_grad_(False) |
| for param in reparam_layer.gates.parameters(): |
| param.requires_grad_(True) |
| for param in reparam_layer.u.parameters(): |
| param.requires_grad_(True) |
|
|
| do_train = total_epochs > 0.0 |
| if do_train: |
| teacher_model.eval() |
| student_model.train() |
|
|
| |
| total_gate_elems = sum(int(p.numel()) for p in reparam_layer.gates.parameters()) |
| total_u_elems = sum(int(p.numel()) for p in reparam_layer.u.parameters()) |
| gate_mib = total_gate_elems * 4.0 / (1024.0 * 1024.0) |
| u_mib = total_u_elems * 4.0 / (1024.0 * 1024.0) |
| family_counts: Dict[str, int] = {"attn": 0, "mlp": 0, "other": 0} |
| for orig in reparam_layer._name_map: |
| family_counts[_classify_param_family(orig)] += 1 |
| print( |
| f"[reparam] subset={param_subset} gates={len(reparam_layer.gates)} " |
| f"(attn={family_counts['attn']}, mlp={family_counts['mlp']}, other={family_counts['other']}) " |
| f"elems={total_gate_elems} (~{gate_mib:.1f} MiB), " |
| f"U_elems={total_u_elems} (~{u_mib:.1f} MiB; +optimizer state)" |
| ) |
|
|
| optimizer = None |
| if do_train: |
| optimizer = torch.optim.AdamW( |
| [*reparam_layer.gates.parameters(), *reparam_layer.u.parameters()], |
| lr=float(args.distill_lr), |
| weight_decay=float(args.distill_weight_decay), |
| ) |
|
|
| device_type = torch.device(args.device).type |
| amp_dtype = None |
| if args.dtype == "float16": |
| amp_dtype = torch.float16 |
| elif args.dtype == "bfloat16": |
| amp_dtype = torch.bfloat16 |
| use_amp = do_train and amp_dtype is not None and device_type == "cuda" |
| use_scaler = use_amp and amp_dtype == torch.float16 |
| scaler = torch.cuda.amp.GradScaler() if use_scaler else None |
|
|
| full_epochs = int(total_epochs) if do_train else 0 |
| fractional = (total_epochs - full_epochs) if do_train else 0.0 |
| if fractional < 1e-8: |
| fractional = 0.0 |
|
|
| epoch_plan = [(epoch_idx, None) for epoch_idx in range(full_epochs)] |
| if fractional > 0: |
| try: |
| batches_per_epoch = len(dataloader) |
| except TypeError as exc: |
| raise SystemExit( |
| "Fractional distill epochs require a dataloader with finite length." |
| ) from exc |
| if batches_per_epoch > 0: |
| frac_batches = int(round(fractional * batches_per_epoch)) |
| if frac_batches <= 0: |
| frac_batches = 1 |
| epoch_plan.append((full_epochs, frac_batches)) |
|
|
| grad_accum = int(getattr(args, "distill_grad_accum_steps", 1)) |
| if grad_accum <= 0: |
| raise SystemExit("--distill_grad_accum_steps must be >= 1") |
|
|
| log_steps = int(getattr(args, "distill_log_steps", 100)) |
| max_grad_norm = getattr(args, "distill_max_grad_norm", 1.0) |
|
|
| params_a = {name: p for name, p in layer_a.named_parameters()} |
| params_b = {name: p for name, p in layer_b.named_parameters()} |
|
|
| step = 0 |
| for epoch_idx, max_batches in epoch_plan: |
| if max_batches is None: |
| epoch_iter = dataloader |
| else: |
| epoch_iter = itertools.islice(dataloader, max_batches) |
| iterator = epoch_iter |
| if tqdm is not None and _tqdm_enabled(): |
| if progressive_cycle is not None: |
| if progressive_total is not None: |
| desc = ( |
| f"Reparam (cycle {progressive_cycle}/{progressive_total}, " |
| f"epoch {epoch_idx+1})" |
| ) |
| else: |
| desc = f"Reparam (cycle {progressive_cycle}, epoch {epoch_idx+1})" |
| else: |
| desc = f"Reparam (epoch {epoch_idx+1})" |
| iterator = tqdm(epoch_iter, desc=desc, unit="batch", total=max_batches) |
|
|
| for batch in iterator: |
| input_ids = batch[0].to(args.device) |
| attention_mask = batch[1].to(args.device) |
| teacher_ids = input_ids.to(args.distill_teacher_device or args.device) |
| teacher_mask = attention_mask.to(args.distill_teacher_device or args.device) |
|
|
| teacher_depth = layer_idx + 2 |
| student_depth = layer_idx + 1 |
|
|
| autocast_ctx = ( |
| torch.autocast(device_type=device_type, dtype=amp_dtype) |
| if use_amp |
| else nullcontext() |
| ) |
| with autocast_ctx: |
| teacher_attn_ctx = ( |
| capture_module_output(teacher_attn) |
| if teacher_attn is not None |
| else nullcontext({"output": None}) |
| ) |
| teacher_mlp_ctx = ( |
| capture_module_output(teacher_mlp) |
| if teacher_mlp is not None |
| else nullcontext({"output": None}) |
| ) |
| with torch.no_grad(): |
| with teacher_attn_ctx as teacher_attn_cache, teacher_mlp_ctx as teacher_mlp_cache: |
| teacher_hidden = forward_truncated( |
| teacher_parent, |
| teacher_layer_attr, |
| teacher_layers, |
| teacher_depth, |
| teacher_ids, |
| attention_mask=teacher_mask, |
| ) |
|
|
| student_attn_ctx = ( |
| capture_module_output(student_attn) |
| if student_attn is not None |
| else nullcontext({"output": None}) |
| ) |
| student_mlp_ctx = ( |
| capture_module_output(student_mlp) |
| if student_mlp is not None |
| else nullcontext({"output": None}) |
| ) |
| with student_attn_ctx as student_attn_cache, student_mlp_ctx as student_mlp_cache: |
| student_hidden = forward_truncated( |
| student_parent, |
| student_layer_attr, |
| virtual_layers, |
| student_depth, |
| input_ids, |
| attention_mask=attention_mask, |
| ) |
|
|
| if teacher_hidden.device != student_hidden.device: |
| teacher_hidden = teacher_hidden.to(student_hidden.device) |
|
|
| mse_loss = None |
| if hidden_mse_weight > 0.0: |
| diff = student_hidden - teacher_hidden |
| mse_loss = _masked_hidden_mse(diff, attention_mask) |
| if mse_loss is None: |
| continue |
|
|
| attn_aux_loss = None |
| if attn_mse_weight > 0.0: |
| teacher_attn_hidden = teacher_attn_cache.get("output") |
| student_attn_hidden = student_attn_cache.get("output") |
| if teacher_attn_hidden is None or student_attn_hidden is None: |
| raise RuntimeError( |
| "Attention-output preservation is enabled, but the forward " |
| "hook did not capture attention outputs." |
| ) |
| if teacher_attn_hidden.device != student_attn_hidden.device: |
| teacher_attn_hidden = teacher_attn_hidden.to(student_attn_hidden.device) |
| attn_aux_loss = _masked_hidden_mse( |
| student_attn_hidden - teacher_attn_hidden, |
| attention_mask, |
| ) |
| if attn_aux_loss is None: |
| continue |
|
|
| mlp_aux_loss = None |
| if mlp_mse_weight > 0.0: |
| teacher_mlp_hidden = teacher_mlp_cache.get("output") |
| student_mlp_hidden = student_mlp_cache.get("output") |
| if teacher_mlp_hidden is None or student_mlp_hidden is None: |
| raise RuntimeError( |
| "MLP-output preservation is enabled, but the forward hook " |
| "did not capture MLP outputs." |
| ) |
| if teacher_mlp_hidden.device != student_mlp_hidden.device: |
| teacher_mlp_hidden = teacher_mlp_hidden.to(student_mlp_hidden.device) |
| mlp_aux_loss = _masked_hidden_mse( |
| student_mlp_hidden - teacher_mlp_hidden, |
| attention_mask, |
| ) |
| if mlp_aux_loss is None: |
| continue |
|
|
| kl_loss = None |
| if kl_weight > 0.0: |
| with torch.no_grad(): |
| teacher_outputs = teacher_model( |
| input_ids=teacher_ids, |
| attention_mask=teacher_mask, |
| use_cache=False, |
| ) |
| teacher_logits = teacher_outputs.logits |
|
|
| virtual_container = torch.nn.ModuleList(virtual_layers) |
| with temporary_layers( |
| student_parent, student_layer_attr, virtual_container |
| ): |
| student_outputs = student_model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| use_cache=False, |
| ) |
| student_logits = student_outputs.logits |
| if teacher_logits.device != student_logits.device: |
| teacher_logits = teacher_logits.to(student_logits.device) |
|
|
| shift_teacher_logits = teacher_logits[:, :-1, :].contiguous() |
| shift_student_logits = student_logits[:, :-1, :].contiguous() |
| shift_mask = attention_mask[:, 1:].contiguous() |
| log_p_t = F.log_softmax(shift_teacher_logits / kl_temp, dim=-1) |
| log_p_s = F.log_softmax(shift_student_logits / kl_temp, dim=-1) |
| p_t = log_p_t.exp() |
| kl_flat = (p_t * (log_p_t - log_p_s)).sum(dim=-1) |
| kl_denom = shift_mask.sum() |
| if kl_denom.item() == 0: |
| continue |
| kl_loss = ( |
| kl_flat * shift_mask.to(kl_flat.dtype) |
| ).sum() / kl_denom |
|
|
| lambda_reg = None |
| if eta > 0.0: |
| reg_sum: Optional[torch.Tensor] = None |
| reg_elems = 0 |
| for orig, safe in reparam_layer._name_map.items(): |
| lam = torch.sigmoid(reparam_layer.gates[safe]).float() |
| target = gate_lambdas.get(orig) |
| if target is None: |
| target_t = 0.5 |
| elif isinstance(target, torch.Tensor): |
| target_t = target.to(device=lam.device, dtype=lam.dtype) |
| else: |
| target_t = float(target) |
| diff_lam = lam - target_t |
| family = _classify_param_family(orig) |
| scale = _family_reg_scale( |
| family, |
| attn_scale=attn_reg_scale, |
| mlp_scale=mlp_reg_scale, |
| ) |
| if scale <= 0.0: |
| continue |
| part = diff_lam.pow(2).sum() * scale |
| reg_sum = part if reg_sum is None else reg_sum + part |
| reg_elems += int(float(diff_lam.numel()) * scale) |
| if reg_elems > 0 and reg_sum is not None: |
| lambda_reg = reg_sum / float(reg_elems) |
|
|
| u_reg = None |
| if gamma > 0.0: |
| reg_sum: Optional[torch.Tensor] = None |
| reg_elems = 0 |
| for orig, safe in reparam_layer._name_map.items(): |
| u = reparam_layer.u[safe].float() |
| param_a = params_a.get(orig) |
| param_b = params_b.get(orig) |
| if param_a is None or param_b is None or param_b.shape != param_a.shape: |
| continue |
| u0 = 0.5 * (param_a.detach().float() - param_b.detach().float()) |
| diff_u = u - u0 |
| family = _classify_param_family(orig) |
| scale = _family_reg_scale( |
| family, |
| attn_scale=attn_reg_scale, |
| mlp_scale=mlp_reg_scale, |
| ) |
| if scale <= 0.0: |
| continue |
| part = diff_u.pow(2).sum() * scale |
| reg_sum = part if reg_sum is None else reg_sum + part |
| reg_elems += int(float(diff_u.numel()) * scale) |
| if reg_elems > 0 and reg_sum is not None: |
| u_reg = reg_sum / float(reg_elems) |
|
|
| total_loss = None |
| if mse_loss is not None: |
| total_loss = hidden_mse_weight * mse_loss |
| if attn_aux_loss is not None: |
| total_loss = attn_mse_weight * attn_aux_loss if total_loss is None else total_loss + (attn_mse_weight * attn_aux_loss) |
| if mlp_aux_loss is not None: |
| total_loss = mlp_mse_weight * mlp_aux_loss if total_loss is None else total_loss + (mlp_mse_weight * mlp_aux_loss) |
| if kl_loss is not None: |
| total_loss = kl_weight * (kl_temp ** 2) * kl_loss if total_loss is None else total_loss + (kl_weight * (kl_temp ** 2) * kl_loss) |
| if lambda_reg is not None: |
| total_loss = eta * lambda_reg if total_loss is None else total_loss + (eta * lambda_reg) |
| if u_reg is not None: |
| total_loss = gamma * u_reg if total_loss is None else total_loss + (gamma * u_reg) |
| if total_loss is None: |
| continue |
|
|
| if grad_accum > 1: |
| total_loss = total_loss / grad_accum |
| if use_scaler: |
| scaler.scale(total_loss).backward() |
| else: |
| total_loss.backward() |
|
|
| if (step + 1) % grad_accum == 0: |
| if max_grad_norm is not None: |
| if use_scaler: |
| scaler.unscale_(optimizer) |
| torch.nn.utils.clip_grad_norm_( |
| [*reparam_layer.gates.parameters(), *reparam_layer.u.parameters()], |
| float(max_grad_norm), |
| ) |
| if use_scaler: |
| scaler.step(optimizer) |
| scaler.update() |
| else: |
| optimizer.step() |
| optimizer.zero_grad(set_to_none=True) |
|
|
| if log_steps and (step == 0 or (step + 1) % log_steps == 0): |
| log_parts = [f"loss={total_loss.item():.6e}"] |
| if mse_loss is not None: |
| log_parts.append(f"mse={mse_loss.item():.6e}") |
| else: |
| log_parts.append("mse=disabled") |
| if attn_aux_loss is not None: |
| log_parts.append(f"attn_mse={attn_aux_loss.item():.6e}") |
| elif attn_mse_weight > 0.0: |
| log_parts.append("attn_mse=nan") |
| if mlp_aux_loss is not None: |
| log_parts.append(f"mlp_mse={mlp_aux_loss.item():.6e}") |
| elif mlp_mse_weight > 0.0: |
| log_parts.append("mlp_mse=nan") |
| if kl_loss is not None: |
| log_parts.append(f"kl={kl_loss.item():.6e}") |
| if lambda_reg is not None: |
| log_parts.append(f"lam_reg={lambda_reg.item():.6e}") |
| if u_reg is not None: |
| log_parts.append(f"u_reg={u_reg.item():.6e}") |
| print( |
| f"[reparam] epoch={epoch_idx+1} step={step+1} " + " ".join(log_parts) |
| ) |
| step += 1 |
|
|
| merged = reparam_layer.materialize_into_layer_a() |
| final_lambdas = reparam_layer.gate_lambdas() |
| stats: Dict[str, object] = { |
| "enabled": True, |
| "epochs": total_epochs, |
| "lr": float(args.distill_lr), |
| "hidden_mse_weight": hidden_mse_weight, |
| "attn_mse_weight": attn_mse_weight, |
| "mlp_mse_weight": mlp_mse_weight, |
| "eta": eta, |
| "gamma": gamma, |
| "attn_reg_scale": attn_reg_scale, |
| "mlp_reg_scale": mlp_reg_scale, |
| "param_subset": param_subset, |
| "num_gates": len(final_lambdas), |
| "num_attn_gates": family_counts["attn"], |
| "num_mlp_gates": family_counts["mlp"], |
| "num_other_gates": family_counts["other"], |
| } |
| return merged, final_lambdas, stats |
|
|
|
|
| class LoRALinear(torch.nn.Module): |
| def __init__( |
| self, |
| base: torch.nn.Linear, |
| rank: int, |
| alpha: float, |
| dropout: float, |
| ) -> None: |
| super().__init__() |
| if rank <= 0: |
| raise ValueError("LoRA rank must be positive") |
| self.base = base |
| self.rank = int(rank) |
| self.alpha = float(alpha) |
| self.scaling = self.alpha / float(self.rank) |
| self.enabled = True |
| if dropout > 0: |
| self.dropout = torch.nn.Dropout(dropout) |
| else: |
| self.dropout = torch.nn.Identity() |
|
|
| self.lora_A = torch.nn.Linear(base.in_features, self.rank, bias=False) |
| self.lora_B = torch.nn.Linear(self.rank, base.out_features, bias=False) |
| torch.nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5)) |
| torch.nn.init.zeros_(self.lora_B.weight) |
|
|
| self.lora_A.to(device=base.weight.device, dtype=base.weight.dtype) |
| self.lora_B.to(device=base.weight.device, dtype=base.weight.dtype) |
| self.merged = False |
|
|
| def lora_parameters(self) -> List[torch.nn.Parameter]: |
| return [*self.lora_A.parameters(), *self.lora_B.parameters()] |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| result = self.base(x) |
| if self.merged or not self.enabled: |
| return result |
| lora_out = self.lora_B(self.lora_A(self.dropout(x))) |
| return result + lora_out * self.scaling |
|
|
| def merge(self) -> None: |
| if self.merged: |
| return |
| delta = torch.matmul(self.lora_B.weight, self.lora_A.weight) |
| delta = delta.to(dtype=self.base.weight.dtype) * self.scaling |
| self.base.weight.data.add_(delta) |
| self.merged = True |
|
|
|
|
| def _get_child_module(parent: torch.nn.Module, part: str) -> torch.nn.Module: |
| if isinstance(parent, (torch.nn.ModuleList, torch.nn.Sequential)) and part.isdigit(): |
| return parent[int(part)] |
| if isinstance(parent, torch.nn.ModuleDict): |
| return parent[part] |
| return getattr(parent, part) |
|
|
|
|
| def _set_child_module(parent: torch.nn.Module, part: str, module: torch.nn.Module) -> None: |
| if isinstance(parent, (torch.nn.ModuleList, torch.nn.Sequential)) and part.isdigit(): |
| parent[int(part)] = module |
| return |
| if isinstance(parent, torch.nn.ModuleDict): |
| parent[part] = module |
| return |
| setattr(parent, part, module) |
|
|
|
|
| def _resolve_parent_module( |
| root: torch.nn.Module, module_name: str |
| ) -> Optional[tuple]: |
| if not module_name: |
| return None |
| parts = module_name.split(".") |
| parent = root |
| for part in parts[:-1]: |
| parent = _get_child_module(parent, part) |
| return parent, parts[-1] |
|
|
|
|
| def _resolve_module_by_path(root: torch.nn.Module, module_path: str) -> Optional[torch.nn.Module]: |
| if not module_path: |
| return None |
| parts = [part for part in module_path.split(".") if part] |
| node = root |
| for part in parts: |
| try: |
| node = _get_child_module(node, part) |
| except Exception: |
| return None |
| return node |
|
|
|
|
| def _resolve_layer_container_for_lora( |
| model: torch.nn.Module, layer_path: Optional[str] |
| ) -> Tuple[Optional[str], Optional[object]]: |
| """Resolve transformer layer container with optional auto-detection. |
| |
| Mirrors the candidate path strategy used elsewhere, so LoRA filtering can work |
| even when --layer_path is not provided. |
| """ |
| if isinstance(layer_path, str) and layer_path and layer_path.lower() != "none": |
| container = _resolve_module_by_path(model, layer_path) |
| if container is not None: |
| try: |
| list(container) |
| return layer_path, container |
| except TypeError: |
| pass |
|
|
| candidate_paths = [ |
| "model.layers", |
| "model.decoder.layers", |
| "transformer.h", |
| "transformer.blocks", |
| "gpt_neox.layers", |
| "layers", |
| ] |
| for path in candidate_paths: |
| container = _resolve_module_by_path(model, path) |
| if container is None: |
| continue |
| try: |
| list(container) |
| except TypeError: |
| continue |
| return path, container |
|
|
| return None, None |
|
|
|
|
| def _parse_exclude_pairs_local(raw_values, num_pairs: int) -> Set[int]: |
| if not raw_values or num_pairs <= 0: |
| return set() |
| exclude: Set[int] = set() |
| for item in raw_values: |
| if item is None: |
| continue |
| for part in str(item).split(","): |
| part = part.strip() |
| if not part: |
| continue |
| try: |
| idx = int(part) |
| except ValueError as exc: |
| raise SystemExit("--exclude_pairs must contain integers.") from exc |
| if idx < 0: |
| idx = num_pairs + idx |
| if 0 <= idx < num_pairs: |
| exclude.add(idx) |
| return exclude |
|
|
|
|
| def _extract_layer_index_from_module_name( |
| module_name: str, layer_path: str |
| ) -> Optional[int]: |
| if not layer_path: |
| return None |
| prefix = f"{layer_path}." |
| if not module_name.startswith(prefix): |
| return None |
| rest = module_name[len(prefix) :] |
| if not rest: |
| return None |
| idx_text = rest.split(".", 1)[0] |
| if not idx_text.isdigit(): |
| return None |
| return int(idx_text) |
|
|
|
|
| def _select_linear_modules_for_lora_targets( |
| model: torch.nn.Module, |
| args: argparse.Namespace, |
| *, |
| log_tag: str, |
| ) -> Tuple[List[Tuple[str, torch.nn.Linear]], Optional[Set[str]], Set[int], Optional[str]]: |
| raw_targets = getattr(args, "lora_target_modules", None) |
| target_modules: Optional[Set[str]] = None |
| if raw_targets: |
| target_modules = {str(item) for item in raw_targets if str(item)} |
|
|
| exclude_layer_indices: Set[int] = set() |
| resolved_layer_path: Optional[str] = None |
| if bool(getattr(args, "lora_respect_exclude_pairs", False)): |
| requested_layer_path = getattr(args, "layer_path", None) |
| resolved_layer_path, layer_container = _resolve_layer_container_for_lora( |
| model, requested_layer_path |
| ) |
| if isinstance(layer_container, (torch.nn.ModuleList, list, tuple)): |
| num_pairs = max(len(layer_container) - 1, 0) |
| exclude_pairs = _parse_exclude_pairs_local( |
| getattr(args, "exclude_pairs", None), num_pairs |
| ) |
| for pair_idx in exclude_pairs: |
| exclude_layer_indices.add(pair_idx) |
| exclude_layer_indices.add(pair_idx + 1) |
| else: |
| print( |
| f"[{log_tag}] Warning: --lora_respect_exclude_pairs enabled, but " |
| f"could not resolve layer path '{requested_layer_path}'." |
| ) |
|
|
| linear_modules = [ |
| (name, module) |
| for name, module in model.named_modules() |
| if isinstance(module, torch.nn.Linear) |
| and (target_modules is None or name.split(".")[-1] in target_modules) |
| and ( |
| not exclude_layer_indices |
| or _extract_layer_index_from_module_name(name, resolved_layer_path or "") |
| not in exclude_layer_indices |
| ) |
| ] |
| return linear_modules, target_modules, exclude_layer_indices, resolved_layer_path |
|
|
|
|
| def apply_lora_adapters( |
| model: torch.nn.Module, args: argparse.Namespace |
| ) -> List[LoRALinear]: |
| if args.lora_rank <= 0: |
| raise SystemExit("--lora_rank must be > 0 when --lora_epochs > 0") |
| linear_modules, target_modules, exclude_layer_indices, _ = ( |
| _select_linear_modules_for_lora_targets(model, args, log_tag="lora") |
| ) |
| if not linear_modules: |
| raise SystemExit( |
| "No Linear modules found for LoRA adapters " |
| "(check --lora_target_modules / --exclude_pairs / --lora_respect_exclude_pairs)." |
| ) |
|
|
| lora_modules: List[LoRALinear] = [] |
| for name, module in linear_modules: |
| resolved = _resolve_parent_module(model, name) |
| if resolved is None: |
| continue |
| parent, attr = resolved |
| wrapped = LoRALinear( |
| base=module, |
| rank=args.lora_rank, |
| alpha=args.lora_alpha, |
| dropout=args.lora_dropout, |
| ) |
| _set_child_module(parent, attr, wrapped) |
| lora_modules.append(wrapped) |
|
|
| for param in model.parameters(): |
| param.requires_grad_(False) |
| for lora_module in lora_modules: |
| for param in lora_module.lora_parameters(): |
| param.requires_grad_(True) |
|
|
| total_params = sum(p.numel() for p in model.parameters()) |
| trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| percent = 100.0 * trainable_params / max(total_params, 1) |
| target_note = "" |
| if target_modules is not None: |
| target_note = f" target={sorted(target_modules)}" |
| exclude_note = "" |
| if exclude_layer_indices: |
| exclude_note = f" excluded_layers={sorted(exclude_layer_indices)}" |
| print( |
| "[lora] Applied adapters to " |
| f"{len(lora_modules)} linear modules " |
| f"({trainable_params}/{total_params} trainable, {percent:.4f}%)." |
| f"{target_note}{exclude_note}" |
| ) |
| return lora_modules |
|
|
|
|
| def merge_lora_adapters(model: torch.nn.Module) -> None: |
| lora_entries = [ |
| (name, module) |
| for name, module in model.named_modules() |
| if isinstance(module, LoRALinear) |
| ] |
| for name, module in lora_entries: |
| module.merge() |
| resolved = _resolve_parent_module(model, name) |
| if resolved is None: |
| continue |
| parent, attr = resolved |
| _set_child_module(parent, attr, module.base) |
|
|
|
|
| def set_lora_enabled(lora_modules: List[LoRALinear], enabled: bool) -> None: |
| for module in lora_modules: |
| module.enabled = enabled |
|
|
|
|
| def lora_ce_finetune( |
| model: torch.nn.Module, |
| dataloader, |
| eval_tokenizer, |
| eval_datasets: List[str], |
| eval_configs: List[Optional[str]], |
| eval_history: List[Dict[str, object]], |
| args: argparse.Namespace, |
| eval_dataloaders: Optional[Dict[str, object]] = None, |
| progressive_cycle: Optional[int] = None, |
| progressive_total: Optional[int] = None, |
| ) -> None: |
| total_epochs = float(args.lora_epochs) |
| if total_epochs <= 0: |
| return |
|
|
| use_kl = bool(getattr(args, "lora_kl_enabled", False)) |
| kl_weight = float(getattr(args, "lora_kl_weight", 0.0)) |
| kl_temp = float(getattr(args, "lora_kl_temp", 1.0)) |
| if use_kl: |
| if kl_weight < 0.0: |
| raise SystemExit("--lora_kl_weight must be >= 0") |
| if kl_temp <= 0.0: |
| raise SystemExit("--lora_kl_temp must be > 0") |
| if kl_weight == 0.0: |
| use_kl = False |
|
|
| lora_modules = apply_lora_adapters(model, args) |
| if not lora_modules: |
| return |
|
|
| model.train() |
|
|
| lora_params = [] |
| for module in lora_modules: |
| lora_params.extend(module.lora_parameters()) |
|
|
| optimizer = torch.optim.AdamW( |
| lora_params, |
| lr=args.lora_lr, |
| weight_decay=args.lora_weight_decay, |
| ) |
|
|
| device_type = torch.device(args.device).type |
| amp_dtype = None |
| if args.dtype == "float16": |
| amp_dtype = torch.float16 |
| elif args.dtype == "bfloat16": |
| amp_dtype = torch.bfloat16 |
| use_amp = amp_dtype is not None and device_type == "cuda" |
| use_scaler = use_amp and amp_dtype == torch.float16 |
| scaler = torch.cuda.amp.GradScaler() if use_scaler else None |
|
|
| full_epochs = int(total_epochs) |
| fractional = total_epochs - full_epochs |
| if fractional < 1e-8: |
| fractional = 0.0 |
|
|
| epoch_plan = [(epoch_idx, None) for epoch_idx in range(full_epochs)] |
| if fractional > 0: |
| try: |
| batches_per_epoch = len(dataloader) |
| except TypeError as exc: |
| raise SystemExit( |
| "Fractional lora epochs require a dataloader with finite length." |
| ) from exc |
| if batches_per_epoch > 0: |
| frac_batches = int(round(fractional * batches_per_epoch)) |
| if frac_batches <= 0: |
| frac_batches = 1 |
| epoch_plan.append((full_epochs, frac_batches)) |
|
|
| step = 0 |
| for epoch_idx, max_batches in epoch_plan: |
| if max_batches is None: |
| epoch_iter = dataloader |
| else: |
| epoch_iter = itertools.islice(dataloader, max_batches) |
| iterator = epoch_iter |
| if tqdm is not None and _tqdm_enabled(): |
| if progressive_cycle is not None: |
| if progressive_total is not None: |
| desc = ( |
| f"LoRA (cycle {progressive_cycle}/{progressive_total}, " |
| f"epoch {epoch_idx+1})" |
| ) |
| else: |
| desc = f"LoRA (cycle {progressive_cycle}, epoch {epoch_idx+1})" |
| else: |
| desc = f"LoRA (epoch {epoch_idx+1})" |
| iterator = tqdm( |
| epoch_iter, |
| desc=desc, |
| unit="batch", |
| total=max_batches, |
| ) |
| for batch in iterator: |
| input_ids = batch[0].to(args.device) |
| attention_mask = batch[1].to(args.device) |
| autocast_ctx = ( |
| torch.autocast(device_type=device_type, dtype=amp_dtype) |
| if use_amp |
| else nullcontext() |
| ) |
| with autocast_ctx: |
| outputs = model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| use_cache=False, |
| ) |
| logits = outputs.logits |
| shift_logits = logits[:, :-1, :].contiguous() |
| shift_labels = input_ids[:, 1:].contiguous() |
| shift_mask = attention_mask[:, 1:].contiguous() |
| ce_flat = F.cross_entropy( |
| shift_logits.view(-1, shift_logits.size(-1)), |
| shift_labels.view(-1), |
| reduction="none", |
| ) |
| ce_denom = shift_mask.sum() |
| if ce_denom.item() == 0: |
| continue |
| ce_loss = ( |
| ce_flat * shift_mask.view(-1).to(ce_flat.dtype) |
| ).sum() / ce_denom |
| kl_loss = None |
| if use_kl: |
| set_lora_enabled(lora_modules, False) |
| with torch.no_grad(): |
| base_outputs = model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| use_cache=False, |
| ) |
| base_logits = base_outputs.logits |
| set_lora_enabled(lora_modules, True) |
| if base_logits.device != shift_logits.device: |
| base_logits = base_logits.to(shift_logits.device) |
| shift_base_logits = base_logits[:, :-1, :].contiguous() |
| log_p_pre = F.log_softmax(shift_base_logits / kl_temp, dim=-1) |
| log_p_post = F.log_softmax(shift_logits / kl_temp, dim=-1) |
| p_pre = log_p_pre.exp() |
| kl_flat = (p_pre * (log_p_pre - log_p_post)).sum(dim=-1) |
| kl_loss = ( |
| kl_flat * shift_mask.to(kl_flat.dtype) |
| ).sum() / ce_denom |
|
|
| total_loss = ce_loss |
| if kl_loss is not None: |
| total_loss = total_loss + (kl_weight * (kl_temp ** 2) * kl_loss) |
|
|
| if args.lora_grad_accum_steps > 1: |
| total_loss = total_loss / args.lora_grad_accum_steps |
| if use_scaler: |
| scaler.scale(total_loss).backward() |
| else: |
| total_loss.backward() |
|
|
| if (step + 1) % args.lora_grad_accum_steps == 0: |
| if args.lora_max_grad_norm is not None: |
| if use_scaler: |
| scaler.unscale_(optimizer) |
| torch.nn.utils.clip_grad_norm_( |
| lora_params, |
| args.lora_max_grad_norm, |
| ) |
| if use_scaler: |
| scaler.step(optimizer) |
| scaler.update() |
| else: |
| optimizer.step() |
| optimizer.zero_grad(set_to_none=True) |
|
|
| if args.lora_eval_every and (step + 1) % args.lora_eval_every == 0: |
| prev_mode = model.training |
| model.eval() |
| eval_device = args.eval_device or args.device |
| if eval_dataloaders is not None: |
| results = ppl_eval.evaluate_ppl_dataloaders( |
| model, |
| eval_dataloaders, |
| eval_device, |
| max_batches=args.lora_eval_max_batches, |
| ) |
| else: |
| results = ppl_eval.evaluate_ppl_datasets( |
| model, |
| eval_tokenizer, |
| datasets=eval_datasets, |
| configs=eval_configs, |
| split=args.eval_split, |
| text_field=args.eval_text_field, |
| num_samples=args.eval_num_samples, |
| seq_len=args.eval_seq_len, |
| batch_size=args.eval_batch_size or args.batch_size, |
| device=eval_device, |
| seed=args.seed, |
| shuffle=False, |
| model_family=args.eval_model_family, |
| add_bos=args.eval_add_bos, |
| max_batches=args.lora_eval_max_batches, |
| cache_dir=args.eval_cache_dir, |
| num_workers=args.eval_num_workers, |
| ) |
| eval_history.append({"step": step + 1, "ppl": results}) |
| print(f"[lora] eval step={step+1}: {results}") |
| if prev_mode: |
| model.train() |
|
|
| if args.lora_log_steps and ( |
| step == 0 or (step + 1) % args.lora_log_steps == 0 |
| ): |
| log_parts = [f"loss={total_loss.item():.6f}"] |
| if kl_loss is not None: |
| log_parts.append(f"kl={kl_loss.item():.6f}") |
| print( |
| f"[lora] epoch={epoch_idx+1} step={step+1} " |
| + " ".join(log_parts) |
| ) |
| step += 1 |
|
|
| merge_lora_adapters(model) |
|
|
|
|
| def _masked_kl( |
| logits_p: torch.Tensor, |
| logits_q: torch.Tensor, |
| attention_mask: torch.Tensor, |
| temp: float, |
| detach_p: bool = True, |
| ) -> Optional[torch.Tensor]: |
| shift_mask = attention_mask[:, 1:].contiguous() |
| denom = shift_mask.sum() |
| if denom.item() == 0: |
| return None |
|
|
| p = logits_p[:, :-1, :].contiguous() |
| q = logits_q[:, :-1, :].contiguous() |
| if p.device != q.device: |
| p = p.to(q.device) |
|
|
| |
| log_p = F.log_softmax(p / temp, dim=-1) |
| log_q = F.log_softmax(q / temp, dim=-1) |
| if detach_p: |
| log_p = log_p.detach() |
| p_probs = log_p.exp() |
| kl_flat = (p_probs * (log_p - log_q)).sum(dim=-1) |
| return (kl_flat * shift_mask.to(kl_flat.dtype)).sum() / denom |
|
|
|
|
| def _extract_hidden_tensor(output: object) -> Optional[torch.Tensor]: |
| if isinstance(output, torch.Tensor): |
| return output |
| if isinstance(output, (tuple, list)) and output: |
| first = output[0] |
| if isinstance(first, torch.Tensor): |
| return first |
| return None |
|
|
|
|
| def _grad_l2_norm(grads: List[Optional[torch.Tensor]]) -> float: |
| total = 0.0 |
| for grad in grads: |
| if grad is None: |
| continue |
| total += float(grad.detach().float().pow(2).sum().item()) |
| if total <= 0.0: |
| return 0.0 |
| return float(math.sqrt(total)) |
|
|
|
|
| def _register_forward_pre_hook_with_optional_kwargs(layer, hook): |
| try: |
| handle = layer.register_forward_pre_hook(hook, with_kwargs=True) |
| return handle |
| except TypeError: |
| def wrapper(module, inputs): |
| return hook(module, inputs, None) |
|
|
| return layer.register_forward_pre_hook(wrapper) |
|
|
|
|
| def commutator_precondition( |
| student_model: torch.nn.Module, |
| student_layers: List[torch.nn.Module], |
| teacher_model: torch.nn.Module, |
| dataloader, |
| dwce_scores: Optional[List[float]], |
| args: argparse.Namespace, |
| exclude_pairs: Optional[Set[int]] = None, |
| progressive_cycle: Optional[int] = None, |
| progressive_total: Optional[int] = None, |
| ) -> Dict[str, object]: |
| """Run commutator-style preconditioning before pair fusion. |
| |
| Objective on each sampled pair i: |
| L = T^2 * KL(p_teacher || p_student) + mu * L_interaction(i) |
| |
| Interaction loss is computed locally on block (i+1): |
| r1 = B_{i+1}(h_{i+1}) - h_{i+1} |
| r0 = B_{i+1}(h_i) - h_i |
| L_interaction = ||r1-r0||^2 (or relative form). |
| """ |
| if not bool(getattr(args, "comm_enabled", False)): |
| return {"enabled": False} |
| if not student_layers or len(student_layers) < 2: |
| return {"enabled": False, "reason": "need_at_least_2_layers"} |
|
|
| temp = float(getattr(args, "comm_temp", 2.0)) |
| steps_ratio = float(getattr(args, "comm_steps_ratio", 0.1)) |
| lr_scale = float(getattr(args, "comm_lr_scale", 0.1)) |
| sample_eta = float(getattr(args, "comm_sample_eta", 0.5)) |
| sample_dwce_scale = float(getattr(args, "comm_sample_dwce_scale", 1.0)) |
| top_k = int(getattr(args, "comm_topk", 1)) |
| interaction_mode = str(getattr(args, "comm_interaction_mode", "relative")).strip().lower() |
| interaction_eps = float(getattr(args, "comm_interaction_eps", 1e-8)) |
| mu_cfg = getattr(args, "comm_mu", None) |
| mu_auto = bool(getattr(args, "comm_mu_auto", False)) |
| mu_auto_rho = float(getattr(args, "comm_mu_auto_rho", 0.1)) |
| mu_auto_eps = float(getattr(args, "comm_mu_auto_eps", 1e-8)) |
| comm_train_mode = str(getattr(args, "comm_train_mode", "lora")).strip().lower() |
| log_steps = int(getattr(args, "comm_log_steps", 50)) |
|
|
| if temp <= 0.0: |
| raise SystemExit("--comm_temp must be > 0") |
| if steps_ratio < 0.0: |
| raise SystemExit("--comm_steps_ratio must be >= 0") |
| if lr_scale <= 0.0: |
| raise SystemExit("--comm_lr_scale must be > 0") |
| if not (0.0 <= sample_eta <= 1.0): |
| raise SystemExit("--comm_sample_eta must be in [0, 1]") |
| if top_k <= 0: |
| raise SystemExit("--comm_topk must be >= 1") |
| if interaction_mode not in {"mse", "relative"}: |
| raise SystemExit("--comm_interaction_mode must be one of: mse, relative") |
| if comm_train_mode not in {"lora", "full"}: |
| raise SystemExit("--comm_train_mode must be one of: lora, full") |
| if interaction_eps <= 0.0: |
| raise SystemExit("--comm_interaction_eps must be > 0") |
| if mu_auto_rho < 0.0: |
| raise SystemExit("--comm_mu_auto_rho must be >= 0") |
| if mu_auto_eps <= 0.0: |
| raise SystemExit("--comm_mu_auto_eps must be > 0") |
|
|
| if mu_cfg is None: |
| base_mu = 0.5 if interaction_mode == "relative" else 0.1 |
| else: |
| base_mu = float(mu_cfg) |
| if base_mu < 0.0: |
| raise SystemExit("--comm_mu must be >= 0") |
|
|
| distill_epochs = float(getattr(args, "distill_epochs", 1.0)) |
| if distill_epochs <= 0.0: |
| distill_epochs = 1.0 |
| grad_accum = int(getattr(args, "distill_grad_accum_steps", 1)) |
| if grad_accum <= 0: |
| grad_accum = 1 |
|
|
| try: |
| batches_per_epoch = len(dataloader) |
| except TypeError as exc: |
| raise SystemExit( |
| "Commutator preconditioning requires a finite-length distillation dataloader." |
| ) from exc |
| if batches_per_epoch <= 0: |
| return {"enabled": False, "reason": "empty_dataloader"} |
|
|
| full_epochs = int(distill_epochs) |
| fractional = distill_epochs - full_epochs |
| if fractional < 1e-8: |
| fractional = 0.0 |
| total_batches = full_epochs * batches_per_epoch |
| if fractional > 0.0: |
| frac_batches = int(round(fractional * batches_per_epoch)) |
| if frac_batches <= 0: |
| frac_batches = 1 |
| total_batches += frac_batches |
|
|
| distill_opt_steps = int(math.ceil(total_batches / float(grad_accum))) |
| target_opt_steps = int(round(steps_ratio * distill_opt_steps)) |
| if target_opt_steps <= 0: |
| target_opt_steps = 1 |
|
|
| num_pairs = max(len(student_layers) - 1, 0) |
| exclude_set = { |
| int(idx) |
| for idx in (exclude_pairs or set()) |
| if isinstance(idx, int) and 0 <= int(idx) < num_pairs |
| } |
| allowed_pairs = [i for i in range(num_pairs) if i not in exclude_set] |
| if not allowed_pairs: |
| return {"enabled": False, "reason": "all_pairs_excluded"} |
|
|
| ranked_pairs = list(allowed_pairs) |
| if dwce_scores is not None and len(dwce_scores) >= num_pairs: |
| finite_pairs = [] |
| for idx in allowed_pairs: |
| value = float(dwce_scores[idx]) |
| if math.isfinite(value): |
| finite_pairs.append(idx) |
| if finite_pairs: |
| ranked_pairs = sorted(finite_pairs, key=lambda i: float(dwce_scores[i])) |
| else: |
| ranked_pairs = list(allowed_pairs) |
| candidate_pairs = ranked_pairs[: min(top_k, len(ranked_pairs))] |
| if not candidate_pairs: |
| return {"enabled": False, "reason": "no_candidate_pairs"} |
|
|
| layer_trainable_params: List[List[torch.nn.Parameter]] = [] |
| trainable_params: List[torch.nn.Parameter] = [] |
| if comm_train_mode == "lora": |
| |
| lora_modules = apply_lora_adapters(student_model, args) |
| if not lora_modules: |
| return {"enabled": False, "reason": "no_lora_modules"} |
|
|
| trainable_seen: Set[int] = set() |
| for module in lora_modules: |
| for param in module.lora_parameters(): |
| pid = id(param) |
| if pid in trainable_seen: |
| continue |
| trainable_seen.add(pid) |
| trainable_params.append(param) |
|
|
| for layer in student_layers: |
| seen: Set[int] = set() |
| params: List[torch.nn.Parameter] = [] |
| for module in layer.modules(): |
| if not isinstance(module, LoRALinear): |
| continue |
| for param in module.lora_parameters(): |
| pid = id(param) |
| if pid in seen: |
| continue |
| seen.add(pid) |
| params.append(param) |
| layer_trainable_params.append(params) |
| else: |
| |
| for layer in student_layers: |
| seen: Set[int] = set() |
| params: List[torch.nn.Parameter] = [] |
| for param in layer.parameters(): |
| if not isinstance(param, torch.nn.Parameter): |
| continue |
| pid = id(param) |
| if pid in seen: |
| continue |
| seen.add(pid) |
| params.append(param) |
| layer_trainable_params.append(params) |
|
|
| candidate_pairs = [ |
| i |
| for i in candidate_pairs |
| if (i + 1) < len(layer_trainable_params) and layer_trainable_params[i + 1] |
| ] |
| if not candidate_pairs: |
| if comm_train_mode == "lora": |
| merge_lora_adapters(student_model) |
| return {"enabled": False, "reason": "no_trainable_receiver_layers"} |
|
|
| if comm_train_mode == "full": |
| trainable_seen: Set[int] = set() |
| for pair_idx in candidate_pairs: |
| for param in layer_trainable_params[pair_idx + 1]: |
| pid = id(param) |
| if pid in trainable_seen: |
| continue |
| trainable_seen.add(pid) |
| trainable_params.append(param) |
| if not trainable_params: |
| return {"enabled": False, "reason": "no_trainable_receiver_layers"} |
|
|
| |
| for param in student_model.parameters(): |
| param.requires_grad_(False) |
| for param in trainable_params: |
| param.requires_grad_(True) |
|
|
| if not trainable_params: |
| if comm_train_mode == "lora": |
| merge_lora_adapters(student_model) |
| return {"enabled": False, "reason": "no_trainable_params"} |
|
|
| candidate_probs = torch.full( |
| (len(candidate_pairs),), |
| 1.0 / float(len(candidate_pairs)), |
| dtype=torch.float32, |
| ) |
| if dwce_scores is not None and len(dwce_scores) >= num_pairs and sample_eta > 0.0: |
| score_vec = torch.tensor( |
| [float(dwce_scores[i]) for i in candidate_pairs], dtype=torch.float32 |
| ) |
| score_vec = torch.nan_to_num(score_vec, nan=1e9, posinf=1e9, neginf=-1e9) |
| biased = torch.softmax(-float(sample_dwce_scale) * score_vec, dim=0) |
| candidate_probs = (1.0 - sample_eta) * candidate_probs + sample_eta * biased |
| candidate_probs = candidate_probs / candidate_probs.sum() |
|
|
| probs_by_pair = [0.0 for _ in range(num_pairs)] |
| for pos, pair_idx in enumerate(candidate_pairs): |
| probs_by_pair[pair_idx] = float(candidate_probs[pos].item()) |
|
|
| lr = float(getattr(args, "distill_lr", 1e-4)) * lr_scale |
| optimizer = torch.optim.AdamW( |
| trainable_params, |
| lr=lr, |
| weight_decay=float(getattr(args, "distill_weight_decay", 0.0)), |
| ) |
|
|
| device_type = torch.device(args.device).type |
| amp_dtype = None |
| if args.dtype == "float16": |
| amp_dtype = torch.float16 |
| elif args.dtype == "bfloat16": |
| amp_dtype = torch.bfloat16 |
| use_amp = amp_dtype is not None and device_type == "cuda" |
| use_scaler = use_amp and amp_dtype == torch.float16 |
| scaler = torch.cuda.amp.GradScaler() if use_scaler else None |
|
|
| teacher_device = next(teacher_model.parameters()).device |
| teacher_model.eval() |
| student_model.train() |
|
|
| gen = torch.Generator(device="cpu") |
| seed = int(getattr(args, "seed", 0)) |
| if progressive_cycle is not None: |
| seed += int(progressive_cycle) * 100003 |
| gen.manual_seed(seed) |
|
|
| opt_step = 0 |
| total_loss_sum = 0.0 |
| anchor_sum = 0.0 |
| interaction_sum = 0.0 |
| mu_sum = 0.0 |
| counted = 0 |
| pair_counts = [0 for _ in range(num_pairs)] |
|
|
| desc = "Comm" |
| if progressive_cycle is not None: |
| if progressive_total is not None: |
| desc = f"Comm (cycle {progressive_cycle}/{progressive_total})" |
| else: |
| desc = f"Comm (cycle {progressive_cycle})" |
| iterator = range(target_opt_steps) |
| if tqdm is not None and _tqdm_enabled(): |
| iterator = tqdm(iterator, desc=desc, unit="step") |
|
|
| data_iter = iter(dataloader) |
| autocast_ctx = ( |
| torch.autocast(device_type=device_type, dtype=amp_dtype) |
| if use_amp |
| else nullcontext() |
| ) |
|
|
| for _ in iterator: |
| optimizer.zero_grad(set_to_none=True) |
| accum_done = 0 |
| while accum_done < grad_accum: |
| try: |
| batch = next(data_iter) |
| except StopIteration: |
| data_iter = iter(dataloader) |
| batch = next(data_iter) |
|
|
| input_ids = batch[0].to(args.device) |
| attention_mask = batch[1].to(args.device) |
| sampled_pos = int(torch.multinomial(candidate_probs, 1, generator=gen).item()) |
| pair_idx = int(candidate_pairs[sampled_pos]) |
| pair_counts[pair_idx] += 1 |
|
|
| receiver_params = layer_trainable_params[pair_idx + 1] |
| receiver_param_ids = {id(param) for param in receiver_params} |
|
|
| teacher_ids = input_ids.to(teacher_device) |
| teacher_mask = attention_mask.to(teacher_device) |
| with torch.no_grad(), autocast_ctx: |
| teacher_outputs = teacher_model( |
| input_ids=teacher_ids, |
| attention_mask=teacher_mask, |
| use_cache=False, |
| ) |
| teacher_logits = teacher_outputs.logits |
|
|
| capture: Dict[str, object] = { |
| "h_l": None, |
| "h_lp1": None, |
| "y1": None, |
| "recv_args": None, |
| "recv_kwargs": None, |
| } |
|
|
| def _hook_l(_module, inputs, _output): |
| if inputs and isinstance(inputs[0], torch.Tensor): |
| capture["h_l"] = inputs[0] |
|
|
| def _hook_recv_pre(_module, inputs, kwargs): |
| capture["recv_args"] = inputs |
| capture["recv_kwargs"] = kwargs |
|
|
| def _hook_recv(_module, inputs, output): |
| if inputs and isinstance(inputs[0], torch.Tensor): |
| capture["h_lp1"] = inputs[0] |
| capture["y1"] = _extract_hidden_tensor(output) |
|
|
| handles: List[object] = [ |
| student_layers[pair_idx].register_forward_hook(_hook_l), |
| _register_forward_pre_hook_with_optional_kwargs( |
| student_layers[pair_idx + 1], _hook_recv_pre |
| ), |
| student_layers[pair_idx + 1].register_forward_hook(_hook_recv), |
| ] |
| try: |
| with autocast_ctx: |
| student_outputs = student_model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| use_cache=False, |
| ) |
| student_logits = student_outputs.logits |
| finally: |
| for handle in handles: |
| try: |
| handle.remove() |
| except Exception: |
| pass |
|
|
| with autocast_ctx: |
| anchor_kl = _masked_kl( |
| teacher_logits, |
| student_logits, |
| attention_mask, |
| temp=temp, |
| detach_p=True, |
| ) |
| if anchor_kl is None: |
| continue |
| anchor_loss = (temp ** 2) * anchor_kl |
|
|
| interaction_loss = None |
| h_l = capture.get("h_l") |
| h_lp1 = capture.get("h_lp1") |
| y1 = capture.get("y1") |
| recv_args = capture.get("recv_args") |
| recv_kwargs = capture.get("recv_kwargs") |
| if ( |
| isinstance(h_l, torch.Tensor) |
| and isinstance(h_lp1, torch.Tensor) |
| and isinstance(y1, torch.Tensor) |
| and isinstance(recv_args, tuple) |
| and len(recv_args) > 0 |
| and isinstance(recv_args[0], torch.Tensor) |
| ): |
| call_args = list(recv_args) |
| first_hidden = call_args[0] |
| h_l_detached = h_l.detach().to( |
| device=first_hidden.device, |
| dtype=first_hidden.dtype, |
| ) |
| call_args[0] = h_l_detached |
| call_kwargs = dict(recv_kwargs) if isinstance(recv_kwargs, dict) else {} |
|
|
| y0_raw = student_layers[pair_idx + 1](*tuple(call_args), **call_kwargs) |
| y0 = _extract_hidden_tensor(y0_raw) |
| if isinstance(y0, torch.Tensor): |
| if y0.device != y1.device: |
| y0 = y0.to(y1.device) |
| h_lp1_detached = h_lp1.detach().to(device=y1.device, dtype=y1.dtype) |
| h_l_for_res = h_l.detach().to(device=y0.device, dtype=y0.dtype) |
| r1 = y1 - h_lp1_detached |
| r0 = y0 - h_l_for_res |
| mask = attention_mask.to(dtype=r1.dtype) |
| mask_sum = mask.sum() |
| if mask_sum.item() > 0: |
| if interaction_mode == "relative": |
| num = (r1 - r0).float().pow(2).sum(dim=-1) |
| den = r1.float().pow(2).sum(dim=-1) + float(interaction_eps) |
| ratio = (num / den) * mask.to(num.dtype) |
| interaction_loss = ratio.sum() / (mask_sum + 1e-8) |
| else: |
| denom = mask_sum * r1.size(-1) |
| if denom.item() > 0: |
| interaction_loss = ( |
| (r1 - r0).pow(2) * mask.unsqueeze(-1) |
| ).sum() / denom |
|
|
| mu_effective = float(base_mu) |
| if ( |
| mu_auto |
| and interaction_loss is not None |
| and receiver_params |
| and mu_auto_rho > 0.0 |
| ): |
| anchor_grads = torch.autograd.grad( |
| anchor_loss, |
| receiver_params, |
| retain_graph=True, |
| allow_unused=True, |
| ) |
| interaction_grads = torch.autograd.grad( |
| interaction_loss, |
| receiver_params, |
| retain_graph=True, |
| allow_unused=True, |
| ) |
| anchor_norm = _grad_l2_norm(list(anchor_grads)) |
| interaction_norm = _grad_l2_norm(list(interaction_grads)) |
| if interaction_norm > 0.0: |
| mu_effective = float( |
| mu_auto_rho |
| * (anchor_norm / (interaction_norm + float(mu_auto_eps))) |
| ) |
| else: |
| mu_effective = float(base_mu) |
| if not math.isfinite(mu_effective): |
| mu_effective = float(base_mu) |
|
|
| total_loss = anchor_loss |
| if interaction_loss is not None: |
| total_loss = total_loss + (float(mu_effective) * interaction_loss) |
|
|
| if grad_accum > 1: |
| total_loss = total_loss / float(grad_accum) |
|
|
| if use_scaler: |
| scaler.scale(total_loss).backward() |
| else: |
| total_loss.backward() |
|
|
| |
| for param in trainable_params: |
| if id(param) in receiver_param_ids: |
| continue |
| if param.grad is not None: |
| if comm_train_mode == "lora": |
| param.grad.zero_() |
| else: |
| param.grad = None |
|
|
| total_loss_sum += float(total_loss.detach().float().item()) |
| anchor_sum += float(anchor_loss.detach().float().item()) |
| if interaction_loss is not None: |
| interaction_sum += float(interaction_loss.detach().float().item()) |
| mu_sum += float(mu_effective) |
| counted += 1 |
| accum_done += 1 |
|
|
| if args.distill_max_grad_norm is not None: |
| if use_scaler: |
| scaler.unscale_(optimizer) |
| torch.nn.utils.clip_grad_norm_( |
| trainable_params, |
| float(args.distill_max_grad_norm), |
| ) |
|
|
| if use_scaler: |
| scaler.step(optimizer) |
| scaler.update() |
| else: |
| optimizer.step() |
|
|
| opt_step += 1 |
| if log_steps and (opt_step == 1 or opt_step % log_steps == 0): |
| denom = max(counted, 1) |
| print( |
| f"[comm] step={opt_step}/{target_opt_steps} " |
| f"loss={total_loss_sum/denom:.6f} " |
| f"anchor={anchor_sum/denom:.6f} " |
| f"int={interaction_sum/denom:.6f} " |
| f"mu={mu_sum/denom:.6f}" |
| ) |
|
|
| if comm_train_mode == "lora": |
| merge_lora_adapters(student_model) |
|
|
| stats: Dict[str, object] = { |
| "enabled": True, |
| "train_mode": comm_train_mode, |
| "opt_steps": int(target_opt_steps), |
| "grad_accum_steps": int(grad_accum), |
| "lr": float(lr), |
| "temp": float(temp), |
| "steps_ratio": float(steps_ratio), |
| "lr_scale": float(lr_scale), |
| "interaction_mode": interaction_mode, |
| "interaction_eps": float(interaction_eps), |
| "mu": float(base_mu), |
| "mu_auto": bool(mu_auto), |
| "mu_auto_rho": float(mu_auto_rho), |
| "mu_auto_eps": float(mu_auto_eps), |
| "sample_eta": float(sample_eta), |
| "sample_dwce_scale": float(sample_dwce_scale), |
| "topk": int(top_k), |
| "candidate_pairs": [int(i) for i in candidate_pairs], |
| "trainable_params": int(sum(int(param.numel()) for param in trainable_params)), |
| } |
| total_samples = int(sum(pair_counts)) |
| probs_list = [float(x) for x in probs_by_pair] |
| freqs = ( |
| [float(c) / float(total_samples) for c in pair_counts] |
| if total_samples > 0 |
| else [0.0 for _ in pair_counts] |
| ) |
| top_show = min(10, num_pairs) |
| top_indices = sorted(range(num_pairs), key=lambda i: pair_counts[i], reverse=True)[:top_show] |
| top_pairs = [ |
| { |
| "pair": int(i), |
| "count": int(pair_counts[i]), |
| "freq": float(freqs[i]), |
| "prob": float(probs_list[i]) if i < len(probs_list) else None, |
| } |
| for i in top_indices |
| if pair_counts[i] > 0 |
| ] |
| stats["pair_selection"] = { |
| "num_pairs": int(num_pairs), |
| "excluded_pairs": sorted(exclude_set), |
| "candidate_pairs": [int(i) for i in candidate_pairs], |
| "total_samples": total_samples, |
| "unique_pairs": int(sum(1 for c in pair_counts if c > 0)), |
| "counts": [int(c) for c in pair_counts], |
| "freqs": freqs, |
| "probs": probs_list, |
| "top_pairs": top_pairs, |
| } |
|
|
| if total_samples > 0 and top_pairs: |
| top_str = ", ".join( |
| f"{entry['pair']}-{entry['pair'] + 1}: {entry['count']} " |
| f"(obs={entry['freq']:.3f}, exp={entry['prob']:.3f})" |
| for entry in top_pairs |
| if entry.get("prob") is not None |
| ) |
| if not top_str: |
| top_str = ", ".join( |
| f"{entry['pair']}-{entry['pair'] + 1}: {entry['count']} " |
| f"(obs={entry['freq']:.3f})" |
| for entry in top_pairs |
| ) |
| print( |
| f"[comm] Pair sampling stats: total={total_samples} " |
| f"unique={stats['pair_selection']['unique_pairs']}/{num_pairs} " |
| f"top={top_str}" |
| ) |
|
|
| if counted > 0: |
| stats["avg_loss"] = float(total_loss_sum / float(counted)) |
| stats["avg_anchor"] = float(anchor_sum / float(counted)) |
| stats["avg_interaction"] = float(interaction_sum / float(counted)) |
| stats["avg_mu"] = float(mu_sum / float(counted)) |
| return stats |
|
|