import torch import torch.nn as nn import torch.nn.functional as F class MambaBlock(nn.Module): def __init__(self, d_model, d_state=8, d_conv=4, expand=1.5, dropout=0.1): super().__init__() self.d_inner = int(expand * d_model) self.dt_rank = d_model // 16 self.d_state = d_state self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=False) self.conv1d = nn.Conv1d(self.d_inner, self.d_inner, d_conv, groups=self.d_inner, padding=d_conv-1, bias=True) self.x_proj = nn.Linear(self.d_inner, self.dt_rank + self.d_state * 2, bias=False) self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True) A = torch.arange(1, self.d_state + 1, dtype=torch.float32).repeat(self.d_inner, 1) self.A_log = nn.Parameter(torch.log(A)) self.D = nn.Parameter(torch.ones(self.d_inner)) self.out_proj = nn.Linear(self.d_inner, d_model, bias=False) self.act = nn.SiLU() self.drop = nn.Dropout(dropout) def forward(self, x): b, l, _ = x.shape x_and_res = self.in_proj(x) x_val, res = x_and_res.split([self.d_inner, self.d_inner], dim=-1) x_val = self.act(self.conv1d(x_val.transpose(1, 2))[..., :l].transpose(1, 2)) x_dbl = self.x_proj(x_val) dt, B, C = x_dbl.split([self.dt_rank, self.d_state, self.d_state], dim=-1) dt = F.softplus(self.dt_proj(dt) + 1e-6) A = -torch.exp(self.A_log.float()) h = torch.zeros(b, self.d_inner, self.d_state, device=x.device) y_list = [] for t in range(l): dt_t = dt[:, t, :].unsqueeze(-1) dA, dB = torch.exp(dt_t * A), dt_t * B[:, t, :].unsqueeze(1) h = dA * h + dB * x_val[:, t, :].unsqueeze(-1) y_list.append(torch.sum(h * C[:, t, :].unsqueeze(1), dim=-1)) y = torch.stack(y_list, dim=1) + x_val * self.D return self.drop(self.out_proj(y * self.act(res))) class DroneMambaClassifier(nn.Module): def __init__(self, input_dim=1, num_classes=10, d_model=256, depth=6): super().__init__() self.embedding = nn.Linear(input_dim, d_model) self.layers = nn.ModuleList([ nn.ModuleDict({'norm': nn.LayerNorm(d_model), 'block': MambaBlock(d_model)}) for _ in range(depth) ]) self.norm_f = nn.LayerNorm(d_model) self.head = nn.Linear(d_model, num_classes) def forward(self, x): x = self.embedding(x) for layer in self.layers: x = x + layer['block'](layer['norm'](x)) return self.head(self.norm_f(x).mean(dim=1))