"""Post-hoc bundle adjustment of merged 3D wireframe vertices. For each vertex in ``merged_v``, we: 1. Project its current 3D position into every available view. 2. Find the nearest gestalt corner (from ``get_vertices_and_edges_improved``) in each view within ``match_px`` pixels. 3. If observations are found in ≥ ``min_views`` views, refine the 3D position to minimise the sum of squared reprojection errors via ``scipy.optimize.least_squares`` with a Huber loss. Cameras are fixed (COLMAP cameras are accurate). Only vertex positions are optimised. No thresholds are tuned — just pure geometric optimisation that converges to the correct answer given the cameras. Entry point: ``refine_vertices_ba(merged_v, entry)``. """ from __future__ import annotations import numpy as np import cv2 from scipy.optimize import least_squares from hoho2025.example_solutions import ( convert_entry_to_human_readable, filter_vertices_by_background, ) from hoho2025.color_mappings import gestalt_color_mapping try: from mvs_utils import collect_views, project_world_to_image except ImportError: from submission.mvs_utils import collect_views, project_world_to_image VERTEX_CLASSES = ['apex', 'eave_end_point', 'flashing_end_point'] def _detect_2d_corners(gest_np): """Detect 2D gestalt corners in a single view (same as pipeline). Returns (N, 2) float32 array of pixel coordinates. """ corners = [] for v_class in VERTEX_CLASSES: color = np.array(gestalt_color_mapping[v_class]) mask = cv2.inRange(gest_np, color - 0.5, color + 0.5) if mask.sum() == 0: continue _, _, _, centroids = cv2.connectedComponentsWithStats(mask, 8, cv2.CV_32S) for c in centroids[1:]: corners.append(c) if not corners: return np.empty((0, 2), dtype=np.float32) return np.array(corners, dtype=np.float32) def _collect_observations( merged_v: np.ndarray, views: dict, corners_per_view: dict[str, np.ndarray], match_px: float = 8.0, ) -> list[list[tuple[str, np.ndarray]]]: """For each vertex, find its 2D observation in each view. Returns a list (one per vertex) of lists of ``(view_id, uv_observed)``. """ n = len(merged_v) observations: list[list[tuple[str, np.ndarray]]] = [[] for _ in range(n)] for vid, info in views.items(): corners_2d = corners_per_view.get(vid) if corners_2d is None or len(corners_2d) == 0: continue P = info['P'] # Project all merged_v into this view uv, z = project_world_to_image(P, merged_v) H, W = info['height'], info['width'] for i in range(n): if z[i] <= 0: continue u, v_px = uv[i] if u < -50 or u > W + 50 or v_px < -50 or v_px > H + 50: continue # Find nearest 2D corner d = np.linalg.norm(corners_2d - uv[i], axis=1) j = int(np.argmin(d)) if d[j] <= match_px: observations[i].append((vid, corners_2d[j].copy())) return observations def _ba_residuals(params, Ps, obs_2d): """Reprojection residuals for a single 3D point. params: (3,) — x, y, z of the 3D point. Ps: list of (3, 4) projection matrices. obs_2d: list of (2,) observed 2D points. Returns: (2*N,) residual vector. """ X = params res = [] homog = np.array([X[0], X[1], X[2], 1.0]) for P, uv_obs in zip(Ps, obs_2d): proj = P @ homog if proj[2] <= 1e-6: res.extend([100.0, 100.0]) # large penalty continue u = proj[0] / proj[2] v = proj[1] / proj[2] res.extend([u - uv_obs[0], v - uv_obs[1]]) return np.array(res, dtype=np.float64) def refine_vertices_ba( merged_v: np.ndarray, entry, match_px: float = 8.0, min_views: int = 2, max_reproj_px: float = 5.0, min_initial_err_px: float = 3.0, ) -> np.ndarray: """Refine 3D vertex positions via bundle adjustment. Only vertices with observations in ≥ ``min_views`` views are refined; the rest keep their original positions. If the optimised position has a mean reprojection error > ``max_reproj_px``, the original position is kept (optimiser diverged). Parameters ---------- merged_v : (N, 3) array of vertex positions. entry : the raw dataset sample (passed to ``convert_entry_to_human_readable``). match_px : maximum pixel distance to match a projected vertex to a gestalt corner in a view. min_views : minimum number of views with a matching observation for BA to fire. max_reproj_px : if post-BA mean reprojection error exceeds this, revert to the original position. Returns ------- refined_v : (N, 3) array with refined positions. """ merged_v = np.asarray(merged_v, dtype=np.float64) refined = merged_v.copy() if len(merged_v) == 0: return refined good = convert_entry_to_human_readable(entry) colmap_rec = good.get('colmap') or good.get('colmap_binary') if colmap_rec is None: return refined views = collect_views(colmap_rec, good['image_ids']) if len(views) < 2: return refined # Detect 2D corners in each view corners_per_view: dict[str, np.ndarray] = {} for gest, depth, img_id in zip(good['gestalt'], good['depth'], good['image_ids']): if img_id not in views: continue depth_np = np.array(depth) H, W = depth_np.shape[:2] gest_np = np.array(gest.resize((W, H))).astype(np.uint8) corners_per_view[img_id] = _detect_2d_corners(gest_np) # Collect multi-view observations for each vertex observations = _collect_observations(merged_v, views, corners_per_view, match_px) # Run BA on each vertex independently. # Key: only refine vertices whose INITIAL reprojection error is high # (> min_initial_err_px). This targets the depth-estimation failures # without disturbing already-good vertices. n_refined = 0 for i in range(len(merged_v)): obs = observations[i] if len(obs) < min_views: continue Ps = [views[vid]['P'] for vid, _ in obs] pts_2d = [uv for _, uv in obs] x0 = merged_v[i].copy() # Check initial reprojection error — skip if already low. res0 = _ba_residuals(x0, Ps, pts_2d) res0_pairs = res0.reshape(-1, 2) initial_err = float(np.sqrt((res0_pairs ** 2).sum(axis=1)).mean()) if initial_err <= min_initial_err_px: continue # already well-localised, leave it alone try: result = least_squares( _ba_residuals, x0, args=(Ps, pts_2d), method='trf', loss='huber', f_scale=2.0, max_nfev=50, ) except Exception: continue X_opt = result.x # Sanity: check post-BA reprojection error and displacement. res = _ba_residuals(X_opt, Ps, pts_2d) res_pairs = res.reshape(-1, 2) final_err = float(np.sqrt((res_pairs ** 2).sum(axis=1)).mean()) displacement = float(np.linalg.norm(X_opt - x0)) # Accept only if: (a) reproj improved, (b) didn't move too far. if final_err < initial_err and final_err <= max_reproj_px and displacement <= 2.0: refined[i] = X_opt n_refined += 1 return refined