KCBtheone's picture
Upload SplatAtlas benchmark pipeline code
23e73f9 verified
Raw
History Blame Contribute Delete
9.9 kB
#!/usr/bin/env python
"""Round-6 — rasterize_mode='antialiased' should kill the white haze."""
import os, sys, inspect, math, json
import numpy as np
import torch
from PIL import Image, ImageOps
from plyfile import PlyData
sys.path.insert(0, '/root/autodl-tmp/3dgsAtlas_official')
import gsplat
DATASET_ROOT = "/root/autodl-tmp/dataset/tnt"
OUTPUT_ROOT = "/root/autodl-tmp/SplatAtlas/outputs"
DUMP_ROOT = "/root/autodl-tmp/SplatAtlas/scripts/phase1_validation/diag_output"
SCENES_TEST = [("truck", "PASS"), ("lighthouse", "FAIL")]
def sec(t):
print("\n" + "=" * 70); print(f" {t}"); print("=" * 70)
def build_args(source_path, img_dir):
from scene.dataset_readers import readColmapSceneInfo
sig = inspect.signature(readColmapSceneInfo)
a = []
for i, (k, p) in enumerate(sig.parameters.items()):
if i == 0: a.append(source_path)
elif i == 1: a.append(img_dir)
elif k == "eval": a.append(True)
elif k == "train_test_exp": a.append(False)
else: a.append(p.default if p.default != inspect.Parameter.empty else "")
return a
def load_ply_tensors(ply_path, device):
v = PlyData.read(ply_path)['vertex']
def t32(x): return torch.tensor(x, dtype=torch.float32, device=device)
means = t32(np.stack((v['x'], v['y'], v['z']), -1))
quats = t32(np.stack((v['rot_0'], v['rot_1'], v['rot_2'], v['rot_3']), -1))
scales = torch.exp(t32(np.stack((v['scale_0'], v['scale_1'], v['scale_2']), -1)))
opacities = torch.sigmoid(t32(np.asarray(v['opacity'])))
f_dc = t32(np.stack((v['f_dc_0'], v['f_dc_1'], v['f_dc_2']), -1)).unsqueeze(1)
f_rest = t32(np.stack([v[f'f_rest_{i}'] for i in range(45)], -1)).view(-1, 3, 15).transpose(1, 2)
shs = torch.cat([f_dc, f_rest], dim=1)
return means, quats, scales, opacities, shs
def render_one(means, quats, scales, opacities, shs, cam, resolution, bg, mode, device):
from utils.graphics_utils import getWorld2View2
w = int(round(cam.width / resolution))
h = int(round(cam.height / resolution))
viewmat = torch.tensor(getWorld2View2(np.array(cam.R), np.array(cam.T)),
dtype=torch.float32, device=device).unsqueeze(0)
fx = w / (2 * math.tan(cam.FovX / 2))
fy = h / (2 * math.tan(cam.FovY / 2))
K = torch.tensor([[fx, 0, w/2], [0, fy, h/2], [0, 0, 1]],
dtype=torch.float32, device=device).unsqueeze(0)
with torch.no_grad():
colors, _, _ = gsplat.rasterization(
means=means, quats=quats, scales=scales, opacities=opacities,
colors=shs, viewmats=viewmat, Ks=K, width=w, height=h,
sh_degree=3, packed=True, render_mode='RGB',
backgrounds=bg.unsqueeze(0),
rasterize_mode=mode)
return colors[0].clamp(0, 1), w, h
def load_gt(source_path, img_dir, cam, w, h, device):
from utils.general_utils import PILtoTorch
pil = Image.open(os.path.join(source_path, img_dir, cam.image_name))
pil = ImageOps.exif_transpose(pil)
gt_chw = PILtoTorch(pil, (w, h)).to(device)
if gt_chw.shape[0] == 4: gt_chw = gt_chw[:3]
return gt_chw.permute(1, 2, 0)
# ---------- Probe 19: compare 'classic' vs 'antialiased' on all test cams ----------
def probe_modes(scene, label):
sec(f"PROBE 19 — rasterize_mode comparison [{scene}] [{label}]")
from scene.dataset_readers import readColmapSceneInfo
source_path = os.path.join(DATASET_ROOT, scene)
cell = os.path.join(OUTPUT_ROOT, f"vanilla_3dgs_{scene}")
img_dir = "images_2"; resolution = 2
scene_info = readColmapSceneInfo(*build_args(source_path, img_dir))
test_cams = scene_info.test_cameras
device = torch.device("cuda")
means, quats, scales, opacities, shs = load_ply_tensors(
os.path.join(cell, "point_cloud", "iteration_30000", "point_cloud.ply"), device)
bg = torch.tensor([0., 0., 0.], device=device)
results = {}
for mode in ["classic", "antialiased"]:
psnrs = []
for c in test_cams:
render, w, h = render_one(means, quats, scales, opacities, shs,
c, resolution, bg, mode, device)
gt = load_gt(source_path, img_dir, c, w, h, device)
mse = ((render - gt) ** 2).mean().item()
psnrs.append(10 * math.log10(1.0 / max(mse, 1e-10)))
results[mode] = np.array(psnrs)
print(f" mode='{mode}': mean PSNR = {results[mode].mean():.4f} dB "
f"(min={results[mode].min():.2f} max={results[mode].max():.2f})")
# Native baseline
baseline_path = os.path.join(cell, "metrics_test_iter30000.json")
bd = json.load(open(baseline_path))
native_psnr = bd.get("photometric", {}).get("PSNR", bd.get("PSNR"))
print(f" Native baseline = {native_psnr:.4f} dB")
print(f"")
print(f" Δ classic : {results['classic'].mean() - native_psnr:+.4f} dB")
print(f" Δ antialiased : {results['antialiased'].mean() - native_psnr:+.4f} dB")
gain = results['antialiased'].mean() - results['classic'].mean()
print(f" Gain from antialiased: {gain:+.4f} dB")
if abs(results['antialiased'].mean() - native_psnr) < 0.15:
print(f"\n ======> 完美命中 baseline。antialiased 就是对齐 3DGS 官方的正确模式。")
elif gain > 0.3:
print(f"\n ======> antialiased 显著改善但未完全命中。还有次级因素。")
else:
print(f"\n ======> antialiased 没帮上忙。换假说。")
return results['antialiased'], test_cams
# ---------- Probe 20: re-dump worst cams under antialiased mode ----------
def probe_redump(scene, antialiased_psnrs, test_cams):
sec(f"PROBE 20 — Re-render worst cams with antialiased [{scene}]")
source_path = os.path.join(DATASET_ROOT, scene)
cell = os.path.join(OUTPUT_ROOT, f"vanilla_3dgs_{scene}")
dump = os.path.join(DUMP_ROOT, f"{scene}_antialiased")
os.makedirs(dump, exist_ok=True)
for f in os.listdir(dump):
os.remove(os.path.join(dump, f))
# worst 3 cams
idx_sorted = np.argsort(antialiased_psnrs)[:3]
device = torch.device("cuda")
means, quats, scales, opacities, shs = load_ply_tensors(
os.path.join(cell, "point_cloud", "iteration_30000", "point_cloud.ply"), device)
bg = torch.tensor([0., 0., 0.], device=device)
native_render_dir = os.path.join(cell, "renders_test_30000")
if os.path.isdir(os.path.join(native_render_dir, "renders")):
native_render_dir = os.path.join(native_render_dir, "renders")
native_gt_dir = os.path.join(cell, "gt_test_30000")
native_renders = sorted([f for f in os.listdir(native_render_dir)
if f.lower().endswith(('.png', '.jpg'))])
native_gts = sorted([f for f in os.listdir(native_gt_dir)
if f.lower().endswith(('.png', '.jpg'))])
THUMB = 128
i2_files = sorted([f for f in os.listdir(os.path.join(source_path, "images_2"))
if f.lower().endswith(('.png', '.jpg', '.jpeg'))])
i2_thumbs = np.stack([
np.asarray(Image.open(os.path.join(source_path, "images_2", f)).convert('RGB')
.resize((THUMB, THUMB), Image.LANCZOS), dtype=np.float32) / 255.0
for f in i2_files])
real_to_native = {}
for nidx, ng in enumerate(native_gts):
t = np.asarray(Image.open(os.path.join(native_gt_dir, ng)).convert('RGB')
.resize((THUMB, THUMB), Image.LANCZOS), dtype=np.float32) / 255.0
real_to_native[i2_files[int(np.argmin(((i2_thumbs - t) ** 2).mean(axis=(1,2,3))))]] = nidx
def to_u8(a):
if a.dtype == np.uint8: return a
return np.clip(a * 255, 0, 255).astype(np.uint8)
def hstack(arrs, gap=20):
imgs = [to_u8(a) for a in arrs]
H = imgs[0].shape[0]
W = sum(i.shape[1] for i in imgs) + gap * (len(imgs) - 1)
out = np.ones((H, W, 3), dtype=np.uint8) * 240
x = 0
for img in imgs:
out[:, x:x + img.shape[1]] = img
x += img.shape[1] + gap
return out
for i in idx_sorted:
c = test_cams[i]
render, w, h = render_one(means, quats, scales, opacities, shs,
c, 2, bg, "antialiased", device)
render_np = render.cpu().numpy()
if c.image_name not in real_to_native:
continue
nidx = real_to_native[c.image_name]
nr = np.asarray(Image.open(os.path.join(native_render_dir, native_renders[nidx])).convert('RGB'),
dtype=np.float32) / 255.0
ng = np.asarray(Image.open(os.path.join(native_gt_dir, native_gts[nidx])).convert('RGB'),
dtype=np.float32) / 255.0
if nr.shape != render_np.shape:
th, tw = render_np.shape[:2]
nr = np.asarray(Image.fromarray(to_u8(nr)).resize((tw, th), Image.LANCZOS), dtype=np.float32) / 255.0
ng = np.asarray(Image.fromarray(to_u8(ng)).resize((tw, th), Image.LANCZOS), dtype=np.float32) / 255.0
vs_psnr = 10*math.log10(1.0 / max(((render_np - nr)**2).mean(), 1e-10))
print(f" cam {c.image_name}: new render_vs_native = {vs_psnr:.2f} dB")
base = os.path.splitext(c.image_name)[0]
diff5 = np.clip(np.abs(render_np - nr) * 5, 0, 1)
panel = hstack([ng, render_np, nr, diff5])
Image.fromarray(panel).save(os.path.join(dump, f"{base}_AA_panel.png"))
Image.fromarray(to_u8(render_np)).save(os.path.join(dump, f"{base}_AA_ours.png"))
print(f"\n dumped to: {dump}/")
def main():
for scene, label in SCENES_TEST:
print(f"\n\n{'#'*70}\n# SCENE: {scene} [{label}]\n{'#'*70}")
aa_psnrs, test_cams = probe_modes(scene, label)
probe_redump(scene, aa_psnrs, test_cams)
if __name__ == "__main__":
main()