File size: 4,447 Bytes
08b3b18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import torch
import torch.nn as nn
import torch.nn.functional as F

from src.config import (
    glove_dim, rnn_hidden, rnn_layers, dropout,
    num_numeric_features, max_vocab_size,
    cnn_filters, cnn_num_filters,
)

class NumericNet(nn.Module):
    def __init__(self, num_features=num_numeric_features, hidden=128, dropout=dropout):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(num_features, hidden),
            nn.ReLU(),
            nn.BatchNorm1d(hidden),
            nn.Dropout(dropout),
            nn.Linear(hidden, hidden),
            nn.ReLU(),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)


class BiGRU_LSTM(nn.Module):
    def __init__(self, vocab_size=max_vocab_size, embed_dim=glove_dim,
                 rnn_hidden=rnn_hidden, rnn_layers=rnn_layers,
                 num_numeric=num_numeric_features, dropout=dropout,
                 pretrained_embeddings=None):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        if pretrained_embeddings is not None:
            self.embedding.weight = nn.Parameter(
                torch.tensor(pretrained_embeddings, dtype=torch.float32), requires_grad=False)

        self.bigru = nn.GRU(embed_dim, rnn_hidden, num_layers=rnn_layers,
                            batch_first=True, bidirectional=True,
                            dropout=dropout if rnn_layers > 1 else 0)
        self.lstm = nn.LSTM(rnn_hidden * 2, rnn_hidden, num_layers=rnn_layers,
                            batch_first=True, dropout=dropout if rnn_layers > 1 else 0)

        self.numeric_net = NumericNet(num_numeric, hidden=128, dropout=dropout)

        fused_dim = 128
        self.text_proj = nn.Linear(rnn_hidden, fused_dim)
        self.gate = nn.Linear(fused_dim * 2, fused_dim)

        self.dropout = nn.Dropout(dropout)
        self.classifier = nn.Linear(fused_dim, 1)

    def forward(self, input_ids, numeric):
        embedded = self.embedding(input_ids)
        gru_out, _ = self.bigru(embedded)
        _, (h_n, _) = self.lstm(gru_out)
        text_repr = self.text_proj(h_n.squeeze(0))
        num_repr = self.numeric_net(numeric)

        gate = torch.sigmoid(self.gate(torch.cat([text_repr, num_repr], dim=1)))
        fused = gate * text_repr + (1 - gate) * num_repr

        x = self.dropout(F.relu(fused))
        return self.classifier(x)


class CNN_BiLSTM(nn.Module):
    def __init__(self, vocab_size=max_vocab_size, embed_dim=glove_dim,
                 filter_sizes=cnn_filters, num_filters=cnn_num_filters,
                 rnn_hidden=rnn_hidden, rnn_layers=rnn_layers,
                 num_numeric=num_numeric_features, dropout=dropout,
                 pretrained_embeddings=None):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        if pretrained_embeddings is not None:
            self.embedding.weight = nn.Parameter(
                torch.tensor(pretrained_embeddings, dtype=torch.float32), requires_grad=False)

        self.convs = nn.ModuleList([nn.Conv1d(embed_dim, num_filters, fs) for fs in filter_sizes])
        cnn_out = num_filters * len(filter_sizes)

        self.bilstm = nn.LSTM(cnn_out, rnn_hidden, num_layers=rnn_layers,
                              batch_first=True, bidirectional=True,
                              dropout=dropout if rnn_layers > 1 else 0)

        self.numeric_net = NumericNet(num_numeric, hidden=128, dropout=dropout)

        fused_dim = 128
        self.text_proj = nn.Linear(rnn_hidden * 2, fused_dim)
        self.gate = nn.Linear(fused_dim * 2, fused_dim)

        self.dropout = nn.Dropout(dropout)
        self.classifier = nn.Linear(fused_dim, 1)

    def forward(self, input_ids, numeric):
        embedded = self.embedding(input_ids).permute(0, 2, 1)
        conv_outs = [F.max_pool1d(F.relu(conv(embedded)), conv(embedded).size(2)).squeeze(2)
                     for conv in self.convs]
        cnn_out = torch.cat(conv_outs, dim=1).unsqueeze(1)

        _, (h_n, _) = self.bilstm(cnn_out)
        text_repr = self.text_proj(torch.cat([h_n[-2], h_n[-1]], dim=1))
        num_repr = self.numeric_net(numeric)

        gate = torch.sigmoid(self.gate(torch.cat([text_repr, num_repr], dim=1)))
        fused = gate * text_repr + (1 - gate) * num_repr

        x = self.dropout(F.relu(fused))
        return self.classifier(x)