File size: 756 Bytes
4fffbd7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22

import torch
import torch.nn as nn

class FusionWiSE(nn.Module):
    def __init__(self, clip_dim=512, entity_dim=147, num_classes=115, hidden_dim=512, dropout=0.5):
        super(FusionWiSE, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(clip_dim + entity_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim) 
        )
        self.classifier = nn.Linear(hidden_dim, num_classes)

    def forward(self, clip_emb, entity_onehot):
        x = torch.cat([clip_emb, entity_onehot], dim=1)
        feats = self.encoder(x)
        feats = feats / feats.norm(p=2, dim=-1, keepdim=True)
        return self.classifier(feats)