Spaces:
Sleeping
Sleeping
File size: 1,196 Bytes
d9f0436 8ea164e d9f0436 8ea164e d9f0436 8ea164e d9f0436 8ea164e d9f0436 8ea164e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 | import torch.nn as nn
import torch
class BiLSTM(nn.Module):
def __init__(self, embedding_matrix, embed_size, num_labels, hidden_size=64, dropout=0.1):
super(BiLSTM, self).__init__()
self.hidden_size = hidden_size
self.embedding = nn.Embedding.from_pretrained(
torch.tensor(embedding_matrix, dtype=torch.float32), freeze=True)
self.lstm = nn.LSTM(embed_size, self.hidden_size, bidirectional=True, batch_first=True)
self.linear = nn.Linear(self.hidden_size * 4, 64)
self.relu = nn.ReLU()
self.dropout = nn.Dropout(dropout)
self.out = nn.Linear(64, num_labels)
def forward(self, x):
# Embedding layer
h_embedding = self.embedding(x)
# LSTM layer
h_lstm, _ = self.lstm(h_embedding)
# Pooling layers
avg_pool = torch.mean(h_lstm, 1)
max_pool, _ = torch.max(h_lstm, 1)
# Concatenate pooling outputs
conc = torch.cat((avg_pool, max_pool), 1)
# Fully connected layers
conc = self.relu(self.linear(conc))
conc = self.dropout(conc)
out = self.out(conc)
return out |