File size: 8,504 Bytes
8125804
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f37be5a
 
 
 
 
 
 
 
 
 
 
8125804
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f37be5a
 
8125804
 
 
 
f37be5a
8125804
f37be5a
 
8125804
 
 
 
 
 
 
 
 
f37be5a
 
 
8125804
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f37be5a
 
 
8125804
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
"""ABPT Stage B — Unified model with Equilibrium Signal + Adaptive Routing.

Stage A modules (AttnRes, Plastic, Branches, Verifier) are still present,
but now controlled by the unified equilibrium signal instead of being always-on.

Three fundamental mechanisms:
1. Equilibrium Signal — deviation from running mean drives everything
2. Adaptive Routing — tokens go forward/branch/backward/plastic based on ED
3. Token Energy Budget — limits compute per token
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from src.model.config import ModelConfig
from src.model.backbone import Backbone, TransformerBlock
from src.model.plastic import PlasticLayer
from src.model.branches import BranchRouter
from src.model.verifier import Verifier
from src.model.equilibrium import EquilibriumSignal, RoutingDecision, TokenEnergyBudget
from src.model.adaptive_routing import ScatterGather


class ABPTModelB(nn.Module):
    """Stage B: Unified ABPT with adaptive routing.

    Key difference from Stage A:
    - Equilibrium signal computed after each layer
    - Routing determines which tokens get branches, backward pass, or plasticity
    - Not all tokens go through all modules — only those that need it
    """

    def __init__(self, cfg: ModelConfig):
        super().__init__()
        self.cfg = cfg
        self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.d_model)
        self.pos_emb = nn.Embedding(cfg.max_seq_len, cfg.d_model)
        self.drop = nn.Dropout(cfg.dropout)

        self.blocks = nn.ModuleList([
            TransformerBlock(cfg, i) for i in range(cfg.n_layers)
        ])
        self.ln_final = nn.LayerNorm(cfg.d_model)

        # Equilibrium signal per layer
        self.eq_signals = nn.ModuleList([
            EquilibriumSignal(cfg.d_model, momentum=cfg.eq_momentum, warmup_steps=cfg.eq_warmup_steps)
            for _ in range(cfg.n_layers)
        ])
        self.router = RoutingDecision(
            target_fractions=(
                cfg.route_forward_target,
                cfg.route_branch_target,
                cfg.route_backward_target,
                cfg.route_plastic_target,
            ),
            threshold_momentum=cfg.route_threshold_momentum,
            temperature=cfg.route_temperature,
            offset_scale=cfg.route_threshold_offset_scale,
        )
        self.energy = TokenEnergyBudget()

        # Stage A modules — activated selectively by routing
        if cfg.use_plastic:
            self.plastic = PlasticLayer(cfg)
        if cfg.use_branches:
            self.branch_router = BranchRouter(cfg)
        if cfg.use_verifier and cfg.use_branches:
            self.verifier = Verifier(cfg)

        # Single lm_head for non-branched tokens
        self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False)

    def forward(
        self,
        input_ids: torch.Tensor,
        targets: torch.Tensor | None = None,
    ) -> dict:
        B, T = input_ids.shape
        pos = torch.arange(T, device=input_ids.device).unsqueeze(0)
        x = self.drop(self.tok_emb(input_ids) + self.pos_emb(pos))

        result = {}
        layer_outputs = [x]
        all_route_stats = []

        for i, block in enumerate(self.blocks):
            # Normal forward pass through block
            x = block(x, layer_outputs)
            layer_outputs.append(x)

            # Compute equilibrium signal
            eq_out = self.eq_signals[i](x)
            ed = eq_out["ed"]  # [B, T]
            route_preview = self.router(ed)
            thresholds = route_preview["thresholds"]

            # During warmup: force all tokens to forward (accumulate stats only)
            if eq_out.get("warming_up", False):
                route = torch.zeros(B, T, dtype=torch.long, device=x.device)
                route_probs = route_preview["route_probs"]
            else:
                route = route_preview["route"]  # [B, T]: 0=fwd, 1=branch, 2=back, 3=plastic
                route_probs = route_preview["route_probs"]

            # Collect stats
            total = route.numel()
            stats = {
                "forward": (route == 0).sum().item() / total,
                "branch": (route == 1).sum().item() / total,
                "backward": (route == 2).sum().item() / total,
                "plastic": (route == 3).sum().item() / total,
                "mean_ed": ed.mean().item(),
                "theta1": thresholds[0].item(),
                "theta2": thresholds[1].item(),
                "theta3": thresholds[2].item(),
            }
            all_route_stats.append(stats)

            # === Apply routing effects ===

            # Backward tokens: re-process through previous layer (if exists)
            if i > 0:
                back_mask = (route == 2)  # [B, T]
                if back_mask.any():
                    back_indices = back_mask.nonzero(as_tuple=False)
                    back_tokens = x[back_indices[:, 0], back_indices[:, 1]]
                    # Re-process through previous block (selective forgetting + reinterpretation)
                    prev_layer_outs = layer_outputs[:i]  # exclude current
                    back_tokens_unsq = back_tokens.unsqueeze(1)  # [N, 1, D]
                    # Simple re-process: just run through previous block
                    prev_outs_for_back = [lo[back_indices[:, 0], back_indices[:, 1]].unsqueeze(1) for lo in prev_layer_outs]
                    reprocessed = self.blocks[i - 1](back_tokens_unsq, prev_outs_for_back)
                    x[back_indices[:, 0], back_indices[:, 1]] = reprocessed.squeeze(1)

            # Plastic tokens: apply plastic adaptation
            if self.cfg.use_plastic:
                plastic_mask = (route == 3)
                if plastic_mask.any():
                    p_indices = plastic_mask.nonzero(as_tuple=False)
                    p_tokens = x[p_indices[:, 0], p_indices[:, 1]].unsqueeze(1)
                    adapted = self.plastic(p_tokens)
                    x[p_indices[:, 0], p_indices[:, 1]] = adapted.squeeze(1)

        # Final norm
        hidden = self.ln_final(x)

        # === Output heads ===
        # Branch tokens get branch+verifier, others get lm_head
        if self.cfg.use_branches:
            # Use last layer's route for output decision
            last_route = route  # from final layer
            branch_mask = (last_route == 1)  # [B, T]

            if branch_mask.any() and branch_mask.sum() > 0:
                # Branch path
                b_indices = branch_mask.nonzero(as_tuple=False)
                b_tokens = hidden[b_indices[:, 0], b_indices[:, 1]].unsqueeze(1)
                branch_out = self.branch_router(b_tokens)
                result["diversity_loss"] = branch_out["diversity_loss"]
                result["branch_logits"] = branch_out["branch_logits"]

                if self.cfg.use_verifier:
                    v_out = self.verifier(branch_out["branch_logits"])
                    branch_logits = v_out["logits"]  # [N, 1, V]
                else:
                    branch_logits = branch_out["logits"]

                # Non-branch path
                all_logits = self.lm_head(hidden)  # [B, T, V]
                # Override branch positions
                all_logits[b_indices[:, 0], b_indices[:, 1]] = branch_logits.squeeze(1)
                logits = all_logits
            else:
                logits = self.lm_head(hidden)
                result["diversity_loss"] = torch.tensor(0.0, device=hidden.device)
        else:
            logits = self.lm_head(hidden)

        result["logits"] = logits
        result["route_stats"] = all_route_stats
        result["last_route_probs"] = route_probs
        result["last_route_thresholds"] = thresholds
        result["hidden"] = hidden

        # Loss
        if targets is not None:
            ce_loss = F.cross_entropy(
                logits.view(B * T, -1), targets.view(B * T)
            )
            total_loss = ce_loss
            if self.cfg.use_branches and "diversity_loss" in result:
                total_loss = total_loss + self.cfg.diversity_weight * result["diversity_loss"]
            result["loss"] = total_loss
            result["ce_loss"] = ce_loss

        return result

    def param_count(self) -> int:
        return sum(p.numel() for p in self.parameters())

    def param_count_str(self) -> str:
        n = self.param_count()
        if n >= 1_000_000:
            return f"{n / 1_000_000:.1f}M"
        return f"{n / 1_000:.1f}K"