File size: 10,094 Bytes
7ff1e9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
"""
Memory Sparse Attention (MSA) β€” EAM-100M Edge Agentic Model
============================================================

Combines three complementary mechanisms into a single attention layer:

  1. **Persistent Memory Tokens**
     Learnable (K, V) parameter pairs prepended to every attention
     computation. They are *never* causally or sparsely masked, so every
     query position can always read from the model's working-memory
     scratchpad. The memory K/V parameters are per-layer and per-head,
     but shared across the batch dimension.

  2. **IndexCache Sparse Attention** (sequence β†’ sequence only)
     Alternating Full / Shared layer pattern:
       β€’ Full  layers  (even layer_idx) – compute fresh Top-K indices
                                          and cache them.
       β€’ Shared layers (odd layer_idx)  – reuse the cached indices from
                                          the previous Full layer.
     This reduces the O(TΒ²) attention cost to O(T Β· sparse_topk).

  3. **Interleaved Head Attention** (sequence β†’ sequence only)
     The first half of attention heads use a local sliding-window mask
     (optimised KV-cache footprint for long sequences); the second half
     retain unrestricted global access.

Attention layout (T sequence tokens, M memory tokens):

    att  (B, n_head, T, M+T)
         β”œβ”€β”€ [:, :, :, :M]   ← sequence β†’ memory   (always dense)
         └── [:, :, :, M:]   ← sequence β†’ sequence  (causal + sparse + interleaved)
"""

import torch
import torch.nn as nn
from torch.nn import functional as F
from model.bitnet import BitLinear


class MemorySparseAttention(nn.Module):
    """
    Memory Sparse Attention.

    Parameters
    ----------
    config : Config
        Model hyper-parameters.  Expected fields (all have defaults):
          n_embd          – model width
          n_head          – number of attention heads
          dropout         – dropout probability
          bias            – whether to use bias in linear layers
          sparse_topk     – K for top-K sparse selection (default 128)
          local_window_size – sliding-window size for local heads (default 256)
          n_memory_tokens – number of persistent memory slots (default 32)
          block_size      – maximum sequence length for the causal mask
    layer_idx : int
        Zero-based depth index used to determine Full vs Shared role.
    """

    def __init__(self, config, layer_idx: int):
        super().__init__()
        assert config.n_embd % config.n_head == 0, (
            "n_embd must be divisible by n_head"
        )

        self.n_head     = config.n_head
        self.n_embd     = config.n_embd
        self.head_dim   = config.n_embd // config.n_head
        self.layer_idx  = layer_idx

        self.sparse_topk        = getattr(config, "sparse_topk", 128)
        self.local_window_size  = getattr(config, "local_window_size", 256)
        self.n_memory           = getattr(config, "n_memory_tokens", 32)

        # IndexCache role: Full layers compute fresh indices; Shared layers reuse.
        self.is_shared = (layer_idx % 2 != 0)

        # ── QKV + output projection (BitNet 1.58-bit ternary weights) ────────
        self.c_attn = BitLinear(config.n_embd, 3 * config.n_embd, bias=config.bias)
        self.c_proj = BitLinear(config.n_embd, config.n_embd,     bias=config.bias)

        self.attn_dropout  = nn.Dropout(config.dropout)
        self.resid_dropout = nn.Dropout(config.dropout)

        # ── Persistent Memory K, V parameters ────────────────────────────────
        # Shape: (1, n_head, n_memory, head_dim) β†’ broadcast over batch.
        # Initialised with the same std as token embeddings (Οƒ = 0.02).
        self.memory_k = nn.Parameter(
            torch.empty(1, self.n_head, self.n_memory, self.head_dim)
        )
        self.memory_v = nn.Parameter(
            torch.empty(1, self.n_head, self.n_memory, self.head_dim)
        )
        nn.init.normal_(self.memory_k, std=0.02)
        nn.init.normal_(self.memory_v, std=0.02)

        # ── Causal mask for the sequence Γ— sequence block ─────────────────────
        # Registered as a buffer so it moves with the model's device automatically.
        self.register_buffer(
            "causal_bias",
            torch.tril(torch.ones(config.block_size, config.block_size))
                  .view(1, 1, config.block_size, config.block_size),
        )

    # ─────────────────────────────────────────────────────────────────────────
    def forward(
        self,
        x: torch.Tensor,
        cached_indices: "torch.Tensor | None" = None,
    ):
        """
        Forward pass.

        Args
        ----
        x              : (B, T, C)  input token representations
        cached_indices : top-K indices from the preceding Full layer
                         (only used when self.is_shared = True)

        Returns
        -------
        y              : (B, T, C)  output representations
        cached_indices : updated top-K indices (unchanged for Shared layers)
        """
        B, T, C = x.size()
        M = self.n_memory

        # ── 1. Project Q, K, V from the input sequence ───────────────────────
        q, seq_k, seq_v = self.c_attn(x).split(self.n_embd, dim=2)

        # Reshape to (B, n_head, T, head_dim)
        q     = q    .view(B, T, self.n_head, self.head_dim).transpose(1, 2)
        seq_k = seq_k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
        seq_v = seq_v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)

        # ── 2. Expand memory K, V over the batch dimension ───────────────────
        mem_k = self.memory_k.expand(B, -1, -1, -1)   # (B, n_head, M, head_dim)
        mem_v = self.memory_v.expand(B, -1, -1, -1)

        # Concatenate: memory first, then sequence
        k = torch.cat([mem_k, seq_k], dim=2)           # (B, n_head, M+T, head_dim)
        v = torch.cat([mem_v, seq_v], dim=2)           # (B, n_head, M+T, head_dim)

        # ── 3. Scaled dot-product attention scores ────────────────────────────
        scale = 1.0 / (self.head_dim ** 0.5)
        att   = (q @ k.transpose(-2, -1)) * scale      # (B, n_head, T, M+T)

        # Split into memory and sequence columns for selective masking
        mem_att = att[:, :, :, :M]                     # (B, n_head, T, M)   β€” kept as-is
        seq_att = att[:, :, :, M:]                     # (B, n_head, T, T)   β€” will be masked

        # ── 4. Causal mask (sequence columns only) ────────────────────────────
        causal: torch.Tensor = self.causal_bias[:, :, :T, :T]
        seq_att = seq_att.masked_fill(causal == 0, float('-inf'))

        # ── 5. Interleaved Head mask (sequence columns only) ──────────────────
        # First n_local heads β†’ sliding window;  remaining heads β†’ global
        n_local = self.n_head // 2
        i_idx = torch.arange(T, device=x.device).view(-1, 1)
        j_idx = torch.arange(T, device=x.device).view(1, -1)

        local_mask    = (i_idx - j_idx) <= self.local_window_size           # (T, T)
        local_mask    = local_mask.view(1, 1, T, T).expand(B, n_local, T, T)
        global_mask   = torch.ones(B, self.n_head - n_local, T, T,
                                   dtype=torch.bool, device=x.device)
        interleaved   = torch.cat([local_mask, global_mask], dim=1)         # (B, n_head, T, T)
        seq_att       = seq_att.masked_fill(~interleaved, float('-inf'))

        # ── 6. IndexCache Sparse Top-K (sequence columns only) ────────────────
        if self.sparse_topk < T:
            if not self.is_shared:
                # Full layer: derive fresh top-K indices and cache them
                _, topk_indices = torch.topk(seq_att, k=self.sparse_topk, dim=-1)
                cached_indices  = topk_indices
            else:
                # Shared layer: reuse cached indices from the preceding Full layer
                topk_indices = cached_indices

            if topk_indices is not None:
                sparse_mask = torch.zeros_like(seq_att, dtype=torch.bool)
                sparse_mask.scatter_(-1, topk_indices, True)
                seq_att = seq_att.masked_fill(~sparse_mask, float('-inf'))

        # ── 7. Recombine memory + sequence scores β†’ softmax ───────────────────
        # Memory slots are always part of the softmax denominator.
        att = torch.cat([mem_att, seq_att], dim=-1)    # (B, n_head, T, M+T)
        att = F.softmax(att, dim=-1)
        att = self.attn_dropout(att)

        # ── 8. Weighted aggregation over V ────────────────────────────────────
        y = att @ v                                     # (B, n_head, T, head_dim)
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        y = self.resid_dropout(self.c_proj(y))

        return y, cached_indices

    # ─────────────────────────────────────────────────────────────────────────
    def extra_repr(self) -> str:
        role = "Shared" if self.is_shared else "Full"
        return (
            f"layer={self.layer_idx} ({role}), "
            f"n_head={self.n_head}, head_dim={self.head_dim}, "
            f"n_memory={self.n_memory}, sparse_topk={self.sparse_topk}, "
            f"local_window={self.local_window_size}"
        )