import os, re, csv, math, sys from pathlib import Path import argparse import torch import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt def parse_bits_and_mode(run_dir): """ 返回 (w_bits, a_bits, mode) mode: "quant" / "fp32" / "unknown" 规则: 1) 目录名含 wXaY → 视为量化 2) 目录名含 "fp32" 或 "baseline" 且不含 wXaY → 视为 FP32 3) 日志里 quant0920.enable=True/False 做兜底 4) 都没命中 → unknown(后续再猜) """ name = run_dir.name.lower() w_bits = a_bits = None m = re.search(r"w(\d+)a(\d+)", name) if m: w_bits, a_bits = int(m.group(1)), int(m.group(2)) return w_bits, a_bits, "quant" log = run_dir / "train.log" text = log.read_text(errors="ignore").lower() if log.exists() else "" if ("fp32" in name or "baseline" in name) and not re.search(r"w\d+a\d+", name): return 32, 32, "fp32" m1 = re.search(r"quant\d*\.?enable\s*=\s*(true|false)", text) mw = re.search(r"quant\d*\.?w_bits\s*=\s*(\d+)", text) ma = re.search(r"quant\d*\.?a_bits\s*=\s*(\d+)", text) if m1: enabled = (m1.group(1) == "true") if enabled: if mw: w_bits = int(mw.group(1)) if ma: a_bits = int(ma.group(1)) if w_bits is None: w_bits = 2 if a_bits is None: a_bits = 8 return w_bits, a_bits, "quant" else: return 32, 32, "fp32" # unknown,后续用 ckpt 统计时再猜 return None, None, "unknown" def guess_dataset(run_dir): s = run_dir.name.lower() for k in ["scannet","s3dis","nuscenes","modelnet"]: if k in s: return k log = run_dir / "train.log" if log.exists(): t = log.read_text(errors="ignore").lower() for k in ["scannet","s3dis","nuscenes","modelnet"]: if k in t: return k return "unknown" def find_ckpts(run_dir): cands = [] for ext in ("*.pth","*.pt"): cands += list(run_dir.rglob(ext)) cands = [p for p in cands if p.is_file()] scored=[] for p in cands: score = 0 n = p.name.lower() if "best" in n: score += 100 if "last" in n or "latest" in n: score += 50 score += int(p.stat().st_mtime) scored.append((score,p)) scored.sort(reverse=True) return [p for _,p in scored] def safe_load_state_dict(p): try: # 优先使用 weights_only(新 PyTorch) try: obj = torch.load(p, map_location="cpu", weights_only=True) except TypeError: obj = torch.load(p, map_location="cpu") if isinstance(obj, dict): for k in ["state_dict","model","net","module","ema","model_state","model_ema"]: v = obj.get(k, None) if isinstance(v, dict): return {kk:vv for kk,vv in v.items() if torch.is_tensor(vv)} # 直接就是 state_dict if all(isinstance(k, str) for k in obj.keys()): return {k:v for k,v in obj.items() if torch.is_tensor(v)} return {} except Exception: return {} def collect_bits_stats(sd, w_bits_for_quant, exclude_hints, excl_norm_bias=True, fp32_force=False): """ 返回平均权重量化位宽估计: - 被排除(head/stem/norm/bias)的参数按 32bit 记 - 其他参数: * fp32_force=True → 全按 32bit * 否则按 w_bits_for_quant 计 """ total = qcnt = fpcnt = 0 for k,v in sd.items(): if not torch.is_tensor(v): continue n = v.numel() name = k.lower() excluded = any(h in name for h in exclude_hints) if excl_norm_bias and (".norm" in name or "bn" in name or name.endswith(".bias")): excluded = True total += n if excluded or fp32_force: fpcnt += n else: qcnt += n if total==0: return dict(total=0,qcnt=0,fp32=0,avg=float("nan"),ratio=0.0) avg = (qcnt*(w_bits_for_quant if not fp32_force else 32) + fpcnt*32.0)/total return dict(total=total,qcnt=qcnt,fp32=fpcnt,avg=avg,ratio=qcnt/total) def parse_metrics(log_path): res={} if not log_path.exists(): return res txt = log_path.read_text(errors="ignore") def floats(s): return [float(x) for x in re.findall(r"[-+]?\d*\.\d+|\d+", s)] # mIoU mi=[] for line in txt.splitlines(): l=line.lower() if "miou" in l: nums=floats(l) nums=[x for x in nums if 0<=x<=100] if nums: mi+=nums if mi: res["mIoU_best"]=max(mi) # Acc / OA acc=[] for key in ["overall acc","oa","accuracy","acc"]: for line in txt.splitlines(): l=line.lower() if key in l: nums=floats(l) nums=[x for x in nums if 0<=x<=100] if nums: acc+=nums if acc: res["Acc_best"]=max(acc) return res def main(): ap=argparse.ArgumentParser() ap.add_argument("--exp-root", default="exp") ap.add_argument("--out-csv", default="exp/summary_0920/summary_0920_fixed.csv") ap.add_argument("--plots-dir", default="exp/summary_0920/plots_0920") ap.add_argument("--exclude", default="cls_head,embedding.stem,stem,head") ap.add_argument("--no-exclude-norm-bias", action="store_true") args=ap.parse_args() exp_root=Path(args.exp_root) runs=[p for p in exp_root.iterdir() if p.is_dir()] exclude_hints=[s.strip().lower() for s in args.exclude.split(",") if s.strip()] excl_norm_bias = not args.no_exclude_norm_bias rows=[] # 先扫 FP32 baselines best_fp32={} for run in runs: ds = guess_dataset(run) w,a,mode = parse_bits_and_mode(run) metrics = parse_metrics(run/"train.log") if mode=="fp32" and metrics: cur = best_fp32.get(ds,{}) if "mIoU_best" in metrics: cur["mIoU_best"]=max(metrics["mIoU_best"], cur.get("mIoU_best",-1)) if "Acc_best" in metrics: cur["Acc_best"]=max(metrics["Acc_best"], cur.get("Acc_best",-1)) best_fp32[ds]=cur for run in runs: ds = guess_dataset(run) w,a,mode = parse_bits_and_mode(run) # 找个 ckpt,统计参数 sd={} ckpt_file="" for c in find_ckpts(run): sd = safe_load_state_dict(c) ckpt_file=str(c) if sd: break # 用 ckpt 进一步判断:若 mode 未知且目录也不带 wXaY, # 则:假设 FP32(常见情况),除非日志明确 enable=True if mode=="unknown": mode="fp32" if w is None: w=32 if a is None: a=32 # 统计 average bit if sd: if mode=="fp32": stat = collect_bits_stats(sd, w_bits_for_quant=32, exclude_hints=exclude_hints, excl_norm_bias=excl_norm_bias, fp32_force=True) else: if w is None: w=2 stat = collect_bits_stats(sd, w_bits_for_quant=w, exclude_hints=exclude_hints, excl_norm_bias=excl_norm_bias, fp32_force=False) avg_bit = stat["avg"] qratio = stat["ratio"] total = stat["total"]; qcnt=stat["qcnt"]; fpcnt=stat["fp32"] else: avg_bit = float("nan"); qratio=0.0; total=qcnt=fpcnt=0 metrics = parse_metrics(run/"train.log") miou = metrics.get("mIoU_best"); acc = metrics.get("Acc_best") base = best_fp32.get(ds, {}) d_miou = miou - base["mIoU_best"] if (miou is not None and "mIoU_best" in base) else None d_acc = acc - base["Acc_best"] if (acc is not None and "Acc_best" in base) else None rows.append({ "run": run.name, "dataset": ds, "mode": mode, "w_bits": w, "a_bits": a, "avg_weight_bit": round(avg_bit,3) if not math.isnan(avg_bit) else "", "quant_ratio(%)": round(qratio*100,2), "params(total)": total, "params_quant": qcnt, "params_fp32": fpcnt, "mIoU_best": miou if miou is not None else "", "ΔmIoU_vs_fp32": round(d_miou,3) if d_miou is not None else "", "Acc_best": acc if acc is not None else "", "ΔAcc_vs_fp32": round(d_acc,3) if d_acc is not None else "", "ckpt": ckpt_file }) # 写 CSV out_csv=Path(args.out_csv) out_csv.parent.mkdir(parents=True, exist_ok=True) with out_csv.open("w", newline="") as f: if rows: writer = csv.DictWriter(f, fieldnames=list(rows[0].keys())) writer.writeheader() for r in rows: writer.writerow(r) print(f"[OK] CSV saved: {out_csv}") # -------- 绘图 -------- plots_dir=Path(args.plots_dir) plots_dir.mkdir(parents=True, exist_ok=True) # 1) 各数据集:avg_bit vs mIoU(若无 mIoU 则跳过),点注 run by_ds={} for r in rows: ds=r["dataset"]; try: ab=float(r["avg_weight_bit"]); except: continue if math.isnan(ab): continue if ds not in by_ds: by_ds[ds]=[] mi=r.get("mIoU_best","") if mi!="": try: mi=float(mi) by_ds[ds].append((ab, mi, r["run"])) except: pass for ds, arr in by_ds.items(): if not arr: continue xs=[x for x,_,_ in arr]; ys=[y for _,y,_ in arr]; labels=[l for *_,l in arr] plt.figure(figsize=(6,4)) plt.scatter(xs, ys) for x,y,l in zip(xs,ys,labels): plt.annotate(l, (x,y), fontsize=8, xytext=(3,3), textcoords="offset points") plt.xlabel("Average Weight Bit") plt.ylabel("mIoU (%)") plt.title(f"{ds.upper()} mIoU vs AvgBit (0920)") p = plots_dir / f"{ds}_miou_vs_avgbit_0920.png" plt.tight_layout(); plt.savefig(p, dpi=200); plt.close() print(f"[plot] {p}") # 2) 各数据集:avg_bit vs Acc(分类/没有 mIoU 的情况) by_ds_acc={} for r in rows: ds=r["dataset"] try: ab=float(r["avg_weight_bit"]) except: continue if math.isnan(ab): continue ac=r.get("Acc_best","") if ac!="": try: ac=float(ac) by_ds_acc.setdefault(ds, []).append((ab, ac, r["run"])) except: pass for ds, arr in by_ds_acc.items(): if not arr: continue xs=[x for x,_,_ in arr]; ys=[y for _,y,_ in arr]; labels=[l for *_,l in arr] plt.figure(figsize=(6,4)) plt.scatter(xs, ys) for x,y,l in zip(xs,ys,labels): plt.annotate(l, (x,y), fontsize=8, xytext=(3,3), textcoords="offset points") plt.xlabel("Average Weight Bit") plt.ylabel("Accuracy (%)") plt.title(f"{ds.upper()} Acc vs AvgBit (0920)") p = plots_dir / f"{ds}_acc_vs_avgbit_0920.png" plt.tight_layout(); plt.savefig(p, dpi=200); plt.close() print(f"[plot] {p}") # 3) 总览条形图:每个 run 的 avg_bit(按数据集分组) labels=[r["run"] for r in rows] abits=[] dss=[] for r in rows: try: ab=float(r["avg_weight_bit"]); except: ab=float("nan") abits.append(ab) dss.append(r["dataset"]) # 仅保留有数值的 items=[(l,a,d) for l,a,d in zip(labels,abits,dss) if not math.isnan(a)] if items: items.sort(key=lambda x:(x[2], x[1])) # 按数据集→avg_bit 排 labs=[f"{d}:{l}" for l,_,d in items] vals=[a for _,a,_ in items] plt.figure(figsize=(max(8, 0.2*len(items)+4), 6)) plt.bar(range(len(items)), vals) plt.xticks(range(len(items)), labs, rotation=75, ha="right", fontsize=8) plt.ylabel("Average Weight Bit") plt.title("Avg Weight Bit by Run (0920)") p = plots_dir / "all_runs_avgbit_0920.png" plt.tight_layout(); plt.savefig(p, dpi=200); plt.close() print(f"[plot] {p}") if __name__=="__main__": main()