"""Swin-Unet wrapper. Instantiates SwinTransformerSys from sota/Swin-Unet directly (bypassing the repo's yacs config + .npz/.h5 Synapse pipeline). Handles grayscale by repeating 1->3. Constraint: the windowed attention requires img_size = 224 (patches resolutions 56/28/14/7 are each divisible by window_size=7). Other sizes will assert; we keep img_size=224 for this backbone. """ from __future__ import annotations import os import sys import torch import torch.nn as nn _REPO = os.path.join(os.path.dirname(__file__), "..", "..", "sota", "Swin-Unet") _REPO = os.path.abspath(_REPO) def _ensure_path(): if _REPO not in sys.path: sys.path.insert(0, _REPO) class SwinUnetWrapper(nn.Module): def __init__(self, in_channels: int, num_classes: int, img_size: int = 224, pretrained_ckpt: str = ""): super().__init__() _ensure_path() from networks.swin_transformer_unet_skip_expand_decoder_sys import SwinTransformerSys if img_size != 224: raise ValueError("Swin-Unet backbone requires img_size=224.") self.net = SwinTransformerSys( img_size=img_size, patch_size=4, in_chans=3, num_classes=num_classes, embed_dim=96, depths=[2, 2, 2, 2], num_heads=[3, 6, 12, 24], window_size=7, mlp_ratio=4.0, qkv_bias=True, drop_path_rate=0.1, ape=False, patch_norm=True, use_checkpoint=False, ) if pretrained_ckpt and os.path.isfile(pretrained_ckpt): self._load_pretrained(pretrained_ckpt) def _load_pretrained(self, path): """Port of Swin-Unet's load_from: load the ImageNet Swin-T encoder AND mirror its `layers.X` weights into the decoder `layers_up.(3-X)` (the scheme the paper uses to initialize the symmetric decoder).""" import copy ckpt = torch.load(path, map_location="cpu", weights_only=False) if "model" not in ckpt: self.net.load_state_dict(ckpt, strict=False) return pretrained = ckpt["model"] model_dict = self.net.state_dict() full = copy.deepcopy(pretrained) for k, v in pretrained.items(): if "layers." in k: n = 3 - int(k[7:8]) full["layers_up." + str(n) + k[8:]] = v for k in list(full.keys()): if k in model_dict and full[k].shape != model_dict[k].shape: del full[k] msg = self.net.load_state_dict(full, strict=False) print(f"[swinunet] loaded pretrained {path}: " f"missing={len(msg.missing_keys)} unexpected={len(msg.unexpected_keys)}") def forward(self, x): if x.size(1) == 1: x = x.repeat(1, 3, 1, 1) return self.net(x) def build_swinunet(in_channels: int, num_classes: int, img_size: int = 224, pretrained_ckpt: str = "", **_): return SwinUnetWrapper(in_channels, num_classes, img_size, pretrained_ckpt)