"""Region-level PSNR (edge/flat by Sobel quartile of GT) → Table 7 expansion.""" import os, argparse, glob import numpy as np from PIL import Image import torch import torch.nn.functional as F def sobel_mag(img_chw): luma = (0.299*img_chw[0] + 0.587*img_chw[1] + 0.114*img_chw[2]).unsqueeze(0).unsqueeze(0) wx = torch.tensor([[1,0,-1],[2,0,-2],[1,0,-1]], dtype=torch.float32).view(1,1,3,3) wy = torch.tensor([[1,2,1],[0,0,0],[-1,-2,-1]], dtype=torch.float32).view(1,1,3,3) gx = F.conv2d(luma, wx, padding=1) gy = F.conv2d(luma, wy, padding=1) return torch.sqrt(gx**2 + gy**2 + 1e-8).squeeze() def main(): ap = argparse.ArgumentParser() ap.add_argument("--methods", nargs="+", default=["vanilla_logged", "vanillasgf_logged", "vanillasgf_nosg", "vanillasgf_hook"]) ap.add_argument("--scenes", nargs="+", default=["francis","m60","panther","bonsai","bicycle","garden"]) ap.add_argument("--outputs_root", default="/root/autodl-tmp/SplatAtlas/outputs") ap.add_argument("--iter", type=int, default=30000) args = ap.parse_args() print(f"{'Method':<22} {'Scene':<10} {'Edge PSNR':>10} {'Flat PSNR':>10} {'Full PSNR':>10}") print("-"*68) for method in args.methods: all_e, all_f, all_t = [], [], [] for scene in args.scenes: rd = os.path.join(args.outputs_root, f"{method}_{scene}", f"renders_test_{args.iter}") gd = os.path.join(args.outputs_root, f"{method}_{scene}", f"gt_test_{args.iter}") if not (os.path.exists(rd) and os.path.exists(gd)): print(f"{method:<22} {scene:<10} (renders/gt missing)"); continue files = sorted(glob.glob(os.path.join(rd, "*.png")) + glob.glob(os.path.join(rd, "*.jpg"))) se, sf, st = [], [], [] for pf in files: gf = os.path.join(gd, os.path.basename(pf)) if not os.path.exists(gf): continue pred = torch.from_numpy(np.array(Image.open(pf).convert("RGB")).astype(np.float32)/255.).permute(2,0,1) gt = torch.from_numpy(np.array(Image.open(gf).convert("RGB")).astype(np.float32)/255.).permute(2,0,1) if pred.shape != gt.shape: continue sob = sobel_mag(gt) q75, q25 = float(sob.quantile(0.75)), float(sob.quantile(0.25)) em, fm = (sob >= q75), (sob <= q25) d = ((pred-gt)**2).mean(0) se.append(float(d[em].mean()) if em.any() else 0) sf.append(float(d[fm].mean()) if fm.any() else 0) st.append(float(d.mean())) if not st: print(f"{method:<22} {scene:<10} (no images)"); continue ep = -10*np.log10(np.mean(se)+1e-12) fp = -10*np.log10(np.mean(sf)+1e-12) tp = -10*np.log10(np.mean(st)+1e-12) print(f"{method:<22} {scene:<10} {ep:>10.2f} {fp:>10.2f} {tp:>10.2f}") all_e.extend(se); all_f.extend(sf); all_t.extend(st) if all_t: print(f"{method:<22} {'AVERAGE':<10} " f"{-10*np.log10(np.mean(all_e)+1e-12):>10.2f} " f"{-10*np.log10(np.mean(all_f)+1e-12):>10.2f} " f"{-10*np.log10(np.mean(all_t)+1e-12):>10.2f}") print("-"*68) if __name__ == "__main__": main()