Spaces:
Sleeping
Sleeping
| """ABPT Stage B — Unified model with Equilibrium Signal + Adaptive Routing. | |
| Stage A modules (AttnRes, Plastic, Branches, Verifier) are still present, | |
| but now controlled by the unified equilibrium signal instead of being always-on. | |
| Three fundamental mechanisms: | |
| 1. Equilibrium Signal — deviation from running mean drives everything | |
| 2. Adaptive Routing — tokens go forward/branch/backward/plastic based on ED | |
| 3. Token Energy Budget — limits compute per token | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from src.model.config import ModelConfig | |
| from src.model.backbone import Backbone, TransformerBlock | |
| from src.model.plastic import PlasticLayer | |
| from src.model.branches import BranchRouter | |
| from src.model.verifier import Verifier | |
| from src.model.equilibrium import EquilibriumSignal, RoutingDecision, TokenEnergyBudget | |
| from src.model.adaptive_routing import ScatterGather | |
| class ABPTModelB(nn.Module): | |
| """Stage B: Unified ABPT with adaptive routing. | |
| Key difference from Stage A: | |
| - Equilibrium signal computed after each layer | |
| - Routing determines which tokens get branches, backward pass, or plasticity | |
| - Not all tokens go through all modules — only those that need it | |
| """ | |
| def __init__(self, cfg: ModelConfig): | |
| super().__init__() | |
| self.cfg = cfg | |
| self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.d_model) | |
| self.pos_emb = nn.Embedding(cfg.max_seq_len, cfg.d_model) | |
| self.drop = nn.Dropout(cfg.dropout) | |
| self.blocks = nn.ModuleList([ | |
| TransformerBlock(cfg, i) for i in range(cfg.n_layers) | |
| ]) | |
| self.ln_final = nn.LayerNorm(cfg.d_model) | |
| # Equilibrium signal per layer | |
| self.eq_signals = nn.ModuleList([ | |
| EquilibriumSignal(cfg.d_model, momentum=cfg.eq_momentum, warmup_steps=cfg.eq_warmup_steps) | |
| for _ in range(cfg.n_layers) | |
| ]) | |
| self.router = RoutingDecision( | |
| target_fractions=( | |
| cfg.route_forward_target, | |
| cfg.route_branch_target, | |
| cfg.route_backward_target, | |
| cfg.route_plastic_target, | |
| ), | |
| threshold_momentum=cfg.route_threshold_momentum, | |
| temperature=cfg.route_temperature, | |
| offset_scale=cfg.route_threshold_offset_scale, | |
| ) | |
| self.energy = TokenEnergyBudget() | |
| # Stage A modules — activated selectively by routing | |
| if cfg.use_plastic: | |
| self.plastic = PlasticLayer(cfg) | |
| if cfg.use_branches: | |
| self.branch_router = BranchRouter(cfg) | |
| if cfg.use_verifier and cfg.use_branches: | |
| self.verifier = Verifier(cfg) | |
| # Single lm_head for non-branched tokens | |
| self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False) | |
| def forward( | |
| self, | |
| input_ids: torch.Tensor, | |
| targets: torch.Tensor | None = None, | |
| ) -> dict: | |
| B, T = input_ids.shape | |
| pos = torch.arange(T, device=input_ids.device).unsqueeze(0) | |
| x = self.drop(self.tok_emb(input_ids) + self.pos_emb(pos)) | |
| result = {} | |
| layer_outputs = [x] | |
| all_route_stats = [] | |
| for i, block in enumerate(self.blocks): | |
| # Normal forward pass through block | |
| x = block(x, layer_outputs) | |
| layer_outputs.append(x) | |
| # Compute equilibrium signal | |
| eq_out = self.eq_signals[i](x) | |
| ed = eq_out["ed"] # [B, T] | |
| route_preview = self.router(ed) | |
| thresholds = route_preview["thresholds"] | |
| # During warmup: force all tokens to forward (accumulate stats only) | |
| if eq_out.get("warming_up", False): | |
| route = torch.zeros(B, T, dtype=torch.long, device=x.device) | |
| route_probs = route_preview["route_probs"] | |
| else: | |
| route = route_preview["route"] # [B, T]: 0=fwd, 1=branch, 2=back, 3=plastic | |
| route_probs = route_preview["route_probs"] | |
| # Collect stats | |
| total = route.numel() | |
| stats = { | |
| "forward": (route == 0).sum().item() / total, | |
| "branch": (route == 1).sum().item() / total, | |
| "backward": (route == 2).sum().item() / total, | |
| "plastic": (route == 3).sum().item() / total, | |
| "mean_ed": ed.mean().item(), | |
| "theta1": thresholds[0].item(), | |
| "theta2": thresholds[1].item(), | |
| "theta3": thresholds[2].item(), | |
| } | |
| all_route_stats.append(stats) | |
| # === Apply routing effects === | |
| # Backward tokens: re-process through previous layer (if exists) | |
| if i > 0: | |
| back_mask = (route == 2) # [B, T] | |
| if back_mask.any(): | |
| back_indices = back_mask.nonzero(as_tuple=False) | |
| back_tokens = x[back_indices[:, 0], back_indices[:, 1]] | |
| # Re-process through previous block (selective forgetting + reinterpretation) | |
| prev_layer_outs = layer_outputs[:i] # exclude current | |
| back_tokens_unsq = back_tokens.unsqueeze(1) # [N, 1, D] | |
| # Simple re-process: just run through previous block | |
| prev_outs_for_back = [lo[back_indices[:, 0], back_indices[:, 1]].unsqueeze(1) for lo in prev_layer_outs] | |
| reprocessed = self.blocks[i - 1](back_tokens_unsq, prev_outs_for_back) | |
| x[back_indices[:, 0], back_indices[:, 1]] = reprocessed.squeeze(1) | |
| # Plastic tokens: apply plastic adaptation | |
| if self.cfg.use_plastic: | |
| plastic_mask = (route == 3) | |
| if plastic_mask.any(): | |
| p_indices = plastic_mask.nonzero(as_tuple=False) | |
| p_tokens = x[p_indices[:, 0], p_indices[:, 1]].unsqueeze(1) | |
| adapted = self.plastic(p_tokens) | |
| x[p_indices[:, 0], p_indices[:, 1]] = adapted.squeeze(1) | |
| # Final norm | |
| hidden = self.ln_final(x) | |
| # === Output heads === | |
| # Branch tokens get branch+verifier, others get lm_head | |
| if self.cfg.use_branches: | |
| # Use last layer's route for output decision | |
| last_route = route # from final layer | |
| branch_mask = (last_route == 1) # [B, T] | |
| if branch_mask.any() and branch_mask.sum() > 0: | |
| # Branch path | |
| b_indices = branch_mask.nonzero(as_tuple=False) | |
| b_tokens = hidden[b_indices[:, 0], b_indices[:, 1]].unsqueeze(1) | |
| branch_out = self.branch_router(b_tokens) | |
| result["diversity_loss"] = branch_out["diversity_loss"] | |
| result["branch_logits"] = branch_out["branch_logits"] | |
| if self.cfg.use_verifier: | |
| v_out = self.verifier(branch_out["branch_logits"]) | |
| branch_logits = v_out["logits"] # [N, 1, V] | |
| else: | |
| branch_logits = branch_out["logits"] | |
| # Non-branch path | |
| all_logits = self.lm_head(hidden) # [B, T, V] | |
| # Override branch positions | |
| all_logits[b_indices[:, 0], b_indices[:, 1]] = branch_logits.squeeze(1) | |
| logits = all_logits | |
| else: | |
| logits = self.lm_head(hidden) | |
| result["diversity_loss"] = torch.tensor(0.0, device=hidden.device) | |
| else: | |
| logits = self.lm_head(hidden) | |
| result["logits"] = logits | |
| result["route_stats"] = all_route_stats | |
| result["last_route_probs"] = route_probs | |
| result["last_route_thresholds"] = thresholds | |
| result["hidden"] = hidden | |
| # Loss | |
| if targets is not None: | |
| ce_loss = F.cross_entropy( | |
| logits.view(B * T, -1), targets.view(B * T) | |
| ) | |
| total_loss = ce_loss | |
| if self.cfg.use_branches and "diversity_loss" in result: | |
| total_loss = total_loss + self.cfg.diversity_weight * result["diversity_loss"] | |
| result["loss"] = total_loss | |
| result["ce_loss"] = ce_loss | |
| return result | |
| def param_count(self) -> int: | |
| return sum(p.numel() for p in self.parameters()) | |
| def param_count_str(self) -> str: | |
| n = self.param_count() | |
| if n >= 1_000_000: | |
| return f"{n / 1_000_000:.1f}M" | |
| return f"{n / 1_000:.1f}K" | |