MRiabov's picture
(cleanup) removing timm from the codebase.
e3ab023
from typing import Tuple
import torch
import torch.nn as nn
from .encoder import SegFormerEncoder
from .decoder import CoarseDecoder, FineDecoder
from .condition import Conditioning1x1
class WireSegHR(nn.Module):
"""
Two-stage WireSegHR model wrapper with shared encoder.
Expects callers to prepare input channel stacks according to the plan:
- Coarse input: RGB + MinMax (and any extra channels per config), shape (B, Cc, Hc, Wc)
- Fine input: RGB + MinMax + cond_crop, shape (B, Cf, p, p)
Conditioning 1x1 is applied to coarse logits to produce a single-channel map.
"""
def __init__(
self, backbone: str = "mit_b2", in_channels: int = 6, pretrained: bool = True
):
super().__init__()
self.encoder = SegFormerEncoder(
backbone=backbone, in_channels=in_channels, pretrained=pretrained
)
# Use encoder-exposed feature dims for decoder projections
in_chs = tuple(self.encoder.feature_dims)
self.coarse_head = CoarseDecoder(in_chs=in_chs, embed_dim=128, num_classes=2)
self.fine_head = FineDecoder(in_chs=in_chs, embed_dim=128, num_classes=2)
self.cond1x1 = Conditioning1x1()
def forward_coarse(
self, x_coarse: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
assert x_coarse.dim() == 4
feats = self.encoder(x_coarse)
logits_coarse = self.coarse_head(feats)
cond_map = self.cond1x1(logits_coarse)
return logits_coarse, cond_map
def forward_fine(self, x_fine: torch.Tensor) -> torch.Tensor:
assert x_fine.dim() == 4
feats = self.encoder(x_fine)
logits_fine = self.fine_head(feats)
return logits_fine