DPDRS-BILSTM / model_folder /modeling_bilstm.py
puneeth1's picture
Update model_folder/modeling_bilstm.py
8ea164e verified
import torch.nn as nn
import torch
class BiLSTM(nn.Module):
def __init__(self, embedding_matrix, embed_size, num_labels, hidden_size=64, dropout=0.1):
super(BiLSTM, self).__init__()
self.hidden_size = hidden_size
self.embedding = nn.Embedding.from_pretrained(
torch.tensor(embedding_matrix, dtype=torch.float32), freeze=True)
self.lstm = nn.LSTM(embed_size, self.hidden_size, bidirectional=True, batch_first=True)
self.linear = nn.Linear(self.hidden_size * 4, 64)
self.relu = nn.ReLU()
self.dropout = nn.Dropout(dropout)
self.out = nn.Linear(64, num_labels)
def forward(self, x):
# Embedding layer
h_embedding = self.embedding(x)
# LSTM layer
h_lstm, _ = self.lstm(h_embedding)
# Pooling layers
avg_pool = torch.mean(h_lstm, 1)
max_pool, _ = torch.max(h_lstm, 1)
# Concatenate pooling outputs
conc = torch.cat((avg_pool, max_pool), 1)
# Fully connected layers
conc = self.relu(self.linear(conc))
conc = self.dropout(conc)
out = self.out(conc)
return out