#!/usr/bin/env python3 import argparse import json import math import sys import time from pathlib import Path import numpy as np from plyfile import PlyData from scipy.spatial import cKDTree SCENE_MAP = { "barn": "Barn", "caterpillar": "Caterpillar", "truck": "Truck", "church": "Church", "courthouse": "Courthouse", "ignatius": "Ignatius", "meetingroom": "Meetingroom", } TAU_DICT = { "barn": 0.01, "caterpillar": 0.005, "truck": 0.005, "church": 0.01, "courthouse": 0.01, "ignatius": 0.003, "meetingroom": 0.01, } ITER_DICT = { "gaussian_surfel": 15000, "hogs": 50000, } DEFAULT_ITER = 30000 def vlog(enabled, *args): if enabled: print(*args, flush=True) def as_float(x): if isinstance(x, (np.floating, np.integer)): return x.item() if isinstance(x, np.ndarray): return x.tolist() return x def locate_recon_ply(outputs_root: Path, method: str, scene: str, iteration=None) -> Path: scene = scene.lower() if iteration is None: iteration = ITER_DICT.get(method, DEFAULT_ITER) candidates = [] # Important special case from user's training history. if method == "vanilla_3dgs": candidates.append( outputs_root / f"{method}_{scene}_bak" / "point_cloud" / f"iteration_{iteration}" / "point_cloud.ply" ) candidates.append( outputs_root / f"{method}_{scene}" / "point_cloud" / f"iteration_{iteration}" / "point_cloud.ply" ) # Fallback: if requested iteration is absent, search any iteration folder. for base in [outputs_root / f"{method}_{scene}_bak", outputs_root / f"{method}_{scene}"]: if base.exists(): candidates.extend(sorted(base.glob("point_cloud/iteration_*/point_cloud.ply"))) for p in candidates: if p.exists(): return p msg = "Could not find reconstruction PLY. Tried:\n" + "\n".join(str(p) for p in candidates) raise FileNotFoundError(msg) def load_vertex_data(ply_path: Path): ply = PlyData.read(str(ply_path)) if "vertex" not in ply: raise ValueError(f"No vertex element in PLY: {ply_path}") v = ply["vertex"].data names = v.dtype.names for k in ["x", "y", "z"]: if k not in names: raise ValueError(f"PLY missing vertex property '{k}': {ply_path}") xyz = np.stack([v["x"], v["y"], v["z"]], axis=1).astype(np.float32, copy=False) return xyz, v, names def load_xyz_only(ply_path: Path): xyz, _, _ = load_vertex_data(ply_path) return xyz def read_transform(path: Path): mat = np.loadtxt(path).astype(np.float64) if mat.shape != (4, 4): raise ValueError(f"Transform must be 4x4, got {mat.shape}: {path}") return mat def apply_transform(xyz: np.ndarray, mat: np.ndarray): # Homogeneous column convention: x_gt = M @ [x, y, z, 1]. out = xyz.astype(np.float64) @ mat[:3, :3].T + mat[:3, 3][None, :] return out.astype(np.float32) def axis_to_index(axis): if isinstance(axis, int): return axis s = str(axis).lower() if s in ["0", "x", "axis_x"]: return 0 if s in ["1", "y", "axis_y"]: return 1 if s in ["2", "z", "axis_z"]: return 2 raise ValueError(f"Unknown orthogonal_axis: {axis}") def points_in_polygon(points2: np.ndarray, polygon2: np.ndarray): # Try matplotlib's optimized path first. Fallback to vectorized ray casting. try: from matplotlib.path import Path as MplPath return MplPath(polygon2).contains_points(points2) except Exception: x = points2[:, 0] y = points2[:, 1] poly = polygon2 inside = np.zeros(points2.shape[0], dtype=bool) j = len(poly) - 1 for i in range(len(poly)): xi, yi = poly[i] xj, yj = poly[j] denom = (yj - yi) if abs(denom) < 1e-30: denom = 1e-30 intersect = ((yi > y) != (yj > y)) & ( x < (xj - xi) * (y - yi) / denom + xi ) inside ^= intersect j = i return inside def crop_mask_tnt(xyz: np.ndarray, crop_json: dict): axis = axis_to_index(crop_json["orthogonal_axis"]) axis_min = float(crop_json["axis_min"]) axis_max = float(crop_json["axis_max"]) axis_mask = (xyz[:, axis] >= axis_min) & (xyz[:, axis] <= axis_max) poly = np.asarray(crop_json["bounding_polygon"], dtype=np.float64) other_axes = [i for i in [0, 1, 2] if i != axis] if poly.ndim != 2: raise ValueError(f"Invalid bounding_polygon shape: {poly.shape}") if poly.shape[1] == 3: poly2 = poly[:, other_axes] elif poly.shape[1] == 2: poly2 = poly else: raise ValueError(f"Invalid bounding_polygon shape: {poly.shape}") idx = np.where(axis_mask)[0] final_mask = np.zeros(xyz.shape[0], dtype=bool) if idx.size == 0: return final_mask pts2 = xyz[idx][:, other_axes] inside = points_in_polygon(pts2, poly2) final_mask[idx] = inside return final_mask def load_crop(path: Path): with open(path, "r") as f: return json.load(f) def build_tree(points: np.ndarray, name: str, verbose=False): if len(points) == 0: raise ValueError(f"Cannot build KDTree for empty point set: {name}") vlog(verbose, f"[KDTree] building {name}: {len(points):,} points") return cKDTree(points) def query_tree(tree, points, batch_size=250000): n = len(points) dists = np.empty(n, dtype=np.float32) inds = np.empty(n, dtype=np.int64) for s in range(0, n, batch_size): e = min(s + batch_size, n) try: d, idx = tree.query(points[s:e], k=1, workers=-1) except TypeError: d, idx = tree.query(points[s:e], k=1) dists[s:e] = d.astype(np.float32) inds[s:e] = idx.astype(np.int64) return dists, inds def compute_metrics(recon: np.ndarray, gt: np.ndarray, tau: float, batch_size=250000, verbose=False): gt_tree = build_tree(gt, "GT", verbose) d_recon_to_gt, _ = query_tree(gt_tree, recon, batch_size=batch_size) precision = float(np.mean(d_recon_to_gt < tau)) chamfer_r2g = float(np.mean(d_recon_to_gt)) recon_tree = build_tree(recon, "reconstruction", verbose) d_gt_to_recon, _ = query_tree(recon_tree, gt, batch_size=batch_size) recall = float(np.mean(d_gt_to_recon < tau)) chamfer_g2r = float(np.mean(d_gt_to_recon)) if precision + recall > 0: f_score = float(2.0 * precision * recall / (precision + recall)) else: f_score = 0.0 # User requested "mean of bidirectional"; keep sum too for transparency. chamfer_mean = float(0.5 * (chamfer_r2g + chamfer_g2r)) chamfer_sum = float(chamfer_r2g + chamfer_g2r) return { "precision": precision, "recall": recall, "f_score": f_score, "chamfer_distance": chamfer_mean, "chamfer_distance_sum": chamfer_sum, "chamfer_recon_to_gt": chamfer_r2g, "chamfer_gt_to_recon": chamfer_g2r, } def choose_eval_indices(n: int, mode: str, n_sample: int, seed: int): if mode == "all": return np.arange(n, dtype=np.int64) if mode == "subsample": k = min(n, n_sample) rng = np.random.default_rng(seed) return np.sort(rng.choice(n, size=k, replace=False)).astype(np.int64) raise ValueError(f"Unsupported mode: {mode}") def get_normals_from_gt_vertex(v, names): candidates = [ ("nx", "ny", "nz"), ("normal_x", "normal_y", "normal_z"), ] for a, b, c in candidates: if a in names and b in names and c in names: normals = np.stack([v[a], v[b], v[c]], axis=1).astype(np.float32) denom = np.linalg.norm(normals, axis=1, keepdims=True) + 1e-12 return normals / denom return None def gaussian_normals_from_vertex(v, names, indices, transform_mat): required = ["scale_0", "scale_1", "scale_2", "rot_0", "rot_1", "rot_2", "rot_3"] if not all(k in names for k in required): return None idx = indices scales = np.stack([v["scale_0"][idx], v["scale_1"][idx], v["scale_2"][idx]], axis=1) min_axis = np.argmin(scales, axis=1) q = np.stack([v["rot_0"][idx], v["rot_1"][idx], v["rot_2"][idx], v["rot_3"][idx]], axis=1).astype(np.float64) q = q / (np.linalg.norm(q, axis=1, keepdims=True) + 1e-12) w, x, y, z = q[:, 0], q[:, 1], q[:, 2], q[:, 3] # Rotation matrix for quaternion [w, x, y, z]. r00 = 1 - 2 * (y * y + z * z) r01 = 2 * (x * y - w * z) r02 = 2 * (x * z + w * y) r10 = 2 * (x * y + w * z) r11 = 1 - 2 * (x * x + z * z) r12 = 2 * (y * z - w * x) r20 = 2 * (x * z - w * y) r21 = 2 * (y * z + w * x) r22 = 1 - 2 * (x * x + y * y) normals = np.empty((len(idx), 3), dtype=np.float64) mask0 = min_axis == 0 normals[mask0, 0] = r00[mask0] normals[mask0, 1] = r10[mask0] normals[mask0, 2] = r20[mask0] mask1 = min_axis == 1 normals[mask1, 0] = r01[mask1] normals[mask1, 1] = r11[mask1] normals[mask1, 2] = r21[mask1] mask2 = min_axis == 2 normals[mask2, 0] = r02[mask2] normals[mask2, 1] = r12[mask2] normals[mask2, 2] = r22[mask2] # For near-similarity transforms this is equivalent up to scale. A = transform_mat[:3, :3].astype(np.float64) normals = normals @ A.T normals = normals / (np.linalg.norm(normals, axis=1, keepdims=True) + 1e-12) return normals.astype(np.float32) def compute_normal_error( recon_eval_points, recon_eval_raw_indices, recon_vertex, recon_names, gt_crop_points, gt_crop_normals, transform_mat, batch_size, ): if gt_crop_normals is None: return {"normal_error_available": False, "normal_error_reason": "GT PLY has no normal fields"} recon_normals = gaussian_normals_from_vertex( recon_vertex, recon_names, recon_eval_raw_indices, transform_mat ) if recon_normals is None: return { "normal_error_available": False, "normal_error_reason": "reconstruction PLY has no scale_*/rot_* fields", } gt_tree = cKDTree(gt_crop_points) _, nn_idx = query_tree(gt_tree, recon_eval_points, batch_size=batch_size) nn_gt_normals = gt_crop_normals[nn_idx] dots = np.sum(recon_normals * nn_gt_normals, axis=1) dots = np.clip(np.abs(dots), 0.0, 1.0) angles = np.degrees(np.arccos(dots)) return { "normal_error_available": True, "normal_angular_error_median_deg": float(np.median(angles)), "normal_angular_error_q25_deg": float(np.quantile(angles, 0.25)), "normal_angular_error_q75_deg": float(np.quantile(angles, 0.75)), "normal_angular_error_iqr_deg": float(np.quantile(angles, 0.75) - np.quantile(angles, 0.25)), "normal_error_n_points": int(len(angles)), } def main(): parser = argparse.ArgumentParser() parser.add_argument("--method", required=True) parser.add_argument("--scene", required=True, choices=sorted(SCENE_MAP.keys())) parser.add_argument("--project-root", default="/root/autodl-tmp/SplatAtlas") parser.add_argument("--outputs-root", default=None) parser.add_argument("--tnt-eval-root", default=None) parser.add_argument("--iteration", type=int, default=None) parser.add_argument("--mode", choices=["all", "subsample"], default="subsample") parser.add_argument("--n-sample", type=int, default=200000) parser.add_argument("--seed", type=int, default=0) parser.add_argument("--batch-size", type=int, default=250000) parser.add_argument("--normal-error", action="store_true") parser.add_argument("--verbose", action="store_true") args = parser.parse_args() t0 = time.time() project_root = Path(args.project_root) outputs_root = Path(args.outputs_root) if args.outputs_root else project_root / "outputs" tnt_eval_root = Path(args.tnt_eval_root) if args.tnt_eval_root else project_root / "data" / "tnt_eval" scene_lower = args.scene.lower() official_scene = SCENE_MAP[scene_lower] tau = TAU_DICT[scene_lower] ply_path = locate_recon_ply(outputs_root, args.method, scene_lower, args.iteration) scene_eval_dir = tnt_eval_root / official_scene gt_ply_path = scene_eval_dir / f"{official_scene}.ply" crop_path = scene_eval_dir / f"{official_scene}.json" trans_path = scene_eval_dir / f"{official_scene}_trans.txt" for p in [gt_ply_path, crop_path, trans_path]: if not p.exists(): raise FileNotFoundError(f"Missing required eval file: {p}") vlog(args.verbose, "=" * 80) vlog(args.verbose, f"method: {args.method}") vlog(args.verbose, f"scene: {scene_lower} -> {official_scene}") vlog(args.verbose, f"tau: {tau}") vlog(args.verbose, f"recon ply: {ply_path}") vlog(args.verbose, f"gt ply: {gt_ply_path}") vlog(args.verbose, f"crop json: {crop_path}") vlog(args.verbose, f"trans: {trans_path}") trans = read_transform(trans_path) crop_json = load_crop(crop_path) vlog(args.verbose, "\n[Load] reconstruction PLY") recon_xyz_raw, recon_vertex, recon_names = load_vertex_data(ply_path) n_gaussians_recon = int(len(recon_xyz_raw)) vlog(args.verbose, f"n_gaussians_recon: {n_gaussians_recon:,}") vlog(args.verbose, "\n[Transform] applying T&T trans matrix") vlog(args.verbose, trans) recon_xyz_aligned = apply_transform(recon_xyz_raw, trans) vlog(args.verbose, "\n[Crop] reconstruction") recon_crop_mask = crop_mask_tnt(recon_xyz_aligned, crop_json) recon_crop_indices_raw = np.where(recon_crop_mask)[0].astype(np.int64) recon_crop = recon_xyz_aligned[recon_crop_mask] vlog(args.verbose, f"recon before crop: {len(recon_xyz_aligned):,}") vlog(args.verbose, f"recon after crop: {len(recon_crop):,}") if len(recon_crop) == 0: raise RuntimeError("Reconstruction crop is empty. Likely coordinate-system/alignment mismatch.") vlog(args.verbose, "\n[Load] GT PLY") gt_xyz_raw, gt_vertex, gt_names = load_vertex_data(gt_ply_path) vlog(args.verbose, f"GT raw points: {len(gt_xyz_raw):,}") vlog(args.verbose, "\n[Crop] GT") gt_crop_mask = crop_mask_tnt(gt_xyz_raw, crop_json) gt_crop = gt_xyz_raw[gt_crop_mask] vlog(args.verbose, f"GT after crop: {len(gt_crop):,}") if len(gt_crop) == 0: raise RuntimeError("GT crop is empty. Crop parser is likely wrong.") eval_idx_in_crop = choose_eval_indices( len(recon_crop), mode=args.mode, n_sample=args.n_sample, seed=args.seed ) recon_eval = recon_crop[eval_idx_in_crop] recon_eval_raw_indices = recon_crop_indices_raw[eval_idx_in_crop] vlog(args.verbose, "\n[Eval set]") vlog(args.verbose, f"mode: {args.mode}") vlog(args.verbose, f"recon eval points: {len(recon_eval):,}") vlog(args.verbose, f"GT eval points: {len(gt_crop):,}") metrics = compute_metrics( recon_eval, gt_crop, tau=tau, batch_size=args.batch_size, verbose=args.verbose, ) normal_metrics = {} if args.normal_error: vlog(args.verbose, "\n[Normal error]") gt_normals_raw = get_normals_from_gt_vertex(gt_vertex, gt_names) gt_normals_crop = gt_normals_raw[gt_crop_mask] if gt_normals_raw is not None else None normal_metrics = compute_normal_error( recon_eval_points=recon_eval, recon_eval_raw_indices=recon_eval_raw_indices, recon_vertex=recon_vertex, recon_names=recon_names, gt_crop_points=gt_crop, gt_crop_normals=gt_normals_crop, transform_mat=trans, batch_size=args.batch_size, ) vlog(args.verbose, normal_metrics) wall_time = time.time() - t0 result = { "method": args.method, "scene": scene_lower, "official_scene": official_scene, "eval_protocol": "tnt_point_cloud_direct_gaussian_centers", "mode": args.mode, "n_sample_requested": int(args.n_sample), "seed": int(args.seed), "ply_path": str(ply_path), "gt_ply_path": str(gt_ply_path), "crop_path": str(crop_path), "trans_path": str(trans_path), "n_gaussians_recon": n_gaussians_recon, "n_points_after_crop": int(len(recon_crop)), "n_points_eval_recon": int(len(recon_eval)), "n_points_gt_raw": int(len(gt_xyz_raw)), "n_points_gt": int(len(gt_crop)), "tau": float(tau), **metrics, **normal_metrics, "wall_time_seconds": float(wall_time), } out_dir = outputs_root / "tnt_eval" / f"{args.method}_{scene_lower}" out_dir.mkdir(parents=True, exist_ok=True) out_path = out_dir / "eval.json" with open(out_path, "w") as f: json.dump(result, f, indent=2, sort_keys=True) print(json.dumps(result, indent=2, sort_keys=True)) print(f"\n[WROTE] {out_path}") if __name__ == "__main__": try: main() except Exception as e: print(f"[ERROR] {type(e).__name__}: {e}", file=sys.stderr) raise