File size: 4,109 Bytes
8125804
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
"""Adaptive Routing β€” data flows where it needs processing.

Based on Equilibrium Signal, tokens are routed:
- Forward: next layer (normal)
- Branch: split into 2-3 routes through different layer paths
- Backward: re-process through earlier layers (selective forgetting)
- Plastic: activate plastic layer for adaptation

Uses SoA scatter/gather for efficient batching:
tokens grouped by route into dense arrays for GPU efficiency.
"""
import torch
import torch.nn as nn
from src.model.equilibrium import EquilibriumSignal, RoutingDecision, TokenEnergyBudget


class ScatterGather(nn.Module):
    """Groups tokens by route into dense buckets, processes, then restores order.

    Scatter: sort tokens by route β†’ dense per-route arrays
    Compute: process each route's tokens as a dense batch
    Gather: restore original token positions
    """

    @staticmethod
    def scatter(x: torch.Tensor, route: torch.Tensor, n_routes: int = 4) -> dict:
        """Group tokens by route.

        Args:
            x: [B, T, D] β€” token representations
            route: [B, T] β€” route assignment (0..n_routes-1)

        Returns:
            dict mapping route_id β†’ {tokens: [N_i, D], indices: [(b,t) pairs]}
        """
        B, T, D = x.shape
        buckets = {}
        for r in range(n_routes):
            mask = (route == r)  # [B, T]
            if mask.any():
                # Gather tokens for this route
                indices = mask.nonzero(as_tuple=False)  # [N_i, 2] β€” (batch_idx, seq_idx)
                tokens = x[indices[:, 0], indices[:, 1]]  # [N_i, D]
                buckets[r] = {"tokens": tokens, "indices": indices}
        return buckets

    @staticmethod
    def gather(buckets: dict, shape: tuple, device: torch.device) -> torch.Tensor:
        """Restore tokens to original positions.

        Args:
            buckets: dict from scatter, with updated tokens
            shape: (B, T, D) β€” original shape
            device: target device

        Returns:
            x: [B, T, D] β€” reconstructed tensor
        """
        x = torch.zeros(shape, device=device)
        for r, bucket in buckets.items():
            indices = bucket["indices"]
            x[indices[:, 0], indices[:, 1]] = bucket["tokens"]
        return x


class AdaptiveRouter(nn.Module):
    """Full adaptive routing module.

    Integrates equilibrium signal, routing decision, energy budget,
    and scatter/gather for each transformer layer.
    """

    def __init__(self, d_model: int, n_layers: int):
        super().__init__()
        self.d_model = d_model
        self.n_layers = n_layers

        # Per-layer equilibrium signals
        self.eq_signals = nn.ModuleList([
            EquilibriumSignal(d_model) for _ in range(n_layers)
        ])
        self.router = RoutingDecision()
        self.energy = TokenEnergyBudget()

    def compute_route(self, x: torch.Tensor, layer_idx: int) -> dict:
        """Compute routing for tokens at a given layer.

        Args:
            x: [B, T, D] β€” activations after layer
            layer_idx: which layer just processed

        Returns:
            dict with ed, route, route_probs, budget
        """
        eq_out = self.eq_signals[layer_idx](x)
        route_out = self.router(eq_out["ed"])
        budget = self.energy(eq_out["ed"], route_out["route_probs"])

        return {
            "ed": eq_out["ed"],
            "route": route_out["route"],
            "route_probs": route_out["route_probs"],
            "budget": budget,
        }

    def get_route_stats(self, route: torch.Tensor) -> dict:
        """Get statistics about routing decisions.

        Args:
            route: [B, T] β€” route assignments

        Returns:
            dict with counts and ratios for each route
        """
        total = route.numel()
        stats = {}
        names = ["forward", "branch", "backward", "plastic"]
        for i, name in enumerate(names):
            count = (route == i).sum().item()
            stats[name] = count
            stats[f"{name}_ratio"] = count / total if total > 0 else 0
        return stats