| import os, re, csv, math, sys |
| from pathlib import Path |
| import argparse |
| import torch |
| import matplotlib |
| matplotlib.use("Agg") |
| import matplotlib.pyplot as plt |
|
|
| def parse_bits_and_mode(run_dir): |
| """ |
| 返回 (w_bits, a_bits, mode) |
| mode: "quant" / "fp32" / "unknown" |
| 规则: |
| 1) 目录名含 wXaY → 视为量化 |
| 2) 目录名含 "fp32" 或 "baseline" 且不含 wXaY → 视为 FP32 |
| 3) 日志里 quant0920.enable=True/False 做兜底 |
| 4) 都没命中 → unknown(后续再猜) |
| """ |
| name = run_dir.name.lower() |
| w_bits = a_bits = None |
| m = re.search(r"w(\d+)a(\d+)", name) |
| if m: |
| w_bits, a_bits = int(m.group(1)), int(m.group(2)) |
| return w_bits, a_bits, "quant" |
|
|
| log = run_dir / "train.log" |
| text = log.read_text(errors="ignore").lower() if log.exists() else "" |
|
|
| if ("fp32" in name or "baseline" in name) and not re.search(r"w\d+a\d+", name): |
| return 32, 32, "fp32" |
|
|
| m1 = re.search(r"quant\d*\.?enable\s*=\s*(true|false)", text) |
| mw = re.search(r"quant\d*\.?w_bits\s*=\s*(\d+)", text) |
| ma = re.search(r"quant\d*\.?a_bits\s*=\s*(\d+)", text) |
| if m1: |
| enabled = (m1.group(1) == "true") |
| if enabled: |
| if mw: w_bits = int(mw.group(1)) |
| if ma: a_bits = int(ma.group(1)) |
| if w_bits is None: w_bits = 2 |
| if a_bits is None: a_bits = 8 |
| return w_bits, a_bits, "quant" |
| else: |
| return 32, 32, "fp32" |
|
|
| |
| return None, None, "unknown" |
|
|
| def guess_dataset(run_dir): |
| s = run_dir.name.lower() |
| for k in ["scannet","s3dis","nuscenes","modelnet"]: |
| if k in s: return k |
| log = run_dir / "train.log" |
| if log.exists(): |
| t = log.read_text(errors="ignore").lower() |
| for k in ["scannet","s3dis","nuscenes","modelnet"]: |
| if k in t: return k |
| return "unknown" |
|
|
| def find_ckpts(run_dir): |
| cands = [] |
| for ext in ("*.pth","*.pt"): |
| cands += list(run_dir.rglob(ext)) |
| cands = [p for p in cands if p.is_file()] |
| scored=[] |
| for p in cands: |
| score = 0 |
| n = p.name.lower() |
| if "best" in n: score += 100 |
| if "last" in n or "latest" in n: score += 50 |
| score += int(p.stat().st_mtime) |
| scored.append((score,p)) |
| scored.sort(reverse=True) |
| return [p for _,p in scored] |
|
|
| def safe_load_state_dict(p): |
| try: |
| |
| try: |
| obj = torch.load(p, map_location="cpu", weights_only=True) |
| except TypeError: |
| obj = torch.load(p, map_location="cpu") |
| if isinstance(obj, dict): |
| for k in ["state_dict","model","net","module","ema","model_state","model_ema"]: |
| v = obj.get(k, None) |
| if isinstance(v, dict): |
| return {kk:vv for kk,vv in v.items() if torch.is_tensor(vv)} |
| |
| if all(isinstance(k, str) for k in obj.keys()): |
| return {k:v for k,v in obj.items() if torch.is_tensor(v)} |
| return {} |
| except Exception: |
| return {} |
|
|
| def collect_bits_stats(sd, w_bits_for_quant, exclude_hints, excl_norm_bias=True, fp32_force=False): |
| """ |
| 返回平均权重量化位宽估计: |
| - 被排除(head/stem/norm/bias)的参数按 32bit 记 |
| - 其他参数: |
| * fp32_force=True → 全按 32bit |
| * 否则按 w_bits_for_quant 计 |
| """ |
| total = qcnt = fpcnt = 0 |
| for k,v in sd.items(): |
| if not torch.is_tensor(v): continue |
| n = v.numel() |
| name = k.lower() |
| excluded = any(h in name for h in exclude_hints) |
| if excl_norm_bias and (".norm" in name or "bn" in name or name.endswith(".bias")): |
| excluded = True |
| total += n |
| if excluded or fp32_force: |
| fpcnt += n |
| else: |
| qcnt += n |
| if total==0: |
| return dict(total=0,qcnt=0,fp32=0,avg=float("nan"),ratio=0.0) |
| avg = (qcnt*(w_bits_for_quant if not fp32_force else 32) + fpcnt*32.0)/total |
| return dict(total=total,qcnt=qcnt,fp32=fpcnt,avg=avg,ratio=qcnt/total) |
|
|
| def parse_metrics(log_path): |
| res={} |
| if not log_path.exists(): return res |
| txt = log_path.read_text(errors="ignore") |
| def floats(s): return [float(x) for x in re.findall(r"[-+]?\d*\.\d+|\d+", s)] |
| |
| mi=[] |
| for line in txt.splitlines(): |
| l=line.lower() |
| if "miou" in l: |
| nums=floats(l) |
| nums=[x for x in nums if 0<=x<=100] |
| if nums: mi+=nums |
| if mi: res["mIoU_best"]=max(mi) |
| |
| acc=[] |
| for key in ["overall acc","oa","accuracy","acc"]: |
| for line in txt.splitlines(): |
| l=line.lower() |
| if key in l: |
| nums=floats(l) |
| nums=[x for x in nums if 0<=x<=100] |
| if nums: acc+=nums |
| if acc: res["Acc_best"]=max(acc) |
| return res |
|
|
| def main(): |
| ap=argparse.ArgumentParser() |
| ap.add_argument("--exp-root", default="exp") |
| ap.add_argument("--out-csv", default="exp/summary_0920/summary_0920_fixed.csv") |
| ap.add_argument("--plots-dir", default="exp/summary_0920/plots_0920") |
| ap.add_argument("--exclude", default="cls_head,embedding.stem,stem,head") |
| ap.add_argument("--no-exclude-norm-bias", action="store_true") |
| args=ap.parse_args() |
|
|
| exp_root=Path(args.exp_root) |
| runs=[p for p in exp_root.iterdir() if p.is_dir()] |
| exclude_hints=[s.strip().lower() for s in args.exclude.split(",") if s.strip()] |
| excl_norm_bias = not args.no_exclude_norm_bias |
|
|
| rows=[] |
| |
| best_fp32={} |
| for run in runs: |
| ds = guess_dataset(run) |
| w,a,mode = parse_bits_and_mode(run) |
| metrics = parse_metrics(run/"train.log") |
| if mode=="fp32" and metrics: |
| cur = best_fp32.get(ds,{}) |
| if "mIoU_best" in metrics: |
| cur["mIoU_best"]=max(metrics["mIoU_best"], cur.get("mIoU_best",-1)) |
| if "Acc_best" in metrics: |
| cur["Acc_best"]=max(metrics["Acc_best"], cur.get("Acc_best",-1)) |
| best_fp32[ds]=cur |
|
|
| for run in runs: |
| ds = guess_dataset(run) |
| w,a,mode = parse_bits_and_mode(run) |
|
|
| |
| sd={} |
| ckpt_file="" |
| for c in find_ckpts(run): |
| sd = safe_load_state_dict(c) |
| ckpt_file=str(c) |
| if sd: break |
|
|
| |
| |
| if mode=="unknown": |
| mode="fp32" |
| if w is None: w=32 |
| if a is None: a=32 |
|
|
| |
| if sd: |
| if mode=="fp32": |
| stat = collect_bits_stats(sd, w_bits_for_quant=32, exclude_hints=exclude_hints, excl_norm_bias=excl_norm_bias, fp32_force=True) |
| else: |
| if w is None: w=2 |
| stat = collect_bits_stats(sd, w_bits_for_quant=w, exclude_hints=exclude_hints, excl_norm_bias=excl_norm_bias, fp32_force=False) |
| avg_bit = stat["avg"] |
| qratio = stat["ratio"] |
| total = stat["total"]; qcnt=stat["qcnt"]; fpcnt=stat["fp32"] |
| else: |
| avg_bit = float("nan"); qratio=0.0; total=qcnt=fpcnt=0 |
|
|
| metrics = parse_metrics(run/"train.log") |
| miou = metrics.get("mIoU_best"); acc = metrics.get("Acc_best") |
| base = best_fp32.get(ds, {}) |
| d_miou = miou - base["mIoU_best"] if (miou is not None and "mIoU_best" in base) else None |
| d_acc = acc - base["Acc_best"] if (acc is not None and "Acc_best" in base) else None |
|
|
| rows.append({ |
| "run": run.name, "dataset": ds, "mode": mode, |
| "w_bits": w, "a_bits": a, |
| "avg_weight_bit": round(avg_bit,3) if not math.isnan(avg_bit) else "", |
| "quant_ratio(%)": round(qratio*100,2), |
| "params(total)": total, "params_quant": qcnt, "params_fp32": fpcnt, |
| "mIoU_best": miou if miou is not None else "", |
| "ΔmIoU_vs_fp32": round(d_miou,3) if d_miou is not None else "", |
| "Acc_best": acc if acc is not None else "", |
| "ΔAcc_vs_fp32": round(d_acc,3) if d_acc is not None else "", |
| "ckpt": ckpt_file |
| }) |
|
|
| |
| out_csv=Path(args.out_csv) |
| out_csv.parent.mkdir(parents=True, exist_ok=True) |
| with out_csv.open("w", newline="") as f: |
| if rows: |
| writer = csv.DictWriter(f, fieldnames=list(rows[0].keys())) |
| writer.writeheader() |
| for r in rows: writer.writerow(r) |
|
|
| print(f"[OK] CSV saved: {out_csv}") |
|
|
| |
| plots_dir=Path(args.plots_dir) |
| plots_dir.mkdir(parents=True, exist_ok=True) |
|
|
| |
| by_ds={} |
| for r in rows: |
| ds=r["dataset"]; |
| try: |
| ab=float(r["avg_weight_bit"]); |
| except: |
| continue |
| if math.isnan(ab): continue |
| if ds not in by_ds: by_ds[ds]=[] |
| mi=r.get("mIoU_best","") |
| if mi!="": |
| try: |
| mi=float(mi) |
| by_ds[ds].append((ab, mi, r["run"])) |
| except: pass |
|
|
| for ds, arr in by_ds.items(): |
| if not arr: continue |
| xs=[x for x,_,_ in arr]; ys=[y for _,y,_ in arr]; labels=[l for *_,l in arr] |
| plt.figure(figsize=(6,4)) |
| plt.scatter(xs, ys) |
| for x,y,l in zip(xs,ys,labels): |
| plt.annotate(l, (x,y), fontsize=8, xytext=(3,3), textcoords="offset points") |
| plt.xlabel("Average Weight Bit") |
| plt.ylabel("mIoU (%)") |
| plt.title(f"{ds.upper()} mIoU vs AvgBit (0920)") |
| p = plots_dir / f"{ds}_miou_vs_avgbit_0920.png" |
| plt.tight_layout(); plt.savefig(p, dpi=200); plt.close() |
| print(f"[plot] {p}") |
|
|
| |
| by_ds_acc={} |
| for r in rows: |
| ds=r["dataset"] |
| try: ab=float(r["avg_weight_bit"]) |
| except: continue |
| if math.isnan(ab): continue |
| ac=r.get("Acc_best","") |
| if ac!="": |
| try: |
| ac=float(ac) |
| by_ds_acc.setdefault(ds, []).append((ab, ac, r["run"])) |
| except: pass |
| for ds, arr in by_ds_acc.items(): |
| if not arr: continue |
| xs=[x for x,_,_ in arr]; ys=[y for _,y,_ in arr]; labels=[l for *_,l in arr] |
| plt.figure(figsize=(6,4)) |
| plt.scatter(xs, ys) |
| for x,y,l in zip(xs,ys,labels): |
| plt.annotate(l, (x,y), fontsize=8, xytext=(3,3), textcoords="offset points") |
| plt.xlabel("Average Weight Bit") |
| plt.ylabel("Accuracy (%)") |
| plt.title(f"{ds.upper()} Acc vs AvgBit (0920)") |
| p = plots_dir / f"{ds}_acc_vs_avgbit_0920.png" |
| plt.tight_layout(); plt.savefig(p, dpi=200); plt.close() |
| print(f"[plot] {p}") |
|
|
| |
| labels=[r["run"] for r in rows] |
| abits=[] |
| dss=[] |
| for r in rows: |
| try: |
| ab=float(r["avg_weight_bit"]); |
| except: |
| ab=float("nan") |
| abits.append(ab) |
| dss.append(r["dataset"]) |
| |
| items=[(l,a,d) for l,a,d in zip(labels,abits,dss) if not math.isnan(a)] |
| if items: |
| items.sort(key=lambda x:(x[2], x[1])) |
| labs=[f"{d}:{l}" for l,_,d in items] |
| vals=[a for _,a,_ in items] |
| plt.figure(figsize=(max(8, 0.2*len(items)+4), 6)) |
| plt.bar(range(len(items)), vals) |
| plt.xticks(range(len(items)), labs, rotation=75, ha="right", fontsize=8) |
| plt.ylabel("Average Weight Bit") |
| plt.title("Avg Weight Bit by Run (0920)") |
| p = plots_dir / "all_runs_avgbit_0920.png" |
| plt.tight_layout(); plt.savefig(p, dpi=200); plt.close() |
| print(f"[plot] {p}") |
|
|
| if __name__=="__main__": |
| main() |