biptv3 / code /pointcept_framework /tools /summarize_and_plot_0920.py
YYYYYYUUU's picture
Add core reproduction code (binarization layers, PTv3, superpoint ops, min-repro pack)
7b95dc2 verified
Raw
History Blame Contribute Delete
12 kB
import os, re, csv, math, sys
from pathlib import Path
import argparse
import torch
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
def parse_bits_and_mode(run_dir):
"""
返回 (w_bits, a_bits, mode)
mode: "quant" / "fp32" / "unknown"
规则:
1) 目录名含 wXaY → 视为量化
2) 目录名含 "fp32" 或 "baseline" 且不含 wXaY → 视为 FP32
3) 日志里 quant0920.enable=True/False 做兜底
4) 都没命中 → unknown(后续再猜)
"""
name = run_dir.name.lower()
w_bits = a_bits = None
m = re.search(r"w(\d+)a(\d+)", name)
if m:
w_bits, a_bits = int(m.group(1)), int(m.group(2))
return w_bits, a_bits, "quant"
log = run_dir / "train.log"
text = log.read_text(errors="ignore").lower() if log.exists() else ""
if ("fp32" in name or "baseline" in name) and not re.search(r"w\d+a\d+", name):
return 32, 32, "fp32"
m1 = re.search(r"quant\d*\.?enable\s*=\s*(true|false)", text)
mw = re.search(r"quant\d*\.?w_bits\s*=\s*(\d+)", text)
ma = re.search(r"quant\d*\.?a_bits\s*=\s*(\d+)", text)
if m1:
enabled = (m1.group(1) == "true")
if enabled:
if mw: w_bits = int(mw.group(1))
if ma: a_bits = int(ma.group(1))
if w_bits is None: w_bits = 2
if a_bits is None: a_bits = 8
return w_bits, a_bits, "quant"
else:
return 32, 32, "fp32"
# unknown,后续用 ckpt 统计时再猜
return None, None, "unknown"
def guess_dataset(run_dir):
s = run_dir.name.lower()
for k in ["scannet","s3dis","nuscenes","modelnet"]:
if k in s: return k
log = run_dir / "train.log"
if log.exists():
t = log.read_text(errors="ignore").lower()
for k in ["scannet","s3dis","nuscenes","modelnet"]:
if k in t: return k
return "unknown"
def find_ckpts(run_dir):
cands = []
for ext in ("*.pth","*.pt"):
cands += list(run_dir.rglob(ext))
cands = [p for p in cands if p.is_file()]
scored=[]
for p in cands:
score = 0
n = p.name.lower()
if "best" in n: score += 100
if "last" in n or "latest" in n: score += 50
score += int(p.stat().st_mtime)
scored.append((score,p))
scored.sort(reverse=True)
return [p for _,p in scored]
def safe_load_state_dict(p):
try:
# 优先使用 weights_only(新 PyTorch)
try:
obj = torch.load(p, map_location="cpu", weights_only=True)
except TypeError:
obj = torch.load(p, map_location="cpu")
if isinstance(obj, dict):
for k in ["state_dict","model","net","module","ema","model_state","model_ema"]:
v = obj.get(k, None)
if isinstance(v, dict):
return {kk:vv for kk,vv in v.items() if torch.is_tensor(vv)}
# 直接就是 state_dict
if all(isinstance(k, str) for k in obj.keys()):
return {k:v for k,v in obj.items() if torch.is_tensor(v)}
return {}
except Exception:
return {}
def collect_bits_stats(sd, w_bits_for_quant, exclude_hints, excl_norm_bias=True, fp32_force=False):
"""
返回平均权重量化位宽估计:
- 被排除(head/stem/norm/bias)的参数按 32bit 记
- 其他参数:
* fp32_force=True → 全按 32bit
* 否则按 w_bits_for_quant 计
"""
total = qcnt = fpcnt = 0
for k,v in sd.items():
if not torch.is_tensor(v): continue
n = v.numel()
name = k.lower()
excluded = any(h in name for h in exclude_hints)
if excl_norm_bias and (".norm" in name or "bn" in name or name.endswith(".bias")):
excluded = True
total += n
if excluded or fp32_force:
fpcnt += n
else:
qcnt += n
if total==0:
return dict(total=0,qcnt=0,fp32=0,avg=float("nan"),ratio=0.0)
avg = (qcnt*(w_bits_for_quant if not fp32_force else 32) + fpcnt*32.0)/total
return dict(total=total,qcnt=qcnt,fp32=fpcnt,avg=avg,ratio=qcnt/total)
def parse_metrics(log_path):
res={}
if not log_path.exists(): return res
txt = log_path.read_text(errors="ignore")
def floats(s): return [float(x) for x in re.findall(r"[-+]?\d*\.\d+|\d+", s)]
# mIoU
mi=[]
for line in txt.splitlines():
l=line.lower()
if "miou" in l:
nums=floats(l)
nums=[x for x in nums if 0<=x<=100]
if nums: mi+=nums
if mi: res["mIoU_best"]=max(mi)
# Acc / OA
acc=[]
for key in ["overall acc","oa","accuracy","acc"]:
for line in txt.splitlines():
l=line.lower()
if key in l:
nums=floats(l)
nums=[x for x in nums if 0<=x<=100]
if nums: acc+=nums
if acc: res["Acc_best"]=max(acc)
return res
def main():
ap=argparse.ArgumentParser()
ap.add_argument("--exp-root", default="exp")
ap.add_argument("--out-csv", default="exp/summary_0920/summary_0920_fixed.csv")
ap.add_argument("--plots-dir", default="exp/summary_0920/plots_0920")
ap.add_argument("--exclude", default="cls_head,embedding.stem,stem,head")
ap.add_argument("--no-exclude-norm-bias", action="store_true")
args=ap.parse_args()
exp_root=Path(args.exp_root)
runs=[p for p in exp_root.iterdir() if p.is_dir()]
exclude_hints=[s.strip().lower() for s in args.exclude.split(",") if s.strip()]
excl_norm_bias = not args.no_exclude_norm_bias
rows=[]
# 先扫 FP32 baselines
best_fp32={}
for run in runs:
ds = guess_dataset(run)
w,a,mode = parse_bits_and_mode(run)
metrics = parse_metrics(run/"train.log")
if mode=="fp32" and metrics:
cur = best_fp32.get(ds,{})
if "mIoU_best" in metrics:
cur["mIoU_best"]=max(metrics["mIoU_best"], cur.get("mIoU_best",-1))
if "Acc_best" in metrics:
cur["Acc_best"]=max(metrics["Acc_best"], cur.get("Acc_best",-1))
best_fp32[ds]=cur
for run in runs:
ds = guess_dataset(run)
w,a,mode = parse_bits_and_mode(run)
# 找个 ckpt,统计参数
sd={}
ckpt_file=""
for c in find_ckpts(run):
sd = safe_load_state_dict(c)
ckpt_file=str(c)
if sd: break
# 用 ckpt 进一步判断:若 mode 未知且目录也不带 wXaY,
# 则:假设 FP32(常见情况),除非日志明确 enable=True
if mode=="unknown":
mode="fp32"
if w is None: w=32
if a is None: a=32
# 统计 average bit
if sd:
if mode=="fp32":
stat = collect_bits_stats(sd, w_bits_for_quant=32, exclude_hints=exclude_hints, excl_norm_bias=excl_norm_bias, fp32_force=True)
else:
if w is None: w=2
stat = collect_bits_stats(sd, w_bits_for_quant=w, exclude_hints=exclude_hints, excl_norm_bias=excl_norm_bias, fp32_force=False)
avg_bit = stat["avg"]
qratio = stat["ratio"]
total = stat["total"]; qcnt=stat["qcnt"]; fpcnt=stat["fp32"]
else:
avg_bit = float("nan"); qratio=0.0; total=qcnt=fpcnt=0
metrics = parse_metrics(run/"train.log")
miou = metrics.get("mIoU_best"); acc = metrics.get("Acc_best")
base = best_fp32.get(ds, {})
d_miou = miou - base["mIoU_best"] if (miou is not None and "mIoU_best" in base) else None
d_acc = acc - base["Acc_best"] if (acc is not None and "Acc_best" in base) else None
rows.append({
"run": run.name, "dataset": ds, "mode": mode,
"w_bits": w, "a_bits": a,
"avg_weight_bit": round(avg_bit,3) if not math.isnan(avg_bit) else "",
"quant_ratio(%)": round(qratio*100,2),
"params(total)": total, "params_quant": qcnt, "params_fp32": fpcnt,
"mIoU_best": miou if miou is not None else "",
"ΔmIoU_vs_fp32": round(d_miou,3) if d_miou is not None else "",
"Acc_best": acc if acc is not None else "",
"ΔAcc_vs_fp32": round(d_acc,3) if d_acc is not None else "",
"ckpt": ckpt_file
})
# 写 CSV
out_csv=Path(args.out_csv)
out_csv.parent.mkdir(parents=True, exist_ok=True)
with out_csv.open("w", newline="") as f:
if rows:
writer = csv.DictWriter(f, fieldnames=list(rows[0].keys()))
writer.writeheader()
for r in rows: writer.writerow(r)
print(f"[OK] CSV saved: {out_csv}")
# -------- 绘图 --------
plots_dir=Path(args.plots_dir)
plots_dir.mkdir(parents=True, exist_ok=True)
# 1) 各数据集:avg_bit vs mIoU(若无 mIoU 则跳过),点注 run
by_ds={}
for r in rows:
ds=r["dataset"];
try:
ab=float(r["avg_weight_bit"]);
except:
continue
if math.isnan(ab): continue
if ds not in by_ds: by_ds[ds]=[]
mi=r.get("mIoU_best","")
if mi!="":
try:
mi=float(mi)
by_ds[ds].append((ab, mi, r["run"]))
except: pass
for ds, arr in by_ds.items():
if not arr: continue
xs=[x for x,_,_ in arr]; ys=[y for _,y,_ in arr]; labels=[l for *_,l in arr]
plt.figure(figsize=(6,4))
plt.scatter(xs, ys)
for x,y,l in zip(xs,ys,labels):
plt.annotate(l, (x,y), fontsize=8, xytext=(3,3), textcoords="offset points")
plt.xlabel("Average Weight Bit")
plt.ylabel("mIoU (%)")
plt.title(f"{ds.upper()} mIoU vs AvgBit (0920)")
p = plots_dir / f"{ds}_miou_vs_avgbit_0920.png"
plt.tight_layout(); plt.savefig(p, dpi=200); plt.close()
print(f"[plot] {p}")
# 2) 各数据集:avg_bit vs Acc(分类/没有 mIoU 的情况)
by_ds_acc={}
for r in rows:
ds=r["dataset"]
try: ab=float(r["avg_weight_bit"])
except: continue
if math.isnan(ab): continue
ac=r.get("Acc_best","")
if ac!="":
try:
ac=float(ac)
by_ds_acc.setdefault(ds, []).append((ab, ac, r["run"]))
except: pass
for ds, arr in by_ds_acc.items():
if not arr: continue
xs=[x for x,_,_ in arr]; ys=[y for _,y,_ in arr]; labels=[l for *_,l in arr]
plt.figure(figsize=(6,4))
plt.scatter(xs, ys)
for x,y,l in zip(xs,ys,labels):
plt.annotate(l, (x,y), fontsize=8, xytext=(3,3), textcoords="offset points")
plt.xlabel("Average Weight Bit")
plt.ylabel("Accuracy (%)")
plt.title(f"{ds.upper()} Acc vs AvgBit (0920)")
p = plots_dir / f"{ds}_acc_vs_avgbit_0920.png"
plt.tight_layout(); plt.savefig(p, dpi=200); plt.close()
print(f"[plot] {p}")
# 3) 总览条形图:每个 run 的 avg_bit(按数据集分组)
labels=[r["run"] for r in rows]
abits=[]
dss=[]
for r in rows:
try:
ab=float(r["avg_weight_bit"]);
except:
ab=float("nan")
abits.append(ab)
dss.append(r["dataset"])
# 仅保留有数值的
items=[(l,a,d) for l,a,d in zip(labels,abits,dss) if not math.isnan(a)]
if items:
items.sort(key=lambda x:(x[2], x[1])) # 按数据集→avg_bit 排
labs=[f"{d}:{l}" for l,_,d in items]
vals=[a for _,a,_ in items]
plt.figure(figsize=(max(8, 0.2*len(items)+4), 6))
plt.bar(range(len(items)), vals)
plt.xticks(range(len(items)), labs, rotation=75, ha="right", fontsize=8)
plt.ylabel("Average Weight Bit")
plt.title("Avg Weight Bit by Run (0920)")
p = plots_dir / "all_runs_avgbit_0920.png"
plt.tight_layout(); plt.savefig(p, dpi=200); plt.close()
print(f"[plot] {p}")
if __name__=="__main__":
main()