mulasagg's picture
final
a9640f8
import torch
import torch.nn as nn
import torch.nn.functional as F
class CNNBiGRU(nn.Module):
def __init__(self, vocab_size, embedding_dim, input_length, num_filters, kernel_size,
num_gated_units, hidden_neurons, dropout_cnn, dropout_fc):
super(CNNBiGRU, self).__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
self.conv1d = nn.Conv1d(in_channels=embedding_dim, out_channels=num_filters, kernel_size=kernel_size)
self.dropout_cnn = nn.Dropout(dropout_cnn)
self.maxpool = nn.MaxPool1d(kernel_size=kernel_size, stride=1)
self.bigru = nn.LSTM(input_size=num_filters, hidden_size=num_gated_units, num_layers=1,
batch_first=True, bidirectional=True)
self.fc1 = nn.Linear(num_gated_units * 2, hidden_neurons)
self.fc2 = nn.Linear(hidden_neurons, hidden_neurons)
self.dropout_fc = nn.Dropout(dropout_fc)
self.output = nn.Linear(hidden_neurons, 128)
def forward(self, x):
x = self.embedding(x)
x = x.permute(0, 2, 1)
x = F.relu(self.conv1d(x))
x = self.dropout_cnn(x)
x = self.maxpool(x)
x = x.permute(0, 2, 1)
x, _ = self.bigru(x)
x = x[:, -1, :]
x = self.output(x)
return x