Spaces:
Sleeping
Sleeping
| import torch.nn as nn | |
| import torch | |
| class BiLSTM(nn.Module): | |
| def __init__(self, embedding_matrix, embed_size, num_labels, hidden_size=64, dropout=0.1): | |
| super(BiLSTM, self).__init__() | |
| self.hidden_size = hidden_size | |
| self.embedding = nn.Embedding.from_pretrained( | |
| torch.tensor(embedding_matrix, dtype=torch.float32), freeze=True) | |
| self.lstm = nn.LSTM(embed_size, self.hidden_size, bidirectional=True, batch_first=True) | |
| self.linear = nn.Linear(self.hidden_size * 4, 64) | |
| self.relu = nn.ReLU() | |
| self.dropout = nn.Dropout(dropout) | |
| self.out = nn.Linear(64, num_labels) | |
| def forward(self, x): | |
| # Embedding layer | |
| h_embedding = self.embedding(x) | |
| # LSTM layer | |
| h_lstm, _ = self.lstm(h_embedding) | |
| # Pooling layers | |
| avg_pool = torch.mean(h_lstm, 1) | |
| max_pool, _ = torch.max(h_lstm, 1) | |
| # Concatenate pooling outputs | |
| conc = torch.cat((avg_pool, max_pool), 1) | |
| # Fully connected layers | |
| conc = self.relu(self.linear(conc)) | |
| conc = self.dropout(conc) | |
| out = self.out(conc) | |
| return out |