abpt / src /model /abpt_b.py
Search
auto: sync run_testformer_wikitext_combo_remote.py
f37be5a
"""ABPT Stage B — Unified model with Equilibrium Signal + Adaptive Routing.
Stage A modules (AttnRes, Plastic, Branches, Verifier) are still present,
but now controlled by the unified equilibrium signal instead of being always-on.
Three fundamental mechanisms:
1. Equilibrium Signal — deviation from running mean drives everything
2. Adaptive Routing — tokens go forward/branch/backward/plastic based on ED
3. Token Energy Budget — limits compute per token
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from src.model.config import ModelConfig
from src.model.backbone import Backbone, TransformerBlock
from src.model.plastic import PlasticLayer
from src.model.branches import BranchRouter
from src.model.verifier import Verifier
from src.model.equilibrium import EquilibriumSignal, RoutingDecision, TokenEnergyBudget
from src.model.adaptive_routing import ScatterGather
class ABPTModelB(nn.Module):
"""Stage B: Unified ABPT with adaptive routing.
Key difference from Stage A:
- Equilibrium signal computed after each layer
- Routing determines which tokens get branches, backward pass, or plasticity
- Not all tokens go through all modules — only those that need it
"""
def __init__(self, cfg: ModelConfig):
super().__init__()
self.cfg = cfg
self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.d_model)
self.pos_emb = nn.Embedding(cfg.max_seq_len, cfg.d_model)
self.drop = nn.Dropout(cfg.dropout)
self.blocks = nn.ModuleList([
TransformerBlock(cfg, i) for i in range(cfg.n_layers)
])
self.ln_final = nn.LayerNorm(cfg.d_model)
# Equilibrium signal per layer
self.eq_signals = nn.ModuleList([
EquilibriumSignal(cfg.d_model, momentum=cfg.eq_momentum, warmup_steps=cfg.eq_warmup_steps)
for _ in range(cfg.n_layers)
])
self.router = RoutingDecision(
target_fractions=(
cfg.route_forward_target,
cfg.route_branch_target,
cfg.route_backward_target,
cfg.route_plastic_target,
),
threshold_momentum=cfg.route_threshold_momentum,
temperature=cfg.route_temperature,
offset_scale=cfg.route_threshold_offset_scale,
)
self.energy = TokenEnergyBudget()
# Stage A modules — activated selectively by routing
if cfg.use_plastic:
self.plastic = PlasticLayer(cfg)
if cfg.use_branches:
self.branch_router = BranchRouter(cfg)
if cfg.use_verifier and cfg.use_branches:
self.verifier = Verifier(cfg)
# Single lm_head for non-branched tokens
self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False)
def forward(
self,
input_ids: torch.Tensor,
targets: torch.Tensor | None = None,
) -> dict:
B, T = input_ids.shape
pos = torch.arange(T, device=input_ids.device).unsqueeze(0)
x = self.drop(self.tok_emb(input_ids) + self.pos_emb(pos))
result = {}
layer_outputs = [x]
all_route_stats = []
for i, block in enumerate(self.blocks):
# Normal forward pass through block
x = block(x, layer_outputs)
layer_outputs.append(x)
# Compute equilibrium signal
eq_out = self.eq_signals[i](x)
ed = eq_out["ed"] # [B, T]
route_preview = self.router(ed)
thresholds = route_preview["thresholds"]
# During warmup: force all tokens to forward (accumulate stats only)
if eq_out.get("warming_up", False):
route = torch.zeros(B, T, dtype=torch.long, device=x.device)
route_probs = route_preview["route_probs"]
else:
route = route_preview["route"] # [B, T]: 0=fwd, 1=branch, 2=back, 3=plastic
route_probs = route_preview["route_probs"]
# Collect stats
total = route.numel()
stats = {
"forward": (route == 0).sum().item() / total,
"branch": (route == 1).sum().item() / total,
"backward": (route == 2).sum().item() / total,
"plastic": (route == 3).sum().item() / total,
"mean_ed": ed.mean().item(),
"theta1": thresholds[0].item(),
"theta2": thresholds[1].item(),
"theta3": thresholds[2].item(),
}
all_route_stats.append(stats)
# === Apply routing effects ===
# Backward tokens: re-process through previous layer (if exists)
if i > 0:
back_mask = (route == 2) # [B, T]
if back_mask.any():
back_indices = back_mask.nonzero(as_tuple=False)
back_tokens = x[back_indices[:, 0], back_indices[:, 1]]
# Re-process through previous block (selective forgetting + reinterpretation)
prev_layer_outs = layer_outputs[:i] # exclude current
back_tokens_unsq = back_tokens.unsqueeze(1) # [N, 1, D]
# Simple re-process: just run through previous block
prev_outs_for_back = [lo[back_indices[:, 0], back_indices[:, 1]].unsqueeze(1) for lo in prev_layer_outs]
reprocessed = self.blocks[i - 1](back_tokens_unsq, prev_outs_for_back)
x[back_indices[:, 0], back_indices[:, 1]] = reprocessed.squeeze(1)
# Plastic tokens: apply plastic adaptation
if self.cfg.use_plastic:
plastic_mask = (route == 3)
if plastic_mask.any():
p_indices = plastic_mask.nonzero(as_tuple=False)
p_tokens = x[p_indices[:, 0], p_indices[:, 1]].unsqueeze(1)
adapted = self.plastic(p_tokens)
x[p_indices[:, 0], p_indices[:, 1]] = adapted.squeeze(1)
# Final norm
hidden = self.ln_final(x)
# === Output heads ===
# Branch tokens get branch+verifier, others get lm_head
if self.cfg.use_branches:
# Use last layer's route for output decision
last_route = route # from final layer
branch_mask = (last_route == 1) # [B, T]
if branch_mask.any() and branch_mask.sum() > 0:
# Branch path
b_indices = branch_mask.nonzero(as_tuple=False)
b_tokens = hidden[b_indices[:, 0], b_indices[:, 1]].unsqueeze(1)
branch_out = self.branch_router(b_tokens)
result["diversity_loss"] = branch_out["diversity_loss"]
result["branch_logits"] = branch_out["branch_logits"]
if self.cfg.use_verifier:
v_out = self.verifier(branch_out["branch_logits"])
branch_logits = v_out["logits"] # [N, 1, V]
else:
branch_logits = branch_out["logits"]
# Non-branch path
all_logits = self.lm_head(hidden) # [B, T, V]
# Override branch positions
all_logits[b_indices[:, 0], b_indices[:, 1]] = branch_logits.squeeze(1)
logits = all_logits
else:
logits = self.lm_head(hidden)
result["diversity_loss"] = torch.tensor(0.0, device=hidden.device)
else:
logits = self.lm_head(hidden)
result["logits"] = logits
result["route_stats"] = all_route_stats
result["last_route_probs"] = route_probs
result["last_route_thresholds"] = thresholds
result["hidden"] = hidden
# Loss
if targets is not None:
ce_loss = F.cross_entropy(
logits.view(B * T, -1), targets.view(B * T)
)
total_loss = ce_loss
if self.cfg.use_branches and "diversity_loss" in result:
total_loss = total_loss + self.cfg.diversity_weight * result["diversity_loss"]
result["loss"] = total_loss
result["ce_loss"] = ce_loss
return result
def param_count(self) -> int:
return sum(p.numel() for p in self.parameters())
def param_count_str(self) -> str:
n = self.param_count()
if n >= 1_000_000:
return f"{n / 1_000_000:.1f}M"
return f"{n / 1_000:.1f}K"