GenSeg-Baselines / code /framework /models /transunet_wrap.py
MaybeRichard's picture
Upload folder using huggingface_hub
b8fae22 verified
Raw
History Blame Contribute Delete
1.55 kB
"""TransUNet wrapper.
Reuses ONLY the model definition from sota/TransUNet (networks/), not its
.npz/.h5 Synapse data pipeline. The model's forward already repeats 1->3 channels
for grayscale input, so it accepts our unified RGB or grayscale tensors.
Notes:
* img_size must be divisible by 16 (ViT patch grid). 224 is the canonical value.
* pretrained_ckpt should be the R50+ViT-B_16 .npz (ImageNet-21k); optional.
"""
from __future__ import annotations
import os
import sys
_REPO = os.path.join(os.path.dirname(__file__), "..", "..", "sota", "TransUNet")
_REPO = os.path.abspath(_REPO)
def _ensure_path():
if _REPO not in sys.path:
sys.path.insert(0, _REPO)
def build_transunet(in_channels: int, num_classes: int, img_size: int = 224,
encoder: str = "R50-ViT-B_16", pretrained_ckpt: str = "",
vit_patches_size: int = 16, **_):
_ensure_path()
import numpy as np
from networks.vit_seg_modeling import VisionTransformer, CONFIGS
vit_name = encoder if encoder in CONFIGS else "R50-ViT-B_16"
config_vit = CONFIGS[vit_name]
config_vit.n_classes = num_classes
config_vit.n_skip = 3
if "R50" in vit_name:
config_vit.patches.grid = (img_size // vit_patches_size,
img_size // vit_patches_size)
model = VisionTransformer(config_vit, img_size=img_size, num_classes=num_classes)
if pretrained_ckpt and os.path.isfile(pretrained_ckpt):
model.load_from(weights=np.load(pretrained_ckpt))
return model