MRiabov's picture
(format) format
6928e82
"""SegFormer-like multi-scale decoder heads for coarse and fine branches.
Fuse four feature maps from MiT encoder via 1x1 projections, upsample to the
highest spatial resolution (stage 0), concatenate, and predict 2-class logits.
"""
from typing import List
import torch
import torch.nn as nn
import torch.nn.functional as F
class _ConvBNReLU(nn.Module):
def __init__(self, in_ch: int, out_ch: int, k: int, s: int = 1, p: int = 0):
super().__init__()
self.conv = nn.Conv2d(
in_ch, out_ch, kernel_size=k, stride=s, padding=p, bias=False
)
self.bn = nn.BatchNorm2d(out_ch)
self.relu = nn.ReLU(inplace=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
return x
class _SegFormerHead(nn.Module):
def __init__(self, in_chs: List[int], embed_dim: int = 128, num_classes: int = 2):
super().__init__()
assert len(in_chs) == 4
self.proj = nn.ModuleList(
[nn.Conv2d(c, embed_dim, kernel_size=1) for c in in_chs]
)
self.fuse = _ConvBNReLU(embed_dim * 4, embed_dim, k=3, p=1)
self.cls = nn.Conv2d(embed_dim, num_classes, kernel_size=1)
def forward(self, feats: List[torch.Tensor]) -> torch.Tensor:
assert len(feats) == 4
h, w = feats[0].shape[2], feats[0].shape[3]
xs = []
for f, proj in zip(feats, self.proj):
x = proj(f)
if x.shape[2] != h or x.shape[3] != w:
x = F.interpolate(x, size=(h, w), mode="bilinear", align_corners=False)
xs.append(x)
x = torch.cat(xs, dim=1)
x = self.fuse(x)
x = self.cls(x)
return x
class CoarseDecoder(_SegFormerHead):
def __init__(
self,
in_chs: List[int] = (64, 128, 320, 512),
embed_dim: int = 128,
num_classes: int = 2,
):
super().__init__(list(in_chs), embed_dim, num_classes)
class FineDecoder(_SegFormerHead):
def __init__(
self,
in_chs: List[int] = (64, 128, 320, 512),
embed_dim: int = 128,
num_classes: int = 2,
):
super().__init__(list(in_chs), embed_dim, num_classes)