GenSeg-Baselines / code /scripts /p1 /backbones.py
MaybeRichard's picture
code: complete eval pipeline (7 metrics + per-class + Wilcoxon) + Swin-UNet/TransUNet networks; remove backups/obsolete
1a18f22 verified
Raw
History Blame Contribute Delete
2.19 kB
"""build_backbone: instantiate one of {jit, pixelgen, deco, pixeldit} pixel-space
denoisers for mask-concat conditioning. Each returns a net callable as net(x, t, y)
-> (N, C>=img_ch, H, W); the caller slices [:, :img_ch] (backbone-agnostic decouple).
in_channels = img_channels + cond_channels; a single dummy class (num_classes=1).
Unified 'Base' tier (~130-150M) for the P1 backbone bake-off (P1 = native arch under a
common flow-matching objective; perceptual/DCT/FD losses are P2 levers, not used here)."""
import os
import sys
_SOTA = "/home/wzhang/LSC/Code/NPJ/sota"
def _add(path):
if path not in sys.path:
sys.path.insert(0, path)
def build_backbone(backbone: str, model_name: str, img_size: int,
in_channels: int, num_classes: int = 1):
bk = backbone.lower()
if bk == "jit":
_add(os.path.join(_SOTA, "JiT"))
from model_jit import JiT_models
return JiT_models[model_name](input_size=img_size, in_channels=in_channels,
num_classes=num_classes)
if bk == "pixelgen":
_add(os.path.join(_SOTA, "PixelGen", "src", "models", "transformer"))
import importlib
jit = importlib.import_module("JiT") # PixelGen's self-contained JiT.py
return jit.JiT_models[model_name](input_size=img_size, in_channels=in_channels,
num_classes=num_classes)
if bk == "deco":
_add(os.path.join(_SOTA, "DeCo", "src", "models", "transformer"))
from dit_c2i_DeCo import PixNerDiT
return PixNerDiT(in_channels=in_channels, patch_size=16, num_groups=12,
hidden_size=768, hidden_size_x=32, num_blocks=13,
num_cond_blocks=12, num_classes=num_classes)
if bk == "pixeldit":
_add(os.path.join(_SOTA, "PixelDiT"))
from pixdit_core.pixeldit_c2i import PixDiT
return PixDiT(in_channels=in_channels, num_groups=10, hidden_size=640,
pixel_hidden_size=16, patch_depth=9, pixel_depth=4,
patch_size=16, num_classes=num_classes)
raise ValueError(f"unknown backbone: {backbone}")