jskvrna commited on
Commit
d2ea57b
·
1 Parent(s): f4eb848

Initial commit

Browse files
Files changed (1) hide show
  1. test.py +401 -0
test.py ADDED
@@ -0,0 +1,401 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+ from hoho2025.vis import plot_all_modalities
3
+ from hoho2025.viz3d import *
4
+ import pycolmap
5
+ import tempfile,zipfile
6
+ import io
7
+ import open3d as o3d
8
+
9
+ from hoho2025.example_solutions import predict_wireframe
10
+
11
+ def read_colmap_rec(colmap_data):
12
+ with tempfile.TemporaryDirectory() as tmpdir:
13
+ with zipfile.ZipFile(io.BytesIO(colmap_data), "r") as zf:
14
+ zf.extractall(tmpdir) # unpacks cameras.txt, images.txt, etc. to tmpdir
15
+ # Now parse with pycolmap
16
+ rec = pycolmap.Reconstruction(tmpdir)
17
+ return rec
18
+
19
+ def _plotly_rgb_to_normalized_o3d_color(color_val) -> list[float]:
20
+ """
21
+ Converts Plotly-style color (str 'rgb(r,g,b)' or tuple (r,g,b))
22
+ to normalized [r/255, g/255, b/255] for Open3D.
23
+ """
24
+ if isinstance(color_val, str):
25
+ if color_val.startswith('rgba'): # e.g. 'rgba(255,0,0,0.5)'
26
+ parts = color_val[5:-1].split(',')
27
+ return [int(p.strip())/255.0 for p in parts[:3]] # Ignore alpha
28
+ elif color_val.startswith('rgb'): # e.g. 'rgb(255,0,0)'
29
+ parts = color_val[4:-1].split(',')
30
+ return [int(p.strip())/255.0 for p in parts]
31
+ else:
32
+ # Basic color names are not directly supported by this helper for Open3D.
33
+ # Plotly might resolve them, but Open3D needs explicit RGB.
34
+ # Consider adding a name-to-RGB mapping if needed.
35
+ raise ValueError(f"Unsupported color string format for Open3D conversion: {color_val}. Expected 'rgb(...)' or 'rgba(...)'.")
36
+ elif isinstance(color_val, (list, tuple)) and len(color_val) == 3:
37
+ # Assuming input tuple/list is in 0-255 range (e.g., from edge_color_mapping)
38
+ return [c/255.0 for c in color_val]
39
+ raise ValueError(f"Unsupported color type for Open3D conversion: {type(color_val)}. Expected string or 3-element tuple/list.")
40
+
41
+
42
+ def plot_reconstruction_local(
43
+ fig: go.Figure,
44
+ rec: pycolmap.Reconstruction, # Added type hint
45
+ color: str = 'rgb(0, 0, 255)',
46
+ name: Optional[str] = None,
47
+ points: bool = True,
48
+ cameras: bool = True,
49
+ cs: float = 1.0,
50
+ single_color_points=False,
51
+ camera_color='rgba(0, 255, 0, 0.5)',
52
+ crop_outliers: bool = False):
53
+ # rec is a pycolmap.Reconstruction object
54
+ # Filter outliers
55
+ xyzs = []
56
+ rgbs = []
57
+ # Iterate over rec.points3D
58
+ for k, p3D in rec.points3D.items():
59
+ #print (p3D)
60
+ xyzs.append(p3D.xyz)
61
+ rgbs.append(p3D.color)
62
+
63
+ xyzs = np.array(xyzs)
64
+ rgbs = np.array(rgbs)
65
+
66
+ # Crop outliers if requested
67
+ if crop_outliers and len(xyzs) > 0:
68
+ # Calculate distances from origin
69
+ distances = np.linalg.norm(xyzs, axis=1)
70
+ # Find threshold at 98th percentile (removing 2% furthest points)
71
+ threshold = np.percentile(distances, 98)
72
+ # Filter points
73
+ mask = distances <= threshold
74
+ xyzs = xyzs[mask]
75
+ rgbs = rgbs[mask]
76
+ print(f"Cropped outliers: removed {np.sum(~mask)} out of {len(mask)} points ({np.sum(~mask)/len(mask)*100:.2f}%)")
77
+
78
+ if points and len(xyzs) > 0:
79
+ pcd = o3d.geometry.PointCloud()
80
+ pcd.points = o3d.utility.Vector3dVector(xyzs)
81
+
82
+ # Normalize RGB colors from [0, 255] to [0, 1] for Open3D
83
+ # and ensure rgbs is not empty before normalization
84
+ if rgbs.size > 0:
85
+ normalized_rgbs = rgbs / 255.0
86
+ pcd.colors = o3d.utility.Vector3dVector(normalized_rgbs)
87
+
88
+ # Original Plotly plot_points call is replaced by Open3D visualization:
89
+ # plot_points(fig, xyzs, color=color if single_color_points else rgbs, ps=1, name=name)
90
+
91
+ # This code assumes it's placed within the plot_reconstruction_local function,
92
+ # after point cloud processing, and that a list `geometries` (List[o3d.geometry.Geometry])
93
+ # is defined in the function's scope to collect all geometries.
94
+ # It uses arguments `cameras`, `rec`, `cs`, `camera_color` from the function signature.
95
+ # The helper `_plotly_rgb_to_normalized_o3d_color` is assumed to be available.
96
+
97
+ if cameras: # Check if camera visualization is enabled
98
+ try:
99
+ # Convert Plotly-style camera_color string to normalized RGB for Open3D
100
+ cam_color_normalized = _plotly_rgb_to_normalized_o3d_color(camera_color)
101
+ except ValueError as e:
102
+ print(f"Warning: Invalid camera_color '{camera_color}'. Using default green. Error: {e}")
103
+ cam_color_normalized = [0.0, 1.0, 0.0] # Default to green
104
+
105
+ geometries = []
106
+
107
+ for image_id, image_data in rec.images.items():
108
+ # Get camera object and its intrinsic matrix K
109
+ cam = rec.cameras[image_data.camera_id]
110
+ K = cam.calibration_matrix()
111
+
112
+ # Validate intrinsics (e.g., focal length check from original code)
113
+ if K[0, 0] > 5000 or K[1, 1] > 5000:
114
+ print(f"Skipping camera for image {image_id} due to large focal length (fx={K[0,0]}, fy={K[1,1]}).")
115
+ continue
116
+
117
+ # Get camera pose (T_world_cam = T_cam_world.inverse())
118
+ # image_data.cam_from_world is T_cam_world (camera coordinates from world coordinates)
119
+ T_world_cam = image_data.cam_from_world.inverse()
120
+ R_wc = T_world_cam.rotation.matrix() # Rotation: camera frame to world frame
121
+ t_wc = T_world_cam.translation # Origin of camera frame in world coordinates (pyramid apex)
122
+
123
+ W, H = float(cam.width), float(cam.height)
124
+
125
+ # Skip if camera dimensions are invalid
126
+ if W <= 0 or H <= 0:
127
+ print(f"Skipping camera for image {image_id} due to invalid dimensions (W={W}, H={H}).")
128
+ continue
129
+
130
+ # Define image plane corners in pixel coordinates (top-left, top-right, bottom-right, bottom-left)
131
+ img_corners_px = np.array([
132
+ [0, 0], [W, 0], [W, H], [0, H]
133
+ ], dtype=float)
134
+
135
+ # Convert pixel corners to homogeneous coordinates
136
+ img_corners_h = np.hstack([img_corners_px, np.ones((4, 1))])
137
+
138
+ try:
139
+ K_inv = np.linalg.inv(K)
140
+ except np.linalg.LinAlgError:
141
+ print(f"Skipping camera for image {image_id} due to non-invertible K matrix.")
142
+ continue
143
+
144
+ # Unproject pixel corners to normalized camera coordinates (points on z=1 plane in camera frame)
145
+ cam_coords_norm = (K_inv @ img_corners_h.T).T
146
+
147
+ # Scale these points by 'cs' (camera scale factor, determines frustum size)
148
+ # These points are ( (x/z)*cs, (y/z)*cs, cs ) in the camera's coordinate system.
149
+ cam_coords_scaled = cam_coords_norm * cs
150
+
151
+ # Transform scaled base corners from camera coordinates to world coordinates
152
+ world_coords_base = (R_wc @ cam_coords_scaled.T).T + t_wc
153
+
154
+ # Define vertices for the camera pyramid visualization
155
+ # Vertex 0 is the apex (camera center), vertices 1-4 are the base corners
156
+ pyramid_vertices = np.vstack((t_wc, world_coords_base))
157
+ if not pyramid_vertices.flags['C_CONTIGUOUS']:
158
+ pyramid_vertices = np.ascontiguousarray(pyramid_vertices, dtype=np.float64)
159
+ elif pyramid_vertices.dtype != np.float64:
160
+ pyramid_vertices = np.asarray(pyramid_vertices, dtype=np.float64)
161
+
162
+ # Define lines for the pyramid: 4 from apex to base, 4 for the base rectangle
163
+ lines = np.array([
164
+ [0, 1], [0, 2], [0, 3], [0, 4], # Apex to base corners
165
+ [1, 2], [2, 3], [3, 4], [4, 1] # Base rectangle
166
+ ])
167
+
168
+ if not lines.flags['C_CONTIGUOUS']:
169
+ lines = np.ascontiguousarray(lines, dtype=np.int32)
170
+ elif lines.dtype != np.int32:
171
+ lines = np.asarray(lines, dtype=np.int32)
172
+
173
+ # Create Open3D LineSet object for the camera pyramid
174
+ camera_pyramid_lineset = o3d.geometry.LineSet()
175
+ camera_pyramid_lineset.points = o3d.utility.Vector3dVector(pyramid_vertices)
176
+ camera_pyramid_lineset.lines = o3d.utility.Vector2iVector(lines)
177
+
178
+ # Add the camera visualization to the list of geometries to be rendered
179
+ geometries.append(camera_pyramid_lineset)
180
+
181
+ else:
182
+ geometries = []
183
+
184
+ return pcd, geometries
185
+
186
+ def plot_wireframe_local(
187
+ fig: go.Figure, # This argument is no longer used for plotting by this function.
188
+ vertices: np.ndarray,
189
+ edges: np.ndarray,
190
+ classifications: np.ndarray = None,
191
+ color: str = 'rgb(0, 0, 255)', # Default color for vertices and unclassified/default edges.
192
+ name: Optional[str] = None, # No direct equivalent for Open3D geometry list's name/legend.
193
+ **kwargs) -> list: # Returns a list of o3d.geometry.Geometry objects.
194
+ """
195
+ Generates Open3D geometries for a wireframe: a PointCloud for vertices
196
+ and a LineSet for edges.
197
+
198
+ Args:
199
+ fig: Plotly figure object (no longer used for plotting by this function).
200
+ vertices: Numpy array of vertex coordinates (N, 3).
201
+ edges: Numpy array of edge connections (M, 2), indices into vertices.
202
+ classifications: Optional numpy array of classifications for edges.
203
+ color: Default color string (e.g., 'rgb(0,0,255)') for vertices and
204
+ for edges if classifications are not provided or don't match.
205
+ name: Optional name (unused in Open3D context here).
206
+ **kwargs: Additional keyword arguments (unused).
207
+
208
+ Returns:
209
+ A list of Open3D geometry objects. Empty if no vertices.
210
+ Note: Plotly-specific 'ps' (point size) and line width are not
211
+ directly translated. These are typically visualizer render options in Open3D.
212
+ """
213
+ open3d_geometries = []
214
+
215
+ # Ensure gt_vertices is numpy array, C-contiguous, and float64
216
+ # np.asarray avoids a copy if 'vertices' is already a suitable ndarray.
217
+ gt_vertices = np.asarray(vertices)
218
+ if not gt_vertices.flags['C_CONTIGUOUS'] or gt_vertices.dtype != np.float64:
219
+ gt_vertices = np.ascontiguousarray(gt_vertices, dtype=np.float64)
220
+
221
+ # Ensure gt_connections is numpy array, C-contiguous, and int32
222
+ gt_connections = np.asarray(edges)
223
+ if not gt_connections.flags['C_CONTIGUOUS'] or gt_connections.dtype != np.int32:
224
+ gt_connections = np.ascontiguousarray(gt_connections, dtype=np.int32)
225
+
226
+ if gt_vertices.size == 0:
227
+ return open3d_geometries
228
+
229
+ # 1. Create PointCloud for Vertices
230
+ try:
231
+ vertex_rgb_normalized = _plotly_rgb_to_normalized_o3d_color(color)
232
+ except ValueError as e:
233
+ print(f"Warning: Could not parse vertex color '{color}'. Using default black. Error: {e}")
234
+ vertex_rgb_normalized = [0.0, 0.0, 0.0] # Default to black on error
235
+
236
+ vertex_pcd = o3d.geometry.PointCloud()
237
+ # gt_vertices is now C-contiguous and float64
238
+ vertex_pcd.points = o3d.utility.Vector3dVector(gt_vertices)
239
+
240
+ num_vertices = len(gt_vertices)
241
+ if num_vertices > 0:
242
+ # Create vertex_colors_np with correct dtype
243
+ vertex_colors_np = np.array([vertex_rgb_normalized] * num_vertices, dtype=np.float64)
244
+ # Ensure C-contiguity (dtype is already float64 from np.array construction)
245
+ # This check is a safeguard, as np.array from a list of lists with specified dtype is usually contiguous.
246
+ if not vertex_colors_np.flags['C_CONTIGUOUS']:
247
+ vertex_colors_np = np.ascontiguousarray(vertex_colors_np) # Preserves dtype
248
+ vertex_pcd.colors = o3d.utility.Vector3dVector(vertex_colors_np)
249
+ open3d_geometries.append(vertex_pcd)
250
+
251
+ # 2. Create LineSet for Edges
252
+ if gt_connections.size > 0 and num_vertices > 0: # Edges need vertices
253
+ line_set = o3d.geometry.LineSet()
254
+ # gt_vertices is already C-contiguous and float64
255
+ line_set.points = o3d.utility.Vector3dVector(gt_vertices)
256
+ # gt_connections is already C-contiguous and int32
257
+ line_set.lines = o3d.utility.Vector2iVector(gt_connections)
258
+
259
+ line_colors_list_normalized = []
260
+ if classifications is not None and len(classifications) == len(gt_connections):
261
+ # Assuming EDGE_CLASSES_BY_ID and edge_color_mapping are defined in the global scope
262
+ # or imported, as implied by the original code structure.
263
+ for c_idx in classifications:
264
+ try:
265
+ color_tuple_255 = edge_color_mapping[EDGE_CLASSES_BY_ID[c_idx]]
266
+ line_colors_list_normalized.append(_plotly_rgb_to_normalized_o3d_color(color_tuple_255))
267
+ except KeyError: # Handle cases where classification ID or mapping is not found
268
+ print(f"Warning: Classification ID {c_idx} or its mapping not found. Using default color.")
269
+ line_colors_list_normalized.append(vertex_rgb_normalized) # Fallback to default vertex color
270
+ except Exception as e:
271
+ print(f"Warning: Error processing classification color for index {c_idx}. Using default. Error: {e}")
272
+ line_colors_list_normalized.append(vertex_rgb_normalized) # Fallback
273
+ else:
274
+ # Use the default 'color' for all lines if no classifications or mismatch
275
+ default_line_rgb_normalized = vertex_rgb_normalized # Same as vertex color by default
276
+ for _ in range(len(gt_connections)):
277
+ line_colors_list_normalized.append(default_line_rgb_normalized)
278
+
279
+ if line_colors_list_normalized: # Check if list is not empty
280
+ # Create line_colors_np with correct dtype
281
+ line_colors_np = np.array(line_colors_list_normalized, dtype=np.float64)
282
+ # Ensure C-contiguity (dtype is already float64)
283
+ # Safeguard, similar to vertex_colors_np.
284
+ if not line_colors_np.flags['C_CONTIGUOUS']:
285
+ line_colors_np = np.ascontiguousarray(line_colors_np) # Preserves dtype
286
+ line_set.colors = o3d.utility.Vector3dVector(line_colors_np)
287
+
288
+ open3d_geometries.append(line_set)
289
+
290
+ return open3d_geometries
291
+
292
+ def plot_bpo_cameras_from_entry_local(fig: go.Figure, entry: dict, idx = None, camera_scale_factor: float = 1.0):
293
+ def cam2world_to_world2cam(R, t):
294
+ rt = np.eye(4)
295
+ rt[:3,:3] = R
296
+ rt[:3,3] = t.reshape(-1)
297
+ rt = np.linalg.inv(rt)
298
+ return rt[:3,:3], rt[:3,3]
299
+ geometries = []
300
+ for i in range(len(entry['R'])):
301
+ if idx is not None and i != idx:
302
+ continue
303
+
304
+ # Parameters for this camera visualization
305
+ # current_cam_size = 1.0 # Original 'size = 1.' - Replaced by camera_scale_factor
306
+ current_cam_color_str = 'rgb(0, 0, 255)' # Original 'color = 'rgb(0, 0, 255)''
307
+
308
+ # Load camera parameters from entry
309
+ K_matrix = np.array(entry['K'][i])
310
+ R_orig = np.array(entry['R'][i])
311
+ t_orig = np.array(entry['t'][i])
312
+
313
+ # Apply cam2world_to_world2cam transformation as in original snippet
314
+ # This R_transformed, t_transformed will be used to place the camera geometry
315
+ R_transformed, t_transformed = cam2world_to_world2cam(R_orig, t_orig)
316
+
317
+ # Image dimensions from K matrix (cx, cy are K[0,2], K[1,2])
318
+ # Ensure W_img and H_img are derived correctly. Assuming K[0,2] and K[1,2] are principal points cx, cy.
319
+ # If K is [fx, 0, cx; 0, fy, cy; 0, 0, 1], then W_img and H_img might need to come from elsewhere
320
+ # or be estimated if not directly available. The original code used K[0,2]*2, K[1,2]*2.
321
+ # This implies cx = W/2, cy = H/2.
322
+ W_img, H_img = K_matrix[0, 2] * 2, K_matrix[1, 2] * 2
323
+ if W_img <= 0 or H_img <= 0:
324
+ # Attempt to get W, H from cam.width, cam.height if available in entry, like in colmap
325
+ # This part depends on the structure of 'entry'. For now, stick to original logic.
326
+ print(f"Warning: Camera {i} has invalid dimensions (W={W_img}, H={H_img}) based on K. Skipping.")
327
+ continue
328
+
329
+ # Define image plane corners in pixel coordinates (top-left, top-right, bottom-right, bottom-left)
330
+ corners_px = np.array([[0, 0], [W_img, 0], [W_img, H_img], [0, H_img]], dtype=float)
331
+
332
+ # Removed scale_val, image_extent, world_extent calculations.
333
+ # The scaling is now directly controlled by camera_scale_factor.
334
+
335
+ try:
336
+ K_inv = np.linalg.inv(K_matrix)
337
+ except np.linalg.LinAlgError:
338
+ print(f"Warning: K matrix for camera {i} is singular. Skipping this camera.")
339
+ continue
340
+
341
+ # Unproject pixel corners to homogeneous camera coordinates.
342
+ # Assuming to_homogeneous converts (N,2) pixel coords to (N,3) homogeneous coords [px, py, 1].
343
+ # These points are on the z=1 plane in camera coordinates.
344
+ corners_cam_homog = to_homogeneous(corners_px) @ K_inv.T
345
+
346
+ # Scale these points by camera_scale_factor.
347
+ # This makes the frustum base at z=camera_scale_factor in camera coordinates.
348
+ scaled_cam_points = corners_cam_homog * camera_scale_factor
349
+
350
+ # Transform scaled camera points to world coordinates using R_transformed, t_transformed
351
+ world_coords_base = scaled_cam_points @ R_transformed.T + t_transformed
352
+
353
+ # Apex of the pyramid is t_transformed
354
+ apex_world = t_transformed.reshape(1, 3)
355
+
356
+ # Vertices for Open3D LineSet: apex (vertex 0) + 4 base corners (vertices 1-4)
357
+ pyramid_vertices_np = np.vstack((apex_world, world_coords_base))
358
+ if not pyramid_vertices_np.flags['C_CONTIGUOUS'] or pyramid_vertices_np.dtype != np.float64:
359
+ pyramid_vertices_np = np.ascontiguousarray(pyramid_vertices_np, dtype=np.float64)
360
+
361
+ # Lines for the pyramid: 4 from apex to base, 4 for the base rectangle
362
+ lines_np = np.array([
363
+ [0, 1], [0, 2], [0, 3], [0, 4], # Apex to base corners
364
+ [1, 2], [2, 3], [3, 4], [4, 1] # Base rectangle (closed loop)
365
+ ], dtype=np.int32)
366
+
367
+ # Create Open3D LineSet object for the camera pyramid
368
+ camera_lineset = o3d.geometry.LineSet()
369
+ camera_lineset.points = o3d.utility.Vector3dVector(pyramid_vertices_np)
370
+ lines_np = np.ascontiguousarray(lines_np, dtype=np.int32)
371
+ camera_lineset.lines = o3d.utility.Vector2iVector(lines_np)
372
+
373
+ # Color the LineSet
374
+ try:
375
+ o3d_color = _plotly_rgb_to_normalized_o3d_color(current_cam_color_str)
376
+ except ValueError as e:
377
+ print(f"Warning: Invalid camera color string '{current_cam_color_str}' for camera {i}. Using default blue. Error: {e}")
378
+ o3d_color = [0.0, 0.0, 1.0] # Default to blue
379
+ camera_lineset.colors = o3d.utility.Vector3dVector(np.array([o3d_color] * len(lines_np), dtype=np.float64))
380
+
381
+ geometries.append(camera_lineset)
382
+
383
+ return geometries
384
+
385
+
386
+ ds = load_dataset("usm3d/hoho25k", streaming=True, trust_remote_code=True)
387
+ for a in ds['train']:
388
+ colmap = read_colmap_rec(a['colmap_binary'])
389
+
390
+ pred_vertices, pred_edges = predict_wireframe(a)
391
+
392
+ pcd, geometries = plot_reconstruction_local(None, colmap, points=True, cameras=True, crop_outliers=True)
393
+ wireframe = plot_wireframe_local(None, a['wf_vertices'], a['wf_edges'], a['wf_classifications'])
394
+ wireframe2 = plot_wireframe_local(None, pred_vertices, pred_edges, None, color='rgb(255, 0, 0)')
395
+ bpo_cams = plot_bpo_cameras_from_entry_local(None, a)
396
+
397
+ print(len(geometries), len(bpo_cams))
398
+
399
+ visu_all = [pcd] + geometries + wireframe + bpo_cams + wireframe2
400
+ o3d.visualization.draw_geometries(visu_all, window_name="3D Reconstruction")
401
+