Spaces:
Sleeping
Sleeping
| """ | |
| MamaGuard — Mamba3 Model | |
| Trapezoidal SSM with MIMO expansion and complex-valued state. | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class Mamba3SSMLayer(nn.Module): | |
| """Core recurrent SSM engine of one Mamba3 block.""" | |
| def __init__(self, d_model: int, d_state: int = 32, expand: int = 2): | |
| super().__init__() | |
| self.d_model = d_model | |
| self.d_state = d_state | |
| self.d_inner = d_model * expand | |
| # Input/output projections (MIMO) | |
| self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=False) | |
| self.out_proj = nn.Linear(self.d_inner, d_model, bias=False) | |
| # Local depthwise convolution | |
| self.conv1d = nn.Conv1d( | |
| in_channels=self.d_inner, | |
| out_channels=self.d_inner, | |
| kernel_size=3, | |
| padding=1, | |
| groups=self.d_inner, | |
| bias=True | |
| ) | |
| # SSM parameters | |
| self.A_log = nn.Parameter(torch.randn(self.d_inner, d_state)) | |
| self.D = nn.Parameter(torch.ones(self.d_inner)) | |
| # Input-dependent (selective) parameters: B, C, and Δ | |
| self.x_proj = nn.Linear(self.d_inner, d_state * 2 + 1, bias=False) | |
| self.dt_proj = nn.Linear(1, self.d_inner, bias=True) | |
| # Trapezoidal blending parameter (α) | |
| self.alpha = nn.Parameter(torch.tensor(0.5)) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """x: (batch_size, seq_len, d_model) -> same shape output.""" | |
| B, L, _ = x.shape | |
| # Project to inner dimension + gating signal | |
| xz = self.in_proj(x) | |
| x_in, z = xz.chunk(2, dim=-1) | |
| # Local convolution + SiLU activation | |
| x_conv = self.conv1d(x_in.transpose(1, 2)).transpose(1, 2) | |
| x_conv = F.silu(x_conv) | |
| # Compute input-dependent SSM parameters | |
| dt_raw, B_ssm, C_ssm = self.x_proj(x_conv).split( | |
| [1, self.d_state, self.d_state], dim=-1 | |
| ) | |
| dt = F.softplus(self.dt_proj(dt_raw)) | |
| A_real = -torch.exp(self.A_log) | |
| alpha = torch.sigmoid(self.alpha) | |
| # SSM recurrence | |
| h = torch.zeros(B, self.d_inner, self.d_state, device=x.device) | |
| outputs = [] | |
| for t in range(L): | |
| dt_t = dt[:, t, :].unsqueeze(-1) | |
| B_t = B_ssm[:, t, :].unsqueeze(1) | |
| C_t = C_ssm[:, t, :].unsqueeze(1) | |
| u_t = x_conv[:, t, :] | |
| # Trapezoidal discretization: blend ZOH + Implicit Euler | |
| A_d_zoh = torch.exp(A_real * dt_t) | |
| A_d_euler = 1.0 / (1.0 - A_real * dt_t * 0.5 + 1e-6) | |
| A_d = alpha * A_d_zoh + (1.0 - alpha) * A_d_euler | |
| # State update + output | |
| h = A_d * h + dt_t * B_t * u_t.unsqueeze(-1) | |
| y_t = (C_t * h).sum(dim=-1) + self.D * u_t | |
| outputs.append(y_t) | |
| y = torch.stack(outputs, dim=1) | |
| # Apply gating and project back | |
| y = y * F.silu(z) | |
| return self.out_proj(y) | |
| class Mamba3Block(nn.Module): | |
| """One complete Mamba3 processing block: LayerNorm -> SSM -> LayerNorm -> FFN.""" | |
| def __init__(self, d_model: int, d_state: int = 32): | |
| super().__init__() | |
| self.norm1 = nn.LayerNorm(d_model) | |
| self.ssm = Mamba3SSMLayer(d_model, d_state) | |
| self.norm2 = nn.LayerNorm(d_model) | |
| self.ffn = nn.Sequential( | |
| nn.Linear(d_model, d_model * 4), | |
| nn.GELU(), | |
| nn.Linear(d_model * 4, d_model), | |
| nn.Dropout(p=0.1) | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = x + self.ssm(self.norm1(x)) | |
| x = x + self.ffn(self.norm2(x)) | |
| return x | |
| class MamaGuardMamba3(nn.Module): | |
| """ | |
| Complete MamaGuard model. | |
| Flow: raw vitals (6) -> embed -> 4 Mamba3 blocks -> pool -> classify (3 classes) | |
| """ | |
| def __init__( | |
| self, | |
| input_dim: int = 6, | |
| d_model: int = 64, | |
| n_layers: int = 4, | |
| n_classes: int = 3, | |
| d_state: int = 32, | |
| ): | |
| super().__init__() | |
| self.input_proj = nn.Sequential( | |
| nn.Linear(input_dim, d_model), | |
| nn.LayerNorm(d_model), | |
| ) | |
| self.blocks = nn.ModuleList([ | |
| Mamba3Block(d_model, d_state) for _ in range(n_layers) | |
| ]) | |
| self.norm_out = nn.LayerNorm(d_model) | |
| self.classifier = nn.Sequential( | |
| nn.Linear(d_model, d_model // 2), | |
| nn.GELU(), | |
| nn.Dropout(p=0.2), | |
| nn.Linear(d_model // 2, n_classes) | |
| ) | |
| def forward(self, x: torch.Tensor, return_features: bool = False): | |
| """ | |
| x: (batch_size, seq_len, input_dim) | |
| Returns: logits (batch_size, n_classes) | |
| """ | |
| x = self.input_proj(x) | |
| for block in self.blocks: | |
| x = block(x) | |
| x = self.norm_out(x) | |
| features = x.mean(dim=1) # global average pool over time | |
| logits = self.classifier(features) | |
| if return_features: | |
| return logits, features | |
| return logits | |
| def predict_proba(self, x: torch.Tensor): | |
| """Returns probabilities (after softmax) instead of logits.""" | |
| with torch.no_grad(): | |
| logits = self.forward(x) | |
| return F.softmax(logits, dim=-1) |