Buckets:
| """Constrained token pruner with an interpretable dual variable mu (formalization §4). | |
| Solves, per image, the constrained problem | |
| min_m sum_i m_i s.t. C*(x) - C(S;x) <= epsilon | |
| via the Lagrangian | |
| J(m, mu) = sum_i m_i + mu * (C*(x) - C(S;x) - epsilon), mu >= 0 | |
| with primal gradient descent on the gate logits (Gumbel straight-through mask) and dual | |
| ascent on mu: | |
| mu <- [ mu + eta_mu * (C*(x) - C(S;x) - epsilon) ]_+ . | |
| mu reads as the marginal token cost of one unit of preserved lesion coverage. When the | |
| coverage floor is violated mu rises (retain more tokens); when satisfied it decays (prune | |
| more). This is the controller — no RL (anti-goal §5). Operates on FROZEN features Z and a | |
| label-free lesion subspace projector P_L; coverage is the RankMe functional (or coding-rate | |
| surrogate). The contribution is this constraint, not the backbone. | |
| """ | |
| from __future__ import annotations | |
| from dataclasses import dataclass | |
| import torch | |
| from coverage.rankme import coverage as rankme_coverage | |
| from .mask_gumbel import gumbel_sigmoid, threshold_mask | |
| class PrunerResult: | |
| mask: torch.Tensor # (n,) hard retention mask at inference | |
| mu: float # final dual value | |
| delta_C: float # C*(x) - C(S;x) under the applied mask | |
| k: int # retained budget |S| | |
| C_star: float # dense coverage reference | |
| C_S: float # retained coverage | |
| mu_trajectory: list # dual trajectory (for Gate 4 stability check) | |
| satisfied: bool # delta_C <= epsilon | |
| class ConstrainedPruner: | |
| def __init__(self, epsilon: float, steps: int = 200, lr: float = 0.5, | |
| eta_mu: float = 0.2, tau: float = 0.5, mu_init: float = 1.0, | |
| keep_init: float = 2.0, coverage_fn=None, momentum: float = 0.9, | |
| mu_max: float = 1e4, cost_scale: float = 1.0, seed: int = 0): | |
| self.epsilon = epsilon | |
| self.steps = steps | |
| self.lr = lr # SGD lr: dual mu must scale the step, so NOT Adam | |
| self.eta_mu = eta_mu | |
| self.tau = tau | |
| self.mu_init = mu_init | |
| self.keep_init = keep_init # init logits > 0 => start by keeping most tokens | |
| self.coverage_fn = coverage_fn or rankme_coverage | |
| self.momentum = momentum | |
| self.mu_max = mu_max | |
| self.cost_scale = cost_scale | |
| self.seed = seed | |
| def fit_image(self, Z: torch.Tensor, P_L: torch.Tensor) -> PrunerResult: | |
| """Optimize the per-image mask. Z: (n,d) frozen tokens; P_L: (d,d) lesion projector.""" | |
| device = Z.device | |
| Z = Z.float() | |
| P_L = P_L.to(device).float() | |
| gen = torch.Generator(device=device).manual_seed(self.seed) | |
| n = Z.shape[0] | |
| theta = torch.full((n,), float(self.keep_init), device=device, requires_grad=True) | |
| # SGD (not Adam): the dual mu scales the constraint gradient, and only a | |
| # non-normalizing optimizer lets mu actually trade off coverage vs token cost. | |
| opt = torch.optim.SGD([theta], lr=self.lr, momentum=self.momentum) | |
| C_star = self.coverage_fn(Z, P_L).detach() | |
| cost_scale = self.cost_scale | |
| mu = torch.tensor(float(self.mu_init), device=device) | |
| mu_traj = [] | |
| for _ in range(self.steps): | |
| opt.zero_grad() | |
| m = gumbel_sigmoid(theta, tau=self.tau, hard=True, generator=gen) | |
| Z_S = Z * m[:, None] | |
| C_S = self.coverage_fn(Z_S, P_L) | |
| violation = C_star - C_S - self.epsilon | |
| J = cost_scale * m.sum() + mu.detach() * violation | |
| J.backward() | |
| opt.step() | |
| with torch.no_grad(): | |
| mu = (mu + self.eta_mu * violation.detach()).clamp_(0.0, self.mu_max) | |
| mu_traj.append(float(mu)) | |
| with torch.no_grad(): | |
| m_hard = threshold_mask(theta) | |
| Z_S = Z * m_hard[:, None] | |
| C_S_final = float(self.coverage_fn(Z_S, P_L)) | |
| delta_C = float(C_star) - C_S_final | |
| return PrunerResult( | |
| mask=m_hard.detach(), mu=float(mu), delta_C=delta_C, k=int(m_hard.sum()), | |
| C_star=float(C_star), C_S=C_S_final, mu_trajectory=mu_traj, | |
| satisfied=bool(delta_C <= self.epsilon), | |
| ) | |
Xet Storage Details
- Size:
- 4.28 kB
- Xet hash:
- 43d798f63941067ad7c28a456b5dcaf03187e89605e6c9015155695bae61e17b
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.