Chucks90's picture
download
raw
2.47 kB
"""Phase 6 — coverage-routed adaptive depth (MoD-style), inference-time.
Tokens are routed by lesion-subspace coverage at a routing block L_route: the top-f fraction
by density-A membership continue through the remaining blocks (full depth); the rest exit
early at L_route. Lesion-candidate tokens (high coverage) keep full depth, so lesion features
are preserved; abundant non-lesion tokens are computed shallow, cutting FLOPs.
FLOP model for a ViT (per block ~ linear in active tokens for the MLP+projection terms, plus
a quadratic attention term). With n tokens, L total blocks, routing after L_route, retaining
fraction f for the deep blocks:
dense ~ L * (a*n + b*n^2)
routed ~ L_route*(a*n + b*n^2) + (L-L_route)*(a*f*n + b*(f*n)^2)
flop_reduction = dense / routed. Gate 6 (routed-depth) PASS: >= 1.5x at equal small-lesion
sensitivity (lesion-patch recall within tol of dense).
"""
from __future__ import annotations
import numpy as np
def flop_reduction(f: float, L_route: int, L_total: int = 12,
attn_frac: float = 0.0) -> float:
"""Dense/routed FLOP ratio. attn_frac in [0,1] weights the quadratic attention term
(0 = MLP/proj-dominated linear model; ~0.5 = attention-heavy)."""
def cost(n_frac):
lin = (1 - attn_frac) * n_frac
quad = attn_frac * n_frac * n_frac
return lin + quad
dense = L_total * cost(1.0)
routed = L_route * cost(1.0) + (L_total - L_route) * cost(f)
return float(dense / routed)
def route_topf(membership_scores: np.ndarray, f: float) -> np.ndarray:
"""Boolean mask of the top-f fraction of tokens by coverage membership (kept deep)."""
n = len(membership_scores)
k = max(1, int(round(f * n)))
keep = np.zeros(n, bool)
keep[np.argsort(-membership_scores)[:k]] = True
return keep
def best_reduction_at_equal_sensitivity(
f_grid, sensitivities, L_route: int, L_total: int = 12,
dense_sensitivity: float = 1.0, tol: float = 0.02, attn_frac: float = 0.0):
"""Given routed sensitivity per retention f, return the max FLOP reduction (min f)
whose sensitivity is within `tol` of dense. Returns (f*, reduction, sensitivity)."""
best = None
for f, s in sorted(zip(f_grid, sensitivities)): # ascending f
if s >= dense_sensitivity - tol:
red = flop_reduction(f, L_route, L_total, attn_frac)
if best is None or red > best[1]:
best = (f, red, s)
return best

Xet Storage Details

Size:
2.47 kB
·
Xet hash:
ecad8158e6d12bd8ca581b4210a9fc40df381764318a1726cff488c919179015

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.