| """Reflexive Recurrent Module (RRM). |
| |
| Tracks per-position semiotic meta-state via a GRU that processes divergence |
| observations from MAH layers. At injection points, produces a FiLM-style |
| modulation (gamma, beta) that multiplicatively + additively biases the |
| backbone's hidden states: |
| |
| h' = h * (1 + gamma(meta_state)) + beta(meta_state) |
| |
| The meta-state h_meta_t represents the model's accumulated awareness of |
| semiotic divergence at position t. Each MAH observation updates it: |
| h_meta_t^{l+1} = GRU(divergence_t^l, h_meta_t^l) |
| |
| v3 used a single low-rank linear inject (gate * proj * scale) with |
| zero-initialized projection. Ablation showed the inject-back arm contributed |
| exactly nothing (every benchmark metric was identical to four decimal places |
| with injection forced to zero). The diagnosis was that the zero init plus |
| the inject-norm regularizer (which rewarded ||inj|| \u2248 1 regardless of |
| direction) drove the optimizer to satisfy the norm penalty with arbitrary |
| directions that were then orthogonal to the gradient signal from the frozen |
| backbone's CE. |
| |
| v4 fixes both: FiLM modulation has a non-zero gradient pathway from the first |
| step (beta is initialized to zero so the forward is identity at init, but |
| gamma has small Gaussian init so dL/d(gamma_proj) flows immediately when the |
| downstream MAH layer's divergence is supervised by the bif/regime losses). |
| The inject-norm regularizer is dropped at the loss layer (LossConfig |
| inject_reg_weight = 0.0 by default in v4). |
| """ |
|
|
| from __future__ import annotations |
|
|
| import torch |
| import torch.nn as nn |
|
|
| from srt.config import RRMConfig |
|
|
|
|
| class ReflexiveRecurrentModule(nn.Module): |
| """GRU-based reflexive meta-state tracker with FiLM-style injection.""" |
|
|
| def __init__(self, cfg: RRMConfig, d_divergence: int, d_backbone: int) -> None: |
| super().__init__() |
| self.d_meta = cfg.d_meta |
| self.inject_scale = cfg.inject_scale |
| self.d_backbone = d_backbone |
|
|
| |
| self.gru = nn.GRUCell(d_divergence, cfg.d_meta) |
|
|
| |
| |
| |
| |
| |
| self.gamma_proj = nn.Linear(cfg.d_meta, d_backbone, bias=True) |
| self.beta_proj = nn.Linear(cfg.d_meta, d_backbone, bias=True) |
| nn.init.normal_(self.gamma_proj.weight, std=0.02) |
| nn.init.zeros_(self.gamma_proj.bias) |
| nn.init.zeros_(self.beta_proj.weight) |
| nn.init.zeros_(self.beta_proj.bias) |
|
|
| def step( |
| self, divergence: torch.Tensor, meta_state: torch.Tensor | None |
| ) -> torch.Tensor: |
| """Update per-position meta-state with new divergence observation. |
| |
| Args: |
| divergence: (B, T, d_divergence) from MAH. |
| meta_state: (B, T, d_meta) or None for initial state. |
| |
| Returns: |
| Updated meta-state (B, T, d_meta). |
| """ |
| B, T, d_div = divergence.shape |
| div_flat = divergence.reshape(B * T, d_div) |
|
|
| if meta_state is None: |
| meta_flat = torch.zeros( |
| B * T, self.d_meta, device=divergence.device, dtype=divergence.dtype |
| ) |
| else: |
| meta_flat = meta_state.reshape(B * T, self.d_meta) |
|
|
| meta_flat = self.gru(div_flat, meta_flat) |
| return meta_flat.reshape(B, T, self.d_meta) |
|
|
| def inject( |
| self, meta_state: torch.Tensor, hidden_states: torch.Tensor |
| ) -> torch.Tensor: |
| """Produce FiLM modulation correction for backbone hidden states. |
| |
| Returns the *correction* (h' - h) = h * gamma + beta, NOT h'. The |
| caller adds this to h to get h'. This keeps the rest of the adapter |
| and the diagnostic logging (injection norm tracking) unchanged: the |
| \"injection\" tensor is still the additive correction applied to h. |
| |
| Args: |
| meta_state: (B, T, d_meta) current reflexive awareness. |
| hidden_states: (B, T, d_backbone) current hidden states. |
| |
| Returns: |
| Correction vector (B, T, d_backbone) to add to hidden_states. |
| """ |
| gamma = self.gamma_proj(meta_state) |
| beta = self.beta_proj(meta_state) |
| |
| correction = hidden_states * gamma + beta |
| return correction * self.inject_scale |
|
|