s23-model / test_one_sample.py
xsponenta
Orphan-vertex cleanup + apex snap + local eval harness
b1c3ec5
Raw
History Blame Contribute Delete
5.68 kB
"""Minimal local sanity check: run ONE sample through the pipeline.
Step-by-step instrumentation. If anything crashes, we know exactly where.
Designed for local Mac M4 debugging, not eval correctness.
"""
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
import sys
import time
import traceback
from pathlib import Path
SCRIPT_DIR = Path(__file__).resolve().parent
sys.path.insert(0, str(SCRIPT_DIR))
def step(name):
print(f"\n>>> {name}")
return time.time()
def done(name, t0):
dt = time.time() - t0
print(f"<<< {name}: {dt:.2f}s")
# -- Step 1: import torch + numpy ----------------------------------------------
t0 = step("Import torch/numpy")
import numpy as np
import torch
print(f" torch {torch.__version__}, MPS available: {torch.backends.mps.is_available()}")
done("torch/numpy", t0)
# -- Step 2: import pipeline modules ------------------------------------------
t0 = step("Import pipeline modules (point_fusion, model, tokenizer, varifold)")
from s23dr_2026_example.point_fusion import build_compact_scene, FuserConfig
from s23dr_2026_example.cache_scenes import (
_compute_group_and_class, _compute_smart_center_scale,
)
from s23dr_2026_example.make_sampled_cache import _priority_sample
from s23dr_2026_example.tokenizer import EdgeDepthSequenceConfig
from s23dr_2026_example.model import EdgeDepthSegmentsModel
from s23dr_2026_example.segment_postprocess import merge_vertices_iterative
from s23dr_2026_example.varifold import segments_to_vertices_edges
from s23dr_2026_example.postprocess_v2 import snap_to_point_cloud, snap_horizontal
done("pipeline imports", t0)
# -- Step 3: import script.py functions ---------------------------------------
t0 = step("Import script.py")
import script
done("script.py", t0)
# -- Step 4: try loading the dataset via streaming ----------------------------
t0 = step("Load dataset (streaming)")
from datasets import load_dataset
try:
ds = load_dataset(
'usm3d/hoho22k_2026_trainval',
split='train',
streaming=True,
trust_remote_code=True,
)
print(f" Got streaming dataset: {ds}")
except Exception:
print("Dataset load failed:")
traceback.print_exc()
sys.exit(1)
done("dataset load", t0)
# -- Step 5: pull one sample ---------------------------------------------------
t0 = step("Get first sample (this triggers data download if cold)")
try:
sample_iter = iter(ds)
sample = next(sample_iter)
print(f" Got sample. Keys: {sorted(sample.keys())[:10]}...")
print(f" order_id: {sample.get('order_id')}")
except Exception:
print("Sample iteration failed:")
traceback.print_exc()
sys.exit(1)
done("first sample", t0)
# -- Step 6: try point fusion --------------------------------------------------
t0 = step("Fuse + sample (script.fuse_and_sample)")
cfg = FuserConfig()
rng = np.random.RandomState(2718)
try:
fused = script.fuse_and_sample(sample, cfg, rng)
if fused is None:
print(" fuse_and_sample returned None")
else:
print(f" xyz_norm shape: {fused['xyz_norm'].shape}")
print(f" center: {fused['center']}, scale: {fused['scale']}")
except Exception:
print("fuse_and_sample crashed:")
traceback.print_exc()
sys.exit(1)
done("fuse_and_sample", t0)
# -- Step 7: load model checkpoint --------------------------------------------
t0 = step("Load model checkpoint")
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f" Using device: {device}")
ckpt_path = SCRIPT_DIR / "checkpoint.pt"
if not ckpt_path.exists() or ckpt_path.stat().st_size < 1000:
print(f" checkpoint.pt missing or pointer-stub ({ckpt_path.stat().st_size} bytes), downloading...")
import urllib.request
url = "https://huggingface.co/jacklangerman/s23dr-2026-submission/resolve/main/checkpoint.pt"
urllib.request.urlretrieve(url, str(ckpt_path))
print(f" downloaded ({ckpt_path.stat().st_size} bytes)")
try:
model = script.load_model(ckpt_path, device)
print(f" Model loaded: {sum(p.numel() for p in model.parameters()):,} params")
except Exception:
print("load_model crashed:")
traceback.print_exc()
sys.exit(1)
done("model load", t0)
# -- Step 8: run prediction ----------------------------------------------------
if fused is not None:
t0 = step("Run predict_sample (model forward + post-process)")
try:
pred_v, pred_e = script.predict_sample(fused, model, device)
print(f" Pred: {len(pred_v)} vertices, {len(pred_e)} edges")
except Exception:
print("predict_sample crashed:")
traceback.print_exc()
sys.exit(1)
done("predict_sample", t0)
# -- Step 9: triangulation tracks ---------------------------------------------
t0 = step("Run triangulation predict_wireframe_tracks")
try:
from triangulation import predict_wireframe_tracks
track_v, track_e = predict_wireframe_tracks(sample, min_views=3)
print(f" Tracks: {len(track_v)} vertices, {len(track_e)} edges")
except Exception:
print("triangulation crashed:")
traceback.print_exc()
done("triangulation", t0)
# -- Step 10: 2D edge filter --------------------------------------------------
t0 = step("Run edge_2d_filter")
try:
from edge_2d_filter import filter_edges_by_2d_support
pred_v2, pred_e2 = filter_edges_by_2d_support(
pred_v, pred_e, sample,
min_views_support=2, min_pixel_frac=0.25, dilate_px=4, sample_steps=20,
)
print(f" Before: {len(pred_e)} edges, after: {len(pred_e2)} edges")
except Exception:
print("edge_2d_filter crashed:")
traceback.print_exc()
done("edge_2d_filter", t0)
print("\n=== ALL STEPS COMPLETED ===")