File size: 3,114 Bytes
9c63689
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
#!/usr/bin/env python3
from __future__ import annotations

from dataclasses import dataclass

import torch
import torch.nn as nn
import torch.nn.functional as F


@dataclass
class AnchorMemoryConfig:
    d_model: int
    heads: int
    anchor_stride: int = 256
    max_anchors: int = 2048
    dropout: float = 0.0


class AnchorCompressor(nn.Module):
    """Compress local token spans into trainable anchor vectors."""

    def __init__(self, d_model: int, anchor_stride: int):
        super().__init__()
        self.anchor_stride = anchor_stride
        self.score = nn.Linear(d_model, 1)
        self.mix = nn.Sequential(
            nn.LayerNorm(d_model),
            nn.Linear(d_model, 4 * d_model),
            nn.GELU(),
            nn.Linear(4 * d_model, d_model),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        bsz, seq, dim = x.shape
        pad = (-seq) % self.anchor_stride
        if pad:
            x = F.pad(x, (0, 0, 0, pad))
        chunks = x.view(bsz, -1, self.anchor_stride, dim)
        weights = self.score(chunks).softmax(dim=2)
        pooled = (chunks * weights).sum(dim=2)
        return pooled + self.mix(pooled)


class AnchorMemoryLayer(nn.Module):
    """Local-token stream reads from a bounded bank of learned anchors."""

    def __init__(self, cfg: AnchorMemoryConfig):
        super().__init__()
        self.cfg = cfg
        self.compress = AnchorCompressor(cfg.d_model, cfg.anchor_stride)
        self.q_ln = nn.LayerNorm(cfg.d_model)
        self.mem_ln = nn.LayerNorm(cfg.d_model)
        self.read = nn.MultiheadAttention(
            cfg.d_model,
            cfg.heads,
            dropout=cfg.dropout,
            batch_first=True,
        )
        self.gate = nn.Sequential(nn.Linear(2 * cfg.d_model, cfg.d_model), nn.Sigmoid())
        self.out_ln = nn.LayerNorm(cfg.d_model)

    def forward(
        self,
        x: torch.Tensor,
        memory: torch.Tensor | None = None,
        *,
        detach_memory: bool = False,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        new_anchors = self.compress(x)
        if detach_memory:
            new_anchors = new_anchors.detach()
        if memory is None:
            bank = new_anchors
        else:
            bank = torch.cat([memory, new_anchors], dim=1)
        if bank.size(1) > self.cfg.max_anchors:
            bank = bank[:, -self.cfg.max_anchors :]

        recalled, _ = self.read(self.q_ln(x), self.mem_ln(bank), self.mem_ln(bank), need_weights=False)
        gate = self.gate(torch.cat([x, recalled], dim=-1))
        mixed = x + gate * recalled
        return self.out_ln(mixed), bank


def smoke_test() -> None:
    cfg = AnchorMemoryConfig(d_model=128, heads=8, anchor_stride=32, max_anchors=64)
    layer = AnchorMemoryLayer(cfg)
    x = torch.randn(2, 256, 128)
    y, memory = layer(x)
    assert y.shape == x.shape
    assert memory.shape == (2, 8, 128)
    y2, memory2 = layer(x, memory)
    assert y2.shape == x.shape
    assert memory2.shape == (2, 16, 128)
    print("anchor_memory smoke OK", y.shape, memory2.shape)


if __name__ == "__main__":
    smoke_test()