GenSeg-Baselines / code /framework /synth /generative_baselines.py
MaybeRichard's picture
code: complete eval pipeline (7 metrics + per-class + Wilcoxon) + Swin-UNet/TransUNet networks; remove backups/obsolete
1a18f22 verified
Raw
History Blame Contribute Delete
5.25 kB
"""Orchestration for the generative-augmentation SOTA baselines (category B).
These methods are compared against our SegGen method. Each runs in its OWN conda
env (see envs/) because their dependency stacks conflict with the main framework.
The shared contract: every generator must emit paired (image, mask) into
<data_root>/<dataset>/<protocol>/synth_<method>/{images,masks}/
which the unified trainer then merges into the train split via --synth_train_dir.
Kept baselines:
* SegGuidedDiff (diffusion, mask->image, medical, modern stack) -- best fit, USE-AS-IS
* SPADE (GAN, mask->image) -- ADAPT (needs sync_bn)
* ControlNet (diffusion, SD-finetune, mask->image) -- ADAPT (needs SD ckpt)
Dropped (per scoping): StyleGAN2-ADA (no masks), LDM (dep hell + AE training).
This module only BUILDS the commands + assembles the standard synth dir; it does
not import the repos (they live in separate envs). Run the printed commands in the
matching env, then call assemble_synth_dir() (env-agnostic) to lay out pairs.
"""
from __future__ import annotations
import os
import shutil
from glob import glob
SOTA = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "sota"))
def assemble_synth_dir(generated_images_dir: str, masks_source_dir: str,
out_dir: str, strip_prefix: str = "condon_",
link: bool = True) -> int:
"""Pair each generated image with the real mask it was conditioned on.
Mask-conditioned generators name outputs after the conditioning mask
(SegGuidedDiff: 'condon_<maskname>.png'). We recover the mask name, copy/link
the matching real mask, and place both under out_dir/{images,masks}/.
Returns the number of pairs assembled.
"""
img_out = os.path.join(out_dir, "images")
msk_out = os.path.join(out_dir, "masks")
os.makedirs(img_out, exist_ok=True)
os.makedirs(msk_out, exist_ok=True)
n = 0
for gp in sorted(glob(os.path.join(generated_images_dir, "*"))):
base = os.path.basename(gp)
stem = os.path.splitext(base)[0]
if strip_prefix and stem.startswith(strip_prefix):
mask_stem = stem[len(strip_prefix):]
else:
mask_stem = stem
cands = glob(os.path.join(masks_source_dir, mask_stem + ".*"))
if not cands:
continue
out_name = f"synth_{n:06d}"
dst_img = os.path.join(img_out, out_name + os.path.splitext(base)[1])
dst_msk = os.path.join(msk_out, out_name + os.path.splitext(cands[0])[1])
_place(gp, dst_img, link)
_place(cands[0], dst_msk, link)
n += 1
return n
def _place(src, dst, link):
if os.path.exists(dst):
os.remove(dst)
if link:
os.symlink(os.path.abspath(src), dst)
else:
shutil.copy2(src, dst)
# ---- command builders (printed into run.sh; run in the matching conda env) ----
def segguideddiff_cmds(data_root, dataset, protocol, num_classes, in_channels,
img_size=256, epochs=400, sample_size=1000):
repo = os.path.join(SOTA, "segmentation-guided-diffusion")
img_dir = f"{data_root}/{dataset}/{protocol}/train/images"
seg_dir = f"{data_root}/{dataset}/{protocol}/train/masks"
train = (f"cd {repo} && python main.py --mode train --model_type DDIM "
f"--img_size {img_size} --num_img_channels {in_channels} --dataset {dataset} "
f"--img_dir {img_dir} --seg_dir {seg_dir} --segmentation_guided "
f"--num_segmentation_classes {num_classes} --num_epochs {epochs}")
synth = (f"cd {repo} && python main.py --mode eval_many --model_type DDIM "
f"--img_size {img_size} --num_img_channels {in_channels} --dataset {dataset} "
f"--seg_dir {seg_dir} --segmentation_guided "
f"--num_segmentation_classes {num_classes} --eval_sample_size {sample_size}")
return train, synth
def spade_cmds(data_root, dataset, protocol, num_classes, img_size=256, niter=100):
repo = os.path.join(SOTA, "SPADE")
img_dir = f"{data_root}/{dataset}/{protocol}/train/images"
lab_dir = f"{data_root}/{dataset}/{protocol}/train/masks"
setup = (f"cd {repo}/models/networks && "
f"git clone https://github.com/vacancy/Synchronized-BatchNorm-PyTorch && "
f"cp -r Synchronized-BatchNorm-PyTorch/sync_batchnorm .")
train = (f"cd {repo} && python train.py --name {dataset}_spade --dataset_mode custom "
f"--label_dir {lab_dir} --image_dir {img_dir} --label_nc {num_classes} "
f"--no_instance --crop_size {img_size} --load_size {int(img_size*1.12)} --niter {niter}")
synth = (f"cd {repo} && python test.py --name {dataset}_spade --dataset_mode custom "
f"--label_dir {lab_dir} --image_dir {img_dir} --label_nc {num_classes} "
f"--no_instance --results_dir ./synth_{dataset}")
return setup, train, synth
def controlnet_notes():
return ("ControlNet: download SD v1.5 (~4GB), run tool_add_control.py, write a "
"MyDataset that colorizes integer masks to RGB hints + triples grayscale "
"images to 3ch, then tutorial_train.py. Run in env seggen-controlnet.")