File size: 3,376 Bytes
18cff1d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
"""
memory.py — MemoryAnchorCross: persistent cross-attention to global memory nodes.

Novel vs standard: vanilla transformers have no persistent state across forward
passes.  MemoryAnchorCross maintains 16 learnable nn.Parameter vectors as
global "anchor" nodes that every token can attend to.  After each forward pass
the anchors are updated via exponential moving average (EMA) of the current
token representations, giving the model a form of fast-weight memory without
recurrence.
"""
from __future__ import annotations

from typing import Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from torch import Tensor

from .config import TMTConfig


class MemoryAnchorCross(nn.Module):
    """Cross-attention from token stream to persistent memory anchor nodes."""

    def __init__(self, cfg: TMTConfig) -> None:
        super().__init__()
        self.d_model = cfg.d_model
        self.n_heads = cfg.n_heads
        self.d_head = cfg.d_model // cfg.n_heads
        self.n_anchors = cfg.memory_anchors
        self.ema_alpha = 0.9  # EMA decay for memory update

        # Persistent memory parameters — shape (M, D)
        self.memory = nn.Parameter(
            torch.randn(cfg.memory_anchors, cfg.d_model) * 0.02
        )

        self.q_proj = nn.Linear(cfg.d_model, cfg.d_model, bias=False)
        self.k_proj = nn.Linear(cfg.d_model, cfg.d_model, bias=False)
        self.v_proj = nn.Linear(cfg.d_model, cfg.d_model, bias=False)
        self.out_proj = nn.Linear(cfg.d_model, cfg.d_model, bias=False)

        self.dropout = nn.Dropout(cfg.dropout)

    def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]:
        """
        Args:
            x: (B, S, D) token representations
        Returns:
            out:          (B, S, D) tokens enhanced by memory cross-attention
            memory_state: (M, D) updated memory anchors (detached for logging)
        """
        B, S, D = x.shape
        M = self.n_anchors
        scale = self.d_head ** -0.5

        # Queries come from tokens, Keys/Values from memory anchors
        Q = self.q_proj(x)  # (B, S, D)
        mem = self.memory.unsqueeze(0).expand(B, -1, -1)  # (B, M, D)
        K = self.k_proj(mem)  # (B, M, D)
        V = self.v_proj(mem)  # (B, M, D)

        Q = rearrange(Q, "b s (h d) -> b h s d", h=self.n_heads)
        K = rearrange(K, "b m (h d) -> b h m d", h=self.n_heads)
        V = rearrange(V, "b m (h d) -> b h m d", h=self.n_heads)

        # Attention over memory anchors: (B, H, S, M)
        attn = torch.einsum("bhsd,bhmd->bhsm", Q, K) * scale
        attn = F.softmax(attn, dim=-1)
        attn = self.dropout(attn)

        out = torch.einsum("bhsm,bhmd->bhsd", attn, V)  # (B, H, S, D//H)
        out = rearrange(out, "b h s d -> b s (h d)")
        out = self.out_proj(out)

        # EMA update of memory anchors using mean token representation
        with torch.no_grad():
            token_mean = x.mean(dim=1).mean(dim=0)  # (D,) across batch
            self.memory.data = (
                self.ema_alpha * self.memory.data
                + (1 - self.ema_alpha) * token_mean.unsqueeze(0)
            )

        return out, self.memory.detach()

    def __repr__(self) -> str:
        p = sum(p.numel() for p in self.parameters())
        return f"MemoryAnchorCross(anchors={self.n_anchors}, params={p:,})"