s1ghhh's picture
Upload folder using huggingface_hub
91aec72 verified
import torch
import torch.nn as nn
class RNNClassifier(nn.Module):
def __init__(self, hidden_dim=256, rnn_type='GRU'):
super().__init__()
self.rnn_type = rnn_type.upper()
if self.rnn_type == 'LSTM':
self.rnn = nn.LSTM(input_size=1, hidden_size=hidden_dim, batch_first=True)
else:
self.rnn = nn.GRU(input_size=1, hidden_size=hidden_dim, batch_first=True)
self.classifier = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1)
)
def forward(self, x, lengths):
packed_x = nn.utils.rnn.pack_padded_sequence(x, lengths.cpu(), batch_first=True, enforce_sorted=False)
if self.rnn_type == 'LSTM':
packed_out, (hn, cn) = self.rnn(packed_x)
else:
packed_out, hn = self.rnn(packed_x)
last_hidden = hn[-1]
out = self.classifier(last_hidden)
return torch.sigmoid(out).squeeze(-1)