File size: 19,000 Bytes
31f43c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
"""LC2WF-inspired 3D line cloud wireframe module.

Instead of lifting individual 2D corners to 3D via a single depth sample,
this module:

1. Extracts 2D line segments from gestalt edge masks (eave/ridge/rake/etc).
2. Samples many depth values along each 2D segment.
3. Fits a robust 3D line through the unprojected samples (RANSAC).
4. Merges similar 3D lines across views (direction + proximity).
5. Computes closest-point intersections of 3D line pairs β†’ vertex candidates.

The resulting vertices average over many depth samples, cancelling noise
that single-pixel corner depth estimates cannot. The 3D line intersections
give overdetermined vertex positions.

Entry points:
    extract_3d_lines(entry) β†’ list[Line3D]
    intersect_lines_to_vertices(lines, ...) β†’ np.ndarray
    predict_wireframe_lines(entry) β†’ (vertices, edges)
"""

from __future__ import annotations

import numpy as np
import cv2
from dataclasses import dataclass

from hoho2025.example_solutions import (
    convert_entry_to_human_readable,
    empty_solution,
    point_to_segment_dist,
)
from hoho2025.color_mappings import gestalt_color_mapping

try:
    from mvs_utils import collect_views, project_world_to_image
except ImportError:
    from submission.mvs_utils import collect_views, project_world_to_image


EDGE_CLASSES = ['eave', 'ridge', 'rake', 'valley', 'hip']
VERTEX_CLASSES = ['apex', 'eave_end_point', 'flashing_end_point']


@dataclass
class Line3D:
    """A 3D line segment fitted from depth samples."""
    point: np.ndarray       # (3,) β€” a point on the line
    direction: np.ndarray   # (3,) β€” unit direction vector
    p1: np.ndarray          # (3,) β€” endpoint 1
    p2: np.ndarray          # (3,) β€” endpoint 2
    length: float
    n_inliers: int
    edge_class: str
    view_id: str


# ---------------------------------------------------------------------------
# Step 1-2: Extract 2D segments, sample depth, fit 3D lines
# ---------------------------------------------------------------------------

def _unproject_pixel(u, v, depth, K_inv, R_t_inv, t_world):
    """Unproject a single pixel (u, v) at the given depth to world coords.

    K_inv : (3,3) β€” inverse intrinsics
    R_t_inv : (3,3) β€” R^T (inverse rotation)
    t_world : (3,) β€” camera centre in world = -R^T @ t
    """
    z = float(depth)
    if z <= 0.01 or z > 80.0:
        return None
    cam = K_inv @ np.array([u * z, v * z, z])
    world = R_t_inv @ cam + t_world
    return world


def _fit_3d_line_ransac(
    pts3d: np.ndarray,
    n_iter: int = 100,
    inlier_th: float = 0.3,
    min_inliers: int = 5,
) -> tuple[np.ndarray, np.ndarray, np.ndarray] | None:
    """RANSAC-fit a 3D line through a set of 3D points.

    Returns (point_on_line, unit_direction, inlier_pts) or None.
    """
    n = len(pts3d)
    if n < 2:
        return None

    best_inliers = None
    best_dir = None
    best_pt = None
    best_count = 0

    for _ in range(n_iter):
        idx = np.random.choice(n, 2, replace=False)
        p1, p2 = pts3d[idx[0]], pts3d[idx[1]]
        d = p2 - p1
        length = np.linalg.norm(d)
        if length < 0.05:
            continue
        d = d / length
        # Distance from each point to the line (p1, d)
        rel = pts3d - p1
        proj = rel @ d
        perp = rel - proj[:, None] * d
        dists = np.linalg.norm(perp, axis=1)
        inlier_mask = dists <= inlier_th
        count = int(inlier_mask.sum())
        if count > best_count:
            best_count = count
            best_inliers = inlier_mask
            best_dir = d
            best_pt = p1

    if best_count < min_inliers or best_inliers is None:
        return None

    # Refit on inliers using PCA
    inlier_pts = pts3d[best_inliers]
    centroid = inlier_pts.mean(axis=0)
    _, _, Vt = np.linalg.svd(inlier_pts - centroid)
    direction = Vt[0]
    if np.dot(direction, best_dir) < 0:
        direction = -direction

    return centroid, direction, inlier_pts


def extract_3d_lines_single_view(
    gest_np: np.ndarray,
    depth_np: np.ndarray,
    view_info: dict,
    n_samples: int = 30,
    min_line_px: int = 20,
) -> list[Line3D]:
    """Extract 3D lines from a single view's gestalt + depth."""
    H, W = depth_np.shape[:2]
    K = view_info['K']
    R = view_info['R']
    t = view_info['t']
    K_inv = np.linalg.inv(K)
    R_inv = R.T
    cam_center = -R_inv @ t

    lines: list[Line3D] = []
    view_id = view_info['image_id']

    for edge_class in EDGE_CLASSES:
        color = np.array(gestalt_color_mapping[edge_class])
        mask = cv2.inRange(gest_np, color - 0.5, color + 0.5)
        mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, np.ones((5, 5), np.uint8))
        if mask.sum() == 0:
            continue

        _, labels, stats, _ = cv2.connectedComponentsWithStats(mask, 8, cv2.CV_32S)
        for lbl in range(1, labels.max() + 1):
            area = stats[lbl, cv2.CC_STAT_AREA]
            if area < min_line_px:
                continue

            ys, xs = np.where(labels == lbl)
            if len(xs) < 3:
                continue

            # Fit 2D line to get direction + endpoints
            pts2d = np.column_stack([xs, ys]).astype(np.float32)
            line_params = cv2.fitLine(pts2d, cv2.DIST_L2, 0, 0.01, 0.01)
            vx, vy, x0, y0 = line_params.ravel()
            proj = (xs - x0) * vx + (ys - y0) * vy
            t_min, t_max = float(proj.min()), float(proj.max())

            # Sample N points along the 2D line
            ts = np.linspace(t_min, t_max, n_samples)
            pts3d_list = []
            for t_val in ts:
                u = x0 + t_val * vx
                v_px = y0 + t_val * vy
                ui, vi = int(round(u)), int(round(v_px))
                if 0 <= ui < W and 0 <= vi < H:
                    d = depth_np[vi, ui]
                    p = _unproject_pixel(u, v_px, d, K_inv, R_inv, cam_center)
                    if p is not None:
                        pts3d_list.append(p)

            if len(pts3d_list) < 5:
                continue

            pts3d = np.array(pts3d_list, dtype=np.float64)
            result = _fit_3d_line_ransac(pts3d, n_iter=50, inlier_th=0.3, min_inliers=5)
            if result is None:
                continue

            centroid, direction, inlier_pts = result
            # Endpoints: project inliers onto direction, take extremes
            s = (inlier_pts - centroid) @ direction
            p1 = centroid + float(s.min()) * direction
            p2 = centroid + float(s.max()) * direction
            length = float(np.linalg.norm(p2 - p1))
            if length < 0.3:
                continue

            lines.append(Line3D(
                point=centroid,
                direction=direction,
                p1=p1, p2=p2,
                length=length,
                n_inliers=len(inlier_pts),
                edge_class=edge_class,
                view_id=view_id,
            ))

    return lines


# ---------------------------------------------------------------------------
# Step 1-2 entry: all views
# ---------------------------------------------------------------------------

def extract_3d_lines(entry) -> tuple[list[Line3D], dict]:
    """Extract 3D lines from all views.

    Returns (all_lines, good_entry).
    """
    good = convert_entry_to_human_readable(entry)
    colmap_rec = good.get('colmap') or good.get('colmap_binary')
    if colmap_rec is None:
        return [], good

    views = collect_views(colmap_rec, good['image_ids'])
    all_lines: list[Line3D] = []

    for gest, depth, img_id in zip(good['gestalt'], good['depth'], good['image_ids']):
        info = views.get(img_id)
        if info is None:
            continue
        depth_np = np.array(depth).astype(np.float64) / 1000.0
        H, W = depth_np.shape[:2]
        gest_np = np.array(gest.resize((W, H))).astype(np.uint8)

        # Affine depth calibration using COLMAP sparse depth (same as pipeline)
        try:
            from hoho2025.example_solutions import get_sparse_depth, get_house_mask
            from sklearn_submission import fit_affine_ransac
            depth_sparse, found, _, _ = get_sparse_depth(colmap_rec, img_id, depth_np)
            if found:
                _, _, depth_np = fit_affine_ransac(depth_np, depth_sparse,
                                                    get_house_mask(good['ade'][good['image_ids'].index(img_id)]))
        except Exception:
            pass  # use raw depth if calibration fails

        view_lines = extract_3d_lines_single_view(gest_np, depth_np, info)
        all_lines.extend(view_lines)

    return all_lines, good


# ---------------------------------------------------------------------------
# Step 3: Merge similar 3D lines across views
# ---------------------------------------------------------------------------

def merge_3d_lines(
    lines: list[Line3D],
    direction_cos: float = 0.95,
    midpoint_dist: float = 1.0,
) -> list[Line3D]:
    """Merge 3D lines that have similar direction and nearby midpoints.

    Uses greedy clustering: each line is assigned to the first compatible
    cluster. The cluster representative is recomputed as the mean of its
    members (direction via PCA, endpoints via extremal projections).
    """
    if len(lines) <= 1:
        return lines

    clusters: list[list[int]] = []
    reps: list[Line3D] = []

    for i, line in enumerate(lines):
        matched = False
        for ci, rep in enumerate(reps):
            cos = abs(float(np.dot(line.direction, rep.direction)))
            if cos < direction_cos:
                continue
            mid_d = float(np.linalg.norm(
                (line.p1 + line.p2) / 2 - (rep.p1 + rep.p2) / 2
            ))
            if mid_d > midpoint_dist:
                continue
            clusters[ci].append(i)
            # Recompute representative
            members = [lines[j] for j in clusters[ci]]
            all_pts = np.vstack([np.vstack([m.p1, m.p2]) for m in members])
            centroid = all_pts.mean(axis=0)
            _, _, Vt = np.linalg.svd(all_pts - centroid)
            direction = Vt[0]
            if np.dot(direction, rep.direction) < 0:
                direction = -direction
            s = (all_pts - centroid) @ direction
            new_p1 = centroid + float(s.min()) * direction
            new_p2 = centroid + float(s.max()) * direction
            reps[ci] = Line3D(
                point=centroid, direction=direction,
                p1=new_p1, p2=new_p2,
                length=float(np.linalg.norm(new_p2 - new_p1)),
                n_inliers=sum(m.n_inliers for m in members),
                edge_class=members[0].edge_class,
                view_id='merged',
            )
            matched = True
            break
        if not matched:
            clusters.append([i])
            reps.append(Line3D(
                point=line.point.copy(), direction=line.direction.copy(),
                p1=line.p1.copy(), p2=line.p2.copy(),
                length=line.length, n_inliers=line.n_inliers,
                edge_class=line.edge_class, view_id=line.view_id,
            ))

    return reps


# ---------------------------------------------------------------------------
# Step 4: Intersect pairs of 3D lines β†’ vertex candidates
# ---------------------------------------------------------------------------

def closest_point_on_two_lines(
    p1: np.ndarray, d1: np.ndarray,
    p2: np.ndarray, d2: np.ndarray,
) -> tuple[np.ndarray, float] | None:
    """Find the closest point between two 3D lines.

    Returns (midpoint_of_closest_approach, distance_between_lines) or None
    if the lines are nearly parallel.
    """
    w0 = p1 - p2
    a = float(np.dot(d1, d1))
    b = float(np.dot(d1, d2))
    c = float(np.dot(d2, d2))
    d = float(np.dot(d1, w0))
    e = float(np.dot(d2, w0))

    denom = a * c - b * b
    if abs(denom) < 1e-8:
        return None  # parallel

    sc = (b * e - c * d) / denom
    tc = (a * e - b * d) / denom

    closest_on_1 = p1 + sc * d1
    closest_on_2 = p2 + tc * d2
    midpoint = (closest_on_1 + closest_on_2) / 2.0
    dist = float(np.linalg.norm(closest_on_1 - closest_on_2))

    return midpoint, dist


def intersect_lines_to_vertices(
    lines: list[Line3D],
    max_dist: float = 0.5,
    parallel_cos: float = 0.95,
    segment_margin: float = 0.5,
) -> np.ndarray:
    """Generate vertex candidates from 3D line intersections.

    For each pair of non-parallel lines:
    - compute the closest approach point;
    - accept if the distance between the lines at that point is ≀ max_dist;
    - accept only if the closest point is within ``segment_margin`` of
      both line segments (not too far outside the actual edge extent).
    """
    if len(lines) < 2:
        return np.empty((0, 3), dtype=np.float64)

    vertices: list[np.ndarray] = []
    for i in range(len(lines)):
        for j in range(i + 1, len(lines)):
            cos = abs(float(np.dot(lines[i].direction, lines[j].direction)))
            if cos >= parallel_cos:
                continue

            result = closest_point_on_two_lines(
                lines[i].point, lines[i].direction,
                lines[j].point, lines[j].direction,
            )
            if result is None:
                continue
            midpoint, dist = result
            if dist > max_dist:
                continue

            # Check that the intersection is near both line segments
            ok = True
            for line in (lines[i], lines[j]):
                s = float(np.dot(midpoint - line.point, line.direction))
                s_min = float(np.dot(line.p1 - line.point, line.direction))
                s_max = float(np.dot(line.p2 - line.point, line.direction))
                if s < s_min - segment_margin or s > s_max + segment_margin:
                    ok = False
                    break
            if ok:
                vertices.append(midpoint)

    if not vertices:
        return np.empty((0, 3), dtype=np.float64)
    return np.array(vertices, dtype=np.float64)


# ---------------------------------------------------------------------------
# Step 5: Integration helper
# ---------------------------------------------------------------------------

def snap_vertices_to_lines(
    vertices: np.ndarray,
    lines: list[Line3D],
    snap_radius: float = 0.4,
    min_line_inliers: int = 10,
    segment_margin: float = 0.3,
    require_agree: int = 1,
) -> tuple[np.ndarray, np.ndarray]:
    """Snap each vertex to the nearest 3D line if the line is trustworthy
    and the vertex sits within ``snap_radius`` perpendicular distance.

    The snap is a perpendicular projection of the vertex onto the line. If
    the projected point falls outside the segment ``[p1, p2]`` by more than
    ``segment_margin``, we clamp it to the nearest endpoint (so we never
    slide a vertex off the ends of the real edge).

    A line is considered "trustworthy" if it has β‰₯ ``min_line_inliers``
    depth samples (the more, the better the depth-noise averaging).

    When ``require_agree`` β‰₯ 2 we only snap if the vertex is within
    ``snap_radius`` of **multiple** independent lines and they all agree
    on roughly the same 3D location β€” this is a "consensus" mode that
    avoids snapping to a single noisy line.

    Returns
    -------
    refined : (N, 3) float64 β€” refined vertex positions
    snapped : (N,)  bool    β€” which vertices were moved
    """
    verts = np.asarray(vertices, dtype=np.float64)
    refined = verts.copy()
    snapped = np.zeros(len(verts), dtype=bool)

    if len(verts) == 0 or not lines:
        return refined, snapped

    # Pre-filter trustworthy lines
    trusted = [ln for ln in lines if ln.n_inliers >= min_line_inliers]
    if not trusted:
        return refined, snapped

    for i, v in enumerate(verts):
        # Compute perpendicular distance and projected point for each line
        candidates: list[tuple[float, np.ndarray, Line3D]] = []
        for ln in trusted:
            rel = v - ln.point
            s = float(np.dot(rel, ln.direction))
            projected = ln.point + s * ln.direction
            perp = float(np.linalg.norm(v - projected))
            if perp > snap_radius:
                continue
            # Clamp projection to segment
            s_min = float(np.dot(ln.p1 - ln.point, ln.direction))
            s_max = float(np.dot(ln.p2 - ln.point, ln.direction))
            if s_min > s_max:
                s_min, s_max = s_max, s_min
            if s < s_min - segment_margin:
                projected = ln.point + (s_min - segment_margin) * ln.direction
            elif s > s_max + segment_margin:
                projected = ln.point + (s_max + segment_margin) * ln.direction
            candidates.append((perp, projected, ln))

        if len(candidates) < require_agree:
            continue

        if require_agree >= 2:
            # Consensus: keep only if β‰₯2 candidates agree within snap_radius.
            candidates.sort(key=lambda c: c[0])
            best_proj = candidates[0][1]
            agree = 0
            for _, cp, _ in candidates:
                if np.linalg.norm(cp - best_proj) <= snap_radius:
                    agree += 1
            if agree < require_agree:
                continue
            # Snap to the mean of agreeing projections
            agreeing = [c[1] for c in candidates
                        if np.linalg.norm(c[1] - best_proj) <= snap_radius]
            refined[i] = np.mean(agreeing, axis=0)
            snapped[i] = True
        else:
            # Single-line snap: pick the closest
            candidates.sort(key=lambda c: c[0])
            refined[i] = candidates[0][1]
            snapped[i] = True

    return refined, snapped


def line_based_vertices(
    entry,
    max_intersection_dist: float = 0.5,
    merge_radius: float = 0.4,
) -> np.ndarray:
    """High-level: extract 3D lines, merge, intersect β†’ vertex candidates.

    Returns (K, 3) array of deduplicated vertex positions.
    """
    lines, good = extract_3d_lines(entry)
    if not lines:
        return np.empty((0, 3), dtype=np.float64)

    merged_lines = merge_3d_lines(lines)
    if len(merged_lines) < 2:
        return np.empty((0, 3), dtype=np.float64)

    raw_verts = intersect_lines_to_vertices(
        merged_lines, max_dist=max_intersection_dist,
    )
    if len(raw_verts) == 0:
        return np.empty((0, 3), dtype=np.float64)

    # Simple NMS merge
    from scipy.spatial import cKDTree
    tree = cKDTree(raw_verts)
    clusters = tree.query_ball_point(raw_verts, merge_radius)
    used = set()
    out = []
    for i, cl in enumerate(clusters):
        if i in used:
            continue
        members = [j for j in cl if j not in used]
        if not members:
            continue
        out.append(raw_verts[members].mean(axis=0))
        used.update(members)

    return np.array(out, dtype=np.float64) if out else np.empty((0, 3), dtype=np.float64)