vigneshwar234 commited on
Commit
18cff1d
·
verified ·
1 Parent(s): 9d0034a

Add source: tmt/model/memory.py

Browse files
Files changed (1) hide show
  1. tmt/model/memory.py +90 -0
tmt/model/memory.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ memory.py — MemoryAnchorCross: persistent cross-attention to global memory nodes.
3
+
4
+ Novel vs standard: vanilla transformers have no persistent state across forward
5
+ passes. MemoryAnchorCross maintains 16 learnable nn.Parameter vectors as
6
+ global "anchor" nodes that every token can attend to. After each forward pass
7
+ the anchors are updated via exponential moving average (EMA) of the current
8
+ token representations, giving the model a form of fast-weight memory without
9
+ recurrence.
10
+ """
11
+ from __future__ import annotations
12
+
13
+ from typing import Tuple
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ from einops import rearrange
19
+ from torch import Tensor
20
+
21
+ from .config import TMTConfig
22
+
23
+
24
+ class MemoryAnchorCross(nn.Module):
25
+ """Cross-attention from token stream to persistent memory anchor nodes."""
26
+
27
+ def __init__(self, cfg: TMTConfig) -> None:
28
+ super().__init__()
29
+ self.d_model = cfg.d_model
30
+ self.n_heads = cfg.n_heads
31
+ self.d_head = cfg.d_model // cfg.n_heads
32
+ self.n_anchors = cfg.memory_anchors
33
+ self.ema_alpha = 0.9 # EMA decay for memory update
34
+
35
+ # Persistent memory parameters — shape (M, D)
36
+ self.memory = nn.Parameter(
37
+ torch.randn(cfg.memory_anchors, cfg.d_model) * 0.02
38
+ )
39
+
40
+ self.q_proj = nn.Linear(cfg.d_model, cfg.d_model, bias=False)
41
+ self.k_proj = nn.Linear(cfg.d_model, cfg.d_model, bias=False)
42
+ self.v_proj = nn.Linear(cfg.d_model, cfg.d_model, bias=False)
43
+ self.out_proj = nn.Linear(cfg.d_model, cfg.d_model, bias=False)
44
+
45
+ self.dropout = nn.Dropout(cfg.dropout)
46
+
47
+ def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]:
48
+ """
49
+ Args:
50
+ x: (B, S, D) token representations
51
+ Returns:
52
+ out: (B, S, D) tokens enhanced by memory cross-attention
53
+ memory_state: (M, D) updated memory anchors (detached for logging)
54
+ """
55
+ B, S, D = x.shape
56
+ M = self.n_anchors
57
+ scale = self.d_head ** -0.5
58
+
59
+ # Queries come from tokens, Keys/Values from memory anchors
60
+ Q = self.q_proj(x) # (B, S, D)
61
+ mem = self.memory.unsqueeze(0).expand(B, -1, -1) # (B, M, D)
62
+ K = self.k_proj(mem) # (B, M, D)
63
+ V = self.v_proj(mem) # (B, M, D)
64
+
65
+ Q = rearrange(Q, "b s (h d) -> b h s d", h=self.n_heads)
66
+ K = rearrange(K, "b m (h d) -> b h m d", h=self.n_heads)
67
+ V = rearrange(V, "b m (h d) -> b h m d", h=self.n_heads)
68
+
69
+ # Attention over memory anchors: (B, H, S, M)
70
+ attn = torch.einsum("bhsd,bhmd->bhsm", Q, K) * scale
71
+ attn = F.softmax(attn, dim=-1)
72
+ attn = self.dropout(attn)
73
+
74
+ out = torch.einsum("bhsm,bhmd->bhsd", attn, V) # (B, H, S, D//H)
75
+ out = rearrange(out, "b h s d -> b s (h d)")
76
+ out = self.out_proj(out)
77
+
78
+ # EMA update of memory anchors using mean token representation
79
+ with torch.no_grad():
80
+ token_mean = x.mean(dim=1).mean(dim=0) # (D,) across batch
81
+ self.memory.data = (
82
+ self.ema_alpha * self.memory.data
83
+ + (1 - self.ema_alpha) * token_mean.unsqueeze(0)
84
+ )
85
+
86
+ return out, self.memory.detach()
87
+
88
+ def __repr__(self) -> str:
89
+ p = sum(p.numel() for p in self.parameters())
90
+ return f"MemoryAnchorCross(anchors={self.n_anchors}, params={p:,})"