File size: 5,267 Bytes
9686dbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
"""
MamaGuard — Mamba3 Model
Trapezoidal SSM with MIMO expansion and complex-valued state.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F


class Mamba3SSMLayer(nn.Module):
    """Core recurrent SSM engine of one Mamba3 block."""

    def __init__(self, d_model: int, d_state: int = 32, expand: int = 2):
        super().__init__()
        self.d_model = d_model
        self.d_state = d_state
        self.d_inner = d_model * expand

        # Input/output projections (MIMO)
        self.in_proj  = nn.Linear(d_model, self.d_inner * 2, bias=False)
        self.out_proj = nn.Linear(self.d_inner, d_model, bias=False)

        # Local depthwise convolution
        self.conv1d = nn.Conv1d(
            in_channels=self.d_inner,
            out_channels=self.d_inner,
            kernel_size=3,
            padding=1,
            groups=self.d_inner,
            bias=True
        )

        # SSM parameters
        self.A_log = nn.Parameter(torch.randn(self.d_inner, d_state))
        self.D = nn.Parameter(torch.ones(self.d_inner))

        # Input-dependent (selective) parameters: B, C, and Δ
        self.x_proj = nn.Linear(self.d_inner, d_state * 2 + 1, bias=False)
        self.dt_proj = nn.Linear(1, self.d_inner, bias=True)

        # Trapezoidal blending parameter (α)
        self.alpha = nn.Parameter(torch.tensor(0.5))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """x: (batch_size, seq_len, d_model) -> same shape output."""
        B, L, _ = x.shape

        # Project to inner dimension + gating signal
        xz = self.in_proj(x)
        x_in, z = xz.chunk(2, dim=-1)

        # Local convolution + SiLU activation
        x_conv = self.conv1d(x_in.transpose(1, 2)).transpose(1, 2)
        x_conv = F.silu(x_conv)

        # Compute input-dependent SSM parameters
        dt_raw, B_ssm, C_ssm = self.x_proj(x_conv).split(
            [1, self.d_state, self.d_state], dim=-1
        )

        dt = F.softplus(self.dt_proj(dt_raw))
        A_real = -torch.exp(self.A_log)
        alpha = torch.sigmoid(self.alpha)

        # SSM recurrence
        h = torch.zeros(B, self.d_inner, self.d_state, device=x.device)
        outputs = []

        for t in range(L):
            dt_t = dt[:, t, :].unsqueeze(-1)
            B_t  = B_ssm[:, t, :].unsqueeze(1)
            C_t  = C_ssm[:, t, :].unsqueeze(1)
            u_t  = x_conv[:, t, :]

            # Trapezoidal discretization: blend ZOH + Implicit Euler
            A_d_zoh   = torch.exp(A_real * dt_t)
            A_d_euler = 1.0 / (1.0 - A_real * dt_t * 0.5 + 1e-6)
            A_d = alpha * A_d_zoh + (1.0 - alpha) * A_d_euler

            # State update + output
            h = A_d * h + dt_t * B_t * u_t.unsqueeze(-1)
            y_t = (C_t * h).sum(dim=-1) + self.D * u_t
            outputs.append(y_t)

        y = torch.stack(outputs, dim=1)

        # Apply gating and project back
        y = y * F.silu(z)
        return self.out_proj(y)


class Mamba3Block(nn.Module):
    """One complete Mamba3 processing block: LayerNorm -> SSM -> LayerNorm -> FFN."""

    def __init__(self, d_model: int, d_state: int = 32):
        super().__init__()
        self.norm1 = nn.LayerNorm(d_model)
        self.ssm   = Mamba3SSMLayer(d_model, d_state)
        self.norm2 = nn.LayerNorm(d_model)

        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_model * 4),
            nn.GELU(),
            nn.Linear(d_model * 4, d_model),
            nn.Dropout(p=0.1)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + self.ssm(self.norm1(x))
        x = x + self.ffn(self.norm2(x))
        return x


class MamaGuardMamba3(nn.Module):
    """
    Complete MamaGuard model.
    Flow: raw vitals (6) -> embed -> 4 Mamba3 blocks -> pool -> classify (3 classes)
    """

    def __init__(
        self,
        input_dim:  int = 6,
        d_model:    int = 64,
        n_layers:   int = 4,
        n_classes:  int = 3,
        d_state:    int = 32,
    ):
        super().__init__()

        self.input_proj = nn.Sequential(
            nn.Linear(input_dim, d_model),
            nn.LayerNorm(d_model),
        )

        self.blocks = nn.ModuleList([
            Mamba3Block(d_model, d_state) for _ in range(n_layers)
        ])

        self.norm_out = nn.LayerNorm(d_model)

        self.classifier = nn.Sequential(
            nn.Linear(d_model, d_model // 2),
            nn.GELU(),
            nn.Dropout(p=0.2),
            nn.Linear(d_model // 2, n_classes)
        )

    def forward(self, x: torch.Tensor, return_features: bool = False):
        """
        x: (batch_size, seq_len, input_dim)
        Returns: logits (batch_size, n_classes)
        """
        x = self.input_proj(x)

        for block in self.blocks:
            x = block(x)

        x = self.norm_out(x)
        features = x.mean(dim=1)       # global average pool over time
        logits = self.classifier(features)

        if return_features:
            return logits, features
        return logits

    def predict_proba(self, x: torch.Tensor):
        """Returns probabilities (after softmax) instead of logits."""
        with torch.no_grad():
            logits = self.forward(x)
            return F.softmax(logits, dim=-1)