| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from transformers import SegformerConfig, SegformerForSemanticSegmentation | |
| from torch import Tensor | |
| import torchvision | |
| from torchvision import models | |
| import torchvision.transforms as T | |
| from torchvision import transforms | |
| class SqueezeExcitation(nn.Module): | |
| def __init__(self, channels: int, reduction: int = 16): | |
| super(SqueezeExcitation, self).__init__() | |
| self.se = nn.Sequential( | |
| nn.AdaptiveAvgPool2d(1), | |
| nn.Conv2d(channels, channels // reduction, kernel_size=1), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(channels // reduction, channels, kernel_size=1), | |
| nn.Sigmoid() | |
| ) | |
| def forward(self, x: Tensor) -> Tensor: | |
| return x * self.se(x) | |
| class SegFormer(nn.Module): | |
| def __init__(self, n_channels: int, n_classes: int, pretrained_model: str = "nvidia/mit-b5"): | |
| super(SegFormer, self).__init__() | |
| self.n_channels = n_channels | |
| self.n_classes = n_classes | |
| config = SegformerConfig.from_pretrained( | |
| pretrained_model, | |
| num_channels=n_channels, | |
| num_labels=n_classes, | |
| hidden_dropout_prob=0.3, | |
| attention_probs_dropout_prob=0.3, | |
| drop_path_rate=0.1 | |
| ) | |
| self.segformer = SegformerForSemanticSegmentation.from_pretrained( | |
| pretrained_model, | |
| config=config, | |
| ignore_mismatched_sizes=True | |
| ) | |
| if n_channels != 3: | |
| self.segformer.segformer.encoder.patch_embeddings[0].proj = nn.Conv2d( | |
| n_channels, config.hidden_sizes[0], kernel_size=7, stride=4, padding=2 | |
| ) | |
| self.segformer.decode_head.classifier = nn.Sequential( | |
| nn.Conv2d(config.decoder_hidden_size, config.decoder_hidden_size // 2, kernel_size=3, padding=1), | |
| nn.BatchNorm2d(config.decoder_hidden_size // 2, momentum=0.05), | |
| nn.ReLU(inplace=True), | |
| nn.Dropout2d(0.4), | |
| nn.Conv2d(config.decoder_hidden_size // 2, n_classes, kernel_size=1) | |
| ) | |
| self.fpn = nn.ModuleList([ | |
| nn.Sequential( | |
| nn.Conv2d(h, 128, kernel_size=1), | |
| nn.BatchNorm2d(128, momentum=0.05), | |
| nn.ReLU(inplace=True), | |
| SqueezeExcitation(128) | |
| ) for h in config.hidden_sizes | |
| ]) | |
| self.fusion = nn.Sequential( | |
| nn.Conv2d(128 * len(config.hidden_sizes), 256, kernel_size=3, padding=1), | |
| nn.BatchNorm2d(256, momentum=0.05), | |
| nn.ReLU(inplace=True), | |
| nn.Dropout2d(0.3), | |
| nn.Conv2d(256, 256, kernel_size=3, padding=1), | |
| nn.BatchNorm2d(256, momentum=0.05), | |
| nn.ReLU(inplace=True) | |
| ) | |
| self.fusion_residual = nn.Conv2d(128 * len(config.hidden_sizes), 256, kernel_size=1) | |
| self.refinement = nn.Sequential( | |
| nn.Conv2d(256 + n_classes, 128, kernel_size=3, padding=1), | |
| nn.BatchNorm2d(128, momentum=0.05), | |
| nn.ReLU(inplace=True), | |
| nn.Dropout2d(0.2), | |
| nn.Conv2d(128, n_classes, kernel_size=1) | |
| ) | |
| def forward(self, x: Tensor) -> Tensor: | |
| input_size = x.size()[2:] | |
| outputs = self.segformer(pixel_values=x) | |
| logits = outputs.logits | |
| encoder_outputs = self.segformer.segformer.encoder(pixel_values=x, output_hidden_states=True) | |
| hidden_states = encoder_outputs.hidden_states | |
| fpn_feats = [] | |
| for i, (feat, layer) in enumerate(zip(hidden_states, self.fpn)): | |
| f = layer(feat) | |
| f = F.interpolate(f, size=logits.shape[2:], mode="bilinear", align_corners=False) | |
| fpn_feats.append(f) | |
| fused = torch.cat(fpn_feats, dim=1) | |
| residual = self.fusion_residual(fused) | |
| fused = self.fusion(fused) | |
| fused = fused + residual | |
| logits = F.interpolate(logits, size=input_size, mode="bilinear", align_corners=False) | |
| fused = F.interpolate(fused, size=input_size, mode="bilinear", align_corners=False) | |
| concat = torch.cat([fused, logits], dim=1) | |
| out = self.refinement(concat) | |
| return out |