File size: 4,842 Bytes
a2999cc
8ea2eff
a2999cc
 
 
 
 
 
8ea2eff
 
 
 
 
 
a2999cc
8ea2eff
 
 
 
 
d46d294
a527623
8ea2eff
 
 
 
 
 
 
8e73ec9
a2999cc
 
 
 
 
 
 
 
 
 
8e73ec9
a2999cc
 
 
8e73ec9
 
e3ab023
8e73ec9
a2999cc
 
 
8e73ec9
 
a2999cc
 
8e73ec9
a2999cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8ea2eff
 
a2999cc
 
 
 
 
 
 
 
 
 
 
 
8e73ec9
 
 
 
 
 
 
 
 
 
 
 
a527623
8e73ec9
 
 
a527623
8e73ec9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6928e82
 
 
8e73ec9
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
"""Encoder wrappers with adjustable input channels.

Supports two backbone families:
- HuggingFace Transformers SegFormer (e.g., "mit_b2")
- TorchVision ResNet-50 (use backbone "resnet50" | "resnet-50" | "resnet_50")

Both return a list of 4 multi-scale feature maps [C1, C2, C3, C4] at strides
1/4, 1/8, 1/16, 1/32 respectively.
"""

from typing import List, Tuple

import torch
import torch.nn as nn
from torchvision.models import resnet50, ResNet50_Weights


class SegFormerEncoder(nn.Module):
    def __init__(
        self,
        backbone: str = "mit_b2",
        in_channels: int = 6,
        pretrained: bool = True,
    ):
        super().__init__()
        self.backbone_name = backbone
        self.in_channels = in_channels
        self.pretrained = pretrained

        self.hf = None
        self.resnet = None

        # SegFormer path
        if backbone.startswith("mit_") or backbone.startswith("segformer"):
            self.hf = _HFEncoderWrapper(in_channels, backbone, pretrained)
            self.feature_dims = self.hf.feature_dims
        # ResNet-50 path
        elif backbone in ("resnet50", "resnet-50", "resnet_50"):
            self.resnet = _ResNetEncoderWrapper(in_channels, pretrained)
            self.feature_dims = self.resnet.feature_dims
        else:
            raise ValueError(
                f"Unsupported backbone '{backbone}'. Use one of: mit_b[0-5], segformer*, resnet50."
            )

    def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
        if self.hf is not None:
            return self.hf(x)
        if self.resnet is not None:
            return self.resnet(x)
        raise AssertionError("No encoder instantiated")


class _ResNetEncoderWrapper(nn.Module):
    def __init__(self, in_chans: int, pretrained: bool):
        super().__init__()
        # Build base ResNet-50
        if pretrained:
            self.model = resnet50(weights=ResNet50_Weights.DEFAULT)
        else:
            self.model = resnet50(weights=None)

        # Adjust input stem for arbitrary channel count
        if in_chans != 3:
            old_conv = self.model.conv1
            new_conv = nn.Conv2d(
                in_chans, old_conv.out_channels, kernel_size=old_conv.kernel_size[0],
                stride=old_conv.stride[0], padding=old_conv.padding[0], bias=False
            )
            with torch.no_grad():
                if pretrained and old_conv.weight.shape[1] == 3:
                    w = old_conv.weight  # [64, 3, 7, 7]
                    if in_chans > 3:
                        w_mean = w.mean(dim=1, keepdim=True)
                        new_w = w_mean.repeat(1, in_chans, 1, 1)
                    else:
                        new_w = w[:, :in_chans, :, :]
                    new_conv.weight.copy_(new_w)
                else:
                    nn.init.kaiming_normal_(new_conv.weight, mode="fan_out", nonlinearity="relu")
            self.model.conv1 = new_conv

        self.feature_dims = [256, 512, 1024, 2048]

    def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
        # Stem
        x = self.model.conv1(x)
        x = self.model.bn1(x)
        x = self.model.relu(x)
        x = self.model.maxpool(x)  # 1/4

        # Stages
        c1 = self.model.layer1(x)  # 1/4, 256
        c2 = self.model.layer2(c1)  # 1/8, 512
        c3 = self.model.layer3(c2)  # 1/16, 1024
        c4 = self.model.layer4(c3)  # 1/32, 2048
        return [c1, c2, c3, c4]


class _HFEncoderWrapper(nn.Module):
    def __init__(self, in_chans: int, backbone: str, pretrained: bool):
        super().__init__()
        # Lazy import to avoid hard dependency during tests if not used
        from transformers import SegformerModel, SegformerConfig

        name_map = {
            "mit_b0": "nvidia/mit-b0",
            "mit_b1": "nvidia/mit-b1",
            "mit_b2": "nvidia/mit-b2",
            "mit_b3": "nvidia/mit-b3",
            "mit_b4": "nvidia/mit-b4",
            "mit_b5": "nvidia/mit-b5",
        }
        model_id = name_map[backbone]

        if pretrained:
            base_cfg = SegformerConfig.from_pretrained(model_id)
            base_cfg.num_channels = in_chans
            self.model = SegformerModel.from_pretrained(
                model_id, config=base_cfg, ignore_mismatched_sizes=True
            )
        else:
            cfg = SegformerConfig()  # default config (B0-like)
            cfg.num_channels = in_chans
            self.model = SegformerModel(cfg)

        # Expose channel dims per stage
        self.feature_dims = list(self.model.config.hidden_sizes)

    def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
        outputs = self.model(
            pixel_values=x, output_hidden_states=True, return_dict=True
        )
        feats = list(outputs.hidden_states)
        assert len(feats) == 4
        return feats