"""Orchestration for the generative-augmentation SOTA baselines (category B). These methods are compared against our SegGen method. Each runs in its OWN conda env (see envs/) because their dependency stacks conflict with the main framework. The shared contract: every generator must emit paired (image, mask) into ///synth_/{images,masks}/ which the unified trainer then merges into the train split via --synth_train_dir. Kept baselines: * SegGuidedDiff (diffusion, mask->image, medical, modern stack) -- best fit, USE-AS-IS * SPADE (GAN, mask->image) -- ADAPT (needs sync_bn) * ControlNet (diffusion, SD-finetune, mask->image) -- ADAPT (needs SD ckpt) Dropped (per scoping): StyleGAN2-ADA (no masks), LDM (dep hell + AE training). This module only BUILDS the commands + assembles the standard synth dir; it does not import the repos (they live in separate envs). Run the printed commands in the matching env, then call assemble_synth_dir() (env-agnostic) to lay out pairs. """ from __future__ import annotations import os import shutil from glob import glob SOTA = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "sota")) def assemble_synth_dir(generated_images_dir: str, masks_source_dir: str, out_dir: str, strip_prefix: str = "condon_", link: bool = True) -> int: """Pair each generated image with the real mask it was conditioned on. Mask-conditioned generators name outputs after the conditioning mask (SegGuidedDiff: 'condon_.png'). We recover the mask name, copy/link the matching real mask, and place both under out_dir/{images,masks}/. Returns the number of pairs assembled. """ img_out = os.path.join(out_dir, "images") msk_out = os.path.join(out_dir, "masks") os.makedirs(img_out, exist_ok=True) os.makedirs(msk_out, exist_ok=True) n = 0 for gp in sorted(glob(os.path.join(generated_images_dir, "*"))): base = os.path.basename(gp) stem = os.path.splitext(base)[0] if strip_prefix and stem.startswith(strip_prefix): mask_stem = stem[len(strip_prefix):] else: mask_stem = stem cands = glob(os.path.join(masks_source_dir, mask_stem + ".*")) if not cands: continue out_name = f"synth_{n:06d}" dst_img = os.path.join(img_out, out_name + os.path.splitext(base)[1]) dst_msk = os.path.join(msk_out, out_name + os.path.splitext(cands[0])[1]) _place(gp, dst_img, link) _place(cands[0], dst_msk, link) n += 1 return n def _place(src, dst, link): if os.path.exists(dst): os.remove(dst) if link: os.symlink(os.path.abspath(src), dst) else: shutil.copy2(src, dst) # ---- command builders (printed into run.sh; run in the matching conda env) ---- def segguideddiff_cmds(data_root, dataset, protocol, num_classes, in_channels, img_size=256, epochs=400, sample_size=1000): repo = os.path.join(SOTA, "segmentation-guided-diffusion") img_dir = f"{data_root}/{dataset}/{protocol}/train/images" seg_dir = f"{data_root}/{dataset}/{protocol}/train/masks" train = (f"cd {repo} && python main.py --mode train --model_type DDIM " f"--img_size {img_size} --num_img_channels {in_channels} --dataset {dataset} " f"--img_dir {img_dir} --seg_dir {seg_dir} --segmentation_guided " f"--num_segmentation_classes {num_classes} --num_epochs {epochs}") synth = (f"cd {repo} && python main.py --mode eval_many --model_type DDIM " f"--img_size {img_size} --num_img_channels {in_channels} --dataset {dataset} " f"--seg_dir {seg_dir} --segmentation_guided " f"--num_segmentation_classes {num_classes} --eval_sample_size {sample_size}") return train, synth def spade_cmds(data_root, dataset, protocol, num_classes, img_size=256, niter=100): repo = os.path.join(SOTA, "SPADE") img_dir = f"{data_root}/{dataset}/{protocol}/train/images" lab_dir = f"{data_root}/{dataset}/{protocol}/train/masks" setup = (f"cd {repo}/models/networks && " f"git clone https://github.com/vacancy/Synchronized-BatchNorm-PyTorch && " f"cp -r Synchronized-BatchNorm-PyTorch/sync_batchnorm .") train = (f"cd {repo} && python train.py --name {dataset}_spade --dataset_mode custom " f"--label_dir {lab_dir} --image_dir {img_dir} --label_nc {num_classes} " f"--no_instance --crop_size {img_size} --load_size {int(img_size*1.12)} --niter {niter}") synth = (f"cd {repo} && python test.py --name {dataset}_spade --dataset_mode custom " f"--label_dir {lab_dir} --image_dir {img_dir} --label_nc {num_classes} " f"--no_instance --results_dir ./synth_{dataset}") return setup, train, synth def controlnet_notes(): return ("ControlNet: download SD v1.5 (~4GB), run tool_add_control.py, write a " "MyDataset that colorizes integer masks to RGB hints + triples grayscale " "images to 3ch, then tutorial_train.py. Run in env seggen-controlnet.")