Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import timm | |
| from pytorch_tabnet.tab_network import TabNetNoEmbeddings | |
| class AttentionFusion(nn.Module): | |
| def __init__(self, img_dim=64, tab_dim=16): | |
| super(AttentionFusion, self).__init__() | |
| self.img_attn = nn.Linear(img_dim, 1) | |
| self.tab_attn = nn.Linear(tab_dim, 1) | |
| self.tab_project = nn.Linear(tab_dim, img_dim) | |
| self.dropout = nn.Dropout(0.2) | |
| def forward(self, img_feat, tab_feat): | |
| w_img = torch.tanh(self.img_attn(img_feat)) | |
| w_tab = torch.tanh(self.tab_attn(tab_feat)) | |
| weights = F.softmax(torch.cat([w_img, w_tab], dim=1), dim=1) | |
| tab_feat_proj = self.tab_project(tab_feat) | |
| fused = (weights[:, 0].unsqueeze(1) * img_feat) + (weights[:, 1].unsqueeze(1) * tab_feat_proj) | |
| return self.dropout(fused), weights | |
| class MultiModalNet(nn.Module): | |
| def __init__(self, num_classes=6, num_tab_features=9): | |
| super(MultiModalNet, self).__init__() | |
| self.img_backbone = timm.create_model('efficientnet_b3', pretrained=False, num_classes=0) | |
| self.img_fc = nn.Sequential(nn.Linear(1536, 64), nn.ReLU(), nn.Dropout(0.2)) | |
| self.tab_backbone = TabNetNoEmbeddings(input_dim=num_tab_features, output_dim=16, n_d=16, n_a=16, n_steps=3) | |
| self.tab_dense = nn.Sequential(nn.Linear(16, 16), nn.ReLU()) | |
| self.fusion = AttentionFusion(img_dim=64, tab_dim=16) | |
| self.classifier = nn.Sequential(nn.Linear(64, 64), nn.ReLU(), nn.Dropout(0.2), nn.Linear(64, num_classes)) | |
| def forward(self, img, tab): | |
| if hasattr(self.tab_backbone, 'encoder'): | |
| if hasattr(self.tab_backbone.encoder, 'group_attention_matrix'): | |
| matrix = self.tab_backbone.encoder.group_attention_matrix | |
| if matrix.device != tab.device: | |
| self.tab_backbone.encoder.group_attention_matrix = matrix.to(tab.device) | |
| i_feat = self.img_fc(self.img_backbone(img)) | |
| t_feat, _ = self.tab_backbone(tab) | |
| t_feat = self.tab_dense(t_feat) | |
| fused, weights = self.fusion(i_feat, t_feat) | |
| return self.classifier(fused), weights |