File size: 793 Bytes
3c27def
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ce367e1
3c27def
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import torch.nn as nn
import torch
import torch.nn.functional as F
from .dynamic_conv import DynamicConv1d

class DynamicTextCNN(nn.Module):
    def __init__(self, input_dim, num_filters, filter_sizes, K=4, dropout=0.1):
        super().__init__()
        self.convs = nn.ModuleList([
            DynamicConv1d(input_dim, num_filters, k, K)
            for k in filter_sizes
        ])
        self.layer_norm = nn.LayerNorm(len(filter_sizes) * num_filters)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):

        convs = [F.relu(conv(x)) for conv in self.convs]

        pools = [F.adaptive_max_pool1d(c, 1).squeeze(-1) for c in convs]
      
        features = torch.cat(pools, dim=1)
        features = self.layer_norm(features)

        return self.dropout(features)