|
|
|
|
|
import torch |
|
|
from torch import nn |
|
|
|
|
|
class sadModel(nn.Module): |
|
|
def __init__(self, input_dim=40, hidden_dim=64, num_layers=1, output_dim=800): |
|
|
super(sadModel, self).__init__() |
|
|
|
|
|
|
|
|
self.gru = nn.GRU( |
|
|
input_size=input_dim, |
|
|
hidden_size=hidden_dim, |
|
|
num_layers=num_layers, |
|
|
batch_first=True, |
|
|
bidirectional=True |
|
|
) |
|
|
|
|
|
self.fc = nn.Linear(hidden_dim * 2 * 400, output_dim) |
|
|
|
|
|
def forward(self, x): |
|
|
|
|
|
x = x.squeeze(1).permute(0, 2, 1) |
|
|
|
|
|
|
|
|
out, _ = self.gru(x) |
|
|
|
|
|
|
|
|
out = out.contiguous().view(out.size(0), -1) |
|
|
|
|
|
out = self.fc(out) |
|
|
|
|
|
return out |
|
|
|