EAM-100M-Agentic-Kernel / model /memory_sparse_attention.py
saur7764's picture
Upload folder using huggingface_hub
7ff1e9e verified
"""
Memory Sparse Attention (MSA) β€” EAM-100M Edge Agentic Model
============================================================
Combines three complementary mechanisms into a single attention layer:
1. **Persistent Memory Tokens**
Learnable (K, V) parameter pairs prepended to every attention
computation. They are *never* causally or sparsely masked, so every
query position can always read from the model's working-memory
scratchpad. The memory K/V parameters are per-layer and per-head,
but shared across the batch dimension.
2. **IndexCache Sparse Attention** (sequence β†’ sequence only)
Alternating Full / Shared layer pattern:
β€’ Full layers (even layer_idx) – compute fresh Top-K indices
and cache them.
β€’ Shared layers (odd layer_idx) – reuse the cached indices from
the previous Full layer.
This reduces the O(TΒ²) attention cost to O(T Β· sparse_topk).
3. **Interleaved Head Attention** (sequence β†’ sequence only)
The first half of attention heads use a local sliding-window mask
(optimised KV-cache footprint for long sequences); the second half
retain unrestricted global access.
Attention layout (T sequence tokens, M memory tokens):
att (B, n_head, T, M+T)
β”œβ”€β”€ [:, :, :, :M] ← sequence β†’ memory (always dense)
└── [:, :, :, M:] ← sequence β†’ sequence (causal + sparse + interleaved)
"""
import torch
import torch.nn as nn
from torch.nn import functional as F
from model.bitnet import BitLinear
class MemorySparseAttention(nn.Module):
"""
Memory Sparse Attention.
Parameters
----------
config : Config
Model hyper-parameters. Expected fields (all have defaults):
n_embd – model width
n_head – number of attention heads
dropout – dropout probability
bias – whether to use bias in linear layers
sparse_topk – K for top-K sparse selection (default 128)
local_window_size – sliding-window size for local heads (default 256)
n_memory_tokens – number of persistent memory slots (default 32)
block_size – maximum sequence length for the causal mask
layer_idx : int
Zero-based depth index used to determine Full vs Shared role.
"""
def __init__(self, config, layer_idx: int):
super().__init__()
assert config.n_embd % config.n_head == 0, (
"n_embd must be divisible by n_head"
)
self.n_head = config.n_head
self.n_embd = config.n_embd
self.head_dim = config.n_embd // config.n_head
self.layer_idx = layer_idx
self.sparse_topk = getattr(config, "sparse_topk", 128)
self.local_window_size = getattr(config, "local_window_size", 256)
self.n_memory = getattr(config, "n_memory_tokens", 32)
# IndexCache role: Full layers compute fresh indices; Shared layers reuse.
self.is_shared = (layer_idx % 2 != 0)
# ── QKV + output projection (BitNet 1.58-bit ternary weights) ────────
self.c_attn = BitLinear(config.n_embd, 3 * config.n_embd, bias=config.bias)
self.c_proj = BitLinear(config.n_embd, config.n_embd, bias=config.bias)
self.attn_dropout = nn.Dropout(config.dropout)
self.resid_dropout = nn.Dropout(config.dropout)
# ── Persistent Memory K, V parameters ────────────────────────────────
# Shape: (1, n_head, n_memory, head_dim) β†’ broadcast over batch.
# Initialised with the same std as token embeddings (Οƒ = 0.02).
self.memory_k = nn.Parameter(
torch.empty(1, self.n_head, self.n_memory, self.head_dim)
)
self.memory_v = nn.Parameter(
torch.empty(1, self.n_head, self.n_memory, self.head_dim)
)
nn.init.normal_(self.memory_k, std=0.02)
nn.init.normal_(self.memory_v, std=0.02)
# ── Causal mask for the sequence Γ— sequence block ─────────────────────
# Registered as a buffer so it moves with the model's device automatically.
self.register_buffer(
"causal_bias",
torch.tril(torch.ones(config.block_size, config.block_size))
.view(1, 1, config.block_size, config.block_size),
)
# ─────────────────────────────────────────────────────────────────────────
def forward(
self,
x: torch.Tensor,
cached_indices: "torch.Tensor | None" = None,
):
"""
Forward pass.
Args
----
x : (B, T, C) input token representations
cached_indices : top-K indices from the preceding Full layer
(only used when self.is_shared = True)
Returns
-------
y : (B, T, C) output representations
cached_indices : updated top-K indices (unchanged for Shared layers)
"""
B, T, C = x.size()
M = self.n_memory
# ── 1. Project Q, K, V from the input sequence ───────────────────────
q, seq_k, seq_v = self.c_attn(x).split(self.n_embd, dim=2)
# Reshape to (B, n_head, T, head_dim)
q = q .view(B, T, self.n_head, self.head_dim).transpose(1, 2)
seq_k = seq_k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
seq_v = seq_v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
# ── 2. Expand memory K, V over the batch dimension ───────────────────
mem_k = self.memory_k.expand(B, -1, -1, -1) # (B, n_head, M, head_dim)
mem_v = self.memory_v.expand(B, -1, -1, -1)
# Concatenate: memory first, then sequence
k = torch.cat([mem_k, seq_k], dim=2) # (B, n_head, M+T, head_dim)
v = torch.cat([mem_v, seq_v], dim=2) # (B, n_head, M+T, head_dim)
# ── 3. Scaled dot-product attention scores ────────────────────────────
scale = 1.0 / (self.head_dim ** 0.5)
att = (q @ k.transpose(-2, -1)) * scale # (B, n_head, T, M+T)
# Split into memory and sequence columns for selective masking
mem_att = att[:, :, :, :M] # (B, n_head, T, M) β€” kept as-is
seq_att = att[:, :, :, M:] # (B, n_head, T, T) β€” will be masked
# ── 4. Causal mask (sequence columns only) ────────────────────────────
causal: torch.Tensor = self.causal_bias[:, :, :T, :T]
seq_att = seq_att.masked_fill(causal == 0, float('-inf'))
# ── 5. Interleaved Head mask (sequence columns only) ──────────────────
# First n_local heads β†’ sliding window; remaining heads β†’ global
n_local = self.n_head // 2
i_idx = torch.arange(T, device=x.device).view(-1, 1)
j_idx = torch.arange(T, device=x.device).view(1, -1)
local_mask = (i_idx - j_idx) <= self.local_window_size # (T, T)
local_mask = local_mask.view(1, 1, T, T).expand(B, n_local, T, T)
global_mask = torch.ones(B, self.n_head - n_local, T, T,
dtype=torch.bool, device=x.device)
interleaved = torch.cat([local_mask, global_mask], dim=1) # (B, n_head, T, T)
seq_att = seq_att.masked_fill(~interleaved, float('-inf'))
# ── 6. IndexCache Sparse Top-K (sequence columns only) ────────────────
if self.sparse_topk < T:
if not self.is_shared:
# Full layer: derive fresh top-K indices and cache them
_, topk_indices = torch.topk(seq_att, k=self.sparse_topk, dim=-1)
cached_indices = topk_indices
else:
# Shared layer: reuse cached indices from the preceding Full layer
topk_indices = cached_indices
if topk_indices is not None:
sparse_mask = torch.zeros_like(seq_att, dtype=torch.bool)
sparse_mask.scatter_(-1, topk_indices, True)
seq_att = seq_att.masked_fill(~sparse_mask, float('-inf'))
# ── 7. Recombine memory + sequence scores β†’ softmax ───────────────────
# Memory slots are always part of the softmax denominator.
att = torch.cat([mem_att, seq_att], dim=-1) # (B, n_head, T, M+T)
att = F.softmax(att, dim=-1)
att = self.attn_dropout(att)
# ── 8. Weighted aggregation over V ────────────────────────────────────
y = att @ v # (B, n_head, T, head_dim)
y = y.transpose(1, 2).contiguous().view(B, T, C)
y = self.resid_dropout(self.c_proj(y))
return y, cached_indices
# ─────────────────────────────────────────────────────────────────────────
def extra_repr(self) -> str:
role = "Shared" if self.is_shared else "Full"
return (
f"layer={self.layer_idx} ({role}), "
f"n_head={self.n_head}, head_dim={self.head_dim}, "
f"n_memory={self.n_memory}, sparse_topk={self.sparse_topk}, "
f"local_window={self.local_window_size}"
)