import torch import torch.nn as nn import torch.nn.functional as F from transformers import SegformerConfig, SegformerForSemanticSegmentation from torch import Tensor import torchvision from torchvision import models import torchvision.transforms as T from torchvision import transforms class MultimodalModel(nn.Module): def __init__(self, num_numeric_features, num_classes): super(MultimodalModel, self).__init__() self.vit = models.vit_b_16(pretrained=True) self.vit.heads = nn.Identity() self.swin_b = models.swin_b(pretrained=True) self.swin_b.head = nn.Identity() self.swinv2_b = models.swin_v2_b(pretrained=True) self.swinv2_b.head = nn.Identity() self.numeric_branch = nn.Sequential( nn.Conv1d(in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=1), nn.BatchNorm1d(16), nn.ReLU(), nn.MaxPool1d(kernel_size=2, stride=2), nn.Conv1d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1), nn.BatchNorm1d(32), nn.ReLU(), nn.MaxPool1d(kernel_size=2, stride=2), nn.Flatten(), nn.Linear((num_numeric_features // 4) * 32, 64), nn.BatchNorm1d(64), nn.ReLU(), nn.Linear(64, num_classes) ) self.image_fc = nn.Sequential( nn.Linear(768 + 1024 + 1024, 128), nn.ReLU(), nn.Dropout(0.3), nn.Linear(128, num_classes) ) def forward(self, image, numeric_data): vit_features = self.vit(image) swin_b_features = self.swin_b(image) swinv2_b_features = self.swinv2_b(image) combined_image_features = torch.cat((vit_features, swin_b_features, swinv2_b_features), dim=1) # Shape: (N, 2816) combined_image_output = self.image_fc(combined_image_features) numeric_data = numeric_data.unsqueeze(1) numeric_output = self.numeric_branch(numeric_data) final_output = 0.95 * combined_image_output + 0.05 * numeric_output return final_output