File size: 2,597 Bytes
47a8bb2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import torch
import torch.nn as nn
import torch.nn.functional as F

class MambaBlock(nn.Module):
    def __init__(self, d_model, d_state=8, d_conv=4, expand=1.5, dropout=0.1):
        super().__init__()
        self.d_inner = int(expand * d_model)
        self.dt_rank = d_model // 16
        self.d_state = d_state
        self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=False)
        self.conv1d = nn.Conv1d(self.d_inner, self.d_inner, d_conv, groups=self.d_inner, padding=d_conv-1, bias=True)
        self.x_proj = nn.Linear(self.d_inner, self.dt_rank + self.d_state * 2, bias=False)
        self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True)
        A = torch.arange(1, self.d_state + 1, dtype=torch.float32).repeat(self.d_inner, 1)
        self.A_log = nn.Parameter(torch.log(A))
        self.D = nn.Parameter(torch.ones(self.d_inner))
        self.out_proj = nn.Linear(self.d_inner, d_model, bias=False)
        self.act = nn.SiLU()
        self.drop = nn.Dropout(dropout)

    def forward(self, x):
        b, l, _ = x.shape
        x_and_res = self.in_proj(x)
        x_val, res = x_and_res.split([self.d_inner, self.d_inner], dim=-1)
        x_val = self.act(self.conv1d(x_val.transpose(1, 2))[..., :l].transpose(1, 2))
        x_dbl = self.x_proj(x_val)
        dt, B, C = x_dbl.split([self.dt_rank, self.d_state, self.d_state], dim=-1)
        dt = F.softplus(self.dt_proj(dt) + 1e-6)
        A = -torch.exp(self.A_log.float())
        h = torch.zeros(b, self.d_inner, self.d_state, device=x.device)
        y_list = []
        for t in range(l):
            dt_t = dt[:, t, :].unsqueeze(-1)
            dA, dB = torch.exp(dt_t * A), dt_t * B[:, t, :].unsqueeze(1)
            h = dA * h + dB * x_val[:, t, :].unsqueeze(-1)
            y_list.append(torch.sum(h * C[:, t, :].unsqueeze(1), dim=-1))
        y = torch.stack(y_list, dim=1) + x_val * self.D
        return self.drop(self.out_proj(y * self.act(res)))

class DroneMambaClassifier(nn.Module):
    def __init__(self, input_dim=1, num_classes=10, d_model=256, depth=6):
        super().__init__()
        self.embedding = nn.Linear(input_dim, d_model)
        self.layers = nn.ModuleList([
            nn.ModuleDict({'norm': nn.LayerNorm(d_model), 'block': MambaBlock(d_model)}) for _ in range(depth)
        ])
        self.norm_f = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, num_classes)

    def forward(self, x):
        x = self.embedding(x)
        for layer in self.layers:
            x = x + layer['block'](layer['norm'](x))
        return self.head(self.norm_f(x).mean(dim=1))