| ## Model Structure | |
| ```python | |
| class GRUTransformerSimple(nn.Module): | |
| def __init__( | |
| self, | |
| d_feat: int = 8, | |
| hidden_size: int = 64, | |
| num_layers: int = 1, | |
| dropout: float = 0.0, | |
| ) -> None: | |
| super().__init__() | |
| self.transformer_encoder_layer = nn.TransformerEncoderLayer( | |
| d_model=hidden_size, | |
| nhead=4, | |
| dim_feedforward=hidden_size * 4, | |
| dropout=dropout, | |
| activation="relu", | |
| batch_first=False, | |
| ) | |
| self.transformer_encoder = nn.TransformerEncoder( | |
| self.transformer_encoder_layer, num_layers=num_layers | |
| ) | |
| self.gru = nn.GRU( | |
| input_size=d_feat, | |
| hidden_size=hidden_size, | |
| num_layers=num_layers, | |
| batch_first=True, | |
| dropout=dropout, | |
| ) | |
| self.out = nn.Sequential( | |
| nn.Linear(hidden_size, hidden_size), nn.GELU(), nn.Linear(hidden_size, 1) | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| b, t, s, f = x.shape | |
| x = x.permute(0, 2, 1, 3).reshape(b * s, t, f) | |
| gru_out, _ = self.gru(x) # [b * s, t, h] | |
| gru_out = gru_out.permute(1, 0, 2).contiguous() # [t, b * s, h] | |
| tfm_out = self.transformer_encoder(gru_out) # [t, b * s, h] | |
| tfm_out = tfm_out[-1].reshape(b, s, -1) # [b, s, h] | |
| final_out = self.out(tfm_out).squeeze(-1) # [b, s] | |
| return final_out | |
| ``` | |
| ## Model Config | |
| ```yaml | |
| d_feat: 8 | |
| hidden_size: 64 | |
| num_layers: 1 | |
| dropout: 0.0 | |
| ``` |