GenSeg-Baselines / code /framework /models /swinunet_wrap.py
MaybeRichard's picture
Upload folder using huggingface_hub
b8fae22 verified
Raw
History Blame Contribute Delete
2.96 kB
"""Swin-Unet wrapper.
Instantiates SwinTransformerSys from sota/Swin-Unet directly (bypassing the repo's
yacs config + .npz/.h5 Synapse pipeline). Handles grayscale by repeating 1->3.
Constraint: the windowed attention requires img_size = 224 (patches resolutions
56/28/14/7 are each divisible by window_size=7). Other sizes will assert; we keep
img_size=224 for this backbone.
"""
from __future__ import annotations
import os
import sys
import torch
import torch.nn as nn
_REPO = os.path.join(os.path.dirname(__file__), "..", "..", "sota", "Swin-Unet")
_REPO = os.path.abspath(_REPO)
def _ensure_path():
if _REPO not in sys.path:
sys.path.insert(0, _REPO)
class SwinUnetWrapper(nn.Module):
def __init__(self, in_channels: int, num_classes: int, img_size: int = 224,
pretrained_ckpt: str = ""):
super().__init__()
_ensure_path()
from networks.swin_transformer_unet_skip_expand_decoder_sys import SwinTransformerSys
if img_size != 224:
raise ValueError("Swin-Unet backbone requires img_size=224.")
self.net = SwinTransformerSys(
img_size=img_size, patch_size=4, in_chans=3, num_classes=num_classes,
embed_dim=96, depths=[2, 2, 2, 2], num_heads=[3, 6, 12, 24],
window_size=7, mlp_ratio=4.0, qkv_bias=True, drop_path_rate=0.1,
ape=False, patch_norm=True, use_checkpoint=False,
)
if pretrained_ckpt and os.path.isfile(pretrained_ckpt):
self._load_pretrained(pretrained_ckpt)
def _load_pretrained(self, path):
"""Port of Swin-Unet's load_from: load the ImageNet Swin-T encoder AND
mirror its `layers.X` weights into the decoder `layers_up.(3-X)` (the
scheme the paper uses to initialize the symmetric decoder)."""
import copy
ckpt = torch.load(path, map_location="cpu", weights_only=False)
if "model" not in ckpt:
self.net.load_state_dict(ckpt, strict=False)
return
pretrained = ckpt["model"]
model_dict = self.net.state_dict()
full = copy.deepcopy(pretrained)
for k, v in pretrained.items():
if "layers." in k:
n = 3 - int(k[7:8])
full["layers_up." + str(n) + k[8:]] = v
for k in list(full.keys()):
if k in model_dict and full[k].shape != model_dict[k].shape:
del full[k]
msg = self.net.load_state_dict(full, strict=False)
print(f"[swinunet] loaded pretrained {path}: "
f"missing={len(msg.missing_keys)} unexpected={len(msg.unexpected_keys)}")
def forward(self, x):
if x.size(1) == 1:
x = x.repeat(1, 3, 1, 1)
return self.net(x)
def build_swinunet(in_channels: int, num_classes: int, img_size: int = 224,
pretrained_ckpt: str = "", **_):
return SwinUnetWrapper(in_channels, num_classes, img_size, pretrained_ckpt)