| |
| """Round-6 — rasterize_mode='antialiased' should kill the white haze.""" |
| import os, sys, inspect, math, json |
| import numpy as np |
| import torch |
| from PIL import Image, ImageOps |
| from plyfile import PlyData |
|
|
| sys.path.insert(0, '/root/autodl-tmp/3dgsAtlas_official') |
| import gsplat |
|
|
| DATASET_ROOT = "/root/autodl-tmp/dataset/tnt" |
| OUTPUT_ROOT = "/root/autodl-tmp/SplatAtlas/outputs" |
| DUMP_ROOT = "/root/autodl-tmp/SplatAtlas/scripts/phase1_validation/diag_output" |
| SCENES_TEST = [("truck", "PASS"), ("lighthouse", "FAIL")] |
|
|
|
|
| def sec(t): |
| print("\n" + "=" * 70); print(f" {t}"); print("=" * 70) |
|
|
|
|
| def build_args(source_path, img_dir): |
| from scene.dataset_readers import readColmapSceneInfo |
| sig = inspect.signature(readColmapSceneInfo) |
| a = [] |
| for i, (k, p) in enumerate(sig.parameters.items()): |
| if i == 0: a.append(source_path) |
| elif i == 1: a.append(img_dir) |
| elif k == "eval": a.append(True) |
| elif k == "train_test_exp": a.append(False) |
| else: a.append(p.default if p.default != inspect.Parameter.empty else "") |
| return a |
|
|
|
|
| def load_ply_tensors(ply_path, device): |
| v = PlyData.read(ply_path)['vertex'] |
| def t32(x): return torch.tensor(x, dtype=torch.float32, device=device) |
| means = t32(np.stack((v['x'], v['y'], v['z']), -1)) |
| quats = t32(np.stack((v['rot_0'], v['rot_1'], v['rot_2'], v['rot_3']), -1)) |
| scales = torch.exp(t32(np.stack((v['scale_0'], v['scale_1'], v['scale_2']), -1))) |
| opacities = torch.sigmoid(t32(np.asarray(v['opacity']))) |
| f_dc = t32(np.stack((v['f_dc_0'], v['f_dc_1'], v['f_dc_2']), -1)).unsqueeze(1) |
| f_rest = t32(np.stack([v[f'f_rest_{i}'] for i in range(45)], -1)).view(-1, 3, 15).transpose(1, 2) |
| shs = torch.cat([f_dc, f_rest], dim=1) |
| return means, quats, scales, opacities, shs |
|
|
|
|
| def render_one(means, quats, scales, opacities, shs, cam, resolution, bg, mode, device): |
| from utils.graphics_utils import getWorld2View2 |
| w = int(round(cam.width / resolution)) |
| h = int(round(cam.height / resolution)) |
| viewmat = torch.tensor(getWorld2View2(np.array(cam.R), np.array(cam.T)), |
| dtype=torch.float32, device=device).unsqueeze(0) |
| fx = w / (2 * math.tan(cam.FovX / 2)) |
| fy = h / (2 * math.tan(cam.FovY / 2)) |
| K = torch.tensor([[fx, 0, w/2], [0, fy, h/2], [0, 0, 1]], |
| dtype=torch.float32, device=device).unsqueeze(0) |
| with torch.no_grad(): |
| colors, _, _ = gsplat.rasterization( |
| means=means, quats=quats, scales=scales, opacities=opacities, |
| colors=shs, viewmats=viewmat, Ks=K, width=w, height=h, |
| sh_degree=3, packed=True, render_mode='RGB', |
| backgrounds=bg.unsqueeze(0), |
| rasterize_mode=mode) |
| return colors[0].clamp(0, 1), w, h |
|
|
|
|
| def load_gt(source_path, img_dir, cam, w, h, device): |
| from utils.general_utils import PILtoTorch |
| pil = Image.open(os.path.join(source_path, img_dir, cam.image_name)) |
| pil = ImageOps.exif_transpose(pil) |
| gt_chw = PILtoTorch(pil, (w, h)).to(device) |
| if gt_chw.shape[0] == 4: gt_chw = gt_chw[:3] |
| return gt_chw.permute(1, 2, 0) |
|
|
|
|
| |
| def probe_modes(scene, label): |
| sec(f"PROBE 19 — rasterize_mode comparison [{scene}] [{label}]") |
| from scene.dataset_readers import readColmapSceneInfo |
|
|
| source_path = os.path.join(DATASET_ROOT, scene) |
| cell = os.path.join(OUTPUT_ROOT, f"vanilla_3dgs_{scene}") |
| img_dir = "images_2"; resolution = 2 |
| scene_info = readColmapSceneInfo(*build_args(source_path, img_dir)) |
| test_cams = scene_info.test_cameras |
|
|
| device = torch.device("cuda") |
| means, quats, scales, opacities, shs = load_ply_tensors( |
| os.path.join(cell, "point_cloud", "iteration_30000", "point_cloud.ply"), device) |
| bg = torch.tensor([0., 0., 0.], device=device) |
|
|
| results = {} |
| for mode in ["classic", "antialiased"]: |
| psnrs = [] |
| for c in test_cams: |
| render, w, h = render_one(means, quats, scales, opacities, shs, |
| c, resolution, bg, mode, device) |
| gt = load_gt(source_path, img_dir, c, w, h, device) |
| mse = ((render - gt) ** 2).mean().item() |
| psnrs.append(10 * math.log10(1.0 / max(mse, 1e-10))) |
| results[mode] = np.array(psnrs) |
| print(f" mode='{mode}': mean PSNR = {results[mode].mean():.4f} dB " |
| f"(min={results[mode].min():.2f} max={results[mode].max():.2f})") |
|
|
| |
| baseline_path = os.path.join(cell, "metrics_test_iter30000.json") |
| bd = json.load(open(baseline_path)) |
| native_psnr = bd.get("photometric", {}).get("PSNR", bd.get("PSNR")) |
| print(f" Native baseline = {native_psnr:.4f} dB") |
| print(f"") |
| print(f" Δ classic : {results['classic'].mean() - native_psnr:+.4f} dB") |
| print(f" Δ antialiased : {results['antialiased'].mean() - native_psnr:+.4f} dB") |
|
|
| gain = results['antialiased'].mean() - results['classic'].mean() |
| print(f" Gain from antialiased: {gain:+.4f} dB") |
|
|
| if abs(results['antialiased'].mean() - native_psnr) < 0.15: |
| print(f"\n ======> 完美命中 baseline。antialiased 就是对齐 3DGS 官方的正确模式。") |
| elif gain > 0.3: |
| print(f"\n ======> antialiased 显著改善但未完全命中。还有次级因素。") |
| else: |
| print(f"\n ======> antialiased 没帮上忙。换假说。") |
|
|
| return results['antialiased'], test_cams |
|
|
|
|
| |
| def probe_redump(scene, antialiased_psnrs, test_cams): |
| sec(f"PROBE 20 — Re-render worst cams with antialiased [{scene}]") |
| source_path = os.path.join(DATASET_ROOT, scene) |
| cell = os.path.join(OUTPUT_ROOT, f"vanilla_3dgs_{scene}") |
| dump = os.path.join(DUMP_ROOT, f"{scene}_antialiased") |
| os.makedirs(dump, exist_ok=True) |
| for f in os.listdir(dump): |
| os.remove(os.path.join(dump, f)) |
|
|
| |
| idx_sorted = np.argsort(antialiased_psnrs)[:3] |
| device = torch.device("cuda") |
| means, quats, scales, opacities, shs = load_ply_tensors( |
| os.path.join(cell, "point_cloud", "iteration_30000", "point_cloud.ply"), device) |
| bg = torch.tensor([0., 0., 0.], device=device) |
|
|
| native_render_dir = os.path.join(cell, "renders_test_30000") |
| if os.path.isdir(os.path.join(native_render_dir, "renders")): |
| native_render_dir = os.path.join(native_render_dir, "renders") |
| native_gt_dir = os.path.join(cell, "gt_test_30000") |
| native_renders = sorted([f for f in os.listdir(native_render_dir) |
| if f.lower().endswith(('.png', '.jpg'))]) |
| native_gts = sorted([f for f in os.listdir(native_gt_dir) |
| if f.lower().endswith(('.png', '.jpg'))]) |
|
|
| THUMB = 128 |
| i2_files = sorted([f for f in os.listdir(os.path.join(source_path, "images_2")) |
| if f.lower().endswith(('.png', '.jpg', '.jpeg'))]) |
| i2_thumbs = np.stack([ |
| np.asarray(Image.open(os.path.join(source_path, "images_2", f)).convert('RGB') |
| .resize((THUMB, THUMB), Image.LANCZOS), dtype=np.float32) / 255.0 |
| for f in i2_files]) |
| real_to_native = {} |
| for nidx, ng in enumerate(native_gts): |
| t = np.asarray(Image.open(os.path.join(native_gt_dir, ng)).convert('RGB') |
| .resize((THUMB, THUMB), Image.LANCZOS), dtype=np.float32) / 255.0 |
| real_to_native[i2_files[int(np.argmin(((i2_thumbs - t) ** 2).mean(axis=(1,2,3))))]] = nidx |
|
|
| def to_u8(a): |
| if a.dtype == np.uint8: return a |
| return np.clip(a * 255, 0, 255).astype(np.uint8) |
|
|
| def hstack(arrs, gap=20): |
| imgs = [to_u8(a) for a in arrs] |
| H = imgs[0].shape[0] |
| W = sum(i.shape[1] for i in imgs) + gap * (len(imgs) - 1) |
| out = np.ones((H, W, 3), dtype=np.uint8) * 240 |
| x = 0 |
| for img in imgs: |
| out[:, x:x + img.shape[1]] = img |
| x += img.shape[1] + gap |
| return out |
|
|
| for i in idx_sorted: |
| c = test_cams[i] |
| render, w, h = render_one(means, quats, scales, opacities, shs, |
| c, 2, bg, "antialiased", device) |
| render_np = render.cpu().numpy() |
| if c.image_name not in real_to_native: |
| continue |
| nidx = real_to_native[c.image_name] |
| nr = np.asarray(Image.open(os.path.join(native_render_dir, native_renders[nidx])).convert('RGB'), |
| dtype=np.float32) / 255.0 |
| ng = np.asarray(Image.open(os.path.join(native_gt_dir, native_gts[nidx])).convert('RGB'), |
| dtype=np.float32) / 255.0 |
| if nr.shape != render_np.shape: |
| th, tw = render_np.shape[:2] |
| nr = np.asarray(Image.fromarray(to_u8(nr)).resize((tw, th), Image.LANCZOS), dtype=np.float32) / 255.0 |
| ng = np.asarray(Image.fromarray(to_u8(ng)).resize((tw, th), Image.LANCZOS), dtype=np.float32) / 255.0 |
| vs_psnr = 10*math.log10(1.0 / max(((render_np - nr)**2).mean(), 1e-10)) |
| print(f" cam {c.image_name}: new render_vs_native = {vs_psnr:.2f} dB") |
|
|
| base = os.path.splitext(c.image_name)[0] |
| diff5 = np.clip(np.abs(render_np - nr) * 5, 0, 1) |
| panel = hstack([ng, render_np, nr, diff5]) |
| Image.fromarray(panel).save(os.path.join(dump, f"{base}_AA_panel.png")) |
| Image.fromarray(to_u8(render_np)).save(os.path.join(dump, f"{base}_AA_ours.png")) |
|
|
| print(f"\n dumped to: {dump}/") |
|
|
|
|
| def main(): |
| for scene, label in SCENES_TEST: |
| print(f"\n\n{'#'*70}\n# SCENE: {scene} [{label}]\n{'#'*70}") |
| aa_psnrs, test_cams = probe_modes(scene, label) |
| probe_redump(scene, aa_psnrs, test_cams) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|