GenSeg-Baselines / code /scripts /p1 /p1_master.py
MaybeRichard's picture
code: complete eval pipeline (7 metrics + per-class + Wilcoxon) + Swin-UNet/TransUNet networks; remove backups/obsolete
1a18f22 verified
Raw
History Blame Contribute Delete
7.16 kB
"""P1 master orchestrator: DAG scheduler over GPU0-5 for the backbone bake-off.
Phases: A) 8 generators (4 backbones x 2 datasets, amortized, 50k steps)
B) 16 sampling jobs (per gen x N in {50,100}, mask_aug n_per_mask=4)
C) 60 downstream seg runs (real + 4 backbones) x 2 ds x 2 N x 3 seeds
Single GPU per job (no DDP needed: 84 independent jobs). Retry-once on failure.
Resumable (skips done outputs). Rolling aggregate -> logs/p1master/p1_results.json."""
import os, sys, time, json, subprocess, statistics as st
ROOT = "/home/wzhang/LSC/Code/NPJ"
DR = "/home/wzhang/LSC/Dataset/Segmentation/processed_unified"
PY = "/opt/anaconda3/envs/seggen/bin/python"
GPUS = [0, 1, 2, 3, 4, 5]
os.chdir(ROOT)
LOGD = os.path.join(ROOT, "logs", "p1master")
os.makedirs(LOGD, exist_ok=True)
def log(m):
line = f"[{time.strftime('%F %T')}] {m}"
with open(os.path.join(LOGD, "status.md"), "a") as f:
f.write(line + "\n")
print(line, flush=True)
DSETS = {"isic": ("medsegdb_isic2018", "holdout", 2582),
"kvasir": ("kvasir_seg", "official", 800)}
BKS = ["jit", "pixelgen", "deco", "pixeldit"]
NS = [50, 100]
SEEDS = [0, 1, 2]
jobs = {}
def add(jid, cmd, deps=(), done_path=None, done_min=1):
jobs[jid] = {"cmd": cmd, "deps": list(deps), "done_path": done_path,
"done_min": done_min, "state": "pending", "tries": 0, "gpu": None}
# Phase A: generators
for bk in BKS:
for dk, (ds, proto, tot) in DSETS.items():
out = f"pretrained/pixdiff/p1_{bk}_{dk}.pt"
cmd = (f"{PY} -m framework.synth.pixdiff.train --data_root {DR} --dataset {ds} "
f"--protocol {proto} --backbone {bk} --img_size 256 --batch_size 16 "
f"--epochs 100000 --max_steps 50000 --lr 1e-4 --amp bf16 "
f"--train_fraction 1.0 --fraction_seed 0 --out_ckpt {out} --log_interval 500")
add(f"gen_{bk}_{dk}", cmd, done_path=os.path.join(ROOT, out))
# Phase B: sampling
for bk in BKS:
for dk, (ds, proto, tot) in DSETS.items():
ck = f"pretrained/pixdiff/p1_{bk}_{dk}.pt"
for N in NS:
f = N / tot
sd = f"{DR}/{ds}/{proto}/synth_p1_{bk}_{dk}_f{N}"
cmd = (f"{PY} -m framework.synth.pixdiff.sample --ckpt {ck} --data_root {DR} "
f"--dataset {ds} --protocol {proto} --train_fraction {f} --fraction_seed 0 "
f"--n_per_mask 4 --mask_aug --num_steps 50 --out_dir {sd}")
add(f"samp_{bk}_{dk}_N{N}", cmd, deps=[f"gen_{bk}_{dk}"],
done_path=os.path.join(sd, "images"), done_min=N * 4)
# Phase C: downstream
def mpath(exp, ds, proto, S):
return os.path.join(ROOT, f"results/{exp}/{ds}_{proto}/unet/seed{S}/metrics.json")
def seg_cmd(ds, proto, f, exp, S, synth=None):
base = (f"{PY} framework/train.py --data_root {DR} --dataset {ds} --protocol {proto} "
f"--arch unet --encoder resnet50 --aug standard --epochs 400 "
f"--train_fraction {f} --fraction_seed 0 --exp_name {exp} --amp bf16 --seed {S}")
if synth:
base += f" --synth_train_dir {synth}"
test = (f"{PY} framework/test.py --data_root {DR} --dataset {ds} --protocol {proto} "
f"--arch unet --encoder resnet50 --aug standard --exp_name {exp} --seed {S}")
return base + " && " + test
for dk, (ds, proto, tot) in DSETS.items():
for N in NS:
f = N / tot
for S in SEEDS:
exp = f"p1_real_{dk}_N{N}"
add(f"seg_real_{dk}_N{N}_s{S}", seg_cmd(ds, proto, f, exp, S),
done_path=mpath(exp, ds, proto, S))
for bk in BKS:
sd = f"{DR}/{ds}/{proto}/synth_p1_{bk}_{dk}_f{N}"
for S in SEEDS:
exp = f"p1_{bk}_{dk}_N{N}"
add(f"seg_{bk}_{dk}_N{N}_s{S}", seg_cmd(ds, proto, f, exp, S, synth=sd),
deps=[f"samp_{bk}_{dk}_N{N}"], done_path=mpath(exp, ds, proto, S))
def is_done(j):
p = j["done_path"]
if not p or not os.path.exists(p):
return False
if os.path.isdir(p):
try:
return len(os.listdir(p)) >= j["done_min"]
except OSError:
return False
return True
def aggregate():
res = {}
for dk, (ds, proto, tot) in DSETS.items():
for N in NS:
for arm in ["real"] + BKS:
exp = f"p1_{arm}_{dk}_N{N}"
ious, dices = [], []
for S in SEEDS:
mp = mpath(exp, ds, proto, S)
if os.path.exists(mp):
try:
m = json.load(open(mp))["metrics"]
ious.append(m["iou_mean"]); dices.append(m["dice_mean"])
except Exception:
pass
if ious:
res[f"{dk}_N{N}_{arm}"] = {
"iou_mean": sum(ious) / len(ious), "dice_mean": sum(dices) / len(dices),
"n_seeds": len(ious), "iou_seeds": ious, "dice_seeds": dices}
json.dump(res, open(os.path.join(LOGD, "p1_results.json"), "w"), indent=2)
for jid, j in jobs.items():
if is_done(j):
j["state"] = "done"
def deps_done(j):
return all(jobs[d]["state"] == "done" for d in j["deps"])
running = {}
free = set(GPUS)
MAXTRIES = 2
log(f"START {len(jobs)} jobs on GPUs {GPUS} ({sum(1 for j in jobs.values() if j['state']=='done')} pre-done)")
last_summary = 0
while True:
if all(j["state"] in ("done", "failed") for j in jobs.values()):
break
for jid, j in jobs.items():
if not free:
break
if j["state"] == "pending" and deps_done(j):
if is_done(j):
j["state"] = "done"; continue
g = free.pop()
env = dict(os.environ, CUDA_DEVICE_ORDER="PCI_BUS_ID",
CUDA_VISIBLE_DEVICES=str(g), TORCHDYNAMO_DISABLE="1",
PYTHONPATH=".", OMP_NUM_THREADS="4")
lf = open(os.path.join(LOGD, jid + ".log"), "a")
p = subprocess.Popen(j["cmd"], shell=True, env=env, stdout=lf,
stderr=subprocess.STDOUT, cwd=ROOT)
running[g] = (jid, p, lf); j["state"] = "running"; j["gpu"] = g; j["tries"] += 1
log(f"LAUNCH {jid} GPU{g} try{j['tries']}")
for g, (jid, p, lf) in list(running.items()):
rc = p.poll()
if rc is None:
continue
lf.close(); del running[g]; free.add(g)
j = jobs[jid]
if is_done(j):
j["state"] = "done"; log(f"DONE {jid} rc={rc}")
elif j["tries"] < MAXTRIES:
j["state"] = "pending"; log(f"RETRY {jid} rc={rc}")
else:
j["state"] = "failed"; log(f"FAILED {jid} rc={rc}")
if time.time() - last_summary > 300:
cnt = {s: sum(1 for j in jobs.values() if j["state"] == s)
for s in ("done", "running", "pending", "failed")}
log(f"SUMMARY {cnt} | running={sorted(j['gpu'] for j in jobs.values() if j['state']=='running')}")
aggregate(); last_summary = time.time()
time.sleep(10)
aggregate()
log("ALL DONE")
print("P1_MASTER_DONE", flush=True)