miosipov's picture
Upload folder using huggingface_hub
70b8d48 verified
"""Core export utilities for hard-pruning and kernel-aligned rounding.
This module is *family-agnostic*. Adapters (e.g., ViT, ResNet, LLM) should:
1) decide which gates map to which structural dims (heads, hidden groups, channels),
2) obtain KEEP indices using helpers in this file, and
3) rebuild family-specific modules with the sliced weights.
Provided here:
- Rounding policies and helpers (floors, multiples, warmup keep-all)
- KEEP index selection from a `Gate` (or gate-like) object
- Generic weight slicers for Linear / Conv2d / Embedding
- Small safe-guards for dtype/device and shape checks
The library avoids touching family internals here. Exporters in adapters should
use these primitives to assemble a clean pruned model.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Iterable, Optional, Sequence, Tuple
import torch
import torch.nn as nn
from .gates import Gate, expand_group_indices
# -----------------------------------------------------------------------------
# Policies & rounding
# -----------------------------------------------------------------------------
@dataclass
class Rounding:
"""Rounding policy for a single gated axis.
Attributes
----------
floor_groups : int
Minimum number of groups to keep after rounding.
multiple_groups : int
Snap the number of groups kept down to a multiple of this (>=1).
min_keep_ratio : float
Optional fractional lower bound on expected keep; applied before rounding.
"""
floor_groups: int = 1
multiple_groups: int = 1
min_keep_ratio: float = 0.0
@dataclass
class ExportPolicy:
"""Export-time policy shared by families.
- `warmup_steps`: if current `step < warmup_steps`, keep-all.
- `rounding`: default rounding used unless adapter overrides per-axis.
"""
warmup_steps: int = 0
rounding: Rounding = Rounding()
def _round_down_mult(n: int, m: int) -> int:
if m is None or m <= 1:
return max(1, int(n))
n = int(n)
return max(m, (n // m) * m)
def _compute_keep_k(
expected_kept: float,
total_groups: int,
*,
rounding: Rounding,
) -> int:
# Start from nearest-integer expectation
k = int(round(expected_kept))
# Apply ratio floor, then absolute floor, then multiple snapping
k = max(k, int(rounding.min_keep_ratio * total_groups))
k = max(k, int(rounding.floor_groups))
k = min(k, total_groups)
k = _round_down_mult(k, int(rounding.multiple_groups))
return max(1, min(k, total_groups))
# -----------------------------------------------------------------------------
# KEEP index selection from a gate
# -----------------------------------------------------------------------------
@torch.no_grad()
def keep_group_indices_from_gate(
gate: Gate,
*,
policy: ExportPolicy,
step: Optional[int] = None,
custom_rounding: Optional[Rounding] = None,
) -> torch.Tensor:
"""Return sorted indices of groups to KEEP based on `gate` and policy.
If `step < warmup_steps`, returns all indices (keep-all). Otherwise, the
number of groups to keep is computed from the *expected keep* under the
current logits and snapped according to the rounding policy.
"""
G = int(gate.num_groups)
if step is not None and step < int(policy.warmup_steps):
return torch.arange(G, device=gate.logits.device)
rounding = custom_rounding or policy.rounding
p = torch.sigmoid(gate.logits.detach().float() / float(gate.tau))
k = _compute_keep_k(expected_kept=float(p.sum()), total_groups=G, rounding=rounding)
idx = torch.topk(p, k, largest=True).indices.sort().values
return idx.to(torch.long)
@torch.no_grad()
def keep_element_indices_from_gate(
gate: Gate,
*,
policy: ExportPolicy,
step: Optional[int] = None,
custom_rounding: Optional[Rounding] = None,
) -> torch.Tensor:
"""Expand kept *group* indices into element indices using `group_size`."""
grp_idx = keep_group_indices_from_gate(gate, policy=policy, step=step, custom_rounding=custom_rounding)
return expand_group_indices(grp_idx, gate.group_size)
# -----------------------------------------------------------------------------
# Generic slicers
# -----------------------------------------------------------------------------
@torch.no_grad()
def slice_linear(mat: nn.Linear, keep_in: Optional[Sequence[int]] = None, keep_out: Optional[Sequence[int]] = None) -> nn.Linear:
"""Create a new Linear with selected input/output features preserved.
- `keep_out` selects rows (output features)
- `keep_in` selects columns (input features)
"""
W = mat.weight.detach()
b = mat.bias.detach() if mat.bias is not None else None
if keep_out is not None:
W = W.index_select(0, torch.as_tensor(keep_out, device=W.device))
if b is not None:
b = b.index_select(0, torch.as_tensor(keep_out, device=b.device))
if keep_in is not None:
W = W.index_select(1, torch.as_tensor(keep_in, device=W.device))
out_f, in_f = W.shape
new = nn.Linear(in_f, out_f, bias=(b is not None)).to(W.device)
new.weight.copy_(W)
if b is not None:
new.bias.copy_(b)
return new
@torch.no_grad()
def slice_conv2d(conv: nn.Conv2d, keep_in: Optional[Sequence[int]] = None, keep_out: Optional[Sequence[int]] = None) -> nn.Conv2d:
"""Create a new Conv2d with selected in/out channels preserved.
Only supports standard conv2d (no groups/depthwise changes). For grouped
convs, the adapter should handle group alignment before calling this.
"""
W = conv.weight.detach()
b = conv.bias.detach() if conv.bias is not None else None
if keep_out is not None:
W = W.index_select(0, torch.as_tensor(keep_out, device=W.device))
if b is not None:
b = b.index_select(0, torch.as_tensor(keep_out, device=b.device))
if keep_in is not None:
W = W.index_select(1, torch.as_tensor(keep_in, device=W.device))
out_c, in_c = W.shape[:2]
new = nn.Conv2d(
in_c,
out_c,
kernel_size=conv.kernel_size,
stride=conv.stride,
padding=conv.padding,
dilation=conv.dilation,
groups=1,
bias=(b is not None),
padding_mode=conv.padding_mode,
).to(W.device)
new.weight.copy_(W)
if b is not None:
new.bias.copy_(b)
return new
@torch.no_grad()
def slice_embedding(emb: nn.Embedding, keep_rows: Optional[Sequence[int]] = None, keep_dim: Optional[Sequence[int]] = None) -> nn.Embedding:
"""Create a new Embedding with selected rows (vocab) and/or dims kept."""
W = emb.weight.detach()
if keep_rows is not None:
W = W.index_select(0, torch.as_tensor(keep_rows, device=W.device))
if keep_dim is not None:
W = W.index_select(1, torch.as_tensor(keep_dim, device=W.device))
num, dim = W.shape
new = nn.Embedding(num, dim, padding_idx=emb.padding_idx, max_norm=emb.max_norm, norm_type=emb.norm_type, scale_grad_by_freq=emb.scale_grad_by_freq, sparse=emb.sparse, device=W.device, dtype=W.dtype)
new.weight.copy_(W)
return new
# -----------------------------------------------------------------------------
# Small helpers for adapters
# -----------------------------------------------------------------------------
@torch.no_grad()
def concat_index_ranges(ranges: Sequence[Tuple[int, int]]) -> torch.Tensor:
"""Given [(start, end_exclusive), ...], return concatenated 1D indices."""
parts = [torch.arange(a, b, dtype=torch.long) for a, b in ranges if b > a]
return torch.cat(parts, dim=0) if parts else torch.empty(0, dtype=torch.long)
@torch.no_grad()
def block_indices_from_groups(groups: Sequence[int], group_size: int) -> torch.Tensor:
"""Convert sorted group ids to expanded feature indices."""
groups = torch.as_tensor(groups, dtype=torch.long)
return expand_group_indices(groups, int(group_size))