| |
| import argparse |
| import json |
| import math |
| import sys |
| import time |
| from pathlib import Path |
|
|
| import numpy as np |
| from plyfile import PlyData |
| from scipy.spatial import cKDTree |
|
|
|
|
| SCENE_MAP = { |
| "barn": "Barn", |
| "caterpillar": "Caterpillar", |
| "truck": "Truck", |
| "church": "Church", |
| "courthouse": "Courthouse", |
| "ignatius": "Ignatius", |
| "meetingroom": "Meetingroom", |
| } |
|
|
| TAU_DICT = { |
| "barn": 0.01, |
| "caterpillar": 0.005, |
| "truck": 0.005, |
| "church": 0.01, |
| "courthouse": 0.01, |
| "ignatius": 0.003, |
| "meetingroom": 0.01, |
| } |
|
|
| ITER_DICT = { |
| "gaussian_surfel": 15000, |
| "hogs": 50000, |
| } |
|
|
| DEFAULT_ITER = 30000 |
|
|
|
|
| def vlog(enabled, *args): |
| if enabled: |
| print(*args, flush=True) |
|
|
|
|
| def as_float(x): |
| if isinstance(x, (np.floating, np.integer)): |
| return x.item() |
| if isinstance(x, np.ndarray): |
| return x.tolist() |
| return x |
|
|
|
|
| def locate_recon_ply(outputs_root: Path, method: str, scene: str, iteration=None) -> Path: |
| scene = scene.lower() |
|
|
| if iteration is None: |
| iteration = ITER_DICT.get(method, DEFAULT_ITER) |
|
|
| candidates = [] |
|
|
| |
| if method == "vanilla_3dgs": |
| candidates.append( |
| outputs_root |
| / f"{method}_{scene}_bak" |
| / "point_cloud" |
| / f"iteration_{iteration}" |
| / "point_cloud.ply" |
| ) |
|
|
| candidates.append( |
| outputs_root |
| / f"{method}_{scene}" |
| / "point_cloud" |
| / f"iteration_{iteration}" |
| / "point_cloud.ply" |
| ) |
|
|
| |
| for base in [outputs_root / f"{method}_{scene}_bak", outputs_root / f"{method}_{scene}"]: |
| if base.exists(): |
| candidates.extend(sorted(base.glob("point_cloud/iteration_*/point_cloud.ply"))) |
|
|
| for p in candidates: |
| if p.exists(): |
| return p |
|
|
| msg = "Could not find reconstruction PLY. Tried:\n" + "\n".join(str(p) for p in candidates) |
| raise FileNotFoundError(msg) |
|
|
|
|
| def load_vertex_data(ply_path: Path): |
| ply = PlyData.read(str(ply_path)) |
| if "vertex" not in ply: |
| raise ValueError(f"No vertex element in PLY: {ply_path}") |
| v = ply["vertex"].data |
| names = v.dtype.names |
| for k in ["x", "y", "z"]: |
| if k not in names: |
| raise ValueError(f"PLY missing vertex property '{k}': {ply_path}") |
| xyz = np.stack([v["x"], v["y"], v["z"]], axis=1).astype(np.float32, copy=False) |
| return xyz, v, names |
|
|
|
|
| def load_xyz_only(ply_path: Path): |
| xyz, _, _ = load_vertex_data(ply_path) |
| return xyz |
|
|
|
|
| def read_transform(path: Path): |
| mat = np.loadtxt(path).astype(np.float64) |
| if mat.shape != (4, 4): |
| raise ValueError(f"Transform must be 4x4, got {mat.shape}: {path}") |
| return mat |
|
|
|
|
| def apply_transform(xyz: np.ndarray, mat: np.ndarray): |
| |
| out = xyz.astype(np.float64) @ mat[:3, :3].T + mat[:3, 3][None, :] |
| return out.astype(np.float32) |
|
|
|
|
| def axis_to_index(axis): |
| if isinstance(axis, int): |
| return axis |
| s = str(axis).lower() |
| if s in ["0", "x", "axis_x"]: |
| return 0 |
| if s in ["1", "y", "axis_y"]: |
| return 1 |
| if s in ["2", "z", "axis_z"]: |
| return 2 |
| raise ValueError(f"Unknown orthogonal_axis: {axis}") |
|
|
|
|
| def points_in_polygon(points2: np.ndarray, polygon2: np.ndarray): |
| |
| try: |
| from matplotlib.path import Path as MplPath |
| return MplPath(polygon2).contains_points(points2) |
| except Exception: |
| x = points2[:, 0] |
| y = points2[:, 1] |
| poly = polygon2 |
| inside = np.zeros(points2.shape[0], dtype=bool) |
| j = len(poly) - 1 |
| for i in range(len(poly)): |
| xi, yi = poly[i] |
| xj, yj = poly[j] |
| denom = (yj - yi) |
| if abs(denom) < 1e-30: |
| denom = 1e-30 |
| intersect = ((yi > y) != (yj > y)) & ( |
| x < (xj - xi) * (y - yi) / denom + xi |
| ) |
| inside ^= intersect |
| j = i |
| return inside |
|
|
|
|
| def crop_mask_tnt(xyz: np.ndarray, crop_json: dict): |
| axis = axis_to_index(crop_json["orthogonal_axis"]) |
| axis_min = float(crop_json["axis_min"]) |
| axis_max = float(crop_json["axis_max"]) |
|
|
| axis_mask = (xyz[:, axis] >= axis_min) & (xyz[:, axis] <= axis_max) |
|
|
| poly = np.asarray(crop_json["bounding_polygon"], dtype=np.float64) |
| other_axes = [i for i in [0, 1, 2] if i != axis] |
|
|
| if poly.ndim != 2: |
| raise ValueError(f"Invalid bounding_polygon shape: {poly.shape}") |
|
|
| if poly.shape[1] == 3: |
| poly2 = poly[:, other_axes] |
| elif poly.shape[1] == 2: |
| poly2 = poly |
| else: |
| raise ValueError(f"Invalid bounding_polygon shape: {poly.shape}") |
|
|
| idx = np.where(axis_mask)[0] |
| final_mask = np.zeros(xyz.shape[0], dtype=bool) |
|
|
| if idx.size == 0: |
| return final_mask |
|
|
| pts2 = xyz[idx][:, other_axes] |
| inside = points_in_polygon(pts2, poly2) |
| final_mask[idx] = inside |
| return final_mask |
|
|
|
|
| def load_crop(path: Path): |
| with open(path, "r") as f: |
| return json.load(f) |
|
|
|
|
| def build_tree(points: np.ndarray, name: str, verbose=False): |
| if len(points) == 0: |
| raise ValueError(f"Cannot build KDTree for empty point set: {name}") |
| vlog(verbose, f"[KDTree] building {name}: {len(points):,} points") |
| return cKDTree(points) |
|
|
|
|
| def query_tree(tree, points, batch_size=250000): |
| n = len(points) |
| dists = np.empty(n, dtype=np.float32) |
| inds = np.empty(n, dtype=np.int64) |
|
|
| for s in range(0, n, batch_size): |
| e = min(s + batch_size, n) |
| try: |
| d, idx = tree.query(points[s:e], k=1, workers=-1) |
| except TypeError: |
| d, idx = tree.query(points[s:e], k=1) |
| dists[s:e] = d.astype(np.float32) |
| inds[s:e] = idx.astype(np.int64) |
|
|
| return dists, inds |
|
|
|
|
| def compute_metrics(recon: np.ndarray, gt: np.ndarray, tau: float, batch_size=250000, verbose=False): |
| gt_tree = build_tree(gt, "GT", verbose) |
| d_recon_to_gt, _ = query_tree(gt_tree, recon, batch_size=batch_size) |
| precision = float(np.mean(d_recon_to_gt < tau)) |
| chamfer_r2g = float(np.mean(d_recon_to_gt)) |
|
|
| recon_tree = build_tree(recon, "reconstruction", verbose) |
| d_gt_to_recon, _ = query_tree(recon_tree, gt, batch_size=batch_size) |
| recall = float(np.mean(d_gt_to_recon < tau)) |
| chamfer_g2r = float(np.mean(d_gt_to_recon)) |
|
|
| if precision + recall > 0: |
| f_score = float(2.0 * precision * recall / (precision + recall)) |
| else: |
| f_score = 0.0 |
|
|
| |
| chamfer_mean = float(0.5 * (chamfer_r2g + chamfer_g2r)) |
| chamfer_sum = float(chamfer_r2g + chamfer_g2r) |
|
|
| return { |
| "precision": precision, |
| "recall": recall, |
| "f_score": f_score, |
| "chamfer_distance": chamfer_mean, |
| "chamfer_distance_sum": chamfer_sum, |
| "chamfer_recon_to_gt": chamfer_r2g, |
| "chamfer_gt_to_recon": chamfer_g2r, |
| } |
|
|
|
|
| def choose_eval_indices(n: int, mode: str, n_sample: int, seed: int): |
| if mode == "all": |
| return np.arange(n, dtype=np.int64) |
|
|
| if mode == "subsample": |
| k = min(n, n_sample) |
| rng = np.random.default_rng(seed) |
| return np.sort(rng.choice(n, size=k, replace=False)).astype(np.int64) |
|
|
| raise ValueError(f"Unsupported mode: {mode}") |
|
|
|
|
| def get_normals_from_gt_vertex(v, names): |
| candidates = [ |
| ("nx", "ny", "nz"), |
| ("normal_x", "normal_y", "normal_z"), |
| ] |
| for a, b, c in candidates: |
| if a in names and b in names and c in names: |
| normals = np.stack([v[a], v[b], v[c]], axis=1).astype(np.float32) |
| denom = np.linalg.norm(normals, axis=1, keepdims=True) + 1e-12 |
| return normals / denom |
| return None |
|
|
|
|
| def gaussian_normals_from_vertex(v, names, indices, transform_mat): |
| required = ["scale_0", "scale_1", "scale_2", "rot_0", "rot_1", "rot_2", "rot_3"] |
| if not all(k in names for k in required): |
| return None |
|
|
| idx = indices |
| scales = np.stack([v["scale_0"][idx], v["scale_1"][idx], v["scale_2"][idx]], axis=1) |
| min_axis = np.argmin(scales, axis=1) |
|
|
| q = np.stack([v["rot_0"][idx], v["rot_1"][idx], v["rot_2"][idx], v["rot_3"][idx]], axis=1).astype(np.float64) |
| q = q / (np.linalg.norm(q, axis=1, keepdims=True) + 1e-12) |
|
|
| w, x, y, z = q[:, 0], q[:, 1], q[:, 2], q[:, 3] |
|
|
| |
| r00 = 1 - 2 * (y * y + z * z) |
| r01 = 2 * (x * y - w * z) |
| r02 = 2 * (x * z + w * y) |
|
|
| r10 = 2 * (x * y + w * z) |
| r11 = 1 - 2 * (x * x + z * z) |
| r12 = 2 * (y * z - w * x) |
|
|
| r20 = 2 * (x * z - w * y) |
| r21 = 2 * (y * z + w * x) |
| r22 = 1 - 2 * (x * x + y * y) |
|
|
| normals = np.empty((len(idx), 3), dtype=np.float64) |
|
|
| mask0 = min_axis == 0 |
| normals[mask0, 0] = r00[mask0] |
| normals[mask0, 1] = r10[mask0] |
| normals[mask0, 2] = r20[mask0] |
|
|
| mask1 = min_axis == 1 |
| normals[mask1, 0] = r01[mask1] |
| normals[mask1, 1] = r11[mask1] |
| normals[mask1, 2] = r21[mask1] |
|
|
| mask2 = min_axis == 2 |
| normals[mask2, 0] = r02[mask2] |
| normals[mask2, 1] = r12[mask2] |
| normals[mask2, 2] = r22[mask2] |
|
|
| |
| A = transform_mat[:3, :3].astype(np.float64) |
| normals = normals @ A.T |
| normals = normals / (np.linalg.norm(normals, axis=1, keepdims=True) + 1e-12) |
| return normals.astype(np.float32) |
|
|
|
|
| def compute_normal_error( |
| recon_eval_points, |
| recon_eval_raw_indices, |
| recon_vertex, |
| recon_names, |
| gt_crop_points, |
| gt_crop_normals, |
| transform_mat, |
| batch_size, |
| ): |
| if gt_crop_normals is None: |
| return {"normal_error_available": False, "normal_error_reason": "GT PLY has no normal fields"} |
|
|
| recon_normals = gaussian_normals_from_vertex( |
| recon_vertex, recon_names, recon_eval_raw_indices, transform_mat |
| ) |
| if recon_normals is None: |
| return { |
| "normal_error_available": False, |
| "normal_error_reason": "reconstruction PLY has no scale_*/rot_* fields", |
| } |
|
|
| gt_tree = cKDTree(gt_crop_points) |
| _, nn_idx = query_tree(gt_tree, recon_eval_points, batch_size=batch_size) |
| nn_gt_normals = gt_crop_normals[nn_idx] |
|
|
| dots = np.sum(recon_normals * nn_gt_normals, axis=1) |
| dots = np.clip(np.abs(dots), 0.0, 1.0) |
| angles = np.degrees(np.arccos(dots)) |
|
|
| return { |
| "normal_error_available": True, |
| "normal_angular_error_median_deg": float(np.median(angles)), |
| "normal_angular_error_q25_deg": float(np.quantile(angles, 0.25)), |
| "normal_angular_error_q75_deg": float(np.quantile(angles, 0.75)), |
| "normal_angular_error_iqr_deg": float(np.quantile(angles, 0.75) - np.quantile(angles, 0.25)), |
| "normal_error_n_points": int(len(angles)), |
| } |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--method", required=True) |
| parser.add_argument("--scene", required=True, choices=sorted(SCENE_MAP.keys())) |
| parser.add_argument("--project-root", default="/root/autodl-tmp/SplatAtlas") |
| parser.add_argument("--outputs-root", default=None) |
| parser.add_argument("--tnt-eval-root", default=None) |
| parser.add_argument("--iteration", type=int, default=None) |
| parser.add_argument("--mode", choices=["all", "subsample"], default="subsample") |
| parser.add_argument("--n-sample", type=int, default=200000) |
| parser.add_argument("--seed", type=int, default=0) |
| parser.add_argument("--batch-size", type=int, default=250000) |
| parser.add_argument("--normal-error", action="store_true") |
| parser.add_argument("--verbose", action="store_true") |
| args = parser.parse_args() |
|
|
| t0 = time.time() |
|
|
| project_root = Path(args.project_root) |
| outputs_root = Path(args.outputs_root) if args.outputs_root else project_root / "outputs" |
| tnt_eval_root = Path(args.tnt_eval_root) if args.tnt_eval_root else project_root / "data" / "tnt_eval" |
|
|
| scene_lower = args.scene.lower() |
| official_scene = SCENE_MAP[scene_lower] |
| tau = TAU_DICT[scene_lower] |
|
|
| ply_path = locate_recon_ply(outputs_root, args.method, scene_lower, args.iteration) |
|
|
| scene_eval_dir = tnt_eval_root / official_scene |
| gt_ply_path = scene_eval_dir / f"{official_scene}.ply" |
| crop_path = scene_eval_dir / f"{official_scene}.json" |
| trans_path = scene_eval_dir / f"{official_scene}_trans.txt" |
|
|
| for p in [gt_ply_path, crop_path, trans_path]: |
| if not p.exists(): |
| raise FileNotFoundError(f"Missing required eval file: {p}") |
|
|
| vlog(args.verbose, "=" * 80) |
| vlog(args.verbose, f"method: {args.method}") |
| vlog(args.verbose, f"scene: {scene_lower} -> {official_scene}") |
| vlog(args.verbose, f"tau: {tau}") |
| vlog(args.verbose, f"recon ply: {ply_path}") |
| vlog(args.verbose, f"gt ply: {gt_ply_path}") |
| vlog(args.verbose, f"crop json: {crop_path}") |
| vlog(args.verbose, f"trans: {trans_path}") |
|
|
| trans = read_transform(trans_path) |
| crop_json = load_crop(crop_path) |
|
|
| vlog(args.verbose, "\n[Load] reconstruction PLY") |
| recon_xyz_raw, recon_vertex, recon_names = load_vertex_data(ply_path) |
| n_gaussians_recon = int(len(recon_xyz_raw)) |
| vlog(args.verbose, f"n_gaussians_recon: {n_gaussians_recon:,}") |
|
|
| vlog(args.verbose, "\n[Transform] applying T&T trans matrix") |
| vlog(args.verbose, trans) |
| recon_xyz_aligned = apply_transform(recon_xyz_raw, trans) |
|
|
| vlog(args.verbose, "\n[Crop] reconstruction") |
| recon_crop_mask = crop_mask_tnt(recon_xyz_aligned, crop_json) |
| recon_crop_indices_raw = np.where(recon_crop_mask)[0].astype(np.int64) |
| recon_crop = recon_xyz_aligned[recon_crop_mask] |
| vlog(args.verbose, f"recon before crop: {len(recon_xyz_aligned):,}") |
| vlog(args.verbose, f"recon after crop: {len(recon_crop):,}") |
|
|
| if len(recon_crop) == 0: |
| raise RuntimeError("Reconstruction crop is empty. Likely coordinate-system/alignment mismatch.") |
|
|
| vlog(args.verbose, "\n[Load] GT PLY") |
| gt_xyz_raw, gt_vertex, gt_names = load_vertex_data(gt_ply_path) |
| vlog(args.verbose, f"GT raw points: {len(gt_xyz_raw):,}") |
|
|
| vlog(args.verbose, "\n[Crop] GT") |
| gt_crop_mask = crop_mask_tnt(gt_xyz_raw, crop_json) |
| gt_crop = gt_xyz_raw[gt_crop_mask] |
| vlog(args.verbose, f"GT after crop: {len(gt_crop):,}") |
|
|
| if len(gt_crop) == 0: |
| raise RuntimeError("GT crop is empty. Crop parser is likely wrong.") |
|
|
| eval_idx_in_crop = choose_eval_indices( |
| len(recon_crop), mode=args.mode, n_sample=args.n_sample, seed=args.seed |
| ) |
| recon_eval = recon_crop[eval_idx_in_crop] |
| recon_eval_raw_indices = recon_crop_indices_raw[eval_idx_in_crop] |
|
|
| vlog(args.verbose, "\n[Eval set]") |
| vlog(args.verbose, f"mode: {args.mode}") |
| vlog(args.verbose, f"recon eval points: {len(recon_eval):,}") |
| vlog(args.verbose, f"GT eval points: {len(gt_crop):,}") |
|
|
| metrics = compute_metrics( |
| recon_eval, |
| gt_crop, |
| tau=tau, |
| batch_size=args.batch_size, |
| verbose=args.verbose, |
| ) |
|
|
| normal_metrics = {} |
| if args.normal_error: |
| vlog(args.verbose, "\n[Normal error]") |
| gt_normals_raw = get_normals_from_gt_vertex(gt_vertex, gt_names) |
| gt_normals_crop = gt_normals_raw[gt_crop_mask] if gt_normals_raw is not None else None |
| normal_metrics = compute_normal_error( |
| recon_eval_points=recon_eval, |
| recon_eval_raw_indices=recon_eval_raw_indices, |
| recon_vertex=recon_vertex, |
| recon_names=recon_names, |
| gt_crop_points=gt_crop, |
| gt_crop_normals=gt_normals_crop, |
| transform_mat=trans, |
| batch_size=args.batch_size, |
| ) |
| vlog(args.verbose, normal_metrics) |
|
|
| wall_time = time.time() - t0 |
|
|
| result = { |
| "method": args.method, |
| "scene": scene_lower, |
| "official_scene": official_scene, |
| "eval_protocol": "tnt_point_cloud_direct_gaussian_centers", |
| "mode": args.mode, |
| "n_sample_requested": int(args.n_sample), |
| "seed": int(args.seed), |
| "ply_path": str(ply_path), |
| "gt_ply_path": str(gt_ply_path), |
| "crop_path": str(crop_path), |
| "trans_path": str(trans_path), |
| "n_gaussians_recon": n_gaussians_recon, |
| "n_points_after_crop": int(len(recon_crop)), |
| "n_points_eval_recon": int(len(recon_eval)), |
| "n_points_gt_raw": int(len(gt_xyz_raw)), |
| "n_points_gt": int(len(gt_crop)), |
| "tau": float(tau), |
| **metrics, |
| **normal_metrics, |
| "wall_time_seconds": float(wall_time), |
| } |
|
|
| out_dir = outputs_root / "tnt_eval" / f"{args.method}_{scene_lower}" |
| out_dir.mkdir(parents=True, exist_ok=True) |
| out_path = out_dir / "eval.json" |
|
|
| with open(out_path, "w") as f: |
| json.dump(result, f, indent=2, sort_keys=True) |
|
|
| print(json.dumps(result, indent=2, sort_keys=True)) |
| print(f"\n[WROTE] {out_path}") |
|
|
|
|
| if __name__ == "__main__": |
| try: |
| main() |
| except Exception as e: |
| print(f"[ERROR] {type(e).__name__}: {e}", file=sys.stderr) |
| raise |
|
|