| | """
|
| | VortexSSM: Selective State-Space Layer
|
| | Simplified Mamba-style SSM with input-dependent selection.
|
| | Provides O(n) complexity for long sequences, ideal for scientific documents.
|
| | """
|
| |
|
| | import torch
|
| | import torch.nn as nn
|
| | import torch.nn.functional as F
|
| | from typing import Optional, Tuple
|
| |
|
| |
|
| | class VortexSSM(nn.Module):
|
| | """
|
| | Selective state-space layer. Linear complexity O(n) vs attention's O(n²).
|
| | Handles long scientific documents efficiently with input-dependent selection.
|
| |
|
| | Architecture based on Mamba but simplified for scientific reasoning tasks.
|
| | """
|
| |
|
| | def __init__(
|
| | self,
|
| | d_model: int,
|
| | d_state: int = 16,
|
| | d_conv: int = 4,
|
| | expand: int = 2,
|
| | dt_rank: Optional[int] = None,
|
| | ):
|
| | """
|
| | Initialize VortexSSM.
|
| |
|
| | Args:
|
| | d_model: Model dimension
|
| | d_state: State dimension (default 16 for 7B, 32 for 13B)
|
| | d_conv: Convolution kernel size for local context
|
| | expand: Expansion factor for inner dimension
|
| | dt_rank: Rank for delta projection (if None, uses ceil(d_model/16))
|
| | """
|
| | super().__init__()
|
| | self.d_model = d_model
|
| | self.d_state = d_state
|
| | self.d_conv = d_conv
|
| | self.expand = expand
|
| | self.d_inner = d_model * expand
|
| |
|
| | if dt_rank is None:
|
| | self.dt_rank = max(1, d_model // 16)
|
| | else:
|
| | self.dt_rank = dt_rank
|
| |
|
| |
|
| | self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=False)
|
| |
|
| |
|
| |
|
| | self.conv1d = nn.Conv1d(
|
| | in_channels=self.d_inner,
|
| | out_channels=self.d_inner,
|
| | kernel_size=d_conv,
|
| | padding=d_conv - 1,
|
| | groups=self.d_inner,
|
| | bias=False,
|
| | )
|
| |
|
| |
|
| | self.x_proj = nn.Linear(self.d_inner, self.dt_rank + 2 * self.d_state, bias=False)
|
| | self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True)
|
| |
|
| |
|
| |
|
| | self.A_log = nn.Parameter(torch.randn(self.d_inner, self.d_state))
|
| | self.D = nn.Parameter(torch.randn(self.d_inner))
|
| |
|
| |
|
| | self.out_proj = nn.Linear(self.d_inner, d_model, bias=False)
|
| |
|
| |
|
| | self._initialize_weights()
|
| |
|
| | def _initialize_weights(self):
|
| | """Initialize weights properly."""
|
| |
|
| | nn.init.normal_(self.A_log, mean=-4.0, std=0.5)
|
| | nn.init.normal_(self.D, mean=0.0, std=0.1)
|
| |
|
| |
|
| | for module in [self.in_proj, self.x_proj, self.dt_proj, self.conv1d, self.out_proj]:
|
| | if hasattr(module, 'weight'):
|
| | nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| |
|
| | def forward(
|
| | self,
|
| | x: torch.Tensor,
|
| | state: Optional[torch.Tensor] = None,
|
| | return_state: bool = False,
|
| | ) -> torch.Tensor:
|
| | """
|
| | Forward pass through the SSM.
|
| |
|
| | Args:
|
| | x: Input tensor (batch, seq_len, d_model)
|
| | state: Previous hidden state (batch, d_inner, d_state)
|
| | return_state: If True, return (output, state)
|
| |
|
| | Returns:
|
| | Output tensor (batch, seq_len, d_model) or tuple with state
|
| | """
|
| | batch, seq_len, _ = x.shape
|
| | device = x.device
|
| | dtype = x.dtype
|
| |
|
| |
|
| | d_inner = self.d_inner
|
| |
|
| |
|
| | xz = self.in_proj(x)
|
| | x, z = xz.chunk(2, dim=-1)
|
| |
|
| |
|
| |
|
| | x_conv = x.transpose(1, 2)
|
| | x_conv = self.conv1d(x_conv)[..., :seq_len]
|
| | x = x_conv.transpose(1, 2)
|
| |
|
| |
|
| |
|
| | x_dbl = self.x_proj(x)
|
| | (delta, B, C) = torch.split(
|
| | x_dbl,
|
| | [self.dt_rank, self.d_state, self.d_state],
|
| | dim=-1,
|
| | )
|
| |
|
| |
|
| | delta = self.dt_proj(delta)
|
| | delta = F.softplus(delta)
|
| |
|
| |
|
| |
|
| | if state is None:
|
| | state = torch.zeros(batch, d_inner, self.d_state, device=device, dtype=dtype)
|
| |
|
| |
|
| | output = []
|
| | for t in range(seq_len):
|
| | x_t = x[:, t]
|
| | delta_t = delta[:, t]
|
| | B_t = B[:, t]
|
| | C_t = C[:, t]
|
| |
|
| |
|
| | A_delta = torch.exp(self.A_log * delta_t.unsqueeze(-1))
|
| |
|
| |
|
| |
|
| | state = A_delta * state + B_t.unsqueeze(1) * x_t.unsqueeze(-1)
|
| |
|
| |
|
| | y = (C_t.unsqueeze(1) * state).sum(dim=-1) + self.D * x_t
|
| | output.append(y)
|
| |
|
| | output = torch.stack(output, dim=1)
|
| |
|
| |
|
| | output = output * F.silu(z)
|
| |
|
| |
|
| | output = self.out_proj(output)
|
| |
|
| | if return_state:
|
| | return output, state
|
| | return output
|
| |
|
| | def step(
|
| | self,
|
| | x: torch.Tensor,
|
| | state: torch.Tensor,
|
| | ) -> Tuple[torch.Tensor, torch.Tensor]:
|
| | """
|
| | Single-step inference for autoregressive decoding.
|
| |
|
| | Args:
|
| | x: Input at current step (batch, d_model)
|
| | state: Previous state (batch, d_inner, d_state)
|
| |
|
| | Returns:
|
| | output: (batch, d_model)
|
| | new_state: updated state
|
| | """
|
| | batch, _ = x.shape
|
| |
|
| |
|
| | xz = self.in_proj(x.unsqueeze(1))
|
| | x, z = xz.chunk(2, dim=-1)
|
| | x = x.squeeze(1)
|
| | z = z.squeeze(1)
|
| |
|
| |
|
| |
|
| |
|
| | x_dbl = self.x_proj(x.unsqueeze(1)).squeeze(1)
|
| | delta, B, C = torch.split(
|
| | x_dbl,
|
| | [self.dt_rank, self.d_state, self.d_state],
|
| | dim=-1,
|
| | )
|
| | delta = self.dt_proj(delta)
|
| | delta = F.softplus(delta)
|
| |
|
| |
|
| | A_delta = torch.exp(self.A_log * delta.unsqueeze(-1))
|
| | state = A_delta * state + B.unsqueeze(1) * x.unsqueeze(-1)
|
| | y = (C.unsqueeze(1) * state).sum(dim=-1) + self.D * x
|
| | y = y * F.silu(z)
|
| | output = self.out_proj(y)
|
| |
|
| | return output, state
|
| |
|
| |
|
| | def test_vortex_ssm():
|
| | """Test the VortexSSM layer."""
|
| | batch_size = 2
|
| | seq_len = 128
|
| | d_model = 4096
|
| | d_state = 16
|
| |
|
| | ssm = VortexSSM(d_model, d_state=d_state)
|
| | x = torch.randn(batch_size, seq_len, d_model)
|
| |
|
| |
|
| | output = ssm(x)
|
| | print(f"Input shape: {x.shape}")
|
| | print(f"Output shape: {output.shape}")
|
| | assert output.shape == x.shape, f"Expected {x.shape}, got {output.shape}"
|
| |
|
| |
|
| | state = torch.zeros(batch_size, ssm.d_inner, d_state)
|
| | output2, new_state = ssm(x, state=state, return_state=True)
|
| | print(f"Stateful output shape: {output2.shape}")
|
| | print(f"State shape: {new_state.shape}")
|
| |
|
| |
|
| | x_step = torch.randn(batch_size, d_model)
|
| | output_step, state_step = ssm.step(x_step, state)
|
| | print(f"Step output shape: {output_step.shape}")
|
| | print(f"Step state shape: {state_step.shape}")
|
| |
|
| | print("VortexSSM test passed!")
|
| |
|
| |
|
| | if __name__ == "__main__":
|
| | test_vortex_ssm()
|
| |
|