AGILLM-4 / anchor_memory.py
OpenTransformer's picture
Add AGILLM-4 training scaffold
9c63689 verified
#!/usr/bin/env python3
from __future__ import annotations
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F
@dataclass
class AnchorMemoryConfig:
d_model: int
heads: int
anchor_stride: int = 256
max_anchors: int = 2048
dropout: float = 0.0
class AnchorCompressor(nn.Module):
"""Compress local token spans into trainable anchor vectors."""
def __init__(self, d_model: int, anchor_stride: int):
super().__init__()
self.anchor_stride = anchor_stride
self.score = nn.Linear(d_model, 1)
self.mix = nn.Sequential(
nn.LayerNorm(d_model),
nn.Linear(d_model, 4 * d_model),
nn.GELU(),
nn.Linear(4 * d_model, d_model),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
bsz, seq, dim = x.shape
pad = (-seq) % self.anchor_stride
if pad:
x = F.pad(x, (0, 0, 0, pad))
chunks = x.view(bsz, -1, self.anchor_stride, dim)
weights = self.score(chunks).softmax(dim=2)
pooled = (chunks * weights).sum(dim=2)
return pooled + self.mix(pooled)
class AnchorMemoryLayer(nn.Module):
"""Local-token stream reads from a bounded bank of learned anchors."""
def __init__(self, cfg: AnchorMemoryConfig):
super().__init__()
self.cfg = cfg
self.compress = AnchorCompressor(cfg.d_model, cfg.anchor_stride)
self.q_ln = nn.LayerNorm(cfg.d_model)
self.mem_ln = nn.LayerNorm(cfg.d_model)
self.read = nn.MultiheadAttention(
cfg.d_model,
cfg.heads,
dropout=cfg.dropout,
batch_first=True,
)
self.gate = nn.Sequential(nn.Linear(2 * cfg.d_model, cfg.d_model), nn.Sigmoid())
self.out_ln = nn.LayerNorm(cfg.d_model)
def forward(
self,
x: torch.Tensor,
memory: torch.Tensor | None = None,
*,
detach_memory: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
new_anchors = self.compress(x)
if detach_memory:
new_anchors = new_anchors.detach()
if memory is None:
bank = new_anchors
else:
bank = torch.cat([memory, new_anchors], dim=1)
if bank.size(1) > self.cfg.max_anchors:
bank = bank[:, -self.cfg.max_anchors :]
recalled, _ = self.read(self.q_ln(x), self.mem_ln(bank), self.mem_ln(bank), need_weights=False)
gate = self.gate(torch.cat([x, recalled], dim=-1))
mixed = x + gate * recalled
return self.out_ln(mixed), bank
def smoke_test() -> None:
cfg = AnchorMemoryConfig(d_model=128, heads=8, anchor_stride=32, max_anchors=64)
layer = AnchorMemoryLayer(cfg)
x = torch.randn(2, 256, 128)
y, memory = layer(x)
assert y.shape == x.shape
assert memory.shape == (2, 8, 128)
y2, memory2 = layer(x, memory)
assert y2.shape == x.shape
assert memory2.shape == (2, 16, 128)
print("anchor_memory smoke OK", y.shape, memory2.shape)
if __name__ == "__main__":
smoke_test()