Xallt commited on
Commit
7058f30
·
1 Parent(s): 10d8e71

Update solution

Browse files
Files changed (2) hide show
  1. example_solutions_copy.py +333 -92
  2. process_sample.py +5 -11
example_solutions_copy.py CHANGED
@@ -2,19 +2,69 @@ import io
2
  import tempfile
3
  import zipfile
4
  from collections import defaultdict
 
5
  from typing import List, Tuple
6
 
7
  import cv2
8
  import numpy as np
9
  import pycolmap
10
  from hoho2025.color_mappings import ade20k_color_mapping, gestalt_color_mapping
11
- from PIL import Image as PImage
12
  from scipy.spatial.distance import cdist
13
 
14
 
15
- def empty_solution():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  """Return a minimal valid solution, i.e. 2 vertices and 1 edge."""
17
- return np.zeros((2, 3)), [(0, 1)]
 
 
 
 
 
 
18
 
19
 
20
  def read_colmap_rec(colmap_data):
@@ -59,23 +109,169 @@ def get_house_mask(ade20k_seg):
59
  return full_mask
60
 
61
 
62
- def point_to_segment_dist(pt, seg_p1, seg_p2):
63
- """
64
- Computes the Euclidean distance from pt to the line segment p1->p2.
65
- pt, seg_p1, seg_p2: (x, y) as np.ndarray
66
- """
67
  # If both endpoints are the same, just return distance to one of them
68
  if np.allclose(seg_p1, seg_p2):
69
  return np.linalg.norm(pt - seg_p1)
70
  seg_vec = seg_p2 - seg_p1
71
  pt_vec = pt - seg_p1
72
- seg_len2 = seg_vec.dot(seg_vec)
73
- t = max(0, min(1, pt_vec.dot(seg_vec) / seg_len2))
74
  proj = seg_p1 + t * seg_vec
 
 
 
 
 
 
 
 
 
75
  return np.linalg.norm(pt - proj)
76
 
77
 
78
- def get_vertices_and_edges_from_segmentation(gest_seg_np, edge_th=25.0):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  """
80
  Identify apex and eave-end vertices, then detect lines for eave/ridge/rake/valley.
81
  For each connected component, we do a line fit with cv2.fitLine, then measure
@@ -190,11 +386,14 @@ def get_vertices_and_edges_from_segmentation(gest_seg_np, edge_th=25.0):
190
  conn = tuple(sorted((vA, vB)))
191
  connections.append(conn)
192
 
193
- return vertices, connections
 
 
 
194
 
195
 
196
  def get_uv_depth(
197
- vertices: List[dict],
198
  depth_fitted: np.ndarray,
199
  sparse_depth: np.ndarray,
200
  search_radius: int = 10,
@@ -211,7 +410,7 @@ def get_uv_depth(
211
 
212
  Parameters
213
  ----------
214
- vertices : List[dict]
215
  Each dict must have "xy" at least, e.g. {"xy": (x, y), ...}
216
  depth_fitted : np.ndarray
217
  A 2D array (H, W), the dense (or corrected) depth for fallback.
@@ -229,7 +428,7 @@ def get_uv_depth(
229
  """
230
 
231
  # Collect each vertex's (x, y)
232
- uv = np.array([vert["xy"] for vert in vertices], dtype=np.float32)
233
 
234
  # Convert to integer pixel coordinates (round or floor)
235
  uv_int = np.round(uv).astype(np.int32)
@@ -277,7 +476,7 @@ def get_uv_depth(
277
 
278
 
279
  def project_vertices_to_3d(
280
- uv: np.ndarray, depth_vert: np.ndarray, col_img: pycolmap.Image
281
  ) -> np.ndarray:
282
  """
283
  Projects 2D vertex coordinates with associated depths to 3D world coordinates.
@@ -316,21 +515,21 @@ def project_vertices_to_3d(
316
 
317
 
318
  def create_3d_wireframe_single_image(
319
- vertices: List[dict],
320
- connections: List[Tuple[int, int]],
321
- depth: PImage,
322
  colmap_rec: pycolmap.Reconstruction,
323
  img_id: str,
324
- ade_seg: PImage,
325
- ) -> np.ndarray:
326
  """
327
  Processes a single image view to generate 3D vertex coordinates from existing 2D vertices/edges.
328
 
329
  Parameters
330
  ----------
331
- vertices : List[dict]
332
  List of 2D vertex dictionaries (e.g., {"xy": (x, y), "type": ...}).
333
- connections : List[Tuple[int, int]]
334
  List of 2D edge connections (indices into the vertices list).
335
  depth : PIL.Image
336
  Initial dense depth map as a PIL Image.
@@ -353,7 +552,7 @@ def create_3d_wireframe_single_image(
353
  print(
354
  f"Warning: create_3d_wireframe_single_image called with insufficient vertices/connections for image {img_id}"
355
  )
356
- return np.empty((0, 3))
357
 
358
  # Get fitted dense depth and sparse depth
359
  depth_fitted, depth_sparse, found_sparse, col_img = get_fitted_dense_depth(
@@ -366,35 +565,54 @@ def create_3d_wireframe_single_image(
366
  # Backproject to 3D
367
  vertices_3d = project_vertices_to_3d(uv, depth_vert, col_img)
368
 
369
- return vertices_3d
 
 
 
370
 
371
 
372
- def merge_vertices_3d(vert_edge_per_image, th=0.5):
 
 
373
  """Merge vertices that are close to each other in 3D space and are of same types"""
374
  # Initialize structures to collect vertices and connections from all images
375
- all_3d_vertices = []
376
- connections_3d = []
377
- all_indexes = []
378
  cur_start = 0
379
- types = []
 
 
 
 
 
 
380
 
 
381
  # Combine vertices and update connection indices across all images
382
- for cimg_idx, (vertices, connections, vertices_3d) in vert_edge_per_image.items():
383
- types += [int(v["type"] == "apex") for v in vertices]
384
- all_3d_vertices.append(vertices_3d)
385
- connections_3d += [(x + cur_start, y + cur_start) for (x, y) in connections]
 
 
 
 
 
 
 
 
386
  cur_start += len(vertices_3d)
387
- all_3d_vertices = np.concatenate(all_3d_vertices, axis=0)
 
388
 
389
  # Calculate distance matrix between all vertices
390
  distmat = cdist(all_3d_vertices, all_3d_vertices)
391
- types = np.array(types).reshape(-1, 1)
392
- same_types = cdist(types, types)
393
 
394
  # Create mask for vertices that should be merged (close in space and same type)
395
- mask_to_merge = (distmat <= th) & (same_types == 0)
396
- new_vertices = []
397
- new_connections = []
398
 
399
  # Extract vertex indices to merge based on the mask
400
  to_merge = sorted(
@@ -426,11 +644,15 @@ def merge_vertices_3d(vert_edge_per_image, th=0.5):
426
  old_idx_to_new = {}
427
  count = 0
428
  for idxs in merged:
429
- new_vertices.append(all_3d_vertices[idxs].mean(axis=0))
 
 
 
 
 
430
  for idx in idxs:
431
  old_idx_to_new[idx] = count
432
  count += 1
433
- new_vertices = np.array(new_vertices)
434
 
435
  # Update connections to use new vertex indices
436
  for conn in connections_3d:
@@ -438,29 +660,29 @@ def merge_vertices_3d(vert_edge_per_image, th=0.5):
438
  if new_con[0] == new_con[1]:
439
  continue
440
  if new_con not in new_connections:
441
- new_connections.append(new_con)
442
- return new_vertices, new_connections
443
 
444
 
445
- def prune_not_connected(all_3d_vertices, connections_3d, keep_largest=True):
446
  """
447
  Prune vertices not connected to anything. If keep_largest=True, also
448
  keep only the largest connected component in the graph.
449
  """
450
- if len(all_3d_vertices) == 0:
451
  return np.array([]), []
452
 
453
  # adjacency
454
  adj = defaultdict(set)
455
- for i, j in connections_3d:
456
- adj[i].add(j)
457
- adj[j].add(i)
458
 
459
  # keep only vertices that appear in at least one edge
460
  used_idxs = set()
461
- for i, j in connections_3d:
462
- used_idxs.add(i)
463
- used_idxs.add(j)
464
 
465
  if not used_idxs:
466
  return np.empty((0, 3)), []
@@ -471,12 +693,13 @@ def prune_not_connected(all_3d_vertices, connections_3d, keep_largest=True):
471
  used_list = sorted(list(used_idxs))
472
  for new_id, old_id in enumerate(used_list):
473
  new_map[old_id] = new_id
474
- new_vertices = np.array([all_3d_vertices[old_id] for old_id in used_list])
 
475
  new_conns = []
476
- for i, j in connections_3d:
477
- if i in used_idxs and j in used_idxs:
478
- new_conns.append((new_map[i], new_map[j]))
479
- return new_vertices, new_conns
480
 
481
  # Otherwise find the largest connected component:
482
  visited = set()
@@ -510,15 +733,16 @@ def prune_not_connected(all_3d_vertices, connections_3d, keep_largest=True):
510
  for new_id, old_id in enumerate(largest):
511
  new_map[old_id] = new_id
512
 
513
- new_vertices = np.array([all_3d_vertices[old_id] for old_id in largest])
514
- new_conns = []
515
- for i, j in connections_3d:
516
- if i in largest and j in largest:
517
- new_conns.append((new_map[i], new_map[j]))
 
518
 
519
  # remove duplicates
520
- new_conns = list(set([tuple(sorted(c)) for c in new_conns]))
521
- return new_vertices, new_conns
522
 
523
 
524
  def get_sparse_depth(colmap_rec, img_id_substring, depth_shape):
@@ -530,7 +754,7 @@ def get_sparse_depth(colmap_rec, img_id_substring, depth_shape):
530
  H, W = depth_shape
531
 
532
  # 1) Find the matching COLMAP image
533
- found_img: pycolmap.Image = None
534
  for img_id_c, col_img in colmap_rec.images.items():
535
  if img_id_substring in col_img.name:
536
  found_img = col_img
@@ -555,7 +779,11 @@ def get_sparse_depth(colmap_rec, img_id_substring, depth_shape):
555
  z_vals = []
556
  for xyz in points_xyz:
557
  proj = found_img.project_point(xyz) # returns (u, v) in image coords or None
558
- cur_res = np.array([found_img.camera.height, found_img.camera.width])
 
 
 
 
559
  exp_res = np.array([H, W])
560
  proj = proj * exp_res / cur_res
561
 
@@ -569,6 +797,8 @@ def get_sparse_depth(colmap_rec, img_id_substring, depth_shape):
569
  # We'll compute depth as Z in camera coords
570
  # from the world->cam transform col_img holds
571
  mat4x4 = np.eye(4)
 
 
572
  mat4x4[:3, :4] = found_img.cam_from_world.matrix()
573
  p_cam = mat4x4 @ np.array([xyz[0], xyz[1], xyz[2], 1.0])
574
  z_vals.append(p_cam[2] / p_cam[3])
@@ -661,7 +891,7 @@ def get_fitted_dense_depth(depth, colmap_rec, img_id, ade20k_seg):
661
  return depth_fitted, depth_sparse, True, col_img
662
 
663
 
664
- def prune_too_far(all_3d_vertices, connections_3d, colmap_rec, th=3.0):
665
  """
666
  Prune vertices that are too far from sparse point cloud
667
 
@@ -669,28 +899,36 @@ def prune_too_far(all_3d_vertices, connections_3d, colmap_rec, th=3.0):
669
  xyz_sfm = []
670
  for k, v in colmap_rec.points3D.items():
671
  xyz_sfm.append(v.xyz)
672
- xyz_sfm = np.array(xyz_sfm)
673
- distmat = cdist(all_3d_vertices, xyz_sfm)
674
- mindist = distmat.min(axis=1)
 
 
 
675
  mask = mindist <= th
676
- all_3d_vertices_new = all_3d_vertices[mask]
677
- old_idx_survived = np.arange(len(all_3d_vertices))[mask]
678
- new_idxs = np.arange(len(all_3d_vertices_new))
679
- old_to_new_idx = dict(zip(old_idx_survived, new_idxs))
 
 
680
  connections_3d_new = [
681
- (old_to_new_idx[conn[0]], old_to_new_idx[conn[1]])
682
- for conn in connections_3d
683
- if mask[conn[0]] and mask[conn[1]]
684
  ]
685
- return all_3d_vertices_new, connections_3d_new
 
 
 
686
 
687
 
688
- def predict_wireframe(entry) -> Tuple[np.ndarray, List[int]]:
689
  """
690
  Predict 3D wireframe from a dataset entry.
691
  """
692
  good_entry = convert_entry_to_human_readable(entry)
693
- vert_edge_per_image = {}
694
  for i, (gest, depth, K, R, t, img_id, ade_seg) in enumerate(
695
  zip(
696
  good_entry["gestalt"],
@@ -712,14 +950,18 @@ def predict_wireframe(entry) -> Tuple[np.ndarray, List[int]]:
712
  gest_seg_np = np.array(gest_seg).astype(np.uint8)
713
 
714
  # Get 2D vertices and edges first
715
- vertices, connections = get_vertices_and_edges_from_segmentation(
716
  gest_seg_np, edge_th=10.0
717
  )
 
 
718
 
719
  # Check if we have enough to proceed
720
  if (len(vertices) < 2) or (len(connections) < 1):
721
  print(f"Not enough vertices or connections found in image {i}, skipping.")
722
- vert_edge_per_image[i] = [], [], np.empty((0, 3))
 
 
723
  continue
724
 
725
  # Call the refactored function to get 3D points
@@ -727,19 +969,18 @@ def predict_wireframe(entry) -> Tuple[np.ndarray, List[int]]:
727
  vertices, connections, depth, colmap_rec, img_id, ade_seg
728
  )
729
  # Store original 2D vertices, connections, and computed 3D points
730
- vert_edge_per_image[i] = vertices, connections, vertices_3d
 
 
731
 
732
  # Merge vertices from all images
733
- all_3d_vertices, connections_3d = merge_vertices_3d(vert_edge_per_image, 0.5)
734
- all_3d_vertices_clean, connections_3d_clean = prune_not_connected(
735
- all_3d_vertices, connections_3d, keep_largest=False
736
- )
737
- all_3d_vertices_clean, connections_3d_clean = prune_too_far(
738
- all_3d_vertices_clean, connections_3d_clean, colmap_rec, th=4.0
739
- )
740
 
741
- if (len(all_3d_vertices_clean) < 2) or len(connections_3d_clean) < 1:
742
  print("Not enough vertices or connections in the 3D vertices")
743
  return empty_solution()
744
 
745
- return all_3d_vertices_clean, connections_3d_clean
 
2
  import tempfile
3
  import zipfile
4
  from collections import defaultdict
5
+ from dataclasses import dataclass
6
  from typing import List, Tuple
7
 
8
  import cv2
9
  import numpy as np
10
  import pycolmap
11
  from hoho2025.color_mappings import ade20k_color_mapping, gestalt_color_mapping
 
12
  from scipy.spatial.distance import cdist
13
 
14
 
15
+ @dataclass
16
+ class WireframePoint2D:
17
+ xy: np.ndarray
18
+ type: str
19
+
20
+
21
+ @dataclass
22
+ class WireframeEdge:
23
+ i1: int
24
+ i2: int
25
+
26
+
27
+ @dataclass
28
+ class Wireframe2D:
29
+ vertices: List[WireframePoint2D]
30
+ edges: List[WireframeEdge]
31
+
32
+
33
+ @dataclass
34
+ class WireframePoint3D:
35
+ xyz: np.ndarray
36
+ type: str
37
+
38
+
39
+ @dataclass
40
+ class Wireframe2DWith3D:
41
+ wireframe2d: Wireframe2D
42
+ vertices_3d: List[WireframePoint3D]
43
+
44
+
45
+ @dataclass
46
+ class Wireframe3D:
47
+ vertices: List[WireframePoint3D]
48
+ edges: List[WireframeEdge]
49
+
50
+ @property
51
+ def vertices_np(self) -> np.ndarray:
52
+ return np.array([v.xyz for v in self.vertices])
53
+
54
+ @property
55
+ def edges_np(self) -> np.ndarray:
56
+ return np.array([[e.i1, e.i2] for e in self.edges])
57
+
58
+
59
+ def empty_solution() -> Wireframe3D:
60
  """Return a minimal valid solution, i.e. 2 vertices and 1 edge."""
61
+ return Wireframe3D(
62
+ vertices=[
63
+ WireframePoint3D(xyz=np.zeros((3,)), type=""),
64
+ WireframePoint3D(xyz=np.zeros((3,)), type=""),
65
+ ],
66
+ edges=[WireframeEdge(i1=0, i2=1)],
67
+ )
68
 
69
 
70
  def read_colmap_rec(colmap_data):
 
109
  return full_mask
110
 
111
 
112
+ def point_to_segment_proj(pt, seg_p1, seg_p2):
 
 
 
 
113
  # If both endpoints are the same, just return distance to one of them
114
  if np.allclose(seg_p1, seg_p2):
115
  return np.linalg.norm(pt - seg_p1)
116
  seg_vec = seg_p2 - seg_p1
117
  pt_vec = pt - seg_p1
118
+ seg_len2 = np.linalg.norm(seg_vec) ** 2
119
+ t = max(0, min(1, np.dot(pt_vec, seg_vec) / seg_len2))
120
  proj = seg_p1 + t * seg_vec
121
+ return proj
122
+
123
+
124
+ def point_to_segment_dist(pt, seg_p1, seg_p2):
125
+ """
126
+ Computes the Euclidean distance from pt to the line segment p1->p2.
127
+ pt, seg_p1, seg_p2: (x, y) as np.ndarray
128
+ """
129
+ proj = point_to_segment_proj(pt, seg_p1, seg_p2)
130
  return np.linalg.norm(pt - proj)
131
 
132
 
133
+ def combine_segs(keys, gestalt_img) -> np.ndarray:
134
+ res = np.zeros(gestalt_img.shape[:2], dtype=bool)
135
+ for key in keys:
136
+ color = np.array(gestalt_color_mapping[key])
137
+ mask = cv2.inRange(gestalt_img, color - 0.5, color + 0.5)
138
+ res = res | mask.astype(bool)
139
+ return res
140
+
141
+
142
+ def get_turn_angles(contour):
143
+ angles = []
144
+ vcur = contour[:, 0] # (N, 2)
145
+ vprev = np.concatenate([vcur[-1, None], vcur[:-1]]) # (N, 2)
146
+ vnext = np.concatenate([vcur[1:], vcur[0, None]]) # (N, 2)
147
+
148
+ vecprev, vecnext = vcur - vprev, vnext - vcur
149
+ vecprev = vecprev / np.linalg.norm(vecprev, axis=1, keepdims=True)
150
+ vecnext = vecnext / np.linalg.norm(vecnext, axis=1, keepdims=True)
151
+
152
+ def dot(a, b):
153
+ return (a * b).sum(axis=-1)
154
+
155
+ angles = np.degrees(np.arctan2(np.cross(vecprev, vecnext), dot(vecprev, vecnext)))
156
+ return angles
157
+
158
+
159
+ def slice_arr(arr, i, j):
160
+ if i <= j:
161
+ if j <= len(arr):
162
+ return arr[i:j]
163
+ else:
164
+ return np.concatenate([arr[i:], arr[: j - len(arr)]])
165
+ else:
166
+ return np.concatenate([arr[i:], arr[:j]])
167
+
168
+
169
+ def group_segments(segments):
170
+ segments = sorted(segments, key=lambda x: x[0])
171
+ grouped = []
172
+ for i in range(len(segments)):
173
+ if i == 0:
174
+ grouped.append(segments[i])
175
+ else:
176
+ if segments[i][0] <= grouped[-1][1]:
177
+ grouped[-1] = (grouped[-1][0], max(grouped[-1][1], segments[i][1]))
178
+ else:
179
+ grouped.append(segments[i])
180
+ return grouped
181
+
182
+
183
+ def get_contour_interesting_points_indices(contour):
184
+ angles = get_turn_angles(contour)
185
+ angle_len = cv2.arcLength(contour, True) / 20
186
+
187
+ interesting_segments = []
188
+ interesting_points = []
189
+ for i in range(len(angles)):
190
+ j = i + 1
191
+ while True:
192
+ cur_len = cv2.arcLength(slice_arr(contour, i, j), False)
193
+ if cur_len > angle_len:
194
+ break
195
+ j += 1
196
+ # i:j is smaller than angle_len
197
+ turns = np.cumsum(slice_arr(angles, i, j))
198
+ k = 2
199
+ if len(turns) > k and np.abs(turns[k:]).max() > 70:
200
+ matching_i = np.where(np.abs(turns[k:]) > 70)[0][0] + k + i
201
+ interesting_segments.append((i, int(matching_i)))
202
+ interesting_points.append(i)
203
+
204
+ grouped_segments = group_segments(interesting_segments)
205
+ return [((i + j) // 2) % len(contour) for i, j in grouped_segments]
206
+ # return interesting_points
207
+
208
+
209
+ def get_contour_interesting_wireframe(contour) -> Tuple[np.ndarray, np.ndarray]:
210
+ indices = get_contour_interesting_points_indices(contour)
211
+ connections = []
212
+ for i in range(len(indices)):
213
+ i1, i2 = indices[i], indices[(i + 1) % len(indices)]
214
+ segment_len = np.linalg.norm(contour[i1, 0] - contour[i2, 0])
215
+ points_side1 = slice_arr(contour[:, 0], i1, i2)
216
+ points_side2 = slice_arr(contour[:, 0], i2, i1)
217
+ points_side1_distances = np.array(
218
+ [
219
+ point_to_segment_dist(p, contour[i1, 0], contour[i2, 0])
220
+ for p in points_side1
221
+ ]
222
+ )
223
+ points_side2_distances = np.array(
224
+ [
225
+ point_to_segment_dist(p, contour[i2, 0], contour[i1, 0])
226
+ for p in points_side2
227
+ ]
228
+ )
229
+ dist_side_1 = (
230
+ points_side1_distances.max() if len(points_side1_distances) > 0 else 0
231
+ )
232
+ dist_side_2 = (
233
+ points_side2_distances.max() if len(points_side2_distances) > 0 else 0
234
+ )
235
+ factor = 0.1
236
+ if dist_side_1 <= segment_len * factor or dist_side_2 <= segment_len * factor:
237
+ connections.append((i, (i + 1) % len(indices)))
238
+ return contour[indices, 0], np.array(connections)
239
+
240
+
241
+ def get_vertices_and_edges_from_segmentation_contours(
242
+ gest_seg_np, edge_th=25.0
243
+ ) -> Wireframe2D:
244
+ gest_seg_np = np.array(gest_seg_np)
245
+ keys_segments = ["eave", "ridge", "rake", "valley"]
246
+
247
+ all_contours = []
248
+ for key in keys_segments:
249
+ mask = combine_segs([key], gest_seg_np)
250
+ contours, _ = cv2.findContours(
251
+ mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_TC89_KCOS
252
+ )
253
+ all_contours.extend(contours)
254
+ # contours = contours[::-1]
255
+ all_vertices: list[WireframePoint2D] = []
256
+ all_connections: list[WireframeEdge] = []
257
+ for contour in all_contours:
258
+ area = cv2.contourArea(contour, oriented=True)
259
+ if area < 0:
260
+ contour = contour[::-1]
261
+
262
+ interesting_points, interesting_connections = get_contour_interesting_wireframe(
263
+ contour
264
+ )
265
+ all_vertices.extend(
266
+ WireframePoint2D(xy=p, type=key) for p in interesting_points
267
+ )
268
+ all_connections.extend(
269
+ WireframeEdge(i1=i1, i2=i2) for i1, i2 in interesting_connections
270
+ )
271
+ return Wireframe2D(all_vertices, all_connections)
272
+
273
+
274
+ def get_vertices_and_edges_from_segmentation(gest_seg_np, edge_th=25.0) -> Wireframe2D:
275
  """
276
  Identify apex and eave-end vertices, then detect lines for eave/ridge/rake/valley.
277
  For each connected component, we do a line fit with cv2.fitLine, then measure
 
386
  conn = tuple(sorted((vA, vB)))
387
  connections.append(conn)
388
 
389
+ vertices = [WireframePoint2D(xy=v["xy"], type=v["type"]) for v in vertices]
390
+ connections = [WireframeEdge(i1=c[0], i2=c[1]) for c in connections]
391
+
392
+ return Wireframe2D(vertices, connections)
393
 
394
 
395
  def get_uv_depth(
396
+ vertices: List[WireframePoint2D],
397
  depth_fitted: np.ndarray,
398
  sparse_depth: np.ndarray,
399
  search_radius: int = 10,
 
410
 
411
  Parameters
412
  ----------
413
+ vertices : List[WireframePoint2D]
414
  Each dict must have "xy" at least, e.g. {"xy": (x, y), ...}
415
  depth_fitted : np.ndarray
416
  A 2D array (H, W), the dense (or corrected) depth for fallback.
 
428
  """
429
 
430
  # Collect each vertex's (x, y)
431
+ uv = np.array([vert.xy for vert in vertices], dtype=np.float32)
432
 
433
  # Convert to integer pixel coordinates (round or floor)
434
  uv_int = np.round(uv).astype(np.int32)
 
476
 
477
 
478
  def project_vertices_to_3d(
479
+ uv: np.ndarray, depth_vert: np.ndarray, col_img
480
  ) -> np.ndarray:
481
  """
482
  Projects 2D vertex coordinates with associated depths to 3D world coordinates.
 
515
 
516
 
517
  def create_3d_wireframe_single_image(
518
+ vertices: List[WireframePoint2D],
519
+ connections: List[WireframeEdge],
520
+ depth,
521
  colmap_rec: pycolmap.Reconstruction,
522
  img_id: str,
523
+ ade_seg,
524
+ ) -> List[WireframePoint3D]:
525
  """
526
  Processes a single image view to generate 3D vertex coordinates from existing 2D vertices/edges.
527
 
528
  Parameters
529
  ----------
530
+ vertices : List[WireframePoint2D]
531
  List of 2D vertex dictionaries (e.g., {"xy": (x, y), "type": ...}).
532
+ connections : List[WireframeEdge]
533
  List of 2D edge connections (indices into the vertices list).
534
  depth : PIL.Image
535
  Initial dense depth map as a PIL Image.
 
552
  print(
553
  f"Warning: create_3d_wireframe_single_image called with insufficient vertices/connections for image {img_id}"
554
  )
555
+ return []
556
 
557
  # Get fitted dense depth and sparse depth
558
  depth_fitted, depth_sparse, found_sparse, col_img = get_fitted_dense_depth(
 
565
  # Backproject to 3D
566
  vertices_3d = project_vertices_to_3d(uv, depth_vert, col_img)
567
 
568
+ return [
569
+ WireframePoint3D(xyz=v, type=vertices[i].type)
570
+ for i, v in enumerate(vertices_3d)
571
+ ]
572
 
573
 
574
+ def merge_vertices_3d(
575
+ vert_edge_per_image: dict[int, Wireframe2DWith3D], th=0.5
576
+ ) -> Wireframe3D:
577
  """Merge vertices that are close to each other in 3D space and are of same types"""
578
  # Initialize structures to collect vertices and connections from all images
579
+ all_3d_vertices_list: list[np.ndarray] = []
580
+ connections_3d: list[tuple[int, int]] = []
 
581
  cur_start = 0
582
+ types: list[int] = []
583
+
584
+ all_types_set: set[str] = set()
585
+ for _, wireframe2d_with_3d in vert_edge_per_image.items():
586
+ all_types_set.update([v.type for v in wireframe2d_with_3d.wireframe2d.vertices])
587
+ all_types = list(all_types_set)
588
+ type_idx_map = {t: i for i, t in enumerate(all_types)}
589
 
590
+ all_wireframe_points_3d: list[WireframePoint3D] = []
591
  # Combine vertices and update connection indices across all images
592
+ for cimg_idx, wireframe2d_with_3d in vert_edge_per_image.items():
593
+ vertices = wireframe2d_with_3d.wireframe2d.vertices
594
+ connections = wireframe2d_with_3d.wireframe2d.edges
595
+ vertices_3d: np.ndarray = np.array(
596
+ [v.xyz for v in wireframe2d_with_3d.vertices_3d]
597
+ )
598
+ types += [type_idx_map[v.type] for v in vertices]
599
+ all_wireframe_points_3d.extend(wireframe2d_with_3d.vertices_3d)
600
+ all_3d_vertices_list.append(vertices_3d)
601
+ connections_3d += [
602
+ (con.i1 + cur_start, con.i2 + cur_start) for con in connections
603
+ ]
604
  cur_start += len(vertices_3d)
605
+ all_3d_vertices = np.concatenate(all_3d_vertices_list, axis=0)
606
+ types_np = np.array(types)
607
 
608
  # Calculate distance matrix between all vertices
609
  distmat = cdist(all_3d_vertices, all_3d_vertices)
610
+ same_types = types_np[:, None] == types_np[None, :]
 
611
 
612
  # Create mask for vertices that should be merged (close in space and same type)
613
+ mask_to_merge = (distmat <= th) & same_types
614
+ new_vertices: list[WireframePoint3D] = []
615
+ new_connections: list[WireframeEdge] = []
616
 
617
  # Extract vertex indices to merge based on the mask
618
  to_merge = sorted(
 
644
  old_idx_to_new = {}
645
  count = 0
646
  for idxs in merged:
647
+ types_cur = [all_wireframe_points_3d[i].type for i in idxs]
648
+ assert len(set(types_cur)) == 1
649
+
650
+ new_vertices.append(
651
+ WireframePoint3D(xyz=all_3d_vertices[idxs].mean(axis=0), type=types_cur[0])
652
+ )
653
  for idx in idxs:
654
  old_idx_to_new[idx] = count
655
  count += 1
 
656
 
657
  # Update connections to use new vertex indices
658
  for conn in connections_3d:
 
660
  if new_con[0] == new_con[1]:
661
  continue
662
  if new_con not in new_connections:
663
+ new_connections.append(WireframeEdge(i1=new_con[0], i2=new_con[1]))
664
+ return Wireframe3D(new_vertices, new_connections)
665
 
666
 
667
+ def prune_not_connected(wireframe_3d: Wireframe3D, keep_largest=True):
668
  """
669
  Prune vertices not connected to anything. If keep_largest=True, also
670
  keep only the largest connected component in the graph.
671
  """
672
+ if len(wireframe_3d.vertices) == 0:
673
  return np.array([]), []
674
 
675
  # adjacency
676
  adj = defaultdict(set)
677
+ for edge in wireframe_3d.edges:
678
+ adj[edge.i1].add(edge.i2)
679
+ adj[edge.i2].add(edge.i1)
680
 
681
  # keep only vertices that appear in at least one edge
682
  used_idxs = set()
683
+ for edge in wireframe_3d.edges:
684
+ used_idxs.add(edge.i1)
685
+ used_idxs.add(edge.i2)
686
 
687
  if not used_idxs:
688
  return np.empty((0, 3)), []
 
693
  used_list = sorted(list(used_idxs))
694
  for new_id, old_id in enumerate(used_list):
695
  new_map[old_id] = new_id
696
+
697
+ new_vertices = [wireframe_3d.vertices[i] for i in used_list]
698
  new_conns = []
699
+ for edge in wireframe_3d.edges:
700
+ if edge.i1 in used_idxs and edge.i2 in used_idxs:
701
+ new_conns.append(edge)
702
+ return Wireframe3D(new_vertices, new_conns)
703
 
704
  # Otherwise find the largest connected component:
705
  visited = set()
 
733
  for new_id, old_id in enumerate(largest):
734
  new_map[old_id] = new_id
735
 
736
+ new_vertices = [wireframe_3d.vertices[i] for i in largest]
737
+ new_conns = [
738
+ WireframeEdge(i1=new_map[edge.i1], i2=new_map[edge.i2])
739
+ for edge in wireframe_3d.edges
740
+ if edge.i1 in largest and edge.i2 in largest
741
+ ]
742
 
743
  # remove duplicates
744
+ new_conns = list(set(new_conns))
745
+ return Wireframe3D(new_vertices, new_conns)
746
 
747
 
748
  def get_sparse_depth(colmap_rec, img_id_substring, depth_shape):
 
754
  H, W = depth_shape
755
 
756
  # 1) Find the matching COLMAP image
757
+ found_img: pycolmap.Image | None = None
758
  for img_id_c, col_img in colmap_rec.images.items():
759
  if img_id_substring in col_img.name:
760
  found_img = col_img
 
779
  z_vals = []
780
  for xyz in points_xyz:
781
  proj = found_img.project_point(xyz) # returns (u, v) in image coords or None
782
+ found_camera = found_img.camera
783
+ if found_camera is None:
784
+ print(f"Camera for {found_img.name} is None.")
785
+ return np.zeros((H, W), dtype=np.float32), False, found_img
786
+ cur_res = np.array([found_camera.height, found_camera.width])
787
  exp_res = np.array([H, W])
788
  proj = proj * exp_res / cur_res
789
 
 
797
  # We'll compute depth as Z in camera coords
798
  # from the world->cam transform col_img holds
799
  mat4x4 = np.eye(4)
800
+ if found_img.cam_from_world is None:
801
+ raise ValueError(f"Camera for {found_img.name} is None.")
802
  mat4x4[:3, :4] = found_img.cam_from_world.matrix()
803
  p_cam = mat4x4 @ np.array([xyz[0], xyz[1], xyz[2], 1.0])
804
  z_vals.append(p_cam[2] / p_cam[3])
 
891
  return depth_fitted, depth_sparse, True, col_img
892
 
893
 
894
+ def prune_too_far(wireframe_3d, colmap_rec, th=3.0):
895
  """
896
  Prune vertices that are too far from sparse point cloud
897
 
 
899
  xyz_sfm = []
900
  for k, v in colmap_rec.points3D.items():
901
  xyz_sfm.append(v.xyz)
902
+ xyz_sfm = np.array(xyz_sfm) # (M, 3)
903
+
904
+ vertices_np = np.array([v.xyz for v in wireframe_3d.vertices]) # (N, 3)
905
+
906
+ distmat = cdist(vertices_np, xyz_sfm) # (N, M)
907
+ mindist = distmat.min(axis=1) # (N,)
908
  mask = mindist <= th
909
+ vertices_new: list[WireframePoint3D] = [
910
+ v for v, m in zip(wireframe_3d.vertices, mask) if m
911
+ ]
912
+ old_idx_survived = np.arange(len(wireframe_3d.vertices))[mask]
913
+
914
+ old_to_new_idx = {old_idx_survived[i]: i for i in range(len(old_idx_survived))}
915
  connections_3d_new = [
916
+ WireframeEdge(i1=int(old_to_new_idx[conn.i1]), i2=int(old_to_new_idx[conn.i2]))
917
+ for conn in wireframe_3d.edges
918
+ if conn.i1 in old_to_new_idx and conn.i2 in old_to_new_idx
919
  ]
920
+ return Wireframe3D(
921
+ vertices_new,
922
+ connections_3d_new,
923
+ )
924
 
925
 
926
+ def predict_wireframe(entry) -> Wireframe3D:
927
  """
928
  Predict 3D wireframe from a dataset entry.
929
  """
930
  good_entry = convert_entry_to_human_readable(entry)
931
+ vert_edge_per_image: dict[int, Wireframe2DWith3D] = {}
932
  for i, (gest, depth, K, R, t, img_id, ade_seg) in enumerate(
933
  zip(
934
  good_entry["gestalt"],
 
950
  gest_seg_np = np.array(gest_seg).astype(np.uint8)
951
 
952
  # Get 2D vertices and edges first
953
+ wireframe2d = get_vertices_and_edges_from_segmentation_contours(
954
  gest_seg_np, edge_th=10.0
955
  )
956
+ vertices = wireframe2d.vertices
957
+ connections = wireframe2d.edges
958
 
959
  # Check if we have enough to proceed
960
  if (len(vertices) < 2) or (len(connections) < 1):
961
  print(f"Not enough vertices or connections found in image {i}, skipping.")
962
+ vert_edge_per_image[i] = Wireframe2DWith3D(
963
+ wireframe2d=wireframe2d, vertices_3d=[]
964
+ )
965
  continue
966
 
967
  # Call the refactored function to get 3D points
 
969
  vertices, connections, depth, colmap_rec, img_id, ade_seg
970
  )
971
  # Store original 2D vertices, connections, and computed 3D points
972
+ vert_edge_per_image[i] = Wireframe2DWith3D(
973
+ wireframe2d=wireframe2d, vertices_3d=vertices_3d
974
+ )
975
 
976
  # Merge vertices from all images
977
+ wireframe_3d = merge_vertices_3d(vert_edge_per_image, 0.5)
978
+ # wireframe_3d_clean = prune_not_connected(wireframe_3d, keep_largest=False)
979
+ wireframe_3d_clean = wireframe_3d
980
+ wireframe_3d_clean = prune_too_far(wireframe_3d_clean, colmap_rec, th=4.0)
 
 
 
981
 
982
+ if (len(wireframe_3d_clean.vertices) < 2) or len(wireframe_3d_clean.edges) < 1:
983
  print("Not enough vertices or connections in the 3D vertices")
984
  return empty_solution()
985
 
986
+ return wireframe_3d_clean
process_sample.py CHANGED
@@ -2,10 +2,9 @@ import io
2
  import tempfile
3
  import zipfile
4
 
5
- import numpy as np
6
  import pycolmap
7
 
8
- from example_solutions_copy import predict_wireframe
9
 
10
 
11
  def read_colmap_rec(colmap_data):
@@ -17,21 +16,16 @@ def read_colmap_rec(colmap_data):
17
  return rec
18
 
19
 
20
- def empty_solution():
21
- """Return a minimal valid solution, i.e. 2 vertices and 1 edge."""
22
- return np.zeros((2, 3)), [(0, 1)]
23
-
24
-
25
  def process_sample(sample, handle_error=True):
26
  try:
27
- pred_vertices, pred_edges = predict_wireframe(sample)
28
  except Exception:
29
  if handle_error:
30
- pred_vertices, pred_edges = empty_solution()
31
  else:
32
  raise
33
  return {
34
  "order_id": sample["order_id"],
35
- "wf_vertices": pred_vertices.tolist(),
36
- "wf_edges": pred_edges,
37
  }
 
2
  import tempfile
3
  import zipfile
4
 
 
5
  import pycolmap
6
 
7
+ from example_solutions_copy import empty_solution, predict_wireframe
8
 
9
 
10
  def read_colmap_rec(colmap_data):
 
16
  return rec
17
 
18
 
 
 
 
 
 
19
  def process_sample(sample, handle_error=True):
20
  try:
21
+ pred_wireframe = predict_wireframe(sample)
22
  except Exception:
23
  if handle_error:
24
+ pred_wireframe = empty_solution()
25
  else:
26
  raise
27
  return {
28
  "order_id": sample["order_id"],
29
+ "wf_vertices": pred_wireframe.vertices_np,
30
+ "wf_edges": pred_wireframe.edges_np,
31
  }