## Model Structure ```python class MambaTransformerSimple(nn.Module): def __init__( self, d_feat: int = 8, hidden_size: int = 64, num_layers: int = 1, dropout: float = 0.0, noise_level: float = 0.0, d_state: int = 16, d_conv: int = 4, expand: int = 2, mask_type: str = "none", ) -> None: super().__init__() self.mask_type = mask_type self.transformer_encoder_layer = nn.TransformerEncoderLayer( d_model=hidden_size, nhead=4, dim_feedforward=hidden_size * 4, dropout=dropout, activation="relu", batch_first=False, ) self.transformer_encoder = nn.TransformerEncoder( self.transformer_encoder_layer, num_layers=num_layers ) self.input_proj = nn.Linear(d_feat, hidden_size) self.mamba = Mamba( d_model=hidden_size, d_state=d_state, d_conv=d_conv, expand=expand ) self.mid_norm = nn.LayerNorm(hidden_size) self.out = nn.Sequential( nn.Linear(hidden_size, hidden_size), nn.GELU(), nn.Linear(hidden_size, 1) ) def _generate_causal_mask(self, seq_len: int, device: torch.device) -> torch.Tensor: """Generate causal attention mask.""" mask = torch.triu( torch.ones(seq_len, seq_len, device=device) * float("-inf"), diagonal=1 ) return mask def forward(self, x: torch.Tensor) -> torch.Tensor: b, t, s, f = x.shape x = x.permute(0, 2, 1, 3).reshape(b * s, t, f) x = self.input_proj(x) # [b * s, t, h] mamba_out = self.mamba(x) # [b * s, t, h] mamba_out = mamba_out.permute(1, 0, 2).contiguous() # [t, b * s, h] mamba_out = self.mid_norm(mamba_out) if self.mask_type == "causal": mask = self._generate_causal_mask(t, x.device) else: mask = None tfm_out = self.transformer_encoder(mamba_out, mask=mask) # [t, b * s, h] tfm_out = tfm_out[-1].reshape(b, s, -1) final_out = self.out(tfm_out).squeeze(-1) # [b, s] return final_out ``` ## Model Config ```yaml num_layers: 1 d_feat: 8 hidden_size: 64 d_state: 16 d_conv: 4 expand: 2 dropout: 0.1 noise_level: 0.0 mask_type: "none" ```