""" VortexSSM: Selective State-Space Layer Simplified Mamba-style SSM with input-dependent selection. Provides O(n) complexity for long sequences, ideal for scientific documents. """ import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional, Tuple class VortexSSM(nn.Module): """ Selective state-space layer. Linear complexity O(n) vs attention's O(n²). Handles long scientific documents efficiently with input-dependent selection. Architecture based on Mamba but simplified for scientific reasoning tasks. """ def __init__( self, d_model: int, d_state: int = 16, d_conv: int = 4, expand: int = 2, dt_rank: Optional[int] = None, ): """ Initialize VortexSSM. Args: d_model: Model dimension d_state: State dimension (default 16 for 7B, 32 for 13B) d_conv: Convolution kernel size for local context expand: Expansion factor for inner dimension dt_rank: Rank for delta projection (if None, uses ceil(d_model/16)) """ super().__init__() self.d_model = d_model self.d_state = d_state self.d_conv = d_conv self.expand = expand self.d_inner = d_model * expand if dt_rank is None: self.dt_rank = max(1, d_model // 16) else: self.dt_rank = dt_rank # Input projection: splits into x and z pathways self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=False) # Convolution for local context before SSM # Depthwise convolution for efficiency self.conv1d = nn.Conv1d( in_channels=self.d_inner, out_channels=self.d_inner, kernel_size=d_conv, padding=d_conv - 1, groups=self.d_inner, bias=False, ) # SSM parameter projections (input-dependent) self.x_proj = nn.Linear(self.d_inner, self.dt_rank + 2 * self.d_state, bias=False) self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True) # State matrices (A is log-scale for stability) # A is (d_inner, d_state) self.A_log = nn.Parameter(torch.randn(self.d_inner, self.d_state)) self.D = nn.Parameter(torch.randn(self.d_inner)) # Output projection self.out_proj = nn.Linear(self.d_inner, d_model, bias=False) # Initialize weights self._initialize_weights() def _initialize_weights(self): """Initialize weights properly.""" # Initialize A_log with negative values for stable discretization nn.init.normal_(self.A_log, mean=-4.0, std=0.5) nn.init.normal_(self.D, mean=0.0, std=0.1) # Initialize projections with small values for module in [self.in_proj, self.x_proj, self.dt_proj, self.conv1d, self.out_proj]: if hasattr(module, 'weight'): nn.init.normal_(module.weight, mean=0.0, std=0.02) def forward( self, x: torch.Tensor, state: Optional[torch.Tensor] = None, return_state: bool = False, ) -> torch.Tensor: """ Forward pass through the SSM. Args: x: Input tensor (batch, seq_len, d_model) state: Previous hidden state (batch, d_inner, d_state) return_state: If True, return (output, state) Returns: Output tensor (batch, seq_len, d_model) or tuple with state """ batch, seq_len, _ = x.shape device = x.device dtype = x.dtype # Double-check d_inner matches A_log shape d_inner = self.d_inner # Project input to inner dimension xz = self.in_proj(x) # (batch, seq_len, 2 * d_inner) x, z = xz.chunk(2, dim=-1) # Apply 1D convolution for local context # Need to transpose for conv1d: (batch, d_inner, seq_len) x_conv = x.transpose(1, 2) x_conv = self.conv1d(x_conv)[..., :seq_len] # Trim padding x = x_conv.transpose(1, 2) # Discretization: compute delta, A, B parameters # x_proj produces: delta (dt_rank), B (d_state), C (d_state) x_dbl = self.x_proj(x) # (batch, seq_len, dt_rank + 2*d_state) (delta, B, C) = torch.split( x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1, ) # Project delta delta = self.dt_proj(delta) # (batch, seq_len, d_inner) delta = F.softplus(delta) # Compute discretized state recurrence # Use scan operation for efficient sequential processing if state is None: state = torch.zeros(batch, d_inner, self.d_state, device=device, dtype=dtype) # Sequential scan (can be optimized with CUDA kernel) output = [] for t in range(seq_len): x_t = x[:, t] # (batch, d_inner) delta_t = delta[:, t] # (batch, d_inner) B_t = B[:, t] # (batch, d_state) C_t = C[:, t] # (batch, d_state) # Discretize A A_delta = torch.exp(self.A_log * delta_t.unsqueeze(-1)) # (batch, d_inner, d_state) # State update: state = A_delta * state + B_t * x_t # B_t needs to be (batch, d_state) -> (batch, d_inner, d_state) via broadcasting state = A_delta * state + B_t.unsqueeze(1) * x_t.unsqueeze(-1) # Output: y = C_t * state + D * x_t y = (C_t.unsqueeze(1) * state).sum(dim=-1) + self.D * x_t output.append(y) output = torch.stack(output, dim=1) # (batch, seq_len, d_inner) # Apply gating with z output = output * F.silu(z) # Project back to model dimension output = self.out_proj(output) if return_state: return output, state return output def step( self, x: torch.Tensor, state: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Single-step inference for autoregressive decoding. Args: x: Input at current step (batch, d_model) state: Previous state (batch, d_inner, d_state) Returns: output: (batch, d_model) new_state: updated state """ batch, _ = x.shape # Project input xz = self.in_proj(x.unsqueeze(1)) # Add seq dim x, z = xz.chunk(2, dim=-1) x = x.squeeze(1) z = z.squeeze(1) # No convolution for single step (would need cache) # Compute parameters x_dbl = self.x_proj(x.unsqueeze(1)).squeeze(1) delta, B, C = torch.split( x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1, ) delta = self.dt_proj(delta) delta = F.softplus(delta) # Single step discretization A_delta = torch.exp(self.A_log * delta.unsqueeze(-1)) state = A_delta * state + B.unsqueeze(1) * x.unsqueeze(-1) y = (C.unsqueeze(1) * state).sum(dim=-1) + self.D * x y = y * F.silu(z) output = self.out_proj(y) return output, state def test_vortex_ssm(): """Test the VortexSSM layer.""" batch_size = 2 seq_len = 128 d_model = 4096 d_state = 16 ssm = VortexSSM(d_model, d_state=d_state) x = torch.randn(batch_size, seq_len, d_model) # Forward pass output = ssm(x) print(f"Input shape: {x.shape}") print(f"Output shape: {output.shape}") assert output.shape == x.shape, f"Expected {x.shape}, got {output.shape}" # Stateful forward state = torch.zeros(batch_size, ssm.d_inner, d_state) output2, new_state = ssm(x, state=state, return_state=True) print(f"Stateful output shape: {output2.shape}") print(f"State shape: {new_state.shape}") # Single step x_step = torch.randn(batch_size, d_model) output_step, state_step = ssm.step(x_step, state) print(f"Step output shape: {output_step.shape}") print(f"Step state shape: {state_step.shape}") print("VortexSSM test passed!") if __name__ == "__main__": test_vortex_ssm()