| """Metapragmatic Attention Head (MAH). |
| |
| Detects where meaning diverges across positions by computing the gap between |
| direct (local) interpretation and contextual (global) interpretation of each |
| token's hidden state. This is Peirce's "unlimited semiosis" made computational: |
| each sign (representamen) receives an interpretation (interpretant) that depends |
| on the surrounding discourse context. MAH quantifies where that context |
| *changes* the interpretation — i.e., where meaning forks. |
| |
| The divergence vector d_t at position t captures: |
| d_t = f(interp_t) - g(attend(interp_{0..t})) |
| where f is direct projection, g is the contextual output after causal attention. |
| High ||d_t|| → the sign at position t means something different in context |
| than it would in isolation. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import math |
| from dataclasses import dataclass |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from srt.config import MAHConfig |
|
|
|
|
| @dataclass |
| class MAHOutput: |
| """Output from a single MAH layer.""" |
|
|
| divergence: torch.Tensor |
| attention_weights: torch.Tensor | None = None |
|
|
|
|
| class MetapragmaticAttentionHead(nn.Module): |
| """Single MAH layer that reads hidden states and produces divergence vectors.""" |
|
|
| def __init__(self, cfg: MAHConfig, d_backbone: int, d_community: int = 0) -> None: |
| super().__init__() |
| d_sub = cfg.d_sub |
|
|
| |
| self.interp_proj = nn.Linear(d_backbone, d_sub, bias=False) |
|
|
| |
| self.comm_proj: nn.Module | None = None |
| if d_community > 0: |
| self.comm_proj = nn.Linear(d_community, d_sub, bias=False) |
|
|
| |
| self.num_heads = cfg.num_heads |
| self.head_dim = d_sub // cfg.num_heads |
| assert d_sub % cfg.num_heads == 0 |
|
|
| self.q_proj = nn.Linear(d_sub, d_sub, bias=False) |
| self.k_proj = nn.Linear(d_sub, d_sub, bias=False) |
| self.v_proj = nn.Linear(d_sub, d_sub, bias=False) |
| self.out_proj = nn.Linear(d_sub, d_sub, bias=False) |
| self.attn_dropout = nn.Dropout(cfg.dropout) |
|
|
| |
| self.div_proj = nn.Linear(d_sub, cfg.d_divergence, bias=False) |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| community_vec: torch.Tensor | None = None, |
| causal_mask: torch.Tensor | None = None, |
| ) -> MAHOutput: |
| """Compute divergence from backbone hidden states. |
| |
| Args: |
| hidden_states: (B, T, d_backbone) from a transformer layer. |
| community_vec: (B, d_community) soft community vector. |
| causal_mask: (1, 1, T, T) additive causal mask. |
| |
| Returns: |
| MAHOutput with divergence vectors and optional attention weights. |
| """ |
| B, T, _ = hidden_states.shape |
|
|
| |
| interp = self.interp_proj(hidden_states) |
|
|
| |
| if community_vec is not None and self.comm_proj is not None: |
| comm_bias = self.comm_proj(community_vec) |
| interp = interp + comm_bias.unsqueeze(1) |
|
|
| |
| q = self.q_proj(interp).view(B, T, self.num_heads, self.head_dim).transpose(1, 2) |
| k = self.k_proj(interp).view(B, T, self.num_heads, self.head_dim).transpose(1, 2) |
| v = self.v_proj(interp).view(B, T, self.num_heads, self.head_dim).transpose(1, 2) |
|
|
| attn = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim) |
| if causal_mask is not None: |
| attn = attn + causal_mask |
| attn_weights = F.softmax(attn, dim=-1) |
| attn_weights = self.attn_dropout(attn_weights) |
|
|
| contextual = (attn_weights @ v).transpose(1, 2).reshape(B, T, -1) |
| contextual = self.out_proj(contextual) |
|
|
| |
| divergence = self.div_proj(interp - contextual) |
|
|
| return MAHOutput(divergence=divergence, attention_weights=attn_weights.detach()) |
|
|