import torch import torch.nn as nn import torch.nn.functional as F import timm from pytorch_tabnet.tab_network import TabNetNoEmbeddings class AttentionFusion(nn.Module): def __init__(self, img_dim=64, tab_dim=16): super(AttentionFusion, self).__init__() self.img_attn = nn.Linear(img_dim, 1) self.tab_attn = nn.Linear(tab_dim, 1) self.tab_project = nn.Linear(tab_dim, img_dim) self.dropout = nn.Dropout(0.2) def forward(self, img_feat, tab_feat): w_img = torch.tanh(self.img_attn(img_feat)) w_tab = torch.tanh(self.tab_attn(tab_feat)) weights = F.softmax(torch.cat([w_img, w_tab], dim=1), dim=1) tab_feat_proj = self.tab_project(tab_feat) fused = (weights[:, 0].unsqueeze(1) * img_feat) + (weights[:, 1].unsqueeze(1) * tab_feat_proj) return self.dropout(fused), weights class MultiModalNet(nn.Module): def __init__(self, num_classes=6, num_tab_features=9): super(MultiModalNet, self).__init__() self.img_backbone = timm.create_model('efficientnet_b3', pretrained=False, num_classes=0) self.img_fc = nn.Sequential(nn.Linear(1536, 64), nn.ReLU(), nn.Dropout(0.2)) self.tab_backbone = TabNetNoEmbeddings(input_dim=num_tab_features, output_dim=16, n_d=16, n_a=16, n_steps=3) self.tab_dense = nn.Sequential(nn.Linear(16, 16), nn.ReLU()) self.fusion = AttentionFusion(img_dim=64, tab_dim=16) self.classifier = nn.Sequential(nn.Linear(64, 64), nn.ReLU(), nn.Dropout(0.2), nn.Linear(64, num_classes)) def forward(self, img, tab): if hasattr(self.tab_backbone, 'encoder'): if hasattr(self.tab_backbone.encoder, 'group_attention_matrix'): matrix = self.tab_backbone.encoder.group_attention_matrix if matrix.device != tab.device: self.tab_backbone.encoder.group_attention_matrix = matrix.to(tab.device) i_feat = self.img_fc(self.img_backbone(img)) t_feat, _ = self.tab_backbone(tab) t_feat = self.tab_dense(t_feat) fused, weights = self.fusion(i_feat, t_feat) return self.classifier(fused), weights