MRiabov's picture
Add ResNet-50 as a backbone option
a2999cc
"""Encoder wrappers with adjustable input channels.
Supports two backbone families:
- HuggingFace Transformers SegFormer (e.g., "mit_b2")
- TorchVision ResNet-50 (use backbone "resnet50" | "resnet-50" | "resnet_50")
Both return a list of 4 multi-scale feature maps [C1, C2, C3, C4] at strides
1/4, 1/8, 1/16, 1/32 respectively.
"""
from typing import List, Tuple
import torch
import torch.nn as nn
from torchvision.models import resnet50, ResNet50_Weights
class SegFormerEncoder(nn.Module):
def __init__(
self,
backbone: str = "mit_b2",
in_channels: int = 6,
pretrained: bool = True,
):
super().__init__()
self.backbone_name = backbone
self.in_channels = in_channels
self.pretrained = pretrained
self.hf = None
self.resnet = None
# SegFormer path
if backbone.startswith("mit_") or backbone.startswith("segformer"):
self.hf = _HFEncoderWrapper(in_channels, backbone, pretrained)
self.feature_dims = self.hf.feature_dims
# ResNet-50 path
elif backbone in ("resnet50", "resnet-50", "resnet_50"):
self.resnet = _ResNetEncoderWrapper(in_channels, pretrained)
self.feature_dims = self.resnet.feature_dims
else:
raise ValueError(
f"Unsupported backbone '{backbone}'. Use one of: mit_b[0-5], segformer*, resnet50."
)
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
if self.hf is not None:
return self.hf(x)
if self.resnet is not None:
return self.resnet(x)
raise AssertionError("No encoder instantiated")
class _ResNetEncoderWrapper(nn.Module):
def __init__(self, in_chans: int, pretrained: bool):
super().__init__()
# Build base ResNet-50
if pretrained:
self.model = resnet50(weights=ResNet50_Weights.DEFAULT)
else:
self.model = resnet50(weights=None)
# Adjust input stem for arbitrary channel count
if in_chans != 3:
old_conv = self.model.conv1
new_conv = nn.Conv2d(
in_chans, old_conv.out_channels, kernel_size=old_conv.kernel_size[0],
stride=old_conv.stride[0], padding=old_conv.padding[0], bias=False
)
with torch.no_grad():
if pretrained and old_conv.weight.shape[1] == 3:
w = old_conv.weight # [64, 3, 7, 7]
if in_chans > 3:
w_mean = w.mean(dim=1, keepdim=True)
new_w = w_mean.repeat(1, in_chans, 1, 1)
else:
new_w = w[:, :in_chans, :, :]
new_conv.weight.copy_(new_w)
else:
nn.init.kaiming_normal_(new_conv.weight, mode="fan_out", nonlinearity="relu")
self.model.conv1 = new_conv
self.feature_dims = [256, 512, 1024, 2048]
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
# Stem
x = self.model.conv1(x)
x = self.model.bn1(x)
x = self.model.relu(x)
x = self.model.maxpool(x) # 1/4
# Stages
c1 = self.model.layer1(x) # 1/4, 256
c2 = self.model.layer2(c1) # 1/8, 512
c3 = self.model.layer3(c2) # 1/16, 1024
c4 = self.model.layer4(c3) # 1/32, 2048
return [c1, c2, c3, c4]
class _HFEncoderWrapper(nn.Module):
def __init__(self, in_chans: int, backbone: str, pretrained: bool):
super().__init__()
# Lazy import to avoid hard dependency during tests if not used
from transformers import SegformerModel, SegformerConfig
name_map = {
"mit_b0": "nvidia/mit-b0",
"mit_b1": "nvidia/mit-b1",
"mit_b2": "nvidia/mit-b2",
"mit_b3": "nvidia/mit-b3",
"mit_b4": "nvidia/mit-b4",
"mit_b5": "nvidia/mit-b5",
}
model_id = name_map[backbone]
if pretrained:
base_cfg = SegformerConfig.from_pretrained(model_id)
base_cfg.num_channels = in_chans
self.model = SegformerModel.from_pretrained(
model_id, config=base_cfg, ignore_mismatched_sizes=True
)
else:
cfg = SegformerConfig() # default config (B0-like)
cfg.num_channels = in_chans
self.model = SegformerModel(cfg)
# Expose channel dims per stage
self.feature_dims = list(self.model.config.hidden_sizes)
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
outputs = self.model(
pixel_values=x, output_hidden_states=True, return_dict=True
)
feats = list(outputs.hidden_states)
assert len(feats) == 4
return feats