File size: 7,990 Bytes
70b8d48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
"""Core export utilities for hard-pruning and kernel-aligned rounding.

This module is *family-agnostic*. Adapters (e.g., ViT, ResNet, LLM) should:
  1) decide which gates map to which structural dims (heads, hidden groups, channels),
  2) obtain KEEP indices using helpers in this file, and
  3) rebuild family-specific modules with the sliced weights.

Provided here:
  - Rounding policies and helpers (floors, multiples, warmup keep-all)
  - KEEP index selection from a `Gate` (or gate-like) object
  - Generic weight slicers for Linear / Conv2d / Embedding
  - Small safe-guards for dtype/device and shape checks

The library avoids touching family internals here. Exporters in adapters should
use these primitives to assemble a clean pruned model.
"""
from __future__ import annotations

from dataclasses import dataclass
from typing import Iterable, Optional, Sequence, Tuple

import torch
import torch.nn as nn

from .gates import Gate, expand_group_indices

# -----------------------------------------------------------------------------
# Policies & rounding
# -----------------------------------------------------------------------------

@dataclass
class Rounding:
    """Rounding policy for a single gated axis.

    Attributes
    ----------
    floor_groups : int
        Minimum number of groups to keep after rounding.
    multiple_groups : int
        Snap the number of groups kept down to a multiple of this (>=1).
    min_keep_ratio : float
        Optional fractional lower bound on expected keep; applied before rounding.
    """

    floor_groups: int = 1
    multiple_groups: int = 1
    min_keep_ratio: float = 0.0


@dataclass
class ExportPolicy:
    """Export-time policy shared by families.

    - `warmup_steps`: if current `step < warmup_steps`, keep-all.
    - `rounding`: default rounding used unless adapter overrides per-axis.
    """

    warmup_steps: int = 0
    rounding: Rounding = Rounding()


def _round_down_mult(n: int, m: int) -> int:
    if m is None or m <= 1:
        return max(1, int(n))
    n = int(n)
    return max(m, (n // m) * m)


def _compute_keep_k(
    expected_kept: float,
    total_groups: int,
    *,
    rounding: Rounding,
) -> int:
    # Start from nearest-integer expectation
    k = int(round(expected_kept))
    # Apply ratio floor, then absolute floor, then multiple snapping
    k = max(k, int(rounding.min_keep_ratio * total_groups))
    k = max(k, int(rounding.floor_groups))
    k = min(k, total_groups)
    k = _round_down_mult(k, int(rounding.multiple_groups))
    return max(1, min(k, total_groups))


# -----------------------------------------------------------------------------
# KEEP index selection from a gate
# -----------------------------------------------------------------------------

@torch.no_grad()
def keep_group_indices_from_gate(
    gate: Gate,
    *,
    policy: ExportPolicy,
    step: Optional[int] = None,
    custom_rounding: Optional[Rounding] = None,
) -> torch.Tensor:
    """Return sorted indices of groups to KEEP based on `gate` and policy.

    If `step < warmup_steps`, returns all indices (keep-all). Otherwise, the
    number of groups to keep is computed from the *expected keep* under the
    current logits and snapped according to the rounding policy.
    """
    G = int(gate.num_groups)
    if step is not None and step < int(policy.warmup_steps):
        return torch.arange(G, device=gate.logits.device)

    rounding = custom_rounding or policy.rounding
    p = torch.sigmoid(gate.logits.detach().float() / float(gate.tau))
    k = _compute_keep_k(expected_kept=float(p.sum()), total_groups=G, rounding=rounding)
    idx = torch.topk(p, k, largest=True).indices.sort().values
    return idx.to(torch.long)


@torch.no_grad()
def keep_element_indices_from_gate(
    gate: Gate,
    *,
    policy: ExportPolicy,
    step: Optional[int] = None,
    custom_rounding: Optional[Rounding] = None,
) -> torch.Tensor:
    """Expand kept *group* indices into element indices using `group_size`."""
    grp_idx = keep_group_indices_from_gate(gate, policy=policy, step=step, custom_rounding=custom_rounding)
    return expand_group_indices(grp_idx, gate.group_size)


# -----------------------------------------------------------------------------
# Generic slicers
# -----------------------------------------------------------------------------

@torch.no_grad()
def slice_linear(mat: nn.Linear, keep_in: Optional[Sequence[int]] = None, keep_out: Optional[Sequence[int]] = None) -> nn.Linear:
    """Create a new Linear with selected input/output features preserved.

    - `keep_out` selects rows (output features)
    - `keep_in`  selects columns (input features)
    """
    W = mat.weight.detach()
    b = mat.bias.detach() if mat.bias is not None else None

    if keep_out is not None:
        W = W.index_select(0, torch.as_tensor(keep_out, device=W.device))
        if b is not None:
            b = b.index_select(0, torch.as_tensor(keep_out, device=b.device))
    if keep_in is not None:
        W = W.index_select(1, torch.as_tensor(keep_in, device=W.device))

    out_f, in_f = W.shape
    new = nn.Linear(in_f, out_f, bias=(b is not None)).to(W.device)
    new.weight.copy_(W)
    if b is not None:
        new.bias.copy_(b)
    return new


@torch.no_grad()
def slice_conv2d(conv: nn.Conv2d, keep_in: Optional[Sequence[int]] = None, keep_out: Optional[Sequence[int]] = None) -> nn.Conv2d:
    """Create a new Conv2d with selected in/out channels preserved.

    Only supports standard conv2d (no groups/depthwise changes). For grouped
    convs, the adapter should handle group alignment before calling this.
    """
    W = conv.weight.detach()
    b = conv.bias.detach() if conv.bias is not None else None

    if keep_out is not None:
        W = W.index_select(0, torch.as_tensor(keep_out, device=W.device))
        if b is not None:
            b = b.index_select(0, torch.as_tensor(keep_out, device=b.device))
    if keep_in is not None:
        W = W.index_select(1, torch.as_tensor(keep_in, device=W.device))

    out_c, in_c = W.shape[:2]
    new = nn.Conv2d(
        in_c,
        out_c,
        kernel_size=conv.kernel_size,
        stride=conv.stride,
        padding=conv.padding,
        dilation=conv.dilation,
        groups=1,
        bias=(b is not None),
        padding_mode=conv.padding_mode,
    ).to(W.device)
    new.weight.copy_(W)
    if b is not None:
        new.bias.copy_(b)
    return new


@torch.no_grad()
def slice_embedding(emb: nn.Embedding, keep_rows: Optional[Sequence[int]] = None, keep_dim: Optional[Sequence[int]] = None) -> nn.Embedding:
    """Create a new Embedding with selected rows (vocab) and/or dims kept."""
    W = emb.weight.detach()
    if keep_rows is not None:
        W = W.index_select(0, torch.as_tensor(keep_rows, device=W.device))
    if keep_dim is not None:
        W = W.index_select(1, torch.as_tensor(keep_dim, device=W.device))
    num, dim = W.shape
    new = nn.Embedding(num, dim, padding_idx=emb.padding_idx, max_norm=emb.max_norm, norm_type=emb.norm_type, scale_grad_by_freq=emb.scale_grad_by_freq, sparse=emb.sparse, device=W.device, dtype=W.dtype)
    new.weight.copy_(W)
    return new


# -----------------------------------------------------------------------------
# Small helpers for adapters
# -----------------------------------------------------------------------------

@torch.no_grad()
def concat_index_ranges(ranges: Sequence[Tuple[int, int]]) -> torch.Tensor:
    """Given [(start, end_exclusive), ...], return concatenated 1D indices."""
    parts = [torch.arange(a, b, dtype=torch.long) for a, b in ranges if b > a]
    return torch.cat(parts, dim=0) if parts else torch.empty(0, dtype=torch.long)


@torch.no_grad()
def block_indices_from_groups(groups: Sequence[int], group_size: int) -> torch.Tensor:
    """Convert sorted group ids to expanded feature indices."""
    groups = torch.as_tensor(groups, dtype=torch.long)
    return expand_group_indices(groups, int(group_size))