Chucks90's picture
download
raw
4.28 kB
"""Constrained token pruner with an interpretable dual variable mu (formalization §4).
Solves, per image, the constrained problem
min_m sum_i m_i s.t. C*(x) - C(S;x) <= epsilon
via the Lagrangian
J(m, mu) = sum_i m_i + mu * (C*(x) - C(S;x) - epsilon), mu >= 0
with primal gradient descent on the gate logits (Gumbel straight-through mask) and dual
ascent on mu:
mu <- [ mu + eta_mu * (C*(x) - C(S;x) - epsilon) ]_+ .
mu reads as the marginal token cost of one unit of preserved lesion coverage. When the
coverage floor is violated mu rises (retain more tokens); when satisfied it decays (prune
more). This is the controller — no RL (anti-goal §5). Operates on FROZEN features Z and a
label-free lesion subspace projector P_L; coverage is the RankMe functional (or coding-rate
surrogate). The contribution is this constraint, not the backbone.
"""
from __future__ import annotations
from dataclasses import dataclass
import torch
from coverage.rankme import coverage as rankme_coverage
from .mask_gumbel import gumbel_sigmoid, threshold_mask
@dataclass
class PrunerResult:
mask: torch.Tensor # (n,) hard retention mask at inference
mu: float # final dual value
delta_C: float # C*(x) - C(S;x) under the applied mask
k: int # retained budget |S|
C_star: float # dense coverage reference
C_S: float # retained coverage
mu_trajectory: list # dual trajectory (for Gate 4 stability check)
satisfied: bool # delta_C <= epsilon
class ConstrainedPruner:
def __init__(self, epsilon: float, steps: int = 200, lr: float = 0.5,
eta_mu: float = 0.2, tau: float = 0.5, mu_init: float = 1.0,
keep_init: float = 2.0, coverage_fn=None, momentum: float = 0.9,
mu_max: float = 1e4, cost_scale: float = 1.0, seed: int = 0):
self.epsilon = epsilon
self.steps = steps
self.lr = lr # SGD lr: dual mu must scale the step, so NOT Adam
self.eta_mu = eta_mu
self.tau = tau
self.mu_init = mu_init
self.keep_init = keep_init # init logits > 0 => start by keeping most tokens
self.coverage_fn = coverage_fn or rankme_coverage
self.momentum = momentum
self.mu_max = mu_max
self.cost_scale = cost_scale
self.seed = seed
def fit_image(self, Z: torch.Tensor, P_L: torch.Tensor) -> PrunerResult:
"""Optimize the per-image mask. Z: (n,d) frozen tokens; P_L: (d,d) lesion projector."""
device = Z.device
Z = Z.float()
P_L = P_L.to(device).float()
gen = torch.Generator(device=device).manual_seed(self.seed)
n = Z.shape[0]
theta = torch.full((n,), float(self.keep_init), device=device, requires_grad=True)
# SGD (not Adam): the dual mu scales the constraint gradient, and only a
# non-normalizing optimizer lets mu actually trade off coverage vs token cost.
opt = torch.optim.SGD([theta], lr=self.lr, momentum=self.momentum)
C_star = self.coverage_fn(Z, P_L).detach()
cost_scale = self.cost_scale
mu = torch.tensor(float(self.mu_init), device=device)
mu_traj = []
for _ in range(self.steps):
opt.zero_grad()
m = gumbel_sigmoid(theta, tau=self.tau, hard=True, generator=gen)
Z_S = Z * m[:, None]
C_S = self.coverage_fn(Z_S, P_L)
violation = C_star - C_S - self.epsilon
J = cost_scale * m.sum() + mu.detach() * violation
J.backward()
opt.step()
with torch.no_grad():
mu = (mu + self.eta_mu * violation.detach()).clamp_(0.0, self.mu_max)
mu_traj.append(float(mu))
with torch.no_grad():
m_hard = threshold_mask(theta)
Z_S = Z * m_hard[:, None]
C_S_final = float(self.coverage_fn(Z_S, P_L))
delta_C = float(C_star) - C_S_final
return PrunerResult(
mask=m_hard.detach(), mu=float(mu), delta_C=delta_C, k=int(m_hard.sum()),
C_star=float(C_star), C_S=C_S_final, mu_trajectory=mu_traj,
satisfied=bool(delta_C <= self.epsilon),
)

Xet Storage Details

Size:
4.28 kB
·
Xet hash:
43d798f63941067ad7c28a456b5dcaf03187e89605e6c9015155695bae61e17b

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.