GenSeg-Baselines / code /framework /efficiency.py
MaybeRichard's picture
Upload folder using huggingface_hub
b8fae22 verified
Raw
History Blame Contribute Delete
3.66 kB
"""Per-architecture efficiency table: #params, FLOPs (GMac), inference throughput.
Representative setting (in_channels=3, num_classes=2). FLOPs via thop -> fvcore ->
ptflops (whichever is installed); params and throughput always computed. Run once,
on a GPU (A100). Output: results/<exp>/efficiency.{md,csv}.
python framework/efficiency.py --img_size 256 --out_root results --exp_name baselines
"""
from __future__ import annotations
import os
import sys
import time
import json
import argparse
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
import torch
from framework.models.registry import build_model, required_img_size
ARCHS = ["unet", "unetpp", "deeplabv3plus", "attention_unet", "transunet", "swinunet"]
def count_flops_gmac(model, x):
try:
from thop import profile
macs, _ = profile(model, inputs=(x,), verbose=False)
return macs / 1e9
except Exception:
pass
try:
from fvcore.nn import FlopCountAnalysis
return FlopCountAnalysis(model, x).total() / 1e9
except Exception:
pass
return float("nan")
@torch.no_grad()
def throughput(model, x, iters=50, warmup=10):
for _ in range(warmup):
model(x)
if x.is_cuda:
torch.cuda.synchronize()
t0 = time.time()
for _ in range(iters):
model(x)
if x.is_cuda:
torch.cuda.synchronize()
dt = time.time() - t0
return iters * x.size(0) / dt # images / sec
def encoder_for(arch):
if arch in ("unet", "unetpp", "deeplabv3plus"):
return "resnet50"
if arch == "transunet":
return "R50-ViT-B_16"
return "resnet34"
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--img_size", type=int, default=256)
ap.add_argument("--batch_size", type=int, default=8)
ap.add_argument("--out_root", default="results")
ap.add_argument("--exp_name", default="baselines")
args = ap.parse_args()
dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
rows = []
for arch in ARCHS:
sz = required_img_size(arch) or args.img_size
model = build_model(arch, in_channels=3, num_classes=2, img_size=sz,
encoder=encoder_for(arch), encoder_weights="none").to(dev).eval()
params_m = sum(p.numel() for p in model.parameters()) / 1e6
x1 = torch.randn(1, 3, sz, sz, device=dev)
gmac = count_flops_gmac(model, x1)
xb = torch.randn(args.batch_size, 3, sz, sz, device=dev)
try:
ips = throughput(model, xb)
except Exception as e:
ips = float("nan"); print(f"[warn] throughput {arch}: {e}")
rows.append({"arch": arch, "img": sz, "params_M": round(params_m, 2),
"gmac": round(gmac, 2) if gmac == gmac else None,
"imgs_per_s": round(ips, 1) if ips == ips else None})
print(f"{arch:16s} img={sz} params={params_m:.2f}M GMac={gmac:.2f} {ips:.1f} img/s")
del model, x1, xb
if dev.type == "cuda":
torch.cuda.empty_cache()
base = os.path.join(args.out_root, args.exp_name)
os.makedirs(base, exist_ok=True)
with open(os.path.join(base, "efficiency.json"), "w") as f:
json.dump(rows, f, indent=2)
md = "| Method | Img | Params(M) | GMac | Img/s |\n|---|---|---|---|---|\n"
for r in rows:
md += f"| {r['arch']} | {r['img']} | {r['params_M']} | {r['gmac']} | {r['imgs_per_s']} |\n"
open(os.path.join(base, "efficiency.md"), "w").write(md)
print(md)
print(f"written {base}/efficiency.{{json,md}}")
if __name__ == "__main__":
main()