SplatAtlas / scripts /region_psnr_analysis.py
KCBtheone's picture
Upload SplatAtlas benchmark pipeline code
23e73f9 verified
Raw
History Blame Contribute Delete
3.37 kB
"""Region-level PSNR (edge/flat by Sobel quartile of GT) → Table 7 expansion."""
import os, argparse, glob
import numpy as np
from PIL import Image
import torch
import torch.nn.functional as F
def sobel_mag(img_chw):
luma = (0.299*img_chw[0] + 0.587*img_chw[1] + 0.114*img_chw[2]).unsqueeze(0).unsqueeze(0)
wx = torch.tensor([[1,0,-1],[2,0,-2],[1,0,-1]], dtype=torch.float32).view(1,1,3,3)
wy = torch.tensor([[1,2,1],[0,0,0],[-1,-2,-1]], dtype=torch.float32).view(1,1,3,3)
gx = F.conv2d(luma, wx, padding=1)
gy = F.conv2d(luma, wy, padding=1)
return torch.sqrt(gx**2 + gy**2 + 1e-8).squeeze()
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--methods", nargs="+", default=["vanilla_logged", "vanillasgf_logged",
"vanillasgf_nosg", "vanillasgf_hook"])
ap.add_argument("--scenes", nargs="+", default=["francis","m60","panther","bonsai","bicycle","garden"])
ap.add_argument("--outputs_root", default="/root/autodl-tmp/SplatAtlas/outputs")
ap.add_argument("--iter", type=int, default=30000)
args = ap.parse_args()
print(f"{'Method':<22} {'Scene':<10} {'Edge PSNR':>10} {'Flat PSNR':>10} {'Full PSNR':>10}")
print("-"*68)
for method in args.methods:
all_e, all_f, all_t = [], [], []
for scene in args.scenes:
rd = os.path.join(args.outputs_root, f"{method}_{scene}", f"renders_test_{args.iter}")
gd = os.path.join(args.outputs_root, f"{method}_{scene}", f"gt_test_{args.iter}")
if not (os.path.exists(rd) and os.path.exists(gd)):
print(f"{method:<22} {scene:<10} (renders/gt missing)"); continue
files = sorted(glob.glob(os.path.join(rd, "*.png")) + glob.glob(os.path.join(rd, "*.jpg")))
se, sf, st = [], [], []
for pf in files:
gf = os.path.join(gd, os.path.basename(pf))
if not os.path.exists(gf): continue
pred = torch.from_numpy(np.array(Image.open(pf).convert("RGB")).astype(np.float32)/255.).permute(2,0,1)
gt = torch.from_numpy(np.array(Image.open(gf).convert("RGB")).astype(np.float32)/255.).permute(2,0,1)
if pred.shape != gt.shape: continue
sob = sobel_mag(gt)
q75, q25 = float(sob.quantile(0.75)), float(sob.quantile(0.25))
em, fm = (sob >= q75), (sob <= q25)
d = ((pred-gt)**2).mean(0)
se.append(float(d[em].mean()) if em.any() else 0)
sf.append(float(d[fm].mean()) if fm.any() else 0)
st.append(float(d.mean()))
if not st:
print(f"{method:<22} {scene:<10} (no images)"); continue
ep = -10*np.log10(np.mean(se)+1e-12)
fp = -10*np.log10(np.mean(sf)+1e-12)
tp = -10*np.log10(np.mean(st)+1e-12)
print(f"{method:<22} {scene:<10} {ep:>10.2f} {fp:>10.2f} {tp:>10.2f}")
all_e.extend(se); all_f.extend(sf); all_t.extend(st)
if all_t:
print(f"{method:<22} {'AVERAGE':<10} "
f"{-10*np.log10(np.mean(all_e)+1e-12):>10.2f} "
f"{-10*np.log10(np.mean(all_f)+1e-12):>10.2f} "
f"{-10*np.log10(np.mean(all_t)+1e-12):>10.2f}")
print("-"*68)
if __name__ == "__main__":
main()