| |
| """Dump <20dB render_vs_render cam pairs for manual inspection.""" |
| import os, sys, inspect, math |
| import numpy as np |
| import torch |
| from PIL import Image, ImageOps |
| from plyfile import PlyData |
|
|
| sys.path.insert(0, '/root/autodl-tmp/3dgsAtlas_official') |
| import gsplat |
|
|
| DATASET_ROOT = "/root/autodl-tmp/dataset/tnt" |
| OUTPUT_ROOT = "/root/autodl-tmp/SplatAtlas/outputs" |
| DUMP_ROOT = "/tmp/tnt_diagnose" |
| SCENE = "lighthouse" |
| THRESHOLD = 20.0 |
|
|
|
|
| def build_args(source_path, img_dir): |
| from scene.dataset_readers import readColmapSceneInfo |
| sig = inspect.signature(readColmapSceneInfo) |
| a = [] |
| for i, (k, p) in enumerate(sig.parameters.items()): |
| if i == 0: a.append(source_path) |
| elif i == 1: a.append(img_dir) |
| elif k == "eval": a.append(True) |
| elif k == "train_test_exp": a.append(False) |
| else: a.append(p.default if p.default != inspect.Parameter.empty else "") |
| return a |
|
|
|
|
| def to_u8(arr): |
| return np.clip(arr * 255, 0, 255).astype(np.uint8) if arr.dtype != np.uint8 else arr |
|
|
|
|
| def hstack_panels(arrs, gap=20): |
| imgs = [to_u8(a) for a in arrs] |
| H = imgs[0].shape[0] |
| W = sum(i.shape[1] for i in imgs) + gap * (len(imgs) - 1) |
| out = np.ones((H, W, 3), dtype=np.uint8) * 240 |
| x = 0 |
| for img in imgs: |
| out[:, x:x + img.shape[1]] = img |
| x += img.shape[1] + gap |
| return out |
|
|
|
|
| def main(): |
| source_path = os.path.join(DATASET_ROOT, SCENE) |
| cell = os.path.join(OUTPUT_ROOT, f"vanilla_3dgs_{SCENE}") |
| dump = os.path.join(DUMP_ROOT, SCENE) |
| os.makedirs(dump, exist_ok=True) |
| |
| for f in os.listdir(dump): |
| os.remove(os.path.join(dump, f)) |
|
|
| img_dir = "images_2"; resolution = 2 |
|
|
| from scene.dataset_readers import readColmapSceneInfo |
| from utils.graphics_utils import getWorld2View2 |
| from utils.general_utils import PILtoTorch |
|
|
| scene_info = readColmapSceneInfo(*build_args(source_path, img_dir)) |
| test_cams = scene_info.test_cameras |
| print(f"test cams: {len(test_cams)}") |
|
|
| |
| native_render_dir = os.path.join(cell, "renders_test_30000") |
| if os.path.isdir(os.path.join(native_render_dir, "renders")): |
| native_render_dir = os.path.join(native_render_dir, "renders") |
| native_gt_dir = os.path.join(cell, "gt_test_30000") |
| native_renders = sorted([f for f in os.listdir(native_render_dir) |
| if f.lower().endswith(('.png', '.jpg'))]) |
| native_gts = sorted([f for f in os.listdir(native_gt_dir) |
| if f.lower().endswith(('.png', '.jpg'))]) |
|
|
| THUMB = 128 |
| i2_files = sorted([f for f in os.listdir(os.path.join(source_path, img_dir)) |
| if f.lower().endswith(('.png', '.jpg', '.jpeg'))]) |
| i2_thumbs = np.stack([ |
| np.asarray(Image.open(os.path.join(source_path, img_dir, f)).convert('RGB') |
| .resize((THUMB, THUMB), Image.LANCZOS), dtype=np.float32) / 255.0 |
| for f in i2_files]) |
| real_to_native_idx = {} |
| for nidx, ng in enumerate(native_gts): |
| t = np.asarray(Image.open(os.path.join(native_gt_dir, ng)).convert('RGB') |
| .resize((THUMB, THUMB), Image.LANCZOS), dtype=np.float32) / 255.0 |
| mse = ((i2_thumbs - t) ** 2).mean(axis=(1, 2, 3)) |
| real_to_native_idx[i2_files[int(np.argmin(mse))]] = nidx |
|
|
| |
| pd = PlyData.read(os.path.join(cell, "point_cloud", "iteration_30000", "point_cloud.ply")) |
| vv = pd['vertex'] |
| device = torch.device("cuda") |
| def t32(x): return torch.tensor(x, dtype=torch.float32, device=device) |
| means = t32(np.stack((vv['x'], vv['y'], vv['z']), -1)) |
| quats = t32(np.stack((vv['rot_0'], vv['rot_1'], vv['rot_2'], vv['rot_3']), -1)) |
| scales = torch.exp(t32(np.stack((vv['scale_0'], vv['scale_1'], vv['scale_2']), -1))) |
| opacities = torch.sigmoid(t32(np.asarray(vv['opacity']))) |
| f_dc = t32(np.stack((vv['f_dc_0'], vv['f_dc_1'], vv['f_dc_2']), -1)).unsqueeze(1) |
| f_rest = t32(np.stack([vv[f'f_rest_{i}'] for i in range(45)], -1)).view(-1, 3, 15).transpose(1, 2) |
| shs = torch.cat([f_dc, f_rest], dim=1) |
| bg_color = torch.tensor([0., 0., 0.], device=device) |
|
|
| summary = [] |
| for i, c in enumerate(test_cams): |
| w = int(round(c.width / resolution)); h = int(round(c.height / resolution)) |
| viewmat = t32(getWorld2View2(np.array(c.R), np.array(c.T))).unsqueeze(0) |
| fx = w / (2 * math.tan(c.FovX / 2)); fy = h / (2 * math.tan(c.FovY / 2)) |
| K = t32(np.array([[fx, 0, w/2], [0, fy, h/2], [0, 0, 1]])).unsqueeze(0) |
|
|
| with torch.no_grad(): |
| colors, _, _ = gsplat.rasterization( |
| means=means, quats=quats, scales=scales, opacities=opacities, |
| colors=shs, viewmats=viewmat, Ks=K, width=w, height=h, |
| sh_degree=3, packed=True, render_mode='RGB', |
| backgrounds=bg_color.unsqueeze(0)) |
| render = colors[0].clamp(0, 1).cpu().numpy() |
|
|
| if c.image_name not in real_to_native_idx: |
| continue |
| nidx = real_to_native_idx[c.image_name] |
| nr = np.asarray(Image.open(os.path.join(native_render_dir, native_renders[nidx])) |
| .convert('RGB'), dtype=np.float32) / 255.0 |
| ng = np.asarray(Image.open(os.path.join(native_gt_dir, native_gts[nidx])) |
| .convert('RGB'), dtype=np.float32) / 255.0 |
| |
| if nr.shape != render.shape: |
| th, tw = render.shape[:2] |
| nr = np.asarray(Image.fromarray(to_u8(nr)).resize((tw, th), Image.LANCZOS), dtype=np.float32) / 255.0 |
| ng = np.asarray(Image.fromarray(to_u8(ng)).resize((tw, th), Image.LANCZOS), dtype=np.float32) / 255.0 |
|
|
| vs_mse = ((render - nr) ** 2).mean() |
| vs_psnr = 10 * math.log10(1.0 / max(vs_mse, 1e-10)) |
| ours_p = 10 * math.log10(1.0 / max(((render - ng) ** 2).mean(), 1e-10)) |
| nat_p = 10 * math.log10(1.0 / max(((nr - ng) ** 2).mean(), 1e-10)) |
|
|
| mark = "DUMP" if vs_psnr < THRESHOLD else "skip" |
| print(f" {i:>3} {c.image_name:>12} vs={vs_psnr:6.2f} ours={ours_p:6.2f} nat={nat_p:6.2f} {mark}") |
|
|
| if vs_psnr < THRESHOLD: |
| base = os.path.splitext(c.image_name)[0] |
| Image.fromarray(to_u8(ng)).save(os.path.join(dump, f"{base}_1_gt.png")) |
| Image.fromarray(to_u8(render)).save(os.path.join(dump, f"{base}_2_ours.png")) |
| Image.fromarray(to_u8(nr)).save(os.path.join(dump, f"{base}_3_native.png")) |
| diff5 = np.clip(np.abs(render - nr) * 5, 0, 1) |
| diff20 = np.clip(np.abs(render - nr) * 20, 0, 1) |
| Image.fromarray(to_u8(diff5)).save(os.path.join(dump, f"{base}_4_diff5x.png")) |
| Image.fromarray(to_u8(diff20)).save(os.path.join(dump, f"{base}_5_diff20x.png")) |
| |
| panel = hstack_panels([ng, render, nr, diff5]) |
| Image.fromarray(panel).save(os.path.join(dump, f"{base}_panel.png")) |
| summary.append((i, c.image_name, vs_psnr, ours_p, nat_p)) |
|
|
| |
| sp = os.path.join(dump, "summary.txt") |
| with open(sp, "w") as f: |
| f.write(f"Scene: {SCENE}\n") |
| f.write(f"Threshold: render_vs_render < {THRESHOLD} dB\n") |
| f.write(f"Dumped : {len(summary)} cams\n\n") |
| f.write(f"{'idx':>3} {'image_name':>14} {'vs_native':>10} {'ours_psnr':>10} {'nat_psnr':>9}\n") |
| f.write("-" * 60 + "\n") |
| for s in sorted(summary, key=lambda x: x[2]): |
| f.write(f"{s[0]:>3} {s[1]:>14} {s[2]:>10.2f} {s[3]:>10.2f} {s[4]:>9.2f}\n") |
| f.write("\nPer-cam files:\n") |
| f.write(" <name>_1_gt.png GT (Native 存盘)\n") |
| f.write(" <name>_2_ours.png gsplat render\n") |
| f.write(" <name>_3_native.png diff-gaussian-rasterization render\n") |
| f.write(" <name>_4_diff5x.png |ours-native| * 5 (彩色)\n") |
| f.write(" <name>_5_diff20x.png |ours-native| * 20 (极端放大)\n") |
| f.write(" <name>_panel.png GT | OURS | NATIVE | DIFF*5 横排\n") |
|
|
| print(f"\nDumped {len(summary)} cams to:") |
| print(f" {dump}/") |
| print(f"Summary: {sp}") |
| print(f"\n拉回本地查看:") |
| print(f" scp -P <port> -r root@<host>:{dump} ./tnt_{SCENE}_diag") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|