Abner0803 commited on
Commit
04544c9
·
verified ·
1 Parent(s): 3733b59

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +81 -0
README.md ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Model Structure
2
+
3
+ ```python
4
+ class MambaTransformerSimple(nn.Module):
5
+ def __init__(
6
+ self,
7
+ d_feat: int = 8,
8
+ hidden_size: int = 64,
9
+ num_layers: int = 1,
10
+ dropout: float = 0.0,
11
+ noise_level: float = 0.0,
12
+ d_state: int = 16,
13
+ d_conv: int = 4,
14
+ expand: int = 2,
15
+ mask_type: str = "none",
16
+ ) -> None:
17
+ super().__init__()
18
+ self.mask_type = mask_type
19
+ self.transformer_encoder_layer = nn.TransformerEncoderLayer(
20
+ d_model=hidden_size,
21
+ nhead=4,
22
+ dim_feedforward=hidden_size * 4,
23
+ dropout=dropout,
24
+ activation="relu",
25
+ batch_first=False,
26
+ )
27
+ # self.transformer_encoder = nn.TransformerEncoder(
28
+ # self.transformer_encoder_layer, num_layers=num_layers
29
+ # )
30
+ self.transformer_encoder = nn.TransformerEncoder(
31
+ self.transformer_encoder_layer, num_layers=2
32
+ )
33
+ self.input_proj = nn.Linear(d_feat, hidden_size)
34
+ self.mamba = Mamba(
35
+ d_model=hidden_size, d_state=d_state, d_conv=d_conv, expand=expand
36
+ )
37
+ self.mid_norm = nn.LayerNorm(hidden_size)
38
+ self.out = nn.Sequential(
39
+ nn.Linear(hidden_size, hidden_size), nn.GELU(), nn.Linear(hidden_size, 1)
40
+ )
41
+
42
+ def _generate_causal_mask(self, seq_len: int, device: torch.device) -> torch.Tensor:
43
+ """Generate causal attention mask."""
44
+ mask = torch.triu(
45
+ torch.ones(seq_len, seq_len, device=device) * float("-inf"), diagonal=1
46
+ )
47
+ return mask
48
+
49
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
50
+ b, t, s, f = x.shape
51
+ x = x.permute(0, 2, 1, 3).reshape(b * s, t, f)
52
+ x = self.input_proj(x) # [b * s, t, h]
53
+ mamba_out = self.mamba(x) # [b * s, t, h]
54
+ mamba_out = mamba_out.permute(1, 0, 2).contiguous() # [t, b * s, h]
55
+ mamba_out = self.mid_norm(mamba_out)
56
+
57
+ if self.mask_type == "causal":
58
+ mask = self._generate_causal_mask(t, x.device)
59
+ else:
60
+ mask = None
61
+
62
+ tfm_out = self.transformer_encoder(mamba_out, mask=mask) # [t, b * s, h]
63
+ tfm_out = tfm_out[-1].reshape(b, s, -1)
64
+ final_out = self.out(tfm_out).squeeze(-1) # [b, s]
65
+
66
+ return final_out
67
+ ```
68
+
69
+ ## Model Config
70
+
71
+ ```yaml
72
+ num_layers: 1
73
+ d_feat: 8
74
+ hidden_size: 64
75
+ d_state: 16
76
+ d_conv: 4
77
+ expand: 2
78
+ dropout: 0.1
79
+ noise_level: 0.0
80
+ mask_type: "none"
81
+ ```