AQIMultiModal / src /model_arch.py
rocky250's picture
Update src/model_arch.py
257d6f1 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
from pytorch_tabnet.tab_network import TabNetNoEmbeddings
class AttentionFusion(nn.Module):
def __init__(self, img_dim=64, tab_dim=16):
super(AttentionFusion, self).__init__()
self.img_attn = nn.Linear(img_dim, 1)
self.tab_attn = nn.Linear(tab_dim, 1)
self.tab_project = nn.Linear(tab_dim, img_dim)
self.dropout = nn.Dropout(0.2)
def forward(self, img_feat, tab_feat):
w_img = torch.tanh(self.img_attn(img_feat))
w_tab = torch.tanh(self.tab_attn(tab_feat))
weights = F.softmax(torch.cat([w_img, w_tab], dim=1), dim=1)
tab_feat_proj = self.tab_project(tab_feat)
fused = (weights[:, 0].unsqueeze(1) * img_feat) + (weights[:, 1].unsqueeze(1) * tab_feat_proj)
return self.dropout(fused), weights
class MultiModalNet(nn.Module):
def __init__(self, num_classes=6, num_tab_features=9):
super(MultiModalNet, self).__init__()
self.img_backbone = timm.create_model('efficientnet_b3', pretrained=False, num_classes=0)
self.img_fc = nn.Sequential(nn.Linear(1536, 64), nn.ReLU(), nn.Dropout(0.2))
self.tab_backbone = TabNetNoEmbeddings(input_dim=num_tab_features, output_dim=16, n_d=16, n_a=16, n_steps=3)
self.tab_dense = nn.Sequential(nn.Linear(16, 16), nn.ReLU())
self.fusion = AttentionFusion(img_dim=64, tab_dim=16)
self.classifier = nn.Sequential(nn.Linear(64, 64), nn.ReLU(), nn.Dropout(0.2), nn.Linear(64, num_classes))
def forward(self, img, tab):
if hasattr(self.tab_backbone, 'encoder'):
if hasattr(self.tab_backbone.encoder, 'group_attention_matrix'):
matrix = self.tab_backbone.encoder.group_attention_matrix
if matrix.device != tab.device:
self.tab_backbone.encoder.group_attention_matrix = matrix.to(tab.device)
i_feat = self.img_fc(self.img_backbone(img))
t_feat, _ = self.tab_backbone(tab)
t_feat = self.tab_dense(t_feat)
fused, weights = self.fusion(i_feat, t_feat)
return self.classifier(fused), weights