|
|
"""Encoder wrappers with adjustable input channels. |
|
|
|
|
|
Supports two backbone families: |
|
|
- HuggingFace Transformers SegFormer (e.g., "mit_b2") |
|
|
- TorchVision ResNet-50 (use backbone "resnet50" | "resnet-50" | "resnet_50") |
|
|
|
|
|
Both return a list of 4 multi-scale feature maps [C1, C2, C3, C4] at strides |
|
|
1/4, 1/8, 1/16, 1/32 respectively. |
|
|
""" |
|
|
|
|
|
from typing import List, Tuple |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torchvision.models import resnet50, ResNet50_Weights |
|
|
|
|
|
|
|
|
class SegFormerEncoder(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
backbone: str = "mit_b2", |
|
|
in_channels: int = 6, |
|
|
pretrained: bool = True, |
|
|
): |
|
|
super().__init__() |
|
|
self.backbone_name = backbone |
|
|
self.in_channels = in_channels |
|
|
self.pretrained = pretrained |
|
|
|
|
|
self.hf = None |
|
|
self.resnet = None |
|
|
|
|
|
|
|
|
if backbone.startswith("mit_") or backbone.startswith("segformer"): |
|
|
self.hf = _HFEncoderWrapper(in_channels, backbone, pretrained) |
|
|
self.feature_dims = self.hf.feature_dims |
|
|
|
|
|
elif backbone in ("resnet50", "resnet-50", "resnet_50"): |
|
|
self.resnet = _ResNetEncoderWrapper(in_channels, pretrained) |
|
|
self.feature_dims = self.resnet.feature_dims |
|
|
else: |
|
|
raise ValueError( |
|
|
f"Unsupported backbone '{backbone}'. Use one of: mit_b[0-5], segformer*, resnet50." |
|
|
) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> List[torch.Tensor]: |
|
|
if self.hf is not None: |
|
|
return self.hf(x) |
|
|
if self.resnet is not None: |
|
|
return self.resnet(x) |
|
|
raise AssertionError("No encoder instantiated") |
|
|
|
|
|
|
|
|
class _ResNetEncoderWrapper(nn.Module): |
|
|
def __init__(self, in_chans: int, pretrained: bool): |
|
|
super().__init__() |
|
|
|
|
|
if pretrained: |
|
|
self.model = resnet50(weights=ResNet50_Weights.DEFAULT) |
|
|
else: |
|
|
self.model = resnet50(weights=None) |
|
|
|
|
|
|
|
|
if in_chans != 3: |
|
|
old_conv = self.model.conv1 |
|
|
new_conv = nn.Conv2d( |
|
|
in_chans, old_conv.out_channels, kernel_size=old_conv.kernel_size[0], |
|
|
stride=old_conv.stride[0], padding=old_conv.padding[0], bias=False |
|
|
) |
|
|
with torch.no_grad(): |
|
|
if pretrained and old_conv.weight.shape[1] == 3: |
|
|
w = old_conv.weight |
|
|
if in_chans > 3: |
|
|
w_mean = w.mean(dim=1, keepdim=True) |
|
|
new_w = w_mean.repeat(1, in_chans, 1, 1) |
|
|
else: |
|
|
new_w = w[:, :in_chans, :, :] |
|
|
new_conv.weight.copy_(new_w) |
|
|
else: |
|
|
nn.init.kaiming_normal_(new_conv.weight, mode="fan_out", nonlinearity="relu") |
|
|
self.model.conv1 = new_conv |
|
|
|
|
|
self.feature_dims = [256, 512, 1024, 2048] |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> List[torch.Tensor]: |
|
|
|
|
|
x = self.model.conv1(x) |
|
|
x = self.model.bn1(x) |
|
|
x = self.model.relu(x) |
|
|
x = self.model.maxpool(x) |
|
|
|
|
|
|
|
|
c1 = self.model.layer1(x) |
|
|
c2 = self.model.layer2(c1) |
|
|
c3 = self.model.layer3(c2) |
|
|
c4 = self.model.layer4(c3) |
|
|
return [c1, c2, c3, c4] |
|
|
|
|
|
|
|
|
class _HFEncoderWrapper(nn.Module): |
|
|
def __init__(self, in_chans: int, backbone: str, pretrained: bool): |
|
|
super().__init__() |
|
|
|
|
|
from transformers import SegformerModel, SegformerConfig |
|
|
|
|
|
name_map = { |
|
|
"mit_b0": "nvidia/mit-b0", |
|
|
"mit_b1": "nvidia/mit-b1", |
|
|
"mit_b2": "nvidia/mit-b2", |
|
|
"mit_b3": "nvidia/mit-b3", |
|
|
"mit_b4": "nvidia/mit-b4", |
|
|
"mit_b5": "nvidia/mit-b5", |
|
|
} |
|
|
model_id = name_map[backbone] |
|
|
|
|
|
if pretrained: |
|
|
base_cfg = SegformerConfig.from_pretrained(model_id) |
|
|
base_cfg.num_channels = in_chans |
|
|
self.model = SegformerModel.from_pretrained( |
|
|
model_id, config=base_cfg, ignore_mismatched_sizes=True |
|
|
) |
|
|
else: |
|
|
cfg = SegformerConfig() |
|
|
cfg.num_channels = in_chans |
|
|
self.model = SegformerModel(cfg) |
|
|
|
|
|
|
|
|
self.feature_dims = list(self.model.config.hidden_sizes) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> List[torch.Tensor]: |
|
|
outputs = self.model( |
|
|
pixel_values=x, output_hidden_states=True, return_dict=True |
|
|
) |
|
|
feats = list(outputs.hidden_states) |
|
|
assert len(feats) == 4 |
|
|
return feats |
|
|
|