File size: 8,198 Bytes
8125804
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f37be5a
 
 
 
8125804
 
f37be5a
 
 
 
 
 
 
 
8125804
f37be5a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8125804
 
 
f37be5a
 
8125804
 
 
f37be5a
 
8125804
 
 
f37be5a
 
8125804
 
 
 
 
 
 
 
 
 
 
 
f37be5a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8125804
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
"""Equilibrium Signal β€” unified routing/confidence/plasticity trigger.

Computes deviation of activations from running mean (accumulated during training).
Near-zero overhead: reuses LayerNorm statistics.

ED(x) = || (x - mu_running) / sigma_running ||

Small ED β†’ forward pass (confident)
Medium ED β†’ branching (uncertain, explore alternatives)
Large ED β†’ backward pass (re-process through earlier layers)
Critical ED β†’ plastic activation (adapt to new context)
"""
import torch
import torch.nn as nn
import torch.nn.functional as F


class EquilibriumSignal(nn.Module):
    def __init__(self, d_model: int, momentum: float = 0.1, warmup_steps: int = 50):
        super().__init__()
        self.d_model = d_model
        self.momentum = momentum
        self.warmup_steps = warmup_steps
        # Running statistics (like BatchNorm)
        self.register_buffer("running_mean", torch.zeros(d_model))
        self.register_buffer("running_var", torch.ones(d_model))
        self.register_buffer("num_batches_tracked", torch.tensor(0, dtype=torch.long))

    @property
    def is_warming_up(self) -> bool:
        return self.num_batches_tracked.item() < self.warmup_steps

    def forward(self, x: torch.Tensor) -> dict:
        """Compute equilibrium deviation for each token.

        Args:
            x: [B, T, D] β€” layer output activations

        Returns:
            dict with:
                ed: [B, T] β€” equilibrium deviation per token
                x: [B, T, D] β€” unchanged input (pass-through)
                warming_up: bool β€” True if still in warmup phase
        """
        # Update running stats during training
        if self.training:
            with torch.no_grad():
                batch_mean = x.detach().mean(dim=(0, 1))  # [D]
                batch_var = x.detach().var(dim=(0, 1))  # [D]
                self.running_mean.mul_(1 - self.momentum).add_(batch_mean, alpha=self.momentum)
                self.running_var.mul_(1 - self.momentum).add_(batch_var, alpha=self.momentum)
                self.num_batches_tracked += 1

        # Compute ED: normalized distance from running mean
        # [B, T, D]
        normalized = (x - self.running_mean) / (self.running_var.sqrt() + 1e-8)
        # [B, T] β€” L2 norm over feature dim, normalized by sqrt(d_model)
        ed = normalized.norm(dim=-1) / (self.d_model ** 0.5)

        return {"ed": ed, "x": x, "warming_up": self.is_warming_up}


class RoutingDecision(nn.Module):
    """Converts equilibrium deviation into routing decisions.

    Buckets are calibrated from running ED quantiles so Stage B does not collapse
    to a single route just because the absolute ED scale shifted.
    Small learnable offsets allow training to nudge boundaries around the
    quantile-derived defaults.
    """

    def __init__(
        self,
        init_thresholds: tuple[float, float, float] = (0.75, 1.0, 1.35),
        target_fractions: tuple[float, float, float, float] = (0.55, 0.25, 0.15, 0.05),
        threshold_momentum: float = 0.2,
        temperature: float = 8.0,
        offset_scale: float = 0.2,
    ):
        super().__init__()
        fractions = torch.tensor(target_fractions, dtype=torch.float32)
        fractions = fractions / fractions.sum().clamp_min(1e-8)
        self.register_buffer("target_cdf", fractions.cumsum(dim=0)[:-1])
        self.register_buffer("running_thresholds", torch.tensor(init_thresholds, dtype=torch.float32))
        self.register_buffer("num_batches_tracked", torch.tensor(0, dtype=torch.long))
        self.threshold_momentum = threshold_momentum
        self.temperature = temperature
        self.offset_scale = offset_scale
        self.threshold_offsets = nn.Parameter(torch.zeros(3))

    def _batch_thresholds(self, ed: torch.Tensor) -> torch.Tensor:
        flat = ed.detach().reshape(-1)
        if flat.numel() == 0:
            return self.running_thresholds
        return torch.quantile(flat, self.target_cdf.to(device=ed.device, dtype=flat.dtype))

    def _ordered_thresholds(self, base: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        offsets = self.offset_scale * torch.tanh(self.threshold_offsets).to(base.device, base.dtype)
        raw = base + offsets
        min_gap = torch.tensor(1e-3, device=base.device, dtype=base.dtype)
        t1 = raw[0]
        t2 = torch.maximum(raw[1], t1 + min_gap)
        t3 = torch.maximum(raw[2], t2 + min_gap)
        return t1, t2, t3

    @property
    def theta1(self) -> torch.Tensor:
        thresholds = self._ordered_thresholds(self.running_thresholds)
        return thresholds[0]

    @property
    def theta2(self) -> torch.Tensor:
        thresholds = self._ordered_thresholds(self.running_thresholds)
        return thresholds[1]

    @property
    def theta3(self) -> torch.Tensor:
        thresholds = self._ordered_thresholds(self.running_thresholds)
        return thresholds[2]

    def forward(self, ed: torch.Tensor) -> dict:
        """Classify each token into routing buckets.

        Args:
            ed: [B, T] β€” equilibrium deviation

        Returns:
            dict with:
                route: [B, T] β€” 0=forward, 1=branch, 2=backward, 3=plastic
                route_probs: [B, T, 4] β€” soft routing probabilities
        """
        if self.training:
            batch_thresholds = self._batch_thresholds(ed)
            with torch.no_grad():
                self.running_thresholds.mul_(1 - self.threshold_momentum).add_(
                    batch_thresholds.to(self.running_thresholds.device, self.running_thresholds.dtype),
                    alpha=self.threshold_momentum,
                )
                self.num_batches_tracked += 1
            base_thresholds = self.running_thresholds.to(device=ed.device, dtype=ed.dtype)
        elif self.num_batches_tracked.item() > 0:
            base_thresholds = self.running_thresholds.to(device=ed.device, dtype=ed.dtype)
        else:
            base_thresholds = self._batch_thresholds(ed).to(device=ed.device, dtype=ed.dtype)

        t1, t2, t3 = self._ordered_thresholds(base_thresholds)
        left_width = (t2 - t1).clamp_min(1e-3)
        right_width = (t3 - t2).clamp_min(1e-3)
        centers = torch.stack(
            [
                t1 - left_width,
                (t1 + t2) * 0.5,
                (t2 + t3) * 0.5,
                t3 + right_width,
            ]
        )
        logits = -self.temperature * (ed.unsqueeze(-1) - centers).abs()
        probs = torch.softmax(logits, dim=-1)

        thresholds = torch.stack([t1, t2, t3])
        route = torch.bucketize(ed, thresholds)

        return {"route": route, "route_probs": probs, "thresholds": thresholds}


class TokenEnergyBudget(nn.Module):
    """Limits compute per token based on ED.

    Low ED β†’ 1 pass (minimum compute)
    Medium ED β†’ 2 passes (branching)
    High ED β†’ 3+ passes (backward + re-process)

    Total budget across all tokens is capped.
    """

    def __init__(self, max_budget_per_token: int = 4, total_budget_ratio: float = 2.0):
        super().__init__()
        self.max_per_token = max_budget_per_token
        self.total_budget_ratio = total_budget_ratio

    def forward(self, ed: torch.Tensor, route_probs: torch.Tensor) -> torch.Tensor:
        """Compute energy budget per token.

        Args:
            ed: [B, T] β€” equilibrium deviation
            route_probs: [B, T, 4] β€” routing probabilities

        Returns:
            budget: [B, T] β€” integer compute budget per token (1 to max_per_token)
        """
        B, T = ed.shape
        total_budget = int(T * self.total_budget_ratio)

        # Base budget from route: forward=1, branch=2, backward=3, plastic=4
        base_costs = torch.tensor([1.0, 2.0, 3.0, 4.0], device=ed.device)
        expected_cost = (route_probs * base_costs).sum(dim=-1)  # [B, T]

        # Scale to fit total budget
        cost_sum = expected_cost.sum(dim=-1, keepdim=True)  # [B, 1]
        scale = total_budget / (cost_sum + 1e-8)
        scale = scale.clamp(max=1.0)  # don't inflate, only deflate

        budget = (expected_cost * scale).clamp(min=1, max=self.max_per_token)
        return budget.round().long()