Spaces:
Running
Running
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from src.config import ( | |
| glove_dim, rnn_hidden, rnn_layers, dropout, | |
| num_numeric_features, max_vocab_size, | |
| cnn_filters, cnn_num_filters, | |
| ) | |
| class NumericNet(nn.Module): | |
| def __init__(self, num_features=num_numeric_features, hidden=128, dropout=dropout): | |
| super().__init__() | |
| self.net = nn.Sequential( | |
| nn.Linear(num_features, hidden), | |
| nn.ReLU(), | |
| nn.BatchNorm1d(hidden), | |
| nn.Dropout(dropout), | |
| nn.Linear(hidden, hidden), | |
| nn.ReLU(), | |
| nn.Dropout(dropout), | |
| ) | |
| def forward(self, x): | |
| return self.net(x) | |
| class BiGRU_LSTM(nn.Module): | |
| def __init__(self, vocab_size=max_vocab_size, embed_dim=glove_dim, | |
| rnn_hidden=rnn_hidden, rnn_layers=rnn_layers, | |
| num_numeric=num_numeric_features, dropout=dropout, | |
| pretrained_embeddings=None): | |
| super().__init__() | |
| self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0) | |
| if pretrained_embeddings is not None: | |
| self.embedding.weight = nn.Parameter( | |
| torch.tensor(pretrained_embeddings, dtype=torch.float32), requires_grad=False) | |
| self.bigru = nn.GRU(embed_dim, rnn_hidden, num_layers=rnn_layers, | |
| batch_first=True, bidirectional=True, | |
| dropout=dropout if rnn_layers > 1 else 0) | |
| self.lstm = nn.LSTM(rnn_hidden * 2, rnn_hidden, num_layers=rnn_layers, | |
| batch_first=True, dropout=dropout if rnn_layers > 1 else 0) | |
| self.numeric_net = NumericNet(num_numeric, hidden=128, dropout=dropout) | |
| fused_dim = 128 | |
| self.text_proj = nn.Linear(rnn_hidden, fused_dim) | |
| self.gate = nn.Linear(fused_dim * 2, fused_dim) | |
| self.dropout = nn.Dropout(dropout) | |
| self.classifier = nn.Linear(fused_dim, 1) | |
| def forward(self, input_ids, numeric): | |
| embedded = self.embedding(input_ids) | |
| gru_out, _ = self.bigru(embedded) | |
| _, (h_n, _) = self.lstm(gru_out) | |
| text_repr = self.text_proj(h_n.squeeze(0)) | |
| num_repr = self.numeric_net(numeric) | |
| gate = torch.sigmoid(self.gate(torch.cat([text_repr, num_repr], dim=1))) | |
| fused = gate * text_repr + (1 - gate) * num_repr | |
| x = self.dropout(F.relu(fused)) | |
| return self.classifier(x) | |
| class CNN_BiLSTM(nn.Module): | |
| def __init__(self, vocab_size=max_vocab_size, embed_dim=glove_dim, | |
| filter_sizes=cnn_filters, num_filters=cnn_num_filters, | |
| rnn_hidden=rnn_hidden, rnn_layers=rnn_layers, | |
| num_numeric=num_numeric_features, dropout=dropout, | |
| pretrained_embeddings=None): | |
| super().__init__() | |
| self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0) | |
| if pretrained_embeddings is not None: | |
| self.embedding.weight = nn.Parameter( | |
| torch.tensor(pretrained_embeddings, dtype=torch.float32), requires_grad=False) | |
| self.convs = nn.ModuleList([nn.Conv1d(embed_dim, num_filters, fs) for fs in filter_sizes]) | |
| cnn_out = num_filters * len(filter_sizes) | |
| self.bilstm = nn.LSTM(cnn_out, rnn_hidden, num_layers=rnn_layers, | |
| batch_first=True, bidirectional=True, | |
| dropout=dropout if rnn_layers > 1 else 0) | |
| self.numeric_net = NumericNet(num_numeric, hidden=128, dropout=dropout) | |
| fused_dim = 128 | |
| self.text_proj = nn.Linear(rnn_hidden * 2, fused_dim) | |
| self.gate = nn.Linear(fused_dim * 2, fused_dim) | |
| self.dropout = nn.Dropout(dropout) | |
| self.classifier = nn.Linear(fused_dim, 1) | |
| def forward(self, input_ids, numeric): | |
| embedded = self.embedding(input_ids).permute(0, 2, 1) | |
| conv_outs = [F.max_pool1d(F.relu(conv(embedded)), conv(embedded).size(2)).squeeze(2) | |
| for conv in self.convs] | |
| cnn_out = torch.cat(conv_outs, dim=1).unsqueeze(1) | |
| _, (h_n, _) = self.bilstm(cnn_out) | |
| text_repr = self.text_proj(torch.cat([h_n[-2], h_n[-1]], dim=1)) | |
| num_repr = self.numeric_net(numeric) | |
| gate = torch.sigmoid(self.gate(torch.cat([text_repr, num_repr], dim=1))) | |
| fused = gate * text_repr + (1 - gate) * num_repr | |
| x = self.dropout(F.relu(fused)) | |
| return self.classifier(x) |