"""SegFormer-like multi-scale decoder heads for coarse and fine branches. Fuse four feature maps from MiT encoder via 1x1 projections, upsample to the highest spatial resolution (stage 0), concatenate, and predict 2-class logits. """ from typing import List import torch import torch.nn as nn import torch.nn.functional as F class _ConvBNReLU(nn.Module): def __init__(self, in_ch: int, out_ch: int, k: int, s: int = 1, p: int = 0): super().__init__() self.conv = nn.Conv2d( in_ch, out_ch, kernel_size=k, stride=s, padding=p, bias=False ) self.bn = nn.BatchNorm2d(out_ch) self.relu = nn.ReLU(inplace=True) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.conv(x) x = self.bn(x) x = self.relu(x) return x class _SegFormerHead(nn.Module): def __init__(self, in_chs: List[int], embed_dim: int = 128, num_classes: int = 2): super().__init__() assert len(in_chs) == 4 self.proj = nn.ModuleList( [nn.Conv2d(c, embed_dim, kernel_size=1) for c in in_chs] ) self.fuse = _ConvBNReLU(embed_dim * 4, embed_dim, k=3, p=1) self.cls = nn.Conv2d(embed_dim, num_classes, kernel_size=1) def forward(self, feats: List[torch.Tensor]) -> torch.Tensor: assert len(feats) == 4 h, w = feats[0].shape[2], feats[0].shape[3] xs = [] for f, proj in zip(feats, self.proj): x = proj(f) if x.shape[2] != h or x.shape[3] != w: x = F.interpolate(x, size=(h, w), mode="bilinear", align_corners=False) xs.append(x) x = torch.cat(xs, dim=1) x = self.fuse(x) x = self.cls(x) return x class CoarseDecoder(_SegFormerHead): def __init__( self, in_chs: List[int] = (64, 128, 320, 512), embed_dim: int = 128, num_classes: int = 2, ): super().__init__(list(in_chs), embed_dim, num_classes) class FineDecoder(_SegFormerHead): def __init__( self, in_chs: List[int] = (64, 128, 320, 512), embed_dim: int = 128, num_classes: int = 2, ): super().__init__(list(in_chs), embed_dim, num_classes)