Spaces:
Paused
Paused
| import torch | |
| import torch.nn as nn | |
| class BiGRU(nn.Module): | |
| def __init__(self, input_features: int, hidden_features: int, num_layers: int): | |
| super().__init__() | |
| self.gru = nn.GRU( | |
| input_features, | |
| hidden_features, | |
| num_layers=num_layers, | |
| batch_first=True, | |
| bidirectional=True, | |
| ) | |
| self.gru.flatten_parameters() | |
| def forward(self, x: torch.Tensor): | |
| return self.gru(x)[0] | |