| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class MambaBlock(nn.Module): | |
| def __init__(self, d_model, d_state=8, d_conv=4, expand=1.5, dropout=0.1): | |
| super().__init__() | |
| self.d_inner = int(expand * d_model) | |
| self.dt_rank = d_model // 16 | |
| self.d_state = d_state | |
| self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=False) | |
| self.conv1d = nn.Conv1d(self.d_inner, self.d_inner, d_conv, groups=self.d_inner, padding=d_conv-1, bias=True) | |
| self.x_proj = nn.Linear(self.d_inner, self.dt_rank + self.d_state * 2, bias=False) | |
| self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True) | |
| A = torch.arange(1, self.d_state + 1, dtype=torch.float32).repeat(self.d_inner, 1) | |
| self.A_log = nn.Parameter(torch.log(A)) | |
| self.D = nn.Parameter(torch.ones(self.d_inner)) | |
| self.out_proj = nn.Linear(self.d_inner, d_model, bias=False) | |
| self.act = nn.SiLU() | |
| self.drop = nn.Dropout(dropout) | |
| def forward(self, x): | |
| b, l, _ = x.shape | |
| x_and_res = self.in_proj(x) | |
| x_val, res = x_and_res.split([self.d_inner, self.d_inner], dim=-1) | |
| x_val = self.act(self.conv1d(x_val.transpose(1, 2))[..., :l].transpose(1, 2)) | |
| x_dbl = self.x_proj(x_val) | |
| dt, B, C = x_dbl.split([self.dt_rank, self.d_state, self.d_state], dim=-1) | |
| dt = F.softplus(self.dt_proj(dt) + 1e-6) | |
| A = -torch.exp(self.A_log.float()) | |
| h = torch.zeros(b, self.d_inner, self.d_state, device=x.device) | |
| y_list = [] | |
| for t in range(l): | |
| dt_t = dt[:, t, :].unsqueeze(-1) | |
| dA, dB = torch.exp(dt_t * A), dt_t * B[:, t, :].unsqueeze(1) | |
| h = dA * h + dB * x_val[:, t, :].unsqueeze(-1) | |
| y_list.append(torch.sum(h * C[:, t, :].unsqueeze(1), dim=-1)) | |
| y = torch.stack(y_list, dim=1) + x_val * self.D | |
| return self.drop(self.out_proj(y * self.act(res))) | |
| class DroneMambaClassifier(nn.Module): | |
| def __init__(self, input_dim=1, num_classes=10, d_model=256, depth=6): | |
| super().__init__() | |
| self.embedding = nn.Linear(input_dim, d_model) | |
| self.layers = nn.ModuleList([ | |
| nn.ModuleDict({'norm': nn.LayerNorm(d_model), 'block': MambaBlock(d_model)}) for _ in range(depth) | |
| ]) | |
| self.norm_f = nn.LayerNorm(d_model) | |
| self.head = nn.Linear(d_model, num_classes) | |
| def forward(self, x): | |
| x = self.embedding(x) | |
| for layer in self.layers: | |
| x = x + layer['block'](layer['norm'](x)) | |
| return self.head(self.norm_f(x).mean(dim=1)) | |