Abner0803 commited on
Commit
b5fde4b
·
verified ·
1 Parent(s): 35fbe61

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +55 -0
README.md ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Model Structure
2
+
3
+ ```python
4
+ class GRUTransformerSimple(nn.Module):
5
+ def __init__(
6
+ self,
7
+ d_feat: int = 8,
8
+ hidden_size: int = 64,
9
+ num_layers: int = 1,
10
+ dropout: float = 0.0,
11
+ ) -> None:
12
+ super().__init__()
13
+ self.transformer_encoder_layer = nn.TransformerEncoderLayer(
14
+ d_model=hidden_size,
15
+ nhead=4,
16
+ dim_feedforward=hidden_size * 4,
17
+ dropout=dropout,
18
+ activation="relu",
19
+ batch_first=False,
20
+ )
21
+ self.transformer_encoder = nn.TransformerEncoder(
22
+ self.transformer_encoder_layer, num_layers=num_layers
23
+ )
24
+ self.gru = nn.GRU(
25
+ input_size=d_feat,
26
+ hidden_size=hidden_size,
27
+ num_layers=num_layers,
28
+ batch_first=True,
29
+ dropout=dropout,
30
+ )
31
+ self.out = nn.Sequential(
32
+ nn.Linear(hidden_size, hidden_size), nn.GELU(), nn.Linear(hidden_size, 1)
33
+ )
34
+
35
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
36
+ b, t, s, f = x.shape
37
+ x = x.permute(0, 2, 1, 3).reshape(b * s, t, f)
38
+ gru_out, _ = self.gru(x) # [b * s, t, h]
39
+ gru_out = gru_out.permute(1, 0, 2).contiguous() # [t, b * s, h]
40
+ tfm_out = self.transformer_encoder(gru_out) # [t, b * s, h]
41
+ tfm_out = tfm_out[-1].reshape(b, s, -1) # [b, s, h]
42
+ final_out = self.out(tfm_out).squeeze(-1) # [b, s]
43
+
44
+ return final_out
45
+
46
+ ```
47
+
48
+ ## Model Config
49
+
50
+ ```yaml
51
+ d_feat: 8
52
+ hidden_size: 64
53
+ num_layers: 1
54
+ dropout: 0.0
55
+ ```