YanCotta's picture
Upload folder using huggingface_hub
4fffbd7 verified
import torch
import torch.nn as nn
class FusionWiSE(nn.Module):
def __init__(self, clip_dim=512, entity_dim=147, num_classes=115, hidden_dim=512, dropout=0.5):
super(FusionWiSE, self).__init__()
self.encoder = nn.Sequential(
nn.Linear(clip_dim + entity_dim, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, hidden_dim)
)
self.classifier = nn.Linear(hidden_dim, num_classes)
def forward(self, clip_emb, entity_onehot):
x = torch.cat([clip_emb, entity_onehot], dim=1)
feats = self.encoder(x)
feats = feats / feats.norm(p=2, dim=-1, keepdim=True)
return self.classifier(feats)