File size: 4,842 Bytes
a2999cc 8ea2eff a2999cc 8ea2eff a2999cc 8ea2eff d46d294 a527623 8ea2eff 8e73ec9 a2999cc 8e73ec9 a2999cc 8e73ec9 e3ab023 8e73ec9 a2999cc 8e73ec9 a2999cc 8e73ec9 a2999cc 8ea2eff a2999cc 8e73ec9 a527623 8e73ec9 a527623 8e73ec9 6928e82 8e73ec9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
"""Encoder wrappers with adjustable input channels.
Supports two backbone families:
- HuggingFace Transformers SegFormer (e.g., "mit_b2")
- TorchVision ResNet-50 (use backbone "resnet50" | "resnet-50" | "resnet_50")
Both return a list of 4 multi-scale feature maps [C1, C2, C3, C4] at strides
1/4, 1/8, 1/16, 1/32 respectively.
"""
from typing import List, Tuple
import torch
import torch.nn as nn
from torchvision.models import resnet50, ResNet50_Weights
class SegFormerEncoder(nn.Module):
def __init__(
self,
backbone: str = "mit_b2",
in_channels: int = 6,
pretrained: bool = True,
):
super().__init__()
self.backbone_name = backbone
self.in_channels = in_channels
self.pretrained = pretrained
self.hf = None
self.resnet = None
# SegFormer path
if backbone.startswith("mit_") or backbone.startswith("segformer"):
self.hf = _HFEncoderWrapper(in_channels, backbone, pretrained)
self.feature_dims = self.hf.feature_dims
# ResNet-50 path
elif backbone in ("resnet50", "resnet-50", "resnet_50"):
self.resnet = _ResNetEncoderWrapper(in_channels, pretrained)
self.feature_dims = self.resnet.feature_dims
else:
raise ValueError(
f"Unsupported backbone '{backbone}'. Use one of: mit_b[0-5], segformer*, resnet50."
)
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
if self.hf is not None:
return self.hf(x)
if self.resnet is not None:
return self.resnet(x)
raise AssertionError("No encoder instantiated")
class _ResNetEncoderWrapper(nn.Module):
def __init__(self, in_chans: int, pretrained: bool):
super().__init__()
# Build base ResNet-50
if pretrained:
self.model = resnet50(weights=ResNet50_Weights.DEFAULT)
else:
self.model = resnet50(weights=None)
# Adjust input stem for arbitrary channel count
if in_chans != 3:
old_conv = self.model.conv1
new_conv = nn.Conv2d(
in_chans, old_conv.out_channels, kernel_size=old_conv.kernel_size[0],
stride=old_conv.stride[0], padding=old_conv.padding[0], bias=False
)
with torch.no_grad():
if pretrained and old_conv.weight.shape[1] == 3:
w = old_conv.weight # [64, 3, 7, 7]
if in_chans > 3:
w_mean = w.mean(dim=1, keepdim=True)
new_w = w_mean.repeat(1, in_chans, 1, 1)
else:
new_w = w[:, :in_chans, :, :]
new_conv.weight.copy_(new_w)
else:
nn.init.kaiming_normal_(new_conv.weight, mode="fan_out", nonlinearity="relu")
self.model.conv1 = new_conv
self.feature_dims = [256, 512, 1024, 2048]
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
# Stem
x = self.model.conv1(x)
x = self.model.bn1(x)
x = self.model.relu(x)
x = self.model.maxpool(x) # 1/4
# Stages
c1 = self.model.layer1(x) # 1/4, 256
c2 = self.model.layer2(c1) # 1/8, 512
c3 = self.model.layer3(c2) # 1/16, 1024
c4 = self.model.layer4(c3) # 1/32, 2048
return [c1, c2, c3, c4]
class _HFEncoderWrapper(nn.Module):
def __init__(self, in_chans: int, backbone: str, pretrained: bool):
super().__init__()
# Lazy import to avoid hard dependency during tests if not used
from transformers import SegformerModel, SegformerConfig
name_map = {
"mit_b0": "nvidia/mit-b0",
"mit_b1": "nvidia/mit-b1",
"mit_b2": "nvidia/mit-b2",
"mit_b3": "nvidia/mit-b3",
"mit_b4": "nvidia/mit-b4",
"mit_b5": "nvidia/mit-b5",
}
model_id = name_map[backbone]
if pretrained:
base_cfg = SegformerConfig.from_pretrained(model_id)
base_cfg.num_channels = in_chans
self.model = SegformerModel.from_pretrained(
model_id, config=base_cfg, ignore_mismatched_sizes=True
)
else:
cfg = SegformerConfig() # default config (B0-like)
cfg.num_channels = in_chans
self.model = SegformerModel(cfg)
# Expose channel dims per stage
self.feature_dims = list(self.model.config.hidden_sizes)
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
outputs = self.model(
pixel_values=x, output_hidden_states=True, return_dict=True
)
feats = list(outputs.hidden_states)
assert len(feats) == 4
return feats
|