GenSeg-Baselines / code /scripts /hf_upload_baselines.py
MaybeRichard's picture
Upload folder using huggingface_hub
057ec4b verified
Raw
History Blame Contribute Delete
7.97 kB
"""Upload the BASELINE snapshot (code + confirmed results + curated weights) to a
private HF model repo. SegGen WIP is excluded. Run ON a100 (where code/results/weights live).
# dry-run: print manifest, upload nothing
HF_TOKEN=... python scripts/hf_upload_baselines.py --dry-run
# real:
HF_TOKEN=... python scripts/hf_upload_baselines.py
Curation: code = framework (minus synth/__pycache__) + scripts + envs/{seggen,nnunet,umamba}.yml.
results = all metrics.json + summary.{html,csv,md,tex} + efficiency.md. weights = best seed per
(dataset,arch) for framework + best fold per dataset for nnU-Net/U-Mamba.
"""
import os, glob, json, argparse, sys
REPO = "MaybeRichard/GenSeg-Baselines"
ROOT = "/home/wzhang/LSC/Code/NPJ"
NNRAW_IDS = { # results/baselines cell name -> nnU-Net Dataset id
"cvc_clinicdb_official":1, "kvasir_seg_official":2, "fives_official":3,
"refuge2_official":4, "busi_fold01":5, "idridd_segmentation_fold01":6,
"acdc_png_official":7, "pannuke_semantic_fold01":8,
"medsegdb_isic2018_holdout":9, "medsegdb_kits19_fold01":10,
}
def sz(p):
try: return os.path.getsize(p)
except OSError: return 0
def human(n): return f"{n/1024**3:.2f} GB" if n>=1024**3 else f"{n/1024**2:.1f} MB"
def curate_framework():
"""best seed per (cell,arch) -> best.pth"""
best = {}
for mj in glob.glob(f"{ROOT}/results/baselines/*/*/seed*/metrics.json"):
parts = mj.split("/"); cell, arch = parts[-4], parts[-3]
if arch in ("nnunet","umamba"): continue # handled separately
try: dice = json.load(open(mj)).get("metrics",{}).get("dice_mean",0)
except Exception: continue
pth = os.path.join(os.path.dirname(mj),"best.pth")
if not os.path.isfile(pth): continue
k = (cell,arch)
if k not in best or dice > best[k][0]:
best[k] = (dice, pth, f"weights/framework/{cell}/{arch}.pth")
return best
def curate_nn(method, results_dir):
"""best fold per dataset -> checkpoint_best.pth (matched via results/baselines metrics)"""
out = {}
for cell, did in NNRAW_IDS.items():
# best fold by our scored metrics
folds = []
for mj in glob.glob(f"{ROOT}/results/baselines/{cell}/{method}/seed*/metrics.json"):
f = int(mj.split("/seed")[-1].split("/")[0])
try: d = json.load(open(mj)).get("metrics",{}).get("dice_mean",0)
except Exception: d = 0
folds.append((d,f))
if not folds: continue
_, bf = max(folds)
cks = glob.glob(f"{results_dir}/Dataset{did:03d}_*/**/fold_{bf}/checkpoint_best.pth", recursive=True)
if not cks: continue
ck = max(cks, key=sz)
out[cell] = (ck, f"weights/{method}/{cell}_fold{bf}.pth")
return out
def list_code():
inc = []
for r,_,fs in os.walk(f"{ROOT}/framework"):
if "/synth" in r or "__pycache__" in r: continue
for f in fs:
if f.endswith(".pyc"): continue
inc.append(os.path.join(r,f))
for r,_,fs in os.walk(f"{ROOT}/scripts"):
if "__pycache__" in r: continue
for f in fs: inc.append(os.path.join(r,f))
for y in ("seggen","nnunet","umamba"):
p=f"{ROOT}/envs/{y}.yml"
if os.path.isfile(p): inc.append(p)
return inc
def list_results():
out = glob.glob(f"{ROOT}/results/baselines/**/metrics.json", recursive=True)
for pat in ("summary.html","summary.csv","summary.md","summary.tex","efficiency.md"):
out += glob.glob(f"{ROOT}/results/baselines/{pat}")
return out
def main():
ap = argparse.ArgumentParser(); ap.add_argument("--dry-run",action="store_true"); a=ap.parse_args()
code, res = list_code(), list_results()
fw = curate_framework()
nn = curate_nn("nnunet", f"{ROOT}/nnunet_workspace/results_nnunet")
um = curate_nn("umamba", f"{ROOT}/nnunet_workspace/results_umamba")
w_fw = sum(sz(v[1]) for v in fw.values())
w_nn = sum(sz(v[0]) for v in nn.values()); w_um = sum(sz(v[0]) for v in um.values())
code_sz = sum(sz(p) for p in code); res_sz = sum(sz(p) for p in res)
print("="*60)
print(f"REPO: {REPO} (private)")
print(f"CODE : {len(code):4d} files {human(code_sz)} (framework w/o synth + scripts + 3 envs)")
print(f"RESULTS: {len(res):4d} files {human(res_sz)} (metrics.json + summary.* + efficiency.md)")
print(f"WEIGHTS framework: {len(fw):3d} cells {human(w_fw)} (best seed per dataset x arch)")
print(f"WEIGHTS nnU-Net : {len(nn):3d} dsets {human(w_nn)} (best fold)")
print(f"WEIGHTS U-Mamba : {len(um):3d} dsets {human(w_um)} (best fold)")
print(f"TOTAL : {human(code_sz+res_sz+w_fw+w_nn+w_um)}")
print("="*60)
print("sample framework weights:")
for k in list(fw)[:3]: print(" ", fw[k][2], "<-", os.path.relpath(fw[k][1],ROOT))
print("sample nnU-Net/U-Mamba weights:")
for d in (nn,um):
for k in list(d)[:2]: print(" ", d[k][1], "<-", os.path.relpath(d[k][0],ROOT))
miss = [c for c in NNRAW_IDS if c not in nn] + [c for c in NNRAW_IDS if c not in um]
if miss: print("NOTE missing nn/um ckpts for:", sorted(set(miss)))
if a.dry_run:
print("\n[dry-run] nothing uploaded.")
return
# ---- real upload ----
from huggingface_hub import HfApi, create_repo
api = HfApi()
create_repo(REPO, repo_type="model", private=True, exist_ok=True)
readme = """---
license: cc-by-nc-4.0
tags: [medical-imaging, segmentation, benchmark]
---
# GenSeg-Baselines
Baseline benchmark for 2D medical image segmentation: **8 methods x 10 datasets x 3 seeds/folds, 7 metrics**.
Companion to the [GenSegDataset](https://huggingface.co/datasets/MaybeRichard/GenSegDataset).
**Methods:** UNet, UNet++, DeepLabV3+ (ResNet-50/ImageNet), Attention-UNet (scratch),
TransUNet (R50-ViT-B/16), Swin-UNet (Swin-Tiny), nnU-Net v2 (250ep), U-Mamba (UMambaBot, 100ep).
**Datasets:** cvc_clinicdb, kvasir_seg, fives, busi, refuge2, acdc, idridd, pannuke, isic2018, kits19.
**Metrics:** Dice, IoU, HD95, ASSD, Sensitivity, Specificity, Precision (+ efficiency).
## Layout
- `code/` - baseline framework (train/test/aggregate), scripts, conda envs. *(Generative SegGen code excluded.)*
- `results/` - per-run `metrics.json` + aggregated `summary.{html,csv,md,tex}` + `efficiency.md`.
- `weights/` - curated checkpoints: best seed per (dataset, arch) for framework; best fold for nnU-Net / U-Mamba.
## Note
These are the **256-px baseline** (confirmed). A resolution-fair re-evaluation (conv methods retrained at a
higher per-dataset resolution; all methods scored at a common R so HD95 is comparable) is in progress and
will be added later.
"""
import tempfile
tmp = os.path.join(tempfile.gettempdir(), "GENSEG_README.md")
open(tmp,"w").write(readme)
api.upload_file(path_or_fileobj=tmp, repo_id=REPO, path_in_repo="README.md")
print("repo ready; uploading code+results ...")
api.upload_folder(folder_path=f"{ROOT}/framework", repo_id=REPO, path_in_repo="code/framework",
ignore_patterns=["synth/*","**/__pycache__/*","*.pyc"])
api.upload_folder(folder_path=f"{ROOT}/scripts", repo_id=REPO, path_in_repo="code/scripts",
ignore_patterns=["**/__pycache__/*","*.pyc"])
for y in ("seggen","nnunet","umamba"):
api.upload_file(path_or_fileobj=f"{ROOT}/envs/{y}.yml", repo_id=REPO, path_in_repo=f"code/envs/{y}.yml")
api.upload_folder(folder_path=f"{ROOT}/results/baselines", repo_id=REPO, path_in_repo="results",
allow_patterns=["**/metrics.json","summary.*","efficiency.md"])
print("uploading weights ...")
for v in fw.values():
api.upload_file(path_or_fileobj=v[1], repo_id=REPO, path_in_repo=v[2])
for d in (nn,um):
for v in d.values():
api.upload_file(path_or_fileobj=v[0], repo_id=REPO, path_in_repo=v[1])
print("DONE:", f"https://huggingface.co/{REPO}")
if __name__ == "__main__":
main()