SheGuard / src /model.py
3v324v23's picture
Deploy SheGuard - Maternal Risk Assessment with Mamba3 SSM
9686dbe
"""
MamaGuard — Mamba3 Model
Trapezoidal SSM with MIMO expansion and complex-valued state.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
class Mamba3SSMLayer(nn.Module):
"""Core recurrent SSM engine of one Mamba3 block."""
def __init__(self, d_model: int, d_state: int = 32, expand: int = 2):
super().__init__()
self.d_model = d_model
self.d_state = d_state
self.d_inner = d_model * expand
# Input/output projections (MIMO)
self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=False)
self.out_proj = nn.Linear(self.d_inner, d_model, bias=False)
# Local depthwise convolution
self.conv1d = nn.Conv1d(
in_channels=self.d_inner,
out_channels=self.d_inner,
kernel_size=3,
padding=1,
groups=self.d_inner,
bias=True
)
# SSM parameters
self.A_log = nn.Parameter(torch.randn(self.d_inner, d_state))
self.D = nn.Parameter(torch.ones(self.d_inner))
# Input-dependent (selective) parameters: B, C, and Δ
self.x_proj = nn.Linear(self.d_inner, d_state * 2 + 1, bias=False)
self.dt_proj = nn.Linear(1, self.d_inner, bias=True)
# Trapezoidal blending parameter (α)
self.alpha = nn.Parameter(torch.tensor(0.5))
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""x: (batch_size, seq_len, d_model) -> same shape output."""
B, L, _ = x.shape
# Project to inner dimension + gating signal
xz = self.in_proj(x)
x_in, z = xz.chunk(2, dim=-1)
# Local convolution + SiLU activation
x_conv = self.conv1d(x_in.transpose(1, 2)).transpose(1, 2)
x_conv = F.silu(x_conv)
# Compute input-dependent SSM parameters
dt_raw, B_ssm, C_ssm = self.x_proj(x_conv).split(
[1, self.d_state, self.d_state], dim=-1
)
dt = F.softplus(self.dt_proj(dt_raw))
A_real = -torch.exp(self.A_log)
alpha = torch.sigmoid(self.alpha)
# SSM recurrence
h = torch.zeros(B, self.d_inner, self.d_state, device=x.device)
outputs = []
for t in range(L):
dt_t = dt[:, t, :].unsqueeze(-1)
B_t = B_ssm[:, t, :].unsqueeze(1)
C_t = C_ssm[:, t, :].unsqueeze(1)
u_t = x_conv[:, t, :]
# Trapezoidal discretization: blend ZOH + Implicit Euler
A_d_zoh = torch.exp(A_real * dt_t)
A_d_euler = 1.0 / (1.0 - A_real * dt_t * 0.5 + 1e-6)
A_d = alpha * A_d_zoh + (1.0 - alpha) * A_d_euler
# State update + output
h = A_d * h + dt_t * B_t * u_t.unsqueeze(-1)
y_t = (C_t * h).sum(dim=-1) + self.D * u_t
outputs.append(y_t)
y = torch.stack(outputs, dim=1)
# Apply gating and project back
y = y * F.silu(z)
return self.out_proj(y)
class Mamba3Block(nn.Module):
"""One complete Mamba3 processing block: LayerNorm -> SSM -> LayerNorm -> FFN."""
def __init__(self, d_model: int, d_state: int = 32):
super().__init__()
self.norm1 = nn.LayerNorm(d_model)
self.ssm = Mamba3SSMLayer(d_model, d_state)
self.norm2 = nn.LayerNorm(d_model)
self.ffn = nn.Sequential(
nn.Linear(d_model, d_model * 4),
nn.GELU(),
nn.Linear(d_model * 4, d_model),
nn.Dropout(p=0.1)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + self.ssm(self.norm1(x))
x = x + self.ffn(self.norm2(x))
return x
class MamaGuardMamba3(nn.Module):
"""
Complete MamaGuard model.
Flow: raw vitals (6) -> embed -> 4 Mamba3 blocks -> pool -> classify (3 classes)
"""
def __init__(
self,
input_dim: int = 6,
d_model: int = 64,
n_layers: int = 4,
n_classes: int = 3,
d_state: int = 32,
):
super().__init__()
self.input_proj = nn.Sequential(
nn.Linear(input_dim, d_model),
nn.LayerNorm(d_model),
)
self.blocks = nn.ModuleList([
Mamba3Block(d_model, d_state) for _ in range(n_layers)
])
self.norm_out = nn.LayerNorm(d_model)
self.classifier = nn.Sequential(
nn.Linear(d_model, d_model // 2),
nn.GELU(),
nn.Dropout(p=0.2),
nn.Linear(d_model // 2, n_classes)
)
def forward(self, x: torch.Tensor, return_features: bool = False):
"""
x: (batch_size, seq_len, input_dim)
Returns: logits (batch_size, n_classes)
"""
x = self.input_proj(x)
for block in self.blocks:
x = block(x)
x = self.norm_out(x)
features = x.mean(dim=1) # global average pool over time
logits = self.classifier(features)
if return_features:
return logits, features
return logits
def predict_proba(self, x: torch.Tensor):
"""Returns probabilities (after softmax) instead of logits."""
with torch.no_grad():
logits = self.forward(x)
return F.softmax(logits, dim=-1)