subm / winner_inference.py
Neritz's picture
Add handcrafted_submission_2026 contents (model-repo form for S23DR2026 submission)
31f43c9 verified
"""Inference adapter for the winner-2025 pipeline.
Loads:
- DGCNN vertex classifier (3 heads: cls/offset/conf)
- DGCNN edge classifier (1 head)
And exposes:
- refine_winner_candidates(candidates, sample, model, device, threshold)
For each candidate, build the 4×4×4 m cubic patch with 11D point
features (winner spec), run the model, return only candidates that
pass the classification threshold and were shifted to the model's
offset.
- score_edges(vertices, sample, model, device, threshold)
For each pair of vertices within MAX_PAIR_DIST, build the 6D
cylindrical patch and ask the model whether the edge exists.
Both functions degrade gracefully if torch is missing or the checkpoint
is not found — they return None and the caller falls back to the
heuristic pipeline.
"""
from __future__ import annotations
import os
import numpy as np
from pathlib import Path
# Lazy torch import — only required at training/inference time, not at
# submission package import time.
_torch = None
_DGCNNVertexClassifier = None
_DGCNNEdgeClassifier = None
def _ensure_torch():
global _torch, _DGCNNVertexClassifier, _DGCNNEdgeClassifier
if _torch is not None:
return True
try:
import torch as _t
_torch = _t
except Exception:
return False
# Try multiple import paths for DGCNN classes:
# 1. Full package (local development)
# 2. Submission-directory copy (HF container)
for _module_path in [
"s23dr.models.dgcnn",
"dgcnn",
"submission.dgcnn",
]:
try:
_mod = __import__(_module_path, fromlist=["DGCNNVertexClassifier", "DGCNNEdgeClassifier"])
_DGCNNVertexClassifier = _mod.DGCNNVertexClassifier
_DGCNNEdgeClassifier = _mod.DGCNNEdgeClassifier
break
except Exception:
continue
if _DGCNNVertexClassifier is None:
return False
return True
def _resolve_model_path(path: str) -> str | None:
"""Try multiple locations for a model checkpoint."""
candidates = [
path,
os.path.join(os.path.dirname(__file__), os.path.basename(path)),
os.path.join(os.path.dirname(__file__), path),
os.path.basename(path),
]
for c in candidates:
if os.path.exists(c):
return c
return None
def load_vertex_model(path="checkpoints/vertex_model_dgcnn.pt", device="cuda"):
if not _ensure_torch():
return None
path = _resolve_model_path(path)
if path is None:
return None
try:
ckpt = _torch.load(path, map_location=device, weights_only=False)
state = ckpt['model'] if isinstance(ckpt, dict) and 'model' in ckpt else ckpt
model = _DGCNNVertexClassifier(in_channels=11).to(device)
model.load_state_dict(state)
model.eval()
return model
except Exception:
return None
def load_edge_model(path="checkpoints/edge_model_dgcnn.pt", device="cuda"):
if not _ensure_torch():
return None
path = _resolve_model_path(path)
if path is None:
return None
try:
ckpt = _torch.load(path, map_location=device, weights_only=False)
state = ckpt['model'] if isinstance(ckpt, dict) and 'model' in ckpt else ckpt
model = _DGCNNEdgeClassifier(in_channels=6).to(device)
model.load_state_dict(state)
model.eval()
return model
except Exception:
return None
def refine_winner_candidates(
candidates,
sample,
model,
device="cuda",
cls_threshold: float = 0.5,
apply_offset: bool = True,
batch_size: int = 64,
max_points: int = 1024,
patch_size: float = 4.0,
):
"""Run DGCNN vertex refinement on Stage 1 winner candidates.
Args:
candidates: list of dicts from generate_vertex_candidates
(each must have 'xyz' and 'point_ids').
sample: raw HF dataset entry.
model: loaded DGCNNVertexClassifier (or compatible).
device: torch device.
cls_threshold: keep candidate if sigmoid(cls_logit) ≥ threshold.
apply_offset: shift accepted candidates by predicted offset.
Returns:
list of (xyz, score) for accepted candidates, OR None on failure.
"""
if model is None or not candidates:
return None
if not _ensure_torch():
return None
try:
from hoho2025.example_solutions import convert_entry_to_human_readable
from s23dr.data_prep.patch_extraction import (
_get_all_points_with_features, _project_and_get_gestalt_labels,
extract_vertex_patch,
)
except Exception:
return None
good = convert_entry_to_human_readable(sample)
colmap_rec = good.get('colmap') or good.get('colmap_binary')
if colmap_rec is None:
return None
all_xyz, all_rgb, all_pids = _get_all_points_with_features(colmap_rec)
if len(all_xyz) == 0:
return None
depth_shapes = [(np.array(d).shape[0], np.array(d).shape[1]) for d in good['depth']]
all_gestalt = _project_and_get_gestalt_labels(
all_xyz, colmap_rec, good['gestalt'], good['image_ids'], depth_shapes,
)
patches = []
cand_idx = []
for i, cand in enumerate(candidates):
patch = extract_vertex_patch(
cand['xyz'], all_xyz, all_rgb, all_gestalt,
cand.get('point_ids', set()), all_pids,
patch_size=patch_size, max_points=max_points,
)
if patch is None:
continue
patches.append(patch)
cand_idx.append(i)
if not patches:
return []
accepted = []
with _torch.no_grad():
for start in range(0, len(patches), batch_size):
end = min(start + batch_size, len(patches))
batch = np.stack(patches[start:end], axis=0) # (B, 11, N)
x = _torch.from_numpy(batch).to(device)
cls_logits, pred_offset, pred_conf = model(x)
cls_logits = cls_logits.squeeze(-1).cpu().numpy()
pred_offset = pred_offset.cpu().numpy()
pred_conf = pred_conf.squeeze(-1).cpu().numpy()
probs = 1.0 / (1.0 + np.exp(-cls_logits))
for k in range(end - start):
if probs[k] < cls_threshold:
continue
ci = cand_idx[start + k]
xyz = candidates[ci]['xyz'].copy()
if apply_offset:
xyz = xyz + pred_offset[k]
accepted.append((xyz.astype(np.float64), float(probs[k])))
return accepted
def score_edges(
vertices: np.ndarray,
sample,
model,
device: str = "cuda",
threshold: float = 0.5,
max_pair_dist: float = 8.0,
batch_size: int = 64,
max_points: int = 1024,
):
"""Run DGCNN edge classifier over all vertex pairs within max_pair_dist.
Returns list of (i, j, prob) for pairs where the model says "edge".
"""
if model is None or vertices is None or len(vertices) < 2:
return None
if not _ensure_torch():
return None
try:
from hoho2025.example_solutions import convert_entry_to_human_readable
from s23dr.data_prep.patch_extraction import (
_get_all_points_with_features, extract_edge_patch,
)
except Exception:
return None
good = convert_entry_to_human_readable(sample)
colmap_rec = good.get('colmap') or good.get('colmap_binary')
if colmap_rec is None:
return None
all_xyz, all_rgb, _ = _get_all_points_with_features(colmap_rec)
if len(all_xyz) == 0:
return None
n = len(vertices)
pairs = []
patches = []
for i in range(n):
for j in range(i + 1, n):
dist = float(np.linalg.norm(vertices[i] - vertices[j]))
if dist > max_pair_dist:
continue
patch = extract_edge_patch(
vertices[i], vertices[j], all_xyz, all_rgb, max_points=max_points,
)
if patch is None:
continue
pairs.append((i, j))
patches.append(patch)
if not patches:
return []
out = []
with _torch.no_grad():
for start in range(0, len(patches), batch_size):
end = min(start + batch_size, len(patches))
batch = np.stack(patches[start:end], axis=0)
x = _torch.from_numpy(batch).to(device)
logits = model(x).squeeze(-1).cpu().numpy()
probs = 1.0 / (1.0 + np.exp(-logits))
for k in range(end - start):
if probs[k] >= threshold:
i, j = pairs[start + k]
out.append((int(i), int(j), float(probs[k])))
return out