abpt / src /model /adaptive_routing.py
Search
feat: add src/ module for script imports
8125804
"""Adaptive Routing β€” data flows where it needs processing.
Based on Equilibrium Signal, tokens are routed:
- Forward: next layer (normal)
- Branch: split into 2-3 routes through different layer paths
- Backward: re-process through earlier layers (selective forgetting)
- Plastic: activate plastic layer for adaptation
Uses SoA scatter/gather for efficient batching:
tokens grouped by route into dense arrays for GPU efficiency.
"""
import torch
import torch.nn as nn
from src.model.equilibrium import EquilibriumSignal, RoutingDecision, TokenEnergyBudget
class ScatterGather(nn.Module):
"""Groups tokens by route into dense buckets, processes, then restores order.
Scatter: sort tokens by route β†’ dense per-route arrays
Compute: process each route's tokens as a dense batch
Gather: restore original token positions
"""
@staticmethod
def scatter(x: torch.Tensor, route: torch.Tensor, n_routes: int = 4) -> dict:
"""Group tokens by route.
Args:
x: [B, T, D] β€” token representations
route: [B, T] β€” route assignment (0..n_routes-1)
Returns:
dict mapping route_id β†’ {tokens: [N_i, D], indices: [(b,t) pairs]}
"""
B, T, D = x.shape
buckets = {}
for r in range(n_routes):
mask = (route == r) # [B, T]
if mask.any():
# Gather tokens for this route
indices = mask.nonzero(as_tuple=False) # [N_i, 2] β€” (batch_idx, seq_idx)
tokens = x[indices[:, 0], indices[:, 1]] # [N_i, D]
buckets[r] = {"tokens": tokens, "indices": indices}
return buckets
@staticmethod
def gather(buckets: dict, shape: tuple, device: torch.device) -> torch.Tensor:
"""Restore tokens to original positions.
Args:
buckets: dict from scatter, with updated tokens
shape: (B, T, D) β€” original shape
device: target device
Returns:
x: [B, T, D] β€” reconstructed tensor
"""
x = torch.zeros(shape, device=device)
for r, bucket in buckets.items():
indices = bucket["indices"]
x[indices[:, 0], indices[:, 1]] = bucket["tokens"]
return x
class AdaptiveRouter(nn.Module):
"""Full adaptive routing module.
Integrates equilibrium signal, routing decision, energy budget,
and scatter/gather for each transformer layer.
"""
def __init__(self, d_model: int, n_layers: int):
super().__init__()
self.d_model = d_model
self.n_layers = n_layers
# Per-layer equilibrium signals
self.eq_signals = nn.ModuleList([
EquilibriumSignal(d_model) for _ in range(n_layers)
])
self.router = RoutingDecision()
self.energy = TokenEnergyBudget()
def compute_route(self, x: torch.Tensor, layer_idx: int) -> dict:
"""Compute routing for tokens at a given layer.
Args:
x: [B, T, D] β€” activations after layer
layer_idx: which layer just processed
Returns:
dict with ed, route, route_probs, budget
"""
eq_out = self.eq_signals[layer_idx](x)
route_out = self.router(eq_out["ed"])
budget = self.energy(eq_out["ed"], route_out["route_probs"])
return {
"ed": eq_out["ed"],
"route": route_out["route"],
"route_probs": route_out["route_probs"],
"budget": budget,
}
def get_route_stats(self, route: torch.Tensor) -> dict:
"""Get statistics about routing decisions.
Args:
route: [B, T] β€” route assignments
Returns:
dict with counts and ratios for each route
"""
total = route.numel()
stats = {}
names = ["forward", "branch", "backward", "plastic"]
for i, name in enumerate(names):
count = (route == i).sum().item()
stats[name] = count
stats[f"{name}_ratio"] = count / total if total > 0 else 0
return stats