File size: 8,447 Bytes
5c43f61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
"""

VortexSSM: Selective State-Space Layer

Simplified Mamba-style SSM with input-dependent selection.

Provides O(n) complexity for long sequences, ideal for scientific documents.

"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple


class VortexSSM(nn.Module):
    """

    Selective state-space layer. Linear complexity O(n) vs attention's O(n²).

    Handles long scientific documents efficiently with input-dependent selection.



    Architecture based on Mamba but simplified for scientific reasoning tasks.

    """

    def __init__(

        self,

        d_model: int,

        d_state: int = 16,

        d_conv: int = 4,

        expand: int = 2,

        dt_rank: Optional[int] = None,

    ):
        """

        Initialize VortexSSM.



        Args:

            d_model: Model dimension

            d_state: State dimension (default 16 for 7B, 32 for 13B)

            d_conv: Convolution kernel size for local context

            expand: Expansion factor for inner dimension

            dt_rank: Rank for delta projection (if None, uses ceil(d_model/16))

        """
        super().__init__()
        self.d_model = d_model
        self.d_state = d_state
        self.d_conv = d_conv
        self.expand = expand
        self.d_inner = d_model * expand

        if dt_rank is None:
            self.dt_rank = max(1, d_model // 16)
        else:
            self.dt_rank = dt_rank

        # Input projection: splits into x and z pathways
        self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=False)

        # Convolution for local context before SSM
        # Depthwise convolution for efficiency
        self.conv1d = nn.Conv1d(
            in_channels=self.d_inner,
            out_channels=self.d_inner,
            kernel_size=d_conv,
            padding=d_conv - 1,
            groups=self.d_inner,
            bias=False,
        )

        # SSM parameter projections (input-dependent)
        self.x_proj = nn.Linear(self.d_inner, self.dt_rank + 2 * self.d_state, bias=False)
        self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True)

        # State matrices (A is log-scale for stability)
        # A is (d_inner, d_state)
        self.A_log = nn.Parameter(torch.randn(self.d_inner, self.d_state))
        self.D = nn.Parameter(torch.randn(self.d_inner))

        # Output projection
        self.out_proj = nn.Linear(self.d_inner, d_model, bias=False)

        # Initialize weights
        self._initialize_weights()

    def _initialize_weights(self):
        """Initialize weights properly."""
        # Initialize A_log with negative values for stable discretization
        nn.init.normal_(self.A_log, mean=-4.0, std=0.5)
        nn.init.normal_(self.D, mean=0.0, std=0.1)

        # Initialize projections with small values
        for module in [self.in_proj, self.x_proj, self.dt_proj, self.conv1d, self.out_proj]:
            if hasattr(module, 'weight'):
                nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(

        self,

        x: torch.Tensor,

        state: Optional[torch.Tensor] = None,

        return_state: bool = False,

    ) -> torch.Tensor:
        """

        Forward pass through the SSM.



        Args:

            x: Input tensor (batch, seq_len, d_model)

            state: Previous hidden state (batch, d_inner, d_state)

            return_state: If True, return (output, state)



        Returns:

            Output tensor (batch, seq_len, d_model) or tuple with state

        """
        batch, seq_len, _ = x.shape
        device = x.device
        dtype = x.dtype

        # Double-check d_inner matches A_log shape
        d_inner = self.d_inner

        # Project input to inner dimension
        xz = self.in_proj(x)  # (batch, seq_len, 2 * d_inner)
        x, z = xz.chunk(2, dim=-1)

        # Apply 1D convolution for local context
        # Need to transpose for conv1d: (batch, d_inner, seq_len)
        x_conv = x.transpose(1, 2)
        x_conv = self.conv1d(x_conv)[..., :seq_len]  # Trim padding
        x = x_conv.transpose(1, 2)

        # Discretization: compute delta, A, B parameters
        # x_proj produces: delta (dt_rank), B (d_state), C (d_state)
        x_dbl = self.x_proj(x)  # (batch, seq_len, dt_rank + 2*d_state)
        (delta, B, C) = torch.split(
            x_dbl,
            [self.dt_rank, self.d_state, self.d_state],
            dim=-1,
        )

        # Project delta
        delta = self.dt_proj(delta)  # (batch, seq_len, d_inner)
        delta = F.softplus(delta)

        # Compute discretized state recurrence
        # Use scan operation for efficient sequential processing
        if state is None:
            state = torch.zeros(batch, d_inner, self.d_state, device=device, dtype=dtype)

        # Sequential scan (can be optimized with CUDA kernel)
        output = []
        for t in range(seq_len):
            x_t = x[:, t]  # (batch, d_inner)
            delta_t = delta[:, t]  # (batch, d_inner)
            B_t = B[:, t]  # (batch, d_state)
            C_t = C[:, t]  # (batch, d_state)

            # Discretize A
            A_delta = torch.exp(self.A_log * delta_t.unsqueeze(-1))  # (batch, d_inner, d_state)

            # State update: state = A_delta * state + B_t * x_t
            # B_t needs to be (batch, d_state) -> (batch, d_inner, d_state) via broadcasting
            state = A_delta * state + B_t.unsqueeze(1) * x_t.unsqueeze(-1)

            # Output: y = C_t * state + D * x_t
            y = (C_t.unsqueeze(1) * state).sum(dim=-1) + self.D * x_t
            output.append(y)

        output = torch.stack(output, dim=1)  # (batch, seq_len, d_inner)

        # Apply gating with z
        output = output * F.silu(z)

        # Project back to model dimension
        output = self.out_proj(output)

        if return_state:
            return output, state
        return output

    def step(

        self,

        x: torch.Tensor,

        state: torch.Tensor,

    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """

        Single-step inference for autoregressive decoding.



        Args:

            x: Input at current step (batch, d_model)

            state: Previous state (batch, d_inner, d_state)



        Returns:

            output: (batch, d_model)

            new_state: updated state

        """
        batch, _ = x.shape

        # Project input
        xz = self.in_proj(x.unsqueeze(1))  # Add seq dim
        x, z = xz.chunk(2, dim=-1)
        x = x.squeeze(1)
        z = z.squeeze(1)

        # No convolution for single step (would need cache)

        # Compute parameters
        x_dbl = self.x_proj(x.unsqueeze(1)).squeeze(1)
        delta, B, C = torch.split(
            x_dbl,
            [self.dt_rank, self.d_state, self.d_state],
            dim=-1,
        )
        delta = self.dt_proj(delta)
        delta = F.softplus(delta)

        # Single step discretization
        A_delta = torch.exp(self.A_log * delta.unsqueeze(-1))
        state = A_delta * state + B.unsqueeze(1) * x.unsqueeze(-1)
        y = (C.unsqueeze(1) * state).sum(dim=-1) + self.D * x
        y = y * F.silu(z)
        output = self.out_proj(y)

        return output, state


def test_vortex_ssm():
    """Test the VortexSSM layer."""
    batch_size = 2
    seq_len = 128
    d_model = 4096
    d_state = 16

    ssm = VortexSSM(d_model, d_state=d_state)
    x = torch.randn(batch_size, seq_len, d_model)

    # Forward pass
    output = ssm(x)
    print(f"Input shape: {x.shape}")
    print(f"Output shape: {output.shape}")
    assert output.shape == x.shape, f"Expected {x.shape}, got {output.shape}"

    # Stateful forward
    state = torch.zeros(batch_size, ssm.d_inner, d_state)
    output2, new_state = ssm(x, state=state, return_state=True)
    print(f"Stateful output shape: {output2.shape}")
    print(f"State shape: {new_state.shape}")

    # Single step
    x_step = torch.randn(batch_size, d_model)
    output_step, state_step = ssm.step(x_step, state)
    print(f"Step output shape: {output_step.shape}")
    print(f"Step state shape: {state_step.shape}")

    print("VortexSSM test passed!")


if __name__ == "__main__":
    test_vortex_ssm()