multimodal-glaucoma-classifier / segformer_model.py
Rahil Parikh
modularize code
5e11c89
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import SegformerConfig, SegformerForSemanticSegmentation
from torch import Tensor
import torchvision
from torchvision import models
import torchvision.transforms as T
from torchvision import transforms
class SqueezeExcitation(nn.Module):
def __init__(self, channels: int, reduction: int = 16):
super(SqueezeExcitation, self).__init__()
self.se = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(channels, channels // reduction, kernel_size=1),
nn.ReLU(inplace=True),
nn.Conv2d(channels // reduction, channels, kernel_size=1),
nn.Sigmoid()
)
def forward(self, x: Tensor) -> Tensor:
return x * self.se(x)
class SegFormer(nn.Module):
def __init__(self, n_channels: int, n_classes: int, pretrained_model: str = "nvidia/mit-b5"):
super(SegFormer, self).__init__()
self.n_channels = n_channels
self.n_classes = n_classes
config = SegformerConfig.from_pretrained(
pretrained_model,
num_channels=n_channels,
num_labels=n_classes,
hidden_dropout_prob=0.3,
attention_probs_dropout_prob=0.3,
drop_path_rate=0.1
)
self.segformer = SegformerForSemanticSegmentation.from_pretrained(
pretrained_model,
config=config,
ignore_mismatched_sizes=True
)
if n_channels != 3:
self.segformer.segformer.encoder.patch_embeddings[0].proj = nn.Conv2d(
n_channels, config.hidden_sizes[0], kernel_size=7, stride=4, padding=2
)
self.segformer.decode_head.classifier = nn.Sequential(
nn.Conv2d(config.decoder_hidden_size, config.decoder_hidden_size // 2, kernel_size=3, padding=1),
nn.BatchNorm2d(config.decoder_hidden_size // 2, momentum=0.05),
nn.ReLU(inplace=True),
nn.Dropout2d(0.4),
nn.Conv2d(config.decoder_hidden_size // 2, n_classes, kernel_size=1)
)
self.fpn = nn.ModuleList([
nn.Sequential(
nn.Conv2d(h, 128, kernel_size=1),
nn.BatchNorm2d(128, momentum=0.05),
nn.ReLU(inplace=True),
SqueezeExcitation(128)
) for h in config.hidden_sizes
])
self.fusion = nn.Sequential(
nn.Conv2d(128 * len(config.hidden_sizes), 256, kernel_size=3, padding=1),
nn.BatchNorm2d(256, momentum=0.05),
nn.ReLU(inplace=True),
nn.Dropout2d(0.3),
nn.Conv2d(256, 256, kernel_size=3, padding=1),
nn.BatchNorm2d(256, momentum=0.05),
nn.ReLU(inplace=True)
)
self.fusion_residual = nn.Conv2d(128 * len(config.hidden_sizes), 256, kernel_size=1)
self.refinement = nn.Sequential(
nn.Conv2d(256 + n_classes, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128, momentum=0.05),
nn.ReLU(inplace=True),
nn.Dropout2d(0.2),
nn.Conv2d(128, n_classes, kernel_size=1)
)
def forward(self, x: Tensor) -> Tensor:
input_size = x.size()[2:]
outputs = self.segformer(pixel_values=x)
logits = outputs.logits
encoder_outputs = self.segformer.segformer.encoder(pixel_values=x, output_hidden_states=True)
hidden_states = encoder_outputs.hidden_states
fpn_feats = []
for i, (feat, layer) in enumerate(zip(hidden_states, self.fpn)):
f = layer(feat)
f = F.interpolate(f, size=logits.shape[2:], mode="bilinear", align_corners=False)
fpn_feats.append(f)
fused = torch.cat(fpn_feats, dim=1)
residual = self.fusion_residual(fused)
fused = self.fusion(fused)
fused = fused + residual
logits = F.interpolate(logits, size=input_size, mode="bilinear", align_corners=False)
fused = F.interpolate(fused, size=input_size, mode="bilinear", align_corners=False)
concat = torch.cat([fused, logits], dim=1)
out = self.refinement(concat)
return out