File size: 7,990 Bytes
70b8d48 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 | """Core export utilities for hard-pruning and kernel-aligned rounding.
This module is *family-agnostic*. Adapters (e.g., ViT, ResNet, LLM) should:
1) decide which gates map to which structural dims (heads, hidden groups, channels),
2) obtain KEEP indices using helpers in this file, and
3) rebuild family-specific modules with the sliced weights.
Provided here:
- Rounding policies and helpers (floors, multiples, warmup keep-all)
- KEEP index selection from a `Gate` (or gate-like) object
- Generic weight slicers for Linear / Conv2d / Embedding
- Small safe-guards for dtype/device and shape checks
The library avoids touching family internals here. Exporters in adapters should
use these primitives to assemble a clean pruned model.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Iterable, Optional, Sequence, Tuple
import torch
import torch.nn as nn
from .gates import Gate, expand_group_indices
# -----------------------------------------------------------------------------
# Policies & rounding
# -----------------------------------------------------------------------------
@dataclass
class Rounding:
"""Rounding policy for a single gated axis.
Attributes
----------
floor_groups : int
Minimum number of groups to keep after rounding.
multiple_groups : int
Snap the number of groups kept down to a multiple of this (>=1).
min_keep_ratio : float
Optional fractional lower bound on expected keep; applied before rounding.
"""
floor_groups: int = 1
multiple_groups: int = 1
min_keep_ratio: float = 0.0
@dataclass
class ExportPolicy:
"""Export-time policy shared by families.
- `warmup_steps`: if current `step < warmup_steps`, keep-all.
- `rounding`: default rounding used unless adapter overrides per-axis.
"""
warmup_steps: int = 0
rounding: Rounding = Rounding()
def _round_down_mult(n: int, m: int) -> int:
if m is None or m <= 1:
return max(1, int(n))
n = int(n)
return max(m, (n // m) * m)
def _compute_keep_k(
expected_kept: float,
total_groups: int,
*,
rounding: Rounding,
) -> int:
# Start from nearest-integer expectation
k = int(round(expected_kept))
# Apply ratio floor, then absolute floor, then multiple snapping
k = max(k, int(rounding.min_keep_ratio * total_groups))
k = max(k, int(rounding.floor_groups))
k = min(k, total_groups)
k = _round_down_mult(k, int(rounding.multiple_groups))
return max(1, min(k, total_groups))
# -----------------------------------------------------------------------------
# KEEP index selection from a gate
# -----------------------------------------------------------------------------
@torch.no_grad()
def keep_group_indices_from_gate(
gate: Gate,
*,
policy: ExportPolicy,
step: Optional[int] = None,
custom_rounding: Optional[Rounding] = None,
) -> torch.Tensor:
"""Return sorted indices of groups to KEEP based on `gate` and policy.
If `step < warmup_steps`, returns all indices (keep-all). Otherwise, the
number of groups to keep is computed from the *expected keep* under the
current logits and snapped according to the rounding policy.
"""
G = int(gate.num_groups)
if step is not None and step < int(policy.warmup_steps):
return torch.arange(G, device=gate.logits.device)
rounding = custom_rounding or policy.rounding
p = torch.sigmoid(gate.logits.detach().float() / float(gate.tau))
k = _compute_keep_k(expected_kept=float(p.sum()), total_groups=G, rounding=rounding)
idx = torch.topk(p, k, largest=True).indices.sort().values
return idx.to(torch.long)
@torch.no_grad()
def keep_element_indices_from_gate(
gate: Gate,
*,
policy: ExportPolicy,
step: Optional[int] = None,
custom_rounding: Optional[Rounding] = None,
) -> torch.Tensor:
"""Expand kept *group* indices into element indices using `group_size`."""
grp_idx = keep_group_indices_from_gate(gate, policy=policy, step=step, custom_rounding=custom_rounding)
return expand_group_indices(grp_idx, gate.group_size)
# -----------------------------------------------------------------------------
# Generic slicers
# -----------------------------------------------------------------------------
@torch.no_grad()
def slice_linear(mat: nn.Linear, keep_in: Optional[Sequence[int]] = None, keep_out: Optional[Sequence[int]] = None) -> nn.Linear:
"""Create a new Linear with selected input/output features preserved.
- `keep_out` selects rows (output features)
- `keep_in` selects columns (input features)
"""
W = mat.weight.detach()
b = mat.bias.detach() if mat.bias is not None else None
if keep_out is not None:
W = W.index_select(0, torch.as_tensor(keep_out, device=W.device))
if b is not None:
b = b.index_select(0, torch.as_tensor(keep_out, device=b.device))
if keep_in is not None:
W = W.index_select(1, torch.as_tensor(keep_in, device=W.device))
out_f, in_f = W.shape
new = nn.Linear(in_f, out_f, bias=(b is not None)).to(W.device)
new.weight.copy_(W)
if b is not None:
new.bias.copy_(b)
return new
@torch.no_grad()
def slice_conv2d(conv: nn.Conv2d, keep_in: Optional[Sequence[int]] = None, keep_out: Optional[Sequence[int]] = None) -> nn.Conv2d:
"""Create a new Conv2d with selected in/out channels preserved.
Only supports standard conv2d (no groups/depthwise changes). For grouped
convs, the adapter should handle group alignment before calling this.
"""
W = conv.weight.detach()
b = conv.bias.detach() if conv.bias is not None else None
if keep_out is not None:
W = W.index_select(0, torch.as_tensor(keep_out, device=W.device))
if b is not None:
b = b.index_select(0, torch.as_tensor(keep_out, device=b.device))
if keep_in is not None:
W = W.index_select(1, torch.as_tensor(keep_in, device=W.device))
out_c, in_c = W.shape[:2]
new = nn.Conv2d(
in_c,
out_c,
kernel_size=conv.kernel_size,
stride=conv.stride,
padding=conv.padding,
dilation=conv.dilation,
groups=1,
bias=(b is not None),
padding_mode=conv.padding_mode,
).to(W.device)
new.weight.copy_(W)
if b is not None:
new.bias.copy_(b)
return new
@torch.no_grad()
def slice_embedding(emb: nn.Embedding, keep_rows: Optional[Sequence[int]] = None, keep_dim: Optional[Sequence[int]] = None) -> nn.Embedding:
"""Create a new Embedding with selected rows (vocab) and/or dims kept."""
W = emb.weight.detach()
if keep_rows is not None:
W = W.index_select(0, torch.as_tensor(keep_rows, device=W.device))
if keep_dim is not None:
W = W.index_select(1, torch.as_tensor(keep_dim, device=W.device))
num, dim = W.shape
new = nn.Embedding(num, dim, padding_idx=emb.padding_idx, max_norm=emb.max_norm, norm_type=emb.norm_type, scale_grad_by_freq=emb.scale_grad_by_freq, sparse=emb.sparse, device=W.device, dtype=W.dtype)
new.weight.copy_(W)
return new
# -----------------------------------------------------------------------------
# Small helpers for adapters
# -----------------------------------------------------------------------------
@torch.no_grad()
def concat_index_ranges(ranges: Sequence[Tuple[int, int]]) -> torch.Tensor:
"""Given [(start, end_exclusive), ...], return concatenated 1D indices."""
parts = [torch.arange(a, b, dtype=torch.long) for a, b in ranges if b > a]
return torch.cat(parts, dim=0) if parts else torch.empty(0, dtype=torch.long)
@torch.no_grad()
def block_indices_from_groups(groups: Sequence[int], group_size: int) -> torch.Tensor:
"""Convert sorted group ids to expanded feature indices."""
groups = torch.as_tensor(groups, dtype=torch.long)
return expand_group_indices(groups, int(group_size))
|