| |
| from __future__ import annotations |
|
|
| from dataclasses import dataclass |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| @dataclass |
| class AnchorMemoryConfig: |
| d_model: int |
| heads: int |
| anchor_stride: int = 256 |
| max_anchors: int = 2048 |
| dropout: float = 0.0 |
|
|
|
|
| class AnchorCompressor(nn.Module): |
| """Compress local token spans into trainable anchor vectors.""" |
|
|
| def __init__(self, d_model: int, anchor_stride: int): |
| super().__init__() |
| self.anchor_stride = anchor_stride |
| self.score = nn.Linear(d_model, 1) |
| self.mix = nn.Sequential( |
| nn.LayerNorm(d_model), |
| nn.Linear(d_model, 4 * d_model), |
| nn.GELU(), |
| nn.Linear(4 * d_model, d_model), |
| ) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| bsz, seq, dim = x.shape |
| pad = (-seq) % self.anchor_stride |
| if pad: |
| x = F.pad(x, (0, 0, 0, pad)) |
| chunks = x.view(bsz, -1, self.anchor_stride, dim) |
| weights = self.score(chunks).softmax(dim=2) |
| pooled = (chunks * weights).sum(dim=2) |
| return pooled + self.mix(pooled) |
|
|
|
|
| class AnchorMemoryLayer(nn.Module): |
| """Local-token stream reads from a bounded bank of learned anchors.""" |
|
|
| def __init__(self, cfg: AnchorMemoryConfig): |
| super().__init__() |
| self.cfg = cfg |
| self.compress = AnchorCompressor(cfg.d_model, cfg.anchor_stride) |
| self.q_ln = nn.LayerNorm(cfg.d_model) |
| self.mem_ln = nn.LayerNorm(cfg.d_model) |
| self.read = nn.MultiheadAttention( |
| cfg.d_model, |
| cfg.heads, |
| dropout=cfg.dropout, |
| batch_first=True, |
| ) |
| self.gate = nn.Sequential(nn.Linear(2 * cfg.d_model, cfg.d_model), nn.Sigmoid()) |
| self.out_ln = nn.LayerNorm(cfg.d_model) |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| memory: torch.Tensor | None = None, |
| *, |
| detach_memory: bool = False, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| new_anchors = self.compress(x) |
| if detach_memory: |
| new_anchors = new_anchors.detach() |
| if memory is None: |
| bank = new_anchors |
| else: |
| bank = torch.cat([memory, new_anchors], dim=1) |
| if bank.size(1) > self.cfg.max_anchors: |
| bank = bank[:, -self.cfg.max_anchors :] |
|
|
| recalled, _ = self.read(self.q_ln(x), self.mem_ln(bank), self.mem_ln(bank), need_weights=False) |
| gate = self.gate(torch.cat([x, recalled], dim=-1)) |
| mixed = x + gate * recalled |
| return self.out_ln(mixed), bank |
|
|
|
|
| def smoke_test() -> None: |
| cfg = AnchorMemoryConfig(d_model=128, heads=8, anchor_stride=32, max_anchors=64) |
| layer = AnchorMemoryLayer(cfg) |
| x = torch.randn(2, 256, 128) |
| y, memory = layer(x) |
| assert y.shape == x.shape |
| assert memory.shape == (2, 8, 128) |
| y2, memory2 = layer(x, memory) |
| assert y2.shape == x.shape |
| assert memory2.shape == (2, 16, 128) |
| print("anchor_memory smoke OK", y.shape, memory2.shape) |
|
|
|
|
| if __name__ == "__main__": |
| smoke_test() |
|
|