| """Swin-Unet wrapper. |
| |
| Instantiates SwinTransformerSys from sota/Swin-Unet directly (bypassing the repo's |
| yacs config + .npz/.h5 Synapse pipeline). Handles grayscale by repeating 1->3. |
| |
| Constraint: the windowed attention requires img_size = 224 (patches resolutions |
| 56/28/14/7 are each divisible by window_size=7). Other sizes will assert; we keep |
| img_size=224 for this backbone. |
| """ |
| from __future__ import annotations |
|
|
| import os |
| import sys |
|
|
| import torch |
| import torch.nn as nn |
|
|
| _REPO = os.path.join(os.path.dirname(__file__), "..", "..", "sota", "Swin-Unet") |
| _REPO = os.path.abspath(_REPO) |
|
|
|
|
| def _ensure_path(): |
| if _REPO not in sys.path: |
| sys.path.insert(0, _REPO) |
|
|
|
|
| class SwinUnetWrapper(nn.Module): |
| def __init__(self, in_channels: int, num_classes: int, img_size: int = 224, |
| pretrained_ckpt: str = ""): |
| super().__init__() |
| _ensure_path() |
| from networks.swin_transformer_unet_skip_expand_decoder_sys import SwinTransformerSys |
| if img_size != 224: |
| raise ValueError("Swin-Unet backbone requires img_size=224.") |
| self.net = SwinTransformerSys( |
| img_size=img_size, patch_size=4, in_chans=3, num_classes=num_classes, |
| embed_dim=96, depths=[2, 2, 2, 2], num_heads=[3, 6, 12, 24], |
| window_size=7, mlp_ratio=4.0, qkv_bias=True, drop_path_rate=0.1, |
| ape=False, patch_norm=True, use_checkpoint=False, |
| ) |
| if pretrained_ckpt and os.path.isfile(pretrained_ckpt): |
| self._load_pretrained(pretrained_ckpt) |
|
|
| def _load_pretrained(self, path): |
| """Port of Swin-Unet's load_from: load the ImageNet Swin-T encoder AND |
| mirror its `layers.X` weights into the decoder `layers_up.(3-X)` (the |
| scheme the paper uses to initialize the symmetric decoder).""" |
| import copy |
| ckpt = torch.load(path, map_location="cpu", weights_only=False) |
| if "model" not in ckpt: |
| self.net.load_state_dict(ckpt, strict=False) |
| return |
| pretrained = ckpt["model"] |
| model_dict = self.net.state_dict() |
| full = copy.deepcopy(pretrained) |
| for k, v in pretrained.items(): |
| if "layers." in k: |
| n = 3 - int(k[7:8]) |
| full["layers_up." + str(n) + k[8:]] = v |
| for k in list(full.keys()): |
| if k in model_dict and full[k].shape != model_dict[k].shape: |
| del full[k] |
| msg = self.net.load_state_dict(full, strict=False) |
| print(f"[swinunet] loaded pretrained {path}: " |
| f"missing={len(msg.missing_keys)} unexpected={len(msg.unexpected_keys)}") |
|
|
| def forward(self, x): |
| if x.size(1) == 1: |
| x = x.repeat(1, 3, 1, 1) |
| return self.net(x) |
|
|
|
|
| def build_swinunet(in_channels: int, num_classes: int, img_size: int = 224, |
| pretrained_ckpt: str = "", **_): |
| return SwinUnetWrapper(in_channels, num_classes, img_size, pretrained_ckpt) |
|
|