Spaces:
Sleeping
Sleeping
| """Adaptive Routing β data flows where it needs processing. | |
| Based on Equilibrium Signal, tokens are routed: | |
| - Forward: next layer (normal) | |
| - Branch: split into 2-3 routes through different layer paths | |
| - Backward: re-process through earlier layers (selective forgetting) | |
| - Plastic: activate plastic layer for adaptation | |
| Uses SoA scatter/gather for efficient batching: | |
| tokens grouped by route into dense arrays for GPU efficiency. | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| from src.model.equilibrium import EquilibriumSignal, RoutingDecision, TokenEnergyBudget | |
| class ScatterGather(nn.Module): | |
| """Groups tokens by route into dense buckets, processes, then restores order. | |
| Scatter: sort tokens by route β dense per-route arrays | |
| Compute: process each route's tokens as a dense batch | |
| Gather: restore original token positions | |
| """ | |
| def scatter(x: torch.Tensor, route: torch.Tensor, n_routes: int = 4) -> dict: | |
| """Group tokens by route. | |
| Args: | |
| x: [B, T, D] β token representations | |
| route: [B, T] β route assignment (0..n_routes-1) | |
| Returns: | |
| dict mapping route_id β {tokens: [N_i, D], indices: [(b,t) pairs]} | |
| """ | |
| B, T, D = x.shape | |
| buckets = {} | |
| for r in range(n_routes): | |
| mask = (route == r) # [B, T] | |
| if mask.any(): | |
| # Gather tokens for this route | |
| indices = mask.nonzero(as_tuple=False) # [N_i, 2] β (batch_idx, seq_idx) | |
| tokens = x[indices[:, 0], indices[:, 1]] # [N_i, D] | |
| buckets[r] = {"tokens": tokens, "indices": indices} | |
| return buckets | |
| def gather(buckets: dict, shape: tuple, device: torch.device) -> torch.Tensor: | |
| """Restore tokens to original positions. | |
| Args: | |
| buckets: dict from scatter, with updated tokens | |
| shape: (B, T, D) β original shape | |
| device: target device | |
| Returns: | |
| x: [B, T, D] β reconstructed tensor | |
| """ | |
| x = torch.zeros(shape, device=device) | |
| for r, bucket in buckets.items(): | |
| indices = bucket["indices"] | |
| x[indices[:, 0], indices[:, 1]] = bucket["tokens"] | |
| return x | |
| class AdaptiveRouter(nn.Module): | |
| """Full adaptive routing module. | |
| Integrates equilibrium signal, routing decision, energy budget, | |
| and scatter/gather for each transformer layer. | |
| """ | |
| def __init__(self, d_model: int, n_layers: int): | |
| super().__init__() | |
| self.d_model = d_model | |
| self.n_layers = n_layers | |
| # Per-layer equilibrium signals | |
| self.eq_signals = nn.ModuleList([ | |
| EquilibriumSignal(d_model) for _ in range(n_layers) | |
| ]) | |
| self.router = RoutingDecision() | |
| self.energy = TokenEnergyBudget() | |
| def compute_route(self, x: torch.Tensor, layer_idx: int) -> dict: | |
| """Compute routing for tokens at a given layer. | |
| Args: | |
| x: [B, T, D] β activations after layer | |
| layer_idx: which layer just processed | |
| Returns: | |
| dict with ed, route, route_probs, budget | |
| """ | |
| eq_out = self.eq_signals[layer_idx](x) | |
| route_out = self.router(eq_out["ed"]) | |
| budget = self.energy(eq_out["ed"], route_out["route_probs"]) | |
| return { | |
| "ed": eq_out["ed"], | |
| "route": route_out["route"], | |
| "route_probs": route_out["route_probs"], | |
| "budget": budget, | |
| } | |
| def get_route_stats(self, route: torch.Tensor) -> dict: | |
| """Get statistics about routing decisions. | |
| Args: | |
| route: [B, T] β route assignments | |
| Returns: | |
| dict with counts and ratios for each route | |
| """ | |
| total = route.numel() | |
| stats = {} | |
| names = ["forward", "branch", "backward", "plastic"] | |
| for i, name in enumerate(names): | |
| count = (route == i).sum().item() | |
| stats[name] = count | |
| stats[f"{name}_ratio"] = count / total if total > 0 else 0 | |
| return stats | |