Spaces:
Sleeping
Sleeping
| """ | |
| Memory Sparse Attention (MSA) β EAM-100M Edge Agentic Model | |
| ============================================================ | |
| Combines three complementary mechanisms into a single attention layer: | |
| 1. **Persistent Memory Tokens** | |
| Learnable (K, V) parameter pairs prepended to every attention | |
| computation. They are *never* causally or sparsely masked, so every | |
| query position can always read from the model's working-memory | |
| scratchpad. The memory K/V parameters are per-layer and per-head, | |
| but shared across the batch dimension. | |
| 2. **IndexCache Sparse Attention** (sequence β sequence only) | |
| Alternating Full / Shared layer pattern: | |
| β’ Full layers (even layer_idx) β compute fresh Top-K indices | |
| and cache them. | |
| β’ Shared layers (odd layer_idx) β reuse the cached indices from | |
| the previous Full layer. | |
| This reduces the O(TΒ²) attention cost to O(T Β· sparse_topk). | |
| 3. **Interleaved Head Attention** (sequence β sequence only) | |
| The first half of attention heads use a local sliding-window mask | |
| (optimised KV-cache footprint for long sequences); the second half | |
| retain unrestricted global access. | |
| Attention layout (T sequence tokens, M memory tokens): | |
| att (B, n_head, T, M+T) | |
| βββ [:, :, :, :M] β sequence β memory (always dense) | |
| βββ [:, :, :, M:] β sequence β sequence (causal + sparse + interleaved) | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| from torch.nn import functional as F | |
| from model.bitnet import BitLinear | |
| class MemorySparseAttention(nn.Module): | |
| """ | |
| Memory Sparse Attention. | |
| Parameters | |
| ---------- | |
| config : Config | |
| Model hyper-parameters. Expected fields (all have defaults): | |
| n_embd β model width | |
| n_head β number of attention heads | |
| dropout β dropout probability | |
| bias β whether to use bias in linear layers | |
| sparse_topk β K for top-K sparse selection (default 128) | |
| local_window_size β sliding-window size for local heads (default 256) | |
| n_memory_tokens β number of persistent memory slots (default 32) | |
| block_size β maximum sequence length for the causal mask | |
| layer_idx : int | |
| Zero-based depth index used to determine Full vs Shared role. | |
| """ | |
| def __init__(self, config, layer_idx: int): | |
| super().__init__() | |
| assert config.n_embd % config.n_head == 0, ( | |
| "n_embd must be divisible by n_head" | |
| ) | |
| self.n_head = config.n_head | |
| self.n_embd = config.n_embd | |
| self.head_dim = config.n_embd // config.n_head | |
| self.layer_idx = layer_idx | |
| self.sparse_topk = getattr(config, "sparse_topk", 128) | |
| self.local_window_size = getattr(config, "local_window_size", 256) | |
| self.n_memory = getattr(config, "n_memory_tokens", 32) | |
| # IndexCache role: Full layers compute fresh indices; Shared layers reuse. | |
| self.is_shared = (layer_idx % 2 != 0) | |
| # ββ QKV + output projection (BitNet 1.58-bit ternary weights) ββββββββ | |
| self.c_attn = BitLinear(config.n_embd, 3 * config.n_embd, bias=config.bias) | |
| self.c_proj = BitLinear(config.n_embd, config.n_embd, bias=config.bias) | |
| self.attn_dropout = nn.Dropout(config.dropout) | |
| self.resid_dropout = nn.Dropout(config.dropout) | |
| # ββ Persistent Memory K, V parameters ββββββββββββββββββββββββββββββββ | |
| # Shape: (1, n_head, n_memory, head_dim) β broadcast over batch. | |
| # Initialised with the same std as token embeddings (Ο = 0.02). | |
| self.memory_k = nn.Parameter( | |
| torch.empty(1, self.n_head, self.n_memory, self.head_dim) | |
| ) | |
| self.memory_v = nn.Parameter( | |
| torch.empty(1, self.n_head, self.n_memory, self.head_dim) | |
| ) | |
| nn.init.normal_(self.memory_k, std=0.02) | |
| nn.init.normal_(self.memory_v, std=0.02) | |
| # ββ Causal mask for the sequence Γ sequence block βββββββββββββββββββββ | |
| # Registered as a buffer so it moves with the model's device automatically. | |
| self.register_buffer( | |
| "causal_bias", | |
| torch.tril(torch.ones(config.block_size, config.block_size)) | |
| .view(1, 1, config.block_size, config.block_size), | |
| ) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| cached_indices: "torch.Tensor | None" = None, | |
| ): | |
| """ | |
| Forward pass. | |
| Args | |
| ---- | |
| x : (B, T, C) input token representations | |
| cached_indices : top-K indices from the preceding Full layer | |
| (only used when self.is_shared = True) | |
| Returns | |
| ------- | |
| y : (B, T, C) output representations | |
| cached_indices : updated top-K indices (unchanged for Shared layers) | |
| """ | |
| B, T, C = x.size() | |
| M = self.n_memory | |
| # ββ 1. Project Q, K, V from the input sequence βββββββββββββββββββββββ | |
| q, seq_k, seq_v = self.c_attn(x).split(self.n_embd, dim=2) | |
| # Reshape to (B, n_head, T, head_dim) | |
| q = q .view(B, T, self.n_head, self.head_dim).transpose(1, 2) | |
| seq_k = seq_k.view(B, T, self.n_head, self.head_dim).transpose(1, 2) | |
| seq_v = seq_v.view(B, T, self.n_head, self.head_dim).transpose(1, 2) | |
| # ββ 2. Expand memory K, V over the batch dimension βββββββββββββββββββ | |
| mem_k = self.memory_k.expand(B, -1, -1, -1) # (B, n_head, M, head_dim) | |
| mem_v = self.memory_v.expand(B, -1, -1, -1) | |
| # Concatenate: memory first, then sequence | |
| k = torch.cat([mem_k, seq_k], dim=2) # (B, n_head, M+T, head_dim) | |
| v = torch.cat([mem_v, seq_v], dim=2) # (B, n_head, M+T, head_dim) | |
| # ββ 3. Scaled dot-product attention scores ββββββββββββββββββββββββββββ | |
| scale = 1.0 / (self.head_dim ** 0.5) | |
| att = (q @ k.transpose(-2, -1)) * scale # (B, n_head, T, M+T) | |
| # Split into memory and sequence columns for selective masking | |
| mem_att = att[:, :, :, :M] # (B, n_head, T, M) β kept as-is | |
| seq_att = att[:, :, :, M:] # (B, n_head, T, T) β will be masked | |
| # ββ 4. Causal mask (sequence columns only) ββββββββββββββββββββββββββββ | |
| causal: torch.Tensor = self.causal_bias[:, :, :T, :T] | |
| seq_att = seq_att.masked_fill(causal == 0, float('-inf')) | |
| # ββ 5. Interleaved Head mask (sequence columns only) ββββββββββββββββββ | |
| # First n_local heads β sliding window; remaining heads β global | |
| n_local = self.n_head // 2 | |
| i_idx = torch.arange(T, device=x.device).view(-1, 1) | |
| j_idx = torch.arange(T, device=x.device).view(1, -1) | |
| local_mask = (i_idx - j_idx) <= self.local_window_size # (T, T) | |
| local_mask = local_mask.view(1, 1, T, T).expand(B, n_local, T, T) | |
| global_mask = torch.ones(B, self.n_head - n_local, T, T, | |
| dtype=torch.bool, device=x.device) | |
| interleaved = torch.cat([local_mask, global_mask], dim=1) # (B, n_head, T, T) | |
| seq_att = seq_att.masked_fill(~interleaved, float('-inf')) | |
| # ββ 6. IndexCache Sparse Top-K (sequence columns only) ββββββββββββββββ | |
| if self.sparse_topk < T: | |
| if not self.is_shared: | |
| # Full layer: derive fresh top-K indices and cache them | |
| _, topk_indices = torch.topk(seq_att, k=self.sparse_topk, dim=-1) | |
| cached_indices = topk_indices | |
| else: | |
| # Shared layer: reuse cached indices from the preceding Full layer | |
| topk_indices = cached_indices | |
| if topk_indices is not None: | |
| sparse_mask = torch.zeros_like(seq_att, dtype=torch.bool) | |
| sparse_mask.scatter_(-1, topk_indices, True) | |
| seq_att = seq_att.masked_fill(~sparse_mask, float('-inf')) | |
| # ββ 7. Recombine memory + sequence scores β softmax βββββββββββββββββββ | |
| # Memory slots are always part of the softmax denominator. | |
| att = torch.cat([mem_att, seq_att], dim=-1) # (B, n_head, T, M+T) | |
| att = F.softmax(att, dim=-1) | |
| att = self.attn_dropout(att) | |
| # ββ 8. Weighted aggregation over V ββββββββββββββββββββββββββββββββββββ | |
| y = att @ v # (B, n_head, T, head_dim) | |
| y = y.transpose(1, 2).contiguous().view(B, T, C) | |
| y = self.resid_dropout(self.c_proj(y)) | |
| return y, cached_indices | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def extra_repr(self) -> str: | |
| role = "Shared" if self.is_shared else "Full" | |
| return ( | |
| f"layer={self.layer_idx} ({role}), " | |
| f"n_head={self.n_head}, head_dim={self.head_dim}, " | |
| f"n_memory={self.n_memory}, sparse_topk={self.sparse_topk}, " | |
| f"local_window={self.local_window_size}" | |
| ) | |