s23 / bundle_adjust.py
IhorIvanyshyn01's picture
Add missing core Python files
d40cee6
"""Post-hoc bundle adjustment of merged 3D wireframe vertices.
For each vertex in ``merged_v``, we:
1. Project its current 3D position into every available view.
2. Find the nearest gestalt corner (from ``get_vertices_and_edges_improved``)
in each view within ``match_px`` pixels.
3. If observations are found in ≥ ``min_views`` views, refine the 3D
position to minimise the sum of squared reprojection errors via
``scipy.optimize.least_squares`` with a Huber loss.
Cameras are fixed (COLMAP cameras are accurate). Only vertex positions
are optimised. No thresholds are tuned — just pure geometric
optimisation that converges to the correct answer given the cameras.
Entry point: ``refine_vertices_ba(merged_v, entry)``.
"""
from __future__ import annotations
import numpy as np
import cv2
from scipy.optimize import least_squares
from hoho2025.example_solutions import (
convert_entry_to_human_readable,
filter_vertices_by_background,
)
from hoho2025.color_mappings import gestalt_color_mapping
try:
from mvs_utils import collect_views, project_world_to_image
except ImportError:
from submission.mvs_utils import collect_views, project_world_to_image
VERTEX_CLASSES = ['apex', 'eave_end_point', 'flashing_end_point']
def _detect_2d_corners(gest_np):
"""Detect 2D gestalt corners in a single view (same as pipeline).
Returns (N, 2) float32 array of pixel coordinates.
"""
corners = []
for v_class in VERTEX_CLASSES:
color = np.array(gestalt_color_mapping[v_class])
mask = cv2.inRange(gest_np, color - 0.5, color + 0.5)
if mask.sum() == 0:
continue
_, _, _, centroids = cv2.connectedComponentsWithStats(mask, 8, cv2.CV_32S)
for c in centroids[1:]:
corners.append(c)
if not corners:
return np.empty((0, 2), dtype=np.float32)
return np.array(corners, dtype=np.float32)
def _collect_observations(
merged_v: np.ndarray,
views: dict,
corners_per_view: dict[str, np.ndarray],
match_px: float = 8.0,
) -> list[list[tuple[str, np.ndarray]]]:
"""For each vertex, find its 2D observation in each view.
Returns a list (one per vertex) of lists of ``(view_id, uv_observed)``.
"""
n = len(merged_v)
observations: list[list[tuple[str, np.ndarray]]] = [[] for _ in range(n)]
for vid, info in views.items():
corners_2d = corners_per_view.get(vid)
if corners_2d is None or len(corners_2d) == 0:
continue
P = info['P']
# Project all merged_v into this view
uv, z = project_world_to_image(P, merged_v)
H, W = info['height'], info['width']
for i in range(n):
if z[i] <= 0:
continue
u, v_px = uv[i]
if u < -50 or u > W + 50 or v_px < -50 or v_px > H + 50:
continue
# Find nearest 2D corner
d = np.linalg.norm(corners_2d - uv[i], axis=1)
j = int(np.argmin(d))
if d[j] <= match_px:
observations[i].append((vid, corners_2d[j].copy()))
return observations
def _ba_residuals(params, Ps, obs_2d):
"""Reprojection residuals for a single 3D point.
params: (3,) — x, y, z of the 3D point.
Ps: list of (3, 4) projection matrices.
obs_2d: list of (2,) observed 2D points.
Returns: (2*N,) residual vector.
"""
X = params
res = []
homog = np.array([X[0], X[1], X[2], 1.0])
for P, uv_obs in zip(Ps, obs_2d):
proj = P @ homog
if proj[2] <= 1e-6:
res.extend([100.0, 100.0]) # large penalty
continue
u = proj[0] / proj[2]
v = proj[1] / proj[2]
res.extend([u - uv_obs[0], v - uv_obs[1]])
return np.array(res, dtype=np.float64)
def refine_vertices_ba(
merged_v: np.ndarray,
entry,
match_px: float = 8.0,
min_views: int = 2,
max_reproj_px: float = 5.0,
min_initial_err_px: float = 3.0,
) -> np.ndarray:
"""Refine 3D vertex positions via bundle adjustment.
Only vertices with observations in ≥ ``min_views`` views are refined;
the rest keep their original positions. If the optimised position has
a mean reprojection error > ``max_reproj_px``, the original position
is kept (optimiser diverged).
Parameters
----------
merged_v : (N, 3) array of vertex positions.
entry : the raw dataset sample (passed to ``convert_entry_to_human_readable``).
match_px : maximum pixel distance to match a projected vertex to a
gestalt corner in a view.
min_views : minimum number of views with a matching observation for
BA to fire.
max_reproj_px : if post-BA mean reprojection error exceeds this,
revert to the original position.
Returns
-------
refined_v : (N, 3) array with refined positions.
"""
merged_v = np.asarray(merged_v, dtype=np.float64)
refined = merged_v.copy()
if len(merged_v) == 0:
return refined
good = convert_entry_to_human_readable(entry)
colmap_rec = good.get('colmap') or good.get('colmap_binary')
if colmap_rec is None:
return refined
views = collect_views(colmap_rec, good['image_ids'])
if len(views) < 2:
return refined
# Detect 2D corners in each view
corners_per_view: dict[str, np.ndarray] = {}
for gest, depth, img_id in zip(good['gestalt'], good['depth'], good['image_ids']):
if img_id not in views:
continue
depth_np = np.array(depth)
H, W = depth_np.shape[:2]
gest_np = np.array(gest.resize((W, H))).astype(np.uint8)
corners_per_view[img_id] = _detect_2d_corners(gest_np)
# Collect multi-view observations for each vertex
observations = _collect_observations(merged_v, views, corners_per_view, match_px)
# Run BA on each vertex independently.
# Key: only refine vertices whose INITIAL reprojection error is high
# (> min_initial_err_px). This targets the depth-estimation failures
# without disturbing already-good vertices.
n_refined = 0
for i in range(len(merged_v)):
obs = observations[i]
if len(obs) < min_views:
continue
Ps = [views[vid]['P'] for vid, _ in obs]
pts_2d = [uv for _, uv in obs]
x0 = merged_v[i].copy()
# Check initial reprojection error — skip if already low.
res0 = _ba_residuals(x0, Ps, pts_2d)
res0_pairs = res0.reshape(-1, 2)
initial_err = float(np.sqrt((res0_pairs ** 2).sum(axis=1)).mean())
if initial_err <= min_initial_err_px:
continue # already well-localised, leave it alone
try:
result = least_squares(
_ba_residuals, x0,
args=(Ps, pts_2d),
method='trf',
loss='huber',
f_scale=2.0,
max_nfev=50,
)
except Exception:
continue
X_opt = result.x
# Sanity: check post-BA reprojection error and displacement.
res = _ba_residuals(X_opt, Ps, pts_2d)
res_pairs = res.reshape(-1, 2)
final_err = float(np.sqrt((res_pairs ** 2).sum(axis=1)).mean())
displacement = float(np.linalg.norm(X_opt - x0))
# Accept only if: (a) reproj improved, (b) didn't move too far.
if final_err < initial_err and final_err <= max_reproj_px and displacement <= 2.0:
refined[i] = X_opt
n_refined += 1
return refined