File size: 7,522 Bytes
5b6e3d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
"""Post-hoc bundle adjustment of merged 3D wireframe vertices.

For each vertex in ``merged_v``, we:
1. Project its current 3D position into every available view.
2. Find the nearest gestalt corner (from ``get_vertices_and_edges_improved``)
   in each view within ``match_px`` pixels.
3. If observations are found in ≥ ``min_views`` views, refine the 3D
   position to minimise the sum of squared reprojection errors via
   ``scipy.optimize.least_squares`` with a Huber loss.

Cameras are fixed (COLMAP cameras are accurate). Only vertex positions
are optimised.  No thresholds are tuned — just pure geometric
optimisation that converges to the correct answer given the cameras.

Entry point: ``refine_vertices_ba(merged_v, entry)``.
"""

from __future__ import annotations

import numpy as np
import cv2
from scipy.optimize import least_squares

from hoho2025.example_solutions import (
    convert_entry_to_human_readable,
    filter_vertices_by_background,
)
from hoho2025.color_mappings import gestalt_color_mapping

try:
    from mvs_utils import collect_views, project_world_to_image
except ImportError:
    from submission.mvs_utils import collect_views, project_world_to_image


VERTEX_CLASSES = ['apex', 'eave_end_point', 'flashing_end_point']


def _detect_2d_corners(gest_np):
    """Detect 2D gestalt corners in a single view (same as pipeline).

    Returns (N, 2) float32 array of pixel coordinates.
    """
    corners = []
    for v_class in VERTEX_CLASSES:
        color = np.array(gestalt_color_mapping[v_class])
        mask = cv2.inRange(gest_np, color - 0.5, color + 0.5)
        if mask.sum() == 0:
            continue
        _, _, _, centroids = cv2.connectedComponentsWithStats(mask, 8, cv2.CV_32S)
        for c in centroids[1:]:
            corners.append(c)
    if not corners:
        return np.empty((0, 2), dtype=np.float32)
    return np.array(corners, dtype=np.float32)


def _collect_observations(
    merged_v: np.ndarray,
    views: dict,
    corners_per_view: dict[str, np.ndarray],
    match_px: float = 8.0,
) -> list[list[tuple[str, np.ndarray]]]:
    """For each vertex, find its 2D observation in each view.

    Returns a list (one per vertex) of lists of ``(view_id, uv_observed)``.
    """
    n = len(merged_v)
    observations: list[list[tuple[str, np.ndarray]]] = [[] for _ in range(n)]

    for vid, info in views.items():
        corners_2d = corners_per_view.get(vid)
        if corners_2d is None or len(corners_2d) == 0:
            continue
        P = info['P']
        # Project all merged_v into this view
        uv, z = project_world_to_image(P, merged_v)
        H, W = info['height'], info['width']
        for i in range(n):
            if z[i] <= 0:
                continue
            u, v_px = uv[i]
            if u < -50 or u > W + 50 or v_px < -50 or v_px > H + 50:
                continue
            # Find nearest 2D corner
            d = np.linalg.norm(corners_2d - uv[i], axis=1)
            j = int(np.argmin(d))
            if d[j] <= match_px:
                observations[i].append((vid, corners_2d[j].copy()))

    return observations


def _ba_residuals(params, Ps, obs_2d):
    """Reprojection residuals for a single 3D point.

    params: (3,) — x, y, z of the 3D point.
    Ps: list of (3, 4) projection matrices.
    obs_2d: list of (2,) observed 2D points.

    Returns: (2*N,) residual vector.
    """
    X = params
    res = []
    homog = np.array([X[0], X[1], X[2], 1.0])
    for P, uv_obs in zip(Ps, obs_2d):
        proj = P @ homog
        if proj[2] <= 1e-6:
            res.extend([100.0, 100.0])  # large penalty
            continue
        u = proj[0] / proj[2]
        v = proj[1] / proj[2]
        res.extend([u - uv_obs[0], v - uv_obs[1]])
    return np.array(res, dtype=np.float64)


def refine_vertices_ba(
    merged_v: np.ndarray,
    entry,
    match_px: float = 8.0,
    min_views: int = 2,
    max_reproj_px: float = 5.0,
    min_initial_err_px: float = 3.0,
) -> np.ndarray:
    """Refine 3D vertex positions via bundle adjustment.

    Only vertices with observations in ≥ ``min_views`` views are refined;
    the rest keep their original positions. If the optimised position has
    a mean reprojection error > ``max_reproj_px``, the original position
    is kept (optimiser diverged).

    Parameters
    ----------
    merged_v : (N, 3) array of vertex positions.
    entry : the raw dataset sample (passed to ``convert_entry_to_human_readable``).
    match_px : maximum pixel distance to match a projected vertex to a
        gestalt corner in a view.
    min_views : minimum number of views with a matching observation for
        BA to fire.
    max_reproj_px : if post-BA mean reprojection error exceeds this,
        revert to the original position.

    Returns
    -------
    refined_v : (N, 3) array with refined positions.
    """
    merged_v = np.asarray(merged_v, dtype=np.float64)
    refined = merged_v.copy()

    if len(merged_v) == 0:
        return refined

    good = convert_entry_to_human_readable(entry)
    colmap_rec = good.get('colmap') or good.get('colmap_binary')
    if colmap_rec is None:
        return refined

    views = collect_views(colmap_rec, good['image_ids'])
    if len(views) < 2:
        return refined

    # Detect 2D corners in each view
    corners_per_view: dict[str, np.ndarray] = {}
    for gest, depth, img_id in zip(good['gestalt'], good['depth'], good['image_ids']):
        if img_id not in views:
            continue
        depth_np = np.array(depth)
        H, W = depth_np.shape[:2]
        gest_np = np.array(gest.resize((W, H))).astype(np.uint8)
        corners_per_view[img_id] = _detect_2d_corners(gest_np)

    # Collect multi-view observations for each vertex
    observations = _collect_observations(merged_v, views, corners_per_view, match_px)

    # Run BA on each vertex independently.
    # Key: only refine vertices whose INITIAL reprojection error is high
    # (> min_initial_err_px). This targets the depth-estimation failures
    # without disturbing already-good vertices.
    n_refined = 0
    for i in range(len(merged_v)):
        obs = observations[i]
        if len(obs) < min_views:
            continue

        Ps = [views[vid]['P'] for vid, _ in obs]
        pts_2d = [uv for _, uv in obs]

        x0 = merged_v[i].copy()

        # Check initial reprojection error — skip if already low.
        res0 = _ba_residuals(x0, Ps, pts_2d)
        res0_pairs = res0.reshape(-1, 2)
        initial_err = float(np.sqrt((res0_pairs ** 2).sum(axis=1)).mean())
        if initial_err <= min_initial_err_px:
            continue  # already well-localised, leave it alone

        try:
            result = least_squares(
                _ba_residuals, x0,
                args=(Ps, pts_2d),
                method='trf',
                loss='huber',
                f_scale=2.0,
                max_nfev=50,
            )
        except Exception:
            continue

        X_opt = result.x
        # Sanity: check post-BA reprojection error and displacement.
        res = _ba_residuals(X_opt, Ps, pts_2d)
        res_pairs = res.reshape(-1, 2)
        final_err = float(np.sqrt((res_pairs ** 2).sum(axis=1)).mean())
        displacement = float(np.linalg.norm(X_opt - x0))

        # Accept only if: (a) reproj improved, (b) didn't move too far.
        if final_err < initial_err and final_err <= max_reproj_px and displacement <= 2.0:
            refined[i] = X_opt
            n_refined += 1

    return refined