hoho / predict.py
jskvrna's picture
Small bug fixes.
34ac397
# This script is designed for 3D wireframe reconstruction, primarily focusing on
# buildings, using multi-view imagery and associated 3D data.
# It leverages COLMAP reconstructions, depth maps, and semantic segmentations
# (ADE20k and Gestalt) to identify and predict structural elements.
# Core tasks include:
# - Processing and aligning 2D image data (segmentations, depth) with 3D COLMAP point clouds.
# - Extracting initial 2D/3D vertex candidates from segmentation maps.
# - Generating local point cloud patches around these candidates.
# - Employing machine learning models (e.g., PointNet variants) to refine vertex locations
# and classify potential edges between them.
# - Optionally, generating datasets of these patches for training ML models.
# - Merging information from multiple views to produce a final 3D wireframe.
import numpy as np
from typing import Tuple, List
from hoho2025.example_solutions import empty_solution, read_colmap_rec, get_vertices_and_edges_from_segmentation, get_house_mask, fit_scale_robust_median, get_uv_depth, merge_vertices_3d, prune_not_connected, prune_too_far, point_to_segment_dist
from hoho2025.color_mappings import ade20k_color_mapping, gestalt_color_mapping
from PIL import Image, ImageDraw
#from visu import save_gestalt_with_proj, draw_crosses_on_image
import os
import pycolmap
from PIL import Image as PImage
import cv2
#import open3d as o3d
#from visu import plot_reconstruction_local, plot_wireframe_local, plot_bpo_cameras_from_entry_local
#import pyvista as pv
#from fast_pointnet import save_patches_dataset, predict_vertex_from_patch
from fast_pointnet_v2 import save_patches_dataset, predict_vertex_from_patch
#from fast_voxel import predict_vertex_from_patch_voxel
#import time
from fast_pointnet_class import save_patches_dataset as save_patches_dataset_class
from fast_pointnet_class import predict_class_from_patch
#from fast_pointnet_class_10d import predict_class_from_patch as predict_class_from_patch_10d
from scipy.spatial.distance import cdist
from scipy.optimize import linear_sum_assignment
import torch
import time
from collections import Counter
GENERATE_DATASET = False
#DATASET_DIR = '/path/to/your/hohocustom/'
DATASET_DIR = '/path/to/your/hohocustom_v4/'
GENERATE_DATASET_EDGES = False
#EDGES_DATASET_DIR = '/path/to/your/hohocustom_edges/'
EDGES_DATASET_DIR = '/path/to/your/hohocustom_edges_10d_v5/'
def convert_entry_to_human_readable(entry):
out = {}
for k, v in entry.items():
if 'colmap' in k:
out[k] = read_colmap_rec(v)
elif k in ['wf_vertices', 'wf_edges', 'K', 'R', 't', 'depth']:
out[k] = v
else:
out[k]=v
out['__key__'] = entry['order_id']
return out
def get_gt_vertices_and_edges(entry, i, depth, colmap_rec, k, r, t, img_id, ade_seg):
depth_fitted, depth_sparse, found_sparse, col_img = get_fitted_dense_depth(depth, colmap_rec, img_id, ade_seg)
#old_k, old_r, old_t = k.copy(), r.copy(), t.copy()
#k = col_img.camera.calibration_matrix()
#world_to_cam = np.eye(4)
#world_to_cam = col_img.cam_from_world.matrix()
#r = world_to_cam[:3, :3]
#t = world_to_cam[:3, 3]
wf_vertices = np.array(entry['wf_vertices'])
wf_edges = entry['wf_edges']
# Project world frame vertices into the current image
if wf_vertices.shape[0] > 0:
# Transform vertices to camera coordinates
wf_vertices_cam = (r @ wf_vertices.T) + t.reshape(3, 1)
# Project to image plane
wf_vertices_img_homogeneous = k @ wf_vertices_cam
# Convert to 2D pixel coordinates
wf_vertices_img = wf_vertices_img_homogeneous[:2, :] / wf_vertices_img_homogeneous[2, :]
projected_gt_vertices_2d = wf_vertices_img.T
# Initialize lists to store corresponding depth values from depth maps
gt_projected_depth_fitted_values = []
gt_projected_depth_sparse_values = []
# Get dimensions of the depth maps for bounds checking
# Assuming depth_fitted and depth_sparse have the same dimensions
map_height, map_width = depth_fitted.shape
for idx in range(projected_gt_vertices_2d.shape[0]):
# Get the 2D projected coordinates (x, y)
px, py = projected_gt_vertices_2d[idx]
# Round to nearest integer to use as indices for the depth maps
ix, iy = int(round(px)), int(round(py))
# Get corresponding depth_fitted value
if 0 <= iy < map_height and 0 <= ix < map_width:
gt_projected_depth_fitted_values.append(depth_fitted[iy, ix])
else:
# Projected point is outside the depth map bounds
gt_projected_depth_fitted_values.append(np.nan)
# Get corresponding depth_sparse value
if 0 <= iy < map_height and 0 <= ix < map_width: # Assuming same dimensions for depth_sparse
gt_projected_depth_sparse_values.append(depth_sparse[iy, ix])
else:
# Projected point is outside the depth map bounds
gt_projected_depth_sparse_values.append(np.nan)
# Determine occlusion status for each ground truth vertex
occlusion_status = [] # True if occluded, False otherwise
# This block executes only if there were ground truth vertices to begin with.
# wf_vertices_cam and projected_gt_vertices_2d would have been computed.
# gt_projected_depth_fitted_values list has one entry per vertex.
if wf_vertices.shape[0] > 0:
# These are the Z-coordinates (depths) of the original 3D wf_vertices
# when transformed into the camera's coordinate system.
# This is effectively the "true" depth of each vertex from the camera.
gt_vertices_depth_in_camera_system = wf_vertices_cam[2, :]
for idx in range(projected_gt_vertices_2d.shape[0]):
true_depth_of_vertex = gt_vertices_depth_in_camera_system[idx]
# This is the depth value read from the (dense) depth_fitted map
# at the 2D projection of the current wf_vertex.
depth_from_fitted_map = gt_projected_depth_fitted_values[idx]
# A vertex is considered occluded if its true depth is greater than
# the depth of the surface recorded in the depth_fitted map.
# This means the vertex is behind the observed surface.
# We also check if depth_from_fitted_map is a valid number (not NaN).
# If depth_from_fitted_map is NaN, it means the vertex projected outside
# the depth map's bounds, so we don't consider it occluded by the map.
if np.isnan(true_depth_of_vertex) or true_depth_of_vertex > depth_from_fitted_map + 200.:
occlusion_status.append(True) # Vertex is occluded
else:
occlusion_status.append(False) # Vertex is not occluded or out of map bounds
if wf_vertices.shape[0] > 0:
# Filter vertices based on occlusion status
visible_vertices_indices = [idx for idx, occluded in enumerate(occlusion_status) if not occluded]
# Create a mapping from old vertex indices to new (filtered) vertex indices
old_to_new_indices_map = {old_idx: new_idx for new_idx, old_idx in enumerate(visible_vertices_indices)}
# Filter the projected_gt_vertices_2d and transform to the new structure
new_wf_vertices = []
if projected_gt_vertices_2d.shape[0] > 0: # Ensure projected_gt_vertices_2d is not empty
for idx in visible_vertices_indices:
xy_coords = projected_gt_vertices_2d[idx]
new_wf_vertices.append({'xy': xy_coords, 'type': 'apex'})
wf_vertices = new_wf_vertices
# Filter the edges
# An edge is kept if both its vertices are in the visible_vertices_indices list
visible_edges = []
for edge_start, edge_end in wf_edges:
if edge_start in old_to_new_indices_map and edge_end in old_to_new_indices_map:
# Remap to new indices
visible_edges.append((old_to_new_indices_map[edge_start], old_to_new_indices_map[edge_end]))
wf_edges = visible_edges
else:
# If there are no original vertices, wf_vertices should be an empty list
wf_vertices = []
wf_edges = []
wf_vertices_3d_visible = np.empty((0, 3))
original_gt_3d_vertices = np.array(entry['wf_vertices'])
# Check if there were original vertices and if occlusion_status was computed for them
if original_gt_3d_vertices.shape[0] > 0 and len(occlusion_status) == original_gt_3d_vertices.shape[0]:
# Determine indices of visible vertices based on occlusion_status
# occlusion_status is True if occluded, False otherwise. We want not occluded.
visible_indices = [idx for idx, occluded_flag in enumerate(occlusion_status) if not occluded_flag]
if visible_indices: # If the list of visible_indices is not empty
wf_vertices_3d_visible = original_gt_3d_vertices[visible_indices]
# If no original_gt_3d_vertices, or if all are occluded (visible_indices is empty),
# or if occlusion_status length doesn't match (which implies an issue earlier, but defensively handled),
# wf_vertices_3d_visible will remain the initialized np.empty((0, 3)).
return wf_vertices, wf_edges, wf_vertices_3d_visible
def project_vertices_to_3d(uv: np.ndarray, depth_vert: np.ndarray, col_img: pycolmap.Image, K, R, t) -> np.ndarray:
"""
Projects 2D vertex coordinates with associated depths to 3D world coordinates.
Parameters
----------
uv : np.ndarray
(N, 2) array of 2D vertex coordinates (u, v).
depth_vert : np.ndarray
(N,) array of depth values for each vertex.
col_img : pycolmap.Image
Returns
-------
vertices_3d : np.ndarray
(N, 3) array of vertex coordinates in 3D world space.
"""
# Backproject to 3D local camera coordinates
xy_local = np.ones((len(uv), 3))
#k = col_img.camera.calibration_matrix()
k = K
xy_local[:, 0] = (uv[:, 0] - k[0, 2]) / k[0, 0]
xy_local[:, 1] = (uv[:, 1] - k[1, 2]) / k[1, 1]
# Get the 3D vertices
vertices_3d_local = xy_local * depth_vert[...,None]
# Create camera-to-world transformation matrix
world_to_cam = np.eye(4)
world_to_cam[:3, :3] = R
world_to_cam[:3, 3] = t.reshape(3)
#world_to_cam[:3] = col_img.cam_from_world.matrix()
cam_to_world = np.linalg.inv(world_to_cam)
# Transform local 3D points to world coordinates
vertices_3d_homogeneous = cv2.convertPointsToHomogeneous(vertices_3d_local)
vertices_3d = cv2.transform(vertices_3d_homogeneous, cam_to_world)
vertices_3d = cv2.convertPointsFromHomogeneous(vertices_3d).reshape(-1, 3)
return vertices_3d
def get_fitted_dense_depth(depth, colmap_rec, img_id, ade20k_seg, K, R, t):
"""
Gets sparse depth from COLMAP, computes a house mask, fits dense depth to sparse
depth within the mask, and returns the fitted dense depth.
Parameters
----------
depth : np.ndarray
Initial dense depth map (H, W).
colmap_rec : pycolmap.Reconstruction
COLMAP reconstruction data.
img_id : str
Identifier for the current image within the COLMAP reconstruction.
K : np.ndarray
Camera intrinsic matrix (3x3).
R : np.ndarray
Camera rotation matrix (3x3).
t : np.ndarray
Camera translation vector (3,).
ade20k_seg : PIL.Image
ADE20k segmentation map for the image.
Returns
-------
depth_fitted : np.ndarray
Dense depth map scaled and shifted to align with sparse depth within the house mask (H, W).
depth_sparse : np.ndarray
The sparse depth map obtained from COLMAP (H, W).
found_sparse : bool
True if sparse depth points were found for this image, False otherwise.
"""
depth_np = np.array(depth) / 1000. # Convert mm to meters if needed
depth_sparse, found_sparse, col_img = get_sparse_depth_custom(colmap_rec, img_id, depth_np, K, R, t)
#print(depth_sparse.sum())
#depth_sparse, found_sparse, col_img = get_sparse_depth(colmap_rec, img_id, depth_np)
if not found_sparse:
print(f'No sparse depth found for image {img_id}')
# Return original (meter-scaled) depth if no sparse data
return depth_np, np.zeros_like(depth_np), False, None
# Get house mask to focus fitting on relevant areas
house_mask = get_house_mask(ade20k_seg)
# Fit dense depth to sparse depth (scale only), using only points within the house mask
k, depth_fitted = fit_scale_robust_median(depth_np, depth_sparse, validity_mask=house_mask)
print(f"Fitted depth scale k={k:.4f} for image {img_id}")
#depth_fitted = depth_np# * house_mask.astype(np.float32)
depth_sparse = depth_sparse# * house_mask.astype(np.float32)
return depth_fitted, depth_sparse, True, col_img
def get_sparse_depth_custom(colmap_rec, img_id_substring, depth, K, R, t):
"""
Return a sparse depth map for the COLMAP image whose name contains
`img_id_substring`. The output is an array of shape `depth_shape` (H,W),
where only the projected 3D points get a depth > 0, else 0.
Uses provided K, R, t for projection instead of COLMAP's image projection.
"""
H, W = depth.shape
# 1) Find the matching COLMAP image to get its associated 3D points
# This part remains to identify which 3D points are relevant for this image view
found_img = None
for img_id_c, col_img_obj in colmap_rec.images.items(): # Renamed col_img to col_img_obj to avoid conflict
if img_id_substring in col_img_obj.name:
found_img = col_img_obj
break
if found_img is None:
print(f"Image substring {img_id_substring} not found in COLMAP.")
return np.zeros((H, W), dtype=np.float32), False, None
# 2) Gather 3D points that this image sees (according to COLMAP)
points_xyz_world = []
for pid, p3D in colmap_rec.points3D.items():
if found_img.has_point3D(pid):
points_xyz_world.append(p3D.xyz) # world coords
if not points_xyz_world:
print(f"No 3D points associated with {found_img.name} in COLMAP.")
return np.zeros((H, W), dtype=np.float32), False, found_img # Return found_img for consistency
points_xyz_world = np.array(points_xyz_world) # (N, 3)
# 3) Project points_xyz_world to camera coordinates using R, t
# points_cam = R @ points_xyz_world.T + t.reshape(3,1)
# points_cam = points_cam.T (N,3)
# More robustly:
points_xyz_world_h = np.hstack((points_xyz_world, np.ones((points_xyz_world.shape[0], 1)))) # (N, 4)
# World to Camera transformation matrix
world_to_cam_mat = np.eye(4)
world_to_cam_mat[:3, :3] = R
world_to_cam_mat[:3, 3] = t.flatten()
points_cam_h = (world_to_cam_mat @ points_xyz_world_h.T).T # (N, 4)
points_cam = points_cam_h[:, :3] / points_cam_h[:, 3, np.newaxis] # (N, 3) in camera coordinates
uv = []
z_vals = []
for i in range(points_cam.shape[0]):
p_cam = points_cam[i]
# Project to image plane using K
# p_img_h = K @ p_cam
# u = p_img_h[0] / p_img_h[2]
# v = p_img_h[1] / p_img_h[2]
# z = p_cam[2]
# Ensure p_cam[2] (depth) is positive
if p_cam[2] <= 0: # Point is behind or on the camera plane
continue
# Project to image plane using K
# K is [[fx, 0, cx], [0, fy, cy], [0, 0, 1]]
u_i = (K[0, 0] * p_cam[0] / p_cam[2]) + K[0, 2]
v_i = (K[1, 1] * p_cam[1] / p_cam[2]) + K[1, 2]
u_i_int = int(round(u_i))
v_i_int = int(round(v_i))
# Check in-bounds
if 0 <= u_i_int < W and 0 <= v_i_int < H:
uv.append((u_i_int, v_i_int))
z_vals.append(p_cam[2]) # Depth is the Z coordinate in camera space
if not uv:
print(f"No points projected into image bounds for {img_id_substring} using K,R,t.")
return np.zeros((H, W), dtype=np.float32), False, found_img
uv = np.array(uv, dtype=int) # shape (M,2)
z_vals = np.array(z_vals) # shape (M,)
depth_out = np.zeros((H, W), dtype=np.float32)
# Ensure z_vals are positive before assignment, though already checked
valid_depth_mask = z_vals > 0
if np.any(valid_depth_mask):
depth_out[uv[valid_depth_mask, 1], uv[valid_depth_mask, 0]] = z_vals[valid_depth_mask]
return depth_out, True, found_img
def create_3d_wireframe_single_image(vertices: List[dict],
connections: List[Tuple[int, int]],
depth: PImage,
colmap_rec: pycolmap.Reconstruction,
img_id: str,
ade_seg: PImage,
K, R, t) -> np.ndarray:
"""
Processes a single image view to generate 3D vertex coordinates from existing 2D vertices/edges.
Parameters
----------
vertices : List[dict]
List of 2D vertex dictionaries (e.g., {"xy": (x, y), "type": ...}).
connections : List[Tuple[int, int]]
List of 2D edge connections (indices into the vertices list).
depth : PIL.Image
Initial dense depth map as a PIL Image.
colmap_rec : pycolmap.Reconstruction
COLMAP reconstruction data.
img_id : str
Identifier for the current image within the COLMAP reconstruction.
ade_seg : PIL.Image
ADE20k segmentation map for the image.
Returns
-------
vertices_3d : np.ndarray
(N, 3) array of vertex coordinates in 3D world space.
Returns an empty array if processing fails (e.g., missing sparse depth).
"""
# Check if initial vertices/connections are valid
if (len(vertices) < 2) or (len(connections) < 1):
# This case should ideally be handled before calling, but good to double check.
print(f'Warning: create_3d_wireframe_single_image called with insufficient vertices/connections for image {img_id}')
return np.empty((0, 3))
# Get fitted dense depth and sparse depth
depth_fitted, depth_sparse, found_sparse, col_img = get_fitted_dense_depth(
depth, colmap_rec, img_id, ade_seg, K, R, t
)
# Get UV coordinates and depth for each vertex
uv, depth_vert = get_uv_depth(vertices, depth_fitted, depth_sparse, 10)
# Backproject to 3D
vertices_3d = project_vertices_to_3d(uv, depth_vert, col_img, K, R ,t)
return vertices_3d
def visu_patch_and_pred(patch, pred, pred_dist, pred_class):
# Create plotter
plotter = pv.Plotter()
# Create point cloud for this patch
offset = patch.get('cluster_center', None) # Offset if available
patch_points_3d = np.array(patch['patch_7d'][:, :3])
patch_points_3d = patch_points_3d + offset
patch_cloud = pv.PolyData(patch_points_3d)
point_idxs = patch['cluster_point_ids'] # List of point indices that are filtered
patch_point_ids = patch['cube_point_ids'] # Assuming the 7th column contains point IDs
assigned_gt_vertex = patch.get('assigned_wf_vertex', None) # GT vertex if available
initial_pred = None
if assigned_gt_vertex is not None:
assigned_gt_vertex = assigned_gt_vertex + offset
# Color points: red for filtered points, blue for other points
patch_point_colors = []
for i, pid in enumerate(patch_point_ids):
if pid in point_idxs:
patch_point_colors.append([255, 0, 0]) # Red for filtered points
else:
patch_point_colors.append([0, 0, 255]) # Blue for other points
patch_cloud["colors"] = np.array(patch_point_colors)
plotter.add_mesh(patch_cloud, scalars="colors", rgb=True, point_size=8, render_points_as_spheres=True)
# Create sphere to visualize GT vertex if available
if assigned_gt_vertex is not None:
gt_sphere = pv.Sphere(radius=0.1, center=assigned_gt_vertex)
plotter.add_mesh(gt_sphere, color="green", opacity=0.5)
if initial_pred is not None:
# Create sphere to visualize initial prediction
pred_sphere = pv.Sphere(radius=0.1, center=initial_pred)
plotter.add_mesh(pred_sphere, color="orange", opacity=0.5)
if pred is not None:
# Create sphere to visualize predicted vertex
pred_sphere = pv.Sphere(radius=0.1, center=pred)
plotter.add_mesh(pred_sphere, color="red", opacity=0.5)
# Add text annotations for prediction values
title_text = f"Patch x\nPred dist: {pred_dist:.4f}\nPred class: {pred_class}"
plotter.show(title=title_text)
def extract_vertices_from_whole_pcloud(colmap_rec, idxs_points, all_connections):
# Filter COLMAP points and colors based on idxs_points
filtered_colmap_points = []
filtered_colmap_colors = []
filtered_colmap_ids = []
all_filtered_ids_list = []
all_extracted_groups = []
all_flattened_connections = []
group_to_flattened_mapping = {} # Maps (group_idx, local_vertex_idx) to flattened_idx
# Flatten all groups and create mapping for connections
flattened_idx = 0
for group_idx, point_ids_group in enumerate(idxs_points):
cur_connections = all_connections[group_idx]
group_to_flattened_mapping[group_idx] = {}
for local_idx, point_ids in enumerate(point_ids_group):
all_extracted_groups.append(point_ids)
group_to_flattened_mapping[group_idx][local_idx] = flattened_idx
flattened_idx += 1
# Convert connections to flattened indices
for conn in cur_connections:
start_idx, end_idx = conn
if start_idx in group_to_flattened_mapping[group_idx] and end_idx in group_to_flattened_mapping[group_idx]:
flattened_start = group_to_flattened_mapping[group_idx][start_idx]
flattened_end = group_to_flattened_mapping[group_idx][end_idx]
all_flattened_connections.append((flattened_start, flattened_end))
# Collect all filtered point IDs from all images
for group_idxs in idxs_points:
for point_ids in group_idxs:
all_filtered_ids_list.extend(point_ids)
# Convert to set for faster lookup
all_filtered_ids_set = set(all_filtered_ids_list)
# Extract only the filtered points, their colors, and IDs
all_colmap_points = []
all_colmap_colors = []
all_colmap_ids = []
for pid, p3D in colmap_rec.points3D.items():
all_colmap_points.append(p3D.xyz)
all_colmap_colors.append(p3D.color / 255.0) # Normalize colors to [0,1]
all_colmap_ids.append(pid)
if pid in all_filtered_ids_set:
filtered_colmap_points.append(p3D.xyz)
filtered_colmap_colors.append(p3D.color / 255.0) # Normalize colors to [0,1]
filtered_colmap_ids.append(pid)
all_colmap_points = np.array(all_colmap_points) if all_colmap_points else np.empty((0, 3))
all_colmap_colors = np.array(all_colmap_colors) if all_colmap_colors else np.empty((0, 3))
all_colmap_ids = np.array(all_colmap_ids) if all_colmap_ids else np.empty((0,))
whole_pcloud = {'points': all_colmap_points,
'colors': all_colmap_colors,
'ids': all_colmap_ids}
filtered_colmap_points = np.array(filtered_colmap_points) if filtered_colmap_points else np.empty((0, 3))
filtered_colmap_colors = np.array(filtered_colmap_colors) if filtered_colmap_colors else np.empty((0, 3))
filtered_colmap_ids = np.array(filtered_colmap_ids) if filtered_colmap_ids else np.empty((0,))
# Extract points within ball radius from each set of points in idxs_points
ball_radius = 0.5 # meters
extracted_points = []
extracted_colors = []
extracted_ids = []
for group_idx, point_ids_group in enumerate(all_extracted_groups):
group_extracted_points = []
group_extracted_colors = []
group_extracted_ids = []
# Get 3D coordinates of points in this group
group_points_3d = []
for pid in point_ids_group:
if pid in [filtered_colmap_ids[i] for i in range(len(filtered_colmap_ids))]:
idx = np.where(filtered_colmap_ids == pid)[0][0]
group_points_3d.append(filtered_colmap_points[idx])
if not group_points_3d:
continue
group_points_3d = np.array(group_points_3d)
center = np.mean(group_points_3d, axis=0) # Center of the group points
# For each point in the filtered point cloud, check if it's within ball radius of any point in this group
# Calculate distance from center to all filtered points
if len(filtered_colmap_points) > 0:
distances_to_center = np.linalg.norm(filtered_colmap_points - center, axis=1)
within_radius_mask = distances_to_center <= ball_radius
if np.any(within_radius_mask):
group_extracted_points.extend(filtered_colmap_points[within_radius_mask])
group_extracted_colors.extend(filtered_colmap_colors[within_radius_mask])
group_extracted_ids.extend(filtered_colmap_ids[within_radius_mask])
extracted_points.append(np.array(group_extracted_points) if group_extracted_points else np.empty((0, 3)))
extracted_colors.append(np.array(group_extracted_colors) if group_extracted_colors else np.empty((0, 3)))
extracted_ids.append(np.array(group_extracted_ids) if group_extracted_ids else np.empty((0,)))
# Filter extracted_points to merge groups that share more than 50% of their points
# and update connections accordingly
updated_connections = []
if extracted_points:
#print(f"Merging groups based on point overlap... Processing {len(extracted_points)} groups")
# Create a list to track which groups to keep
groups_to_keep = []
merged_groups = set() # Track which groups have been merged
old_to_new_mapping = {} # Maps old flattened index to new index
for i, (points_i, colors_i, ids_i) in enumerate(zip(extracted_points, extracted_colors, extracted_ids)):
if i in merged_groups or len(ids_i) == 0:
continue
# Start with the current group
merged_points = points_i.copy()
merged_colors = colors_i.copy()
merged_ids = set(ids_i)
merged_indices = [i] # Track which original indices are merged
# Check all subsequent groups for overlap
for j in range(i + 1, len(extracted_points)):
if j in merged_groups or len(extracted_ids[j]) == 0:
continue
ids_j = set(extracted_ids[j])
# Calculate overlap percentage
intersection = merged_ids.intersection(ids_j)
smaller_group_size = min(len(merged_ids), len(ids_j))
if smaller_group_size > 0:
overlap_percentage = len(intersection) / smaller_group_size
# If more than 50% overlap, merge the groups
if overlap_percentage > 0.5:
merged_points = np.vstack([merged_points, extracted_points[j]]) if len(merged_points) > 0 else extracted_points[j]
merged_colors = np.vstack([merged_colors, extracted_colors[j]]) if len(merged_colors) > 0 else extracted_colors[j]
merged_ids.update(ids_j)
merged_indices.append(j)
merged_groups.add(j)
# Add the merged group to the list of groups to keep
if len(merged_points) > 0:
new_group_idx = len(groups_to_keep)
groups_to_keep.append((merged_points, merged_colors, np.array(list(merged_ids))))
# Update mapping for all merged indices
for old_idx in merged_indices:
old_to_new_mapping[old_idx] = new_group_idx
# Update extracted_points, extracted_colors, and extracted_ids with filtered results
extracted_points = [group[0] for group in groups_to_keep]
extracted_colors = [group[1] for group in groups_to_keep]
extracted_ids = [group[2] for group in groups_to_keep]
# Update connections based on the new mapping
for start_idx, end_idx in all_flattened_connections:
if start_idx in old_to_new_mapping and end_idx in old_to_new_mapping:
new_start = old_to_new_mapping[start_idx]
new_end = old_to_new_mapping[end_idx]
# Only add connection if vertices are still different after merging
if new_start != new_end:
connection = tuple(sorted((new_start, new_end)))
if connection not in updated_connections:
updated_connections.append(connection)
#print(f"After merging, number of groups: {len(extracted_points)}")
#print(f"Updated connections: {updated_connections}")
# Create visualization showing extracted points for each group as balls within their mean
if False:
if extracted_points:
plotter = pv.Plotter()
# Add all COLMAP points in gray
all_points = []
all_colors = []
for pid, p3D in colmap_rec.points3D.items():
all_points.append(p3D.xyz)
all_colors.append([0.8, 0.8, 0.8]) # Gray color
if all_points:
all_points = np.array(all_points)
all_colors = np.array(all_colors)
point_cloud = pv.PolyData(all_points)
point_cloud["colors"] = np.array(all_colors)
plotter.add_mesh(point_cloud, scalars="colors", rgb=True, point_size=3, render_points_as_spheres=True)
for group_idx, (group_points, group_colors) in enumerate(zip(extracted_points, extracted_colors)):
if len(group_points) > 0:
# Calculate mean position for this group
group_mean = np.mean(group_points, axis=0)
# Create a sphere at the mean position
sphere = pv.Sphere(radius=0.2, center=group_mean)
# Generate a random color for each group
group_color = np.random.rand(3)
plotter.add_mesh(sphere, color=group_color, opacity=0.7)
# Add the extracted points for this group in the same color
group_cloud = pv.PolyData(group_points)
plotter.add_mesh(group_cloud, color=group_color, point_size=6, render_points_as_spheres=True)
plotter.show(title=f"Extracted Points within {ball_radius}m radius - Spheres at group means")
return extracted_points, extracted_colors, extracted_ids, whole_pcloud, updated_connections
from collections import Counter # Ensure Counter is imported
def extract_vertices_from_whole_pcloud_v2(colmap_pcloud, idxs_points, all_connections):
# Extract initial data from colmap_pcloud
# points_7d contains: x, y, z, r, g, b, pid (r,g,b are already normalized to [0,1])
all_colmap_points_xyz = colmap_pcloud['points_7d'][:, :3]
all_colmap_rgb_colors = colmap_pcloud['points_7d'][:, 3:6]
all_colmap_ids = colmap_pcloud['points_7d'][:, 6].astype(int)
# ADE feature: 1.0 if ade_count > 0, else 0.0
# colmap_pcloud['ade'] stores the count of times a point was seen in an ADE house mask
all_colmap_ade_feature = (np.array(colmap_pcloud['ade']) > 0).astype(float).reshape(-1, 1)
# Gestalt feature: Fused Gestalt color by majority vote, normalized to [0,1]
# colmap_pcloud['gestalt'] is a list of lists; each inner list contains uint8 RGB arrays from different views
all_colmap_fused_gestalt_colors_normalized = np.zeros((len(all_colmap_points_xyz), 3))
for i, gestalt_obs_for_point_i in enumerate(colmap_pcloud['gestalt']):
if gestalt_obs_for_point_i:
# Convert list of np.arrays to list of tuples to make them hashable for Counter
# Ensure gestalt_obs_for_point_i contains hashable items, e.g. tuples
try:
# If gestalt_obs_for_point_i contains numpy arrays:
counts = Counter(map(tuple, gestalt_obs_for_point_i))
except TypeError:
# If gestalt_obs_for_point_i already contains tuples or other hashables:
counts = Counter(gestalt_obs_for_point_i)
if counts:
most_common_gestalt_tuple = counts.most_common(1)[0][0]
fused_gestalt_rgb_uint8 = np.array(most_common_gestalt_tuple)
all_colmap_fused_gestalt_colors_normalized[i] = fused_gestalt_rgb_uint8 / 255.0
else:
all_colmap_fused_gestalt_colors_normalized[i] = np.array([0.0, 0.0, 0.0]) # Default if counts is empty
else:
all_colmap_fused_gestalt_colors_normalized[i] = np.array([0.0, 0.0, 0.0]) # Default if no observations
# Combine into 7D colors [R, G, B, ADE, Gestalt_R, Gestalt_G, Gestalt_B]
all_colmap_colors_7d = np.hstack((
all_colmap_rgb_colors,
all_colmap_ade_feature,
all_colmap_fused_gestalt_colors_normalized
))
# Flatten all groups and create mapping for connections
all_filtered_ids_list = []
all_extracted_groups = [] # List of lists of point_ids
all_flattened_connections = []
group_to_flattened_mapping = {} # Maps (group_idx, local_vertex_idx) to flattened_idx
flattened_idx = 0
for group_idx, point_ids_group in enumerate(idxs_points): # idxs_points is list of lists of point_ids
cur_connections = all_connections[group_idx]
group_to_flattened_mapping[group_idx] = {}
for local_idx, point_ids in enumerate(point_ids_group): # point_ids is a list of pids for one vertex candidate
all_extracted_groups.append(point_ids) # Store the list of pids
all_filtered_ids_list.extend(point_ids) # Add all pids to a flat list
group_to_flattened_mapping[group_idx][local_idx] = flattened_idx
flattened_idx += 1
for conn in cur_connections:
start_idx, end_idx = conn
if start_idx in group_to_flattened_mapping[group_idx] and end_idx in group_to_flattened_mapping[group_idx]:
flattened_start = group_to_flattened_mapping[group_idx][start_idx]
flattened_end = group_to_flattened_mapping[group_idx][end_idx]
all_flattened_connections.append((flattened_start, flattened_end))
all_filtered_ids_set = set(all_filtered_ids_list)
# Extract only the points that are part of any initial group
filtered_colmap_points_xyz_list = []
filtered_colmap_colors_7d_list = []
filtered_colmap_ids_list = []
for i, pid in enumerate(all_colmap_ids):
if pid in all_filtered_ids_set:
filtered_colmap_points_xyz_list.append(all_colmap_points_xyz[i])
filtered_colmap_colors_7d_list.append(all_colmap_colors_7d[i])
filtered_colmap_ids_list.append(pid)
filtered_colmap_points_xyz_arr = np.array(filtered_colmap_points_xyz_list) if filtered_colmap_points_xyz_list else np.empty((0, 3))
filtered_colmap_colors_7d_arr = np.array(filtered_colmap_colors_7d_list) if filtered_colmap_colors_7d_list else np.empty((0, 7))
filtered_colmap_ids_arr = np.array(filtered_colmap_ids_list) if filtered_colmap_ids_list else np.empty((0,), dtype=int)
# This whole_pcloud is created by this function, reflecting the full dataset with 7D colors
whole_pcloud_internal = {
'points': all_colmap_points_xyz,
'colors': all_colmap_colors_7d, # Now 7D
'ids': all_colmap_ids
}
# Extract points within ball radius for each group
ball_radius = 0.5 # meters
extracted_points_groups = []
extracted_colors_7d_groups = []
extracted_ids_groups = []
for point_ids_in_one_group in all_extracted_groups: # point_ids_in_one_group is a list of PIDs
current_group_points_xyz = []
# Get 3D coordinates of points in this specific initial group
# These PIDs should be in all_colmap_ids
indices_in_all_colmap = [np.where(all_colmap_ids == pid)[0][0] for pid in point_ids_in_one_group if pid in all_colmap_ids]
if not indices_in_all_colmap:
extracted_points_groups.append(np.empty((0,3)))
extracted_colors_7d_groups.append(np.empty((0,7)))
extracted_ids_groups.append(np.empty((0,), dtype=int))
continue
current_group_points_xyz = all_colmap_points_xyz[indices_in_all_colmap]
if current_group_points_xyz.shape[0] == 0:
extracted_points_groups.append(np.empty((0,3)))
extracted_colors_7d_groups.append(np.empty((0,7)))
extracted_ids_groups.append(np.empty((0,), dtype=int))
continue
center = np.mean(current_group_points_xyz, axis=0)
# Find points from the *filtered_colmap_points_xyz_arr* (points belonging to *any* initial group)
# that are within ball_radius of this group's center.
group_extracted_points_list = []
group_extracted_colors_7d_list = []
group_extracted_ids_list = []
if len(filtered_colmap_points_xyz_arr) > 0:
distances_to_center = np.linalg.norm(filtered_colmap_points_xyz_arr - center, axis=1)
within_radius_mask = distances_to_center <= ball_radius
if np.any(within_radius_mask):
group_extracted_points_list.extend(filtered_colmap_points_xyz_arr[within_radius_mask])
group_extracted_colors_7d_list.extend(filtered_colmap_colors_7d_arr[within_radius_mask])
group_extracted_ids_list.extend(filtered_colmap_ids_arr[within_radius_mask])
extracted_points_groups.append(np.array(group_extracted_points_list) if group_extracted_points_list else np.empty((0, 3)))
extracted_colors_7d_groups.append(np.array(group_extracted_colors_7d_list) if group_extracted_colors_7d_list else np.empty((0, 7)))
extracted_ids_groups.append(np.array(group_extracted_ids_list) if group_extracted_ids_list else np.empty((0,), dtype=int))
# Filter extracted_points to merge groups that share more than 50% of their points
updated_connections = []
final_extracted_points = []
final_extracted_colors_7d = []
final_extracted_ids = []
if extracted_points_groups:
groups_to_keep_data = []
merged_groups_indices = set()
old_to_new_mapping = {}
for i in range(len(extracted_points_groups)):
if i in merged_groups_indices or len(extracted_ids_groups[i]) == 0:
continue
current_merged_points = extracted_points_groups[i].copy()
current_merged_colors_7d = extracted_colors_7d_groups[i].copy()
current_merged_ids_set = set(extracted_ids_groups[i])
indices_in_this_merged_group = [i]
for j in range(i + 1, len(extracted_points_groups)):
if j in merged_groups_indices or len(extracted_ids_groups[j]) == 0:
continue
ids_j_set = set(extracted_ids_groups[j])
intersection = current_merged_ids_set.intersection(ids_j_set)
smaller_group_size = min(len(current_merged_ids_set), len(ids_j_set))
if smaller_group_size > 0:
overlap_percentage = len(intersection) / smaller_group_size
if overlap_percentage > 0.5:
current_merged_points = np.vstack([current_merged_points, extracted_points_groups[j]]) if len(current_merged_points) > 0 else extracted_points_groups[j]
current_merged_colors_7d = np.vstack([current_merged_colors_7d, extracted_colors_7d_groups[j]]) if len(current_merged_colors_7d) > 0 else extracted_colors_7d_groups[j]
current_merged_ids_set.update(ids_j_set)
indices_in_this_merged_group.append(j)
merged_groups_indices.add(j)
if len(current_merged_points) > 0:
new_group_idx = len(groups_to_keep_data)
groups_to_keep_data.append((current_merged_points, current_merged_colors_7d, np.array(list(current_merged_ids_set))))
for old_idx in indices_in_this_merged_group:
old_to_new_mapping[old_idx] = new_group_idx
final_extracted_points = [group_data[0] for group_data in groups_to_keep_data]
final_extracted_colors_7d = [group_data[1] for group_data in groups_to_keep_data]
final_extracted_ids = [group_data[2] for group_data in groups_to_keep_data]
for start_idx, end_idx in all_flattened_connections:
if start_idx in old_to_new_mapping and end_idx in old_to_new_mapping:
new_start = old_to_new_mapping[start_idx]
new_end = old_to_new_mapping[end_idx]
if new_start != new_end:
connection = tuple(sorted((new_start, new_end)))
if connection not in updated_connections:
updated_connections.append(connection)
# Visualization part (remains largely unchanged, uses random colors for spheres)
if False: # Set to True to enable visualization
if final_extracted_points:
# Ensure pyvista is imported if this block is enabled
# import pyvista as pv
plotter = pv.Plotter()
# Add all COLMAP points (from whole_pcloud_internal) in gray
if len(whole_pcloud_internal['points']) > 0:
# For visualization, use only RGB part of 7D colors or a fixed color
# Here, using fixed gray color as in original
vis_colors = np.full((len(whole_pcloud_internal['points']), 3), [0.8, 0.8, 0.8])
point_cloud = pv.PolyData(whole_pcloud_internal['points'])
point_cloud["colors"] = vis_colors
plotter.add_mesh(point_cloud, scalars="colors", rgb=True, point_size=3, render_points_as_spheres=True)
for group_idx, (group_points_xyz, _) in enumerate(zip(final_extracted_points, final_extracted_colors_7d)):
if len(group_points_xyz) > 0:
group_mean = np.mean(group_points_xyz, axis=0)
sphere = pv.Sphere(radius=0.2, center=group_mean)
group_color_vis = np.random.rand(3) # Random color for sphere
plotter.add_mesh(sphere, color=group_color_vis, opacity=0.7)
group_cloud = pv.PolyData(group_points_xyz)
# Use the same random color for points in this group for visualization consistency
plotter.add_mesh(group_cloud, color=group_color_vis, point_size=6, render_points_as_spheres=True)
plotter.show(title=f"Extracted Points within {ball_radius}m radius - Spheres at group means")
return final_extracted_points, final_extracted_colors_7d, final_extracted_ids, whole_pcloud_internal, updated_connections
def visu_pcloud_and_preds(colmap_rec, extracted_ids, extracted_points, extracted_colors, predicted_vertices, connections):
if extracted_ids:
plotter = pv.Plotter()
# Add all COLMAP points in gray
all_points = []
all_colors = []
for pid, p3D in colmap_rec.points3D.items():
all_points.append(p3D.xyz)
all_colors.append([0.8, 0.8, 0.8]) # Gray color
if all_points:
all_points = np.array(all_points)
all_colors = np.array(all_colors)
point_cloud = pv.PolyData(all_points)
point_cloud["colors"] = np.array(all_colors)
plotter.add_mesh(point_cloud, scalars="colors", rgb=True, point_size=3, render_points_as_spheres=True)
for group_idx, (group_points, group_colors) in enumerate(zip(extracted_points, extracted_colors)):
if len(group_points) > 0:
# Calculate mean position for this group
group_mean = np.mean(group_points, axis=0)
# Create a sphere at the mean position
sphere = pv.Sphere(radius=0.2, center=group_mean)
# Generate a random color for each group
group_color = np.random.rand(3)
plotter.add_mesh(sphere, color=group_color, opacity=0.5)
# Add the extracted points for this group in the same color
group_cloud = pv.PolyData(group_points)
plotter.add_mesh(group_cloud, color=group_color, point_size=6, render_points_as_spheres=True)
# Add predicted vertex as sphere if it exists and is valid
if group_idx < len(predicted_vertices):
pred_vertex = predicted_vertices[group_idx]
if not np.allclose(pred_vertex, [0.0, 0.0, 0.0]): # Check if it's not a zero vertex
pred_sphere = pv.Sphere(radius=0.15, center=pred_vertex)
plotter.add_mesh(pred_sphere, color="black", opacity=1.)
# Add connections between predicted vertices
if len(predicted_vertices) > 0 and len(connections) > 0:
valid_pred_vertices = []
valid_indices = []
for i, pred_vertex in enumerate(predicted_vertices):
if not np.allclose(pred_vertex, [0.0, 0.0, 0.0]):
valid_pred_vertices.append(pred_vertex)
valid_indices.append(i)
if len(valid_pred_vertices) > 1:
valid_pred_vertices = np.array(valid_pred_vertices)
# Create lines for connections
for start_idx, end_idx in connections:
if start_idx in valid_indices and end_idx in valid_indices:
# Map to valid vertex indices
valid_start = valid_indices.index(start_idx)
valid_end = valid_indices.index(end_idx)
# Create line between vertices
line_points = np.array([valid_pred_vertices[valid_start], valid_pred_vertices[valid_end]])
line = pv.Line(line_points[0], line_points[1])
plotter.add_mesh(line, color="red", line_width=3)
ball_radius = 1.0 # meters
plotter.show(title=f"Extracted Points within {ball_radius}m radius - Spheres at group means")
def generate_edge_patches(frame, pred_vertices, colmap_pcloud):
gt_vertices = np.array(frame['wf_vertices']) if frame['wf_vertices'] else np.empty((0, 3))
gt_connections = frame['wf_edges']
vertices = np.array(pred_vertices) if pred_vertices is not None and len(pred_vertices) > 0 else np.empty((0, 3))
# Find closest GT vertex for each predicted vertex
connections = []
if len(vertices) > 0 and len(gt_vertices) > 0:
# For each GT vertex, find the closest predicted vertex
gt_to_pred_mapping = {}
for gt_idx, gt_vertex in enumerate(gt_vertices):
# Calculate distances from this GT vertex to all predicted vertices
distances = np.linalg.norm(vertices - gt_vertex, axis=1)
# Find the closest predicted vertex
closest_pred_idx = np.argmin(distances)
closest_distance = distances[closest_pred_idx]
# Only map if within distance threshold
distance_threshold = 1.5
if closest_distance <= distance_threshold:
gt_to_pred_mapping[gt_idx] = closest_pred_idx
# Propagate GT connections to predicted vertices
for gt_connection in gt_connections:
gt_start, gt_end = gt_connection
if gt_start in gt_to_pred_mapping and gt_end in gt_to_pred_mapping:
pred_start = gt_to_pred_mapping[gt_start]
pred_end = gt_to_pred_mapping[gt_end]
connections.append((pred_start, pred_end))
print(f"Matched {len(gt_to_pred_mapping)} GT vertices to predicted vertices")
print(f"Propagated {len(connections)} connections from GT to predicted vertices")
positive_patches = []
negative_patches = []
cylinder_radius = 1.0 # meters
points_6d = colmap_pcloud['points_7d'][:, :7]
points_6d[:, 3:6] = points_6d[:, 3:6] * 2 - 1 # Normalize RGB colors to [0, 1]
ade = colmap_pcloud['ade']
ade = np.where(ade, 1, -1) # Normalize to [-1, 1]
gestalt = colmap_pcloud['gestalt']
# Fuse multiple gestalt values per point using majority voting
fused_gestalt = []
for point_gestalt_list in gestalt:
if len(point_gestalt_list) == 0:
fused_gestalt.append(np.array([0, 0, 0]))
elif len(point_gestalt_list) == 1:
fused_gestalt.append(point_gestalt_list[0])
else:
# Convert to tuples for hashable voting
gestalt_tuples = [tuple(gestalt_val) for gestalt_val in point_gestalt_list]
# Use Counter for majority voting
counts = Counter(gestalt_tuples)
most_common_tuple = counts.most_common(1)[0][0]
fused_value = np.array(most_common_tuple, dtype=np.uint8)
fused_gestalt.append(fused_value)
gestalt = np.array(fused_gestalt)
gestalt = (gestalt / 255) * 2 - 1 # Normalize to [-1, 1]
# Extract 3D coordinates for faster vectorized operations
colmap_points_3d = points_6d[:, :3]
# Create combined 10D point cloud (xyz + rgb + ade + gestalt)
colmap_points_10d = np.zeros((len(colmap_points_3d), 10))
colmap_points_10d[:, :3] = colmap_points_3d # xyz coordinates
colmap_points_10d[:, 3:6] = points_6d[:, 3:6] # rgb colors (already normalized to [-1, 1])
colmap_points_10d[:, 6] = ade # ade values (normalized to [-1, 1])
colmap_points_10d[:, 7:10] = gestalt # gestalt values (normalized to [-1, 1], all 3 RGB channels)
# For each connection, create a positive edge patch
for connection in connections:
start_idx, end_idx = connection
# Get start and end vertices from the connections
start_vertex = vertices[start_idx]
end_vertex = vertices[end_idx]
# Create line vector from start to end
line_vector = end_vertex - start_vertex
line_length = np.linalg.norm(line_vector)
# Normalize line vector
line_direction = line_vector / line_length
# Extend the line by 25 cm (0.25 meters) on both ends for more context
extension_length = 1 # 25 cm in meters
extended_start = start_vertex - extension_length * line_direction
extended_end = end_vertex + extension_length * line_direction
extended_line_length = line_length + 2 * extension_length
# Vectorized distance calculation
# Vector from extended start to all points
start_to_points = colmap_points_3d - extended_start[np.newaxis, :]
# Project onto line direction to get distance along extended line
projection_lengths = np.dot(start_to_points, line_direction)
# Filter points within extended line segment bounds
within_bounds = (projection_lengths >= 0) & (projection_lengths <= extended_line_length)
# Find closest points on extended line segment for all points
closest_points_on_line = extended_start[np.newaxis, :] + projection_lengths[:, np.newaxis] * line_direction[np.newaxis, :]
# Calculate perpendicular distances from points to line
perpendicular_distances = np.linalg.norm(colmap_points_3d - closest_points_on_line, axis=1)
# Find points within cylinder
within_cylinder = within_bounds & (perpendicular_distances <= cylinder_radius)
if np.sum(within_cylinder) <= 5:
continue
points_in_cylinder = colmap_points_10d[within_cylinder]
point_indices_in_cylinder = np.where(within_cylinder)[0]
# Center the patch at the midpoint of the original line (not extended)
line_midpoint = (start_vertex + end_vertex) / 2
# Shift points to center around origin
points_centered = points_in_cylinder.copy()
points_centered[:, :3] -= line_midpoint
# Create positive edge patch
positive_patch = {
'patch_10d': points_centered,
'connection': connection,
'line_start': start_vertex - line_midpoint,
'line_end': end_vertex - line_midpoint,
'cylinder_radius': cylinder_radius,
'point_indices': point_indices_in_cylinder,
'label': 1, # Positive label for edge
'center': line_midpoint
}
positive_patches.append(positive_patch)
# Generate negative edge patches by sampling random unconnected vertex pairs
num_negative_patches = len(positive_patches)
if num_negative_patches > 0 and len(vertices) >= 2:
# Create set of connected pairs for fast lookup
connected_pairs = set(tuple(sorted(conn)) for conn in connections)
# Generate all possible vertex pairs
vertex_indices = np.arange(len(vertices))
all_pairs = np.array(np.meshgrid(vertex_indices, vertex_indices)).T.reshape(-1, 2)
# Filter out pairs where both indices are the same
all_pairs = all_pairs[all_pairs[:, 0] != all_pairs[:, 1]]
# Sort pairs to match connected_pairs format
all_pairs_sorted = np.sort(all_pairs, axis=1)
# Find unconnected pairs
unconnected_mask = np.array([tuple(pair) not in connected_pairs for pair in all_pairs_sorted])
unconnected_pairs = all_pairs[unconnected_mask]
if len(unconnected_pairs) > 0:
# Pre-compute positive patch cylinder info for overlap checks
positive_cylinders = []
for pos_patch in positive_patches:
start_world = pos_patch['line_start'] + pos_patch['center']
end_world = pos_patch['line_end'] + pos_patch['center']
positive_cylinders.append({
'start': start_world,
'end': end_world,
'radius': pos_patch['cylinder_radius']
})
# Randomly sample negative pairs without replacement
num_to_sample = min(num_negative_patches * 3, len(unconnected_pairs)) # Sample more to account for rejections
sampled_indices = np.random.choice(len(unconnected_pairs), size=num_to_sample, replace=False)
sampled_pairs = unconnected_pairs[sampled_indices]
for idx1, idx2 in sampled_pairs:
if len(negative_patches) >= num_negative_patches:
break
start_vertex = vertices[idx1]
end_vertex = vertices[idx2]
# Create line vector from start to end
line_vector = end_vertex - start_vertex
line_length = np.linalg.norm(line_vector)
# Normalize line vector
line_direction = line_vector / line_length
# Extend the line by 25 cm (0.25 meters) on both ends for more context
extension_length = 1 # 25 cm in meters
extended_start = start_vertex - extension_length * line_direction
extended_end = end_vertex + extension_length * line_direction
extended_line_length = line_length + 2 * extension_length
# Check cylinder overlap with positive patches
current_cylinder = {
'start': extended_start,
'end': extended_end,
'radius': cylinder_radius
}
has_overlap = False
for pos_cylinder in positive_cylinders:
# Calculate cylinder-cylinder intersection volume
overlap_volume = calculate_cylinder_overlap_volume(current_cylinder, pos_cylinder)
# Calculate volumes of both cylinders
current_volume = np.pi * cylinder_radius**2 * extended_line_length
pos_height = np.linalg.norm(pos_cylinder['end'] - pos_cylinder['start'])
pos_volume = np.pi * pos_cylinder['radius']**2 * pos_height
# Calculate IoU
union_volume = current_volume + pos_volume - overlap_volume
if union_volume > 0:
iou = overlap_volume / union_volume
if iou > 0.25: # 0.2 IoU threshold
has_overlap = True
break
if has_overlap:
continue # Skip this negative patch due to cylinder overlap
# Vectorized distance calculation
# Vector from extended start to all points
start_to_points = colmap_points_3d - extended_start[np.newaxis, :]
# Project onto line direction to get distance along extended line
projection_lengths = np.dot(start_to_points, line_direction)
# Filter points within extended line segment bounds
within_bounds = (projection_lengths >= 0) & (projection_lengths <= extended_line_length)
# Find closest points on extended line segment for all points
closest_points_on_line = extended_start[np.newaxis, :] + projection_lengths[:, np.newaxis] * line_direction[np.newaxis, :]
# Calculate perpendicular distances from points to line
perpendicular_distances = np.linalg.norm(colmap_points_3d - closest_points_on_line, axis=1)
# Find points within cylinder
within_cylinder = within_bounds & (perpendicular_distances <= cylinder_radius)
if np.sum(within_cylinder) <= 10:
continue
points_in_cylinder = colmap_points_10d[within_cylinder]
point_indices_in_cylinder = np.where(within_cylinder)[0]
# Center the patch at the midpoint of the original line (not extended)
line_midpoint = (start_vertex + end_vertex) / 2
# Shift points to center around origin
points_centered = points_in_cylinder.copy()
points_centered[:, :3] -= line_midpoint
# Create negative edge patch
negative_patch = {
'patch_10d': points_centered,
'connection': (idx1, idx2),
'line_start': start_vertex - line_midpoint,
'line_end': end_vertex - line_midpoint,
'cylinder_radius': cylinder_radius,
'point_indices': point_indices_in_cylinder,
'label': 0, # Negative label for non-edge
'center': line_midpoint # Center of the patch
}
negative_patches.append(negative_patch)
print(f"Generated {len(positive_patches)} positive patches and {len(negative_patches)} negative patches")
all_patches = positive_patches + negative_patches
# Visualize edge patches
if False: # Set to True to enable visualization
# Create plotter
plotter = pv.Plotter()
# Add whole point cloud in gray
if len(colmap_points_10d) > 0:
whole_cloud = pv.PolyData(colmap_points_3d)
gray_colors = np.full((len(colmap_points_3d), 3), [0.5, 0.5, 0.5])
whole_cloud["colors"] = gray_colors
plotter.add_mesh(whole_cloud, scalars="colors", rgb=True, point_size=3, render_points_as_spheres=True)
# Add GT vertices and connections in blue
gt_vertices = np.array(frame['wf_vertices']) if frame['wf_vertices'] else np.empty((0, 3))
gt_connections = frame['wf_edges']
if len(gt_vertices) > 0:
# Add GT vertices as blue spheres
for gt_vertex in gt_vertices:
gt_sphere = pv.Sphere(radius=0.15, center=gt_vertex)
plotter.add_mesh(gt_sphere, color='blue', opacity=0.8)
# Add GT connections as blue lines
for gt_connection in gt_connections:
gt_start_idx, gt_end_idx = gt_connection
if gt_start_idx < len(gt_vertices) and gt_end_idx < len(gt_vertices):
gt_line_points = np.array([gt_vertices[gt_start_idx], gt_vertices[gt_end_idx]])
gt_line = pv.Line(gt_line_points[0], gt_line_points[1])
plotter.add_mesh(gt_line, color='blue', line_width=8)
# Visualize each patch
for patch_idx, patch in enumerate(all_patches):
# Use green for positive (edge), red for negative (non-edge)
patch_color = 'green' if patch['label'] == 1 else 'red'
# Get patch data
points_in_cylinder = patch['patch_10d'][:, :3] # xyz coordinates
line_start = patch['line_start']
line_end = patch['line_end']
center = patch['center'] # Use center instead of calculating midpoint
# Shift points back to world coordinates for visualization
points_world = points_in_cylinder + center
# Add points inside cylinder with patch-specific color
if len(points_world) > 0:
cylinder_cloud = pv.PolyData(points_world)
plotter.add_mesh(cylinder_cloud, color=patch_color, point_size=8, render_points_as_spheres=True)
# Add start and end points as larger spheres
start_sphere = pv.Sphere(radius=0.1, center=line_start + center)
end_sphere = pv.Sphere(radius=0.1, center=line_end + center)
plotter.add_mesh(start_sphere, color='black', opacity=0.8)
plotter.add_mesh(end_sphere, color='white', opacity=0.8)
# Add line between start and end
line_points = np.array([line_start + center, line_end + center])
line = pv.Line(line_points[0], line_points[1])
plotter.add_mesh(line, color=patch_color, line_width=5)
# Add cylinder wireframe to show extraction bounds
cylinder_center = center
cylinder_direction = (line_end - line_start) / np.linalg.norm(line_end - line_start)
cylinder_height = np.linalg.norm(line_end - line_start) + 2 * 0.25 # Including extensions
# Create cylinder mesh for visualization
cylinder_mesh = pv.Cylinder(center=cylinder_center, direction=cylinder_direction,
radius=patch['cylinder_radius'], height=cylinder_height)
plotter.add_mesh(cylinder_mesh, color=patch_color, opacity=0.2, style='wireframe')
# Set title based on label distribution
positive_count = sum(1 for patch in all_patches if patch['label'] == 1)
negative_count = sum(1 for patch in all_patches if patch['label'] == 0)
title = f"Edge Patches - Positive (Green): {positive_count}, Negative (Red): {negative_count}, GT (Blue)"
plotter.show(title=title)
return all_patches
def generate_edge_patches_forward(frame, pred_vertices):
vertices = pred_vertices
cylinder_radius = 0.5
colmap = frame['colmap_binary']
# Create 6D point cloud from COLMAP data
colmap_points_6d = []
for pid, p3D in colmap.points3D.items():
# Combine xyz coordinates and RGB color
point_6d = np.concatenate([p3D.xyz, p3D.color / 255.0]) # Normalize color to [0,1]
colmap_points_6d.append(point_6d)
colmap_points_6d = np.array(colmap_points_6d) if colmap_points_6d else np.empty((0, 6))
colmap_points_6d[:, 3:] = colmap_points_6d[:, 3:] * 2 - 1
# Extract 3D coordinates for faster vectorized operations
colmap_points_3d = colmap_points_6d[:, :3]
forward_patches = []
# For each vertex pair, create a patch without label
for i in range(len(vertices)):
for j in range(i + 1, len(vertices)):
start_vertex = vertices[i]
end_vertex = vertices[j]
# Create line vector from start to end
line_vector = end_vertex - start_vertex
line_length = np.linalg.norm(line_vector)
# Normalize line vector
line_direction = line_vector / line_length
# Extend the line by 25 cm (0.25 meters) on both ends for more context
extension_length = 0.25 # 25 cm in meters
extended_start = start_vertex - extension_length * line_direction
extended_end = end_vertex + extension_length * line_direction
extended_line_length = line_length + 2 * extension_length
# Vectorized distance calculation
# Vector from extended start to all points
start_to_points = colmap_points_3d - extended_start[np.newaxis, :]
# Project onto line direction to get distance along extended line
projection_lengths = np.dot(start_to_points, line_direction)
# Filter points within extended line segment bounds
within_bounds = (projection_lengths >= 0) & (projection_lengths <= extended_line_length)
# Find closest points on extended line segment for all points
closest_points_on_line = extended_start[np.newaxis, :] + projection_lengths[:, np.newaxis] * line_direction[np.newaxis, :]
# Calculate perpendicular distances from points to line
perpendicular_distances = np.linalg.norm(colmap_points_3d - closest_points_on_line, axis=1)
# Find points within cylinder
within_cylinder = within_bounds & (perpendicular_distances <= cylinder_radius)
if np.sum(within_cylinder) <= 10:
continue
points_in_cylinder = colmap_points_6d[within_cylinder]
point_indices_in_cylinder = np.where(within_cylinder)[0]
# Center the patch at the midpoint of the original line (not extended)
line_midpoint = (start_vertex + end_vertex) / 2
# Shift points to center around origin
points_centered = points_in_cylinder.copy()
points_centered[:, :3] -= line_midpoint
# Create edge patch without label
edge_patch = {
'patch_6d': points_centered,
'connection': (i, j),
'line_start': start_vertex - line_midpoint,
'line_end': end_vertex - line_midpoint,
'cylinder_radius': cylinder_radius,
'point_indices': point_indices_in_cylinder,
'center': line_midpoint
}
forward_patches.append(edge_patch)
return forward_patches
def generate_edge_patches_forward_10d(frame, pred_vertices, colmap_pcloud):
vertices = np.array(pred_vertices) if pred_vertices is not None and len(pred_vertices) > 0 else np.empty((0, 3))
forward_patches = []
cylinder_radius = 1.0 # meters
# colmap_pcloud['points_7d'] is [x,y,z, r,g,b (0-1), pid]
# Extract xyz and rgb (0-1)
points_xyz_rgb_pid = colmap_pcloud['points_7d']
colmap_points_3d = points_xyz_rgb_pid[:, :3]
colmap_rgb_colors_01 = points_xyz_rgb_pid[:, 3:6]
# Normalize RGB colors to [-1, 1]
colmap_rgb_colors_neg1_1 = colmap_rgb_colors_01 * 2.0 - 1.0
ade_counts = colmap_pcloud['ade'] # These are counts
ade_feature_neg1_1 = np.where(ade_counts > 0, 1.0, -1.0).reshape(-1, 1) # Normalize to [-1, 1]
gestalt_observations_per_point = colmap_pcloud['gestalt'] # List of lists of uint8 RGB arrays
fused_gestalt_neg1_1 = np.zeros((len(colmap_points_3d), 3))
if len(colmap_points_3d) > 0: # Ensure there are points to process
for i, point_gestalt_list in enumerate(gestalt_observations_per_point):
if not point_gestalt_list: # Empty list
fused_gestalt_neg1_1[i] = np.array([-1.0, -1.0, -1.0])
continue
gestalt_tuples = [tuple(gestalt_val) for gestalt_val in point_gestalt_list]
counts = Counter(gestalt_tuples)
if counts: # Ensure counts is not empty
most_common_tuple = counts.most_common(1)[0][0]
fused_value_uint8 = np.array(most_common_tuple, dtype=np.uint8)
fused_gestalt_neg1_1[i] = (fused_value_uint8 / 255.0) * 2.0 - 1.0
else: # Default if counts is empty (e.g. all gestalt_val were unhashable or list was empty after filtering)
fused_gestalt_neg1_1[i] = np.array([-1.0, -1.0, -1.0])
else: # Handle case with no points
fused_gestalt_neg1_1 = np.empty((0,3))
# Create combined 10D point cloud (xyz + rgb + ade + gestalt)
if len(colmap_points_3d) > 0:
colmap_points_10d = np.hstack((
colmap_points_3d,
colmap_rgb_colors_neg1_1,
ade_feature_neg1_1,
fused_gestalt_neg1_1
))
else:
colmap_points_10d = np.empty((0,10))
# For each unique pair of vertices, create a candidate edge patch
if len(vertices) >= 2 and len(colmap_points_10d) > 0:
for i in range(len(vertices)):
for j in range(i + 1, len(vertices)): # Ensure unique pairs (j > i)
start_vertex = vertices[i]
end_vertex = vertices[j]
line_vector = end_vertex - start_vertex
line_length = np.linalg.norm(line_vector)
if line_length < 1e-6: continue # Avoid division by zero or very short lines
line_direction = line_vector / line_length
extension_length = 1.0 # meters
extended_start = start_vertex - extension_length * line_direction
extended_end = end_vertex + extension_length * line_direction
extended_line_length = line_length + 2 * extension_length
start_to_points = colmap_points_3d - extended_start[np.newaxis, :]
projection_lengths = np.dot(start_to_points, line_direction)
within_bounds = (projection_lengths >= 0) & (projection_lengths <= extended_line_length)
# Ensure closest_points_on_line has the same shape for subtraction
closest_points_on_line = extended_start[np.newaxis, :] + projection_lengths[:, np.newaxis] * line_direction[np.newaxis, :]
perpendicular_distances = np.linalg.norm(colmap_points_3d - closest_points_on_line, axis=1)
within_cylinder = within_bounds & (perpendicular_distances <= cylinder_radius)
if np.sum(within_cylinder) <= 5: # Minimum number of points to form a patch
continue
points_in_cylinder_10d = colmap_points_10d[within_cylinder]
point_indices_in_cylinder = np.where(within_cylinder)[0] # Original indices from colmap_points_10d
line_midpoint = (start_vertex + end_vertex) / 2
points_centered_10d = points_in_cylinder_10d.copy()
points_centered_10d[:, :3] -= line_midpoint # Center XYZ coordinates
candidate_patch = {
'patch_10d': points_centered_10d,
'connection': (i, j), # Indices refer to `pred_vertices`
'line_start': start_vertex - line_midpoint, # Relative to midpoint
'line_end': end_vertex - line_midpoint, # Relative to midpoint
'cylinder_radius': cylinder_radius,
'point_indices': point_indices_in_cylinder, # Indices from the full 10D point cloud
'center': line_midpoint # World coordinate of the patch center
}
forward_patches.append(candidate_patch)
#print(f"Generated {len(forward_patches)} candidate edge patches for 10d_forward")
# Visualization (optional, can be enabled for debugging)
if False:
# Ensure pyvista (pv) is imported if this block is enabled
# import pyvista as pv
plotter = pv.Plotter()
if len(colmap_points_3d) > 0:
whole_cloud = pv.PolyData(colmap_points_3d)
# Use actual RGB colors from colmap_rgb_colors_01 for visualization
whole_cloud["colors"] = colmap_rgb_colors_01
plotter.add_mesh(whole_cloud, scalars="colors", rgb=True, point_size=3, render_points_as_spheres=True)
# Visualize predicted vertices
for vert_idx, vert_pos in enumerate(vertices):
vert_sphere = pv.Sphere(radius=0.1, center=vert_pos)
plotter.add_mesh(vert_sphere, color='cyan', opacity=0.8)
plotter.add_point_labels([vert_pos], [f"V{vert_idx}"], point_size=20, font_size=10)
for patch_idx, patch in enumerate(forward_patches):
patch_color = 'orange' # Color for candidate patches
points_in_cylinder_xyz_local = patch['patch_10d'][:, :3] # Already centered
line_start_local = patch['line_start']
line_end_local = patch['line_end']
patch_center_world = patch['center']
# Transform patch points back to world coordinates for visualization
points_world = points_in_cylinder_xyz_local + patch_center_world
if len(points_world) > 0:
cylinder_cloud = pv.PolyData(points_world)
# Use RGB from patch_10d (cols 3,4,5), denormalized for visualization
patch_rgb_colors_neg1_1 = patch['patch_10d'][:, 3:6]
patch_rgb_colors_01 = (patch_rgb_colors_neg1_1 + 1.0) / 2.0
cylinder_cloud["colors"] = patch_rgb_colors_01
plotter.add_mesh(cylinder_cloud, scalars="colors", rgb=True, point_size=8, render_points_as_spheres=True)
# Visualize the line segment (connection) in world coordinates
start_point_world = line_start_local + patch_center_world
end_point_world = line_end_local + patch_center_world
start_sphere_world = pv.Sphere(radius=0.05, center=start_point_world)
end_sphere_world = pv.Sphere(radius=0.05, center=end_point_world)
plotter.add_mesh(start_sphere_world, color='black', opacity=0.8)
plotter.add_mesh(end_sphere_world, color='white', opacity=0.8)
line_world = pv.Line(start_point_world, end_point_world)
plotter.add_mesh(line_world, color=patch_color, line_width=3)
# Visualize cylinder bounds
cyl_direction_local = (line_end_local - line_start_local)
cyl_height_local = np.linalg.norm(cyl_direction_local)
if cyl_height_local > 1e-6:
cyl_direction_unit_local = cyl_direction_local / cyl_height_local
# Height for visualization should match the extended line used for point gathering
cyl_height_world_vis = cyl_height_local + 2 * 1.0 # extension_length was 1.0
cylinder_mesh = pv.Cylinder(center=patch_center_world,
direction=cyl_direction_unit_local,
radius=patch['cylinder_radius'],
height=cyl_height_world_vis)
plotter.add_mesh(cylinder_mesh, color=patch_color, opacity=0.15, style='wireframe')
title = f"Candidate Edge Patches (10d_forward): {len(forward_patches)}"
plotter.show(title=title)
return forward_patches
def calculate_cylinder_overlap_volume(cyl1, cyl2):
"""
Calculate the intersection volume between two cylinders using numpy vectorization.
Returns approximate overlap volume.
"""
# Get cylinder parameters
p1_start, p1_end = cyl1['start'], cyl1['end']
p2_start, p2_end = cyl2['start'], cyl2['end']
r1, r2 = cyl1['radius'], cyl2['radius']
# Calculate cylinder axes
axis1 = p1_end - p1_start
axis2 = p2_end - p2_start
len1 = np.linalg.norm(axis1)
len2 = np.linalg.norm(axis2)
if len1 == 0 or len2 == 0:
return 0.0
axis1_norm = axis1 / len1
axis2_norm = axis2 / len2
# Calculate distance between cylinder axes using line-line distance formula
w = p1_start - p2_start
a = np.dot(axis1_norm, axis1_norm)
b = np.dot(axis1_norm, axis2_norm)
c = np.dot(axis2_norm, axis2_norm)
d = np.dot(axis1_norm, w)
e = np.dot(axis2_norm, w)
denom = a * c - b * b
if abs(denom) < 1e-10: # Lines are parallel
# Calculate perpendicular distance between parallel lines
cross_product = np.cross(axis1_norm, w)
if axis1_norm.shape[0] == 3: # 3D case
dist = np.linalg.norm(cross_product)
else: # 2D case
dist = abs(cross_product)
else:
# Calculate closest points on both lines
t1 = (b * e - c * d) / denom
t2 = (a * e - b * d) / denom
# Clamp to cylinder bounds
t1 = np.clip(t1, 0, len1)
t2 = np.clip(t2, 0, len2)
# Calculate distance between closest points
point1 = p1_start + t1 * axis1_norm
point2 = p2_start + t2 * axis2_norm
dist = np.linalg.norm(point1 - point2)
# If cylinders don't intersect radially, return 0
if dist >= (r1 + r2):
return 0.0
# Calculate overlapping length along both axes
# Project cylinder 2 endpoints onto cylinder 1 axis
proj_start = np.dot(p2_start - p1_start, axis1_norm)
proj_end = np.dot(p2_end - p1_start, axis1_norm)
# Find overlap interval
overlap_start = max(0, min(proj_start, proj_end))
overlap_end = min(len1, max(proj_start, proj_end))
overlap_length = max(0, overlap_end - overlap_start)
if overlap_length <= 0:
return 0.0
# Approximate volume calculation
# For simplicity, assume uniform overlap along the length
if dist < abs(r1 - r2):
# One cylinder is inside the other
smaller_radius = min(r1, r2)
overlap_volume = np.pi * smaller_radius**2 * overlap_length
else:
# Partial overlap - use geometric approximation
# This is a simplified calculation for the intersection area of two circles
r_smaller = min(r1, r2)
r_larger = max(r1, r2)
if dist < (r1 + r2):
# Calculate intersection area of two circles (approximate)
# Using lens area formula
d1 = (r1**2 - r2**2 + dist**2) / (2 * dist) if dist > 0 else 0
d2 = dist - d1
if d1 >= 0 and d1 <= r1 and d2 >= 0 and d2 <= r2:
area1 = r1**2 * np.arccos(d1/r1) - d1 * np.sqrt(r1**2 - d1**2)
area2 = r2**2 * np.arccos(d2/r2) - d2 * np.sqrt(r2**2 - d2**2)
intersection_area = area1 + area2
else:
intersection_area = np.pi * r_smaller**2
overlap_volume = intersection_area * overlap_length
else:
overlap_volume = 0.0
return max(0.0, overlap_volume)
def create_pcloud(colmap_rec, frame):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#print(f"create_pcloud using device: {device}")
# 1. Preprocess image data from the frame and colmap (mostly on CPU)
img_id_to_colmap_img_obj_map = {
img_obj.name: img_obj for img_obj_name, img_obj in colmap_rec.images.items()
}
frame_img_data = {}
ordered_frame_img_ids = []
for K_val, R_val, t_val, img_id_val, ade_val, gestalt_val, depth_val in zip(
frame['K'], frame['R'], frame['t'], frame['image_ids'],
frame['ade'], frame['gestalt'], frame['depth']
):
if img_id_val not in img_id_to_colmap_img_obj_map:
continue
ordered_frame_img_ids.append(img_id_val)
depth_np = np.array(depth_val)
depth_H, depth_W = depth_np.shape[0], depth_np.shape[1]
ade_mask_np = get_house_mask(ade_val)
gest_seg_pil = gestalt_val.resize((depth_W, depth_H), Image.Resampling.NEAREST)
gest_seg_np = np.array(gest_seg_pil).astype(np.uint8)
frame_img_data[img_id_val] = {
'K_np': np.array(K_val),
'R_np': np.array(R_val),
't_np': np.array(t_val).reshape(3,1),
'ade_mask_np': ade_mask_np,
'gestalt_seg_np': gest_seg_np,
'H': depth_H,
'W': depth_W
}
# 2. Process 3D points by iterating through images
point_data_accumulator = {} # Key: pid, accumulates data on CPU
# Pre-fetch all COLMAP point data to avoid repeated dictionary lookups
colmap_points_data_cpu = {
pid: {'xyz': p3D.xyz, 'color': p3D.color / 255.0}
for pid, p3D in colmap_rec.points3D.items()
}
for img_id in ordered_frame_img_ids:
if img_id not in frame_img_data:
continue
col_img_obj = img_id_to_colmap_img_obj_map[img_id]
img_data = frame_img_data[img_id]
K_np, R_np, t_np = img_data['K_np'], img_data['R_np'], img_data['t_np']
ade_mask_np, gestalt_seg_np = img_data['ade_mask_np'], img_data['gestalt_seg_np']
H, W = img_data['H'], img_data['W']
# Convert current image data to GPU tensors
K_gpu = torch.from_numpy(K_np).float().to(device)
R_gpu = torch.from_numpy(R_np).float().to(device)
t_gpu = torch.from_numpy(t_np).float().to(device)
ade_mask_gpu = torch.from_numpy(ade_mask_np).bool().to(device)
gestalt_seg_gpu = torch.from_numpy(gestalt_seg_np).to(device) # uint8 is fine
visible_pids_in_img = []
visible_xyz_coords_list = []
for pid, p3D_data in colmap_points_data_cpu.items():
if col_img_obj.has_point3D(pid): # This check remains CPU-bound
visible_pids_in_img.append(pid)
visible_xyz_coords_list.append(p3D_data['xyz'])
if not visible_pids_in_img:
continue
num_visible_points = len(visible_pids_in_img)
world_pts_np = np.array(visible_xyz_coords_list)
world_pts_gpu = torch.from_numpy(world_pts_np).float().to(device)
# Batch projection on GPU
world_pts_h_gpu = torch.cat((world_pts_gpu, torch.ones(num_visible_points, 1, device=device)), dim=1)
P_world_to_cam_gpu = torch.hstack((R_gpu, t_gpu))
cam_coords_proj_gpu = P_world_to_cam_gpu @ world_pts_h_gpu.T
cam_coords_z_gpu = cam_coords_proj_gpu[2, :]
in_front_mask_gpu = cam_coords_z_gpu > 1e-6
pixel_coords_h_gpu = K_gpu @ cam_coords_proj_gpu
u_proj_gpu = torch.full_like(cam_coords_z_gpu, -1.0, dtype=torch.float32)
v_proj_gpu = torch.full_like(cam_coords_z_gpu, -1.0, dtype=torch.float32)
# Avoid division by zero/small numbers for points not truly in front or on optical center
valid_depth_mask_gpu = in_front_mask_gpu & (torch.abs(cam_coords_z_gpu) > 1e-6)
if torch.any(valid_depth_mask_gpu):
u_proj_gpu[valid_depth_mask_gpu] = pixel_coords_h_gpu[0, valid_depth_mask_gpu] / cam_coords_z_gpu[valid_depth_mask_gpu]
v_proj_gpu[valid_depth_mask_gpu] = pixel_coords_h_gpu[1, valid_depth_mask_gpu] / cam_coords_z_gpu[valid_depth_mask_gpu]
u_rounded_gpu = torch.round(u_proj_gpu).long()
v_rounded_gpu = torch.round(v_proj_gpu).long()
is_in_bounds_gpu = (u_rounded_gpu >= 0) & (u_rounded_gpu < W) & \
(v_rounded_gpu >= 0) & (v_rounded_gpu < H) & \
in_front_mask_gpu # Re-check in_front_mask_gpu as rounding might affect edge cases slightly
# Sample ADE and Gestalt on GPU for points in bounds
# Initialize with default values for all points, then update for those in bounds
sampled_ade_status_gpu = torch.zeros(num_visible_points, dtype=torch.bool, device=device)
sampled_gestalt_values_gpu = torch.zeros(num_visible_points, 3, dtype=torch.uint8, device=device)
# Create a mask for points that are valid for sampling (in_bounds and in_front)
valid_for_sampling_mask_gpu = is_in_bounds_gpu
if torch.any(valid_for_sampling_mask_gpu):
u_sample_gpu = u_rounded_gpu[valid_for_sampling_mask_gpu]
v_sample_gpu = v_rounded_gpu[valid_for_sampling_mask_gpu]
sampled_ade_status_gpu[valid_for_sampling_mask_gpu] = ade_mask_gpu[v_sample_gpu, u_sample_gpu]
sampled_gestalt_values_gpu[valid_for_sampling_mask_gpu] = gestalt_seg_gpu[v_sample_gpu, u_sample_gpu]
# Transfer necessary results back to CPU for accumulation
u_rounded_cpu = u_rounded_gpu.cpu().numpy()
v_rounded_cpu = v_rounded_gpu.cpu().numpy()
is_in_bounds_cpu = is_in_bounds_gpu.cpu().numpy() # Use the original is_in_bounds_gpu for logic
sampled_ade_status_cpu = sampled_ade_status_gpu.cpu().numpy()
sampled_gestalt_values_cpu = sampled_gestalt_values_gpu.cpu().numpy()
# Update accumulator (on CPU)
for i in range(num_visible_points):
pid = visible_pids_in_img[i]
if pid not in point_data_accumulator:
point_data_accumulator[pid] = {
'xyz': colmap_points_data_cpu[pid]['xyz'],
'color': colmap_points_data_cpu[pid]['color'],
'imgs_seen_by': [],
'uv_projections': [],
'ade_count': 0, # Count of times seen in ADE segmentation
'gestalt_values': []
}
acc = point_data_accumulator[pid]
acc['imgs_seen_by'].append(img_id)
acc['uv_projections'].append((u_rounded_cpu[i], v_rounded_cpu[i]))
if is_in_bounds_cpu[i]: # This point was projected within bounds and in front
if sampled_ade_status_cpu[i]:
acc['ade_count'] += 1
acc['gestalt_values'].append(sampled_gestalt_values_cpu[i])
else: # Point projected out of bounds, behind, or failed depth check
acc['gestalt_values'].append(np.array([0,0,0], dtype=np.uint8))
# Optional: clear GPU cache if memory is a concern for many images
# if device.type == 'cuda':
# torch.cuda.empty_cache()
# 3. Final data assembly (on CPU)
points_xyz_world_list = []
points_colors_list = []
points_idxs_list = []
points_imgs_seen_by_list = []
points_uv_projections_per_point_list = []
points_ade_count_final_list = []
points_gestalt_values_per_point_list = []
# Ensure consistent order if downstream code relies on it, though original didn't specify sorting for pids
# Using sorted_pids for reproducibility if point_data_accumulator keys order changes.
sorted_pids = sorted(point_data_accumulator.keys())
for pid in sorted_pids:
data = point_data_accumulator[pid]
points_xyz_world_list.append(data['xyz'])
points_colors_list.append(data['color'])
points_idxs_list.append(pid)
points_imgs_seen_by_list.append(data['imgs_seen_by'])
points_uv_projections_per_point_list.append(data['uv_projections'])
points_ade_count_final_list.append(data['ade_count'])
points_gestalt_values_per_point_list.append(data['gestalt_values'])
points_xyz_world = np.array(points_xyz_world_list) if points_xyz_world_list else np.empty((0, 3))
points_colors = np.array(points_colors_list) if points_colors_list else np.empty((0, 3))
points_idxs = np.array(points_idxs_list, dtype=int) if points_idxs_list else np.empty((0,), dtype=int) # Ensure dtype for pids
points_ade = np.array(points_ade_count_final_list, dtype=int) if points_ade_count_final_list else np.empty((0,), dtype=int)
output_all_colmap_img_ids = [img_obj.name for img_obj_name, img_obj in colmap_rec.images.items()]
output_frame_K, output_frame_R, output_frame_t = [], [], []
for img_id_val in frame['image_ids']:
if img_id_val in frame_img_data:
data = frame_img_data[img_id_val]
output_frame_K.append(data['K_np'])
output_frame_R.append(data['R_np'])
output_frame_t.append(data['t_np'])
if points_xyz_world.shape[0] > 0:
colmap_points_7d = np.zeros((points_xyz_world.shape[0], 7))
colmap_points_7d[:, :3] = points_xyz_world
colmap_points_7d[:, 3:6] = points_colors
colmap_points_7d[:, 6] = points_idxs
whole_pcloud = {
'points_7d': colmap_points_7d,
'imgs': points_imgs_seen_by_list,
'uv': points_uv_projections_per_point_list,
'all_imgs_ids': output_all_colmap_img_ids,
'all_imgs_K': output_frame_K,
'all_imgs_R': output_frame_R,
'all_imgs_t': output_frame_t,
'ade': points_ade,
'gestalt': points_gestalt_values_per_point_list
}
else:
whole_pcloud = {
'points_7d': np.empty((0, 7)),
'imgs': [],
'uv': [],
'all_imgs_ids': output_all_colmap_img_ids,
'all_imgs_K': output_frame_K,
'all_imgs_R': output_frame_R,
'all_imgs_t': output_frame_t,
'ade': np.empty((0,), dtype=int),
'gestalt': []
}
return whole_pcloud
def predict_wireframe(entry, pnet_model, voxel_model, pnet_class_model, config) -> Tuple[np.ndarray, List[int]]:
"""
Predict 3D wireframe from a dataset entry.
"""
device = 'cuda' if torch.cuda.is_available() else 'cpu'
good_entry = convert_entry_to_human_readable(entry)
colmap_rec = good_entry['colmap_binary']
start_time = time.time()
colmap_pcloud = create_pcloud(colmap_rec, good_entry)
print(f"Time for create_pcloud: {time.time() - start_time:.4f} seconds")
vertex_threshold = config.get('vertex_threshold', 0.5)
edge_threshold = config.get('edge_threshold', 0.5)
only_predicted_connections = config.get('only_predicted_connections', False)
vert_edge_per_image = {}
idxs_points = []
all_connections = []
our_get_vertices_time_total = 0
for i, (gest, depth, K, R, t, img_id, ade_seg) in enumerate(zip(good_entry['gestalt'],
good_entry['depth'],
good_entry['K'],
good_entry['R'],
good_entry['t'],
good_entry['image_ids'],
good_entry['ade'] # Added ade20k segmentation
)):
# Visualize gestalt segmentation
K = np.array(K)
R = np.array(R)
t = np.array(t)
# Resize gestalt segmentation to match depth map size
depth_size = (np.array(depth).shape[1], np.array(depth).shape[0]) # W, H
gest_seg = gest.resize(depth_size)
gest_seg_np = np.array(gest_seg).astype(np.uint8)
start_time_loop = time.time()
vertices_ours, connections_ours, vertices_3d_ours, patches, filtered_point_idxs = our_get_vertices_and_edges(gest_seg_np, colmap_rec, img_id, ade_seg, depth, K=K, R=R, t=t, frame=good_entry)
our_get_vertices_time_total += (time.time() - start_time_loop)
idxs_points.append(filtered_point_idxs)
all_connections.append(connections_ours)
vertices, connections, vertices_3d = vertices_ours, connections_ours, vertices_3d_ours
vert_edge_per_image[i] = vertices, connections, vertices_3d
print(f"Total time for our_get_vertices_and_edges loop: {our_get_vertices_time_total:.4f} seconds")
start_time = time.time()
extracted_points, extracted_colors, extracted_ids, whole_pcloud, connections = extract_vertices_from_whole_pcloud_v2(colmap_pcloud, idxs_points, all_connections)
print(f"Time for extract_vertices_from_whole_pcloud_v2: {time.time() - start_time:.4f} seconds")
wf_vertices = good_entry.get('wf_vertices', None)
start_time = time.time()
patches = generate_patches_v3(extracted_points, extracted_colors, extracted_ids, whole_pcloud, wf_vertices)
print(f"Time for generate_patches_v3: {time.time() - start_time:.4f} seconds")
if GENERATE_DATASET:
start_time = time.time()
save_patches_dataset(patches, DATASET_DIR, img_id)
print(f"Time for save_patches_dataset: {time.time() - start_time:.4f} seconds")
return empty_solution()
predicted_vertices = []
predict_vertex_time_total = 0
for i, patch in enumerate(patches):
start_time_loop = time.time()
pred_vertex, pred_dist, pred_class = predict_vertex_from_patch(pnet_model, patch, device=device)
predict_vertex_time_total += (time.time() - start_time_loop)
if pred_class > vertex_threshold:
predicted_vertices.append(pred_vertex)
else:
predicted_vertices.append(np.array([0.0, 0.0, 0.0])) # Append a zero vertex if not predicted
print(f"Total time for predict_vertex_from_patch loop: {predict_vertex_time_total:.4f} seconds")
predicted_vertices = np.array(predicted_vertices) if predicted_vertices else np.empty((0, 3))
# Filter out zero vertices and update connections accordingly
non_zero_mask = ~np.all(np.isclose(predicted_vertices, [0.0, 0.0, 0.0]), axis=1)
valid_indices = np.where(non_zero_mask)[0]
# Filter vertices to only include non-zero ones
filtered_vertices = predicted_vertices[valid_indices]
if GENERATE_DATASET_EDGES:
start_time = time.time()
edge_patches = generate_edge_patches(good_entry, filtered_vertices, colmap_pcloud)
print(f"Time for generate_edge_patches: {time.time() - start_time:.4f} seconds")
start_time = time.time()
save_patches_dataset_class(edge_patches, EDGES_DATASET_DIR, good_entry['order_id'])
print(f"Time for save_patches_dataset_class: {time.time() - start_time:.4f} seconds")
return empty_solution()
if len(valid_indices) == 0:
print("No valid predicted vertices found")
return empty_solution()
# Create mapping from old indices to new indices
old_to_new_mapping = {old_idx: new_idx for new_idx, old_idx in enumerate(valid_indices)}
# Filter and update connections
filtered_connections = []
for start_idx, end_idx in connections:
if start_idx in old_to_new_mapping and end_idx in old_to_new_mapping:
new_start = old_to_new_mapping[start_idx]
new_end = old_to_new_mapping[end_idx]
if new_start != new_end: # Ensure we don't connect a vertex to itself
filtered_connections.append((new_start, new_end))
start_time = time.time()
#forward_patches = generate_edge_patches_forward_10d(good_entry, filtered_vertices, colmap_pcloud)
forward_patches = generate_edge_patches_forward(good_entry, filtered_vertices)
print(f"Time for generate_edge_patches_forward: {time.time() - start_time:.4f} seconds")
new_connections = []
predict_class_time_total = 0
if len(forward_patches) > 0:
for i, patch in enumerate(forward_patches):
start_idx, end_idx = patch['connection']
start_time_loop = time.time()
pred_class, pred_score = predict_class_from_patch(pnet_class_model, patch, device=device)
predict_class_time_total += (time.time() - start_time_loop)
if pred_score > edge_threshold:
new_connections.append((start_idx, end_idx))
print(f"Total time for predict_class_from_patch loop: {predict_class_time_total:.4f} seconds")
predicted_vertices = np.array(filtered_vertices)
if only_predicted_connections:
connections = new_connections
else:
connections = filtered_connections + new_connections
# Remove duplicates from connections
connections = list(set(connections))
return predicted_vertices, connections
def predict_wireframe_old(entry) -> Tuple[np.ndarray, List[int]]:
"""
Predict 3D wireframe from a dataset entry.
"""
good_entry = convert_entry_to_human_readable(entry)
vert_edge_per_image = {}
for i, (gest, depth, K, R, t, img_id, ade_seg) in enumerate(zip(good_entry['gestalt'],
good_entry['depth'],
good_entry['K'],
good_entry['R'],
good_entry['t'],
good_entry['image_ids'],
good_entry['ade'] # Added ade20k segmentation
)):
colmap_rec = good_entry['colmap_binary']
K = np.array(K)
R = np.array(R)
t = np.array(t)
# Resize gestalt segmentation to match depth map size
depth_size = (np.array(depth).shape[1], np.array(depth).shape[0]) # W, H
gest_seg = gest.resize(depth_size)
gest_seg_np = np.array(gest_seg).astype(np.uint8)
# Get 2D vertices and edges first
vertices, connections = get_vertices_and_edges_from_segmentation(gest_seg_np, edge_th=25.)
# Check if we have enough to proceed
if (len(vertices) < 2) or (len(connections) < 1):
print(f'Not enough vertices or connections found in image {i}, skipping.')
vert_edge_per_image[i] = [], [], np.empty((0, 3))
continue
# Call the refactored function to get 3D points
vertices_3d = create_3d_wireframe_single_image(
vertices, connections, depth, colmap_rec, img_id, ade_seg, K, R, t
)
# Store original 2D vertices, connections, and computed 3D points
vert_edge_per_image[i] = vertices, connections, vertices_3d
# Merge vertices from all images
all_3d_vertices, connections_3d = merge_vertices_3d(vert_edge_per_image, 0.5)
all_3d_vertices_clean, connections_3d_clean = prune_not_connected(all_3d_vertices, connections_3d, keep_largest=False)
all_3d_vertices_clean, connections_3d_clean = prune_too_far(all_3d_vertices_clean, connections_3d_clean, colmap_rec, th = 1.5)
if (len(all_3d_vertices_clean) < 2) or len(connections_3d_clean) < 1:
print (f'Not enough vertices or connections in the 3D vertices')
return empty_solution()
return all_3d_vertices_clean, connections_3d_clean
def generate_patches_v2(extracted_points, extracted_colors, extracted_ids, whole_pcloud, wf_vertices):
patches = []
whole_points = whole_pcloud['points']
whole_colors = whole_pcloud['colors']
whole_ids = whole_pcloud['ids']
wf_vertices = np.array(wf_vertices) if wf_vertices is not None else np.empty((0, 3))
for cluster_idx, (cluster_points, cluster_colors, cluster_ids) in enumerate(zip(extracted_points, extracted_colors, extracted_ids)):
if len(cluster_points) == 0:
continue
# Calculate center as mean of cluster points
cluster_center = np.mean(cluster_points, axis=0)
# Define cube edge length
cube_edge_length = 4.0
half_edge = cube_edge_length / 2.0
# Find points within cube bounds
within_cube_mask = (
(whole_points[:, 0] >= cluster_center[0] - half_edge) &
(whole_points[:, 0] <= cluster_center[0] + half_edge) &
(whole_points[:, 1] >= cluster_center[1] - half_edge) &
(whole_points[:, 1] <= cluster_center[1] + half_edge) &
(whole_points[:, 2] >= cluster_center[2] - half_edge) &
(whole_points[:, 2] <= cluster_center[2] + half_edge)
)
if not np.any(within_cube_mask):
continue
# Extract points within cube
cube_points = whole_points[within_cube_mask]
cube_colors = whole_colors[within_cube_mask]
cube_point_ids = whole_ids[within_cube_mask]
# Shift points to center at origin
cube_points_centered = cube_points - cluster_center
# Create 7D point cloud
patch_7d = np.zeros((len(cube_points_centered), 7))
patch_7d[:, :3] = cube_points_centered # xyz coordinates centered at origin
patch_7d[:, 3:6] = cube_colors * 2.0 - 1.0 # rgb colors normalized to [-1, 1]
# Set flag: 1 if point is in current cluster, -1 otherwise
cluster_ids_set = set(cluster_ids)
for i, pid in enumerate(cube_point_ids):
if pid in cluster_ids_set:
patch_7d[i, 6] = 1.0
else:
patch_7d[i, 6] = -1.0
# Find closest wf_vertex to cluster center
assigned_wf_vertex = None
if len(wf_vertices) > 0:
# Calculate distances from cluster center to all GT vertices
distances_to_gt = np.linalg.norm(wf_vertices - cluster_center, axis=1)
# Find GT vertices within 1 meter of cluster center
within_radius_mask = distances_to_gt <= 1.0
if np.any(within_radius_mask):
# Find the closest GT vertex within 1 meter
closest_idx = np.argmin(distances_to_gt[within_radius_mask])
# Get the actual index in the original array
valid_indices = np.where(within_radius_mask)[0]
actual_closest_idx = valid_indices[closest_idx]
# Shift the assigned GT vertex to be relative to origin
assigned_wf_vertex = wf_vertices[actual_closest_idx] - cluster_center
patch = {
'patch_7d': patch_7d,
'cluster_center': cluster_center,
'cube_edge_length': cube_edge_length,
'cluster_idx': cluster_idx,
'assigned_wf_vertex': assigned_wf_vertex,
'cube_point_ids': cube_point_ids,
'cluster_point_ids': cluster_ids
}
patches.append(patch)
# Visualize the patch using PyVista
if False: # Set to False to disable visualization
plotter = pv.Plotter()
# Create point cloud for this patch
patch_cloud = pv.PolyData(cube_points_centered)
# Color points based on cluster membership flag
patch_colors = []
for i in range(len(cube_points_centered)):
if patch_7d[i, 6] == 1.0: # Point is in cluster
patch_colors.append([1.0, 0.0, 0.0]) # Red for cluster points
else:
patch_colors.append([0.0, 0.0, 1.0]) # Blue for other points
patch_cloud["colors"] = np.array(patch_colors)
plotter.add_mesh(patch_cloud, scalars="colors", rgb=True, point_size=8, render_points_as_spheres=True)
# Add cube wireframe to show extraction bounds
cube_bounds = [
-half_edge, half_edge, # x_min, x_max
-half_edge, half_edge, # y_min, y_max
-half_edge, half_edge # z_min, z_max
]
cube_wireframe = pv.Box(bounds=cube_bounds)
plotter.add_mesh(cube_wireframe, style='wireframe', color='gray', line_width=2)
# Add sphere for assigned GT vertex if available
if assigned_wf_vertex is not None:
gt_sphere = pv.Sphere(radius=0.1, center=assigned_wf_vertex)
plotter.add_mesh(gt_sphere, color="green", opacity=0.7)
# Add sphere at origin to show patch center
origin_sphere = pv.Sphere(radius=0.05, center=[0, 0, 0])
plotter.add_mesh(origin_sphere, color="yellow", opacity=0.8)
plotter.show(title=f"Patch {cluster_idx} - Edge length: {cube_edge_length}m")
return patches
def generate_patches_v3(extracted_points, extracted_colors, extracted_ids, whole_pcloud, wf_vertices):
patches = []
whole_points = whole_pcloud['points']
whole_colors = whole_pcloud['colors'] # Now 7D: [r, g, b, ade, gestalt_r, gestalt_g, gestalt_b]
whole_ids = whole_pcloud['ids']
wf_vertices = np.array(wf_vertices) if wf_vertices is not None else np.empty((0, 3))
for cluster_idx, (cluster_points, cluster_colors, cluster_ids) in enumerate(zip(extracted_points, extracted_colors, extracted_ids)):
if len(cluster_points) == 0:
continue
# Calculate center as mean of cluster points
cluster_center = np.mean(cluster_points, axis=0)
# Define cube edge length
cube_edge_length = 8.0
half_edge = cube_edge_length / 2.0
# Find points within cube bounds
within_cube_mask = (
(whole_points[:, 0] >= cluster_center[0] - half_edge) &
(whole_points[:, 0] <= cluster_center[0] + half_edge) &
(whole_points[:, 1] >= cluster_center[1] - half_edge) &
(whole_points[:, 1] <= cluster_center[1] + half_edge) &
(whole_points[:, 2] >= cluster_center[2] - half_edge) &
(whole_points[:, 2] <= cluster_center[2] + half_edge)
)
if not np.any(within_cube_mask):
continue
# Extract points within cube
cube_points = whole_points[within_cube_mask]
cube_colors_7d = whole_colors[within_cube_mask] # Now 7D colors
cube_point_ids = whole_ids[within_cube_mask]
# Shift points to center at origin
cube_points_centered = cube_points - cluster_center
# Create 10D point cloud: [x, y, z, r, g, b, ade, gestalt_r, gestalt_g, gestalt_b]
patch_10d = np.zeros((len(cube_points_centered), 10))
patch_10d[:, :3] = cube_points_centered # xyz coordinates centered at origin
patch_10d[:, 3:6] = cube_colors_7d[:, :3] * 2.0 - 1.0 # rgb colors normalized to [-1, 1]
patch_10d[:, 6] = cube_colors_7d[:, 3] * 2.0 - 1.0 # ade feature normalized to [-1, 1]
patch_10d[:, 7:10] = cube_colors_7d[:, 4:7] * 2.0 - 1.0 # gestalt colors normalized to [-1, 1]
# Set flag: 1 if point is in current cluster, -1 otherwise
cluster_ids_set = set(cluster_ids)
cluster_flag = np.full(len(cube_point_ids), -1.0)
for i, pid in enumerate(cube_point_ids):
if pid in cluster_ids_set:
cluster_flag[i] = 1.0
# Add cluster flag as 11th dimension
patch_11d = np.zeros((len(cube_points_centered), 11))
patch_11d[:, :10] = patch_10d
patch_11d[:, 10] = cluster_flag
# Find closest wf_vertex to cluster center
assigned_wf_vertex = None
if len(wf_vertices) > 0:
# Calculate distances from cluster center to all GT vertices
distances_to_gt = np.linalg.norm(wf_vertices - cluster_center, axis=1)
# Find GT vertices within 1 meter of cluster center
within_radius_mask = distances_to_gt <= 1.0
if np.any(within_radius_mask):
# Find the closest GT vertex within 1 meter
closest_idx = np.argmin(distances_to_gt[within_radius_mask])
# Get the actual index in the original array
valid_indices = np.where(within_radius_mask)[0]
actual_closest_idx = valid_indices[closest_idx]
# Shift the assigned GT vertex to be relative to origin
assigned_wf_vertex = wf_vertices[actual_closest_idx] - cluster_center
patch = {
'patch_11d': patch_11d, # Changed from patch_7d to patch_11d
'cluster_center': cluster_center,
'cube_edge_length': cube_edge_length,
'cluster_idx': cluster_idx,
'assigned_wf_vertex': assigned_wf_vertex,
'cube_point_ids': cube_point_ids,
'cluster_point_ids': cluster_ids
}
patches.append(patch)
# Visualize the patch using PyVista
if False: # Set to False to disable visualization
plotter = pv.Plotter()
# Create point cloud for this patch
patch_cloud = pv.PolyData(cube_points_centered)
# Color points based on cluster membership flag
patch_colors = []
for i in range(len(cube_points_centered)):
if patch_11d[i, 10] == 1.0: # Point is in cluster
patch_colors.append([1.0, 0.0, 0.0]) # Red for cluster points
else:
patch_colors.append([0.0, 0.0, 1.0]) # Blue for other points
patch_cloud["colors"] = np.array(patch_colors)
plotter.add_mesh(patch_cloud, scalars="colors", rgb=True, point_size=8, render_points_as_spheres=True)
# Add cube wireframe to show extraction bounds
cube_bounds = [
-half_edge, half_edge, # x_min, x_max
-half_edge, half_edge, # y_min, y_max
-half_edge, half_edge # z_min, z_max
]
cube_wireframe = pv.Box(bounds=cube_bounds)
plotter.add_mesh(cube_wireframe, style='wireframe', color='gray', line_width=2)
# Add sphere for assigned GT vertex if available
if assigned_wf_vertex is not None:
gt_sphere = pv.Sphere(radius=0.1, center=assigned_wf_vertex)
plotter.add_mesh(gt_sphere, color="green", opacity=0.7)
# Add sphere at origin to show patch center
origin_sphere = pv.Sphere(radius=0.05, center=[0, 0, 0])
plotter.add_mesh(origin_sphere, color="yellow", opacity=0.8)
plotter.show(title=f"Patch {cluster_idx} - Edge length: {cube_edge_length}m")
return patches
def get_visible_points(colmap_rec, img_id_substring, R=None, t=None):
# 1) Find the matching COLMAP image to get its associated 3D points
# This part remains to identify which 3D points are relevant for this image view
found_img = None
for img_id_c, col_img_obj in colmap_rec.images.items(): # Renamed col_img to col_img_obj to avoid conflict
if img_id_substring in col_img_obj.name:
found_img = col_img_obj
break
if found_img is None:
print(f"Image substring {img_id_substring} not found in COLMAP.")
return [], [], []
# 2) Gather 3D points that this image sees (according to COLMAP)
points_xyz_world = []
points_idxs = []
for pid, p3D in colmap_rec.points3D.items():
if found_img.has_point3D(pid):
points_xyz_world.append(p3D.xyz) # world coords
points_idxs.append(pid)
if not points_xyz_world:
print(f"No 3D points associated with {found_img.name} in COLMAP.")
return [], [], []
points_xyz_world = np.array(points_xyz_world) # (N, 3)
points_idxs = np.array(points_idxs) # (N,)
points_xyz_world_h = np.hstack((points_xyz_world, np.ones((points_xyz_world.shape[0], 1)))) # (N, 4)
# World to Camera transformation matrix
world_to_cam_mat = np.eye(4)
world_to_cam_mat[:3, :3] = R
world_to_cam_mat[:3, 3] = t.flatten()
points_cam_h = (world_to_cam_mat @ points_xyz_world_h.T).T # (N, 4)
points_cam = points_cam_h[:, :3] / points_cam_h[:, 3, np.newaxis] # (N, 3) in camera coordinates
return points_cam, points_xyz_world, points_idxs
def project_points_to_2d(points_cam, K, H, W):
uv = []
valid_indices = [] # Track which original points are valid
for i in range(points_cam.shape[0]):
p_cam = points_cam[i]
# Ensure p_cam[2] (depth) is positive
if p_cam[2] <= 0:
continue
# Project to image plane using K
u_i = (K[0, 0] * p_cam[0] / p_cam[2]) + K[0, 2]
v_i = (K[1, 1] * p_cam[1] / p_cam[2]) + K[1, 2]
u_i_int = int(round(u_i))
v_i_int = int(round(v_i))
# Check in-bounds
if 0 <= u_i_int < W and 0 <= v_i_int < H:
uv.append((u_i_int, v_i_int))
valid_indices.append(i) # Store original index
uv = np.array(uv, dtype=int) # shape (M,2)
valid_indices = np.array(valid_indices) # shape (M,)
return uv, valid_indices
def project_points_to_2d_colmap(points_xyz_world, found_img, H, W):
uv_colmap = []
valid_indices_colmap = []
for i, xyz in enumerate(points_xyz_world):
proj = found_img.project_point(xyz) # returns (u, v) in image coords or None
if proj is not None:
u_i, v_i = proj
u_i = int(round(u_i))
v_i = int(round(v_i))
# Check in-bounds
if 0 <= u_i < W and 0 <= v_i < H:
uv_colmap.append((u_i, v_i))
valid_indices_colmap.append(i) # Store original index
uv_colmap = np.array(uv_colmap, dtype=int)
valid_indices_colmap = np.array(valid_indices_colmap)
return uv_colmap, valid_indices_colmap
def get_apex_or_eave_points(type, uv, gest_seg_np, house_mask, valid_indices, points_xyz_world, points_cam, points_idxs):
# Apex
if type == 'apex':
apex_color = np.array(gestalt_color_mapping['apex'])
elif type == 'eave_end':
apex_color = np.array(gestalt_color_mapping['eave_end_point'])
elif type == 'flashing_end_point':
apex_color = np.array(gestalt_color_mapping['flashing_end_point'])
apex_mask = cv2.inRange(gest_seg_np, apex_color-10., apex_color+10.)
filtered_points_xyz = []
filtered_point_idxs = []
filtered_points_color = []
filtered_vertices_apex = []
filtered_vertices_apex_uv = []
if apex_mask.sum() > 0:
output = cv2.connectedComponentsWithStats(apex_mask, 8, cv2.CV_32S)
(numLabels, labels, stats, centroids) = output
for i in range(1, numLabels):
cur_mask = labels == i
# Dilate the current mask to make it slightly larger
kernel = np.ones((5,5), np.uint8)
cur_mask = cv2.dilate(cur_mask.astype(np.uint8), kernel, iterations=2).astype(bool)
color = np.random.rand(3)
# Create boolean mask for points in current apex mask and house mask
valid_points_mask = cur_mask[uv[:, 1], uv[:, 0]] & house_mask[uv[:, 1], uv[:, 0]]
for z in range(5):
if np.sum(valid_points_mask) < 5:
cur_mask = cv2.dilate(cur_mask.astype(np.uint8), kernel, iterations=1).astype(bool)
valid_points_mask = cur_mask[uv[:, 1], uv[:, 0]] & house_mask[uv[:, 1], uv[:, 0]]
else:
break
if np.any(valid_points_mask):
# Get indices of valid points
valid_point_indices = valid_indices[valid_points_mask]
# Get 3D points in camera coordinates for depth filtering
valid_world_points = points_xyz_world[valid_point_indices]
valid_cam_points = points_cam[valid_point_indices]
# Compute depths (Z coordinates in camera space)
depths = valid_cam_points[:, 2]
# Find minimum depth and filter points within min_depth + 2 meters
if len(depths) > 0:
min_depth = np.min(depths)
depth_filter = depths <= (min_depth + 2.0)
# Apply depth filter
final_valid_indices = valid_point_indices[depth_filter]
# Only add if we have valid points after depth filtering
if len(final_valid_indices) > 0:
# Add corresponding points to filtered lists
filtered_points_xyz.append(points_xyz_world[final_valid_indices])
filtered_point_idxs.append(points_idxs[final_valid_indices])
filtered_points_color.append([color] * np.sum(depth_filter))
# Find the point with lowest depth in the filtered points
lowest_depth_idx = np.argmin(depths[depth_filter])
lowest_depth_point = final_valid_indices[lowest_depth_idx]
filtered_vertices_apex.append(points_xyz_world[lowest_depth_point])
filtered_vertices_apex_uv.append(centroids[i])
return filtered_points_xyz, filtered_point_idxs, filtered_points_color, filtered_vertices_apex, filtered_vertices_apex_uv
def get_vertexes(uv, gest_seg_np, house_mask, valid_indices, points_xyz_world, points_cam, points_idxs):
filtered_points_xyz_apex, filtered_point_idxs_apex, filtered_points_color_apex, filtered_vertices_apex, filtered_vertices_apex_uv = get_apex_or_eave_points('apex', uv, gest_seg_np, house_mask, valid_indices, points_xyz_world, points_cam, points_idxs)
filtered_points_xyz_eave, filtered_point_idxs_eave, filtered_points_color_eave, filtered_vertices_eave, filtered_vertices_eave_uv = get_apex_or_eave_points('eave_end', uv, gest_seg_np, house_mask, valid_indices, points_xyz_world, points_cam, points_idxs)
filtered_points_xyz_flashing, filtered_point_idxs_flashing, filtered_points_color_flashing, filtered_vertices_flashing, filtered_vertices_flashing_uv = get_apex_or_eave_points('flashing_end_point', uv, gest_seg_np, house_mask, valid_indices, points_xyz_world, points_cam, points_idxs)
#print(len(filtered_points_xyz_apex), len(filtered_points_xyz_eave), len(filtered_vertices_apex), len(filtered_vertices_eave), len(filtered_point_idxs_apex), len(filtered_point_idxs_eave))
# Combine filtered points from apex, eave_end, and flashing_end_point
filtered_points_xyz = filtered_points_xyz_apex + filtered_points_xyz_eave + filtered_points_xyz_flashing
filtered_point_idxs = filtered_point_idxs_apex + filtered_point_idxs_eave + filtered_point_idxs_flashing
filtered_points_color = filtered_points_color_apex + filtered_points_color_eave + filtered_points_color_flashing
#filtered_points_xyz = np.array(filtered_points_xyz[::-1]) if filtered_points_xyz else np.empty((0, 3))
#filtered_point_idxs = np.array(filtered_point_idxs[::-1]) if filtered_point_idxs else np.empty((0,))
#filtered_points_color = np.array(filtered_points_color[::-1]) if filtered_points_color else np.empty((0, 3))
filtered_vertices_apex = np.array(filtered_vertices_apex) if filtered_vertices_apex else np.empty((0, 3))
filtered_vertices_apex_uv = np.array(filtered_vertices_apex_uv) if filtered_vertices_apex_uv else np.empty((0, 2))
filtered_vertices_eave = np.array(filtered_vertices_eave) if filtered_vertices_eave else np.empty((0, 3))
filtered_vertices_eave_uv = np.array(filtered_vertices_eave_uv) if filtered_vertices_eave_uv else np.empty((0, 2))
filtered_vertices_flashing = np.array(filtered_vertices_flashing) if filtered_vertices_flashing else np.empty((0, 3))
filtered_vertices_flashing_uv = np.array(filtered_vertices_flashing_uv) if filtered_vertices_flashing_uv else np.empty((0, 2))
#print(len(filtered_points_xyz), len(filtered_point_idxs), len(filtered_vertices_apex), len(filtered_vertices_apex_uv), len(filtered_vertices_eave), len(filtered_vertices_eave_uv))
return filtered_points_xyz, filtered_point_idxs, filtered_points_color, filtered_vertices_apex, filtered_vertices_apex_uv, filtered_vertices_eave, filtered_vertices_eave_uv, filtered_vertices_flashing, filtered_vertices_flashing_uv
def get_connections(gest_seg_np, filtered_vertices_apex, filtered_vertices_eave, filtered_vertices_apex_uv, filtered_vertices_eave_uv):
connections = []
edge_classes = ['eave', 'ridge', 'rake', 'valley']
edge_th = 25.0 # threshold for proximity to line segments
# Combine apex and eave_end vertices and their UV coordinates
all_vertices_3d = []
all_vertices_uv = []
vertex_types = []
# Add apex vertices
for i, (vertex_3d, vertex_uv) in enumerate(zip(filtered_vertices_apex, filtered_vertices_apex_uv)):
all_vertices_3d.append(vertex_3d)
all_vertices_uv.append(vertex_uv)
vertex_types.append('apex')
# Add eave_end vertices
for i, (vertex_3d, vertex_uv) in enumerate(zip(filtered_vertices_eave, filtered_vertices_eave_uv)):
all_vertices_3d.append(vertex_3d)
all_vertices_uv.append(vertex_uv)
vertex_types.append('eave_end')
all_vertices_3d = np.array(all_vertices_3d)
all_vertices_uv = np.array(all_vertices_uv)
if len(all_vertices_uv) < 2:
vertices_formatted = []
for uv, vertex_type in zip(all_vertices_uv, vertex_types):
vertices_formatted.append({
'xy': np.array(uv, dtype=float),
'type': vertex_type
})
return vertices_formatted, [], all_vertices_3d
for edge_class in edge_classes:
edge_color = np.array(gestalt_color_mapping[edge_class])
mask_raw = cv2.inRange(gest_seg_np, edge_color-10, edge_color+10)
# Morphological operations to clean up the mask
kernel = np.ones((5, 5), np.uint8)
mask = cv2.morphologyEx(mask_raw, cv2.MORPH_CLOSE, kernel)
if mask.sum() == 0:
continue
# Connected components
output = cv2.connectedComponentsWithStats(mask, 8, cv2.CV_32S)
(numLabels, labels, stats, centroids) = output
# Skip the background
stats, centroids = stats[1:], centroids[1:]
label_indices = range(1, numLabels)
# For each connected component, do a line fit
for lbl in label_indices:
ys, xs = np.where(labels == lbl)
if len(xs) < 2:
continue
# Fit a line using cv2.fitLine
pts_for_fit = np.column_stack([xs, ys]).astype(np.float32)
line_params = cv2.fitLine(pts_for_fit, distType=cv2.DIST_L2,
param=0, reps=0.01, aeps=0.01)
vx, vy, x0, y0 = line_params.ravel()
# Find line segment endpoints by projecting points onto the line
proj = ((xs - x0)*vx + (ys - y0)*vy)
proj_min, proj_max = proj.min(), proj.max()
p1 = np.array([x0 + proj_min*vx, y0 + proj_min*vy])
p2 = np.array([x0 + proj_max*vx, y0 + proj_max*vy])
# Find vertices that are close to this line segment
if len(all_vertices_uv) < 2:
continue
# Calculate distance from each vertex UV to the line segment
dists = []
for vertex_uv in all_vertices_uv:
dist = point_to_segment_dist(vertex_uv, p1, p2)
dists.append(dist)
dists = np.array(dists)
# Find vertices that are near this line segment
near_mask = (dists <= edge_th)
near_indices = np.where(near_mask)[0]
if len(near_indices) < 2:
continue
# Connect each pair among these near vertices
for i in range(len(near_indices)):
for j in range(i+1, len(near_indices)):
idx_a = near_indices[i]
idx_b = near_indices[j]
# Create connection tuple (using sorted indices for consistency)
conn = tuple(sorted((idx_a, idx_b)))
if conn not in connections:
connections.append(conn)
# Convert all_vertices_uv and vertex_types to the required format
vertices_formatted = []
for uv, vertex_type in zip(all_vertices_uv, vertex_types):
vertices_formatted.append({
'xy': np.array(uv, dtype=float),
'type': vertex_type
})
return vertices_formatted, connections, all_vertices_3d
def visualize_3d_wireframe(colmap_rec, filtered_points_xyz, filtered_points_color, vertices_3d, connections):
segmented_points_3d = []
# Visualize with the segmented depth points in blue
pcd_all = o3d.geometry.PointCloud()
pcd_filtered = o3d.geometry.PointCloud()
pcd_depth = o3d.geometry.PointCloud()
# All points in gray
all_points = []
all_colors = []
for p3D in colmap_rec.points3D.values():
all_points.append(p3D.xyz)
all_colors.append([0.5, 0.5, 0.5]) # Gray color
if all_points:
pcd_all.points = o3d.utility.Vector3dVector(np.array(all_points))
pcd_all.colors = o3d.utility.Vector3dVector(np.array(all_colors))
# Filtered COLMAP points in red
if len(filtered_points_xyz) > 0:
pcd_filtered.points = o3d.utility.Vector3dVector(filtered_points_xyz)
pcd_filtered.colors = o3d.utility.Vector3dVector(np.array(filtered_points_color))
# Segmented depth points in blue
if len(segmented_points_3d) > 0:
pcd_depth.points = o3d.utility.Vector3dVector(segmented_points_3d)
pcd_depth.colors = o3d.utility.Vector3dVector(np.full((len(segmented_points_3d), 3), [0.0, 0.0, 1.0]))
# Visualize all point clouds and spheres
geometries = [pcd_all]
if len(filtered_points_xyz) > 0:
geometries.append(pcd_filtered)
if len(segmented_points_3d) > 0:
geometries.append(pcd_depth)
#o3d.visualization.draw_geometries(geometries, window_name=f"Combined Point Cloud - {img_id_substring}")
def generate_patches(colmap_rec, filtered_points_idxs, frame, filtered_vertices, vertices_formatted):
patches = []
gt_vertices = frame['wf_vertices']
# Process each group of filtered points
for group_idx, point_idxs in enumerate(filtered_points_idxs):
# Get 3D coordinates and colors for this group
group_points_3d = []
group_colors = []
assigned_gt_vertex = None
for pid in point_idxs:
p3d = colmap_rec.points3D[pid]
group_points_3d.append(p3d.xyz)
group_colors.append(p3d.color)
group_points_3d = np.array(group_points_3d)
group_colors = np.array(group_colors)
# Calculate centroid of filtered points
# Find the closest GT vertex to the centroid of filtered points
centroid = np.mean(group_points_3d, axis=0)
if len(gt_vertices) > 0:
# Calculate distances from centroid to all GT vertices
distances_to_gt = []
for gt_vertex in gt_vertices:
distance = np.linalg.norm(gt_vertex - centroid)
distances_to_gt.append(distance)
# Find the closest GT vertex
min_distance_idx = np.argmin(distances_to_gt)
closest_gt_vertex = gt_vertices[min_distance_idx]
min_distance = distances_to_gt[min_distance_idx]
# Define ball radius (you can adjust this value)
ball_radius = 2.0 # meters
# Use closest GT vertex as centroid if it's within the ball radius
if min_distance <= ball_radius:
assigned_gt_vertex = closest_gt_vertex
# If no GT vertex is close enough, skip this group
else:
assigned_gt_vertex = None
else:
# No GT vertices available, use original centroid
centroid = np.mean(group_points_3d, axis=0)
# Define ball radius (you can adjust this value)
ball_radius = 2.0 # meters
# Find all COLMAP points within the ball around centroid
patch_points_3d = []
patch_colors = []
patch_point_ids = []
for pid, p3d in colmap_rec.points3D.items():
distance = np.linalg.norm(p3d.xyz - centroid)
if distance <= ball_radius:
patch_points_3d.append(p3d.xyz)
patch_colors.append(p3d.color)
patch_point_ids.append(pid)
patch_points_3d = np.array(patch_points_3d)
# Calculate offset to center the patch
patch_centroid = np.mean(patch_points_3d, axis=0)
offset = -patch_centroid
# Shift points to center them around origin
patch_points_3d += offset
# Also shift the assigned GT vertex by the same offset if it exists
if assigned_gt_vertex is not None:
assigned_gt_vertex = assigned_gt_vertex + offset
patch_colors = np.array(patch_colors)
# Create 7D point cloud for this patch
# [x, y, z, r, g, b, in_filtered_flag]
patch_7d = np.zeros((len(patch_points_3d), 7))
patch_7d[:, :3] = patch_points_3d # xyz coordinates
patch_7d[:, 3:6] = patch_colors / 255.0 # rgb colors normalized to [0,1]
# Set in_filtered_flag: 1 if point was in original filtered set, 0 otherwise
for i, pid in enumerate(patch_point_ids):
if pid in point_idxs:
patch_7d[i, 6] = 1.0
else:
patch_7d[i, 6] = -1.0
if len(filtered_vertices) > 0 and filtered_vertices[group_idx] is not None:
initial_pred = filtered_vertices[group_idx] + offset
else:
initial_pred = None
if vertices_formatted[group_idx] is not None:
# Get the xy coordinates of the vertex
vertex_class = vertices_formatted[group_idx]['type']
patches.append({
'patch_7d': patch_7d,
'centroid': centroid,
'radius': ball_radius,
'point_ids': patch_point_ids,
'filtered_point_ids': point_idxs,
'group_idx': group_idx,
'assigned_gt_vertex': assigned_gt_vertex,
'offset': offset,
'initial_pred': initial_pred,
'vertex_class': vertex_class
})
if False:
# Create plotter
plotter = pv.Plotter()
# Create point cloud for this patch
patch_cloud = pv.PolyData(patch_points_3d)
# Color points: red for filtered points, blue for other points
patch_point_colors = []
for i, pid in enumerate(patch_point_ids):
if pid in point_idxs:
patch_point_colors.append([255, 0, 0]) # Red for filtered points
else:
patch_point_colors.append([0, 0, 255]) # Blue for other points
patch_cloud["colors"] = np.array(patch_point_colors)
plotter.add_mesh(patch_cloud, scalars="colors", rgb=True, point_size=8, render_points_as_spheres=True)
# Create sphere to visualize GT vertex if available
if assigned_gt_vertex is not None:
gt_sphere = pv.Sphere(radius=0.1, center=assigned_gt_vertex)
plotter.add_mesh(gt_sphere, color="green", opacity=0.5)
if initial_pred is not None:
# Create sphere to visualize initial prediction
pred_sphere = pv.Sphere(radius=0.1, center=initial_pred)
plotter.add_mesh(pred_sphere, color="orange", opacity=0.5)
plotter.show(title=f"Patch {group_idx}")
return patches
def our_get_vertices_and_edges(gest_seg_np, colmap_rec, img_id_substring, ade_seg, depth, K=None, R=None, t=None, frame=None):
"""
Identify apex and eave-end vertices, then detect lines for eave/ridge/rake/valley.
Also find all COLMAP points that project into apex or eave_end masks.
"""
#--------------------------------------------------------------------------------
# Step A: Collect apex and eave_end vertices
#--------------------------------------------------------------------------------
if not isinstance(gest_seg_np, np.ndarray):
gest_seg_np = np.array(gest_seg_np)
H, W = gest_seg_np.shape[:2]
# Get camera parameters from COLMAP reconstruction if not provided
if False:
# Find the matching COLMAP image
found_img = None
for img_id_c, col_img_obj in colmap_rec.images.items():
if img_id_substring in col_img_obj.name:
found_img = col_img_obj
break
if found_img is not None:
# Get camera intrinsic matrix
K = found_img.camera.calibration_matrix()
# Get world-to-camera transformation matrix
world_to_cam = found_img.cam_from_world.matrix()
R = world_to_cam[:3, :3]
t = world_to_cam[:3, 3]
else:
print(f"Image substring {img_id_substring} not found in COLMAP.")
return [], [], [], [], []
points_cam, points_xyz_world, points_idxs = get_visible_points(colmap_rec, img_id_substring, R=R, t=t)
uv, valid_indices = project_points_to_2d(points_cam, K, H, W)
if len(uv) == 0:
print(f"No points projected into image bounds for {img_id_substring} using K,R,t.")
return [], [], [], [], []
house_mask = get_house_mask(ade_seg)
filtered_points_xyz, filtered_point_idxs, filtered_points_color, filtered_vertices_apex, filtered_vertices_apex_uv, filtered_vertices_eave, filtered_vertices_eave_uv, _, _ = get_vertexes(uv, gest_seg_np, house_mask, valid_indices, points_xyz_world, points_cam, points_idxs)
vertices_formatted, connections, all_vertices_3d = get_connections(gest_seg_np, filtered_vertices_apex, filtered_vertices_eave, filtered_vertices_apex_uv, filtered_vertices_eave_uv)
#print(len(vertices_formatted), len(connections), len(all_vertices_3d))
#patches = generate_patches(colmap_rec, filtered_point_idxs, frame, all_vertices_3d, vertices_formatted)
patches = None
#visualize_3d_wireframe(colmap_rec, filtered_points_xyz, filtered_points_color, all_vertices_3d, connections)
return vertices_formatted, connections, all_vertices_3d, patches, filtered_point_idxs