# This script is designed for 3D wireframe reconstruction, primarily focusing on # buildings, using multi-view imagery and associated 3D data. # It leverages COLMAP reconstructions, depth maps, and semantic segmentations # (ADE20k and Gestalt) to identify and predict structural elements. # Core tasks include: # - Processing and aligning 2D image data (segmentations, depth) with 3D COLMAP point clouds. # - Extracting initial 2D/3D vertex candidates from segmentation maps. # - Generating local point cloud patches around these candidates. # - Employing machine learning models (e.g., PointNet variants) to refine vertex locations # and classify potential edges between them. # - Optionally, generating datasets of these patches for training ML models. # - Merging information from multiple views to produce a final 3D wireframe. import numpy as np from typing import Tuple, List from hoho2025.example_solutions import empty_solution, read_colmap_rec, get_vertices_and_edges_from_segmentation, get_house_mask, fit_scale_robust_median, get_uv_depth, merge_vertices_3d, prune_not_connected, prune_too_far, point_to_segment_dist from hoho2025.color_mappings import ade20k_color_mapping, gestalt_color_mapping from PIL import Image, ImageDraw #from visu import save_gestalt_with_proj, draw_crosses_on_image import os import pycolmap from PIL import Image as PImage import cv2 #import open3d as o3d #from visu import plot_reconstruction_local, plot_wireframe_local, plot_bpo_cameras_from_entry_local #import pyvista as pv #from fast_pointnet import save_patches_dataset, predict_vertex_from_patch from fast_pointnet_v2 import save_patches_dataset, predict_vertex_from_patch #from fast_voxel import predict_vertex_from_patch_voxel #import time from fast_pointnet_class import save_patches_dataset as save_patches_dataset_class from fast_pointnet_class import predict_class_from_patch #from fast_pointnet_class_10d import predict_class_from_patch as predict_class_from_patch_10d from scipy.spatial.distance import cdist from scipy.optimize import linear_sum_assignment import torch import time from collections import Counter GENERATE_DATASET = False #DATASET_DIR = '/path/to/your/hohocustom/' DATASET_DIR = '/path/to/your/hohocustom_v4/' GENERATE_DATASET_EDGES = False #EDGES_DATASET_DIR = '/path/to/your/hohocustom_edges/' EDGES_DATASET_DIR = '/path/to/your/hohocustom_edges_10d_v5/' def convert_entry_to_human_readable(entry): out = {} for k, v in entry.items(): if 'colmap' in k: out[k] = read_colmap_rec(v) elif k in ['wf_vertices', 'wf_edges', 'K', 'R', 't', 'depth']: out[k] = v else: out[k]=v out['__key__'] = entry['order_id'] return out def get_gt_vertices_and_edges(entry, i, depth, colmap_rec, k, r, t, img_id, ade_seg): depth_fitted, depth_sparse, found_sparse, col_img = get_fitted_dense_depth(depth, colmap_rec, img_id, ade_seg) #old_k, old_r, old_t = k.copy(), r.copy(), t.copy() #k = col_img.camera.calibration_matrix() #world_to_cam = np.eye(4) #world_to_cam = col_img.cam_from_world.matrix() #r = world_to_cam[:3, :3] #t = world_to_cam[:3, 3] wf_vertices = np.array(entry['wf_vertices']) wf_edges = entry['wf_edges'] # Project world frame vertices into the current image if wf_vertices.shape[0] > 0: # Transform vertices to camera coordinates wf_vertices_cam = (r @ wf_vertices.T) + t.reshape(3, 1) # Project to image plane wf_vertices_img_homogeneous = k @ wf_vertices_cam # Convert to 2D pixel coordinates wf_vertices_img = wf_vertices_img_homogeneous[:2, :] / wf_vertices_img_homogeneous[2, :] projected_gt_vertices_2d = wf_vertices_img.T # Initialize lists to store corresponding depth values from depth maps gt_projected_depth_fitted_values = [] gt_projected_depth_sparse_values = [] # Get dimensions of the depth maps for bounds checking # Assuming depth_fitted and depth_sparse have the same dimensions map_height, map_width = depth_fitted.shape for idx in range(projected_gt_vertices_2d.shape[0]): # Get the 2D projected coordinates (x, y) px, py = projected_gt_vertices_2d[idx] # Round to nearest integer to use as indices for the depth maps ix, iy = int(round(px)), int(round(py)) # Get corresponding depth_fitted value if 0 <= iy < map_height and 0 <= ix < map_width: gt_projected_depth_fitted_values.append(depth_fitted[iy, ix]) else: # Projected point is outside the depth map bounds gt_projected_depth_fitted_values.append(np.nan) # Get corresponding depth_sparse value if 0 <= iy < map_height and 0 <= ix < map_width: # Assuming same dimensions for depth_sparse gt_projected_depth_sparse_values.append(depth_sparse[iy, ix]) else: # Projected point is outside the depth map bounds gt_projected_depth_sparse_values.append(np.nan) # Determine occlusion status for each ground truth vertex occlusion_status = [] # True if occluded, False otherwise # This block executes only if there were ground truth vertices to begin with. # wf_vertices_cam and projected_gt_vertices_2d would have been computed. # gt_projected_depth_fitted_values list has one entry per vertex. if wf_vertices.shape[0] > 0: # These are the Z-coordinates (depths) of the original 3D wf_vertices # when transformed into the camera's coordinate system. # This is effectively the "true" depth of each vertex from the camera. gt_vertices_depth_in_camera_system = wf_vertices_cam[2, :] for idx in range(projected_gt_vertices_2d.shape[0]): true_depth_of_vertex = gt_vertices_depth_in_camera_system[idx] # This is the depth value read from the (dense) depth_fitted map # at the 2D projection of the current wf_vertex. depth_from_fitted_map = gt_projected_depth_fitted_values[idx] # A vertex is considered occluded if its true depth is greater than # the depth of the surface recorded in the depth_fitted map. # This means the vertex is behind the observed surface. # We also check if depth_from_fitted_map is a valid number (not NaN). # If depth_from_fitted_map is NaN, it means the vertex projected outside # the depth map's bounds, so we don't consider it occluded by the map. if np.isnan(true_depth_of_vertex) or true_depth_of_vertex > depth_from_fitted_map + 200.: occlusion_status.append(True) # Vertex is occluded else: occlusion_status.append(False) # Vertex is not occluded or out of map bounds if wf_vertices.shape[0] > 0: # Filter vertices based on occlusion status visible_vertices_indices = [idx for idx, occluded in enumerate(occlusion_status) if not occluded] # Create a mapping from old vertex indices to new (filtered) vertex indices old_to_new_indices_map = {old_idx: new_idx for new_idx, old_idx in enumerate(visible_vertices_indices)} # Filter the projected_gt_vertices_2d and transform to the new structure new_wf_vertices = [] if projected_gt_vertices_2d.shape[0] > 0: # Ensure projected_gt_vertices_2d is not empty for idx in visible_vertices_indices: xy_coords = projected_gt_vertices_2d[idx] new_wf_vertices.append({'xy': xy_coords, 'type': 'apex'}) wf_vertices = new_wf_vertices # Filter the edges # An edge is kept if both its vertices are in the visible_vertices_indices list visible_edges = [] for edge_start, edge_end in wf_edges: if edge_start in old_to_new_indices_map and edge_end in old_to_new_indices_map: # Remap to new indices visible_edges.append((old_to_new_indices_map[edge_start], old_to_new_indices_map[edge_end])) wf_edges = visible_edges else: # If there are no original vertices, wf_vertices should be an empty list wf_vertices = [] wf_edges = [] wf_vertices_3d_visible = np.empty((0, 3)) original_gt_3d_vertices = np.array(entry['wf_vertices']) # Check if there were original vertices and if occlusion_status was computed for them if original_gt_3d_vertices.shape[0] > 0 and len(occlusion_status) == original_gt_3d_vertices.shape[0]: # Determine indices of visible vertices based on occlusion_status # occlusion_status is True if occluded, False otherwise. We want not occluded. visible_indices = [idx for idx, occluded_flag in enumerate(occlusion_status) if not occluded_flag] if visible_indices: # If the list of visible_indices is not empty wf_vertices_3d_visible = original_gt_3d_vertices[visible_indices] # If no original_gt_3d_vertices, or if all are occluded (visible_indices is empty), # or if occlusion_status length doesn't match (which implies an issue earlier, but defensively handled), # wf_vertices_3d_visible will remain the initialized np.empty((0, 3)). return wf_vertices, wf_edges, wf_vertices_3d_visible def project_vertices_to_3d(uv: np.ndarray, depth_vert: np.ndarray, col_img: pycolmap.Image, K, R, t) -> np.ndarray: """ Projects 2D vertex coordinates with associated depths to 3D world coordinates. Parameters ---------- uv : np.ndarray (N, 2) array of 2D vertex coordinates (u, v). depth_vert : np.ndarray (N,) array of depth values for each vertex. col_img : pycolmap.Image Returns ------- vertices_3d : np.ndarray (N, 3) array of vertex coordinates in 3D world space. """ # Backproject to 3D local camera coordinates xy_local = np.ones((len(uv), 3)) #k = col_img.camera.calibration_matrix() k = K xy_local[:, 0] = (uv[:, 0] - k[0, 2]) / k[0, 0] xy_local[:, 1] = (uv[:, 1] - k[1, 2]) / k[1, 1] # Get the 3D vertices vertices_3d_local = xy_local * depth_vert[...,None] # Create camera-to-world transformation matrix world_to_cam = np.eye(4) world_to_cam[:3, :3] = R world_to_cam[:3, 3] = t.reshape(3) #world_to_cam[:3] = col_img.cam_from_world.matrix() cam_to_world = np.linalg.inv(world_to_cam) # Transform local 3D points to world coordinates vertices_3d_homogeneous = cv2.convertPointsToHomogeneous(vertices_3d_local) vertices_3d = cv2.transform(vertices_3d_homogeneous, cam_to_world) vertices_3d = cv2.convertPointsFromHomogeneous(vertices_3d).reshape(-1, 3) return vertices_3d def get_fitted_dense_depth(depth, colmap_rec, img_id, ade20k_seg, K, R, t): """ Gets sparse depth from COLMAP, computes a house mask, fits dense depth to sparse depth within the mask, and returns the fitted dense depth. Parameters ---------- depth : np.ndarray Initial dense depth map (H, W). colmap_rec : pycolmap.Reconstruction COLMAP reconstruction data. img_id : str Identifier for the current image within the COLMAP reconstruction. K : np.ndarray Camera intrinsic matrix (3x3). R : np.ndarray Camera rotation matrix (3x3). t : np.ndarray Camera translation vector (3,). ade20k_seg : PIL.Image ADE20k segmentation map for the image. Returns ------- depth_fitted : np.ndarray Dense depth map scaled and shifted to align with sparse depth within the house mask (H, W). depth_sparse : np.ndarray The sparse depth map obtained from COLMAP (H, W). found_sparse : bool True if sparse depth points were found for this image, False otherwise. """ depth_np = np.array(depth) / 1000. # Convert mm to meters if needed depth_sparse, found_sparse, col_img = get_sparse_depth_custom(colmap_rec, img_id, depth_np, K, R, t) #print(depth_sparse.sum()) #depth_sparse, found_sparse, col_img = get_sparse_depth(colmap_rec, img_id, depth_np) if not found_sparse: print(f'No sparse depth found for image {img_id}') # Return original (meter-scaled) depth if no sparse data return depth_np, np.zeros_like(depth_np), False, None # Get house mask to focus fitting on relevant areas house_mask = get_house_mask(ade20k_seg) # Fit dense depth to sparse depth (scale only), using only points within the house mask k, depth_fitted = fit_scale_robust_median(depth_np, depth_sparse, validity_mask=house_mask) print(f"Fitted depth scale k={k:.4f} for image {img_id}") #depth_fitted = depth_np# * house_mask.astype(np.float32) depth_sparse = depth_sparse# * house_mask.astype(np.float32) return depth_fitted, depth_sparse, True, col_img def get_sparse_depth_custom(colmap_rec, img_id_substring, depth, K, R, t): """ Return a sparse depth map for the COLMAP image whose name contains `img_id_substring`. The output is an array of shape `depth_shape` (H,W), where only the projected 3D points get a depth > 0, else 0. Uses provided K, R, t for projection instead of COLMAP's image projection. """ H, W = depth.shape # 1) Find the matching COLMAP image to get its associated 3D points # This part remains to identify which 3D points are relevant for this image view found_img = None for img_id_c, col_img_obj in colmap_rec.images.items(): # Renamed col_img to col_img_obj to avoid conflict if img_id_substring in col_img_obj.name: found_img = col_img_obj break if found_img is None: print(f"Image substring {img_id_substring} not found in COLMAP.") return np.zeros((H, W), dtype=np.float32), False, None # 2) Gather 3D points that this image sees (according to COLMAP) points_xyz_world = [] for pid, p3D in colmap_rec.points3D.items(): if found_img.has_point3D(pid): points_xyz_world.append(p3D.xyz) # world coords if not points_xyz_world: print(f"No 3D points associated with {found_img.name} in COLMAP.") return np.zeros((H, W), dtype=np.float32), False, found_img # Return found_img for consistency points_xyz_world = np.array(points_xyz_world) # (N, 3) # 3) Project points_xyz_world to camera coordinates using R, t # points_cam = R @ points_xyz_world.T + t.reshape(3,1) # points_cam = points_cam.T (N,3) # More robustly: points_xyz_world_h = np.hstack((points_xyz_world, np.ones((points_xyz_world.shape[0], 1)))) # (N, 4) # World to Camera transformation matrix world_to_cam_mat = np.eye(4) world_to_cam_mat[:3, :3] = R world_to_cam_mat[:3, 3] = t.flatten() points_cam_h = (world_to_cam_mat @ points_xyz_world_h.T).T # (N, 4) points_cam = points_cam_h[:, :3] / points_cam_h[:, 3, np.newaxis] # (N, 3) in camera coordinates uv = [] z_vals = [] for i in range(points_cam.shape[0]): p_cam = points_cam[i] # Project to image plane using K # p_img_h = K @ p_cam # u = p_img_h[0] / p_img_h[2] # v = p_img_h[1] / p_img_h[2] # z = p_cam[2] # Ensure p_cam[2] (depth) is positive if p_cam[2] <= 0: # Point is behind or on the camera plane continue # Project to image plane using K # K is [[fx, 0, cx], [0, fy, cy], [0, 0, 1]] u_i = (K[0, 0] * p_cam[0] / p_cam[2]) + K[0, 2] v_i = (K[1, 1] * p_cam[1] / p_cam[2]) + K[1, 2] u_i_int = int(round(u_i)) v_i_int = int(round(v_i)) # Check in-bounds if 0 <= u_i_int < W and 0 <= v_i_int < H: uv.append((u_i_int, v_i_int)) z_vals.append(p_cam[2]) # Depth is the Z coordinate in camera space if not uv: print(f"No points projected into image bounds for {img_id_substring} using K,R,t.") return np.zeros((H, W), dtype=np.float32), False, found_img uv = np.array(uv, dtype=int) # shape (M,2) z_vals = np.array(z_vals) # shape (M,) depth_out = np.zeros((H, W), dtype=np.float32) # Ensure z_vals are positive before assignment, though already checked valid_depth_mask = z_vals > 0 if np.any(valid_depth_mask): depth_out[uv[valid_depth_mask, 1], uv[valid_depth_mask, 0]] = z_vals[valid_depth_mask] return depth_out, True, found_img def create_3d_wireframe_single_image(vertices: List[dict], connections: List[Tuple[int, int]], depth: PImage, colmap_rec: pycolmap.Reconstruction, img_id: str, ade_seg: PImage, K, R, t) -> np.ndarray: """ Processes a single image view to generate 3D vertex coordinates from existing 2D vertices/edges. Parameters ---------- vertices : List[dict] List of 2D vertex dictionaries (e.g., {"xy": (x, y), "type": ...}). connections : List[Tuple[int, int]] List of 2D edge connections (indices into the vertices list). depth : PIL.Image Initial dense depth map as a PIL Image. colmap_rec : pycolmap.Reconstruction COLMAP reconstruction data. img_id : str Identifier for the current image within the COLMAP reconstruction. ade_seg : PIL.Image ADE20k segmentation map for the image. Returns ------- vertices_3d : np.ndarray (N, 3) array of vertex coordinates in 3D world space. Returns an empty array if processing fails (e.g., missing sparse depth). """ # Check if initial vertices/connections are valid if (len(vertices) < 2) or (len(connections) < 1): # This case should ideally be handled before calling, but good to double check. print(f'Warning: create_3d_wireframe_single_image called with insufficient vertices/connections for image {img_id}') return np.empty((0, 3)) # Get fitted dense depth and sparse depth depth_fitted, depth_sparse, found_sparse, col_img = get_fitted_dense_depth( depth, colmap_rec, img_id, ade_seg, K, R, t ) # Get UV coordinates and depth for each vertex uv, depth_vert = get_uv_depth(vertices, depth_fitted, depth_sparse, 10) # Backproject to 3D vertices_3d = project_vertices_to_3d(uv, depth_vert, col_img, K, R ,t) return vertices_3d def visu_patch_and_pred(patch, pred, pred_dist, pred_class): # Create plotter plotter = pv.Plotter() # Create point cloud for this patch offset = patch.get('cluster_center', None) # Offset if available patch_points_3d = np.array(patch['patch_7d'][:, :3]) patch_points_3d = patch_points_3d + offset patch_cloud = pv.PolyData(patch_points_3d) point_idxs = patch['cluster_point_ids'] # List of point indices that are filtered patch_point_ids = patch['cube_point_ids'] # Assuming the 7th column contains point IDs assigned_gt_vertex = patch.get('assigned_wf_vertex', None) # GT vertex if available initial_pred = None if assigned_gt_vertex is not None: assigned_gt_vertex = assigned_gt_vertex + offset # Color points: red for filtered points, blue for other points patch_point_colors = [] for i, pid in enumerate(patch_point_ids): if pid in point_idxs: patch_point_colors.append([255, 0, 0]) # Red for filtered points else: patch_point_colors.append([0, 0, 255]) # Blue for other points patch_cloud["colors"] = np.array(patch_point_colors) plotter.add_mesh(patch_cloud, scalars="colors", rgb=True, point_size=8, render_points_as_spheres=True) # Create sphere to visualize GT vertex if available if assigned_gt_vertex is not None: gt_sphere = pv.Sphere(radius=0.1, center=assigned_gt_vertex) plotter.add_mesh(gt_sphere, color="green", opacity=0.5) if initial_pred is not None: # Create sphere to visualize initial prediction pred_sphere = pv.Sphere(radius=0.1, center=initial_pred) plotter.add_mesh(pred_sphere, color="orange", opacity=0.5) if pred is not None: # Create sphere to visualize predicted vertex pred_sphere = pv.Sphere(radius=0.1, center=pred) plotter.add_mesh(pred_sphere, color="red", opacity=0.5) # Add text annotations for prediction values title_text = f"Patch x\nPred dist: {pred_dist:.4f}\nPred class: {pred_class}" plotter.show(title=title_text) def extract_vertices_from_whole_pcloud(colmap_rec, idxs_points, all_connections): # Filter COLMAP points and colors based on idxs_points filtered_colmap_points = [] filtered_colmap_colors = [] filtered_colmap_ids = [] all_filtered_ids_list = [] all_extracted_groups = [] all_flattened_connections = [] group_to_flattened_mapping = {} # Maps (group_idx, local_vertex_idx) to flattened_idx # Flatten all groups and create mapping for connections flattened_idx = 0 for group_idx, point_ids_group in enumerate(idxs_points): cur_connections = all_connections[group_idx] group_to_flattened_mapping[group_idx] = {} for local_idx, point_ids in enumerate(point_ids_group): all_extracted_groups.append(point_ids) group_to_flattened_mapping[group_idx][local_idx] = flattened_idx flattened_idx += 1 # Convert connections to flattened indices for conn in cur_connections: start_idx, end_idx = conn if start_idx in group_to_flattened_mapping[group_idx] and end_idx in group_to_flattened_mapping[group_idx]: flattened_start = group_to_flattened_mapping[group_idx][start_idx] flattened_end = group_to_flattened_mapping[group_idx][end_idx] all_flattened_connections.append((flattened_start, flattened_end)) # Collect all filtered point IDs from all images for group_idxs in idxs_points: for point_ids in group_idxs: all_filtered_ids_list.extend(point_ids) # Convert to set for faster lookup all_filtered_ids_set = set(all_filtered_ids_list) # Extract only the filtered points, their colors, and IDs all_colmap_points = [] all_colmap_colors = [] all_colmap_ids = [] for pid, p3D in colmap_rec.points3D.items(): all_colmap_points.append(p3D.xyz) all_colmap_colors.append(p3D.color / 255.0) # Normalize colors to [0,1] all_colmap_ids.append(pid) if pid in all_filtered_ids_set: filtered_colmap_points.append(p3D.xyz) filtered_colmap_colors.append(p3D.color / 255.0) # Normalize colors to [0,1] filtered_colmap_ids.append(pid) all_colmap_points = np.array(all_colmap_points) if all_colmap_points else np.empty((0, 3)) all_colmap_colors = np.array(all_colmap_colors) if all_colmap_colors else np.empty((0, 3)) all_colmap_ids = np.array(all_colmap_ids) if all_colmap_ids else np.empty((0,)) whole_pcloud = {'points': all_colmap_points, 'colors': all_colmap_colors, 'ids': all_colmap_ids} filtered_colmap_points = np.array(filtered_colmap_points) if filtered_colmap_points else np.empty((0, 3)) filtered_colmap_colors = np.array(filtered_colmap_colors) if filtered_colmap_colors else np.empty((0, 3)) filtered_colmap_ids = np.array(filtered_colmap_ids) if filtered_colmap_ids else np.empty((0,)) # Extract points within ball radius from each set of points in idxs_points ball_radius = 0.5 # meters extracted_points = [] extracted_colors = [] extracted_ids = [] for group_idx, point_ids_group in enumerate(all_extracted_groups): group_extracted_points = [] group_extracted_colors = [] group_extracted_ids = [] # Get 3D coordinates of points in this group group_points_3d = [] for pid in point_ids_group: if pid in [filtered_colmap_ids[i] for i in range(len(filtered_colmap_ids))]: idx = np.where(filtered_colmap_ids == pid)[0][0] group_points_3d.append(filtered_colmap_points[idx]) if not group_points_3d: continue group_points_3d = np.array(group_points_3d) center = np.mean(group_points_3d, axis=0) # Center of the group points # For each point in the filtered point cloud, check if it's within ball radius of any point in this group # Calculate distance from center to all filtered points if len(filtered_colmap_points) > 0: distances_to_center = np.linalg.norm(filtered_colmap_points - center, axis=1) within_radius_mask = distances_to_center <= ball_radius if np.any(within_radius_mask): group_extracted_points.extend(filtered_colmap_points[within_radius_mask]) group_extracted_colors.extend(filtered_colmap_colors[within_radius_mask]) group_extracted_ids.extend(filtered_colmap_ids[within_radius_mask]) extracted_points.append(np.array(group_extracted_points) if group_extracted_points else np.empty((0, 3))) extracted_colors.append(np.array(group_extracted_colors) if group_extracted_colors else np.empty((0, 3))) extracted_ids.append(np.array(group_extracted_ids) if group_extracted_ids else np.empty((0,))) # Filter extracted_points to merge groups that share more than 50% of their points # and update connections accordingly updated_connections = [] if extracted_points: #print(f"Merging groups based on point overlap... Processing {len(extracted_points)} groups") # Create a list to track which groups to keep groups_to_keep = [] merged_groups = set() # Track which groups have been merged old_to_new_mapping = {} # Maps old flattened index to new index for i, (points_i, colors_i, ids_i) in enumerate(zip(extracted_points, extracted_colors, extracted_ids)): if i in merged_groups or len(ids_i) == 0: continue # Start with the current group merged_points = points_i.copy() merged_colors = colors_i.copy() merged_ids = set(ids_i) merged_indices = [i] # Track which original indices are merged # Check all subsequent groups for overlap for j in range(i + 1, len(extracted_points)): if j in merged_groups or len(extracted_ids[j]) == 0: continue ids_j = set(extracted_ids[j]) # Calculate overlap percentage intersection = merged_ids.intersection(ids_j) smaller_group_size = min(len(merged_ids), len(ids_j)) if smaller_group_size > 0: overlap_percentage = len(intersection) / smaller_group_size # If more than 50% overlap, merge the groups if overlap_percentage > 0.5: merged_points = np.vstack([merged_points, extracted_points[j]]) if len(merged_points) > 0 else extracted_points[j] merged_colors = np.vstack([merged_colors, extracted_colors[j]]) if len(merged_colors) > 0 else extracted_colors[j] merged_ids.update(ids_j) merged_indices.append(j) merged_groups.add(j) # Add the merged group to the list of groups to keep if len(merged_points) > 0: new_group_idx = len(groups_to_keep) groups_to_keep.append((merged_points, merged_colors, np.array(list(merged_ids)))) # Update mapping for all merged indices for old_idx in merged_indices: old_to_new_mapping[old_idx] = new_group_idx # Update extracted_points, extracted_colors, and extracted_ids with filtered results extracted_points = [group[0] for group in groups_to_keep] extracted_colors = [group[1] for group in groups_to_keep] extracted_ids = [group[2] for group in groups_to_keep] # Update connections based on the new mapping for start_idx, end_idx in all_flattened_connections: if start_idx in old_to_new_mapping and end_idx in old_to_new_mapping: new_start = old_to_new_mapping[start_idx] new_end = old_to_new_mapping[end_idx] # Only add connection if vertices are still different after merging if new_start != new_end: connection = tuple(sorted((new_start, new_end))) if connection not in updated_connections: updated_connections.append(connection) #print(f"After merging, number of groups: {len(extracted_points)}") #print(f"Updated connections: {updated_connections}") # Create visualization showing extracted points for each group as balls within their mean if False: if extracted_points: plotter = pv.Plotter() # Add all COLMAP points in gray all_points = [] all_colors = [] for pid, p3D in colmap_rec.points3D.items(): all_points.append(p3D.xyz) all_colors.append([0.8, 0.8, 0.8]) # Gray color if all_points: all_points = np.array(all_points) all_colors = np.array(all_colors) point_cloud = pv.PolyData(all_points) point_cloud["colors"] = np.array(all_colors) plotter.add_mesh(point_cloud, scalars="colors", rgb=True, point_size=3, render_points_as_spheres=True) for group_idx, (group_points, group_colors) in enumerate(zip(extracted_points, extracted_colors)): if len(group_points) > 0: # Calculate mean position for this group group_mean = np.mean(group_points, axis=0) # Create a sphere at the mean position sphere = pv.Sphere(radius=0.2, center=group_mean) # Generate a random color for each group group_color = np.random.rand(3) plotter.add_mesh(sphere, color=group_color, opacity=0.7) # Add the extracted points for this group in the same color group_cloud = pv.PolyData(group_points) plotter.add_mesh(group_cloud, color=group_color, point_size=6, render_points_as_spheres=True) plotter.show(title=f"Extracted Points within {ball_radius}m radius - Spheres at group means") return extracted_points, extracted_colors, extracted_ids, whole_pcloud, updated_connections from collections import Counter # Ensure Counter is imported def extract_vertices_from_whole_pcloud_v2(colmap_pcloud, idxs_points, all_connections): # Extract initial data from colmap_pcloud # points_7d contains: x, y, z, r, g, b, pid (r,g,b are already normalized to [0,1]) all_colmap_points_xyz = colmap_pcloud['points_7d'][:, :3] all_colmap_rgb_colors = colmap_pcloud['points_7d'][:, 3:6] all_colmap_ids = colmap_pcloud['points_7d'][:, 6].astype(int) # ADE feature: 1.0 if ade_count > 0, else 0.0 # colmap_pcloud['ade'] stores the count of times a point was seen in an ADE house mask all_colmap_ade_feature = (np.array(colmap_pcloud['ade']) > 0).astype(float).reshape(-1, 1) # Gestalt feature: Fused Gestalt color by majority vote, normalized to [0,1] # colmap_pcloud['gestalt'] is a list of lists; each inner list contains uint8 RGB arrays from different views all_colmap_fused_gestalt_colors_normalized = np.zeros((len(all_colmap_points_xyz), 3)) for i, gestalt_obs_for_point_i in enumerate(colmap_pcloud['gestalt']): if gestalt_obs_for_point_i: # Convert list of np.arrays to list of tuples to make them hashable for Counter # Ensure gestalt_obs_for_point_i contains hashable items, e.g. tuples try: # If gestalt_obs_for_point_i contains numpy arrays: counts = Counter(map(tuple, gestalt_obs_for_point_i)) except TypeError: # If gestalt_obs_for_point_i already contains tuples or other hashables: counts = Counter(gestalt_obs_for_point_i) if counts: most_common_gestalt_tuple = counts.most_common(1)[0][0] fused_gestalt_rgb_uint8 = np.array(most_common_gestalt_tuple) all_colmap_fused_gestalt_colors_normalized[i] = fused_gestalt_rgb_uint8 / 255.0 else: all_colmap_fused_gestalt_colors_normalized[i] = np.array([0.0, 0.0, 0.0]) # Default if counts is empty else: all_colmap_fused_gestalt_colors_normalized[i] = np.array([0.0, 0.0, 0.0]) # Default if no observations # Combine into 7D colors [R, G, B, ADE, Gestalt_R, Gestalt_G, Gestalt_B] all_colmap_colors_7d = np.hstack(( all_colmap_rgb_colors, all_colmap_ade_feature, all_colmap_fused_gestalt_colors_normalized )) # Flatten all groups and create mapping for connections all_filtered_ids_list = [] all_extracted_groups = [] # List of lists of point_ids all_flattened_connections = [] group_to_flattened_mapping = {} # Maps (group_idx, local_vertex_idx) to flattened_idx flattened_idx = 0 for group_idx, point_ids_group in enumerate(idxs_points): # idxs_points is list of lists of point_ids cur_connections = all_connections[group_idx] group_to_flattened_mapping[group_idx] = {} for local_idx, point_ids in enumerate(point_ids_group): # point_ids is a list of pids for one vertex candidate all_extracted_groups.append(point_ids) # Store the list of pids all_filtered_ids_list.extend(point_ids) # Add all pids to a flat list group_to_flattened_mapping[group_idx][local_idx] = flattened_idx flattened_idx += 1 for conn in cur_connections: start_idx, end_idx = conn if start_idx in group_to_flattened_mapping[group_idx] and end_idx in group_to_flattened_mapping[group_idx]: flattened_start = group_to_flattened_mapping[group_idx][start_idx] flattened_end = group_to_flattened_mapping[group_idx][end_idx] all_flattened_connections.append((flattened_start, flattened_end)) all_filtered_ids_set = set(all_filtered_ids_list) # Extract only the points that are part of any initial group filtered_colmap_points_xyz_list = [] filtered_colmap_colors_7d_list = [] filtered_colmap_ids_list = [] for i, pid in enumerate(all_colmap_ids): if pid in all_filtered_ids_set: filtered_colmap_points_xyz_list.append(all_colmap_points_xyz[i]) filtered_colmap_colors_7d_list.append(all_colmap_colors_7d[i]) filtered_colmap_ids_list.append(pid) filtered_colmap_points_xyz_arr = np.array(filtered_colmap_points_xyz_list) if filtered_colmap_points_xyz_list else np.empty((0, 3)) filtered_colmap_colors_7d_arr = np.array(filtered_colmap_colors_7d_list) if filtered_colmap_colors_7d_list else np.empty((0, 7)) filtered_colmap_ids_arr = np.array(filtered_colmap_ids_list) if filtered_colmap_ids_list else np.empty((0,), dtype=int) # This whole_pcloud is created by this function, reflecting the full dataset with 7D colors whole_pcloud_internal = { 'points': all_colmap_points_xyz, 'colors': all_colmap_colors_7d, # Now 7D 'ids': all_colmap_ids } # Extract points within ball radius for each group ball_radius = 0.5 # meters extracted_points_groups = [] extracted_colors_7d_groups = [] extracted_ids_groups = [] for point_ids_in_one_group in all_extracted_groups: # point_ids_in_one_group is a list of PIDs current_group_points_xyz = [] # Get 3D coordinates of points in this specific initial group # These PIDs should be in all_colmap_ids indices_in_all_colmap = [np.where(all_colmap_ids == pid)[0][0] for pid in point_ids_in_one_group if pid in all_colmap_ids] if not indices_in_all_colmap: extracted_points_groups.append(np.empty((0,3))) extracted_colors_7d_groups.append(np.empty((0,7))) extracted_ids_groups.append(np.empty((0,), dtype=int)) continue current_group_points_xyz = all_colmap_points_xyz[indices_in_all_colmap] if current_group_points_xyz.shape[0] == 0: extracted_points_groups.append(np.empty((0,3))) extracted_colors_7d_groups.append(np.empty((0,7))) extracted_ids_groups.append(np.empty((0,), dtype=int)) continue center = np.mean(current_group_points_xyz, axis=0) # Find points from the *filtered_colmap_points_xyz_arr* (points belonging to *any* initial group) # that are within ball_radius of this group's center. group_extracted_points_list = [] group_extracted_colors_7d_list = [] group_extracted_ids_list = [] if len(filtered_colmap_points_xyz_arr) > 0: distances_to_center = np.linalg.norm(filtered_colmap_points_xyz_arr - center, axis=1) within_radius_mask = distances_to_center <= ball_radius if np.any(within_radius_mask): group_extracted_points_list.extend(filtered_colmap_points_xyz_arr[within_radius_mask]) group_extracted_colors_7d_list.extend(filtered_colmap_colors_7d_arr[within_radius_mask]) group_extracted_ids_list.extend(filtered_colmap_ids_arr[within_radius_mask]) extracted_points_groups.append(np.array(group_extracted_points_list) if group_extracted_points_list else np.empty((0, 3))) extracted_colors_7d_groups.append(np.array(group_extracted_colors_7d_list) if group_extracted_colors_7d_list else np.empty((0, 7))) extracted_ids_groups.append(np.array(group_extracted_ids_list) if group_extracted_ids_list else np.empty((0,), dtype=int)) # Filter extracted_points to merge groups that share more than 50% of their points updated_connections = [] final_extracted_points = [] final_extracted_colors_7d = [] final_extracted_ids = [] if extracted_points_groups: groups_to_keep_data = [] merged_groups_indices = set() old_to_new_mapping = {} for i in range(len(extracted_points_groups)): if i in merged_groups_indices or len(extracted_ids_groups[i]) == 0: continue current_merged_points = extracted_points_groups[i].copy() current_merged_colors_7d = extracted_colors_7d_groups[i].copy() current_merged_ids_set = set(extracted_ids_groups[i]) indices_in_this_merged_group = [i] for j in range(i + 1, len(extracted_points_groups)): if j in merged_groups_indices or len(extracted_ids_groups[j]) == 0: continue ids_j_set = set(extracted_ids_groups[j]) intersection = current_merged_ids_set.intersection(ids_j_set) smaller_group_size = min(len(current_merged_ids_set), len(ids_j_set)) if smaller_group_size > 0: overlap_percentage = len(intersection) / smaller_group_size if overlap_percentage > 0.5: current_merged_points = np.vstack([current_merged_points, extracted_points_groups[j]]) if len(current_merged_points) > 0 else extracted_points_groups[j] current_merged_colors_7d = np.vstack([current_merged_colors_7d, extracted_colors_7d_groups[j]]) if len(current_merged_colors_7d) > 0 else extracted_colors_7d_groups[j] current_merged_ids_set.update(ids_j_set) indices_in_this_merged_group.append(j) merged_groups_indices.add(j) if len(current_merged_points) > 0: new_group_idx = len(groups_to_keep_data) groups_to_keep_data.append((current_merged_points, current_merged_colors_7d, np.array(list(current_merged_ids_set)))) for old_idx in indices_in_this_merged_group: old_to_new_mapping[old_idx] = new_group_idx final_extracted_points = [group_data[0] for group_data in groups_to_keep_data] final_extracted_colors_7d = [group_data[1] for group_data in groups_to_keep_data] final_extracted_ids = [group_data[2] for group_data in groups_to_keep_data] for start_idx, end_idx in all_flattened_connections: if start_idx in old_to_new_mapping and end_idx in old_to_new_mapping: new_start = old_to_new_mapping[start_idx] new_end = old_to_new_mapping[end_idx] if new_start != new_end: connection = tuple(sorted((new_start, new_end))) if connection not in updated_connections: updated_connections.append(connection) # Visualization part (remains largely unchanged, uses random colors for spheres) if False: # Set to True to enable visualization if final_extracted_points: # Ensure pyvista is imported if this block is enabled # import pyvista as pv plotter = pv.Plotter() # Add all COLMAP points (from whole_pcloud_internal) in gray if len(whole_pcloud_internal['points']) > 0: # For visualization, use only RGB part of 7D colors or a fixed color # Here, using fixed gray color as in original vis_colors = np.full((len(whole_pcloud_internal['points']), 3), [0.8, 0.8, 0.8]) point_cloud = pv.PolyData(whole_pcloud_internal['points']) point_cloud["colors"] = vis_colors plotter.add_mesh(point_cloud, scalars="colors", rgb=True, point_size=3, render_points_as_spheres=True) for group_idx, (group_points_xyz, _) in enumerate(zip(final_extracted_points, final_extracted_colors_7d)): if len(group_points_xyz) > 0: group_mean = np.mean(group_points_xyz, axis=0) sphere = pv.Sphere(radius=0.2, center=group_mean) group_color_vis = np.random.rand(3) # Random color for sphere plotter.add_mesh(sphere, color=group_color_vis, opacity=0.7) group_cloud = pv.PolyData(group_points_xyz) # Use the same random color for points in this group for visualization consistency plotter.add_mesh(group_cloud, color=group_color_vis, point_size=6, render_points_as_spheres=True) plotter.show(title=f"Extracted Points within {ball_radius}m radius - Spheres at group means") return final_extracted_points, final_extracted_colors_7d, final_extracted_ids, whole_pcloud_internal, updated_connections def visu_pcloud_and_preds(colmap_rec, extracted_ids, extracted_points, extracted_colors, predicted_vertices, connections): if extracted_ids: plotter = pv.Plotter() # Add all COLMAP points in gray all_points = [] all_colors = [] for pid, p3D in colmap_rec.points3D.items(): all_points.append(p3D.xyz) all_colors.append([0.8, 0.8, 0.8]) # Gray color if all_points: all_points = np.array(all_points) all_colors = np.array(all_colors) point_cloud = pv.PolyData(all_points) point_cloud["colors"] = np.array(all_colors) plotter.add_mesh(point_cloud, scalars="colors", rgb=True, point_size=3, render_points_as_spheres=True) for group_idx, (group_points, group_colors) in enumerate(zip(extracted_points, extracted_colors)): if len(group_points) > 0: # Calculate mean position for this group group_mean = np.mean(group_points, axis=0) # Create a sphere at the mean position sphere = pv.Sphere(radius=0.2, center=group_mean) # Generate a random color for each group group_color = np.random.rand(3) plotter.add_mesh(sphere, color=group_color, opacity=0.5) # Add the extracted points for this group in the same color group_cloud = pv.PolyData(group_points) plotter.add_mesh(group_cloud, color=group_color, point_size=6, render_points_as_spheres=True) # Add predicted vertex as sphere if it exists and is valid if group_idx < len(predicted_vertices): pred_vertex = predicted_vertices[group_idx] if not np.allclose(pred_vertex, [0.0, 0.0, 0.0]): # Check if it's not a zero vertex pred_sphere = pv.Sphere(radius=0.15, center=pred_vertex) plotter.add_mesh(pred_sphere, color="black", opacity=1.) # Add connections between predicted vertices if len(predicted_vertices) > 0 and len(connections) > 0: valid_pred_vertices = [] valid_indices = [] for i, pred_vertex in enumerate(predicted_vertices): if not np.allclose(pred_vertex, [0.0, 0.0, 0.0]): valid_pred_vertices.append(pred_vertex) valid_indices.append(i) if len(valid_pred_vertices) > 1: valid_pred_vertices = np.array(valid_pred_vertices) # Create lines for connections for start_idx, end_idx in connections: if start_idx in valid_indices and end_idx in valid_indices: # Map to valid vertex indices valid_start = valid_indices.index(start_idx) valid_end = valid_indices.index(end_idx) # Create line between vertices line_points = np.array([valid_pred_vertices[valid_start], valid_pred_vertices[valid_end]]) line = pv.Line(line_points[0], line_points[1]) plotter.add_mesh(line, color="red", line_width=3) ball_radius = 1.0 # meters plotter.show(title=f"Extracted Points within {ball_radius}m radius - Spheres at group means") def generate_edge_patches(frame, pred_vertices, colmap_pcloud): gt_vertices = np.array(frame['wf_vertices']) if frame['wf_vertices'] else np.empty((0, 3)) gt_connections = frame['wf_edges'] vertices = np.array(pred_vertices) if pred_vertices is not None and len(pred_vertices) > 0 else np.empty((0, 3)) # Find closest GT vertex for each predicted vertex connections = [] if len(vertices) > 0 and len(gt_vertices) > 0: # For each GT vertex, find the closest predicted vertex gt_to_pred_mapping = {} for gt_idx, gt_vertex in enumerate(gt_vertices): # Calculate distances from this GT vertex to all predicted vertices distances = np.linalg.norm(vertices - gt_vertex, axis=1) # Find the closest predicted vertex closest_pred_idx = np.argmin(distances) closest_distance = distances[closest_pred_idx] # Only map if within distance threshold distance_threshold = 1.5 if closest_distance <= distance_threshold: gt_to_pred_mapping[gt_idx] = closest_pred_idx # Propagate GT connections to predicted vertices for gt_connection in gt_connections: gt_start, gt_end = gt_connection if gt_start in gt_to_pred_mapping and gt_end in gt_to_pred_mapping: pred_start = gt_to_pred_mapping[gt_start] pred_end = gt_to_pred_mapping[gt_end] connections.append((pred_start, pred_end)) print(f"Matched {len(gt_to_pred_mapping)} GT vertices to predicted vertices") print(f"Propagated {len(connections)} connections from GT to predicted vertices") positive_patches = [] negative_patches = [] cylinder_radius = 1.0 # meters points_6d = colmap_pcloud['points_7d'][:, :7] points_6d[:, 3:6] = points_6d[:, 3:6] * 2 - 1 # Normalize RGB colors to [0, 1] ade = colmap_pcloud['ade'] ade = np.where(ade, 1, -1) # Normalize to [-1, 1] gestalt = colmap_pcloud['gestalt'] # Fuse multiple gestalt values per point using majority voting fused_gestalt = [] for point_gestalt_list in gestalt: if len(point_gestalt_list) == 0: fused_gestalt.append(np.array([0, 0, 0])) elif len(point_gestalt_list) == 1: fused_gestalt.append(point_gestalt_list[0]) else: # Convert to tuples for hashable voting gestalt_tuples = [tuple(gestalt_val) for gestalt_val in point_gestalt_list] # Use Counter for majority voting counts = Counter(gestalt_tuples) most_common_tuple = counts.most_common(1)[0][0] fused_value = np.array(most_common_tuple, dtype=np.uint8) fused_gestalt.append(fused_value) gestalt = np.array(fused_gestalt) gestalt = (gestalt / 255) * 2 - 1 # Normalize to [-1, 1] # Extract 3D coordinates for faster vectorized operations colmap_points_3d = points_6d[:, :3] # Create combined 10D point cloud (xyz + rgb + ade + gestalt) colmap_points_10d = np.zeros((len(colmap_points_3d), 10)) colmap_points_10d[:, :3] = colmap_points_3d # xyz coordinates colmap_points_10d[:, 3:6] = points_6d[:, 3:6] # rgb colors (already normalized to [-1, 1]) colmap_points_10d[:, 6] = ade # ade values (normalized to [-1, 1]) colmap_points_10d[:, 7:10] = gestalt # gestalt values (normalized to [-1, 1], all 3 RGB channels) # For each connection, create a positive edge patch for connection in connections: start_idx, end_idx = connection # Get start and end vertices from the connections start_vertex = vertices[start_idx] end_vertex = vertices[end_idx] # Create line vector from start to end line_vector = end_vertex - start_vertex line_length = np.linalg.norm(line_vector) # Normalize line vector line_direction = line_vector / line_length # Extend the line by 25 cm (0.25 meters) on both ends for more context extension_length = 1 # 25 cm in meters extended_start = start_vertex - extension_length * line_direction extended_end = end_vertex + extension_length * line_direction extended_line_length = line_length + 2 * extension_length # Vectorized distance calculation # Vector from extended start to all points start_to_points = colmap_points_3d - extended_start[np.newaxis, :] # Project onto line direction to get distance along extended line projection_lengths = np.dot(start_to_points, line_direction) # Filter points within extended line segment bounds within_bounds = (projection_lengths >= 0) & (projection_lengths <= extended_line_length) # Find closest points on extended line segment for all points closest_points_on_line = extended_start[np.newaxis, :] + projection_lengths[:, np.newaxis] * line_direction[np.newaxis, :] # Calculate perpendicular distances from points to line perpendicular_distances = np.linalg.norm(colmap_points_3d - closest_points_on_line, axis=1) # Find points within cylinder within_cylinder = within_bounds & (perpendicular_distances <= cylinder_radius) if np.sum(within_cylinder) <= 5: continue points_in_cylinder = colmap_points_10d[within_cylinder] point_indices_in_cylinder = np.where(within_cylinder)[0] # Center the patch at the midpoint of the original line (not extended) line_midpoint = (start_vertex + end_vertex) / 2 # Shift points to center around origin points_centered = points_in_cylinder.copy() points_centered[:, :3] -= line_midpoint # Create positive edge patch positive_patch = { 'patch_10d': points_centered, 'connection': connection, 'line_start': start_vertex - line_midpoint, 'line_end': end_vertex - line_midpoint, 'cylinder_radius': cylinder_radius, 'point_indices': point_indices_in_cylinder, 'label': 1, # Positive label for edge 'center': line_midpoint } positive_patches.append(positive_patch) # Generate negative edge patches by sampling random unconnected vertex pairs num_negative_patches = len(positive_patches) if num_negative_patches > 0 and len(vertices) >= 2: # Create set of connected pairs for fast lookup connected_pairs = set(tuple(sorted(conn)) for conn in connections) # Generate all possible vertex pairs vertex_indices = np.arange(len(vertices)) all_pairs = np.array(np.meshgrid(vertex_indices, vertex_indices)).T.reshape(-1, 2) # Filter out pairs where both indices are the same all_pairs = all_pairs[all_pairs[:, 0] != all_pairs[:, 1]] # Sort pairs to match connected_pairs format all_pairs_sorted = np.sort(all_pairs, axis=1) # Find unconnected pairs unconnected_mask = np.array([tuple(pair) not in connected_pairs for pair in all_pairs_sorted]) unconnected_pairs = all_pairs[unconnected_mask] if len(unconnected_pairs) > 0: # Pre-compute positive patch cylinder info for overlap checks positive_cylinders = [] for pos_patch in positive_patches: start_world = pos_patch['line_start'] + pos_patch['center'] end_world = pos_patch['line_end'] + pos_patch['center'] positive_cylinders.append({ 'start': start_world, 'end': end_world, 'radius': pos_patch['cylinder_radius'] }) # Randomly sample negative pairs without replacement num_to_sample = min(num_negative_patches * 3, len(unconnected_pairs)) # Sample more to account for rejections sampled_indices = np.random.choice(len(unconnected_pairs), size=num_to_sample, replace=False) sampled_pairs = unconnected_pairs[sampled_indices] for idx1, idx2 in sampled_pairs: if len(negative_patches) >= num_negative_patches: break start_vertex = vertices[idx1] end_vertex = vertices[idx2] # Create line vector from start to end line_vector = end_vertex - start_vertex line_length = np.linalg.norm(line_vector) # Normalize line vector line_direction = line_vector / line_length # Extend the line by 25 cm (0.25 meters) on both ends for more context extension_length = 1 # 25 cm in meters extended_start = start_vertex - extension_length * line_direction extended_end = end_vertex + extension_length * line_direction extended_line_length = line_length + 2 * extension_length # Check cylinder overlap with positive patches current_cylinder = { 'start': extended_start, 'end': extended_end, 'radius': cylinder_radius } has_overlap = False for pos_cylinder in positive_cylinders: # Calculate cylinder-cylinder intersection volume overlap_volume = calculate_cylinder_overlap_volume(current_cylinder, pos_cylinder) # Calculate volumes of both cylinders current_volume = np.pi * cylinder_radius**2 * extended_line_length pos_height = np.linalg.norm(pos_cylinder['end'] - pos_cylinder['start']) pos_volume = np.pi * pos_cylinder['radius']**2 * pos_height # Calculate IoU union_volume = current_volume + pos_volume - overlap_volume if union_volume > 0: iou = overlap_volume / union_volume if iou > 0.25: # 0.2 IoU threshold has_overlap = True break if has_overlap: continue # Skip this negative patch due to cylinder overlap # Vectorized distance calculation # Vector from extended start to all points start_to_points = colmap_points_3d - extended_start[np.newaxis, :] # Project onto line direction to get distance along extended line projection_lengths = np.dot(start_to_points, line_direction) # Filter points within extended line segment bounds within_bounds = (projection_lengths >= 0) & (projection_lengths <= extended_line_length) # Find closest points on extended line segment for all points closest_points_on_line = extended_start[np.newaxis, :] + projection_lengths[:, np.newaxis] * line_direction[np.newaxis, :] # Calculate perpendicular distances from points to line perpendicular_distances = np.linalg.norm(colmap_points_3d - closest_points_on_line, axis=1) # Find points within cylinder within_cylinder = within_bounds & (perpendicular_distances <= cylinder_radius) if np.sum(within_cylinder) <= 10: continue points_in_cylinder = colmap_points_10d[within_cylinder] point_indices_in_cylinder = np.where(within_cylinder)[0] # Center the patch at the midpoint of the original line (not extended) line_midpoint = (start_vertex + end_vertex) / 2 # Shift points to center around origin points_centered = points_in_cylinder.copy() points_centered[:, :3] -= line_midpoint # Create negative edge patch negative_patch = { 'patch_10d': points_centered, 'connection': (idx1, idx2), 'line_start': start_vertex - line_midpoint, 'line_end': end_vertex - line_midpoint, 'cylinder_radius': cylinder_radius, 'point_indices': point_indices_in_cylinder, 'label': 0, # Negative label for non-edge 'center': line_midpoint # Center of the patch } negative_patches.append(negative_patch) print(f"Generated {len(positive_patches)} positive patches and {len(negative_patches)} negative patches") all_patches = positive_patches + negative_patches # Visualize edge patches if False: # Set to True to enable visualization # Create plotter plotter = pv.Plotter() # Add whole point cloud in gray if len(colmap_points_10d) > 0: whole_cloud = pv.PolyData(colmap_points_3d) gray_colors = np.full((len(colmap_points_3d), 3), [0.5, 0.5, 0.5]) whole_cloud["colors"] = gray_colors plotter.add_mesh(whole_cloud, scalars="colors", rgb=True, point_size=3, render_points_as_spheres=True) # Add GT vertices and connections in blue gt_vertices = np.array(frame['wf_vertices']) if frame['wf_vertices'] else np.empty((0, 3)) gt_connections = frame['wf_edges'] if len(gt_vertices) > 0: # Add GT vertices as blue spheres for gt_vertex in gt_vertices: gt_sphere = pv.Sphere(radius=0.15, center=gt_vertex) plotter.add_mesh(gt_sphere, color='blue', opacity=0.8) # Add GT connections as blue lines for gt_connection in gt_connections: gt_start_idx, gt_end_idx = gt_connection if gt_start_idx < len(gt_vertices) and gt_end_idx < len(gt_vertices): gt_line_points = np.array([gt_vertices[gt_start_idx], gt_vertices[gt_end_idx]]) gt_line = pv.Line(gt_line_points[0], gt_line_points[1]) plotter.add_mesh(gt_line, color='blue', line_width=8) # Visualize each patch for patch_idx, patch in enumerate(all_patches): # Use green for positive (edge), red for negative (non-edge) patch_color = 'green' if patch['label'] == 1 else 'red' # Get patch data points_in_cylinder = patch['patch_10d'][:, :3] # xyz coordinates line_start = patch['line_start'] line_end = patch['line_end'] center = patch['center'] # Use center instead of calculating midpoint # Shift points back to world coordinates for visualization points_world = points_in_cylinder + center # Add points inside cylinder with patch-specific color if len(points_world) > 0: cylinder_cloud = pv.PolyData(points_world) plotter.add_mesh(cylinder_cloud, color=patch_color, point_size=8, render_points_as_spheres=True) # Add start and end points as larger spheres start_sphere = pv.Sphere(radius=0.1, center=line_start + center) end_sphere = pv.Sphere(radius=0.1, center=line_end + center) plotter.add_mesh(start_sphere, color='black', opacity=0.8) plotter.add_mesh(end_sphere, color='white', opacity=0.8) # Add line between start and end line_points = np.array([line_start + center, line_end + center]) line = pv.Line(line_points[0], line_points[1]) plotter.add_mesh(line, color=patch_color, line_width=5) # Add cylinder wireframe to show extraction bounds cylinder_center = center cylinder_direction = (line_end - line_start) / np.linalg.norm(line_end - line_start) cylinder_height = np.linalg.norm(line_end - line_start) + 2 * 0.25 # Including extensions # Create cylinder mesh for visualization cylinder_mesh = pv.Cylinder(center=cylinder_center, direction=cylinder_direction, radius=patch['cylinder_radius'], height=cylinder_height) plotter.add_mesh(cylinder_mesh, color=patch_color, opacity=0.2, style='wireframe') # Set title based on label distribution positive_count = sum(1 for patch in all_patches if patch['label'] == 1) negative_count = sum(1 for patch in all_patches if patch['label'] == 0) title = f"Edge Patches - Positive (Green): {positive_count}, Negative (Red): {negative_count}, GT (Blue)" plotter.show(title=title) return all_patches def generate_edge_patches_forward(frame, pred_vertices): vertices = pred_vertices cylinder_radius = 0.5 colmap = frame['colmap_binary'] # Create 6D point cloud from COLMAP data colmap_points_6d = [] for pid, p3D in colmap.points3D.items(): # Combine xyz coordinates and RGB color point_6d = np.concatenate([p3D.xyz, p3D.color / 255.0]) # Normalize color to [0,1] colmap_points_6d.append(point_6d) colmap_points_6d = np.array(colmap_points_6d) if colmap_points_6d else np.empty((0, 6)) colmap_points_6d[:, 3:] = colmap_points_6d[:, 3:] * 2 - 1 # Extract 3D coordinates for faster vectorized operations colmap_points_3d = colmap_points_6d[:, :3] forward_patches = [] # For each vertex pair, create a patch without label for i in range(len(vertices)): for j in range(i + 1, len(vertices)): start_vertex = vertices[i] end_vertex = vertices[j] # Create line vector from start to end line_vector = end_vertex - start_vertex line_length = np.linalg.norm(line_vector) # Normalize line vector line_direction = line_vector / line_length # Extend the line by 25 cm (0.25 meters) on both ends for more context extension_length = 0.25 # 25 cm in meters extended_start = start_vertex - extension_length * line_direction extended_end = end_vertex + extension_length * line_direction extended_line_length = line_length + 2 * extension_length # Vectorized distance calculation # Vector from extended start to all points start_to_points = colmap_points_3d - extended_start[np.newaxis, :] # Project onto line direction to get distance along extended line projection_lengths = np.dot(start_to_points, line_direction) # Filter points within extended line segment bounds within_bounds = (projection_lengths >= 0) & (projection_lengths <= extended_line_length) # Find closest points on extended line segment for all points closest_points_on_line = extended_start[np.newaxis, :] + projection_lengths[:, np.newaxis] * line_direction[np.newaxis, :] # Calculate perpendicular distances from points to line perpendicular_distances = np.linalg.norm(colmap_points_3d - closest_points_on_line, axis=1) # Find points within cylinder within_cylinder = within_bounds & (perpendicular_distances <= cylinder_radius) if np.sum(within_cylinder) <= 10: continue points_in_cylinder = colmap_points_6d[within_cylinder] point_indices_in_cylinder = np.where(within_cylinder)[0] # Center the patch at the midpoint of the original line (not extended) line_midpoint = (start_vertex + end_vertex) / 2 # Shift points to center around origin points_centered = points_in_cylinder.copy() points_centered[:, :3] -= line_midpoint # Create edge patch without label edge_patch = { 'patch_6d': points_centered, 'connection': (i, j), 'line_start': start_vertex - line_midpoint, 'line_end': end_vertex - line_midpoint, 'cylinder_radius': cylinder_radius, 'point_indices': point_indices_in_cylinder, 'center': line_midpoint } forward_patches.append(edge_patch) return forward_patches def generate_edge_patches_forward_10d(frame, pred_vertices, colmap_pcloud): vertices = np.array(pred_vertices) if pred_vertices is not None and len(pred_vertices) > 0 else np.empty((0, 3)) forward_patches = [] cylinder_radius = 1.0 # meters # colmap_pcloud['points_7d'] is [x,y,z, r,g,b (0-1), pid] # Extract xyz and rgb (0-1) points_xyz_rgb_pid = colmap_pcloud['points_7d'] colmap_points_3d = points_xyz_rgb_pid[:, :3] colmap_rgb_colors_01 = points_xyz_rgb_pid[:, 3:6] # Normalize RGB colors to [-1, 1] colmap_rgb_colors_neg1_1 = colmap_rgb_colors_01 * 2.0 - 1.0 ade_counts = colmap_pcloud['ade'] # These are counts ade_feature_neg1_1 = np.where(ade_counts > 0, 1.0, -1.0).reshape(-1, 1) # Normalize to [-1, 1] gestalt_observations_per_point = colmap_pcloud['gestalt'] # List of lists of uint8 RGB arrays fused_gestalt_neg1_1 = np.zeros((len(colmap_points_3d), 3)) if len(colmap_points_3d) > 0: # Ensure there are points to process for i, point_gestalt_list in enumerate(gestalt_observations_per_point): if not point_gestalt_list: # Empty list fused_gestalt_neg1_1[i] = np.array([-1.0, -1.0, -1.0]) continue gestalt_tuples = [tuple(gestalt_val) for gestalt_val in point_gestalt_list] counts = Counter(gestalt_tuples) if counts: # Ensure counts is not empty most_common_tuple = counts.most_common(1)[0][0] fused_value_uint8 = np.array(most_common_tuple, dtype=np.uint8) fused_gestalt_neg1_1[i] = (fused_value_uint8 / 255.0) * 2.0 - 1.0 else: # Default if counts is empty (e.g. all gestalt_val were unhashable or list was empty after filtering) fused_gestalt_neg1_1[i] = np.array([-1.0, -1.0, -1.0]) else: # Handle case with no points fused_gestalt_neg1_1 = np.empty((0,3)) # Create combined 10D point cloud (xyz + rgb + ade + gestalt) if len(colmap_points_3d) > 0: colmap_points_10d = np.hstack(( colmap_points_3d, colmap_rgb_colors_neg1_1, ade_feature_neg1_1, fused_gestalt_neg1_1 )) else: colmap_points_10d = np.empty((0,10)) # For each unique pair of vertices, create a candidate edge patch if len(vertices) >= 2 and len(colmap_points_10d) > 0: for i in range(len(vertices)): for j in range(i + 1, len(vertices)): # Ensure unique pairs (j > i) start_vertex = vertices[i] end_vertex = vertices[j] line_vector = end_vertex - start_vertex line_length = np.linalg.norm(line_vector) if line_length < 1e-6: continue # Avoid division by zero or very short lines line_direction = line_vector / line_length extension_length = 1.0 # meters extended_start = start_vertex - extension_length * line_direction extended_end = end_vertex + extension_length * line_direction extended_line_length = line_length + 2 * extension_length start_to_points = colmap_points_3d - extended_start[np.newaxis, :] projection_lengths = np.dot(start_to_points, line_direction) within_bounds = (projection_lengths >= 0) & (projection_lengths <= extended_line_length) # Ensure closest_points_on_line has the same shape for subtraction closest_points_on_line = extended_start[np.newaxis, :] + projection_lengths[:, np.newaxis] * line_direction[np.newaxis, :] perpendicular_distances = np.linalg.norm(colmap_points_3d - closest_points_on_line, axis=1) within_cylinder = within_bounds & (perpendicular_distances <= cylinder_radius) if np.sum(within_cylinder) <= 5: # Minimum number of points to form a patch continue points_in_cylinder_10d = colmap_points_10d[within_cylinder] point_indices_in_cylinder = np.where(within_cylinder)[0] # Original indices from colmap_points_10d line_midpoint = (start_vertex + end_vertex) / 2 points_centered_10d = points_in_cylinder_10d.copy() points_centered_10d[:, :3] -= line_midpoint # Center XYZ coordinates candidate_patch = { 'patch_10d': points_centered_10d, 'connection': (i, j), # Indices refer to `pred_vertices` 'line_start': start_vertex - line_midpoint, # Relative to midpoint 'line_end': end_vertex - line_midpoint, # Relative to midpoint 'cylinder_radius': cylinder_radius, 'point_indices': point_indices_in_cylinder, # Indices from the full 10D point cloud 'center': line_midpoint # World coordinate of the patch center } forward_patches.append(candidate_patch) #print(f"Generated {len(forward_patches)} candidate edge patches for 10d_forward") # Visualization (optional, can be enabled for debugging) if False: # Ensure pyvista (pv) is imported if this block is enabled # import pyvista as pv plotter = pv.Plotter() if len(colmap_points_3d) > 0: whole_cloud = pv.PolyData(colmap_points_3d) # Use actual RGB colors from colmap_rgb_colors_01 for visualization whole_cloud["colors"] = colmap_rgb_colors_01 plotter.add_mesh(whole_cloud, scalars="colors", rgb=True, point_size=3, render_points_as_spheres=True) # Visualize predicted vertices for vert_idx, vert_pos in enumerate(vertices): vert_sphere = pv.Sphere(radius=0.1, center=vert_pos) plotter.add_mesh(vert_sphere, color='cyan', opacity=0.8) plotter.add_point_labels([vert_pos], [f"V{vert_idx}"], point_size=20, font_size=10) for patch_idx, patch in enumerate(forward_patches): patch_color = 'orange' # Color for candidate patches points_in_cylinder_xyz_local = patch['patch_10d'][:, :3] # Already centered line_start_local = patch['line_start'] line_end_local = patch['line_end'] patch_center_world = patch['center'] # Transform patch points back to world coordinates for visualization points_world = points_in_cylinder_xyz_local + patch_center_world if len(points_world) > 0: cylinder_cloud = pv.PolyData(points_world) # Use RGB from patch_10d (cols 3,4,5), denormalized for visualization patch_rgb_colors_neg1_1 = patch['patch_10d'][:, 3:6] patch_rgb_colors_01 = (patch_rgb_colors_neg1_1 + 1.0) / 2.0 cylinder_cloud["colors"] = patch_rgb_colors_01 plotter.add_mesh(cylinder_cloud, scalars="colors", rgb=True, point_size=8, render_points_as_spheres=True) # Visualize the line segment (connection) in world coordinates start_point_world = line_start_local + patch_center_world end_point_world = line_end_local + patch_center_world start_sphere_world = pv.Sphere(radius=0.05, center=start_point_world) end_sphere_world = pv.Sphere(radius=0.05, center=end_point_world) plotter.add_mesh(start_sphere_world, color='black', opacity=0.8) plotter.add_mesh(end_sphere_world, color='white', opacity=0.8) line_world = pv.Line(start_point_world, end_point_world) plotter.add_mesh(line_world, color=patch_color, line_width=3) # Visualize cylinder bounds cyl_direction_local = (line_end_local - line_start_local) cyl_height_local = np.linalg.norm(cyl_direction_local) if cyl_height_local > 1e-6: cyl_direction_unit_local = cyl_direction_local / cyl_height_local # Height for visualization should match the extended line used for point gathering cyl_height_world_vis = cyl_height_local + 2 * 1.0 # extension_length was 1.0 cylinder_mesh = pv.Cylinder(center=patch_center_world, direction=cyl_direction_unit_local, radius=patch['cylinder_radius'], height=cyl_height_world_vis) plotter.add_mesh(cylinder_mesh, color=patch_color, opacity=0.15, style='wireframe') title = f"Candidate Edge Patches (10d_forward): {len(forward_patches)}" plotter.show(title=title) return forward_patches def calculate_cylinder_overlap_volume(cyl1, cyl2): """ Calculate the intersection volume between two cylinders using numpy vectorization. Returns approximate overlap volume. """ # Get cylinder parameters p1_start, p1_end = cyl1['start'], cyl1['end'] p2_start, p2_end = cyl2['start'], cyl2['end'] r1, r2 = cyl1['radius'], cyl2['radius'] # Calculate cylinder axes axis1 = p1_end - p1_start axis2 = p2_end - p2_start len1 = np.linalg.norm(axis1) len2 = np.linalg.norm(axis2) if len1 == 0 or len2 == 0: return 0.0 axis1_norm = axis1 / len1 axis2_norm = axis2 / len2 # Calculate distance between cylinder axes using line-line distance formula w = p1_start - p2_start a = np.dot(axis1_norm, axis1_norm) b = np.dot(axis1_norm, axis2_norm) c = np.dot(axis2_norm, axis2_norm) d = np.dot(axis1_norm, w) e = np.dot(axis2_norm, w) denom = a * c - b * b if abs(denom) < 1e-10: # Lines are parallel # Calculate perpendicular distance between parallel lines cross_product = np.cross(axis1_norm, w) if axis1_norm.shape[0] == 3: # 3D case dist = np.linalg.norm(cross_product) else: # 2D case dist = abs(cross_product) else: # Calculate closest points on both lines t1 = (b * e - c * d) / denom t2 = (a * e - b * d) / denom # Clamp to cylinder bounds t1 = np.clip(t1, 0, len1) t2 = np.clip(t2, 0, len2) # Calculate distance between closest points point1 = p1_start + t1 * axis1_norm point2 = p2_start + t2 * axis2_norm dist = np.linalg.norm(point1 - point2) # If cylinders don't intersect radially, return 0 if dist >= (r1 + r2): return 0.0 # Calculate overlapping length along both axes # Project cylinder 2 endpoints onto cylinder 1 axis proj_start = np.dot(p2_start - p1_start, axis1_norm) proj_end = np.dot(p2_end - p1_start, axis1_norm) # Find overlap interval overlap_start = max(0, min(proj_start, proj_end)) overlap_end = min(len1, max(proj_start, proj_end)) overlap_length = max(0, overlap_end - overlap_start) if overlap_length <= 0: return 0.0 # Approximate volume calculation # For simplicity, assume uniform overlap along the length if dist < abs(r1 - r2): # One cylinder is inside the other smaller_radius = min(r1, r2) overlap_volume = np.pi * smaller_radius**2 * overlap_length else: # Partial overlap - use geometric approximation # This is a simplified calculation for the intersection area of two circles r_smaller = min(r1, r2) r_larger = max(r1, r2) if dist < (r1 + r2): # Calculate intersection area of two circles (approximate) # Using lens area formula d1 = (r1**2 - r2**2 + dist**2) / (2 * dist) if dist > 0 else 0 d2 = dist - d1 if d1 >= 0 and d1 <= r1 and d2 >= 0 and d2 <= r2: area1 = r1**2 * np.arccos(d1/r1) - d1 * np.sqrt(r1**2 - d1**2) area2 = r2**2 * np.arccos(d2/r2) - d2 * np.sqrt(r2**2 - d2**2) intersection_area = area1 + area2 else: intersection_area = np.pi * r_smaller**2 overlap_volume = intersection_area * overlap_length else: overlap_volume = 0.0 return max(0.0, overlap_volume) def create_pcloud(colmap_rec, frame): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') #print(f"create_pcloud using device: {device}") # 1. Preprocess image data from the frame and colmap (mostly on CPU) img_id_to_colmap_img_obj_map = { img_obj.name: img_obj for img_obj_name, img_obj in colmap_rec.images.items() } frame_img_data = {} ordered_frame_img_ids = [] for K_val, R_val, t_val, img_id_val, ade_val, gestalt_val, depth_val in zip( frame['K'], frame['R'], frame['t'], frame['image_ids'], frame['ade'], frame['gestalt'], frame['depth'] ): if img_id_val not in img_id_to_colmap_img_obj_map: continue ordered_frame_img_ids.append(img_id_val) depth_np = np.array(depth_val) depth_H, depth_W = depth_np.shape[0], depth_np.shape[1] ade_mask_np = get_house_mask(ade_val) gest_seg_pil = gestalt_val.resize((depth_W, depth_H), Image.Resampling.NEAREST) gest_seg_np = np.array(gest_seg_pil).astype(np.uint8) frame_img_data[img_id_val] = { 'K_np': np.array(K_val), 'R_np': np.array(R_val), 't_np': np.array(t_val).reshape(3,1), 'ade_mask_np': ade_mask_np, 'gestalt_seg_np': gest_seg_np, 'H': depth_H, 'W': depth_W } # 2. Process 3D points by iterating through images point_data_accumulator = {} # Key: pid, accumulates data on CPU # Pre-fetch all COLMAP point data to avoid repeated dictionary lookups colmap_points_data_cpu = { pid: {'xyz': p3D.xyz, 'color': p3D.color / 255.0} for pid, p3D in colmap_rec.points3D.items() } for img_id in ordered_frame_img_ids: if img_id not in frame_img_data: continue col_img_obj = img_id_to_colmap_img_obj_map[img_id] img_data = frame_img_data[img_id] K_np, R_np, t_np = img_data['K_np'], img_data['R_np'], img_data['t_np'] ade_mask_np, gestalt_seg_np = img_data['ade_mask_np'], img_data['gestalt_seg_np'] H, W = img_data['H'], img_data['W'] # Convert current image data to GPU tensors K_gpu = torch.from_numpy(K_np).float().to(device) R_gpu = torch.from_numpy(R_np).float().to(device) t_gpu = torch.from_numpy(t_np).float().to(device) ade_mask_gpu = torch.from_numpy(ade_mask_np).bool().to(device) gestalt_seg_gpu = torch.from_numpy(gestalt_seg_np).to(device) # uint8 is fine visible_pids_in_img = [] visible_xyz_coords_list = [] for pid, p3D_data in colmap_points_data_cpu.items(): if col_img_obj.has_point3D(pid): # This check remains CPU-bound visible_pids_in_img.append(pid) visible_xyz_coords_list.append(p3D_data['xyz']) if not visible_pids_in_img: continue num_visible_points = len(visible_pids_in_img) world_pts_np = np.array(visible_xyz_coords_list) world_pts_gpu = torch.from_numpy(world_pts_np).float().to(device) # Batch projection on GPU world_pts_h_gpu = torch.cat((world_pts_gpu, torch.ones(num_visible_points, 1, device=device)), dim=1) P_world_to_cam_gpu = torch.hstack((R_gpu, t_gpu)) cam_coords_proj_gpu = P_world_to_cam_gpu @ world_pts_h_gpu.T cam_coords_z_gpu = cam_coords_proj_gpu[2, :] in_front_mask_gpu = cam_coords_z_gpu > 1e-6 pixel_coords_h_gpu = K_gpu @ cam_coords_proj_gpu u_proj_gpu = torch.full_like(cam_coords_z_gpu, -1.0, dtype=torch.float32) v_proj_gpu = torch.full_like(cam_coords_z_gpu, -1.0, dtype=torch.float32) # Avoid division by zero/small numbers for points not truly in front or on optical center valid_depth_mask_gpu = in_front_mask_gpu & (torch.abs(cam_coords_z_gpu) > 1e-6) if torch.any(valid_depth_mask_gpu): u_proj_gpu[valid_depth_mask_gpu] = pixel_coords_h_gpu[0, valid_depth_mask_gpu] / cam_coords_z_gpu[valid_depth_mask_gpu] v_proj_gpu[valid_depth_mask_gpu] = pixel_coords_h_gpu[1, valid_depth_mask_gpu] / cam_coords_z_gpu[valid_depth_mask_gpu] u_rounded_gpu = torch.round(u_proj_gpu).long() v_rounded_gpu = torch.round(v_proj_gpu).long() is_in_bounds_gpu = (u_rounded_gpu >= 0) & (u_rounded_gpu < W) & \ (v_rounded_gpu >= 0) & (v_rounded_gpu < H) & \ in_front_mask_gpu # Re-check in_front_mask_gpu as rounding might affect edge cases slightly # Sample ADE and Gestalt on GPU for points in bounds # Initialize with default values for all points, then update for those in bounds sampled_ade_status_gpu = torch.zeros(num_visible_points, dtype=torch.bool, device=device) sampled_gestalt_values_gpu = torch.zeros(num_visible_points, 3, dtype=torch.uint8, device=device) # Create a mask for points that are valid for sampling (in_bounds and in_front) valid_for_sampling_mask_gpu = is_in_bounds_gpu if torch.any(valid_for_sampling_mask_gpu): u_sample_gpu = u_rounded_gpu[valid_for_sampling_mask_gpu] v_sample_gpu = v_rounded_gpu[valid_for_sampling_mask_gpu] sampled_ade_status_gpu[valid_for_sampling_mask_gpu] = ade_mask_gpu[v_sample_gpu, u_sample_gpu] sampled_gestalt_values_gpu[valid_for_sampling_mask_gpu] = gestalt_seg_gpu[v_sample_gpu, u_sample_gpu] # Transfer necessary results back to CPU for accumulation u_rounded_cpu = u_rounded_gpu.cpu().numpy() v_rounded_cpu = v_rounded_gpu.cpu().numpy() is_in_bounds_cpu = is_in_bounds_gpu.cpu().numpy() # Use the original is_in_bounds_gpu for logic sampled_ade_status_cpu = sampled_ade_status_gpu.cpu().numpy() sampled_gestalt_values_cpu = sampled_gestalt_values_gpu.cpu().numpy() # Update accumulator (on CPU) for i in range(num_visible_points): pid = visible_pids_in_img[i] if pid not in point_data_accumulator: point_data_accumulator[pid] = { 'xyz': colmap_points_data_cpu[pid]['xyz'], 'color': colmap_points_data_cpu[pid]['color'], 'imgs_seen_by': [], 'uv_projections': [], 'ade_count': 0, # Count of times seen in ADE segmentation 'gestalt_values': [] } acc = point_data_accumulator[pid] acc['imgs_seen_by'].append(img_id) acc['uv_projections'].append((u_rounded_cpu[i], v_rounded_cpu[i])) if is_in_bounds_cpu[i]: # This point was projected within bounds and in front if sampled_ade_status_cpu[i]: acc['ade_count'] += 1 acc['gestalt_values'].append(sampled_gestalt_values_cpu[i]) else: # Point projected out of bounds, behind, or failed depth check acc['gestalt_values'].append(np.array([0,0,0], dtype=np.uint8)) # Optional: clear GPU cache if memory is a concern for many images # if device.type == 'cuda': # torch.cuda.empty_cache() # 3. Final data assembly (on CPU) points_xyz_world_list = [] points_colors_list = [] points_idxs_list = [] points_imgs_seen_by_list = [] points_uv_projections_per_point_list = [] points_ade_count_final_list = [] points_gestalt_values_per_point_list = [] # Ensure consistent order if downstream code relies on it, though original didn't specify sorting for pids # Using sorted_pids for reproducibility if point_data_accumulator keys order changes. sorted_pids = sorted(point_data_accumulator.keys()) for pid in sorted_pids: data = point_data_accumulator[pid] points_xyz_world_list.append(data['xyz']) points_colors_list.append(data['color']) points_idxs_list.append(pid) points_imgs_seen_by_list.append(data['imgs_seen_by']) points_uv_projections_per_point_list.append(data['uv_projections']) points_ade_count_final_list.append(data['ade_count']) points_gestalt_values_per_point_list.append(data['gestalt_values']) points_xyz_world = np.array(points_xyz_world_list) if points_xyz_world_list else np.empty((0, 3)) points_colors = np.array(points_colors_list) if points_colors_list else np.empty((0, 3)) points_idxs = np.array(points_idxs_list, dtype=int) if points_idxs_list else np.empty((0,), dtype=int) # Ensure dtype for pids points_ade = np.array(points_ade_count_final_list, dtype=int) if points_ade_count_final_list else np.empty((0,), dtype=int) output_all_colmap_img_ids = [img_obj.name for img_obj_name, img_obj in colmap_rec.images.items()] output_frame_K, output_frame_R, output_frame_t = [], [], [] for img_id_val in frame['image_ids']: if img_id_val in frame_img_data: data = frame_img_data[img_id_val] output_frame_K.append(data['K_np']) output_frame_R.append(data['R_np']) output_frame_t.append(data['t_np']) if points_xyz_world.shape[0] > 0: colmap_points_7d = np.zeros((points_xyz_world.shape[0], 7)) colmap_points_7d[:, :3] = points_xyz_world colmap_points_7d[:, 3:6] = points_colors colmap_points_7d[:, 6] = points_idxs whole_pcloud = { 'points_7d': colmap_points_7d, 'imgs': points_imgs_seen_by_list, 'uv': points_uv_projections_per_point_list, 'all_imgs_ids': output_all_colmap_img_ids, 'all_imgs_K': output_frame_K, 'all_imgs_R': output_frame_R, 'all_imgs_t': output_frame_t, 'ade': points_ade, 'gestalt': points_gestalt_values_per_point_list } else: whole_pcloud = { 'points_7d': np.empty((0, 7)), 'imgs': [], 'uv': [], 'all_imgs_ids': output_all_colmap_img_ids, 'all_imgs_K': output_frame_K, 'all_imgs_R': output_frame_R, 'all_imgs_t': output_frame_t, 'ade': np.empty((0,), dtype=int), 'gestalt': [] } return whole_pcloud def predict_wireframe(entry, pnet_model, voxel_model, pnet_class_model, config) -> Tuple[np.ndarray, List[int]]: """ Predict 3D wireframe from a dataset entry. """ device = 'cuda' if torch.cuda.is_available() else 'cpu' good_entry = convert_entry_to_human_readable(entry) colmap_rec = good_entry['colmap_binary'] start_time = time.time() colmap_pcloud = create_pcloud(colmap_rec, good_entry) print(f"Time for create_pcloud: {time.time() - start_time:.4f} seconds") vertex_threshold = config.get('vertex_threshold', 0.5) edge_threshold = config.get('edge_threshold', 0.5) only_predicted_connections = config.get('only_predicted_connections', False) vert_edge_per_image = {} idxs_points = [] all_connections = [] our_get_vertices_time_total = 0 for i, (gest, depth, K, R, t, img_id, ade_seg) in enumerate(zip(good_entry['gestalt'], good_entry['depth'], good_entry['K'], good_entry['R'], good_entry['t'], good_entry['image_ids'], good_entry['ade'] # Added ade20k segmentation )): # Visualize gestalt segmentation K = np.array(K) R = np.array(R) t = np.array(t) # Resize gestalt segmentation to match depth map size depth_size = (np.array(depth).shape[1], np.array(depth).shape[0]) # W, H gest_seg = gest.resize(depth_size) gest_seg_np = np.array(gest_seg).astype(np.uint8) start_time_loop = time.time() vertices_ours, connections_ours, vertices_3d_ours, patches, filtered_point_idxs = our_get_vertices_and_edges(gest_seg_np, colmap_rec, img_id, ade_seg, depth, K=K, R=R, t=t, frame=good_entry) our_get_vertices_time_total += (time.time() - start_time_loop) idxs_points.append(filtered_point_idxs) all_connections.append(connections_ours) vertices, connections, vertices_3d = vertices_ours, connections_ours, vertices_3d_ours vert_edge_per_image[i] = vertices, connections, vertices_3d print(f"Total time for our_get_vertices_and_edges loop: {our_get_vertices_time_total:.4f} seconds") start_time = time.time() extracted_points, extracted_colors, extracted_ids, whole_pcloud, connections = extract_vertices_from_whole_pcloud_v2(colmap_pcloud, idxs_points, all_connections) print(f"Time for extract_vertices_from_whole_pcloud_v2: {time.time() - start_time:.4f} seconds") wf_vertices = good_entry.get('wf_vertices', None) start_time = time.time() patches = generate_patches_v3(extracted_points, extracted_colors, extracted_ids, whole_pcloud, wf_vertices) print(f"Time for generate_patches_v3: {time.time() - start_time:.4f} seconds") if GENERATE_DATASET: start_time = time.time() save_patches_dataset(patches, DATASET_DIR, img_id) print(f"Time for save_patches_dataset: {time.time() - start_time:.4f} seconds") return empty_solution() predicted_vertices = [] predict_vertex_time_total = 0 for i, patch in enumerate(patches): start_time_loop = time.time() pred_vertex, pred_dist, pred_class = predict_vertex_from_patch(pnet_model, patch, device=device) predict_vertex_time_total += (time.time() - start_time_loop) if pred_class > vertex_threshold: predicted_vertices.append(pred_vertex) else: predicted_vertices.append(np.array([0.0, 0.0, 0.0])) # Append a zero vertex if not predicted print(f"Total time for predict_vertex_from_patch loop: {predict_vertex_time_total:.4f} seconds") predicted_vertices = np.array(predicted_vertices) if predicted_vertices else np.empty((0, 3)) # Filter out zero vertices and update connections accordingly non_zero_mask = ~np.all(np.isclose(predicted_vertices, [0.0, 0.0, 0.0]), axis=1) valid_indices = np.where(non_zero_mask)[0] # Filter vertices to only include non-zero ones filtered_vertices = predicted_vertices[valid_indices] if GENERATE_DATASET_EDGES: start_time = time.time() edge_patches = generate_edge_patches(good_entry, filtered_vertices, colmap_pcloud) print(f"Time for generate_edge_patches: {time.time() - start_time:.4f} seconds") start_time = time.time() save_patches_dataset_class(edge_patches, EDGES_DATASET_DIR, good_entry['order_id']) print(f"Time for save_patches_dataset_class: {time.time() - start_time:.4f} seconds") return empty_solution() if len(valid_indices) == 0: print("No valid predicted vertices found") return empty_solution() # Create mapping from old indices to new indices old_to_new_mapping = {old_idx: new_idx for new_idx, old_idx in enumerate(valid_indices)} # Filter and update connections filtered_connections = [] for start_idx, end_idx in connections: if start_idx in old_to_new_mapping and end_idx in old_to_new_mapping: new_start = old_to_new_mapping[start_idx] new_end = old_to_new_mapping[end_idx] if new_start != new_end: # Ensure we don't connect a vertex to itself filtered_connections.append((new_start, new_end)) start_time = time.time() #forward_patches = generate_edge_patches_forward_10d(good_entry, filtered_vertices, colmap_pcloud) forward_patches = generate_edge_patches_forward(good_entry, filtered_vertices) print(f"Time for generate_edge_patches_forward: {time.time() - start_time:.4f} seconds") new_connections = [] predict_class_time_total = 0 if len(forward_patches) > 0: for i, patch in enumerate(forward_patches): start_idx, end_idx = patch['connection'] start_time_loop = time.time() pred_class, pred_score = predict_class_from_patch(pnet_class_model, patch, device=device) predict_class_time_total += (time.time() - start_time_loop) if pred_score > edge_threshold: new_connections.append((start_idx, end_idx)) print(f"Total time for predict_class_from_patch loop: {predict_class_time_total:.4f} seconds") predicted_vertices = np.array(filtered_vertices) if only_predicted_connections: connections = new_connections else: connections = filtered_connections + new_connections # Remove duplicates from connections connections = list(set(connections)) return predicted_vertices, connections def predict_wireframe_old(entry) -> Tuple[np.ndarray, List[int]]: """ Predict 3D wireframe from a dataset entry. """ good_entry = convert_entry_to_human_readable(entry) vert_edge_per_image = {} for i, (gest, depth, K, R, t, img_id, ade_seg) in enumerate(zip(good_entry['gestalt'], good_entry['depth'], good_entry['K'], good_entry['R'], good_entry['t'], good_entry['image_ids'], good_entry['ade'] # Added ade20k segmentation )): colmap_rec = good_entry['colmap_binary'] K = np.array(K) R = np.array(R) t = np.array(t) # Resize gestalt segmentation to match depth map size depth_size = (np.array(depth).shape[1], np.array(depth).shape[0]) # W, H gest_seg = gest.resize(depth_size) gest_seg_np = np.array(gest_seg).astype(np.uint8) # Get 2D vertices and edges first vertices, connections = get_vertices_and_edges_from_segmentation(gest_seg_np, edge_th=25.) # Check if we have enough to proceed if (len(vertices) < 2) or (len(connections) < 1): print(f'Not enough vertices or connections found in image {i}, skipping.') vert_edge_per_image[i] = [], [], np.empty((0, 3)) continue # Call the refactored function to get 3D points vertices_3d = create_3d_wireframe_single_image( vertices, connections, depth, colmap_rec, img_id, ade_seg, K, R, t ) # Store original 2D vertices, connections, and computed 3D points vert_edge_per_image[i] = vertices, connections, vertices_3d # Merge vertices from all images all_3d_vertices, connections_3d = merge_vertices_3d(vert_edge_per_image, 0.5) all_3d_vertices_clean, connections_3d_clean = prune_not_connected(all_3d_vertices, connections_3d, keep_largest=False) all_3d_vertices_clean, connections_3d_clean = prune_too_far(all_3d_vertices_clean, connections_3d_clean, colmap_rec, th = 1.5) if (len(all_3d_vertices_clean) < 2) or len(connections_3d_clean) < 1: print (f'Not enough vertices or connections in the 3D vertices') return empty_solution() return all_3d_vertices_clean, connections_3d_clean def generate_patches_v2(extracted_points, extracted_colors, extracted_ids, whole_pcloud, wf_vertices): patches = [] whole_points = whole_pcloud['points'] whole_colors = whole_pcloud['colors'] whole_ids = whole_pcloud['ids'] wf_vertices = np.array(wf_vertices) if wf_vertices is not None else np.empty((0, 3)) for cluster_idx, (cluster_points, cluster_colors, cluster_ids) in enumerate(zip(extracted_points, extracted_colors, extracted_ids)): if len(cluster_points) == 0: continue # Calculate center as mean of cluster points cluster_center = np.mean(cluster_points, axis=0) # Define cube edge length cube_edge_length = 4.0 half_edge = cube_edge_length / 2.0 # Find points within cube bounds within_cube_mask = ( (whole_points[:, 0] >= cluster_center[0] - half_edge) & (whole_points[:, 0] <= cluster_center[0] + half_edge) & (whole_points[:, 1] >= cluster_center[1] - half_edge) & (whole_points[:, 1] <= cluster_center[1] + half_edge) & (whole_points[:, 2] >= cluster_center[2] - half_edge) & (whole_points[:, 2] <= cluster_center[2] + half_edge) ) if not np.any(within_cube_mask): continue # Extract points within cube cube_points = whole_points[within_cube_mask] cube_colors = whole_colors[within_cube_mask] cube_point_ids = whole_ids[within_cube_mask] # Shift points to center at origin cube_points_centered = cube_points - cluster_center # Create 7D point cloud patch_7d = np.zeros((len(cube_points_centered), 7)) patch_7d[:, :3] = cube_points_centered # xyz coordinates centered at origin patch_7d[:, 3:6] = cube_colors * 2.0 - 1.0 # rgb colors normalized to [-1, 1] # Set flag: 1 if point is in current cluster, -1 otherwise cluster_ids_set = set(cluster_ids) for i, pid in enumerate(cube_point_ids): if pid in cluster_ids_set: patch_7d[i, 6] = 1.0 else: patch_7d[i, 6] = -1.0 # Find closest wf_vertex to cluster center assigned_wf_vertex = None if len(wf_vertices) > 0: # Calculate distances from cluster center to all GT vertices distances_to_gt = np.linalg.norm(wf_vertices - cluster_center, axis=1) # Find GT vertices within 1 meter of cluster center within_radius_mask = distances_to_gt <= 1.0 if np.any(within_radius_mask): # Find the closest GT vertex within 1 meter closest_idx = np.argmin(distances_to_gt[within_radius_mask]) # Get the actual index in the original array valid_indices = np.where(within_radius_mask)[0] actual_closest_idx = valid_indices[closest_idx] # Shift the assigned GT vertex to be relative to origin assigned_wf_vertex = wf_vertices[actual_closest_idx] - cluster_center patch = { 'patch_7d': patch_7d, 'cluster_center': cluster_center, 'cube_edge_length': cube_edge_length, 'cluster_idx': cluster_idx, 'assigned_wf_vertex': assigned_wf_vertex, 'cube_point_ids': cube_point_ids, 'cluster_point_ids': cluster_ids } patches.append(patch) # Visualize the patch using PyVista if False: # Set to False to disable visualization plotter = pv.Plotter() # Create point cloud for this patch patch_cloud = pv.PolyData(cube_points_centered) # Color points based on cluster membership flag patch_colors = [] for i in range(len(cube_points_centered)): if patch_7d[i, 6] == 1.0: # Point is in cluster patch_colors.append([1.0, 0.0, 0.0]) # Red for cluster points else: patch_colors.append([0.0, 0.0, 1.0]) # Blue for other points patch_cloud["colors"] = np.array(patch_colors) plotter.add_mesh(patch_cloud, scalars="colors", rgb=True, point_size=8, render_points_as_spheres=True) # Add cube wireframe to show extraction bounds cube_bounds = [ -half_edge, half_edge, # x_min, x_max -half_edge, half_edge, # y_min, y_max -half_edge, half_edge # z_min, z_max ] cube_wireframe = pv.Box(bounds=cube_bounds) plotter.add_mesh(cube_wireframe, style='wireframe', color='gray', line_width=2) # Add sphere for assigned GT vertex if available if assigned_wf_vertex is not None: gt_sphere = pv.Sphere(radius=0.1, center=assigned_wf_vertex) plotter.add_mesh(gt_sphere, color="green", opacity=0.7) # Add sphere at origin to show patch center origin_sphere = pv.Sphere(radius=0.05, center=[0, 0, 0]) plotter.add_mesh(origin_sphere, color="yellow", opacity=0.8) plotter.show(title=f"Patch {cluster_idx} - Edge length: {cube_edge_length}m") return patches def generate_patches_v3(extracted_points, extracted_colors, extracted_ids, whole_pcloud, wf_vertices): patches = [] whole_points = whole_pcloud['points'] whole_colors = whole_pcloud['colors'] # Now 7D: [r, g, b, ade, gestalt_r, gestalt_g, gestalt_b] whole_ids = whole_pcloud['ids'] wf_vertices = np.array(wf_vertices) if wf_vertices is not None else np.empty((0, 3)) for cluster_idx, (cluster_points, cluster_colors, cluster_ids) in enumerate(zip(extracted_points, extracted_colors, extracted_ids)): if len(cluster_points) == 0: continue # Calculate center as mean of cluster points cluster_center = np.mean(cluster_points, axis=0) # Define cube edge length cube_edge_length = 8.0 half_edge = cube_edge_length / 2.0 # Find points within cube bounds within_cube_mask = ( (whole_points[:, 0] >= cluster_center[0] - half_edge) & (whole_points[:, 0] <= cluster_center[0] + half_edge) & (whole_points[:, 1] >= cluster_center[1] - half_edge) & (whole_points[:, 1] <= cluster_center[1] + half_edge) & (whole_points[:, 2] >= cluster_center[2] - half_edge) & (whole_points[:, 2] <= cluster_center[2] + half_edge) ) if not np.any(within_cube_mask): continue # Extract points within cube cube_points = whole_points[within_cube_mask] cube_colors_7d = whole_colors[within_cube_mask] # Now 7D colors cube_point_ids = whole_ids[within_cube_mask] # Shift points to center at origin cube_points_centered = cube_points - cluster_center # Create 10D point cloud: [x, y, z, r, g, b, ade, gestalt_r, gestalt_g, gestalt_b] patch_10d = np.zeros((len(cube_points_centered), 10)) patch_10d[:, :3] = cube_points_centered # xyz coordinates centered at origin patch_10d[:, 3:6] = cube_colors_7d[:, :3] * 2.0 - 1.0 # rgb colors normalized to [-1, 1] patch_10d[:, 6] = cube_colors_7d[:, 3] * 2.0 - 1.0 # ade feature normalized to [-1, 1] patch_10d[:, 7:10] = cube_colors_7d[:, 4:7] * 2.0 - 1.0 # gestalt colors normalized to [-1, 1] # Set flag: 1 if point is in current cluster, -1 otherwise cluster_ids_set = set(cluster_ids) cluster_flag = np.full(len(cube_point_ids), -1.0) for i, pid in enumerate(cube_point_ids): if pid in cluster_ids_set: cluster_flag[i] = 1.0 # Add cluster flag as 11th dimension patch_11d = np.zeros((len(cube_points_centered), 11)) patch_11d[:, :10] = patch_10d patch_11d[:, 10] = cluster_flag # Find closest wf_vertex to cluster center assigned_wf_vertex = None if len(wf_vertices) > 0: # Calculate distances from cluster center to all GT vertices distances_to_gt = np.linalg.norm(wf_vertices - cluster_center, axis=1) # Find GT vertices within 1 meter of cluster center within_radius_mask = distances_to_gt <= 1.0 if np.any(within_radius_mask): # Find the closest GT vertex within 1 meter closest_idx = np.argmin(distances_to_gt[within_radius_mask]) # Get the actual index in the original array valid_indices = np.where(within_radius_mask)[0] actual_closest_idx = valid_indices[closest_idx] # Shift the assigned GT vertex to be relative to origin assigned_wf_vertex = wf_vertices[actual_closest_idx] - cluster_center patch = { 'patch_11d': patch_11d, # Changed from patch_7d to patch_11d 'cluster_center': cluster_center, 'cube_edge_length': cube_edge_length, 'cluster_idx': cluster_idx, 'assigned_wf_vertex': assigned_wf_vertex, 'cube_point_ids': cube_point_ids, 'cluster_point_ids': cluster_ids } patches.append(patch) # Visualize the patch using PyVista if False: # Set to False to disable visualization plotter = pv.Plotter() # Create point cloud for this patch patch_cloud = pv.PolyData(cube_points_centered) # Color points based on cluster membership flag patch_colors = [] for i in range(len(cube_points_centered)): if patch_11d[i, 10] == 1.0: # Point is in cluster patch_colors.append([1.0, 0.0, 0.0]) # Red for cluster points else: patch_colors.append([0.0, 0.0, 1.0]) # Blue for other points patch_cloud["colors"] = np.array(patch_colors) plotter.add_mesh(patch_cloud, scalars="colors", rgb=True, point_size=8, render_points_as_spheres=True) # Add cube wireframe to show extraction bounds cube_bounds = [ -half_edge, half_edge, # x_min, x_max -half_edge, half_edge, # y_min, y_max -half_edge, half_edge # z_min, z_max ] cube_wireframe = pv.Box(bounds=cube_bounds) plotter.add_mesh(cube_wireframe, style='wireframe', color='gray', line_width=2) # Add sphere for assigned GT vertex if available if assigned_wf_vertex is not None: gt_sphere = pv.Sphere(radius=0.1, center=assigned_wf_vertex) plotter.add_mesh(gt_sphere, color="green", opacity=0.7) # Add sphere at origin to show patch center origin_sphere = pv.Sphere(radius=0.05, center=[0, 0, 0]) plotter.add_mesh(origin_sphere, color="yellow", opacity=0.8) plotter.show(title=f"Patch {cluster_idx} - Edge length: {cube_edge_length}m") return patches def get_visible_points(colmap_rec, img_id_substring, R=None, t=None): # 1) Find the matching COLMAP image to get its associated 3D points # This part remains to identify which 3D points are relevant for this image view found_img = None for img_id_c, col_img_obj in colmap_rec.images.items(): # Renamed col_img to col_img_obj to avoid conflict if img_id_substring in col_img_obj.name: found_img = col_img_obj break if found_img is None: print(f"Image substring {img_id_substring} not found in COLMAP.") return [], [], [] # 2) Gather 3D points that this image sees (according to COLMAP) points_xyz_world = [] points_idxs = [] for pid, p3D in colmap_rec.points3D.items(): if found_img.has_point3D(pid): points_xyz_world.append(p3D.xyz) # world coords points_idxs.append(pid) if not points_xyz_world: print(f"No 3D points associated with {found_img.name} in COLMAP.") return [], [], [] points_xyz_world = np.array(points_xyz_world) # (N, 3) points_idxs = np.array(points_idxs) # (N,) points_xyz_world_h = np.hstack((points_xyz_world, np.ones((points_xyz_world.shape[0], 1)))) # (N, 4) # World to Camera transformation matrix world_to_cam_mat = np.eye(4) world_to_cam_mat[:3, :3] = R world_to_cam_mat[:3, 3] = t.flatten() points_cam_h = (world_to_cam_mat @ points_xyz_world_h.T).T # (N, 4) points_cam = points_cam_h[:, :3] / points_cam_h[:, 3, np.newaxis] # (N, 3) in camera coordinates return points_cam, points_xyz_world, points_idxs def project_points_to_2d(points_cam, K, H, W): uv = [] valid_indices = [] # Track which original points are valid for i in range(points_cam.shape[0]): p_cam = points_cam[i] # Ensure p_cam[2] (depth) is positive if p_cam[2] <= 0: continue # Project to image plane using K u_i = (K[0, 0] * p_cam[0] / p_cam[2]) + K[0, 2] v_i = (K[1, 1] * p_cam[1] / p_cam[2]) + K[1, 2] u_i_int = int(round(u_i)) v_i_int = int(round(v_i)) # Check in-bounds if 0 <= u_i_int < W and 0 <= v_i_int < H: uv.append((u_i_int, v_i_int)) valid_indices.append(i) # Store original index uv = np.array(uv, dtype=int) # shape (M,2) valid_indices = np.array(valid_indices) # shape (M,) return uv, valid_indices def project_points_to_2d_colmap(points_xyz_world, found_img, H, W): uv_colmap = [] valid_indices_colmap = [] for i, xyz in enumerate(points_xyz_world): proj = found_img.project_point(xyz) # returns (u, v) in image coords or None if proj is not None: u_i, v_i = proj u_i = int(round(u_i)) v_i = int(round(v_i)) # Check in-bounds if 0 <= u_i < W and 0 <= v_i < H: uv_colmap.append((u_i, v_i)) valid_indices_colmap.append(i) # Store original index uv_colmap = np.array(uv_colmap, dtype=int) valid_indices_colmap = np.array(valid_indices_colmap) return uv_colmap, valid_indices_colmap def get_apex_or_eave_points(type, uv, gest_seg_np, house_mask, valid_indices, points_xyz_world, points_cam, points_idxs): # Apex if type == 'apex': apex_color = np.array(gestalt_color_mapping['apex']) elif type == 'eave_end': apex_color = np.array(gestalt_color_mapping['eave_end_point']) elif type == 'flashing_end_point': apex_color = np.array(gestalt_color_mapping['flashing_end_point']) apex_mask = cv2.inRange(gest_seg_np, apex_color-10., apex_color+10.) filtered_points_xyz = [] filtered_point_idxs = [] filtered_points_color = [] filtered_vertices_apex = [] filtered_vertices_apex_uv = [] if apex_mask.sum() > 0: output = cv2.connectedComponentsWithStats(apex_mask, 8, cv2.CV_32S) (numLabels, labels, stats, centroids) = output for i in range(1, numLabels): cur_mask = labels == i # Dilate the current mask to make it slightly larger kernel = np.ones((5,5), np.uint8) cur_mask = cv2.dilate(cur_mask.astype(np.uint8), kernel, iterations=2).astype(bool) color = np.random.rand(3) # Create boolean mask for points in current apex mask and house mask valid_points_mask = cur_mask[uv[:, 1], uv[:, 0]] & house_mask[uv[:, 1], uv[:, 0]] for z in range(5): if np.sum(valid_points_mask) < 5: cur_mask = cv2.dilate(cur_mask.astype(np.uint8), kernel, iterations=1).astype(bool) valid_points_mask = cur_mask[uv[:, 1], uv[:, 0]] & house_mask[uv[:, 1], uv[:, 0]] else: break if np.any(valid_points_mask): # Get indices of valid points valid_point_indices = valid_indices[valid_points_mask] # Get 3D points in camera coordinates for depth filtering valid_world_points = points_xyz_world[valid_point_indices] valid_cam_points = points_cam[valid_point_indices] # Compute depths (Z coordinates in camera space) depths = valid_cam_points[:, 2] # Find minimum depth and filter points within min_depth + 2 meters if len(depths) > 0: min_depth = np.min(depths) depth_filter = depths <= (min_depth + 2.0) # Apply depth filter final_valid_indices = valid_point_indices[depth_filter] # Only add if we have valid points after depth filtering if len(final_valid_indices) > 0: # Add corresponding points to filtered lists filtered_points_xyz.append(points_xyz_world[final_valid_indices]) filtered_point_idxs.append(points_idxs[final_valid_indices]) filtered_points_color.append([color] * np.sum(depth_filter)) # Find the point with lowest depth in the filtered points lowest_depth_idx = np.argmin(depths[depth_filter]) lowest_depth_point = final_valid_indices[lowest_depth_idx] filtered_vertices_apex.append(points_xyz_world[lowest_depth_point]) filtered_vertices_apex_uv.append(centroids[i]) return filtered_points_xyz, filtered_point_idxs, filtered_points_color, filtered_vertices_apex, filtered_vertices_apex_uv def get_vertexes(uv, gest_seg_np, house_mask, valid_indices, points_xyz_world, points_cam, points_idxs): filtered_points_xyz_apex, filtered_point_idxs_apex, filtered_points_color_apex, filtered_vertices_apex, filtered_vertices_apex_uv = get_apex_or_eave_points('apex', uv, gest_seg_np, house_mask, valid_indices, points_xyz_world, points_cam, points_idxs) filtered_points_xyz_eave, filtered_point_idxs_eave, filtered_points_color_eave, filtered_vertices_eave, filtered_vertices_eave_uv = get_apex_or_eave_points('eave_end', uv, gest_seg_np, house_mask, valid_indices, points_xyz_world, points_cam, points_idxs) filtered_points_xyz_flashing, filtered_point_idxs_flashing, filtered_points_color_flashing, filtered_vertices_flashing, filtered_vertices_flashing_uv = get_apex_or_eave_points('flashing_end_point', uv, gest_seg_np, house_mask, valid_indices, points_xyz_world, points_cam, points_idxs) #print(len(filtered_points_xyz_apex), len(filtered_points_xyz_eave), len(filtered_vertices_apex), len(filtered_vertices_eave), len(filtered_point_idxs_apex), len(filtered_point_idxs_eave)) # Combine filtered points from apex, eave_end, and flashing_end_point filtered_points_xyz = filtered_points_xyz_apex + filtered_points_xyz_eave + filtered_points_xyz_flashing filtered_point_idxs = filtered_point_idxs_apex + filtered_point_idxs_eave + filtered_point_idxs_flashing filtered_points_color = filtered_points_color_apex + filtered_points_color_eave + filtered_points_color_flashing #filtered_points_xyz = np.array(filtered_points_xyz[::-1]) if filtered_points_xyz else np.empty((0, 3)) #filtered_point_idxs = np.array(filtered_point_idxs[::-1]) if filtered_point_idxs else np.empty((0,)) #filtered_points_color = np.array(filtered_points_color[::-1]) if filtered_points_color else np.empty((0, 3)) filtered_vertices_apex = np.array(filtered_vertices_apex) if filtered_vertices_apex else np.empty((0, 3)) filtered_vertices_apex_uv = np.array(filtered_vertices_apex_uv) if filtered_vertices_apex_uv else np.empty((0, 2)) filtered_vertices_eave = np.array(filtered_vertices_eave) if filtered_vertices_eave else np.empty((0, 3)) filtered_vertices_eave_uv = np.array(filtered_vertices_eave_uv) if filtered_vertices_eave_uv else np.empty((0, 2)) filtered_vertices_flashing = np.array(filtered_vertices_flashing) if filtered_vertices_flashing else np.empty((0, 3)) filtered_vertices_flashing_uv = np.array(filtered_vertices_flashing_uv) if filtered_vertices_flashing_uv else np.empty((0, 2)) #print(len(filtered_points_xyz), len(filtered_point_idxs), len(filtered_vertices_apex), len(filtered_vertices_apex_uv), len(filtered_vertices_eave), len(filtered_vertices_eave_uv)) return filtered_points_xyz, filtered_point_idxs, filtered_points_color, filtered_vertices_apex, filtered_vertices_apex_uv, filtered_vertices_eave, filtered_vertices_eave_uv, filtered_vertices_flashing, filtered_vertices_flashing_uv def get_connections(gest_seg_np, filtered_vertices_apex, filtered_vertices_eave, filtered_vertices_apex_uv, filtered_vertices_eave_uv): connections = [] edge_classes = ['eave', 'ridge', 'rake', 'valley'] edge_th = 25.0 # threshold for proximity to line segments # Combine apex and eave_end vertices and their UV coordinates all_vertices_3d = [] all_vertices_uv = [] vertex_types = [] # Add apex vertices for i, (vertex_3d, vertex_uv) in enumerate(zip(filtered_vertices_apex, filtered_vertices_apex_uv)): all_vertices_3d.append(vertex_3d) all_vertices_uv.append(vertex_uv) vertex_types.append('apex') # Add eave_end vertices for i, (vertex_3d, vertex_uv) in enumerate(zip(filtered_vertices_eave, filtered_vertices_eave_uv)): all_vertices_3d.append(vertex_3d) all_vertices_uv.append(vertex_uv) vertex_types.append('eave_end') all_vertices_3d = np.array(all_vertices_3d) all_vertices_uv = np.array(all_vertices_uv) if len(all_vertices_uv) < 2: vertices_formatted = [] for uv, vertex_type in zip(all_vertices_uv, vertex_types): vertices_formatted.append({ 'xy': np.array(uv, dtype=float), 'type': vertex_type }) return vertices_formatted, [], all_vertices_3d for edge_class in edge_classes: edge_color = np.array(gestalt_color_mapping[edge_class]) mask_raw = cv2.inRange(gest_seg_np, edge_color-10, edge_color+10) # Morphological operations to clean up the mask kernel = np.ones((5, 5), np.uint8) mask = cv2.morphologyEx(mask_raw, cv2.MORPH_CLOSE, kernel) if mask.sum() == 0: continue # Connected components output = cv2.connectedComponentsWithStats(mask, 8, cv2.CV_32S) (numLabels, labels, stats, centroids) = output # Skip the background stats, centroids = stats[1:], centroids[1:] label_indices = range(1, numLabels) # For each connected component, do a line fit for lbl in label_indices: ys, xs = np.where(labels == lbl) if len(xs) < 2: continue # Fit a line using cv2.fitLine pts_for_fit = np.column_stack([xs, ys]).astype(np.float32) line_params = cv2.fitLine(pts_for_fit, distType=cv2.DIST_L2, param=0, reps=0.01, aeps=0.01) vx, vy, x0, y0 = line_params.ravel() # Find line segment endpoints by projecting points onto the line proj = ((xs - x0)*vx + (ys - y0)*vy) proj_min, proj_max = proj.min(), proj.max() p1 = np.array([x0 + proj_min*vx, y0 + proj_min*vy]) p2 = np.array([x0 + proj_max*vx, y0 + proj_max*vy]) # Find vertices that are close to this line segment if len(all_vertices_uv) < 2: continue # Calculate distance from each vertex UV to the line segment dists = [] for vertex_uv in all_vertices_uv: dist = point_to_segment_dist(vertex_uv, p1, p2) dists.append(dist) dists = np.array(dists) # Find vertices that are near this line segment near_mask = (dists <= edge_th) near_indices = np.where(near_mask)[0] if len(near_indices) < 2: continue # Connect each pair among these near vertices for i in range(len(near_indices)): for j in range(i+1, len(near_indices)): idx_a = near_indices[i] idx_b = near_indices[j] # Create connection tuple (using sorted indices for consistency) conn = tuple(sorted((idx_a, idx_b))) if conn not in connections: connections.append(conn) # Convert all_vertices_uv and vertex_types to the required format vertices_formatted = [] for uv, vertex_type in zip(all_vertices_uv, vertex_types): vertices_formatted.append({ 'xy': np.array(uv, dtype=float), 'type': vertex_type }) return vertices_formatted, connections, all_vertices_3d def visualize_3d_wireframe(colmap_rec, filtered_points_xyz, filtered_points_color, vertices_3d, connections): segmented_points_3d = [] # Visualize with the segmented depth points in blue pcd_all = o3d.geometry.PointCloud() pcd_filtered = o3d.geometry.PointCloud() pcd_depth = o3d.geometry.PointCloud() # All points in gray all_points = [] all_colors = [] for p3D in colmap_rec.points3D.values(): all_points.append(p3D.xyz) all_colors.append([0.5, 0.5, 0.5]) # Gray color if all_points: pcd_all.points = o3d.utility.Vector3dVector(np.array(all_points)) pcd_all.colors = o3d.utility.Vector3dVector(np.array(all_colors)) # Filtered COLMAP points in red if len(filtered_points_xyz) > 0: pcd_filtered.points = o3d.utility.Vector3dVector(filtered_points_xyz) pcd_filtered.colors = o3d.utility.Vector3dVector(np.array(filtered_points_color)) # Segmented depth points in blue if len(segmented_points_3d) > 0: pcd_depth.points = o3d.utility.Vector3dVector(segmented_points_3d) pcd_depth.colors = o3d.utility.Vector3dVector(np.full((len(segmented_points_3d), 3), [0.0, 0.0, 1.0])) # Visualize all point clouds and spheres geometries = [pcd_all] if len(filtered_points_xyz) > 0: geometries.append(pcd_filtered) if len(segmented_points_3d) > 0: geometries.append(pcd_depth) #o3d.visualization.draw_geometries(geometries, window_name=f"Combined Point Cloud - {img_id_substring}") def generate_patches(colmap_rec, filtered_points_idxs, frame, filtered_vertices, vertices_formatted): patches = [] gt_vertices = frame['wf_vertices'] # Process each group of filtered points for group_idx, point_idxs in enumerate(filtered_points_idxs): # Get 3D coordinates and colors for this group group_points_3d = [] group_colors = [] assigned_gt_vertex = None for pid in point_idxs: p3d = colmap_rec.points3D[pid] group_points_3d.append(p3d.xyz) group_colors.append(p3d.color) group_points_3d = np.array(group_points_3d) group_colors = np.array(group_colors) # Calculate centroid of filtered points # Find the closest GT vertex to the centroid of filtered points centroid = np.mean(group_points_3d, axis=0) if len(gt_vertices) > 0: # Calculate distances from centroid to all GT vertices distances_to_gt = [] for gt_vertex in gt_vertices: distance = np.linalg.norm(gt_vertex - centroid) distances_to_gt.append(distance) # Find the closest GT vertex min_distance_idx = np.argmin(distances_to_gt) closest_gt_vertex = gt_vertices[min_distance_idx] min_distance = distances_to_gt[min_distance_idx] # Define ball radius (you can adjust this value) ball_radius = 2.0 # meters # Use closest GT vertex as centroid if it's within the ball radius if min_distance <= ball_radius: assigned_gt_vertex = closest_gt_vertex # If no GT vertex is close enough, skip this group else: assigned_gt_vertex = None else: # No GT vertices available, use original centroid centroid = np.mean(group_points_3d, axis=0) # Define ball radius (you can adjust this value) ball_radius = 2.0 # meters # Find all COLMAP points within the ball around centroid patch_points_3d = [] patch_colors = [] patch_point_ids = [] for pid, p3d in colmap_rec.points3D.items(): distance = np.linalg.norm(p3d.xyz - centroid) if distance <= ball_radius: patch_points_3d.append(p3d.xyz) patch_colors.append(p3d.color) patch_point_ids.append(pid) patch_points_3d = np.array(patch_points_3d) # Calculate offset to center the patch patch_centroid = np.mean(patch_points_3d, axis=0) offset = -patch_centroid # Shift points to center them around origin patch_points_3d += offset # Also shift the assigned GT vertex by the same offset if it exists if assigned_gt_vertex is not None: assigned_gt_vertex = assigned_gt_vertex + offset patch_colors = np.array(patch_colors) # Create 7D point cloud for this patch # [x, y, z, r, g, b, in_filtered_flag] patch_7d = np.zeros((len(patch_points_3d), 7)) patch_7d[:, :3] = patch_points_3d # xyz coordinates patch_7d[:, 3:6] = patch_colors / 255.0 # rgb colors normalized to [0,1] # Set in_filtered_flag: 1 if point was in original filtered set, 0 otherwise for i, pid in enumerate(patch_point_ids): if pid in point_idxs: patch_7d[i, 6] = 1.0 else: patch_7d[i, 6] = -1.0 if len(filtered_vertices) > 0 and filtered_vertices[group_idx] is not None: initial_pred = filtered_vertices[group_idx] + offset else: initial_pred = None if vertices_formatted[group_idx] is not None: # Get the xy coordinates of the vertex vertex_class = vertices_formatted[group_idx]['type'] patches.append({ 'patch_7d': patch_7d, 'centroid': centroid, 'radius': ball_radius, 'point_ids': patch_point_ids, 'filtered_point_ids': point_idxs, 'group_idx': group_idx, 'assigned_gt_vertex': assigned_gt_vertex, 'offset': offset, 'initial_pred': initial_pred, 'vertex_class': vertex_class }) if False: # Create plotter plotter = pv.Plotter() # Create point cloud for this patch patch_cloud = pv.PolyData(patch_points_3d) # Color points: red for filtered points, blue for other points patch_point_colors = [] for i, pid in enumerate(patch_point_ids): if pid in point_idxs: patch_point_colors.append([255, 0, 0]) # Red for filtered points else: patch_point_colors.append([0, 0, 255]) # Blue for other points patch_cloud["colors"] = np.array(patch_point_colors) plotter.add_mesh(patch_cloud, scalars="colors", rgb=True, point_size=8, render_points_as_spheres=True) # Create sphere to visualize GT vertex if available if assigned_gt_vertex is not None: gt_sphere = pv.Sphere(radius=0.1, center=assigned_gt_vertex) plotter.add_mesh(gt_sphere, color="green", opacity=0.5) if initial_pred is not None: # Create sphere to visualize initial prediction pred_sphere = pv.Sphere(radius=0.1, center=initial_pred) plotter.add_mesh(pred_sphere, color="orange", opacity=0.5) plotter.show(title=f"Patch {group_idx}") return patches def our_get_vertices_and_edges(gest_seg_np, colmap_rec, img_id_substring, ade_seg, depth, K=None, R=None, t=None, frame=None): """ Identify apex and eave-end vertices, then detect lines for eave/ridge/rake/valley. Also find all COLMAP points that project into apex or eave_end masks. """ #-------------------------------------------------------------------------------- # Step A: Collect apex and eave_end vertices #-------------------------------------------------------------------------------- if not isinstance(gest_seg_np, np.ndarray): gest_seg_np = np.array(gest_seg_np) H, W = gest_seg_np.shape[:2] # Get camera parameters from COLMAP reconstruction if not provided if False: # Find the matching COLMAP image found_img = None for img_id_c, col_img_obj in colmap_rec.images.items(): if img_id_substring in col_img_obj.name: found_img = col_img_obj break if found_img is not None: # Get camera intrinsic matrix K = found_img.camera.calibration_matrix() # Get world-to-camera transformation matrix world_to_cam = found_img.cam_from_world.matrix() R = world_to_cam[:3, :3] t = world_to_cam[:3, 3] else: print(f"Image substring {img_id_substring} not found in COLMAP.") return [], [], [], [], [] points_cam, points_xyz_world, points_idxs = get_visible_points(colmap_rec, img_id_substring, R=R, t=t) uv, valid_indices = project_points_to_2d(points_cam, K, H, W) if len(uv) == 0: print(f"No points projected into image bounds for {img_id_substring} using K,R,t.") return [], [], [], [], [] house_mask = get_house_mask(ade_seg) filtered_points_xyz, filtered_point_idxs, filtered_points_color, filtered_vertices_apex, filtered_vertices_apex_uv, filtered_vertices_eave, filtered_vertices_eave_uv, _, _ = get_vertexes(uv, gest_seg_np, house_mask, valid_indices, points_xyz_world, points_cam, points_idxs) vertices_formatted, connections, all_vertices_3d = get_connections(gest_seg_np, filtered_vertices_apex, filtered_vertices_eave, filtered_vertices_apex_uv, filtered_vertices_eave_uv) #print(len(vertices_formatted), len(connections), len(all_vertices_3d)) #patches = generate_patches(colmap_rec, filtered_point_idxs, frame, all_vertices_3d, vertices_formatted) patches = None #visualize_3d_wireframe(colmap_rec, filtered_points_xyz, filtered_points_color, all_vertices_3d, connections) return vertices_formatted, connections, all_vertices_3d, patches, filtered_point_idxs