SplatAtlas / scripts /eval_tnt_normals_v3_vs_pca.py
KCBtheone's picture
Upload SplatAtlas benchmark pipeline code
23e73f9 verified
Raw
History Blame Contribute Delete
13.3 kB
#!/usr/bin/env python3
import argparse
import json
import sys
import time
from pathlib import Path
import numpy as np
from scipy.spatial import cKDTree
sys.path.append(str(Path(__file__).resolve().parent))
import eval_tnt_wrapper as W
def query_tree(tree, points, k=1, batch_size=50000):
ds, inds = [], []
for s in range(0, len(points), batch_size):
e = min(s + batch_size, len(points))
try:
d, i = tree.query(points[s:e], k=k, workers=-1)
except TypeError:
d, i = tree.query(points[s:e], k=k)
ds.append(d)
inds.append(i)
return np.concatenate(ds, axis=0), np.concatenate(inds, axis=0)
def pca_normals_from_neighbors(points, neighbor_indices, batch_size=50000):
"""
points: [N, 3]
neighbor_indices: [M, K]
return normals: [M, 3]
"""
if neighbor_indices.ndim != 2:
raise ValueError("neighbor_indices must be [M, K]")
M = neighbor_indices.shape[0]
K = neighbor_indices.shape[1]
normals = np.empty((M, 3), dtype=np.float32)
for s in range(0, M, batch_size):
e = min(s + batch_size, M)
neigh = points[neighbor_indices[s:e]] # [B, K, 3]
centered = neigh - neigh.mean(axis=1, keepdims=True)
cov = np.einsum("bki,bkj->bij", centered, centered) / max(K - 1, 1)
vals, vecs = np.linalg.eigh(cov)
n = vecs[:, :, 0]
n = n / (np.linalg.norm(n, axis=1, keepdims=True) + 1e-12)
normals[s:e] = n.astype(np.float32)
return normals
def gaussian_v3_normals_from_vertex(v, names, indices, transform_mat):
required = ["scale_0", "scale_1", "scale_2", "rot_0", "rot_1", "rot_2", "rot_3"]
missing = [k for k in required if k not in names]
if missing:
raise ValueError(f"reconstruction PLY missing Gaussian fields: {missing}")
idx = indices
scales = np.stack(
[v["scale_0"][idx], v["scale_1"][idx], v["scale_2"][idx]],
axis=1,
).astype(np.float64)
min_axis = np.argmin(scales, axis=1)
q = np.stack(
[v["rot_0"][idx], v["rot_1"][idx], v["rot_2"][idx], v["rot_3"][idx]],
axis=1,
).astype(np.float64)
q = q / (np.linalg.norm(q, axis=1, keepdims=True) + 1e-12)
# 3DGS convention: quaternion [w, x, y, z]
w, x, y, z = q[:, 0], q[:, 1], q[:, 2], q[:, 3]
r00 = 1 - 2 * (y * y + z * z)
r01 = 2 * (x * y - w * z)
r02 = 2 * (x * z + w * y)
r10 = 2 * (x * y + w * z)
r11 = 1 - 2 * (x * x + z * z)
r12 = 2 * (y * z - w * x)
r20 = 2 * (x * z - w * y)
r21 = 2 * (y * z + w * x)
r22 = 1 - 2 * (x * x + y * y)
normals = np.empty((len(idx), 3), dtype=np.float64)
m0 = min_axis == 0
normals[m0, 0] = r00[m0]
normals[m0, 1] = r10[m0]
normals[m0, 2] = r20[m0]
m1 = min_axis == 1
normals[m1, 0] = r01[m1]
normals[m1, 1] = r11[m1]
normals[m1, 2] = r21[m1]
m2 = min_axis == 2
normals[m2, 0] = r02[m2]
normals[m2, 1] = r12[m2]
normals[m2, 2] = r22[m2]
# Apply linear part of T&T transform to direction.
A = transform_mat[:3, :3].astype(np.float64)
normals = normals @ A.T
normals = normals / (np.linalg.norm(normals, axis=1, keepdims=True) + 1e-12)
scale_min = np.min(scales, axis=1)
scale_max = np.max(scales, axis=1)
anisotropy = np.exp(scale_max - scale_min)
return normals.astype(np.float32), anisotropy.astype(np.float32)
def angular_errors(pred_normals, ref_normals):
dots = np.sum(pred_normals * ref_normals, axis=1)
dots = np.clip(np.abs(dots), 0.0, 1.0)
return np.degrees(np.arccos(dots))
def stats(prefix, angles):
return {
f"{prefix}_mean_deg": float(np.mean(angles)),
f"{prefix}_median_deg": float(np.median(angles)),
f"{prefix}_q10_deg": float(np.quantile(angles, 0.10)),
f"{prefix}_q25_deg": float(np.quantile(angles, 0.25)),
f"{prefix}_q75_deg": float(np.quantile(angles, 0.75)),
f"{prefix}_q90_deg": float(np.quantile(angles, 0.90)),
f"{prefix}_iqr_deg": float(np.quantile(angles, 0.75) - np.quantile(angles, 0.25)),
}
def sigmoid(x):
return 1.0 / (1.0 + np.exp(-x))
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--method", default="gof")
ap.add_argument("--scene", required=True, choices=sorted(W.SCENE_MAP.keys()))
ap.add_argument("--project-root", default="/root/autodl-tmp/SplatAtlas")
ap.add_argument("--outputs-root", default=None)
ap.add_argument("--tnt-eval-root", default=None)
ap.add_argument("--iteration", type=int, default=None)
ap.add_argument("--mode", choices=["all", "subsample"], default="subsample")
ap.add_argument("--n-sample", type=int, default=200000)
ap.add_argument("--seed", type=int, default=0)
ap.add_argument("--distance-multiplier", type=float, default=2.0)
ap.add_argument("--gt-normal-k", type=int, default=30)
ap.add_argument("--recon-pca-k", type=int, default=30)
ap.add_argument("--max-normal-points", type=int, default=50000)
ap.add_argument("--batch-size", type=int, default=50000)
ap.add_argument("--verbose", action="store_true")
args = ap.parse_args()
t0 = time.time()
project_root = Path(args.project_root)
outputs_root = Path(args.outputs_root) if args.outputs_root else project_root / "outputs"
tnt_eval_root = Path(args.tnt_eval_root) if args.tnt_eval_root else project_root / "data" / "tnt_eval"
scene = args.scene.lower()
official_scene = W.SCENE_MAP[scene]
tau = W.TAU_DICT[scene]
near_threshold = args.distance_multiplier * tau
ply_path = W.locate_recon_ply(outputs_root, args.method, scene, args.iteration)
scene_eval_dir = tnt_eval_root / official_scene
gt_ply_path = scene_eval_dir / f"{official_scene}.ply"
crop_path = scene_eval_dir / f"{official_scene}.json"
trans_path = scene_eval_dir / f"{official_scene}_trans.txt"
trans = W.read_transform(trans_path)
crop = W.load_crop(crop_path)
if args.verbose:
print("=" * 80)
print("V3 normal vs Gaussian-position PCA normal")
print("method:", args.method)
print("scene:", scene, "->", official_scene)
print("tau:", tau)
print("surface-near threshold:", near_threshold)
print("recon:", ply_path)
print("gt:", gt_ply_path)
# Load reconstruction.
recon_raw, recon_vertex, recon_names = W.load_vertex_data(ply_path)
recon_aligned = W.apply_transform(recon_raw, trans)
recon_crop_mask = W.crop_mask_tnt(recon_aligned, crop)
recon_crop = recon_aligned[recon_crop_mask]
recon_crop_raw_idx = np.where(recon_crop_mask)[0].astype(np.int64)
if len(recon_crop) == 0:
raise RuntimeError("Reconstruction crop is empty.")
# Choose evaluated reconstruction centers.
eval_idx_in_crop = W.choose_eval_indices(
len(recon_crop),
mode=args.mode,
n_sample=args.n_sample,
seed=args.seed,
)
recon_eval = recon_crop[eval_idx_in_crop]
recon_eval_raw_idx = recon_crop_raw_idx[eval_idx_in_crop]
# Load GT.
gt_raw, _, _ = W.load_vertex_data(gt_ply_path)
gt_crop_mask = W.crop_mask_tnt(gt_raw, crop)
gt_crop = gt_raw[gt_crop_mask]
if len(gt_crop) == 0:
raise RuntimeError("GT crop is empty.")
# Nearest GT for surface-near filter.
gt_tree = cKDTree(gt_crop)
d_r2g, nn_gt_idx = query_tree(
gt_tree,
recon_eval,
k=1,
batch_size=args.batch_size,
)
surface_mask = d_r2g < near_threshold
surface_eval_idx = np.where(surface_mask)[0].astype(np.int64)
if len(surface_eval_idx) == 0:
raise RuntimeError("No surface-near Gaussians. Increase --distance-multiplier.")
rng = np.random.default_rng(args.seed)
if len(surface_eval_idx) > args.max_normal_points:
chosen_eval_idx = np.sort(
rng.choice(surface_eval_idx, size=args.max_normal_points, replace=False)
).astype(np.int64)
else:
chosen_eval_idx = surface_eval_idx
chosen_points = recon_eval[chosen_eval_idx]
chosen_raw_idx = recon_eval_raw_idx[chosen_eval_idx]
chosen_nearest_gt_points = gt_crop[nn_gt_idx[chosen_eval_idx]]
if args.verbose:
print("n_gaussians_recon:", len(recon_raw))
print("n_recon_after_crop:", len(recon_crop))
print("n_recon_eval:", len(recon_eval))
print("n_gt_after_crop:", len(gt_crop))
print("n_surface_near:", len(surface_eval_idx), "/", len(recon_eval))
print("n_normal_eval:", len(chosen_eval_idx))
print("distance median:", float(np.median(d_r2g)))
print("distance q10:", float(np.quantile(d_r2g, 0.10)))
print("distance q90:", float(np.quantile(d_r2g, 0.90)))
# Reference: GT local PCA normals at nearest GT points.
if args.verbose:
print("[GT PCA normals] k =", args.gt_normal_k)
_, gt_neighbor_idx = query_tree(
gt_tree,
chosen_nearest_gt_points,
k=args.gt_normal_k,
batch_size=args.batch_size,
)
gt_pca_normals = pca_normals_from_neighbors(
gt_crop,
gt_neighbor_idx,
batch_size=args.batch_size,
)
# Method 1: V3 / Gaussian minor-axis normal.
if args.verbose:
print("[V3 normals] from Gaussian scale + rotation")
v3_normals, anisotropy = gaussian_v3_normals_from_vertex(
recon_vertex,
recon_names,
chosen_raw_idx,
trans,
)
v3_angles = angular_errors(v3_normals, gt_pca_normals)
# Method 2: Gaussian-position local PCA normal.
if args.verbose:
print("[Gaussian-position PCA normals] k =", args.recon_pca_k)
recon_tree = cKDTree(recon_crop)
_, recon_neighbor_idx = query_tree(
recon_tree,
chosen_points,
k=args.recon_pca_k,
batch_size=args.batch_size,
)
recon_pca_normals = pca_normals_from_neighbors(
recon_crop,
recon_neighbor_idx,
batch_size=args.batch_size,
)
recon_pca_angles = angular_errors(recon_pca_normals, gt_pca_normals)
result = {
"method": args.method,
"scene": scene,
"official_scene": official_scene,
"eval_protocol": "gof_v3_minor_axis_vs_gaussian_position_pca_normal",
"is_official_tnt_metric": False,
"ply_path": str(ply_path),
"gt_ply_path": str(gt_ply_path),
"crop_path": str(crop_path),
"trans_path": str(trans_path),
"tau": float(tau),
"surface_near_threshold": float(near_threshold),
"surface_near_rule": f"d_recon_to_gt < {args.distance_multiplier} * tau",
"mode": args.mode,
"n_sample_requested": int(args.n_sample),
"seed": int(args.seed),
"gt_normal_k": int(args.gt_normal_k),
"recon_pca_k": int(args.recon_pca_k),
"max_normal_points": int(args.max_normal_points),
"n_gaussians_recon": int(len(recon_raw)),
"n_recon_after_crop": int(len(recon_crop)),
"n_recon_eval": int(len(recon_eval)),
"n_gt_after_crop": int(len(gt_crop)),
"n_surface_near": int(len(surface_eval_idx)),
"surface_near_ratio": float(len(surface_eval_idx) / max(len(recon_eval), 1)),
"n_normal_eval": int(len(chosen_eval_idx)),
"distance_recon_to_gt_mean": float(np.mean(d_r2g)),
"distance_recon_to_gt_median": float(np.median(d_r2g)),
"distance_recon_to_gt_q10": float(np.quantile(d_r2g, 0.10)),
"distance_recon_to_gt_q90": float(np.quantile(d_r2g, 0.90)),
**stats("v3_normal_error", v3_angles),
**stats("pca_normal_error", recon_pca_angles),
"median_delta_v3_minus_pca_deg": float(np.median(v3_angles) - np.median(recon_pca_angles)),
"mean_delta_v3_minus_pca_deg": float(np.mean(v3_angles) - np.mean(recon_pca_angles)),
"pca_better_than_v3_by_median": bool(np.median(recon_pca_angles) < np.median(v3_angles)),
"pca_within_3deg_of_v3_median": bool(np.median(recon_pca_angles) <= np.median(v3_angles) + 3.0),
"pca_within_5deg_of_v3_median": bool(np.median(recon_pca_angles) <= np.median(v3_angles) + 5.0),
"gaussian_anisotropy_exp_scale_range_median": float(np.median(anisotropy)),
"gaussian_anisotropy_exp_scale_range_q25": float(np.quantile(anisotropy, 0.25)),
"gaussian_anisotropy_exp_scale_range_q75": float(np.quantile(anisotropy, 0.75)),
"wall_time_seconds": float(time.time() - t0),
}
if "opacity" in recon_names:
opacity_raw = recon_vertex["opacity"][chosen_raw_idx].astype(np.float64)
opacity_sigmoid = sigmoid(opacity_raw)
result.update({
"opacity_sigmoid_median": float(np.median(opacity_sigmoid)),
"opacity_sigmoid_q25": float(np.quantile(opacity_sigmoid, 0.25)),
"opacity_sigmoid_q75": float(np.quantile(opacity_sigmoid, 0.75)),
})
out_dir = outputs_root / "tnt_eval_normals_v3_vs_pca" / f"{args.method}_{scene}"
out_dir.mkdir(parents=True, exist_ok=True)
out_path = out_dir / "v3_vs_pca_normal_eval.json"
with open(out_path, "w") as f:
json.dump(result, f, indent=2, sort_keys=True)
print(json.dumps(result, indent=2, sort_keys=True))
print(f"\n[WROTE] {out_path}")
if __name__ == "__main__":
main()