Vortex-7b-V1 / models /ssm_layer.py
Zandy-Wandy's picture
Upload Vortex model
bf64b03 verified
"""
VortexSSM: Selective State-Space Layer
Simplified Mamba-style SSM with input-dependent selection.
Provides O(n) complexity for long sequences, ideal for scientific documents.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple
class VortexSSM(nn.Module):
"""
Selective state-space layer. Linear complexity O(n) vs attention's O(n²).
Handles long scientific documents efficiently with input-dependent selection.
Architecture based on Mamba but simplified for scientific reasoning tasks.
"""
def __init__(
self,
d_model: int,
d_state: int = 16,
d_conv: int = 4,
expand: int = 2,
dt_rank: Optional[int] = None,
):
"""
Initialize VortexSSM.
Args:
d_model: Model dimension
d_state: State dimension (default 16 for 7B, 32 for 13B)
d_conv: Convolution kernel size for local context
expand: Expansion factor for inner dimension
dt_rank: Rank for delta projection (if None, uses ceil(d_model/16))
"""
super().__init__()
self.d_model = d_model
self.d_state = d_state
self.d_conv = d_conv
self.expand = expand
self.d_inner = d_model * expand
if dt_rank is None:
self.dt_rank = max(1, d_model // 16)
else:
self.dt_rank = dt_rank
# Input projection: splits into x and z pathways
self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=False)
# Convolution for local context before SSM
# Depthwise convolution for efficiency
self.conv1d = nn.Conv1d(
in_channels=self.d_inner,
out_channels=self.d_inner,
kernel_size=d_conv,
padding=d_conv - 1,
groups=self.d_inner,
bias=False,
)
# SSM parameter projections (input-dependent)
self.x_proj = nn.Linear(self.d_inner, self.dt_rank + 2 * self.d_state, bias=False)
self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True)
# State matrices (A is log-scale for stability)
# A is (d_inner, d_state)
self.A_log = nn.Parameter(torch.randn(self.d_inner, self.d_state))
self.D = nn.Parameter(torch.randn(self.d_inner))
# Output projection
self.out_proj = nn.Linear(self.d_inner, d_model, bias=False)
# Initialize weights
self._initialize_weights()
def _initialize_weights(self):
"""Initialize weights properly."""
# Initialize A_log with negative values for stable discretization
nn.init.normal_(self.A_log, mean=-4.0, std=0.5)
nn.init.normal_(self.D, mean=0.0, std=0.1)
# Initialize projections with small values
for module in [self.in_proj, self.x_proj, self.dt_proj, self.conv1d, self.out_proj]:
if hasattr(module, 'weight'):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward(
self,
x: torch.Tensor,
state: Optional[torch.Tensor] = None,
return_state: bool = False,
) -> torch.Tensor:
"""
Forward pass through the SSM.
Args:
x: Input tensor (batch, seq_len, d_model)
state: Previous hidden state (batch, d_inner, d_state)
return_state: If True, return (output, state)
Returns:
Output tensor (batch, seq_len, d_model) or tuple with state
"""
batch, seq_len, _ = x.shape
device = x.device
dtype = x.dtype
# Double-check d_inner matches A_log shape
d_inner = self.d_inner
# Project input to inner dimension
xz = self.in_proj(x) # (batch, seq_len, 2 * d_inner)
x, z = xz.chunk(2, dim=-1)
# Apply 1D convolution for local context
# Need to transpose for conv1d: (batch, d_inner, seq_len)
x_conv = x.transpose(1, 2)
x_conv = self.conv1d(x_conv)[..., :seq_len] # Trim padding
x = x_conv.transpose(1, 2)
# Discretization: compute delta, A, B parameters
# x_proj produces: delta (dt_rank), B (d_state), C (d_state)
x_dbl = self.x_proj(x) # (batch, seq_len, dt_rank + 2*d_state)
(delta, B, C) = torch.split(
x_dbl,
[self.dt_rank, self.d_state, self.d_state],
dim=-1,
)
# Project delta
delta = self.dt_proj(delta) # (batch, seq_len, d_inner)
delta = F.softplus(delta)
# Compute discretized state recurrence
# Use scan operation for efficient sequential processing
if state is None:
state = torch.zeros(batch, d_inner, self.d_state, device=device, dtype=dtype)
# Sequential scan (can be optimized with CUDA kernel)
output = []
for t in range(seq_len):
x_t = x[:, t] # (batch, d_inner)
delta_t = delta[:, t] # (batch, d_inner)
B_t = B[:, t] # (batch, d_state)
C_t = C[:, t] # (batch, d_state)
# Discretize A
A_delta = torch.exp(self.A_log * delta_t.unsqueeze(-1)) # (batch, d_inner, d_state)
# State update: state = A_delta * state + B_t * x_t
# B_t needs to be (batch, d_state) -> (batch, d_inner, d_state) via broadcasting
state = A_delta * state + B_t.unsqueeze(1) * x_t.unsqueeze(-1)
# Output: y = C_t * state + D * x_t
y = (C_t.unsqueeze(1) * state).sum(dim=-1) + self.D * x_t
output.append(y)
output = torch.stack(output, dim=1) # (batch, seq_len, d_inner)
# Apply gating with z
output = output * F.silu(z)
# Project back to model dimension
output = self.out_proj(output)
if return_state:
return output, state
return output
def step(
self,
x: torch.Tensor,
state: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Single-step inference for autoregressive decoding.
Args:
x: Input at current step (batch, d_model)
state: Previous state (batch, d_inner, d_state)
Returns:
output: (batch, d_model)
new_state: updated state
"""
batch, _ = x.shape
# Project input
xz = self.in_proj(x.unsqueeze(1)) # Add seq dim
x, z = xz.chunk(2, dim=-1)
x = x.squeeze(1)
z = z.squeeze(1)
# No convolution for single step (would need cache)
# Compute parameters
x_dbl = self.x_proj(x.unsqueeze(1)).squeeze(1)
delta, B, C = torch.split(
x_dbl,
[self.dt_rank, self.d_state, self.d_state],
dim=-1,
)
delta = self.dt_proj(delta)
delta = F.softplus(delta)
# Single step discretization
A_delta = torch.exp(self.A_log * delta.unsqueeze(-1))
state = A_delta * state + B.unsqueeze(1) * x.unsqueeze(-1)
y = (C.unsqueeze(1) * state).sum(dim=-1) + self.D * x
y = y * F.silu(z)
output = self.out_proj(y)
return output, state
def test_vortex_ssm():
"""Test the VortexSSM layer."""
batch_size = 2
seq_len = 128
d_model = 4096
d_state = 16
ssm = VortexSSM(d_model, d_state=d_state)
x = torch.randn(batch_size, seq_len, d_model)
# Forward pass
output = ssm(x)
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
assert output.shape == x.shape, f"Expected {x.shape}, got {output.shape}"
# Stateful forward
state = torch.zeros(batch_size, ssm.d_inner, d_state)
output2, new_state = ssm(x, state=state, return_state=True)
print(f"Stateful output shape: {output2.shape}")
print(f"State shape: {new_state.shape}")
# Single step
x_step = torch.randn(batch_size, d_model)
output_step, state_step = ssm.step(x_step, state)
print(f"Step output shape: {output_step.shape}")
print(f"Step state shape: {state_step.shape}")
print("VortexSSM test passed!")
if __name__ == "__main__":
test_vortex_ssm()