SplatAtlas / scripts /eval_tnt_wrapper.py
KCBtheone's picture
Upload SplatAtlas benchmark pipeline code
23e73f9 verified
Raw
History Blame Contribute Delete
17.2 kB
#!/usr/bin/env python3
import argparse
import json
import math
import sys
import time
from pathlib import Path
import numpy as np
from plyfile import PlyData
from scipy.spatial import cKDTree
SCENE_MAP = {
"barn": "Barn",
"caterpillar": "Caterpillar",
"truck": "Truck",
"church": "Church",
"courthouse": "Courthouse",
"ignatius": "Ignatius",
"meetingroom": "Meetingroom",
}
TAU_DICT = {
"barn": 0.01,
"caterpillar": 0.005,
"truck": 0.005,
"church": 0.01,
"courthouse": 0.01,
"ignatius": 0.003,
"meetingroom": 0.01,
}
ITER_DICT = {
"gaussian_surfel": 15000,
"hogs": 50000,
}
DEFAULT_ITER = 30000
def vlog(enabled, *args):
if enabled:
print(*args, flush=True)
def as_float(x):
if isinstance(x, (np.floating, np.integer)):
return x.item()
if isinstance(x, np.ndarray):
return x.tolist()
return x
def locate_recon_ply(outputs_root: Path, method: str, scene: str, iteration=None) -> Path:
scene = scene.lower()
if iteration is None:
iteration = ITER_DICT.get(method, DEFAULT_ITER)
candidates = []
# Important special case from user's training history.
if method == "vanilla_3dgs":
candidates.append(
outputs_root
/ f"{method}_{scene}_bak"
/ "point_cloud"
/ f"iteration_{iteration}"
/ "point_cloud.ply"
)
candidates.append(
outputs_root
/ f"{method}_{scene}"
/ "point_cloud"
/ f"iteration_{iteration}"
/ "point_cloud.ply"
)
# Fallback: if requested iteration is absent, search any iteration folder.
for base in [outputs_root / f"{method}_{scene}_bak", outputs_root / f"{method}_{scene}"]:
if base.exists():
candidates.extend(sorted(base.glob("point_cloud/iteration_*/point_cloud.ply")))
for p in candidates:
if p.exists():
return p
msg = "Could not find reconstruction PLY. Tried:\n" + "\n".join(str(p) for p in candidates)
raise FileNotFoundError(msg)
def load_vertex_data(ply_path: Path):
ply = PlyData.read(str(ply_path))
if "vertex" not in ply:
raise ValueError(f"No vertex element in PLY: {ply_path}")
v = ply["vertex"].data
names = v.dtype.names
for k in ["x", "y", "z"]:
if k not in names:
raise ValueError(f"PLY missing vertex property '{k}': {ply_path}")
xyz = np.stack([v["x"], v["y"], v["z"]], axis=1).astype(np.float32, copy=False)
return xyz, v, names
def load_xyz_only(ply_path: Path):
xyz, _, _ = load_vertex_data(ply_path)
return xyz
def read_transform(path: Path):
mat = np.loadtxt(path).astype(np.float64)
if mat.shape != (4, 4):
raise ValueError(f"Transform must be 4x4, got {mat.shape}: {path}")
return mat
def apply_transform(xyz: np.ndarray, mat: np.ndarray):
# Homogeneous column convention: x_gt = M @ [x, y, z, 1].
out = xyz.astype(np.float64) @ mat[:3, :3].T + mat[:3, 3][None, :]
return out.astype(np.float32)
def axis_to_index(axis):
if isinstance(axis, int):
return axis
s = str(axis).lower()
if s in ["0", "x", "axis_x"]:
return 0
if s in ["1", "y", "axis_y"]:
return 1
if s in ["2", "z", "axis_z"]:
return 2
raise ValueError(f"Unknown orthogonal_axis: {axis}")
def points_in_polygon(points2: np.ndarray, polygon2: np.ndarray):
# Try matplotlib's optimized path first. Fallback to vectorized ray casting.
try:
from matplotlib.path import Path as MplPath
return MplPath(polygon2).contains_points(points2)
except Exception:
x = points2[:, 0]
y = points2[:, 1]
poly = polygon2
inside = np.zeros(points2.shape[0], dtype=bool)
j = len(poly) - 1
for i in range(len(poly)):
xi, yi = poly[i]
xj, yj = poly[j]
denom = (yj - yi)
if abs(denom) < 1e-30:
denom = 1e-30
intersect = ((yi > y) != (yj > y)) & (
x < (xj - xi) * (y - yi) / denom + xi
)
inside ^= intersect
j = i
return inside
def crop_mask_tnt(xyz: np.ndarray, crop_json: dict):
axis = axis_to_index(crop_json["orthogonal_axis"])
axis_min = float(crop_json["axis_min"])
axis_max = float(crop_json["axis_max"])
axis_mask = (xyz[:, axis] >= axis_min) & (xyz[:, axis] <= axis_max)
poly = np.asarray(crop_json["bounding_polygon"], dtype=np.float64)
other_axes = [i for i in [0, 1, 2] if i != axis]
if poly.ndim != 2:
raise ValueError(f"Invalid bounding_polygon shape: {poly.shape}")
if poly.shape[1] == 3:
poly2 = poly[:, other_axes]
elif poly.shape[1] == 2:
poly2 = poly
else:
raise ValueError(f"Invalid bounding_polygon shape: {poly.shape}")
idx = np.where(axis_mask)[0]
final_mask = np.zeros(xyz.shape[0], dtype=bool)
if idx.size == 0:
return final_mask
pts2 = xyz[idx][:, other_axes]
inside = points_in_polygon(pts2, poly2)
final_mask[idx] = inside
return final_mask
def load_crop(path: Path):
with open(path, "r") as f:
return json.load(f)
def build_tree(points: np.ndarray, name: str, verbose=False):
if len(points) == 0:
raise ValueError(f"Cannot build KDTree for empty point set: {name}")
vlog(verbose, f"[KDTree] building {name}: {len(points):,} points")
return cKDTree(points)
def query_tree(tree, points, batch_size=250000):
n = len(points)
dists = np.empty(n, dtype=np.float32)
inds = np.empty(n, dtype=np.int64)
for s in range(0, n, batch_size):
e = min(s + batch_size, n)
try:
d, idx = tree.query(points[s:e], k=1, workers=-1)
except TypeError:
d, idx = tree.query(points[s:e], k=1)
dists[s:e] = d.astype(np.float32)
inds[s:e] = idx.astype(np.int64)
return dists, inds
def compute_metrics(recon: np.ndarray, gt: np.ndarray, tau: float, batch_size=250000, verbose=False):
gt_tree = build_tree(gt, "GT", verbose)
d_recon_to_gt, _ = query_tree(gt_tree, recon, batch_size=batch_size)
precision = float(np.mean(d_recon_to_gt < tau))
chamfer_r2g = float(np.mean(d_recon_to_gt))
recon_tree = build_tree(recon, "reconstruction", verbose)
d_gt_to_recon, _ = query_tree(recon_tree, gt, batch_size=batch_size)
recall = float(np.mean(d_gt_to_recon < tau))
chamfer_g2r = float(np.mean(d_gt_to_recon))
if precision + recall > 0:
f_score = float(2.0 * precision * recall / (precision + recall))
else:
f_score = 0.0
# User requested "mean of bidirectional"; keep sum too for transparency.
chamfer_mean = float(0.5 * (chamfer_r2g + chamfer_g2r))
chamfer_sum = float(chamfer_r2g + chamfer_g2r)
return {
"precision": precision,
"recall": recall,
"f_score": f_score,
"chamfer_distance": chamfer_mean,
"chamfer_distance_sum": chamfer_sum,
"chamfer_recon_to_gt": chamfer_r2g,
"chamfer_gt_to_recon": chamfer_g2r,
}
def choose_eval_indices(n: int, mode: str, n_sample: int, seed: int):
if mode == "all":
return np.arange(n, dtype=np.int64)
if mode == "subsample":
k = min(n, n_sample)
rng = np.random.default_rng(seed)
return np.sort(rng.choice(n, size=k, replace=False)).astype(np.int64)
raise ValueError(f"Unsupported mode: {mode}")
def get_normals_from_gt_vertex(v, names):
candidates = [
("nx", "ny", "nz"),
("normal_x", "normal_y", "normal_z"),
]
for a, b, c in candidates:
if a in names and b in names and c in names:
normals = np.stack([v[a], v[b], v[c]], axis=1).astype(np.float32)
denom = np.linalg.norm(normals, axis=1, keepdims=True) + 1e-12
return normals / denom
return None
def gaussian_normals_from_vertex(v, names, indices, transform_mat):
required = ["scale_0", "scale_1", "scale_2", "rot_0", "rot_1", "rot_2", "rot_3"]
if not all(k in names for k in required):
return None
idx = indices
scales = np.stack([v["scale_0"][idx], v["scale_1"][idx], v["scale_2"][idx]], axis=1)
min_axis = np.argmin(scales, axis=1)
q = np.stack([v["rot_0"][idx], v["rot_1"][idx], v["rot_2"][idx], v["rot_3"][idx]], axis=1).astype(np.float64)
q = q / (np.linalg.norm(q, axis=1, keepdims=True) + 1e-12)
w, x, y, z = q[:, 0], q[:, 1], q[:, 2], q[:, 3]
# Rotation matrix for quaternion [w, x, y, z].
r00 = 1 - 2 * (y * y + z * z)
r01 = 2 * (x * y - w * z)
r02 = 2 * (x * z + w * y)
r10 = 2 * (x * y + w * z)
r11 = 1 - 2 * (x * x + z * z)
r12 = 2 * (y * z - w * x)
r20 = 2 * (x * z - w * y)
r21 = 2 * (y * z + w * x)
r22 = 1 - 2 * (x * x + y * y)
normals = np.empty((len(idx), 3), dtype=np.float64)
mask0 = min_axis == 0
normals[mask0, 0] = r00[mask0]
normals[mask0, 1] = r10[mask0]
normals[mask0, 2] = r20[mask0]
mask1 = min_axis == 1
normals[mask1, 0] = r01[mask1]
normals[mask1, 1] = r11[mask1]
normals[mask1, 2] = r21[mask1]
mask2 = min_axis == 2
normals[mask2, 0] = r02[mask2]
normals[mask2, 1] = r12[mask2]
normals[mask2, 2] = r22[mask2]
# For near-similarity transforms this is equivalent up to scale.
A = transform_mat[:3, :3].astype(np.float64)
normals = normals @ A.T
normals = normals / (np.linalg.norm(normals, axis=1, keepdims=True) + 1e-12)
return normals.astype(np.float32)
def compute_normal_error(
recon_eval_points,
recon_eval_raw_indices,
recon_vertex,
recon_names,
gt_crop_points,
gt_crop_normals,
transform_mat,
batch_size,
):
if gt_crop_normals is None:
return {"normal_error_available": False, "normal_error_reason": "GT PLY has no normal fields"}
recon_normals = gaussian_normals_from_vertex(
recon_vertex, recon_names, recon_eval_raw_indices, transform_mat
)
if recon_normals is None:
return {
"normal_error_available": False,
"normal_error_reason": "reconstruction PLY has no scale_*/rot_* fields",
}
gt_tree = cKDTree(gt_crop_points)
_, nn_idx = query_tree(gt_tree, recon_eval_points, batch_size=batch_size)
nn_gt_normals = gt_crop_normals[nn_idx]
dots = np.sum(recon_normals * nn_gt_normals, axis=1)
dots = np.clip(np.abs(dots), 0.0, 1.0)
angles = np.degrees(np.arccos(dots))
return {
"normal_error_available": True,
"normal_angular_error_median_deg": float(np.median(angles)),
"normal_angular_error_q25_deg": float(np.quantile(angles, 0.25)),
"normal_angular_error_q75_deg": float(np.quantile(angles, 0.75)),
"normal_angular_error_iqr_deg": float(np.quantile(angles, 0.75) - np.quantile(angles, 0.25)),
"normal_error_n_points": int(len(angles)),
}
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--method", required=True)
parser.add_argument("--scene", required=True, choices=sorted(SCENE_MAP.keys()))
parser.add_argument("--project-root", default="/root/autodl-tmp/SplatAtlas")
parser.add_argument("--outputs-root", default=None)
parser.add_argument("--tnt-eval-root", default=None)
parser.add_argument("--iteration", type=int, default=None)
parser.add_argument("--mode", choices=["all", "subsample"], default="subsample")
parser.add_argument("--n-sample", type=int, default=200000)
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--batch-size", type=int, default=250000)
parser.add_argument("--normal-error", action="store_true")
parser.add_argument("--verbose", action="store_true")
args = parser.parse_args()
t0 = time.time()
project_root = Path(args.project_root)
outputs_root = Path(args.outputs_root) if args.outputs_root else project_root / "outputs"
tnt_eval_root = Path(args.tnt_eval_root) if args.tnt_eval_root else project_root / "data" / "tnt_eval"
scene_lower = args.scene.lower()
official_scene = SCENE_MAP[scene_lower]
tau = TAU_DICT[scene_lower]
ply_path = locate_recon_ply(outputs_root, args.method, scene_lower, args.iteration)
scene_eval_dir = tnt_eval_root / official_scene
gt_ply_path = scene_eval_dir / f"{official_scene}.ply"
crop_path = scene_eval_dir / f"{official_scene}.json"
trans_path = scene_eval_dir / f"{official_scene}_trans.txt"
for p in [gt_ply_path, crop_path, trans_path]:
if not p.exists():
raise FileNotFoundError(f"Missing required eval file: {p}")
vlog(args.verbose, "=" * 80)
vlog(args.verbose, f"method: {args.method}")
vlog(args.verbose, f"scene: {scene_lower} -> {official_scene}")
vlog(args.verbose, f"tau: {tau}")
vlog(args.verbose, f"recon ply: {ply_path}")
vlog(args.verbose, f"gt ply: {gt_ply_path}")
vlog(args.verbose, f"crop json: {crop_path}")
vlog(args.verbose, f"trans: {trans_path}")
trans = read_transform(trans_path)
crop_json = load_crop(crop_path)
vlog(args.verbose, "\n[Load] reconstruction PLY")
recon_xyz_raw, recon_vertex, recon_names = load_vertex_data(ply_path)
n_gaussians_recon = int(len(recon_xyz_raw))
vlog(args.verbose, f"n_gaussians_recon: {n_gaussians_recon:,}")
vlog(args.verbose, "\n[Transform] applying T&T trans matrix")
vlog(args.verbose, trans)
recon_xyz_aligned = apply_transform(recon_xyz_raw, trans)
vlog(args.verbose, "\n[Crop] reconstruction")
recon_crop_mask = crop_mask_tnt(recon_xyz_aligned, crop_json)
recon_crop_indices_raw = np.where(recon_crop_mask)[0].astype(np.int64)
recon_crop = recon_xyz_aligned[recon_crop_mask]
vlog(args.verbose, f"recon before crop: {len(recon_xyz_aligned):,}")
vlog(args.verbose, f"recon after crop: {len(recon_crop):,}")
if len(recon_crop) == 0:
raise RuntimeError("Reconstruction crop is empty. Likely coordinate-system/alignment mismatch.")
vlog(args.verbose, "\n[Load] GT PLY")
gt_xyz_raw, gt_vertex, gt_names = load_vertex_data(gt_ply_path)
vlog(args.verbose, f"GT raw points: {len(gt_xyz_raw):,}")
vlog(args.verbose, "\n[Crop] GT")
gt_crop_mask = crop_mask_tnt(gt_xyz_raw, crop_json)
gt_crop = gt_xyz_raw[gt_crop_mask]
vlog(args.verbose, f"GT after crop: {len(gt_crop):,}")
if len(gt_crop) == 0:
raise RuntimeError("GT crop is empty. Crop parser is likely wrong.")
eval_idx_in_crop = choose_eval_indices(
len(recon_crop), mode=args.mode, n_sample=args.n_sample, seed=args.seed
)
recon_eval = recon_crop[eval_idx_in_crop]
recon_eval_raw_indices = recon_crop_indices_raw[eval_idx_in_crop]
vlog(args.verbose, "\n[Eval set]")
vlog(args.verbose, f"mode: {args.mode}")
vlog(args.verbose, f"recon eval points: {len(recon_eval):,}")
vlog(args.verbose, f"GT eval points: {len(gt_crop):,}")
metrics = compute_metrics(
recon_eval,
gt_crop,
tau=tau,
batch_size=args.batch_size,
verbose=args.verbose,
)
normal_metrics = {}
if args.normal_error:
vlog(args.verbose, "\n[Normal error]")
gt_normals_raw = get_normals_from_gt_vertex(gt_vertex, gt_names)
gt_normals_crop = gt_normals_raw[gt_crop_mask] if gt_normals_raw is not None else None
normal_metrics = compute_normal_error(
recon_eval_points=recon_eval,
recon_eval_raw_indices=recon_eval_raw_indices,
recon_vertex=recon_vertex,
recon_names=recon_names,
gt_crop_points=gt_crop,
gt_crop_normals=gt_normals_crop,
transform_mat=trans,
batch_size=args.batch_size,
)
vlog(args.verbose, normal_metrics)
wall_time = time.time() - t0
result = {
"method": args.method,
"scene": scene_lower,
"official_scene": official_scene,
"eval_protocol": "tnt_point_cloud_direct_gaussian_centers",
"mode": args.mode,
"n_sample_requested": int(args.n_sample),
"seed": int(args.seed),
"ply_path": str(ply_path),
"gt_ply_path": str(gt_ply_path),
"crop_path": str(crop_path),
"trans_path": str(trans_path),
"n_gaussians_recon": n_gaussians_recon,
"n_points_after_crop": int(len(recon_crop)),
"n_points_eval_recon": int(len(recon_eval)),
"n_points_gt_raw": int(len(gt_xyz_raw)),
"n_points_gt": int(len(gt_crop)),
"tau": float(tau),
**metrics,
**normal_metrics,
"wall_time_seconds": float(wall_time),
}
out_dir = outputs_root / "tnt_eval" / f"{args.method}_{scene_lower}"
out_dir.mkdir(parents=True, exist_ok=True)
out_path = out_dir / "eval.json"
with open(out_path, "w") as f:
json.dump(result, f, indent=2, sort_keys=True)
print(json.dumps(result, indent=2, sort_keys=True))
print(f"\n[WROTE] {out_path}")
if __name__ == "__main__":
try:
main()
except Exception as e:
print(f"[ERROR] {type(e).__name__}: {e}", file=sys.stderr)
raise