File size: 2,597 Bytes
47a8bb2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 | import torch
import torch.nn as nn
import torch.nn.functional as F
class MambaBlock(nn.Module):
def __init__(self, d_model, d_state=8, d_conv=4, expand=1.5, dropout=0.1):
super().__init__()
self.d_inner = int(expand * d_model)
self.dt_rank = d_model // 16
self.d_state = d_state
self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=False)
self.conv1d = nn.Conv1d(self.d_inner, self.d_inner, d_conv, groups=self.d_inner, padding=d_conv-1, bias=True)
self.x_proj = nn.Linear(self.d_inner, self.dt_rank + self.d_state * 2, bias=False)
self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True)
A = torch.arange(1, self.d_state + 1, dtype=torch.float32).repeat(self.d_inner, 1)
self.A_log = nn.Parameter(torch.log(A))
self.D = nn.Parameter(torch.ones(self.d_inner))
self.out_proj = nn.Linear(self.d_inner, d_model, bias=False)
self.act = nn.SiLU()
self.drop = nn.Dropout(dropout)
def forward(self, x):
b, l, _ = x.shape
x_and_res = self.in_proj(x)
x_val, res = x_and_res.split([self.d_inner, self.d_inner], dim=-1)
x_val = self.act(self.conv1d(x_val.transpose(1, 2))[..., :l].transpose(1, 2))
x_dbl = self.x_proj(x_val)
dt, B, C = x_dbl.split([self.dt_rank, self.d_state, self.d_state], dim=-1)
dt = F.softplus(self.dt_proj(dt) + 1e-6)
A = -torch.exp(self.A_log.float())
h = torch.zeros(b, self.d_inner, self.d_state, device=x.device)
y_list = []
for t in range(l):
dt_t = dt[:, t, :].unsqueeze(-1)
dA, dB = torch.exp(dt_t * A), dt_t * B[:, t, :].unsqueeze(1)
h = dA * h + dB * x_val[:, t, :].unsqueeze(-1)
y_list.append(torch.sum(h * C[:, t, :].unsqueeze(1), dim=-1))
y = torch.stack(y_list, dim=1) + x_val * self.D
return self.drop(self.out_proj(y * self.act(res)))
class DroneMambaClassifier(nn.Module):
def __init__(self, input_dim=1, num_classes=10, d_model=256, depth=6):
super().__init__()
self.embedding = nn.Linear(input_dim, d_model)
self.layers = nn.ModuleList([
nn.ModuleDict({'norm': nn.LayerNorm(d_model), 'block': MambaBlock(d_model)}) for _ in range(depth)
])
self.norm_f = nn.LayerNorm(d_model)
self.head = nn.Linear(d_model, num_classes)
def forward(self, x):
x = self.embedding(x)
for layer in self.layers:
x = x + layer['block'](layer['norm'](x))
return self.head(self.norm_f(x).mean(dim=1))
|