Create README.md
Browse files
README.md
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## Model Structure
|
| 2 |
+
|
| 3 |
+
```python
|
| 4 |
+
class MambaComp(nn.Module):
|
| 5 |
+
def __init__(
|
| 6 |
+
self,
|
| 7 |
+
enc_in: int,
|
| 8 |
+
c_out: int,
|
| 9 |
+
e_layers: int,
|
| 10 |
+
noise_level: float,
|
| 11 |
+
d_model: int,
|
| 12 |
+
d_ff: int,
|
| 13 |
+
d_state: int,
|
| 14 |
+
d_conv: int,
|
| 15 |
+
expand: int,
|
| 16 |
+
dropout: float = 0.0,
|
| 17 |
+
) -> None:
|
| 18 |
+
super().__init__()
|
| 19 |
+
|
| 20 |
+
self.input_drop = nn.Dropout(dropout)
|
| 21 |
+
|
| 22 |
+
self.input_size = enc_in
|
| 23 |
+
self.output_size = c_out
|
| 24 |
+
self.num_layers = e_layers
|
| 25 |
+
self.noise_level = noise_level
|
| 26 |
+
|
| 27 |
+
self.mamba = nn.ModuleList(
|
| 28 |
+
[
|
| 29 |
+
Mamba(
|
| 30 |
+
d_model=d_model, # Model dimension d_model
|
| 31 |
+
d_state=d_state, # SSM state expansion factor 16
|
| 32 |
+
d_conv=d_conv, # Local convolution width 4
|
| 33 |
+
expand=expand, # Block expansion factor 2
|
| 34 |
+
)
|
| 35 |
+
for _ in range(self.num_layers)
|
| 36 |
+
]
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
self.in_layer = nn.Linear(self.input_size, d_model)
|
| 40 |
+
self.layer_norm = nn.LayerNorm(d_model)
|
| 41 |
+
|
| 42 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 43 |
+
b, t, s, f = x.shape
|
| 44 |
+
x = self.input_drop(x)
|
| 45 |
+
x = x.permute(0, 2, 1, 3).reshape(b * s, t, f)
|
| 46 |
+
|
| 47 |
+
if self.training and self.noise_level > 0:
|
| 48 |
+
noise = torch.randn_like(x).to(x)
|
| 49 |
+
x = x + noise * self.noise_level
|
| 50 |
+
|
| 51 |
+
x = self.in_layer(x) # [b*s, t, d_model]
|
| 52 |
+
x = self.layer_norm(x)
|
| 53 |
+
|
| 54 |
+
for i in range(self.num_layers):
|
| 55 |
+
x = self.mamba[i](x) # [b*s, t, d_model]
|
| 56 |
+
|
| 57 |
+
out = x[:, -1, :].reshape(b, s, -1) # [b, s, d_model]
|
| 58 |
+
|
| 59 |
+
return out # [b, s, d_model]
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class Mambav1(nn.Module):
|
| 63 |
+
def __init__(
|
| 64 |
+
self,
|
| 65 |
+
enc_in: int,
|
| 66 |
+
c_out: int,
|
| 67 |
+
e_layers: int,
|
| 68 |
+
noise_level: float,
|
| 69 |
+
d_model: int,
|
| 70 |
+
d_ff: int,
|
| 71 |
+
d_state: int,
|
| 72 |
+
d_conv: int,
|
| 73 |
+
expand: int,
|
| 74 |
+
dropout: float = 0.0,
|
| 75 |
+
) -> None:
|
| 76 |
+
super().__init__()
|
| 77 |
+
|
| 78 |
+
self.input_drop = nn.Dropout(dropout)
|
| 79 |
+
|
| 80 |
+
self.input_size = enc_in
|
| 81 |
+
self.output_size = c_out
|
| 82 |
+
self.num_layers = e_layers
|
| 83 |
+
self.noise_level = noise_level
|
| 84 |
+
|
| 85 |
+
self.mamba = MambaComp(
|
| 86 |
+
enc_in=self.input_size,
|
| 87 |
+
c_out=self.output_size,
|
| 88 |
+
e_layers=self.num_layers,
|
| 89 |
+
noise_level=self.noise_level,
|
| 90 |
+
d_model=d_model,
|
| 91 |
+
d_ff=d_ff,
|
| 92 |
+
d_state=d_state,
|
| 93 |
+
d_conv=d_conv,
|
| 94 |
+
expand=expand,
|
| 95 |
+
dropout=dropout,
|
| 96 |
+
)
|
| 97 |
+
self.projection = nn.Linear(d_model, c_out, bias=True)
|
| 98 |
+
|
| 99 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 100 |
+
"""
|
| 101 |
+
x.shape [b, t, s, f]
|
| 102 |
+
"""
|
| 103 |
+
b, _, s, _ = x.shape
|
| 104 |
+
mamba_out = self.mamba(x) # [b, s, d_model]
|
| 105 |
+
out = self.projection(mamba_out) # [b*s, 1]
|
| 106 |
+
out = out.reshape(b, s, 1).squeeze(-1) # [b, s]
|
| 107 |
+
|
| 108 |
+
return out
|
| 109 |
+
```
|
| 110 |
+
|
| 111 |
+
## Model Config
|
| 112 |
+
|
| 113 |
+
```yaml
|
| 114 |
+
e_layers: 1
|
| 115 |
+
enc_in: 8
|
| 116 |
+
c_out: 1
|
| 117 |
+
d_model: 64
|
| 118 |
+
d_ff: 64
|
| 119 |
+
d_state: 16
|
| 120 |
+
d_conv: 4
|
| 121 |
+
expand: 2
|
| 122 |
+
dropout: 0.1
|
| 123 |
+
noise_level: 0.0
|
| 124 |
+
```
|