File size: 2,283 Bytes
5e11c89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import SegformerConfig, SegformerForSemanticSegmentation
from torch import Tensor

import torchvision
from torchvision import models
import torchvision.transforms as T
from torchvision import transforms

class MultimodalModel(nn.Module):
    def __init__(self, num_numeric_features, num_classes):
        super(MultimodalModel, self).__init__()
        
        self.vit = models.vit_b_16(pretrained=True)
        self.vit.heads = nn.Identity()
        
        self.swin_b = models.swin_b(pretrained=True)
        self.swin_b.head = nn.Identity()
        
        self.swinv2_b = models.swin_v2_b(pretrained=True)
        self.swinv2_b.head = nn.Identity()
        
        self.numeric_branch = nn.Sequential(
            nn.Conv1d(in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm1d(16),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=2, stride=2),
            nn.Conv1d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm1d(32),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=2, stride=2),
            nn.Flatten(),
            nn.Linear((num_numeric_features // 4) * 32, 64),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Linear(64, num_classes)
        )
        
        self.image_fc = nn.Sequential(
            nn.Linear(768 + 1024 + 1024, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, num_classes)
        )
    
    def forward(self, image, numeric_data):
        vit_features = self.vit(image)
        
        swin_b_features = self.swin_b(image)
        
        swinv2_b_features = self.swinv2_b(image)
        
        combined_image_features = torch.cat((vit_features, swin_b_features, swinv2_b_features), dim=1)  # Shape: (N, 2816)
        combined_image_output = self.image_fc(combined_image_features)
        
        numeric_data = numeric_data.unsqueeze(1)
        numeric_output = self.numeric_branch(numeric_data)
        
        final_output = 0.95 * combined_image_output + 0.05 * numeric_output
        
        return final_output