import torch import torch.nn as nn class FusionWiSE(nn.Module): def __init__(self, clip_dim=512, entity_dim=147, num_classes=115, hidden_dim=512, dropout=0.5): super(FusionWiSE, self).__init__() self.encoder = nn.Sequential( nn.Linear(clip_dim + entity_dim, hidden_dim), nn.BatchNorm1d(hidden_dim), nn.ReLU(), nn.Dropout(dropout), nn.Linear(hidden_dim, hidden_dim) ) self.classifier = nn.Linear(hidden_dim, num_classes) def forward(self, clip_emb, entity_onehot): x = torch.cat([clip_emb, entity_onehot], dim=1) feats = self.encoder(x) feats = feats / feats.norm(p=2, dim=-1, keepdim=True) return self.classifier(feats)