code: complete eval pipeline (7 metrics + per-class + Wilcoxon) + Swin-UNet/TransUNet networks; remove backups/obsolete
1a18f22 verified | """Orchestration for the generative-augmentation SOTA baselines (category B). | |
| These methods are compared against our SegGen method. Each runs in its OWN conda | |
| env (see envs/) because their dependency stacks conflict with the main framework. | |
| The shared contract: every generator must emit paired (image, mask) into | |
| <data_root>/<dataset>/<protocol>/synth_<method>/{images,masks}/ | |
| which the unified trainer then merges into the train split via --synth_train_dir. | |
| Kept baselines: | |
| * SegGuidedDiff (diffusion, mask->image, medical, modern stack) -- best fit, USE-AS-IS | |
| * SPADE (GAN, mask->image) -- ADAPT (needs sync_bn) | |
| * ControlNet (diffusion, SD-finetune, mask->image) -- ADAPT (needs SD ckpt) | |
| Dropped (per scoping): StyleGAN2-ADA (no masks), LDM (dep hell + AE training). | |
| This module only BUILDS the commands + assembles the standard synth dir; it does | |
| not import the repos (they live in separate envs). Run the printed commands in the | |
| matching env, then call assemble_synth_dir() (env-agnostic) to lay out pairs. | |
| """ | |
| from __future__ import annotations | |
| import os | |
| import shutil | |
| from glob import glob | |
| SOTA = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "sota")) | |
| def assemble_synth_dir(generated_images_dir: str, masks_source_dir: str, | |
| out_dir: str, strip_prefix: str = "condon_", | |
| link: bool = True) -> int: | |
| """Pair each generated image with the real mask it was conditioned on. | |
| Mask-conditioned generators name outputs after the conditioning mask | |
| (SegGuidedDiff: 'condon_<maskname>.png'). We recover the mask name, copy/link | |
| the matching real mask, and place both under out_dir/{images,masks}/. | |
| Returns the number of pairs assembled. | |
| """ | |
| img_out = os.path.join(out_dir, "images") | |
| msk_out = os.path.join(out_dir, "masks") | |
| os.makedirs(img_out, exist_ok=True) | |
| os.makedirs(msk_out, exist_ok=True) | |
| n = 0 | |
| for gp in sorted(glob(os.path.join(generated_images_dir, "*"))): | |
| base = os.path.basename(gp) | |
| stem = os.path.splitext(base)[0] | |
| if strip_prefix and stem.startswith(strip_prefix): | |
| mask_stem = stem[len(strip_prefix):] | |
| else: | |
| mask_stem = stem | |
| cands = glob(os.path.join(masks_source_dir, mask_stem + ".*")) | |
| if not cands: | |
| continue | |
| out_name = f"synth_{n:06d}" | |
| dst_img = os.path.join(img_out, out_name + os.path.splitext(base)[1]) | |
| dst_msk = os.path.join(msk_out, out_name + os.path.splitext(cands[0])[1]) | |
| _place(gp, dst_img, link) | |
| _place(cands[0], dst_msk, link) | |
| n += 1 | |
| return n | |
| def _place(src, dst, link): | |
| if os.path.exists(dst): | |
| os.remove(dst) | |
| if link: | |
| os.symlink(os.path.abspath(src), dst) | |
| else: | |
| shutil.copy2(src, dst) | |
| # ---- command builders (printed into run.sh; run in the matching conda env) ---- | |
| def segguideddiff_cmds(data_root, dataset, protocol, num_classes, in_channels, | |
| img_size=256, epochs=400, sample_size=1000): | |
| repo = os.path.join(SOTA, "segmentation-guided-diffusion") | |
| img_dir = f"{data_root}/{dataset}/{protocol}/train/images" | |
| seg_dir = f"{data_root}/{dataset}/{protocol}/train/masks" | |
| train = (f"cd {repo} && python main.py --mode train --model_type DDIM " | |
| f"--img_size {img_size} --num_img_channels {in_channels} --dataset {dataset} " | |
| f"--img_dir {img_dir} --seg_dir {seg_dir} --segmentation_guided " | |
| f"--num_segmentation_classes {num_classes} --num_epochs {epochs}") | |
| synth = (f"cd {repo} && python main.py --mode eval_many --model_type DDIM " | |
| f"--img_size {img_size} --num_img_channels {in_channels} --dataset {dataset} " | |
| f"--seg_dir {seg_dir} --segmentation_guided " | |
| f"--num_segmentation_classes {num_classes} --eval_sample_size {sample_size}") | |
| return train, synth | |
| def spade_cmds(data_root, dataset, protocol, num_classes, img_size=256, niter=100): | |
| repo = os.path.join(SOTA, "SPADE") | |
| img_dir = f"{data_root}/{dataset}/{protocol}/train/images" | |
| lab_dir = f"{data_root}/{dataset}/{protocol}/train/masks" | |
| setup = (f"cd {repo}/models/networks && " | |
| f"git clone https://github.com/vacancy/Synchronized-BatchNorm-PyTorch && " | |
| f"cp -r Synchronized-BatchNorm-PyTorch/sync_batchnorm .") | |
| train = (f"cd {repo} && python train.py --name {dataset}_spade --dataset_mode custom " | |
| f"--label_dir {lab_dir} --image_dir {img_dir} --label_nc {num_classes} " | |
| f"--no_instance --crop_size {img_size} --load_size {int(img_size*1.12)} --niter {niter}") | |
| synth = (f"cd {repo} && python test.py --name {dataset}_spade --dataset_mode custom " | |
| f"--label_dir {lab_dir} --image_dir {img_dir} --label_nc {num_classes} " | |
| f"--no_instance --results_dir ./synth_{dataset}") | |
| return setup, train, synth | |
| def controlnet_notes(): | |
| return ("ControlNet: download SD v1.5 (~4GB), run tool_add_control.py, write a " | |
| "MyDataset that colorizes integer masks to RGB hints + triples grayscale " | |
| "images to 3ch, then tutorial_train.py. Run in env seggen-controlnet.") | |