File size: 2,188 Bytes
7cb85d8
 
 
 
 
 
 
 
 
 
 
 
 
257d6f1
7cb85d8
 
 
 
 
 
 
 
257d6f1
7cb85d8
 
 
 
 
 
 
 
257d6f1
7cb85d8
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
from pytorch_tabnet.tab_network import TabNetNoEmbeddings

class AttentionFusion(nn.Module):
    def __init__(self, img_dim=64, tab_dim=16):
        super(AttentionFusion, self).__init__()
        self.img_attn = nn.Linear(img_dim, 1)
        self.tab_attn = nn.Linear(tab_dim, 1)
        self.tab_project = nn.Linear(tab_dim, img_dim)
        self.dropout = nn.Dropout(0.2)
            
    def forward(self, img_feat, tab_feat):
        w_img = torch.tanh(self.img_attn(img_feat))
        w_tab = torch.tanh(self.tab_attn(tab_feat))
        weights = F.softmax(torch.cat([w_img, w_tab], dim=1), dim=1)
        tab_feat_proj = self.tab_project(tab_feat)
        fused = (weights[:, 0].unsqueeze(1) * img_feat) + (weights[:, 1].unsqueeze(1) * tab_feat_proj)
        return self.dropout(fused), weights

class MultiModalNet(nn.Module):
    def __init__(self, num_classes=6, num_tab_features=9):
        super(MultiModalNet, self).__init__()
        self.img_backbone = timm.create_model('efficientnet_b3', pretrained=False, num_classes=0)
        self.img_fc = nn.Sequential(nn.Linear(1536, 64), nn.ReLU(), nn.Dropout(0.2))
        self.tab_backbone = TabNetNoEmbeddings(input_dim=num_tab_features, output_dim=16, n_d=16, n_a=16, n_steps=3)
        self.tab_dense = nn.Sequential(nn.Linear(16, 16), nn.ReLU())
        self.fusion = AttentionFusion(img_dim=64, tab_dim=16)
        self.classifier = nn.Sequential(nn.Linear(64, 64), nn.ReLU(), nn.Dropout(0.2), nn.Linear(64, num_classes))
            
    def forward(self, img, tab):
        if hasattr(self.tab_backbone, 'encoder'):
            if hasattr(self.tab_backbone.encoder, 'group_attention_matrix'):
                matrix = self.tab_backbone.encoder.group_attention_matrix
                if matrix.device != tab.device:
                    self.tab_backbone.encoder.group_attention_matrix = matrix.to(tab.device)
        i_feat = self.img_fc(self.img_backbone(img))
        t_feat, _ = self.tab_backbone(tab)
        t_feat = self.tab_dense(t_feat)
        fused, weights = self.fusion(i_feat, t_feat)
        return self.classifier(fused), weights