multimodal-glaucoma-classifier / multimodal_model.py
Rahil Parikh
modularize code
5e11c89
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import SegformerConfig, SegformerForSemanticSegmentation
from torch import Tensor
import torchvision
from torchvision import models
import torchvision.transforms as T
from torchvision import transforms
class MultimodalModel(nn.Module):
def __init__(self, num_numeric_features, num_classes):
super(MultimodalModel, self).__init__()
self.vit = models.vit_b_16(pretrained=True)
self.vit.heads = nn.Identity()
self.swin_b = models.swin_b(pretrained=True)
self.swin_b.head = nn.Identity()
self.swinv2_b = models.swin_v2_b(pretrained=True)
self.swinv2_b.head = nn.Identity()
self.numeric_branch = nn.Sequential(
nn.Conv1d(in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=1),
nn.BatchNorm1d(16),
nn.ReLU(),
nn.MaxPool1d(kernel_size=2, stride=2),
nn.Conv1d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1),
nn.BatchNorm1d(32),
nn.ReLU(),
nn.MaxPool1d(kernel_size=2, stride=2),
nn.Flatten(),
nn.Linear((num_numeric_features // 4) * 32, 64),
nn.BatchNorm1d(64),
nn.ReLU(),
nn.Linear(64, num_classes)
)
self.image_fc = nn.Sequential(
nn.Linear(768 + 1024 + 1024, 128),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(128, num_classes)
)
def forward(self, image, numeric_data):
vit_features = self.vit(image)
swin_b_features = self.swin_b(image)
swinv2_b_features = self.swinv2_b(image)
combined_image_features = torch.cat((vit_features, swin_b_features, swinv2_b_features), dim=1) # Shape: (N, 2816)
combined_image_output = self.image_fc(combined_image_features)
numeric_data = numeric_data.unsqueeze(1)
numeric_output = self.numeric_branch(numeric_data)
final_output = 0.95 * combined_image_output + 0.05 * numeric_output
return final_output