"""FD-lever ablation on the recommended P2 base (JiT): refine p1_jit_{ds} with FD loss -> sample -> downstream. Compares +JiT-FD vs +JiT(native) vs real. 18 jobs on GPU0-5.""" import os, time, json, subprocess ROOT = "/home/wzhang/LSC/Code/NPJ" DR = "/home/wzhang/LSC/Dataset/Segmentation/processed_unified" PY = "/opt/anaconda3/envs/seggen/bin/python" GPUS = [0, 1, 2, 3, 4, 5] os.chdir(ROOT) LOGD = os.path.join(ROOT, "logs", "fdlever") os.makedirs(LOGD, exist_ok=True) def log(m): line = f"[{time.strftime('%F %T')}] {m}" open(os.path.join(LOGD, "status.md"), "a").write(line + "\n"); print(line, flush=True) DSETS = {"isic": ("medsegdb_isic2018", "holdout", 2582), "kvasir": ("kvasir_seg", "official", 800)} NS = [50, 100]; SEEDS = [0, 1, 2] jobs = {} def add(jid, cmd, deps=(), done_path=None, done_min=1): jobs[jid] = {"cmd": cmd, "deps": list(deps), "done_path": done_path, "done_min": done_min, "state": "pending", "tries": 0, "gpu": None} for dk, (ds, proto, tot) in DSETS.items(): base = f"pretrained/pixdiff/p1_jit_{dk}.pt" out = f"pretrained/pixdiff/p1_jitfd_{dk}.pt" cmd = (f"{PY} -m framework.synth.pixdiff.train_fd --base_ckpt {base} --data_root {DR} " f"--dataset {ds} --protocol {proto} --train_fraction 1.0 --epochs 150 --batch_size 16 " f"--amp bf16 --fd_weight 0.5 --out_ckpt {out} --log_interval 100") add(f"genfd_{dk}", cmd, done_path=os.path.join(ROOT, out)) for N in NS: f = N / tot sd = f"{DR}/{ds}/{proto}/synth_p1_jitfd_{dk}_f{N}" cmd = (f"{PY} -m framework.synth.pixdiff.sample --ckpt {out} --data_root {DR} --dataset {ds} " f"--protocol {proto} --train_fraction {f} --fraction_seed 0 --n_per_mask 4 --mask_aug " f"--num_steps 50 --out_dir {sd}") add(f"samp_jitfd_{dk}_N{N}", cmd, deps=[f"genfd_{dk}"], done_path=os.path.join(sd, "images"), done_min=N * 4) for S in SEEDS: exp = f"p1_jitfd_{dk}_N{N}" mp = os.path.join(ROOT, f"results/{exp}/{ds}_{proto}/unet/seed{S}/metrics.json") cmd = (f"{PY} framework/train.py --data_root {DR} --dataset {ds} --protocol {proto} --arch unet " f"--encoder resnet50 --aug standard --epochs 400 --train_fraction {f} --fraction_seed 0 " f"--synth_train_dir {sd} --exp_name {exp} --amp bf16 --seed {S} " f"&& {PY} framework/test.py --data_root {DR} --dataset {ds} --protocol {proto} --arch unet " f"--encoder resnet50 --aug standard --exp_name {exp} --seed {S}") add(f"seg_jitfd_{dk}_N{N}_s{S}", cmd, deps=[f"samp_jitfd_{dk}_N{N}"], done_path=mp) def is_done(j): p = j["done_path"] if not p or not os.path.exists(p): return False if os.path.isdir(p): try: return len(os.listdir(p)) >= j["done_min"] except OSError: return False return True def aggregate(): res = {} for dk, (ds, proto, tot) in DSETS.items(): for N in NS: exp = f"p1_jitfd_{dk}_N{N}"; ious = []; dices = [] for S in SEEDS: mp = f"results/{exp}/{ds}_{proto}/unet/seed{S}/metrics.json" if os.path.exists(mp): try: m = json.load(open(mp))["metrics"]; ious.append(m["iou_mean"]); dices.append(m["dice_mean"]) except Exception: pass if ious: res[f"{dk}_N{N}_jitfd"] = {"iou_mean": sum(ious) / len(ious), "dice_mean": sum(dices) / len(dices), "n_seeds": len(ious), "iou_seeds": ious} json.dump(res, open(os.path.join(LOGD, "fd_results.json"), "w"), indent=2) for jid, j in jobs.items(): if is_done(j): j["state"] = "done" def deps_done(j): return all(jobs[d]["state"] == "done" for d in j["deps"]) running = {}; free = set(GPUS); last = 0 log(f"START {len(jobs)} jobs on {GPUS} ({sum(1 for j in jobs.values() if j['state']=='done')} pre-done)") while True: if all(j["state"] in ("done", "failed") for j in jobs.values()): break for jid, j in jobs.items(): if not free: break if j["state"] == "pending" and deps_done(j): if is_done(j): j["state"] = "done"; continue g = free.pop() env = dict(os.environ, CUDA_DEVICE_ORDER="PCI_BUS_ID", CUDA_VISIBLE_DEVICES=str(g), TORCHDYNAMO_DISABLE="1", PYTHONPATH=".", OMP_NUM_THREADS="4") lf = open(os.path.join(LOGD, jid + ".log"), "a") p = subprocess.Popen(j["cmd"], shell=True, env=env, stdout=lf, stderr=subprocess.STDOUT, cwd=ROOT) running[g] = (jid, p, lf); j["state"] = "running"; j["gpu"] = g; j["tries"] += 1 log(f"LAUNCH {jid} GPU{g} try{j['tries']}") for g, (jid, p, lf) in list(running.items()): rc = p.poll() if rc is None: continue lf.close(); del running[g]; free.add(g); j = jobs[jid] if is_done(j): j["state"] = "done"; log(f"DONE {jid}") elif j["tries"] < 2: j["state"] = "pending"; log(f"RETRY {jid} rc={rc}") else: j["state"] = "failed"; log(f"FAILED {jid} rc={rc}") if time.time() - last > 180: cnt = {s: sum(1 for j in jobs.values() if j["state"] == s) for s in ("done", "running", "pending", "failed")} log(f"SUMMARY {cnt}"); aggregate(); last = time.time() time.sleep(10) aggregate(); log("ALL DONE"); print("FD_LEVER_DONE", flush=True)