| | import torch
|
| | import torch.nn as nn
|
| | import torch.nn.functional as F
|
| | from transformers import SegformerConfig, SegformerForSemanticSegmentation
|
| | from torch import Tensor
|
| |
|
| | import torchvision
|
| | from torchvision import models
|
| | import torchvision.transforms as T
|
| | from torchvision import transforms
|
| |
|
| | class MultimodalModel(nn.Module):
|
| | def __init__(self, num_numeric_features, num_classes):
|
| | super(MultimodalModel, self).__init__()
|
| |
|
| | self.vit = models.vit_b_16(pretrained=True)
|
| | self.vit.heads = nn.Identity()
|
| |
|
| | self.swin_b = models.swin_b(pretrained=True)
|
| | self.swin_b.head = nn.Identity()
|
| |
|
| | self.swinv2_b = models.swin_v2_b(pretrained=True)
|
| | self.swinv2_b.head = nn.Identity()
|
| |
|
| | self.numeric_branch = nn.Sequential(
|
| | nn.Conv1d(in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=1),
|
| | nn.BatchNorm1d(16),
|
| | nn.ReLU(),
|
| | nn.MaxPool1d(kernel_size=2, stride=2),
|
| | nn.Conv1d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1),
|
| | nn.BatchNorm1d(32),
|
| | nn.ReLU(),
|
| | nn.MaxPool1d(kernel_size=2, stride=2),
|
| | nn.Flatten(),
|
| | nn.Linear((num_numeric_features // 4) * 32, 64),
|
| | nn.BatchNorm1d(64),
|
| | nn.ReLU(),
|
| | nn.Linear(64, num_classes)
|
| | )
|
| |
|
| | self.image_fc = nn.Sequential(
|
| | nn.Linear(768 + 1024 + 1024, 128),
|
| | nn.ReLU(),
|
| | nn.Dropout(0.3),
|
| | nn.Linear(128, num_classes)
|
| | )
|
| |
|
| | def forward(self, image, numeric_data):
|
| | vit_features = self.vit(image)
|
| |
|
| | swin_b_features = self.swin_b(image)
|
| |
|
| | swinv2_b_features = self.swinv2_b(image)
|
| |
|
| | combined_image_features = torch.cat((vit_features, swin_b_features, swinv2_b_features), dim=1)
|
| | combined_image_output = self.image_fc(combined_image_features)
|
| |
|
| | numeric_data = numeric_data.unsqueeze(1)
|
| | numeric_output = self.numeric_branch(numeric_data)
|
| |
|
| | final_output = 0.95 * combined_image_output + 0.05 * numeric_output
|
| |
|
| | return final_output |