SplatAtlas / scripts /phase1_validation /dump_tnt_worst.py
KCBtheone's picture
Upload SplatAtlas benchmark pipeline code
23e73f9 verified
Raw
History Blame Contribute Delete
8.53 kB
#!/usr/bin/env python
"""Dump <20dB render_vs_render cam pairs for manual inspection."""
import os, sys, inspect, math
import numpy as np
import torch
from PIL import Image, ImageOps
from plyfile import PlyData
sys.path.insert(0, '/root/autodl-tmp/3dgsAtlas_official')
import gsplat
DATASET_ROOT = "/root/autodl-tmp/dataset/tnt"
OUTPUT_ROOT = "/root/autodl-tmp/SplatAtlas/outputs"
DUMP_ROOT = "/tmp/tnt_diagnose"
SCENE = "lighthouse"
THRESHOLD = 20.0 # render_vs_render PSNR < THRESHOLD 的 cam 会被 dump
def build_args(source_path, img_dir):
from scene.dataset_readers import readColmapSceneInfo
sig = inspect.signature(readColmapSceneInfo)
a = []
for i, (k, p) in enumerate(sig.parameters.items()):
if i == 0: a.append(source_path)
elif i == 1: a.append(img_dir)
elif k == "eval": a.append(True)
elif k == "train_test_exp": a.append(False)
else: a.append(p.default if p.default != inspect.Parameter.empty else "")
return a
def to_u8(arr):
return np.clip(arr * 255, 0, 255).astype(np.uint8) if arr.dtype != np.uint8 else arr
def hstack_panels(arrs, gap=20):
imgs = [to_u8(a) for a in arrs]
H = imgs[0].shape[0]
W = sum(i.shape[1] for i in imgs) + gap * (len(imgs) - 1)
out = np.ones((H, W, 3), dtype=np.uint8) * 240
x = 0
for img in imgs:
out[:, x:x + img.shape[1]] = img
x += img.shape[1] + gap
return out
def main():
source_path = os.path.join(DATASET_ROOT, SCENE)
cell = os.path.join(OUTPUT_ROOT, f"vanilla_3dgs_{SCENE}")
dump = os.path.join(DUMP_ROOT, SCENE)
os.makedirs(dump, exist_ok=True)
# 清空旧内容
for f in os.listdir(dump):
os.remove(os.path.join(dump, f))
img_dir = "images_2"; resolution = 2
from scene.dataset_readers import readColmapSceneInfo
from utils.graphics_utils import getWorld2View2
from utils.general_utils import PILtoTorch
scene_info = readColmapSceneInfo(*build_args(source_path, img_dir))
test_cams = scene_info.test_cameras
print(f"test cams: {len(test_cams)}")
# reverse-lookup: real image_name → native_idx
native_render_dir = os.path.join(cell, "renders_test_30000")
if os.path.isdir(os.path.join(native_render_dir, "renders")):
native_render_dir = os.path.join(native_render_dir, "renders")
native_gt_dir = os.path.join(cell, "gt_test_30000")
native_renders = sorted([f for f in os.listdir(native_render_dir)
if f.lower().endswith(('.png', '.jpg'))])
native_gts = sorted([f for f in os.listdir(native_gt_dir)
if f.lower().endswith(('.png', '.jpg'))])
THUMB = 128
i2_files = sorted([f for f in os.listdir(os.path.join(source_path, img_dir))
if f.lower().endswith(('.png', '.jpg', '.jpeg'))])
i2_thumbs = np.stack([
np.asarray(Image.open(os.path.join(source_path, img_dir, f)).convert('RGB')
.resize((THUMB, THUMB), Image.LANCZOS), dtype=np.float32) / 255.0
for f in i2_files])
real_to_native_idx = {}
for nidx, ng in enumerate(native_gts):
t = np.asarray(Image.open(os.path.join(native_gt_dir, ng)).convert('RGB')
.resize((THUMB, THUMB), Image.LANCZOS), dtype=np.float32) / 255.0
mse = ((i2_thumbs - t) ** 2).mean(axis=(1, 2, 3))
real_to_native_idx[i2_files[int(np.argmin(mse))]] = nidx
# PLY
pd = PlyData.read(os.path.join(cell, "point_cloud", "iteration_30000", "point_cloud.ply"))
vv = pd['vertex']
device = torch.device("cuda")
def t32(x): return torch.tensor(x, dtype=torch.float32, device=device)
means = t32(np.stack((vv['x'], vv['y'], vv['z']), -1))
quats = t32(np.stack((vv['rot_0'], vv['rot_1'], vv['rot_2'], vv['rot_3']), -1))
scales = torch.exp(t32(np.stack((vv['scale_0'], vv['scale_1'], vv['scale_2']), -1)))
opacities = torch.sigmoid(t32(np.asarray(vv['opacity'])))
f_dc = t32(np.stack((vv['f_dc_0'], vv['f_dc_1'], vv['f_dc_2']), -1)).unsqueeze(1)
f_rest = t32(np.stack([vv[f'f_rest_{i}'] for i in range(45)], -1)).view(-1, 3, 15).transpose(1, 2)
shs = torch.cat([f_dc, f_rest], dim=1)
bg_color = torch.tensor([0., 0., 0.], device=device)
summary = []
for i, c in enumerate(test_cams):
w = int(round(c.width / resolution)); h = int(round(c.height / resolution))
viewmat = t32(getWorld2View2(np.array(c.R), np.array(c.T))).unsqueeze(0)
fx = w / (2 * math.tan(c.FovX / 2)); fy = h / (2 * math.tan(c.FovY / 2))
K = t32(np.array([[fx, 0, w/2], [0, fy, h/2], [0, 0, 1]])).unsqueeze(0)
with torch.no_grad():
colors, _, _ = gsplat.rasterization(
means=means, quats=quats, scales=scales, opacities=opacities,
colors=shs, viewmats=viewmat, Ks=K, width=w, height=h,
sh_degree=3, packed=True, render_mode='RGB',
backgrounds=bg_color.unsqueeze(0))
render = colors[0].clamp(0, 1).cpu().numpy() # (H, W, 3) float
if c.image_name not in real_to_native_idx:
continue
nidx = real_to_native_idx[c.image_name]
nr = np.asarray(Image.open(os.path.join(native_render_dir, native_renders[nidx]))
.convert('RGB'), dtype=np.float32) / 255.0
ng = np.asarray(Image.open(os.path.join(native_gt_dir, native_gts[nidx]))
.convert('RGB'), dtype=np.float32) / 255.0
# 统一尺寸到 render
if nr.shape != render.shape:
th, tw = render.shape[:2]
nr = np.asarray(Image.fromarray(to_u8(nr)).resize((tw, th), Image.LANCZOS), dtype=np.float32) / 255.0
ng = np.asarray(Image.fromarray(to_u8(ng)).resize((tw, th), Image.LANCZOS), dtype=np.float32) / 255.0
vs_mse = ((render - nr) ** 2).mean()
vs_psnr = 10 * math.log10(1.0 / max(vs_mse, 1e-10))
ours_p = 10 * math.log10(1.0 / max(((render - ng) ** 2).mean(), 1e-10))
nat_p = 10 * math.log10(1.0 / max(((nr - ng) ** 2).mean(), 1e-10))
mark = "DUMP" if vs_psnr < THRESHOLD else "skip"
print(f" {i:>3} {c.image_name:>12} vs={vs_psnr:6.2f} ours={ours_p:6.2f} nat={nat_p:6.2f} {mark}")
if vs_psnr < THRESHOLD:
base = os.path.splitext(c.image_name)[0]
Image.fromarray(to_u8(ng)).save(os.path.join(dump, f"{base}_1_gt.png"))
Image.fromarray(to_u8(render)).save(os.path.join(dump, f"{base}_2_ours.png"))
Image.fromarray(to_u8(nr)).save(os.path.join(dump, f"{base}_3_native.png"))
diff5 = np.clip(np.abs(render - nr) * 5, 0, 1)
diff20 = np.clip(np.abs(render - nr) * 20, 0, 1)
Image.fromarray(to_u8(diff5)).save(os.path.join(dump, f"{base}_4_diff5x.png"))
Image.fromarray(to_u8(diff20)).save(os.path.join(dump, f"{base}_5_diff20x.png"))
# 4-panel: GT | OURS | NATIVE | DIFF*5
panel = hstack_panels([ng, render, nr, diff5])
Image.fromarray(panel).save(os.path.join(dump, f"{base}_panel.png"))
summary.append((i, c.image_name, vs_psnr, ours_p, nat_p))
# summary
sp = os.path.join(dump, "summary.txt")
with open(sp, "w") as f:
f.write(f"Scene: {SCENE}\n")
f.write(f"Threshold: render_vs_render < {THRESHOLD} dB\n")
f.write(f"Dumped : {len(summary)} cams\n\n")
f.write(f"{'idx':>3} {'image_name':>14} {'vs_native':>10} {'ours_psnr':>10} {'nat_psnr':>9}\n")
f.write("-" * 60 + "\n")
for s in sorted(summary, key=lambda x: x[2]):
f.write(f"{s[0]:>3} {s[1]:>14} {s[2]:>10.2f} {s[3]:>10.2f} {s[4]:>9.2f}\n")
f.write("\nPer-cam files:\n")
f.write(" <name>_1_gt.png GT (Native 存盘)\n")
f.write(" <name>_2_ours.png gsplat render\n")
f.write(" <name>_3_native.png diff-gaussian-rasterization render\n")
f.write(" <name>_4_diff5x.png |ours-native| * 5 (彩色)\n")
f.write(" <name>_5_diff20x.png |ours-native| * 20 (极端放大)\n")
f.write(" <name>_panel.png GT | OURS | NATIVE | DIFF*5 横排\n")
print(f"\nDumped {len(summary)} cams to:")
print(f" {dump}/")
print(f"Summary: {sp}")
print(f"\n拉回本地查看:")
print(f" scp -P <port> -r root@<host>:{dump} ./tnt_{SCENE}_diag")
if __name__ == "__main__":
main()