| ## Model Structure | |
| ```python | |
| class MambaTransformerSimple(nn.Module): | |
| def __init__( | |
| self, | |
| d_feat: int = 8, | |
| hidden_size: int = 64, | |
| num_layers: int = 1, | |
| dropout: float = 0.0, | |
| noise_level: float = 0.0, | |
| d_state: int = 16, | |
| d_conv: int = 4, | |
| expand: int = 2, | |
| mask_type: str = "none", | |
| ) -> None: | |
| super().__init__() | |
| self.mask_type = mask_type | |
| self.transformer_encoder_layer = nn.TransformerEncoderLayer( | |
| d_model=hidden_size, | |
| nhead=4, | |
| dim_feedforward=hidden_size * 4, | |
| dropout=dropout, | |
| activation="relu", | |
| batch_first=False, | |
| ) | |
| self.transformer_encoder = nn.TransformerEncoder( | |
| self.transformer_encoder_layer, num_layers=num_layers | |
| ) | |
| self.input_proj = nn.Linear(d_feat, hidden_size) | |
| self.mamba = Mamba( | |
| d_model=hidden_size, d_state=d_state, d_conv=d_conv, expand=expand | |
| ) | |
| self.mid_norm = nn.LayerNorm(hidden_size) | |
| self.out = nn.Sequential( | |
| nn.Linear(hidden_size, hidden_size), nn.GELU(), nn.Linear(hidden_size, 1) | |
| ) | |
| def _generate_causal_mask(self, seq_len: int, device: torch.device) -> torch.Tensor: | |
| """Generate causal attention mask.""" | |
| mask = torch.triu( | |
| torch.ones(seq_len, seq_len, device=device) * float("-inf"), diagonal=1 | |
| ) | |
| return mask | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| b, t, s, f = x.shape | |
| x = x.permute(0, 2, 1, 3).reshape(b * s, t, f) | |
| x = self.input_proj(x) # [b * s, t, h] | |
| mamba_out = self.mamba(x) # [b * s, t, h] | |
| mamba_out = mamba_out.permute(1, 0, 2).contiguous() # [t, b * s, h] | |
| mamba_out = self.mid_norm(mamba_out) | |
| if self.mask_type == "causal": | |
| mask = self._generate_causal_mask(t, x.device) | |
| else: | |
| mask = None | |
| tfm_out = self.transformer_encoder(mamba_out, mask=mask) # [t, b * s, h] | |
| tfm_out = tfm_out[-1].reshape(b, s, -1) | |
| final_out = self.out(tfm_out).squeeze(-1) # [b, s] | |
| return final_out | |
| ``` | |
| ## Model Config | |
| ```yaml | |
| num_layers: 1 | |
| d_feat: 8 | |
| hidden_size: 64 | |
| d_state: 16 | |
| d_conv: 4 | |
| expand: 2 | |
| dropout: 0.1 | |
| noise_level: 0.0 | |
| mask_type: "none" | |
| ``` |