Spaces:
Sleeping
Sleeping
File size: 8,198 Bytes
8125804 f37be5a 8125804 f37be5a 8125804 f37be5a 8125804 f37be5a 8125804 f37be5a 8125804 f37be5a 8125804 f37be5a 8125804 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 | """Equilibrium Signal β unified routing/confidence/plasticity trigger.
Computes deviation of activations from running mean (accumulated during training).
Near-zero overhead: reuses LayerNorm statistics.
ED(x) = || (x - mu_running) / sigma_running ||
Small ED β forward pass (confident)
Medium ED β branching (uncertain, explore alternatives)
Large ED β backward pass (re-process through earlier layers)
Critical ED β plastic activation (adapt to new context)
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
class EquilibriumSignal(nn.Module):
def __init__(self, d_model: int, momentum: float = 0.1, warmup_steps: int = 50):
super().__init__()
self.d_model = d_model
self.momentum = momentum
self.warmup_steps = warmup_steps
# Running statistics (like BatchNorm)
self.register_buffer("running_mean", torch.zeros(d_model))
self.register_buffer("running_var", torch.ones(d_model))
self.register_buffer("num_batches_tracked", torch.tensor(0, dtype=torch.long))
@property
def is_warming_up(self) -> bool:
return self.num_batches_tracked.item() < self.warmup_steps
def forward(self, x: torch.Tensor) -> dict:
"""Compute equilibrium deviation for each token.
Args:
x: [B, T, D] β layer output activations
Returns:
dict with:
ed: [B, T] β equilibrium deviation per token
x: [B, T, D] β unchanged input (pass-through)
warming_up: bool β True if still in warmup phase
"""
# Update running stats during training
if self.training:
with torch.no_grad():
batch_mean = x.detach().mean(dim=(0, 1)) # [D]
batch_var = x.detach().var(dim=(0, 1)) # [D]
self.running_mean.mul_(1 - self.momentum).add_(batch_mean, alpha=self.momentum)
self.running_var.mul_(1 - self.momentum).add_(batch_var, alpha=self.momentum)
self.num_batches_tracked += 1
# Compute ED: normalized distance from running mean
# [B, T, D]
normalized = (x - self.running_mean) / (self.running_var.sqrt() + 1e-8)
# [B, T] β L2 norm over feature dim, normalized by sqrt(d_model)
ed = normalized.norm(dim=-1) / (self.d_model ** 0.5)
return {"ed": ed, "x": x, "warming_up": self.is_warming_up}
class RoutingDecision(nn.Module):
"""Converts equilibrium deviation into routing decisions.
Buckets are calibrated from running ED quantiles so Stage B does not collapse
to a single route just because the absolute ED scale shifted.
Small learnable offsets allow training to nudge boundaries around the
quantile-derived defaults.
"""
def __init__(
self,
init_thresholds: tuple[float, float, float] = (0.75, 1.0, 1.35),
target_fractions: tuple[float, float, float, float] = (0.55, 0.25, 0.15, 0.05),
threshold_momentum: float = 0.2,
temperature: float = 8.0,
offset_scale: float = 0.2,
):
super().__init__()
fractions = torch.tensor(target_fractions, dtype=torch.float32)
fractions = fractions / fractions.sum().clamp_min(1e-8)
self.register_buffer("target_cdf", fractions.cumsum(dim=0)[:-1])
self.register_buffer("running_thresholds", torch.tensor(init_thresholds, dtype=torch.float32))
self.register_buffer("num_batches_tracked", torch.tensor(0, dtype=torch.long))
self.threshold_momentum = threshold_momentum
self.temperature = temperature
self.offset_scale = offset_scale
self.threshold_offsets = nn.Parameter(torch.zeros(3))
def _batch_thresholds(self, ed: torch.Tensor) -> torch.Tensor:
flat = ed.detach().reshape(-1)
if flat.numel() == 0:
return self.running_thresholds
return torch.quantile(flat, self.target_cdf.to(device=ed.device, dtype=flat.dtype))
def _ordered_thresholds(self, base: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
offsets = self.offset_scale * torch.tanh(self.threshold_offsets).to(base.device, base.dtype)
raw = base + offsets
min_gap = torch.tensor(1e-3, device=base.device, dtype=base.dtype)
t1 = raw[0]
t2 = torch.maximum(raw[1], t1 + min_gap)
t3 = torch.maximum(raw[2], t2 + min_gap)
return t1, t2, t3
@property
def theta1(self) -> torch.Tensor:
thresholds = self._ordered_thresholds(self.running_thresholds)
return thresholds[0]
@property
def theta2(self) -> torch.Tensor:
thresholds = self._ordered_thresholds(self.running_thresholds)
return thresholds[1]
@property
def theta3(self) -> torch.Tensor:
thresholds = self._ordered_thresholds(self.running_thresholds)
return thresholds[2]
def forward(self, ed: torch.Tensor) -> dict:
"""Classify each token into routing buckets.
Args:
ed: [B, T] β equilibrium deviation
Returns:
dict with:
route: [B, T] β 0=forward, 1=branch, 2=backward, 3=plastic
route_probs: [B, T, 4] β soft routing probabilities
"""
if self.training:
batch_thresholds = self._batch_thresholds(ed)
with torch.no_grad():
self.running_thresholds.mul_(1 - self.threshold_momentum).add_(
batch_thresholds.to(self.running_thresholds.device, self.running_thresholds.dtype),
alpha=self.threshold_momentum,
)
self.num_batches_tracked += 1
base_thresholds = self.running_thresholds.to(device=ed.device, dtype=ed.dtype)
elif self.num_batches_tracked.item() > 0:
base_thresholds = self.running_thresholds.to(device=ed.device, dtype=ed.dtype)
else:
base_thresholds = self._batch_thresholds(ed).to(device=ed.device, dtype=ed.dtype)
t1, t2, t3 = self._ordered_thresholds(base_thresholds)
left_width = (t2 - t1).clamp_min(1e-3)
right_width = (t3 - t2).clamp_min(1e-3)
centers = torch.stack(
[
t1 - left_width,
(t1 + t2) * 0.5,
(t2 + t3) * 0.5,
t3 + right_width,
]
)
logits = -self.temperature * (ed.unsqueeze(-1) - centers).abs()
probs = torch.softmax(logits, dim=-1)
thresholds = torch.stack([t1, t2, t3])
route = torch.bucketize(ed, thresholds)
return {"route": route, "route_probs": probs, "thresholds": thresholds}
class TokenEnergyBudget(nn.Module):
"""Limits compute per token based on ED.
Low ED β 1 pass (minimum compute)
Medium ED β 2 passes (branching)
High ED β 3+ passes (backward + re-process)
Total budget across all tokens is capped.
"""
def __init__(self, max_budget_per_token: int = 4, total_budget_ratio: float = 2.0):
super().__init__()
self.max_per_token = max_budget_per_token
self.total_budget_ratio = total_budget_ratio
def forward(self, ed: torch.Tensor, route_probs: torch.Tensor) -> torch.Tensor:
"""Compute energy budget per token.
Args:
ed: [B, T] β equilibrium deviation
route_probs: [B, T, 4] β routing probabilities
Returns:
budget: [B, T] β integer compute budget per token (1 to max_per_token)
"""
B, T = ed.shape
total_budget = int(T * self.total_budget_ratio)
# Base budget from route: forward=1, branch=2, backward=3, plastic=4
base_costs = torch.tensor([1.0, 2.0, 3.0, 4.0], device=ed.device)
expected_cost = (route_probs * base_costs).sum(dim=-1) # [B, T]
# Scale to fit total budget
cost_sum = expected_cost.sum(dim=-1, keepdim=True) # [B, 1]
scale = total_budget / (cost_sum + 1e-8)
scale = scale.clamp(max=1.0) # don't inflate, only deflate
budget = (expected_cost * scale).clamp(min=1, max=self.max_per_token)
return budget.round().long()
|