| import torch | |
| import torch.nn as nn | |
| class FusionWiSE(nn.Module): | |
| def __init__(self, clip_dim=512, entity_dim=147, num_classes=115, hidden_dim=512, dropout=0.5): | |
| super(FusionWiSE, self).__init__() | |
| self.encoder = nn.Sequential( | |
| nn.Linear(clip_dim + entity_dim, hidden_dim), | |
| nn.BatchNorm1d(hidden_dim), | |
| nn.ReLU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(hidden_dim, hidden_dim) | |
| ) | |
| self.classifier = nn.Linear(hidden_dim, num_classes) | |
| def forward(self, clip_emb, entity_onehot): | |
| x = torch.cat([clip_emb, entity_onehot], dim=1) | |
| feats = self.encoder(x) | |
| feats = feats / feats.norm(p=2, dim=-1, keepdim=True) | |
| return self.classifier(feats) | |