"""Inference adapter for the winner-2025 pipeline. Loads: - DGCNN vertex classifier (3 heads: cls/offset/conf) - DGCNN edge classifier (1 head) And exposes: - refine_winner_candidates(candidates, sample, model, device, threshold) For each candidate, build the 4×4×4 m cubic patch with 11D point features (winner spec), run the model, return only candidates that pass the classification threshold and were shifted to the model's offset. - score_edges(vertices, sample, model, device, threshold) For each pair of vertices within MAX_PAIR_DIST, build the 6D cylindrical patch and ask the model whether the edge exists. Both functions degrade gracefully if torch is missing or the checkpoint is not found — they return None and the caller falls back to the heuristic pipeline. """ from __future__ import annotations import os import numpy as np from pathlib import Path # Lazy torch import — only required at training/inference time, not at # submission package import time. _torch = None _DGCNNVertexClassifier = None _DGCNNEdgeClassifier = None def _ensure_torch(): global _torch, _DGCNNVertexClassifier, _DGCNNEdgeClassifier if _torch is not None: return True try: import torch as _t _torch = _t except Exception: return False # Try multiple import paths for DGCNN classes: # 1. Full package (local development) # 2. Submission-directory copy (HF container) for _module_path in [ "s23dr.models.dgcnn", "dgcnn", "submission.dgcnn", ]: try: _mod = __import__(_module_path, fromlist=["DGCNNVertexClassifier", "DGCNNEdgeClassifier"]) _DGCNNVertexClassifier = _mod.DGCNNVertexClassifier _DGCNNEdgeClassifier = _mod.DGCNNEdgeClassifier break except Exception: continue if _DGCNNVertexClassifier is None: return False return True def _resolve_model_path(path: str) -> str | None: """Try multiple locations for a model checkpoint.""" candidates = [ path, os.path.join(os.path.dirname(__file__), os.path.basename(path)), os.path.join(os.path.dirname(__file__), path), os.path.basename(path), ] for c in candidates: if os.path.exists(c): return c return None def load_vertex_model(path="checkpoints/vertex_model_dgcnn.pt", device="cuda"): if not _ensure_torch(): return None path = _resolve_model_path(path) if path is None: return None try: ckpt = _torch.load(path, map_location=device, weights_only=False) state = ckpt['model'] if isinstance(ckpt, dict) and 'model' in ckpt else ckpt model = _DGCNNVertexClassifier(in_channels=11).to(device) model.load_state_dict(state) model.eval() return model except Exception: return None def load_edge_model(path="checkpoints/edge_model_dgcnn.pt", device="cuda"): if not _ensure_torch(): return None path = _resolve_model_path(path) if path is None: return None try: ckpt = _torch.load(path, map_location=device, weights_only=False) state = ckpt['model'] if isinstance(ckpt, dict) and 'model' in ckpt else ckpt model = _DGCNNEdgeClassifier(in_channels=6).to(device) model.load_state_dict(state) model.eval() return model except Exception: return None def refine_winner_candidates( candidates, sample, model, device="cuda", cls_threshold: float = 0.5, apply_offset: bool = True, batch_size: int = 64, max_points: int = 1024, patch_size: float = 4.0, ): """Run DGCNN vertex refinement on Stage 1 winner candidates. Args: candidates: list of dicts from generate_vertex_candidates (each must have 'xyz' and 'point_ids'). sample: raw HF dataset entry. model: loaded DGCNNVertexClassifier (or compatible). device: torch device. cls_threshold: keep candidate if sigmoid(cls_logit) ≥ threshold. apply_offset: shift accepted candidates by predicted offset. Returns: list of (xyz, score) for accepted candidates, OR None on failure. """ if model is None or not candidates: return None if not _ensure_torch(): return None try: from hoho2025.example_solutions import convert_entry_to_human_readable from s23dr.data_prep.patch_extraction import ( _get_all_points_with_features, _project_and_get_gestalt_labels, extract_vertex_patch, ) except Exception: return None good = convert_entry_to_human_readable(sample) colmap_rec = good.get('colmap') or good.get('colmap_binary') if colmap_rec is None: return None all_xyz, all_rgb, all_pids = _get_all_points_with_features(colmap_rec) if len(all_xyz) == 0: return None depth_shapes = [(np.array(d).shape[0], np.array(d).shape[1]) for d in good['depth']] all_gestalt = _project_and_get_gestalt_labels( all_xyz, colmap_rec, good['gestalt'], good['image_ids'], depth_shapes, ) patches = [] cand_idx = [] for i, cand in enumerate(candidates): patch = extract_vertex_patch( cand['xyz'], all_xyz, all_rgb, all_gestalt, cand.get('point_ids', set()), all_pids, patch_size=patch_size, max_points=max_points, ) if patch is None: continue patches.append(patch) cand_idx.append(i) if not patches: return [] accepted = [] with _torch.no_grad(): for start in range(0, len(patches), batch_size): end = min(start + batch_size, len(patches)) batch = np.stack(patches[start:end], axis=0) # (B, 11, N) x = _torch.from_numpy(batch).to(device) cls_logits, pred_offset, pred_conf = model(x) cls_logits = cls_logits.squeeze(-1).cpu().numpy() pred_offset = pred_offset.cpu().numpy() pred_conf = pred_conf.squeeze(-1).cpu().numpy() probs = 1.0 / (1.0 + np.exp(-cls_logits)) for k in range(end - start): if probs[k] < cls_threshold: continue ci = cand_idx[start + k] xyz = candidates[ci]['xyz'].copy() if apply_offset: xyz = xyz + pred_offset[k] accepted.append((xyz.astype(np.float64), float(probs[k]))) return accepted def score_edges( vertices: np.ndarray, sample, model, device: str = "cuda", threshold: float = 0.5, max_pair_dist: float = 8.0, batch_size: int = 64, max_points: int = 1024, ): """Run DGCNN edge classifier over all vertex pairs within max_pair_dist. Returns list of (i, j, prob) for pairs where the model says "edge". """ if model is None or vertices is None or len(vertices) < 2: return None if not _ensure_torch(): return None try: from hoho2025.example_solutions import convert_entry_to_human_readable from s23dr.data_prep.patch_extraction import ( _get_all_points_with_features, extract_edge_patch, ) except Exception: return None good = convert_entry_to_human_readable(sample) colmap_rec = good.get('colmap') or good.get('colmap_binary') if colmap_rec is None: return None all_xyz, all_rgb, _ = _get_all_points_with_features(colmap_rec) if len(all_xyz) == 0: return None n = len(vertices) pairs = [] patches = [] for i in range(n): for j in range(i + 1, n): dist = float(np.linalg.norm(vertices[i] - vertices[j])) if dist > max_pair_dist: continue patch = extract_edge_patch( vertices[i], vertices[j], all_xyz, all_rgb, max_points=max_points, ) if patch is None: continue pairs.append((i, j)) patches.append(patch) if not patches: return [] out = [] with _torch.no_grad(): for start in range(0, len(patches), batch_size): end = min(start + batch_size, len(patches)) batch = np.stack(patches[start:end], axis=0) x = _torch.from_numpy(batch).to(device) logits = model(x).squeeze(-1).cpu().numpy() probs = 1.0 / (1.0 + np.exp(-logits)) for k in range(end - start): if probs[k] >= threshold: i, j = pairs[start + k] out.append((int(i), int(j), float(probs[k]))) return out