KCBtheone's picture
Upload SplatAtlas benchmark pipeline code
23e73f9 verified
Raw
History Blame Contribute Delete
8.92 kB
#!/usr/bin/env python
"""Round-5 — exposure compensation hypothesis (decisive)."""
import os, sys, json, math, inspect
import numpy as np
import torch
from PIL import Image, ImageOps
from plyfile import PlyData
sys.path.insert(0, '/root/autodl-tmp/3dgsAtlas_official')
import gsplat
DATASET_ROOT = "/root/autodl-tmp/dataset/tnt"
OUTPUT_ROOT = "/root/autodl-tmp/SplatAtlas/outputs"
SCENES = [("truck", "PASS"), ("lighthouse", "FAIL")]
def sec(t):
print("\n" + "=" * 70); print(f" {t}"); print("=" * 70)
def build_args(source_path, img_dir):
from scene.dataset_readers import readColmapSceneInfo
sig = inspect.signature(readColmapSceneInfo)
a = []
for i, (k, p) in enumerate(sig.parameters.items()):
if i == 0: a.append(source_path)
elif i == 1: a.append(img_dir)
elif k == "eval": a.append(True)
elif k == "train_test_exp": a.append(False)
else: a.append(p.default if p.default != inspect.Parameter.empty else "")
return a
# ---------- Probe 17: exposure.json 结构 ----------
def probe_exposure_structure(scene):
sec(f"PROBE 17 — exposure.json structure [{scene}]")
p = os.path.join(OUTPUT_ROOT, f"vanilla_3dgs_{scene}", "exposure.json")
if not os.path.exists(p):
print(f"[!] not found")
return None
data = json.load(open(p))
if not isinstance(data, dict):
print(f"Not a dict: {type(data)}")
return None
keys = list(data.keys())
print(f"Entries : {len(keys)}")
print(f"Sample keys : {keys[:5]}")
sample = np.asarray(data[keys[0]])
print(f"Sample shape: {sample.shape}")
print(f"Sample[0]:\n{sample}")
all_m = np.stack([np.asarray(data[k]) for k in keys])
print(f"\nAll exposures shape: {all_m.shape}")
if all_m.ndim == 3 and all_m.shape[1] == 3 and all_m.shape[2] >= 3:
linear = all_m[:, :3, :3]
bias = all_m[:, :3, 3] if all_m.shape[2] == 4 else np.zeros((len(keys), 3))
eye = np.eye(3)[None]
lin_dev = np.abs(linear - eye).reshape(len(keys), -1).max(axis=1)
bias_mag = np.linalg.norm(bias, axis=1)
print(f"\nLinear max |dev from I|:")
print(f" min={lin_dev.min():.5f} p50={np.percentile(lin_dev,50):.5f} "
f"p95={np.percentile(lin_dev,95):.5f} max={lin_dev.max():.5f}")
print(f"Bias magnitude ||b||:")
print(f" min={bias_mag.min():.5f} p50={np.percentile(bias_mag,50):.5f} "
f"p95={np.percentile(bias_mag,95):.5f} max={bias_mag.max():.5f}")
trivial = ((lin_dev < 0.01) & (bias_mag < 0.01)).sum()
large = ((lin_dev >= 0.05) | (bias_mag >= 0.05)).sum()
print(f"\nNear-identity (lin_dev<0.01 & bias<0.01): {trivial} / {len(keys)}")
print(f"Large deviation (>=0.05) : {large} / {len(keys)}")
if trivial > len(keys) * 0.95:
print("[VERDICT] 曝光几乎都是单位阵 → 曝光补偿不是根因")
elif large > 0:
print("[VERDICT] 存在显著曝光偏移 → 曝光补偿极可能是根因")
return data
# ---------- Probe 18: 应用 exposure 后重算 PSNR ----------
def probe_exposure_corrected(scene, exposures):
sec(f"PROBE 18 — Apply exposure to render & recompute PSNR [{scene}]")
if exposures is None:
print("[!] no exposure data, skip")
return
from scene.dataset_readers import readColmapSceneInfo
from utils.graphics_utils import getWorld2View2
from utils.general_utils import PILtoTorch
source_path = os.path.join(DATASET_ROOT, scene)
cell = os.path.join(OUTPUT_ROOT, f"vanilla_3dgs_{scene}")
img_dir = "images_2"; resolution = 2
scene_info = readColmapSceneInfo(*build_args(source_path, img_dir))
test_cams = scene_info.test_cameras
pd = PlyData.read(os.path.join(cell, "point_cloud", "iteration_30000", "point_cloud.ply"))
v = pd['vertex']
device = torch.device("cuda")
def t32(x): return torch.tensor(x, dtype=torch.float32, device=device)
means = t32(np.stack((v['x'], v['y'], v['z']), -1))
quats = t32(np.stack((v['rot_0'], v['rot_1'], v['rot_2'], v['rot_3']), -1))
scales = torch.exp(t32(np.stack((v['scale_0'], v['scale_1'], v['scale_2']), -1)))
opacities = torch.sigmoid(t32(np.asarray(v['opacity'])))
f_dc = t32(np.stack((v['f_dc_0'], v['f_dc_1'], v['f_dc_2']), -1)).unsqueeze(1)
f_rest = t32(np.stack([v[f'f_rest_{i}'] for i in range(45)], -1)).view(-1, 3, 15).transpose(1, 2)
shs = torch.cat([f_dc, f_rest], dim=1)
bg_color = torch.tensor([0., 0., 0.], device=device)
# exposure name 映射策略:尝试多种 key 形式
def get_exp(image_name):
candidates = [image_name,
os.path.splitext(image_name)[0],
image_name.lower(),
os.path.splitext(image_name)[0].lower()]
for c in candidates:
if c in exposures:
return torch.tensor(np.asarray(exposures[c]), dtype=torch.float32, device=device)
return None
# sanity: check key匹配率
matched = sum(1 for c in test_cams if get_exp(c.image_name) is not None)
print(f"Exposure key-match rate: {matched} / {len(test_cams)} test cams")
if matched == 0:
print(f" test cam names: {[c.image_name for c in test_cams[:3]]}")
print(f" exposure keys : {list(exposures.keys())[:3]}")
print("[!] 无法匹配,请检查 key 命名")
return
raw_psnrs, corr_psnrs = [], []
print(f"\n {'i':>3} {'cam':>14} {'raw':>8} {'corrected':>10} {'gain':>8}")
print(" " + "-" * 50)
for i, c in enumerate(test_cams):
w = int(round(c.width / resolution)); h = int(round(c.height / resolution))
viewmat = t32(getWorld2View2(np.array(c.R), np.array(c.T))).unsqueeze(0)
fx = w / (2 * math.tan(c.FovX / 2)); fy = h / (2 * math.tan(c.FovY / 2))
K = t32(np.array([[fx,0,w/2],[0,fy,h/2],[0,0,1]])).unsqueeze(0)
with torch.no_grad():
colors, _, _ = gsplat.rasterization(
means=means, quats=quats, scales=scales, opacities=opacities,
colors=shs, viewmats=viewmat, Ks=K, width=w, height=h,
sh_degree=3, packed=True, render_mode='RGB',
backgrounds=bg_color.unsqueeze(0))
render = colors[0].clamp(0, 1) # (H, W, 3)
pil = Image.open(os.path.join(source_path, img_dir, c.image_name))
pil = ImageOps.exif_transpose(pil)
gt_chw = PILtoTorch(pil, (w, h)).to(device)
if gt_chw.shape[0] == 4: gt_chw = gt_chw[:3]
gt = gt_chw.permute(1, 2, 0) # (H, W, 3)
mse_raw = ((render - gt) ** 2).mean().item()
raw_p = 10 * math.log10(1.0 / max(mse_raw, 1e-10))
raw_psnrs.append(raw_p)
exp_mat = get_exp(c.image_name)
if exp_mat is not None:
A = exp_mat[:3, :3]; b = exp_mat[:3, 3] if exp_mat.shape[1] >= 4 else torch.zeros(3, device=device)
# render (H,W,3) @ A.T + b (3DGS 官方公式)
corrected = render @ A.T + b
corrected = corrected.clamp(0, 1)
mse_c = ((corrected - gt) ** 2).mean().item()
corr_p = 10 * math.log10(1.0 / max(mse_c, 1e-10))
else:
corr_p = raw_p
corr_psnrs.append(corr_p)
if i < 15 or (corr_p - raw_p) > 2:
print(f" {i:>3} {c.image_name:>14} {raw_p:>8.2f} {corr_p:>10.2f} {corr_p - raw_p:>+8.2f}")
raw_arr, corr_arr = np.array(raw_psnrs), np.array(corr_psnrs)
print(f"\n Ours RAW mean PSNR: {raw_arr.mean():.4f} dB")
print(f" Ours CORRECTED mean PSNR: {corr_arr.mean():.4f} dB")
native_psnr = None
bp = os.path.join(cell, "metrics_test_iter30000.json")
if os.path.exists(bp):
bd = json.load(open(bp))
native_psnr = bd.get("photometric", {}).get("PSNR", bd.get("PSNR"))
print(f" Native baseline PSNR : {native_psnr:.4f} dB")
if native_psnr:
print(f" Δ raw : {raw_arr.mean() - native_psnr:+.4f} dB")
print(f" Δ corrected : {corr_arr.mean() - native_psnr:+.4f} dB")
if abs(corr_arr.mean() - native_psnr) < 0.1:
print(f"\n [VERDICT CONFIRMED] 曝光补偿就是根因。")
elif abs(corr_arr.mean() - native_psnr) < abs(raw_arr.mean() - native_psnr) - 0.3:
print(f"\n [PARTIAL] 曝光补偿解释了大部分差异,但残差>0.3dB,可能还有次级因素")
else:
print(f"\n [REJECTED] 曝光补偿影响有限,需另寻根因")
def main():
for scene, label in SCENES:
print(f"\n\n{'#'*70}\n# SCENE: {scene} [{label}]\n{'#'*70}")
exp = probe_exposure_structure(scene)
probe_exposure_corrected(scene, exp)
if __name__ == "__main__":
main()