| """TransUNet wrapper. |
| |
| Reuses ONLY the model definition from sota/TransUNet (networks/), not its |
| .npz/.h5 Synapse data pipeline. The model's forward already repeats 1->3 channels |
| for grayscale input, so it accepts our unified RGB or grayscale tensors. |
| |
| Notes: |
| * img_size must be divisible by 16 (ViT patch grid). 224 is the canonical value. |
| * pretrained_ckpt should be the R50+ViT-B_16 .npz (ImageNet-21k); optional. |
| """ |
| from __future__ import annotations |
|
|
| import os |
| import sys |
|
|
| _REPO = os.path.join(os.path.dirname(__file__), "..", "..", "sota", "TransUNet") |
| _REPO = os.path.abspath(_REPO) |
|
|
|
|
| def _ensure_path(): |
| if _REPO not in sys.path: |
| sys.path.insert(0, _REPO) |
|
|
|
|
| def build_transunet(in_channels: int, num_classes: int, img_size: int = 224, |
| encoder: str = "R50-ViT-B_16", pretrained_ckpt: str = "", |
| vit_patches_size: int = 16, **_): |
| _ensure_path() |
| import numpy as np |
| from networks.vit_seg_modeling import VisionTransformer, CONFIGS |
|
|
| vit_name = encoder if encoder in CONFIGS else "R50-ViT-B_16" |
| config_vit = CONFIGS[vit_name] |
| config_vit.n_classes = num_classes |
| config_vit.n_skip = 3 |
| if "R50" in vit_name: |
| config_vit.patches.grid = (img_size // vit_patches_size, |
| img_size // vit_patches_size) |
| model = VisionTransformer(config_vit, img_size=img_size, num_classes=num_classes) |
| if pretrained_ckpt and os.path.isfile(pretrained_ckpt): |
| model.load_from(weights=np.load(pretrained_ckpt)) |
| return model |
|
|