File size: 9,955 Bytes
1973cf0
d7d1235
1973cf0
d7d1235
 
1973cf0
d7d1235
1973cf0
d7d1235
 
1973cf0
d7d1235
 
 
 
 
 
 
1973cf0
 
 
 
 
 
 
 
 
 
d7d1235
1973cf0
 
 
d7d1235
 
 
 
1973cf0
 
d7d1235
1973cf0
 
 
 
be1bcbb
1973cf0
d7d1235
be1bcbb
1973cf0
d7d1235
1973cf0
be1bcbb
1973cf0
be1bcbb
1973cf0
 
be1bcbb
 
 
 
1973cf0
d7d1235
be1bcbb
d7d1235
1973cf0
d7d1235
be1bcbb
1973cf0
d7d1235
be1bcbb
 
1973cf0
be1bcbb
 
 
 
 
 
1973cf0
be1bcbb
d7d1235
 
be1bcbb
d7d1235
be1bcbb
 
 
d7d1235
 
be1bcbb
d7d1235
 
 
1973cf0
d7d1235
be1bcbb
d7d1235
 
 
1973cf0
d7d1235
 
1973cf0
d7d1235
be1bcbb
d7d1235
1973cf0
be1bcbb
1973cf0
d7d1235
1973cf0
d7d1235
be1bcbb
d7d1235
 
1973cf0
d7d1235
 
1973cf0
d7d1235
1973cf0
be1bcbb
 
d7d1235
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
be1bcbb
 
d7d1235
 
 
 
 
be1bcbb
d7d1235
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1973cf0
d7d1235
 
 
1973cf0
d7d1235
 
 
 
 
1973cf0
d7d1235
 
 
 
1973cf0
d7d1235
 
be1bcbb
1973cf0
 
 
 
be1bcbb
d7d1235
1973cf0
 
 
 
 
 
 
 
 
be1bcbb
1973cf0
 
 
d7d1235
 
1973cf0
 
 
d7d1235
be1bcbb
 
 
d7d1235
1973cf0
 
 
be1bcbb
d7d1235
 
 
be1bcbb
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
"""
Mamba-2 SSD — OPTIMIZED: intra-chunk parallelism via matrix multiply.

The key Mamba-2 insight (State Space Duality):
Within each chunk of size T, the SSM can be computed as a MATRIX MULTIPLY:

    Y_chunk = (L ⊙ (C B^T)) @ (Δ ⊙ X)

Where L is a lower-triangular mask with cumulative A products.
This replaces the T sequential steps with a single matmul of size T×T.

For L=256, T=16, num_chunks=16:
- Within chunk: parallel matmul (T×T = 16×16) 
- Across chunks: 16 sequential state carries (unavoidable, but trivial)

Total: 16 sequential state carries + 16 parallel matmuls = FAST.

NO in-place ops. Fully autograd safe. Works on CPU and GPU.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import math


class Mamba2SSD(nn.Module):
    """
    Mamba-2 SSD with intra-chunk matrix-multiply parallelism.
    
    Args:
        dim: Input/output dimension
        d_state: SSM state dimension (default 16)
        d_conv: Conv1d kernel size (default 4)
        expand: Inner dimension expansion (default 2)
        chunk_size: Chunk size for scan (default 64 — larger = more parallel)
    """
    
    def __init__(self, dim, d_state=16, d_conv=4, expand=2, chunk_size=64):
        super().__init__()
        self.dim = dim
        self.d_state = d_state
        self.chunk_size = chunk_size
        self.inner_dim = dim * expand
        
        # Input projection: x and gate
        self.in_proj = nn.Linear(dim, self.inner_dim * 2, bias=False)
        
        # Short causal conv for local context
        self.conv1d = nn.Conv1d(
            self.inner_dim, self.inner_dim,
            kernel_size=d_conv, padding=d_conv - 1,
            groups=self.inner_dim, bias=True
        )
        
        # SSM parameter projections
        self.dt_proj = nn.Linear(self.inner_dim, self.inner_dim, bias=True)
        self.B_proj = nn.Linear(self.inner_dim, d_state, bias=False)
        self.C_proj = nn.Linear(self.inner_dim, d_state, bias=False)
        
        # A: fixed decay rates (log-space, negative for stability)
        A = torch.arange(1, d_state + 1, dtype=torch.float32)
        self.A_log = nn.Parameter(torch.log(A))
        
        # D: residual skip
        self.D = nn.Parameter(torch.ones(self.inner_dim))
        
        # Output
        self.norm = nn.LayerNorm(self.inner_dim)
        self.out_proj = nn.Linear(self.inner_dim, dim, bias=False)
        
        self._init_weights()
    
    def _init_weights(self):
        nn.init.constant_(self.dt_proj.bias, -4.0)  # softplus(-4) ≈ 0.018
        nn.init.xavier_uniform_(self.in_proj.weight, gain=0.1)
        nn.init.xavier_uniform_(self.out_proj.weight, gain=0.1)
    
    def forward(self, x):
        """x: [B, L, dim] → [B, L, dim]"""
        return self._process(x)
    
    def _process(self, x):
        B, L, D = x.shape
        
        # Input projection
        xz = self.in_proj(x)
        x_inner, z = xz.chunk(2, dim=-1)
        
        # Causal conv
        x_conv = self.conv1d(x_inner.transpose(1, 2))[:, :, :L].transpose(1, 2)
        x_conv = F.silu(x_conv)
        
        # SSM params
        dt = F.softplus(self.dt_proj(x_conv))  # [B, L, inner_dim], positive
        B_mat = self.B_proj(x_conv)             # [B, L, d_state]
        C_mat = self.C_proj(x_conv)             # [B, L, d_state]
        A = -torch.exp(self.A_log)              # [d_state], negative
        
        # Chunk-parallel scan
        y = self._chunk_ssm(x_conv, dt, A, B_mat, C_mat)
        
        # Skip + norm + gate
        y = y + x_conv * self.D.unsqueeze(0).unsqueeze(0)
        y = self.norm(y) * F.silu(z)
        
        return self.out_proj(y)
    
    def _chunk_ssm(self, u, dt, A, B, C):
        """
        Chunk-parallel SSM computation.
        
        Within each chunk: compute via cumulative decay matrix (parallel).
        Across chunks: propagate final state (sequential, only num_chunks steps).
        
        The intra-chunk computation uses the identity:
            h_t = sum_{s=0}^{t} (prod_{k=s+1}^{t} dA_k) * dB_s * u_s
        
        This is a lower-triangular matrix-vector product, computable in parallel.
        """
        batch, L, d_inner = u.shape
        d_state = A.shape[0]
        T = min(self.chunk_size, L)
        
        # Pad to multiple of T
        pad = (T - L % T) % T
        if pad > 0:
            u = F.pad(u, (0, 0, 0, pad))
            dt = F.pad(dt, (0, 0, 0, pad))
            B = F.pad(B, (0, 0, 0, pad))
            C = F.pad(C, (0, 0, 0, pad))
        
        L_pad = u.shape[1]
        n_chunks = L_pad // T
        
        # Reshape: [B, n_chunks, T, ...]
        u_c = u.reshape(batch, n_chunks, T, d_inner)
        dt_c = dt.reshape(batch, n_chunks, T, d_inner)
        B_c = B.reshape(batch, n_chunks, T, d_state)
        C_c = C.reshape(batch, n_chunks, T, d_state)
        
        # Mean dt per position for state decay (simplification for scalar-A)
        dt_mean = dt_c.mean(dim=-1)  # [B, n_chunks, T]
        
        # Compute log(dA) per position: log_dA = dt_mean * A
        # A is [d_state], dt_mean is [B, nc, T]
        log_dA = dt_mean.unsqueeze(-1) * A.unsqueeze(0).unsqueeze(0).unsqueeze(0)
        # log_dA: [B, nc, T, d_state]
        
        # Cumulative sum for decay within chunk: cumsum along T dimension
        # For position t, decay from position s is: exp(sum_{k=s+1}^{t} log_dA_k)
        log_dA_cumsum = torch.cumsum(log_dA, dim=2)  # [B, nc, T, d_state]
        
        # Lower-triangular decay matrix: L[t,s] = exp(cumsum[t] - cumsum[s])
        # L[t,s,n] = exp(sum_{k=s+1}^{t} log_dA_k_n) for t >= s, else 0
        # Shape: [B, nc, T, T, d_state]
        decay_matrix = log_dA_cumsum.unsqueeze(3) - log_dA_cumsum.unsqueeze(2)
        # decay_matrix[..., t, s, :] = cumsum[t] - cumsum[s]
        
        # Apply causal mask (t >= s only)
        causal_mask = torch.tril(torch.ones(T, T, device=u.device))  # [T, T]
        decay_matrix = decay_matrix * causal_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1)
        decay_matrix = torch.exp(decay_matrix) * causal_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1)
        # [B, nc, T, T, d_state]
        
        # Compute dBu: dt * B * u → state input at each position
        # dt_c: [B, nc, T, d_inner], B_c: [B, nc, T, d_state], u_c: [B, nc, T, d_inner]
        # We need [B, nc, T, d_state, d_inner]
        dBu = dt_c.unsqueeze(-2) * B_c.unsqueeze(-1) * u_c.unsqueeze(-2)
        # dBu: [B, nc, T, d_state, d_inner]
        
        # Intra-chunk SSM via matrix multiply:
        # h[t] = sum_s decay[t,s] * dBu[s]
        # h: [B, nc, T, d_state, d_inner]
        # decay_matrix: [B, nc, T, T, d_state]
        # dBu: [B, nc, T, d_state, d_inner]
        
        # Einsum: h[b,c,t,n,d] = sum_s decay[b,c,t,s,n] * dBu[b,c,s,n,d]
        h_intra = torch.einsum('bctsn,bcsnd->bctnd', decay_matrix, dBu)
        # h_intra: [B, nc, T, d_state, d_inner]
        
        # Inter-chunk state propagation
        # Decay of previous chunk's final state into current chunk
        # Total decay for a full chunk: exp(sum of all T log_dA values)
        chunk_decay = torch.exp(log_dA_cumsum[:, :, -1, :])  # [B, nc, d_state]
        # Decay from chunk start to each position within chunk:
        # position_decay[t] = exp(cumsum[t]) (from position 0)
        position_decay = torch.exp(log_dA_cumsum)  # [B, nc, T, d_state]
        
        # Propagate states across chunks
        h_carry = torch.zeros(batch, d_state, d_inner, device=u.device)
        h_chunks = []
        
        for c_idx in range(n_chunks):
            # Decay carry state to each position in this chunk
            # h_from_prev[t] = position_decay[t] * h_carry
            h_from_prev = position_decay[:, c_idx, :, :].unsqueeze(-1) * h_carry.unsqueeze(1)
            # h_from_prev: [B, T, d_state, d_inner]
            
            # Total hidden state
            h_total = h_intra[:, c_idx] + h_from_prev  # [B, T, d_state, d_inner]
            h_chunks.append(h_total)
            
            # Update carry: final state of this chunk
            h_carry = h_total[:, -1, :, :]  # [B, d_state, d_inner]
        
        # Stack chunks: [B, nc, T, d_state, d_inner]
        h_all = torch.stack(h_chunks, dim=1)
        
        # Output: y[t] = C[t]^T @ h[t]
        # C_c: [B, nc, T, d_state], h_all: [B, nc, T, d_state, d_inner]
        y = torch.einsum('bctn,bctnd->bctd', C_c, h_all)
        # y: [B, nc, T, d_inner]
        
        # Reshape back
        y = y.reshape(batch, L_pad, d_inner)
        return y[:, :L, :]


class Mamba2Block(nn.Module):
    """
    Mamba-2 block with bidirectional scanning for 2D images.
    Forward + backward raster scan, merged via learned projection.
    """
    
    def __init__(self, dim, d_state=16, d_conv=4, expand=2, dropout=0.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
        
        self.ssd_fwd = Mamba2SSD(dim, d_state, d_conv, expand)
        self.ssd_bwd = Mamba2SSD(dim, d_state, d_conv, expand)
        self.merge = nn.Linear(dim * 2, dim, bias=False)
        
        ff_dim = dim * expand
        self.ff = nn.Sequential(
            nn.Linear(dim, ff_dim), nn.GELU(), nn.Dropout(dropout),
            nn.Linear(ff_dim, dim), nn.Dropout(dropout),
        )
    
    def forward(self, x):
        """x: [B, C, H, W] or [B, L, C]"""
        is_2d = x.dim() == 4
        if is_2d:
            B, C, H, W = x.shape
            x = x.flatten(2).transpose(1, 2)
        
        residual = x
        x_norm = self.norm1(x)
        
        fwd = self.ssd_fwd(x_norm)
        bwd = torch.flip(self.ssd_bwd(torch.flip(x_norm, [1])), [1])
        merged = self.merge(torch.cat([fwd, bwd], dim=-1))
        
        x = residual + merged
        x = x + self.ff(self.norm2(x))
        
        if is_2d:
            x = x.transpose(1, 2).reshape(B, C, H, W)
        return x