Abner0803 commited on
Commit
cdfb1ae
·
verified ·
1 Parent(s): db2bbec

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +124 -0
README.md ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Model Structure
2
+
3
+ ```python
4
+ class MambaComp(nn.Module):
5
+ def __init__(
6
+ self,
7
+ enc_in: int,
8
+ c_out: int,
9
+ e_layers: int,
10
+ noise_level: float,
11
+ d_model: int,
12
+ d_ff: int,
13
+ d_state: int,
14
+ d_conv: int,
15
+ expand: int,
16
+ dropout: float = 0.0,
17
+ ) -> None:
18
+ super().__init__()
19
+
20
+ self.input_drop = nn.Dropout(dropout)
21
+
22
+ self.input_size = enc_in
23
+ self.output_size = c_out
24
+ self.num_layers = e_layers
25
+ self.noise_level = noise_level
26
+
27
+ self.mamba = nn.ModuleList(
28
+ [
29
+ Mamba(
30
+ d_model=d_model, # Model dimension d_model
31
+ d_state=d_state, # SSM state expansion factor 16
32
+ d_conv=d_conv, # Local convolution width 4
33
+ expand=expand, # Block expansion factor 2
34
+ )
35
+ for _ in range(self.num_layers)
36
+ ]
37
+ )
38
+
39
+ self.in_layer = nn.Linear(self.input_size, d_model)
40
+ self.layer_norm = nn.LayerNorm(d_model)
41
+
42
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
43
+ b, t, s, f = x.shape
44
+ x = self.input_drop(x)
45
+ x = x.permute(0, 2, 1, 3).reshape(b * s, t, f)
46
+
47
+ if self.training and self.noise_level > 0:
48
+ noise = torch.randn_like(x).to(x)
49
+ x = x + noise * self.noise_level
50
+
51
+ x = self.in_layer(x) # [b*s, t, d_model]
52
+ x = self.layer_norm(x)
53
+
54
+ for i in range(self.num_layers):
55
+ x = self.mamba[i](x) # [b*s, t, d_model]
56
+
57
+ out = x[:, -1, :].reshape(b, s, -1) # [b, s, d_model]
58
+
59
+ return out # [b, s, d_model]
60
+
61
+
62
+ class Mambav1(nn.Module):
63
+ def __init__(
64
+ self,
65
+ enc_in: int,
66
+ c_out: int,
67
+ e_layers: int,
68
+ noise_level: float,
69
+ d_model: int,
70
+ d_ff: int,
71
+ d_state: int,
72
+ d_conv: int,
73
+ expand: int,
74
+ dropout: float = 0.0,
75
+ ) -> None:
76
+ super().__init__()
77
+
78
+ self.input_drop = nn.Dropout(dropout)
79
+
80
+ self.input_size = enc_in
81
+ self.output_size = c_out
82
+ self.num_layers = e_layers
83
+ self.noise_level = noise_level
84
+
85
+ self.mamba = MambaComp(
86
+ enc_in=self.input_size,
87
+ c_out=self.output_size,
88
+ e_layers=self.num_layers,
89
+ noise_level=self.noise_level,
90
+ d_model=d_model,
91
+ d_ff=d_ff,
92
+ d_state=d_state,
93
+ d_conv=d_conv,
94
+ expand=expand,
95
+ dropout=dropout,
96
+ )
97
+ self.projection = nn.Linear(d_model, c_out, bias=True)
98
+
99
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
100
+ """
101
+ x.shape [b, t, s, f]
102
+ """
103
+ b, _, s, _ = x.shape
104
+ mamba_out = self.mamba(x) # [b, s, d_model]
105
+ out = self.projection(mamba_out) # [b*s, 1]
106
+ out = out.reshape(b, s, 1).squeeze(-1) # [b, s]
107
+
108
+ return out
109
+ ```
110
+
111
+ ## Model Config
112
+
113
+ ```yaml
114
+ e_layers: 1
115
+ enc_in: 8
116
+ c_out: 1
117
+ d_model: 64
118
+ d_ff: 64
119
+ d_state: 16
120
+ d_conv: 4
121
+ expand: 2
122
+ dropout: 0.1
123
+ noise_level: 0.0
124
+ ```