code: complete eval pipeline (7 metrics + per-class + Wilcoxon) + Swin-UNet/TransUNet networks; remove backups/obsolete
1a18f22 verified | """build_backbone: instantiate one of {jit, pixelgen, deco, pixeldit} pixel-space | |
| denoisers for mask-concat conditioning. Each returns a net callable as net(x, t, y) | |
| -> (N, C>=img_ch, H, W); the caller slices [:, :img_ch] (backbone-agnostic decouple). | |
| in_channels = img_channels + cond_channels; a single dummy class (num_classes=1). | |
| Unified 'Base' tier (~130-150M) for the P1 backbone bake-off (P1 = native arch under a | |
| common flow-matching objective; perceptual/DCT/FD losses are P2 levers, not used here).""" | |
| import os | |
| import sys | |
| _SOTA = "/home/wzhang/LSC/Code/NPJ/sota" | |
| def _add(path): | |
| if path not in sys.path: | |
| sys.path.insert(0, path) | |
| def build_backbone(backbone: str, model_name: str, img_size: int, | |
| in_channels: int, num_classes: int = 1): | |
| bk = backbone.lower() | |
| if bk == "jit": | |
| _add(os.path.join(_SOTA, "JiT")) | |
| from model_jit import JiT_models | |
| return JiT_models[model_name](input_size=img_size, in_channels=in_channels, | |
| num_classes=num_classes) | |
| if bk == "pixelgen": | |
| _add(os.path.join(_SOTA, "PixelGen", "src", "models", "transformer")) | |
| import importlib | |
| jit = importlib.import_module("JiT") # PixelGen's self-contained JiT.py | |
| return jit.JiT_models[model_name](input_size=img_size, in_channels=in_channels, | |
| num_classes=num_classes) | |
| if bk == "deco": | |
| _add(os.path.join(_SOTA, "DeCo", "src", "models", "transformer")) | |
| from dit_c2i_DeCo import PixNerDiT | |
| return PixNerDiT(in_channels=in_channels, patch_size=16, num_groups=12, | |
| hidden_size=768, hidden_size_x=32, num_blocks=13, | |
| num_cond_blocks=12, num_classes=num_classes) | |
| if bk == "pixeldit": | |
| _add(os.path.join(_SOTA, "PixelDiT")) | |
| from pixdit_core.pixeldit_c2i import PixDiT | |
| return PixDiT(in_channels=in_channels, num_groups=10, hidden_size=640, | |
| pixel_hidden_size=16, patch_depth=9, pixel_depth=4, | |
| patch_size=16, num_classes=num_classes) | |
| raise ValueError(f"unknown backbone: {backbone}") | |