|
|
"""Core export utilities for hard-pruning and kernel-aligned rounding. |
|
|
|
|
|
This module is *family-agnostic*. Adapters (e.g., ViT, ResNet, LLM) should: |
|
|
1) decide which gates map to which structural dims (heads, hidden groups, channels), |
|
|
2) obtain KEEP indices using helpers in this file, and |
|
|
3) rebuild family-specific modules with the sliced weights. |
|
|
|
|
|
Provided here: |
|
|
- Rounding policies and helpers (floors, multiples, warmup keep-all) |
|
|
- KEEP index selection from a `Gate` (or gate-like) object |
|
|
- Generic weight slicers for Linear / Conv2d / Embedding |
|
|
- Small safe-guards for dtype/device and shape checks |
|
|
|
|
|
The library avoids touching family internals here. Exporters in adapters should |
|
|
use these primitives to assemble a clean pruned model. |
|
|
""" |
|
|
from __future__ import annotations |
|
|
|
|
|
from dataclasses import dataclass |
|
|
from typing import Iterable, Optional, Sequence, Tuple |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
from .gates import Gate, expand_group_indices |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
|
class Rounding: |
|
|
"""Rounding policy for a single gated axis. |
|
|
|
|
|
Attributes |
|
|
---------- |
|
|
floor_groups : int |
|
|
Minimum number of groups to keep after rounding. |
|
|
multiple_groups : int |
|
|
Snap the number of groups kept down to a multiple of this (>=1). |
|
|
min_keep_ratio : float |
|
|
Optional fractional lower bound on expected keep; applied before rounding. |
|
|
""" |
|
|
|
|
|
floor_groups: int = 1 |
|
|
multiple_groups: int = 1 |
|
|
min_keep_ratio: float = 0.0 |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class ExportPolicy: |
|
|
"""Export-time policy shared by families. |
|
|
|
|
|
- `warmup_steps`: if current `step < warmup_steps`, keep-all. |
|
|
- `rounding`: default rounding used unless adapter overrides per-axis. |
|
|
""" |
|
|
|
|
|
warmup_steps: int = 0 |
|
|
rounding: Rounding = Rounding() |
|
|
|
|
|
|
|
|
def _round_down_mult(n: int, m: int) -> int: |
|
|
if m is None or m <= 1: |
|
|
return max(1, int(n)) |
|
|
n = int(n) |
|
|
return max(m, (n // m) * m) |
|
|
|
|
|
|
|
|
def _compute_keep_k( |
|
|
expected_kept: float, |
|
|
total_groups: int, |
|
|
*, |
|
|
rounding: Rounding, |
|
|
) -> int: |
|
|
|
|
|
k = int(round(expected_kept)) |
|
|
|
|
|
k = max(k, int(rounding.min_keep_ratio * total_groups)) |
|
|
k = max(k, int(rounding.floor_groups)) |
|
|
k = min(k, total_groups) |
|
|
k = _round_down_mult(k, int(rounding.multiple_groups)) |
|
|
return max(1, min(k, total_groups)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def keep_group_indices_from_gate( |
|
|
gate: Gate, |
|
|
*, |
|
|
policy: ExportPolicy, |
|
|
step: Optional[int] = None, |
|
|
custom_rounding: Optional[Rounding] = None, |
|
|
) -> torch.Tensor: |
|
|
"""Return sorted indices of groups to KEEP based on `gate` and policy. |
|
|
|
|
|
If `step < warmup_steps`, returns all indices (keep-all). Otherwise, the |
|
|
number of groups to keep is computed from the *expected keep* under the |
|
|
current logits and snapped according to the rounding policy. |
|
|
""" |
|
|
G = int(gate.num_groups) |
|
|
if step is not None and step < int(policy.warmup_steps): |
|
|
return torch.arange(G, device=gate.logits.device) |
|
|
|
|
|
rounding = custom_rounding or policy.rounding |
|
|
p = torch.sigmoid(gate.logits.detach().float() / float(gate.tau)) |
|
|
k = _compute_keep_k(expected_kept=float(p.sum()), total_groups=G, rounding=rounding) |
|
|
idx = torch.topk(p, k, largest=True).indices.sort().values |
|
|
return idx.to(torch.long) |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def keep_element_indices_from_gate( |
|
|
gate: Gate, |
|
|
*, |
|
|
policy: ExportPolicy, |
|
|
step: Optional[int] = None, |
|
|
custom_rounding: Optional[Rounding] = None, |
|
|
) -> torch.Tensor: |
|
|
"""Expand kept *group* indices into element indices using `group_size`.""" |
|
|
grp_idx = keep_group_indices_from_gate(gate, policy=policy, step=step, custom_rounding=custom_rounding) |
|
|
return expand_group_indices(grp_idx, gate.group_size) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def slice_linear(mat: nn.Linear, keep_in: Optional[Sequence[int]] = None, keep_out: Optional[Sequence[int]] = None) -> nn.Linear: |
|
|
"""Create a new Linear with selected input/output features preserved. |
|
|
|
|
|
- `keep_out` selects rows (output features) |
|
|
- `keep_in` selects columns (input features) |
|
|
""" |
|
|
W = mat.weight.detach() |
|
|
b = mat.bias.detach() if mat.bias is not None else None |
|
|
|
|
|
if keep_out is not None: |
|
|
W = W.index_select(0, torch.as_tensor(keep_out, device=W.device)) |
|
|
if b is not None: |
|
|
b = b.index_select(0, torch.as_tensor(keep_out, device=b.device)) |
|
|
if keep_in is not None: |
|
|
W = W.index_select(1, torch.as_tensor(keep_in, device=W.device)) |
|
|
|
|
|
out_f, in_f = W.shape |
|
|
new = nn.Linear(in_f, out_f, bias=(b is not None)).to(W.device) |
|
|
new.weight.copy_(W) |
|
|
if b is not None: |
|
|
new.bias.copy_(b) |
|
|
return new |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def slice_conv2d(conv: nn.Conv2d, keep_in: Optional[Sequence[int]] = None, keep_out: Optional[Sequence[int]] = None) -> nn.Conv2d: |
|
|
"""Create a new Conv2d with selected in/out channels preserved. |
|
|
|
|
|
Only supports standard conv2d (no groups/depthwise changes). For grouped |
|
|
convs, the adapter should handle group alignment before calling this. |
|
|
""" |
|
|
W = conv.weight.detach() |
|
|
b = conv.bias.detach() if conv.bias is not None else None |
|
|
|
|
|
if keep_out is not None: |
|
|
W = W.index_select(0, torch.as_tensor(keep_out, device=W.device)) |
|
|
if b is not None: |
|
|
b = b.index_select(0, torch.as_tensor(keep_out, device=b.device)) |
|
|
if keep_in is not None: |
|
|
W = W.index_select(1, torch.as_tensor(keep_in, device=W.device)) |
|
|
|
|
|
out_c, in_c = W.shape[:2] |
|
|
new = nn.Conv2d( |
|
|
in_c, |
|
|
out_c, |
|
|
kernel_size=conv.kernel_size, |
|
|
stride=conv.stride, |
|
|
padding=conv.padding, |
|
|
dilation=conv.dilation, |
|
|
groups=1, |
|
|
bias=(b is not None), |
|
|
padding_mode=conv.padding_mode, |
|
|
).to(W.device) |
|
|
new.weight.copy_(W) |
|
|
if b is not None: |
|
|
new.bias.copy_(b) |
|
|
return new |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def slice_embedding(emb: nn.Embedding, keep_rows: Optional[Sequence[int]] = None, keep_dim: Optional[Sequence[int]] = None) -> nn.Embedding: |
|
|
"""Create a new Embedding with selected rows (vocab) and/or dims kept.""" |
|
|
W = emb.weight.detach() |
|
|
if keep_rows is not None: |
|
|
W = W.index_select(0, torch.as_tensor(keep_rows, device=W.device)) |
|
|
if keep_dim is not None: |
|
|
W = W.index_select(1, torch.as_tensor(keep_dim, device=W.device)) |
|
|
num, dim = W.shape |
|
|
new = nn.Embedding(num, dim, padding_idx=emb.padding_idx, max_norm=emb.max_norm, norm_type=emb.norm_type, scale_grad_by_freq=emb.scale_grad_by_freq, sparse=emb.sparse, device=W.device, dtype=W.dtype) |
|
|
new.weight.copy_(W) |
|
|
return new |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def concat_index_ranges(ranges: Sequence[Tuple[int, int]]) -> torch.Tensor: |
|
|
"""Given [(start, end_exclusive), ...], return concatenated 1D indices.""" |
|
|
parts = [torch.arange(a, b, dtype=torch.long) for a, b in ranges if b > a] |
|
|
return torch.cat(parts, dim=0) if parts else torch.empty(0, dtype=torch.long) |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def block_indices_from_groups(groups: Sequence[int], group_size: int) -> torch.Tensor: |
|
|
"""Convert sorted group ids to expanded feature indices.""" |
|
|
groups = torch.as_tensor(groups, dtype=torch.long) |
|
|
return expand_group_indices(groups, int(group_size)) |
|
|
|