""" Memory Sparse Attention (MSA) — EAM-100M Edge Agentic Model ============================================================ Combines three complementary mechanisms into a single attention layer: 1. **Persistent Memory Tokens** Learnable (K, V) parameter pairs prepended to every attention computation. They are *never* causally or sparsely masked, so every query position can always read from the model's working-memory scratchpad. The memory K/V parameters are per-layer and per-head, but shared across the batch dimension. 2. **IndexCache Sparse Attention** (sequence → sequence only) Alternating Full / Shared layer pattern: • Full layers (even layer_idx) – compute fresh Top-K indices and cache them. • Shared layers (odd layer_idx) – reuse the cached indices from the previous Full layer. This reduces the O(T²) attention cost to O(T · sparse_topk). 3. **Interleaved Head Attention** (sequence → sequence only) The first half of attention heads use a local sliding-window mask (optimised KV-cache footprint for long sequences); the second half retain unrestricted global access. Attention layout (T sequence tokens, M memory tokens): att (B, n_head, T, M+T) ├── [:, :, :, :M] ← sequence → memory (always dense) └── [:, :, :, M:] ← sequence → sequence (causal + sparse + interleaved) """ import torch import torch.nn as nn from torch.nn import functional as F from model.bitnet import BitLinear class MemorySparseAttention(nn.Module): """ Memory Sparse Attention. Parameters ---------- config : Config Model hyper-parameters. Expected fields (all have defaults): n_embd – model width n_head – number of attention heads dropout – dropout probability bias – whether to use bias in linear layers sparse_topk – K for top-K sparse selection (default 128) local_window_size – sliding-window size for local heads (default 256) n_memory_tokens – number of persistent memory slots (default 32) block_size – maximum sequence length for the causal mask layer_idx : int Zero-based depth index used to determine Full vs Shared role. """ def __init__(self, config, layer_idx: int): super().__init__() assert config.n_embd % config.n_head == 0, ( "n_embd must be divisible by n_head" ) self.n_head = config.n_head self.n_embd = config.n_embd self.head_dim = config.n_embd // config.n_head self.layer_idx = layer_idx self.sparse_topk = getattr(config, "sparse_topk", 128) self.local_window_size = getattr(config, "local_window_size", 256) self.n_memory = getattr(config, "n_memory_tokens", 32) # IndexCache role: Full layers compute fresh indices; Shared layers reuse. self.is_shared = (layer_idx % 2 != 0) # ── QKV + output projection (BitNet 1.58-bit ternary weights) ──────── self.c_attn = BitLinear(config.n_embd, 3 * config.n_embd, bias=config.bias) self.c_proj = BitLinear(config.n_embd, config.n_embd, bias=config.bias) self.attn_dropout = nn.Dropout(config.dropout) self.resid_dropout = nn.Dropout(config.dropout) # ── Persistent Memory K, V parameters ──────────────────────────────── # Shape: (1, n_head, n_memory, head_dim) → broadcast over batch. # Initialised with the same std as token embeddings (σ = 0.02). self.memory_k = nn.Parameter( torch.empty(1, self.n_head, self.n_memory, self.head_dim) ) self.memory_v = nn.Parameter( torch.empty(1, self.n_head, self.n_memory, self.head_dim) ) nn.init.normal_(self.memory_k, std=0.02) nn.init.normal_(self.memory_v, std=0.02) # ── Causal mask for the sequence × sequence block ───────────────────── # Registered as a buffer so it moves with the model's device automatically. self.register_buffer( "causal_bias", torch.tril(torch.ones(config.block_size, config.block_size)) .view(1, 1, config.block_size, config.block_size), ) # ───────────────────────────────────────────────────────────────────────── def forward( self, x: torch.Tensor, cached_indices: "torch.Tensor | None" = None, ): """ Forward pass. Args ---- x : (B, T, C) input token representations cached_indices : top-K indices from the preceding Full layer (only used when self.is_shared = True) Returns ------- y : (B, T, C) output representations cached_indices : updated top-K indices (unchanged for Shared layers) """ B, T, C = x.size() M = self.n_memory # ── 1. Project Q, K, V from the input sequence ─────────────────────── q, seq_k, seq_v = self.c_attn(x).split(self.n_embd, dim=2) # Reshape to (B, n_head, T, head_dim) q = q .view(B, T, self.n_head, self.head_dim).transpose(1, 2) seq_k = seq_k.view(B, T, self.n_head, self.head_dim).transpose(1, 2) seq_v = seq_v.view(B, T, self.n_head, self.head_dim).transpose(1, 2) # ── 2. Expand memory K, V over the batch dimension ─────────────────── mem_k = self.memory_k.expand(B, -1, -1, -1) # (B, n_head, M, head_dim) mem_v = self.memory_v.expand(B, -1, -1, -1) # Concatenate: memory first, then sequence k = torch.cat([mem_k, seq_k], dim=2) # (B, n_head, M+T, head_dim) v = torch.cat([mem_v, seq_v], dim=2) # (B, n_head, M+T, head_dim) # ── 3. Scaled dot-product attention scores ──────────────────────────── scale = 1.0 / (self.head_dim ** 0.5) att = (q @ k.transpose(-2, -1)) * scale # (B, n_head, T, M+T) # Split into memory and sequence columns for selective masking mem_att = att[:, :, :, :M] # (B, n_head, T, M) — kept as-is seq_att = att[:, :, :, M:] # (B, n_head, T, T) — will be masked # ── 4. Causal mask (sequence columns only) ──────────────────────────── causal: torch.Tensor = self.causal_bias[:, :, :T, :T] seq_att = seq_att.masked_fill(causal == 0, float('-inf')) # ── 5. Interleaved Head mask (sequence columns only) ────────────────── # First n_local heads → sliding window; remaining heads → global n_local = self.n_head // 2 i_idx = torch.arange(T, device=x.device).view(-1, 1) j_idx = torch.arange(T, device=x.device).view(1, -1) local_mask = (i_idx - j_idx) <= self.local_window_size # (T, T) local_mask = local_mask.view(1, 1, T, T).expand(B, n_local, T, T) global_mask = torch.ones(B, self.n_head - n_local, T, T, dtype=torch.bool, device=x.device) interleaved = torch.cat([local_mask, global_mask], dim=1) # (B, n_head, T, T) seq_att = seq_att.masked_fill(~interleaved, float('-inf')) # ── 6. IndexCache Sparse Top-K (sequence columns only) ──────────────── if self.sparse_topk < T: if not self.is_shared: # Full layer: derive fresh top-K indices and cache them _, topk_indices = torch.topk(seq_att, k=self.sparse_topk, dim=-1) cached_indices = topk_indices else: # Shared layer: reuse cached indices from the preceding Full layer topk_indices = cached_indices if topk_indices is not None: sparse_mask = torch.zeros_like(seq_att, dtype=torch.bool) sparse_mask.scatter_(-1, topk_indices, True) seq_att = seq_att.masked_fill(~sparse_mask, float('-inf')) # ── 7. Recombine memory + sequence scores → softmax ─────────────────── # Memory slots are always part of the softmax denominator. att = torch.cat([mem_att, seq_att], dim=-1) # (B, n_head, T, M+T) att = F.softmax(att, dim=-1) att = self.attn_dropout(att) # ── 8. Weighted aggregation over V ──────────────────────────────────── y = att @ v # (B, n_head, T, head_dim) y = y.transpose(1, 2).contiguous().view(B, T, C) y = self.resid_dropout(self.c_proj(y)) return y, cached_indices # ───────────────────────────────────────────────────────────────────────── def extra_repr(self) -> str: role = "Shared" if self.is_shared else "Full" return ( f"layer={self.layer_idx} ({role}), " f"n_head={self.n_head}, head_dim={self.head_dim}, " f"n_memory={self.n_memory}, sparse_topk={self.sparse_topk}, " f"local_window={self.local_window_size}" )