## Model Structure ```python class GRUTransformerSimple(nn.Module): def __init__( self, d_feat: int = 8, hidden_size: int = 64, num_layers: int = 1, dropout: float = 0.0, ) -> None: super().__init__() self.transformer_encoder_layer = nn.TransformerEncoderLayer( d_model=hidden_size, nhead=4, dim_feedforward=hidden_size * 4, dropout=dropout, activation="relu", batch_first=False, ) self.transformer_encoder = nn.TransformerEncoder( self.transformer_encoder_layer, num_layers=num_layers ) self.gru = nn.GRU( input_size=d_feat, hidden_size=hidden_size, num_layers=num_layers, batch_first=True, dropout=dropout, ) self.out = nn.Sequential( nn.Linear(hidden_size, hidden_size), nn.GELU(), nn.Linear(hidden_size, 1) ) def forward(self, x: torch.Tensor) -> torch.Tensor: b, t, s, f = x.shape x = x.permute(0, 2, 1, 3).reshape(b * s, t, f) gru_out, _ = self.gru(x) # [b * s, t, h] gru_out = gru_out.permute(1, 0, 2).contiguous() # [t, b * s, h] tfm_out = self.transformer_encoder(gru_out) # [t, b * s, h] tfm_out = tfm_out[-1].reshape(b, s, -1) # [b, s, h] final_out = self.out(tfm_out).squeeze(-1) # [b, s] return final_out ``` ## Model Config ```yaml d_feat: 8 hidden_size: 64 num_layers: 1 dropout: 0.0 ```