FactSight / models /models.py
DeepActionPotential's picture
Initial project upload via Python API for Flask Space
e0f2d0e verified
raw
history blame contribute delete
913 Bytes
import torch.nn as nn
class LSTMClassifier(nn.Module):
def __init__(self, vocab_size, embedding_dim=100, hidden_dim=128,
num_layers=1, dropout=0.5, bidirectional=True):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
self.lstm = nn.LSTM(
embedding_dim, hidden_dim, num_layers,
batch_first=True,
bidirectional=bidirectional,
dropout=dropout if num_layers > 1 else 0
)
self.dropout = nn.Dropout(dropout)
self.fc = nn.Linear(hidden_dim * (2 if bidirectional else 1), 1)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
x = self.embedding(x)
out, _ = self.lstm(x)
last = out[:, -1, :]
out = self.dropout(last)
out = self.fc(out)
return self.sigmoid(out).squeeze()