"""build_backbone: instantiate one of {jit, pixelgen, deco, pixeldit} pixel-space denoisers for mask-concat conditioning. Each returns a net callable as net(x, t, y) -> (N, C>=img_ch, H, W); the caller slices [:, :img_ch] (backbone-agnostic decouple). in_channels = img_channels + cond_channels; a single dummy class (num_classes=1). Unified 'Base' tier (~130-150M) for the P1 backbone bake-off (P1 = native arch under a common flow-matching objective; perceptual/DCT/FD losses are P2 levers, not used here).""" import os import sys _SOTA = "/home/wzhang/LSC/Code/NPJ/sota" def _add(path): if path not in sys.path: sys.path.insert(0, path) def build_backbone(backbone: str, model_name: str, img_size: int, in_channels: int, num_classes: int = 1): bk = backbone.lower() if bk == "jit": _add(os.path.join(_SOTA, "JiT")) from model_jit import JiT_models return JiT_models[model_name](input_size=img_size, in_channels=in_channels, num_classes=num_classes) if bk == "pixelgen": _add(os.path.join(_SOTA, "PixelGen", "src", "models", "transformer")) import importlib jit = importlib.import_module("JiT") # PixelGen's self-contained JiT.py return jit.JiT_models[model_name](input_size=img_size, in_channels=in_channels, num_classes=num_classes) if bk == "deco": _add(os.path.join(_SOTA, "DeCo", "src", "models", "transformer")) from dit_c2i_DeCo import PixNerDiT return PixNerDiT(in_channels=in_channels, patch_size=16, num_groups=12, hidden_size=768, hidden_size_x=32, num_blocks=13, num_cond_blocks=12, num_classes=num_classes) if bk == "pixeldit": _add(os.path.join(_SOTA, "PixelDiT")) from pixdit_core.pixeldit_c2i import PixDiT return PixDiT(in_channels=in_channels, num_groups=10, hidden_size=640, pixel_hidden_size=16, patch_depth=9, pixel_depth=4, patch_size=16, num_classes=num_classes) raise ValueError(f"unknown backbone: {backbone}")