vigneshwar234's picture
Add source: tmt/model/memory.py
18cff1d verified
"""
memory.py — MemoryAnchorCross: persistent cross-attention to global memory nodes.
Novel vs standard: vanilla transformers have no persistent state across forward
passes. MemoryAnchorCross maintains 16 learnable nn.Parameter vectors as
global "anchor" nodes that every token can attend to. After each forward pass
the anchors are updated via exponential moving average (EMA) of the current
token representations, giving the model a form of fast-weight memory without
recurrence.
"""
from __future__ import annotations
from typing import Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from torch import Tensor
from .config import TMTConfig
class MemoryAnchorCross(nn.Module):
"""Cross-attention from token stream to persistent memory anchor nodes."""
def __init__(self, cfg: TMTConfig) -> None:
super().__init__()
self.d_model = cfg.d_model
self.n_heads = cfg.n_heads
self.d_head = cfg.d_model // cfg.n_heads
self.n_anchors = cfg.memory_anchors
self.ema_alpha = 0.9 # EMA decay for memory update
# Persistent memory parameters — shape (M, D)
self.memory = nn.Parameter(
torch.randn(cfg.memory_anchors, cfg.d_model) * 0.02
)
self.q_proj = nn.Linear(cfg.d_model, cfg.d_model, bias=False)
self.k_proj = nn.Linear(cfg.d_model, cfg.d_model, bias=False)
self.v_proj = nn.Linear(cfg.d_model, cfg.d_model, bias=False)
self.out_proj = nn.Linear(cfg.d_model, cfg.d_model, bias=False)
self.dropout = nn.Dropout(cfg.dropout)
def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]:
"""
Args:
x: (B, S, D) token representations
Returns:
out: (B, S, D) tokens enhanced by memory cross-attention
memory_state: (M, D) updated memory anchors (detached for logging)
"""
B, S, D = x.shape
M = self.n_anchors
scale = self.d_head ** -0.5
# Queries come from tokens, Keys/Values from memory anchors
Q = self.q_proj(x) # (B, S, D)
mem = self.memory.unsqueeze(0).expand(B, -1, -1) # (B, M, D)
K = self.k_proj(mem) # (B, M, D)
V = self.v_proj(mem) # (B, M, D)
Q = rearrange(Q, "b s (h d) -> b h s d", h=self.n_heads)
K = rearrange(K, "b m (h d) -> b h m d", h=self.n_heads)
V = rearrange(V, "b m (h d) -> b h m d", h=self.n_heads)
# Attention over memory anchors: (B, H, S, M)
attn = torch.einsum("bhsd,bhmd->bhsm", Q, K) * scale
attn = F.softmax(attn, dim=-1)
attn = self.dropout(attn)
out = torch.einsum("bhsm,bhmd->bhsd", attn, V) # (B, H, S, D//H)
out = rearrange(out, "b h s d -> b s (h d)")
out = self.out_proj(out)
# EMA update of memory anchors using mean token representation
with torch.no_grad():
token_mean = x.mean(dim=1).mean(dim=0) # (D,) across batch
self.memory.data = (
self.ema_alpha * self.memory.data
+ (1 - self.ema_alpha) * token_mean.unsqueeze(0)
)
return out, self.memory.detach()
def __repr__(self) -> str:
p = sum(p.numel() for p in self.parameters())
return f"MemoryAnchorCross(anchors={self.n_anchors}, params={p:,})"