"""TransUNet wrapper. Reuses ONLY the model definition from sota/TransUNet (networks/), not its .npz/.h5 Synapse data pipeline. The model's forward already repeats 1->3 channels for grayscale input, so it accepts our unified RGB or grayscale tensors. Notes: * img_size must be divisible by 16 (ViT patch grid). 224 is the canonical value. * pretrained_ckpt should be the R50+ViT-B_16 .npz (ImageNet-21k); optional. """ from __future__ import annotations import os import sys _REPO = os.path.join(os.path.dirname(__file__), "..", "..", "sota", "TransUNet") _REPO = os.path.abspath(_REPO) def _ensure_path(): if _REPO not in sys.path: sys.path.insert(0, _REPO) def build_transunet(in_channels: int, num_classes: int, img_size: int = 224, encoder: str = "R50-ViT-B_16", pretrained_ckpt: str = "", vit_patches_size: int = 16, **_): _ensure_path() import numpy as np from networks.vit_seg_modeling import VisionTransformer, CONFIGS vit_name = encoder if encoder in CONFIGS else "R50-ViT-B_16" config_vit = CONFIGS[vit_name] config_vit.n_classes = num_classes config_vit.n_skip = 3 if "R50" in vit_name: config_vit.patches.grid = (img_size // vit_patches_size, img_size // vit_patches_size) model = VisionTransformer(config_vit, img_size=img_size, num_classes=num_classes) if pretrained_ckpt and os.path.isfile(pretrained_ckpt): model.load_from(weights=np.load(pretrained_ckpt)) return model