IhorIvanyshyn01 commited on
Commit
29e5ee8
·
1 Parent(s): ffa2c1d

Enhance wireframe prediction filtering and clustering

Browse files
Files changed (1) hide show
  1. script.py +317 -3
script.py CHANGED
@@ -11,16 +11,330 @@ import json
11
  import numpy as np
12
  from datasets import load_dataset
13
  from typing import Dict
14
- from hoho2025.example_solutions import predict_wireframe
15
  from joblib import Parallel, delayed
 
 
 
 
 
 
 
 
 
 
16
 
17
  def empty_solution(sample):
18
  '''Return a minimal valid solution, i.e. 2 vertices and 1 edge.'''
19
  return np.zeros((2,3)), [(0, 1)]
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  def predict_wireframe_safely(sample):
22
  try:
23
- pred_vertices, pred_edges = predict_wireframe(sample)
24
  except Exception as e:
25
  print (f"Failed due to {e}, returning empty solution")
26
  pred_vertices, pred_edges = empty_solution(sample)
@@ -155,4 +469,4 @@ if __name__ == "__main__":
155
  with open("submission.json", "w") as f:
156
  json.dump(solution, f)
157
 
158
- print("------------ Done ------------ ")
 
11
  import numpy as np
12
  from datasets import load_dataset
13
  from typing import Dict
 
14
  from joblib import Parallel, delayed
15
+ from sklearn.cluster import DBSCAN
16
+ from hoho2025 import example_solutions as hoho_example
17
+
18
+ VERTEX_MERGE_EPS = 0.5
19
+ EDGE_MIN_LENGTH = 0.5
20
+ EDGE_MIN_SUPPORT_IMAGES = 2
21
+ EDGE_MAX_ANGLE_DEG = 25.0
22
+ VERTEX_MAX_COLMAP_DIST = 4.0
23
+ VERTEX_MIN_EDGE_DEGREE = 2
24
+ VERTEX_MIN_VIEW_COUNT = 2
25
 
26
  def empty_solution(sample):
27
  '''Return a minimal valid solution, i.e. 2 vertices and 1 edge.'''
28
  return np.zeros((2,3)), [(0, 1)]
29
 
30
+ def _unit_vector(vector):
31
+ norm = np.linalg.norm(vector)
32
+ if norm == 0:
33
+ return None
34
+ return vector / norm
35
+
36
+ def merge_vertices_dbscan(vertices_3d, edges, eps=VERTEX_MERGE_EPS):
37
+ """Cluster nearby 3D vertices, replace clusters with centroids, and remap edges."""
38
+ vertices_3d = np.asarray(vertices_3d, dtype=float)
39
+ if vertices_3d.ndim != 2 or vertices_3d.shape[1] != 3 or len(vertices_3d) == 0:
40
+ return vertices_3d.reshape((-1, 3)), []
41
+
42
+ clustering = DBSCAN(eps=eps, min_samples=1).fit(vertices_3d)
43
+ labels = clustering.labels_
44
+ unique_labels = np.unique(labels)
45
+ label_to_new_idx = {label: idx for idx, label in enumerate(unique_labels)}
46
+
47
+ merged_vertices = np.stack(
48
+ [vertices_3d[labels == label].mean(axis=0) for label in unique_labels],
49
+ axis=0,
50
+ )
51
+
52
+ remapped_edges = []
53
+ seen_edges = set()
54
+ for a, b in edges:
55
+ a = int(a)
56
+ b = int(b)
57
+ if a < 0 or b < 0 or a >= len(labels) or b >= len(labels):
58
+ continue
59
+
60
+ new_a = label_to_new_idx[labels[a]]
61
+ new_b = label_to_new_idx[labels[b]]
62
+ if new_a == new_b:
63
+ continue
64
+
65
+ edge = (min(new_a, new_b), max(new_a, new_b))
66
+ if edge in seen_edges:
67
+ continue
68
+
69
+ seen_edges.add(edge)
70
+ remapped_edges.append(edge)
71
+
72
+ return merged_vertices, remapped_edges
73
+
74
+ def merge_vertices_dbscan_with_edge_stats(vert_edge_per_image, eps=VERTEX_MERGE_EPS):
75
+ """Merge multi-view vertices with DBSCAN and keep edge support/direction stats."""
76
+ all_vertices = []
77
+ old_edges = []
78
+ vertex_types = []
79
+ vertex_src_images = []
80
+ cur_start = 0
81
+
82
+ for img_idx, (vertices, connections, vertices_3d) in vert_edge_per_image.items():
83
+ vertices_3d = np.asarray(vertices_3d, dtype=float).reshape((-1, 3))
84
+ if len(vertices_3d) == 0:
85
+ continue
86
+
87
+ all_vertices.append(vertices_3d)
88
+ vertex_types.extend([int(v.get('type') == 'apex') for v in vertices])
89
+ vertex_src_images.extend([img_idx] * len(vertices_3d))
90
+
91
+ for a, b in connections:
92
+ a = int(a)
93
+ b = int(b)
94
+ if a < 0 or b < 0 or a >= len(vertices_3d) or b >= len(vertices_3d):
95
+ continue
96
+ old_edges.append((cur_start + a, cur_start + b, img_idx))
97
+
98
+ cur_start += len(vertices_3d)
99
+
100
+ if not all_vertices:
101
+ return np.empty((0, 3)), [], np.array([], dtype=int), {}, {}
102
+
103
+ all_vertices = np.concatenate(all_vertices, axis=0)
104
+ labels = DBSCAN(eps=eps, min_samples=1).fit(all_vertices).labels_
105
+
106
+ # Keep apex and non-apex vertices from collapsing into the same final corner.
107
+ vertex_types = np.asarray(vertex_types, dtype=int)
108
+ cluster_keys = [(int(label), int(vtype)) for label, vtype in zip(labels, vertex_types)]
109
+ unique_keys = sorted(set(cluster_keys))
110
+ key_to_new_idx = {key: idx for idx, key in enumerate(unique_keys)}
111
+ old_to_new_idx = np.array([key_to_new_idx[key] for key in cluster_keys], dtype=int)
112
+
113
+ merged_vertices = np.stack(
114
+ [
115
+ all_vertices[old_to_new_idx == idx].mean(axis=0)
116
+ for idx in range(len(unique_keys))
117
+ ],
118
+ axis=0,
119
+ )
120
+
121
+ vertex_src_images = np.asarray(vertex_src_images)
122
+ vertex_view_count = np.array(
123
+ [
124
+ len(set(vertex_src_images[old_to_new_idx == idx]))
125
+ for idx in range(len(unique_keys))
126
+ ],
127
+ dtype=int,
128
+ )
129
+
130
+ edge_image_sets = {}
131
+ edge_dirs = {}
132
+ for old_a, old_b, img_idx in old_edges:
133
+ new_a = int(old_to_new_idx[old_a])
134
+ new_b = int(old_to_new_idx[old_b])
135
+ if new_a == new_b:
136
+ continue
137
+
138
+ edge = (min(new_a, new_b), max(new_a, new_b))
139
+ direction = _unit_vector(all_vertices[old_b] - all_vertices[old_a])
140
+ if direction is None:
141
+ continue
142
+
143
+ final_direction = _unit_vector(merged_vertices[edge[1]] - merged_vertices[edge[0]])
144
+ if final_direction is None:
145
+ continue
146
+ if np.dot(direction, final_direction) < 0:
147
+ direction = -direction
148
+
149
+ edge_image_sets.setdefault(edge, set()).add(img_idx)
150
+ edge_dirs.setdefault(edge, []).append(direction)
151
+
152
+ edge_vote_count = {edge: len(imgs) for edge, imgs in edge_image_sets.items()}
153
+ edge_angle_ok = {
154
+ edge: edge_directions_are_consistent(dirs, EDGE_MAX_ANGLE_DEG)
155
+ for edge, dirs in edge_dirs.items()
156
+ }
157
+
158
+ return merged_vertices, list(edge_image_sets.keys()), vertex_view_count, edge_vote_count, edge_angle_ok
159
+
160
+ def edge_directions_are_consistent(directions, max_angle_deg=EDGE_MAX_ANGLE_DEG):
161
+ """Check whether per-image 3D edge directions agree after sign alignment."""
162
+ if len(directions) <= 1:
163
+ return True
164
+
165
+ directions = np.asarray(directions, dtype=float)
166
+ mean_direction = _unit_vector(directions.mean(axis=0))
167
+ if mean_direction is None:
168
+ return False
169
+
170
+ min_cos = np.cos(np.deg2rad(max_angle_deg))
171
+ return bool(np.all(directions @ mean_direction >= min_cos))
172
+
173
+ def filter_edges_by_geometry(vertices_3d, edges, edge_vote_count=None, edge_angle_ok=None,
174
+ min_len=EDGE_MIN_LENGTH, min_support=EDGE_MIN_SUPPORT_IMAGES):
175
+ """Drop short, weakly supported, or direction-inconsistent edges."""
176
+ vertices_3d = np.asarray(vertices_3d, dtype=float)
177
+ if len(vertices_3d) == 0:
178
+ return []
179
+
180
+ filtered_edges = []
181
+ seen_edges = set()
182
+ for a, b in edges:
183
+ a = int(a)
184
+ b = int(b)
185
+ if a < 0 or b < 0 or a >= len(vertices_3d) or b >= len(vertices_3d) or a == b:
186
+ continue
187
+
188
+ edge = (min(a, b), max(a, b))
189
+ if edge in seen_edges:
190
+ continue
191
+
192
+ edge_length = np.linalg.norm(vertices_3d[edge[1]] - vertices_3d[edge[0]])
193
+ if edge_length < min_len:
194
+ continue
195
+
196
+ if edge_vote_count is not None and edge_vote_count.get(edge, 0) < min_support:
197
+ continue
198
+
199
+ if edge_angle_ok is not None and not edge_angle_ok.get(edge, True):
200
+ continue
201
+
202
+ seen_edges.add(edge)
203
+ filtered_edges.append(edge)
204
+
205
+ return filtered_edges
206
+
207
+ def prune_bad_vertices(vertices_3d, edges, colmap_rec, vertex_view_count,
208
+ max_colmap_dist=VERTEX_MAX_COLMAP_DIST,
209
+ min_edge_degree=VERTEX_MIN_EDGE_DEGREE,
210
+ min_view_count=VERTEX_MIN_VIEW_COUNT):
211
+ """Remove vertices with weak COLMAP, edge-degree, or multi-view support."""
212
+ vertices_3d = np.asarray(vertices_3d, dtype=float)
213
+ vertex_view_count = np.asarray(vertex_view_count, dtype=int)
214
+ if len(vertices_3d) == 0:
215
+ return vertices_3d.reshape((-1, 3)), []
216
+
217
+ valid_edges = []
218
+ for a, b in edges:
219
+ a = int(a)
220
+ b = int(b)
221
+ if 0 <= a < len(vertices_3d) and 0 <= b < len(vertices_3d) and a != b:
222
+ valid_edges.append((a, b))
223
+
224
+ colmap_mask = np.ones(len(vertices_3d), dtype=bool)
225
+ xyz_sfm = [point.xyz for point in colmap_rec.points3D.values()]
226
+ if xyz_sfm:
227
+ xyz_sfm = np.asarray(xyz_sfm, dtype=float)
228
+ diff = vertices_3d[:, None, :] - xyz_sfm[None, :, :]
229
+ min_dist = np.sqrt((diff ** 2).sum(axis=-1)).min(axis=1)
230
+ colmap_mask = min_dist <= max_colmap_dist
231
+
232
+ view_mask = vertex_view_count >= min_view_count
233
+ keep_mask = colmap_mask & view_mask
234
+
235
+ while True:
236
+ degree = np.zeros(len(vertices_3d), dtype=int)
237
+ for a, b in valid_edges:
238
+ if keep_mask[a] and keep_mask[b]:
239
+ degree[a] += 1
240
+ degree[b] += 1
241
+
242
+ next_keep_mask = keep_mask & (degree >= min_edge_degree)
243
+ if np.array_equal(next_keep_mask, keep_mask):
244
+ break
245
+ keep_mask = next_keep_mask
246
+
247
+ old_to_new = {}
248
+ new_vertices = []
249
+ for old_idx, keep in enumerate(keep_mask):
250
+ if keep:
251
+ old_to_new[old_idx] = len(new_vertices)
252
+ new_vertices.append(vertices_3d[old_idx])
253
+
254
+ new_edges = []
255
+ seen_edges = set()
256
+ for a, b in valid_edges:
257
+ if a not in old_to_new or b not in old_to_new:
258
+ continue
259
+ edge = tuple(sorted((old_to_new[a], old_to_new[b])))
260
+ if edge in seen_edges:
261
+ continue
262
+ seen_edges.add(edge)
263
+ new_edges.append(edge)
264
+
265
+ if not new_vertices:
266
+ return np.empty((0, 3)), []
267
+ return np.asarray(new_vertices), new_edges
268
+
269
+ def predict_wireframe_filtered(sample, verbose=False):
270
+ """Baseline prediction with DBSCAN vertex merge plus stricter edge filtering."""
271
+ good_entry = hoho_example.convert_entry_to_human_readable(sample)
272
+ vert_edge_per_image = {}
273
+ colmap_rec = None
274
+
275
+ for i, (gest, depth, img_id, ade_seg) in enumerate(zip(
276
+ good_entry['gestalt'],
277
+ good_entry['depth'],
278
+ good_entry['image_ids'],
279
+ good_entry['ade'],
280
+ )):
281
+ if 'colmap' in good_entry:
282
+ colmap_rec = good_entry['colmap']
283
+ else:
284
+ colmap_rec = good_entry['colmap_binary']
285
+
286
+ depth_size = (np.array(depth).shape[1], np.array(depth).shape[0])
287
+ gest_seg_np = np.array(gest.resize(depth_size)).astype(np.uint8)
288
+ vertices, connections = hoho_example.get_vertices_and_edges_from_segmentation(
289
+ gest_seg_np, edge_th=10.
290
+ )
291
+
292
+ ade_seg_np = np.array(ade_seg.resize(depth_size)).astype(np.uint8)
293
+ vertices, connections = hoho_example.filter_vertices_by_background(
294
+ vertices, connections, ade_seg_np
295
+ )
296
+
297
+ if (len(vertices) < 2) or (len(connections) < 1):
298
+ if verbose:
299
+ print(f'Not enough vertices or connections found in image {i}, skipping.')
300
+ vert_edge_per_image[i] = [], [], np.empty((0, 3))
301
+ continue
302
+
303
+ vertices_3d = hoho_example.create_3d_wireframe_single_image(
304
+ vertices, connections, depth, colmap_rec, img_id, ade_seg, verbose=verbose
305
+ )
306
+ vert_edge_per_image[i] = vertices, connections, vertices_3d
307
+
308
+ all_3d_vertices, connections_3d, vertex_view_count, edge_vote_count, edge_angle_ok = (
309
+ merge_vertices_dbscan_with_edge_stats(vert_edge_per_image, VERTEX_MERGE_EPS)
310
+ )
311
+ connections_3d = filter_edges_by_geometry(
312
+ all_3d_vertices,
313
+ connections_3d,
314
+ edge_vote_count=edge_vote_count,
315
+ edge_angle_ok=edge_angle_ok,
316
+ )
317
+ if colmap_rec is None or len(all_3d_vertices) < 2 or len(connections_3d) < 1:
318
+ return empty_solution(sample)
319
+
320
+ all_3d_vertices_clean, connections_3d_clean = prune_bad_vertices(
321
+ all_3d_vertices, connections_3d, colmap_rec, vertex_view_count
322
+ )
323
+ all_3d_vertices_clean, connections_3d_clean = hoho_example.prune_not_connected(
324
+ all_3d_vertices_clean, connections_3d_clean, keep_largest=False
325
+ )
326
+
327
+ if (len(all_3d_vertices_clean) < 2) or len(connections_3d_clean) < 1:
328
+ if verbose:
329
+ print(f'Not enough vertices or connections in the 3D vertices')
330
+ return empty_solution(sample)
331
+
332
+ connections_3d_clean = [(int(a), int(b)) for a, b in connections_3d_clean]
333
+ return all_3d_vertices_clean, connections_3d_clean
334
+
335
  def predict_wireframe_safely(sample):
336
  try:
337
+ pred_vertices, pred_edges = predict_wireframe_filtered(sample)
338
  except Exception as e:
339
  print (f"Failed due to {e}, returning empty solution")
340
  pred_vertices, pred_edges = empty_solution(sample)
 
469
  with open("submission.json", "w") as f:
470
  json.dump(solution, f)
471
 
472
+ print("------------ Done ------------ ")