Improves wireframe extraction and prediction
Browse filesRefactors wireframe extraction by merging overlapping point groups and updating connections.
Adds edge patch generation for training edge prediction models using point clouds near wireframe edges and non-edges, improving data creation and model training.
Adds visualization tools for point clouds and predicted vertices with connections to help debugging and evaluation.
Simplifies and streamlines vertex prediction by applying a PointNet model to extracted patches and filtering zero vertices.
- predict.py +619 -11
predict.py
CHANGED
|
@@ -12,10 +12,17 @@ import open3d as o3d
|
|
| 12 |
from visu import plot_reconstruction_local, plot_wireframe_local, plot_bpo_cameras_from_entry_local
|
| 13 |
import pyvista as pv
|
| 14 |
from fast_pointnet import save_patches_dataset, predict_vertex_from_patch
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
-
GENERATE_DATASET =
|
| 17 |
-
|
| 18 |
-
DATASET_DIR = '/mnt/personal/skvrnjan/hohocustom'
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
def convert_entry_to_human_readable(entry):
|
| 21 |
out = {}
|
|
@@ -544,7 +551,7 @@ def extract_vertices_from_whole_pcloud(colmap_rec, idxs_points, all_connections)
|
|
| 544 |
# and update connections accordingly
|
| 545 |
updated_connections = []
|
| 546 |
if extracted_points:
|
| 547 |
-
print(f"Merging groups based on point overlap... Processing {len(extracted_points)} groups")
|
| 548 |
# Create a list to track which groups to keep
|
| 549 |
groups_to_keep = []
|
| 550 |
merged_groups = set() # Track which groups have been merged
|
|
@@ -579,19 +586,36 @@ def extract_vertices_from_whole_pcloud(colmap_rec, idxs_points, all_connections)
|
|
| 579 |
merged_points = np.vstack([merged_points, extracted_points[j]]) if len(merged_points) > 0 else extracted_points[j]
|
| 580 |
merged_colors = np.vstack([merged_colors, extracted_colors[j]]) if len(merged_colors) > 0 else extracted_colors[j]
|
| 581 |
merged_ids.update(ids_j)
|
|
|
|
| 582 |
merged_groups.add(j)
|
| 583 |
|
| 584 |
# Add the merged group to the list of groups to keep
|
| 585 |
if len(merged_points) > 0:
|
|
|
|
| 586 |
groups_to_keep.append((merged_points, merged_colors, np.array(list(merged_ids))))
|
|
|
|
|
|
|
|
|
|
|
|
|
| 587 |
|
| 588 |
# Update extracted_points, extracted_colors, and extracted_ids with filtered results
|
| 589 |
extracted_points = [group[0] for group in groups_to_keep]
|
| 590 |
extracted_colors = [group[1] for group in groups_to_keep]
|
| 591 |
extracted_ids = [group[2] for group in groups_to_keep]
|
| 592 |
|
| 593 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 594 |
|
|
|
|
|
|
|
| 595 |
|
| 596 |
# Create visualization showing extracted points for each group as balls within their mean
|
| 597 |
if False:
|
|
@@ -629,17 +653,481 @@ def extract_vertices_from_whole_pcloud(colmap_rec, idxs_points, all_connections)
|
|
| 629 |
|
| 630 |
plotter.show(title=f"Extracted Points within {ball_radius}m radius - Spheres at group means")
|
| 631 |
|
| 632 |
-
return extracted_points, extracted_colors, extracted_ids, whole_pcloud
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 633 |
|
| 634 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 635 |
"""
|
| 636 |
Predict 3D wireframe from a dataset entry.
|
| 637 |
"""
|
| 638 |
good_entry = convert_entry_to_human_readable(entry)
|
| 639 |
colmap_rec = good_entry['colmap_binary']
|
| 640 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 641 |
vert_edge_per_image = {}
|
| 642 |
idxs_points = []
|
|
|
|
| 643 |
for i, (gest, depth, K, R, t, img_id, ade_seg) in enumerate(zip(good_entry['gestalt'],
|
| 644 |
good_entry['depth'],
|
| 645 |
good_entry['K'],
|
|
@@ -660,6 +1148,7 @@ def predict_wireframe(entry, pnet_model) -> Tuple[np.ndarray, List[int]]:
|
|
| 660 |
|
| 661 |
vertices_ours, connections_ours, vertices_3d_ours, patches, filtered_point_idxs = our_get_vertices_and_edges(gest_seg_np, colmap_rec, img_id, ade_seg, depth, K=K, R=R, t=t, frame=good_entry)
|
| 662 |
idxs_points.append(filtered_point_idxs)
|
|
|
|
| 663 |
|
| 664 |
'''
|
| 665 |
if GENERATE_DATASET:
|
|
@@ -722,21 +1211,119 @@ def predict_wireframe(entry, pnet_model) -> Tuple[np.ndarray, List[int]]:
|
|
| 722 |
|
| 723 |
vert_edge_per_image[i] = vertices, connections, vertices_3d
|
| 724 |
|
| 725 |
-
extracted_points, extracted_colors, extracted_ids, whole_pcloud = extract_vertices_from_whole_pcloud(colmap_rec, idxs_points)
|
| 726 |
|
| 727 |
patches = generate_patches_v2(extracted_points, extracted_colors, extracted_ids, whole_pcloud, good_entry['wf_vertices'])
|
| 728 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 729 |
if GENERATE_DATASET:
|
| 730 |
save_patches_dataset(patches, DATASET_DIR, img_id)
|
| 731 |
return empty_solution()
|
| 732 |
|
| 733 |
# Merge vertices from all images
|
| 734 |
-
all_3d_vertices, connections_3d = merge_vertices_3d(vert_edge_per_image, 0.1)
|
| 735 |
-
all_3d_vertices_clean, connections_3d_clean = all_3d_vertices, connections_3d
|
| 736 |
#all_3d_vertices_clean, connections_3d_clean = prune_not_connected(all_3d_vertices, connections_3d, keep_largest=False)
|
| 737 |
#all_3d_vertices_clean, connections_3d_clean = prune_too_far(all_3d_vertices_clean, connections_3d_clean, colmap_rec, th = 1.5)
|
| 738 |
|
| 739 |
-
if (len(all_3d_vertices_clean) < 2) or len(connections_3d_clean) < 1 and False:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 740 |
print (f'Not enough vertices or connections in the 3D vertices')
|
| 741 |
return empty_solution()
|
| 742 |
|
|
@@ -1352,6 +1939,27 @@ def our_get_vertices_and_edges(gest_seg_np, colmap_rec, img_id_substring, ade_se
|
|
| 1352 |
|
| 1353 |
H, W = gest_seg_np.shape[:2]
|
| 1354 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1355 |
points_cam, points_xyz_world, points_idxs = get_visible_points(colmap_rec, img_id_substring, R=R, t=t)
|
| 1356 |
|
| 1357 |
uv, valid_indices = project_points_to_2d(points_cam, K, H, W)
|
|
|
|
| 12 |
from visu import plot_reconstruction_local, plot_wireframe_local, plot_bpo_cameras_from_entry_local
|
| 13 |
import pyvista as pv
|
| 14 |
from fast_pointnet import save_patches_dataset, predict_vertex_from_patch
|
| 15 |
+
from fast_voxel import predict_vertex_from_patch_voxel
|
| 16 |
+
import time
|
| 17 |
+
from fast_pointnet_class import save_patches_dataset as save_patches_dataset_class
|
| 18 |
|
| 19 |
+
GENERATE_DATASET = False
|
| 20 |
+
DATASET_DIR = '/home/skvrnjan/personal/hohocustom/'
|
| 21 |
+
#DATASET_DIR = '/mnt/personal/skvrnjan/hohocustom'
|
| 22 |
+
|
| 23 |
+
GENERATE_DATASET_EDGES = True
|
| 24 |
+
#EDGES_DATASET_DIR = '/home/skvrnjan/personal/hohocustom_edges/'
|
| 25 |
+
EDGES_DATASET_DIR = '/mnt/personal/skvrnjan/hohocustom_edges'
|
| 26 |
|
| 27 |
def convert_entry_to_human_readable(entry):
|
| 28 |
out = {}
|
|
|
|
| 551 |
# and update connections accordingly
|
| 552 |
updated_connections = []
|
| 553 |
if extracted_points:
|
| 554 |
+
#print(f"Merging groups based on point overlap... Processing {len(extracted_points)} groups")
|
| 555 |
# Create a list to track which groups to keep
|
| 556 |
groups_to_keep = []
|
| 557 |
merged_groups = set() # Track which groups have been merged
|
|
|
|
| 586 |
merged_points = np.vstack([merged_points, extracted_points[j]]) if len(merged_points) > 0 else extracted_points[j]
|
| 587 |
merged_colors = np.vstack([merged_colors, extracted_colors[j]]) if len(merged_colors) > 0 else extracted_colors[j]
|
| 588 |
merged_ids.update(ids_j)
|
| 589 |
+
merged_indices.append(j)
|
| 590 |
merged_groups.add(j)
|
| 591 |
|
| 592 |
# Add the merged group to the list of groups to keep
|
| 593 |
if len(merged_points) > 0:
|
| 594 |
+
new_group_idx = len(groups_to_keep)
|
| 595 |
groups_to_keep.append((merged_points, merged_colors, np.array(list(merged_ids))))
|
| 596 |
+
|
| 597 |
+
# Update mapping for all merged indices
|
| 598 |
+
for old_idx in merged_indices:
|
| 599 |
+
old_to_new_mapping[old_idx] = new_group_idx
|
| 600 |
|
| 601 |
# Update extracted_points, extracted_colors, and extracted_ids with filtered results
|
| 602 |
extracted_points = [group[0] for group in groups_to_keep]
|
| 603 |
extracted_colors = [group[1] for group in groups_to_keep]
|
| 604 |
extracted_ids = [group[2] for group in groups_to_keep]
|
| 605 |
|
| 606 |
+
# Update connections based on the new mapping
|
| 607 |
+
for start_idx, end_idx in all_flattened_connections:
|
| 608 |
+
if start_idx in old_to_new_mapping and end_idx in old_to_new_mapping:
|
| 609 |
+
new_start = old_to_new_mapping[start_idx]
|
| 610 |
+
new_end = old_to_new_mapping[end_idx]
|
| 611 |
+
# Only add connection if vertices are still different after merging
|
| 612 |
+
if new_start != new_end:
|
| 613 |
+
connection = tuple(sorted((new_start, new_end)))
|
| 614 |
+
if connection not in updated_connections:
|
| 615 |
+
updated_connections.append(connection)
|
| 616 |
|
| 617 |
+
#print(f"After merging, number of groups: {len(extracted_points)}")
|
| 618 |
+
#print(f"Updated connections: {updated_connections}")
|
| 619 |
|
| 620 |
# Create visualization showing extracted points for each group as balls within their mean
|
| 621 |
if False:
|
|
|
|
| 653 |
|
| 654 |
plotter.show(title=f"Extracted Points within {ball_radius}m radius - Spheres at group means")
|
| 655 |
|
| 656 |
+
return extracted_points, extracted_colors, extracted_ids, whole_pcloud, updated_connections
|
| 657 |
+
|
| 658 |
+
def visu_pcloud_and_preds(colmap_rec, extracted_ids, extracted_points, extracted_colors, predicted_vertices, connections):
|
| 659 |
+
if extracted_ids:
|
| 660 |
+
plotter = pv.Plotter()
|
| 661 |
+
|
| 662 |
+
# Add all COLMAP points in gray
|
| 663 |
+
all_points = []
|
| 664 |
+
all_colors = []
|
| 665 |
+
for pid, p3D in colmap_rec.points3D.items():
|
| 666 |
+
all_points.append(p3D.xyz)
|
| 667 |
+
all_colors.append([0.8, 0.8, 0.8]) # Gray color
|
| 668 |
+
|
| 669 |
+
if all_points:
|
| 670 |
+
all_points = np.array(all_points)
|
| 671 |
+
all_colors = np.array(all_colors)
|
| 672 |
+
point_cloud = pv.PolyData(all_points)
|
| 673 |
+
point_cloud["colors"] = np.array(all_colors)
|
| 674 |
+
plotter.add_mesh(point_cloud, scalars="colors", rgb=True, point_size=3, render_points_as_spheres=True)
|
| 675 |
+
|
| 676 |
+
for group_idx, (group_points, group_colors) in enumerate(zip(extracted_points, extracted_colors)):
|
| 677 |
+
if len(group_points) > 0:
|
| 678 |
+
# Calculate mean position for this group
|
| 679 |
+
group_mean = np.mean(group_points, axis=0)
|
| 680 |
+
|
| 681 |
+
# Create a sphere at the mean position
|
| 682 |
+
sphere = pv.Sphere(radius=0.2, center=group_mean)
|
| 683 |
+
# Generate a random color for each group
|
| 684 |
+
group_color = np.random.rand(3)
|
| 685 |
+
plotter.add_mesh(sphere, color=group_color, opacity=0.5)
|
| 686 |
+
|
| 687 |
+
# Add the extracted points for this group in the same color
|
| 688 |
+
group_cloud = pv.PolyData(group_points)
|
| 689 |
+
plotter.add_mesh(group_cloud, color=group_color, point_size=6, render_points_as_spheres=True)
|
| 690 |
+
|
| 691 |
+
# Add predicted vertex as sphere if it exists and is valid
|
| 692 |
+
if group_idx < len(predicted_vertices):
|
| 693 |
+
pred_vertex = predicted_vertices[group_idx]
|
| 694 |
+
if not np.allclose(pred_vertex, [0.0, 0.0, 0.0]): # Check if it's not a zero vertex
|
| 695 |
+
pred_sphere = pv.Sphere(radius=0.15, center=pred_vertex)
|
| 696 |
+
plotter.add_mesh(pred_sphere, color="black", opacity=1.)
|
| 697 |
+
|
| 698 |
+
# Add connections between predicted vertices
|
| 699 |
+
if len(predicted_vertices) > 0 and len(connections) > 0:
|
| 700 |
+
valid_pred_vertices = []
|
| 701 |
+
valid_indices = []
|
| 702 |
+
for i, pred_vertex in enumerate(predicted_vertices):
|
| 703 |
+
if not np.allclose(pred_vertex, [0.0, 0.0, 0.0]):
|
| 704 |
+
valid_pred_vertices.append(pred_vertex)
|
| 705 |
+
valid_indices.append(i)
|
| 706 |
+
|
| 707 |
+
if len(valid_pred_vertices) > 1:
|
| 708 |
+
valid_pred_vertices = np.array(valid_pred_vertices)
|
| 709 |
+
|
| 710 |
+
# Create lines for connections
|
| 711 |
+
for start_idx, end_idx in connections:
|
| 712 |
+
if start_idx in valid_indices and end_idx in valid_indices:
|
| 713 |
+
# Map to valid vertex indices
|
| 714 |
+
valid_start = valid_indices.index(start_idx)
|
| 715 |
+
valid_end = valid_indices.index(end_idx)
|
| 716 |
+
|
| 717 |
+
# Create line between vertices
|
| 718 |
+
line_points = np.array([valid_pred_vertices[valid_start], valid_pred_vertices[valid_end]])
|
| 719 |
+
line = pv.Line(line_points[0], line_points[1])
|
| 720 |
+
plotter.add_mesh(line, color="red", line_width=3)
|
| 721 |
+
|
| 722 |
+
ball_radius = 1.0 # meters
|
| 723 |
+
plotter.show(title=f"Extracted Points within {ball_radius}m radius - Spheres at group means")
|
| 724 |
+
|
| 725 |
+
def generate_edge_patches(frame):
|
| 726 |
+
vertices = frame['wf_vertices']
|
| 727 |
+
connections = frame['wf_edges']
|
| 728 |
+
|
| 729 |
+
vertices = np.array(vertices) if vertices else np.empty((0, 3))
|
| 730 |
+
|
| 731 |
+
positive_patches = []
|
| 732 |
+
negative_patches = []
|
| 733 |
+
|
| 734 |
+
cylinder_radius = 0.5
|
| 735 |
+
|
| 736 |
+
colmap = frame['colmap_binary']
|
| 737 |
+
|
| 738 |
+
# Create 6D point cloud from COLMAP data
|
| 739 |
+
colmap_points_6d = []
|
| 740 |
+
for pid, p3D in colmap.points3D.items():
|
| 741 |
+
# Combine xyz coordinates and RGB color
|
| 742 |
+
point_6d = np.concatenate([p3D.xyz, p3D.color / 255.0]) # Normalize color to [0,1]
|
| 743 |
+
colmap_points_6d.append(point_6d)
|
| 744 |
+
|
| 745 |
+
colmap_points_6d = np.array(colmap_points_6d) if colmap_points_6d else np.empty((0, 6))
|
| 746 |
+
|
| 747 |
+
colmap_points_6d[:, 3:] = colmap_points_6d[:, 3:] * 2 - 1
|
| 748 |
+
|
| 749 |
+
# Extract 3D coordinates for faster vectorized operations
|
| 750 |
+
colmap_points_3d = colmap_points_6d[:, :3]
|
| 751 |
+
|
| 752 |
+
# For each connection, create a positive edge patch
|
| 753 |
+
for connection in connections:
|
| 754 |
+
start_idx, end_idx = connection
|
| 755 |
+
|
| 756 |
+
# Get start and end vertices from the connections
|
| 757 |
+
start_vertex = vertices[start_idx]
|
| 758 |
+
end_vertex = vertices[end_idx]
|
| 759 |
+
|
| 760 |
+
# Create line vector from start to end
|
| 761 |
+
line_vector = end_vertex - start_vertex
|
| 762 |
+
line_length = np.linalg.norm(line_vector)
|
| 763 |
+
|
| 764 |
+
# Normalize line vector
|
| 765 |
+
line_direction = line_vector / line_length
|
| 766 |
+
|
| 767 |
+
# Extend the line by 25 cm (0.25 meters) on both ends for more context
|
| 768 |
+
extension_length = 0.25 # 25 cm in meters
|
| 769 |
+
extended_start = start_vertex - extension_length * line_direction
|
| 770 |
+
extended_end = end_vertex + extension_length * line_direction
|
| 771 |
+
extended_line_length = line_length + 2 * extension_length
|
| 772 |
+
|
| 773 |
+
# Vectorized distance calculation
|
| 774 |
+
# Vector from extended start to all points
|
| 775 |
+
start_to_points = colmap_points_3d - extended_start[np.newaxis, :]
|
| 776 |
+
|
| 777 |
+
# Project onto line direction to get distance along extended line
|
| 778 |
+
projection_lengths = np.dot(start_to_points, line_direction)
|
| 779 |
+
|
| 780 |
+
# Filter points within extended line segment bounds
|
| 781 |
+
within_bounds = (projection_lengths >= 0) & (projection_lengths <= extended_line_length)
|
| 782 |
+
|
| 783 |
+
# Find closest points on extended line segment for all points
|
| 784 |
+
closest_points_on_line = extended_start[np.newaxis, :] + projection_lengths[:, np.newaxis] * line_direction[np.newaxis, :]
|
| 785 |
+
|
| 786 |
+
# Calculate perpendicular distances from points to line
|
| 787 |
+
perpendicular_distances = np.linalg.norm(colmap_points_3d - closest_points_on_line, axis=1)
|
| 788 |
+
|
| 789 |
+
# Find points within cylinder
|
| 790 |
+
within_cylinder = within_bounds & (perpendicular_distances <= cylinder_radius)
|
| 791 |
+
|
| 792 |
+
if np.sum(within_cylinder) <= 10:
|
| 793 |
+
continue
|
| 794 |
+
|
| 795 |
+
points_in_cylinder = colmap_points_6d[within_cylinder]
|
| 796 |
+
point_indices_in_cylinder = np.where(within_cylinder)[0]
|
| 797 |
+
|
| 798 |
+
# Center the patch at the midpoint of the original line (not extended)
|
| 799 |
+
line_midpoint = (start_vertex + end_vertex) / 2
|
| 800 |
+
|
| 801 |
+
# Shift points to center around origin
|
| 802 |
+
points_centered = points_in_cylinder.copy()
|
| 803 |
+
points_centered[:, :3] -= line_midpoint
|
| 804 |
+
|
| 805 |
+
# Create positive edge patch
|
| 806 |
+
positive_patch = {
|
| 807 |
+
'patch_6d': points_centered,
|
| 808 |
+
'connection': connection,
|
| 809 |
+
'line_start': start_vertex - line_midpoint,
|
| 810 |
+
'line_end': end_vertex - line_midpoint,
|
| 811 |
+
'cylinder_radius': cylinder_radius,
|
| 812 |
+
'point_indices': point_indices_in_cylinder,
|
| 813 |
+
'label': 1, # Positive label for edge
|
| 814 |
+
'center': line_midpoint
|
| 815 |
+
}
|
| 816 |
+
|
| 817 |
+
|
| 818 |
+
positive_patches.append(positive_patch)
|
| 819 |
+
|
| 820 |
+
# Generate negative edge patches by sampling random unconnected vertex pairs
|
| 821 |
+
num_negative_patches = len(positive_patches)
|
| 822 |
+
|
| 823 |
+
if num_negative_patches > 0 and len(vertices) >= 2:
|
| 824 |
+
# Create set of connected pairs for fast lookup
|
| 825 |
+
connected_pairs = set(tuple(sorted(conn)) for conn in connections)
|
| 826 |
+
|
| 827 |
+
# Generate all possible vertex pairs
|
| 828 |
+
vertex_indices = np.arange(len(vertices))
|
| 829 |
+
all_pairs = np.array(np.meshgrid(vertex_indices, vertex_indices)).T.reshape(-1, 2)
|
| 830 |
+
|
| 831 |
+
# Filter out pairs where both indices are the same
|
| 832 |
+
all_pairs = all_pairs[all_pairs[:, 0] != all_pairs[:, 1]]
|
| 833 |
+
|
| 834 |
+
# Sort pairs to match connected_pairs format
|
| 835 |
+
all_pairs_sorted = np.sort(all_pairs, axis=1)
|
| 836 |
+
|
| 837 |
+
# Find unconnected pairs
|
| 838 |
+
unconnected_mask = np.array([tuple(pair) not in connected_pairs for pair in all_pairs_sorted])
|
| 839 |
+
unconnected_pairs = all_pairs[unconnected_mask]
|
| 840 |
+
|
| 841 |
+
if len(unconnected_pairs) > 0:
|
| 842 |
+
# Pre-compute positive patch cylinder info for overlap checks
|
| 843 |
+
positive_cylinders = []
|
| 844 |
+
for pos_patch in positive_patches:
|
| 845 |
+
start_world = pos_patch['line_start'] + pos_patch['center']
|
| 846 |
+
end_world = pos_patch['line_end'] + pos_patch['center']
|
| 847 |
+
positive_cylinders.append({
|
| 848 |
+
'start': start_world,
|
| 849 |
+
'end': end_world,
|
| 850 |
+
'radius': pos_patch['cylinder_radius']
|
| 851 |
+
})
|
| 852 |
+
|
| 853 |
+
# Randomly sample negative pairs without replacement
|
| 854 |
+
num_to_sample = min(num_negative_patches * 3, len(unconnected_pairs)) # Sample more to account for rejections
|
| 855 |
+
sampled_indices = np.random.choice(len(unconnected_pairs), size=num_to_sample, replace=False)
|
| 856 |
+
sampled_pairs = unconnected_pairs[sampled_indices]
|
| 857 |
+
|
| 858 |
+
for idx1, idx2 in sampled_pairs:
|
| 859 |
+
if len(negative_patches) >= num_negative_patches:
|
| 860 |
+
break
|
| 861 |
+
|
| 862 |
+
start_vertex = vertices[idx1]
|
| 863 |
+
end_vertex = vertices[idx2]
|
| 864 |
+
|
| 865 |
+
# Create line vector from start to end
|
| 866 |
+
line_vector = end_vertex - start_vertex
|
| 867 |
+
line_length = np.linalg.norm(line_vector)
|
| 868 |
+
|
| 869 |
+
# Normalize line vector
|
| 870 |
+
line_direction = line_vector / line_length
|
| 871 |
+
|
| 872 |
+
# Extend the line by 25 cm (0.25 meters) on both ends for more context
|
| 873 |
+
extension_length = 0.25 # 25 cm in meters
|
| 874 |
+
extended_start = start_vertex - extension_length * line_direction
|
| 875 |
+
extended_end = end_vertex + extension_length * line_direction
|
| 876 |
+
extended_line_length = line_length + 2 * extension_length
|
| 877 |
+
|
| 878 |
+
# Check cylinder overlap with positive patches
|
| 879 |
+
current_cylinder = {
|
| 880 |
+
'start': extended_start,
|
| 881 |
+
'end': extended_end,
|
| 882 |
+
'radius': cylinder_radius
|
| 883 |
+
}
|
| 884 |
+
|
| 885 |
+
has_overlap = False
|
| 886 |
+
for pos_cylinder in positive_cylinders:
|
| 887 |
+
# Calculate cylinder-cylinder intersection volume
|
| 888 |
+
overlap_volume = calculate_cylinder_overlap_volume(current_cylinder, pos_cylinder)
|
| 889 |
+
|
| 890 |
+
# Calculate volumes of both cylinders
|
| 891 |
+
current_volume = np.pi * cylinder_radius**2 * extended_line_length
|
| 892 |
+
pos_height = np.linalg.norm(pos_cylinder['end'] - pos_cylinder['start'])
|
| 893 |
+
pos_volume = np.pi * pos_cylinder['radius']**2 * pos_height
|
| 894 |
+
|
| 895 |
+
# Calculate IoU
|
| 896 |
+
union_volume = current_volume + pos_volume - overlap_volume
|
| 897 |
+
if union_volume > 0:
|
| 898 |
+
iou = overlap_volume / union_volume
|
| 899 |
+
if iou > 0.5: # 0.2 IoU threshold
|
| 900 |
+
has_overlap = True
|
| 901 |
+
break
|
| 902 |
+
|
| 903 |
+
if has_overlap:
|
| 904 |
+
continue # Skip this negative patch due to cylinder overlap
|
| 905 |
+
|
| 906 |
+
# Vectorized distance calculation
|
| 907 |
+
# Vector from extended start to all points
|
| 908 |
+
start_to_points = colmap_points_3d - extended_start[np.newaxis, :]
|
| 909 |
+
|
| 910 |
+
# Project onto line direction to get distance along extended line
|
| 911 |
+
projection_lengths = np.dot(start_to_points, line_direction)
|
| 912 |
+
|
| 913 |
+
# Filter points within extended line segment bounds
|
| 914 |
+
within_bounds = (projection_lengths >= 0) & (projection_lengths <= extended_line_length)
|
| 915 |
+
|
| 916 |
+
# Find closest points on extended line segment for all points
|
| 917 |
+
closest_points_on_line = extended_start[np.newaxis, :] + projection_lengths[:, np.newaxis] * line_direction[np.newaxis, :]
|
| 918 |
+
|
| 919 |
+
# Calculate perpendicular distances from points to line
|
| 920 |
+
perpendicular_distances = np.linalg.norm(colmap_points_3d - closest_points_on_line, axis=1)
|
| 921 |
+
|
| 922 |
+
# Find points within cylinder
|
| 923 |
+
within_cylinder = within_bounds & (perpendicular_distances <= cylinder_radius)
|
| 924 |
+
|
| 925 |
+
if np.sum(within_cylinder) <= 10:
|
| 926 |
+
continue
|
| 927 |
+
|
| 928 |
+
points_in_cylinder = colmap_points_6d[within_cylinder]
|
| 929 |
+
point_indices_in_cylinder = np.where(within_cylinder)[0]
|
| 930 |
+
|
| 931 |
+
# Center the patch at the midpoint of the original line (not extended)
|
| 932 |
+
line_midpoint = (start_vertex + end_vertex) / 2
|
| 933 |
+
|
| 934 |
+
# Shift points to center around origin
|
| 935 |
+
points_centered = points_in_cylinder.copy()
|
| 936 |
+
points_centered[:, :3] -= line_midpoint
|
| 937 |
+
|
| 938 |
+
# Create negative edge patch
|
| 939 |
+
negative_patch = {
|
| 940 |
+
'patch_6d': points_centered,
|
| 941 |
+
'connection': (idx1, idx2),
|
| 942 |
+
'line_start': start_vertex - line_midpoint,
|
| 943 |
+
'line_end': end_vertex - line_midpoint,
|
| 944 |
+
'cylinder_radius': cylinder_radius,
|
| 945 |
+
'point_indices': point_indices_in_cylinder,
|
| 946 |
+
'label': 0, # Negative label for non-edge
|
| 947 |
+
'center': line_midpoint # Center of the patch
|
| 948 |
+
}
|
| 949 |
+
|
| 950 |
+
negative_patches.append(negative_patch)
|
| 951 |
+
|
| 952 |
+
print(f"Generated {len(positive_patches)} positive patches and {len(negative_patches)} negative patches")
|
| 953 |
+
all_patches = positive_patches + negative_patches
|
| 954 |
+
|
| 955 |
+
# Visualize edge patches
|
| 956 |
+
if False: # Set to True to enable visualization
|
| 957 |
+
# Create plotter
|
| 958 |
+
plotter = pv.Plotter()
|
| 959 |
+
|
| 960 |
+
# Add whole point cloud in gray
|
| 961 |
+
if len(colmap_points_6d) > 0:
|
| 962 |
+
whole_cloud = pv.PolyData(colmap_points_3d)
|
| 963 |
+
gray_colors = np.full((len(colmap_points_3d), 3), [0.5, 0.5, 0.5])
|
| 964 |
+
whole_cloud["colors"] = gray_colors
|
| 965 |
+
plotter.add_mesh(whole_cloud, scalars="colors", rgb=True, point_size=3, render_points_as_spheres=True)
|
| 966 |
+
|
| 967 |
+
# Visualize each patch
|
| 968 |
+
for patch_idx, patch in enumerate(all_patches):
|
| 969 |
+
# Use green for positive (edge), red for negative (non-edge)
|
| 970 |
+
patch_color = 'green' if patch['label'] == 1 else 'red'
|
| 971 |
+
|
| 972 |
+
# Get patch data
|
| 973 |
+
points_in_cylinder = patch['patch_6d'][:, :3] # xyz coordinates
|
| 974 |
+
line_start = patch['line_start']
|
| 975 |
+
line_end = patch['line_end']
|
| 976 |
+
center = patch['center'] # Use center instead of calculating midpoint
|
| 977 |
+
|
| 978 |
+
# Shift points back to world coordinates for visualization
|
| 979 |
+
points_world = points_in_cylinder + center
|
| 980 |
+
|
| 981 |
+
# Add points inside cylinder with patch-specific color
|
| 982 |
+
if len(points_world) > 0:
|
| 983 |
+
cylinder_cloud = pv.PolyData(points_world)
|
| 984 |
+
plotter.add_mesh(cylinder_cloud, color=patch_color, point_size=8, render_points_as_spheres=True)
|
| 985 |
+
|
| 986 |
+
# Add start and end points as larger spheres
|
| 987 |
+
start_sphere = pv.Sphere(radius=0.1, center=line_start + center)
|
| 988 |
+
end_sphere = pv.Sphere(radius=0.1, center=line_end + center)
|
| 989 |
+
plotter.add_mesh(start_sphere, color='black', opacity=0.8)
|
| 990 |
+
plotter.add_mesh(end_sphere, color='white', opacity=0.8)
|
| 991 |
+
|
| 992 |
+
# Add line between start and end
|
| 993 |
+
line_points = np.array([line_start + center, line_end + center])
|
| 994 |
+
line = pv.Line(line_points[0], line_points[1])
|
| 995 |
+
plotter.add_mesh(line, color=patch_color, line_width=5)
|
| 996 |
+
|
| 997 |
+
# Add cylinder wireframe to show extraction bounds
|
| 998 |
+
cylinder_center = center
|
| 999 |
+
cylinder_direction = (line_end - line_start) / np.linalg.norm(line_end - line_start)
|
| 1000 |
+
cylinder_height = np.linalg.norm(line_end - line_start) + 2 * 0.25 # Including extensions
|
| 1001 |
+
|
| 1002 |
+
# Create cylinder mesh for visualization
|
| 1003 |
+
cylinder_mesh = pv.Cylinder(center=cylinder_center, direction=cylinder_direction,
|
| 1004 |
+
radius=patch['cylinder_radius'], height=cylinder_height)
|
| 1005 |
+
plotter.add_mesh(cylinder_mesh, color=patch_color, opacity=0.2, style='wireframe')
|
| 1006 |
+
|
| 1007 |
+
# Set title based on label distribution
|
| 1008 |
+
positive_count = sum(1 for patch in all_patches if patch['label'] == 1)
|
| 1009 |
+
negative_count = sum(1 for patch in all_patches if patch['label'] == 0)
|
| 1010 |
+
title = f"Edge Patches - Positive (Green): {positive_count}, Negative (Red): {negative_count}"
|
| 1011 |
+
|
| 1012 |
+
plotter.show(title=title)
|
| 1013 |
+
|
| 1014 |
+
return all_patches
|
| 1015 |
|
| 1016 |
+
def calculate_cylinder_overlap_volume(cyl1, cyl2):
|
| 1017 |
+
"""
|
| 1018 |
+
Calculate the intersection volume between two cylinders using numpy vectorization.
|
| 1019 |
+
Returns approximate overlap volume.
|
| 1020 |
+
"""
|
| 1021 |
+
# Get cylinder parameters
|
| 1022 |
+
p1_start, p1_end = cyl1['start'], cyl1['end']
|
| 1023 |
+
p2_start, p2_end = cyl2['start'], cyl2['end']
|
| 1024 |
+
r1, r2 = cyl1['radius'], cyl2['radius']
|
| 1025 |
+
|
| 1026 |
+
# Calculate cylinder axes
|
| 1027 |
+
axis1 = p1_end - p1_start
|
| 1028 |
+
axis2 = p2_end - p2_start
|
| 1029 |
+
len1 = np.linalg.norm(axis1)
|
| 1030 |
+
len2 = np.linalg.norm(axis2)
|
| 1031 |
+
|
| 1032 |
+
if len1 == 0 or len2 == 0:
|
| 1033 |
+
return 0.0
|
| 1034 |
+
|
| 1035 |
+
axis1_norm = axis1 / len1
|
| 1036 |
+
axis2_norm = axis2 / len2
|
| 1037 |
+
|
| 1038 |
+
# Calculate distance between cylinder axes using line-line distance formula
|
| 1039 |
+
w = p1_start - p2_start
|
| 1040 |
+
a = np.dot(axis1_norm, axis1_norm)
|
| 1041 |
+
b = np.dot(axis1_norm, axis2_norm)
|
| 1042 |
+
c = np.dot(axis2_norm, axis2_norm)
|
| 1043 |
+
d = np.dot(axis1_norm, w)
|
| 1044 |
+
e = np.dot(axis2_norm, w)
|
| 1045 |
+
|
| 1046 |
+
denom = a * c - b * b
|
| 1047 |
+
if abs(denom) < 1e-10: # Lines are parallel
|
| 1048 |
+
# Calculate perpendicular distance between parallel lines
|
| 1049 |
+
cross_product = np.cross(axis1_norm, w)
|
| 1050 |
+
if axis1_norm.shape[0] == 3: # 3D case
|
| 1051 |
+
dist = np.linalg.norm(cross_product)
|
| 1052 |
+
else: # 2D case
|
| 1053 |
+
dist = abs(cross_product)
|
| 1054 |
+
else:
|
| 1055 |
+
# Calculate closest points on both lines
|
| 1056 |
+
t1 = (b * e - c * d) / denom
|
| 1057 |
+
t2 = (a * e - b * d) / denom
|
| 1058 |
+
|
| 1059 |
+
# Clamp to cylinder bounds
|
| 1060 |
+
t1 = np.clip(t1, 0, len1)
|
| 1061 |
+
t2 = np.clip(t2, 0, len2)
|
| 1062 |
+
|
| 1063 |
+
# Calculate distance between closest points
|
| 1064 |
+
point1 = p1_start + t1 * axis1_norm
|
| 1065 |
+
point2 = p2_start + t2 * axis2_norm
|
| 1066 |
+
dist = np.linalg.norm(point1 - point2)
|
| 1067 |
+
|
| 1068 |
+
# If cylinders don't intersect radially, return 0
|
| 1069 |
+
if dist >= (r1 + r2):
|
| 1070 |
+
return 0.0
|
| 1071 |
+
|
| 1072 |
+
# Calculate overlapping length along both axes
|
| 1073 |
+
# Project cylinder 2 endpoints onto cylinder 1 axis
|
| 1074 |
+
proj_start = np.dot(p2_start - p1_start, axis1_norm)
|
| 1075 |
+
proj_end = np.dot(p2_end - p1_start, axis1_norm)
|
| 1076 |
+
|
| 1077 |
+
# Find overlap interval
|
| 1078 |
+
overlap_start = max(0, min(proj_start, proj_end))
|
| 1079 |
+
overlap_end = min(len1, max(proj_start, proj_end))
|
| 1080 |
+
overlap_length = max(0, overlap_end - overlap_start)
|
| 1081 |
+
|
| 1082 |
+
if overlap_length <= 0:
|
| 1083 |
+
return 0.0
|
| 1084 |
+
|
| 1085 |
+
# Approximate volume calculation
|
| 1086 |
+
# For simplicity, assume uniform overlap along the length
|
| 1087 |
+
if dist < abs(r1 - r2):
|
| 1088 |
+
# One cylinder is inside the other
|
| 1089 |
+
smaller_radius = min(r1, r2)
|
| 1090 |
+
overlap_volume = np.pi * smaller_radius**2 * overlap_length
|
| 1091 |
+
else:
|
| 1092 |
+
# Partial overlap - use geometric approximation
|
| 1093 |
+
# This is a simplified calculation for the intersection area of two circles
|
| 1094 |
+
r_smaller = min(r1, r2)
|
| 1095 |
+
r_larger = max(r1, r2)
|
| 1096 |
+
|
| 1097 |
+
if dist < (r1 + r2):
|
| 1098 |
+
# Calculate intersection area of two circles (approximate)
|
| 1099 |
+
# Using lens area formula
|
| 1100 |
+
d1 = (r1**2 - r2**2 + dist**2) / (2 * dist) if dist > 0 else 0
|
| 1101 |
+
d2 = dist - d1
|
| 1102 |
+
|
| 1103 |
+
if d1 >= 0 and d1 <= r1 and d2 >= 0 and d2 <= r2:
|
| 1104 |
+
area1 = r1**2 * np.arccos(d1/r1) - d1 * np.sqrt(r1**2 - d1**2)
|
| 1105 |
+
area2 = r2**2 * np.arccos(d2/r2) - d2 * np.sqrt(r2**2 - d2**2)
|
| 1106 |
+
intersection_area = area1 + area2
|
| 1107 |
+
else:
|
| 1108 |
+
intersection_area = np.pi * r_smaller**2
|
| 1109 |
+
|
| 1110 |
+
overlap_volume = intersection_area * overlap_length
|
| 1111 |
+
else:
|
| 1112 |
+
overlap_volume = 0.0
|
| 1113 |
+
|
| 1114 |
+
return max(0.0, overlap_volume)
|
| 1115 |
+
|
| 1116 |
+
def predict_wireframe(entry, pnet_model, voxel_model) -> Tuple[np.ndarray, List[int]]:
|
| 1117 |
"""
|
| 1118 |
Predict 3D wireframe from a dataset entry.
|
| 1119 |
"""
|
| 1120 |
good_entry = convert_entry_to_human_readable(entry)
|
| 1121 |
colmap_rec = good_entry['colmap_binary']
|
| 1122 |
|
| 1123 |
+
if GENERATE_DATASET_EDGES:
|
| 1124 |
+
patches = generate_edge_patches(good_entry)
|
| 1125 |
+
save_patches_dataset_class(patches, EDGES_DATASET_DIR, good_entry['order_id'])
|
| 1126 |
+
return empty_solution()
|
| 1127 |
+
|
| 1128 |
vert_edge_per_image = {}
|
| 1129 |
idxs_points = []
|
| 1130 |
+
all_connections = []
|
| 1131 |
for i, (gest, depth, K, R, t, img_id, ade_seg) in enumerate(zip(good_entry['gestalt'],
|
| 1132 |
good_entry['depth'],
|
| 1133 |
good_entry['K'],
|
|
|
|
| 1148 |
|
| 1149 |
vertices_ours, connections_ours, vertices_3d_ours, patches, filtered_point_idxs = our_get_vertices_and_edges(gest_seg_np, colmap_rec, img_id, ade_seg, depth, K=K, R=R, t=t, frame=good_entry)
|
| 1150 |
idxs_points.append(filtered_point_idxs)
|
| 1151 |
+
all_connections.append(connections_ours)
|
| 1152 |
|
| 1153 |
'''
|
| 1154 |
if GENERATE_DATASET:
|
|
|
|
| 1211 |
|
| 1212 |
vert_edge_per_image[i] = vertices, connections, vertices_3d
|
| 1213 |
|
| 1214 |
+
extracted_points, extracted_colors, extracted_ids, whole_pcloud, connections = extract_vertices_from_whole_pcloud(colmap_rec, idxs_points, all_connections)
|
| 1215 |
|
| 1216 |
patches = generate_patches_v2(extracted_points, extracted_colors, extracted_ids, whole_pcloud, good_entry['wf_vertices'])
|
| 1217 |
|
| 1218 |
+
# Predict vertices from patches using the neural network
|
| 1219 |
+
predicted_vertices = []
|
| 1220 |
+
for patch in patches:
|
| 1221 |
+
pred_vertex, pred_dist, pred_class = predict_vertex_from_patch(pnet_model, patch, device='cuda')
|
| 1222 |
+
|
| 1223 |
+
#visu_patch_and_pred(patch, pred_vertex, pred_dist, pred_class)
|
| 1224 |
+
|
| 1225 |
+
if pred_class > 0.5:
|
| 1226 |
+
predicted_vertices.append(pred_vertex)
|
| 1227 |
+
else:
|
| 1228 |
+
predicted_vertices.append(np.array([0.0, 0.0, 0.0])) # Append a zero vertex if not predicted
|
| 1229 |
+
|
| 1230 |
+
#pred_vertex_voxel, pred_dist_voxel, pred_class_voxel = predict_vertex_from_patch_voxel(voxel_model, patch, device='cuda')
|
| 1231 |
+
#visu_patch_and_pred(patch, pred_vertex_voxel, pred_dist_voxel, pred_class_voxel)
|
| 1232 |
+
|
| 1233 |
+
predicted_vertices = np.array(predicted_vertices) if predicted_vertices else np.empty((0, 3))
|
| 1234 |
+
|
| 1235 |
+
#visu_pcloud_and_preds(colmap_rec, extracted_ids, extracted_points, extracted_colors, predicted_vertices, connections)
|
| 1236 |
+
|
| 1237 |
if GENERATE_DATASET:
|
| 1238 |
save_patches_dataset(patches, DATASET_DIR, img_id)
|
| 1239 |
return empty_solution()
|
| 1240 |
|
| 1241 |
# Merge vertices from all images
|
| 1242 |
+
#all_3d_vertices, connections_3d = merge_vertices_3d(vert_edge_per_image, 0.1)
|
| 1243 |
+
#all_3d_vertices_clean, connections_3d_clean = all_3d_vertices, connections_3d
|
| 1244 |
#all_3d_vertices_clean, connections_3d_clean = prune_not_connected(all_3d_vertices, connections_3d, keep_largest=False)
|
| 1245 |
#all_3d_vertices_clean, connections_3d_clean = prune_too_far(all_3d_vertices_clean, connections_3d_clean, colmap_rec, th = 1.5)
|
| 1246 |
|
| 1247 |
+
#if (len(all_3d_vertices_clean) < 2) or len(connections_3d_clean) < 1 and False:
|
| 1248 |
+
# print (f'Not enough vertices or connections in the 3D vertices')
|
| 1249 |
+
# return empty_solution()
|
| 1250 |
+
|
| 1251 |
+
# Filter out zero vertices and update connections accordingly
|
| 1252 |
+
non_zero_mask = ~np.all(np.isclose(predicted_vertices, [0.0, 0.0, 0.0]), axis=1)
|
| 1253 |
+
valid_indices = np.where(non_zero_mask)[0]
|
| 1254 |
+
|
| 1255 |
+
if len(valid_indices) == 0:
|
| 1256 |
+
print("No valid predicted vertices found")
|
| 1257 |
+
return empty_solution()
|
| 1258 |
+
|
| 1259 |
+
# Filter vertices to only include non-zero ones
|
| 1260 |
+
filtered_vertices = predicted_vertices[valid_indices]
|
| 1261 |
+
|
| 1262 |
+
# Create mapping from old indices to new indices
|
| 1263 |
+
old_to_new_mapping = {old_idx: new_idx for new_idx, old_idx in enumerate(valid_indices)}
|
| 1264 |
+
|
| 1265 |
+
# Filter and update connections
|
| 1266 |
+
filtered_connections = []
|
| 1267 |
+
for start_idx, end_idx in connections:
|
| 1268 |
+
if start_idx in old_to_new_mapping and end_idx in old_to_new_mapping:
|
| 1269 |
+
new_start = old_to_new_mapping[start_idx]
|
| 1270 |
+
new_end = old_to_new_mapping[end_idx]
|
| 1271 |
+
if new_start != new_end: # Ensure we don't connect a vertex to itself
|
| 1272 |
+
filtered_connections.append((new_start, new_end))
|
| 1273 |
+
|
| 1274 |
+
#print(f"Filtered vertices: {len(filtered_vertices)} from {len(predicted_vertices)}")
|
| 1275 |
+
#print(f"Filtered connections: {len(filtered_connections)} from {len(connections)}")
|
| 1276 |
+
|
| 1277 |
+
predicted_vertices = np.array(filtered_vertices)
|
| 1278 |
+
connections = filtered_connections
|
| 1279 |
+
|
| 1280 |
+
return predicted_vertices, connections
|
| 1281 |
+
|
| 1282 |
+
def predict_wireframe_old(entry) -> Tuple[np.ndarray, List[int]]:
|
| 1283 |
+
"""
|
| 1284 |
+
Predict 3D wireframe from a dataset entry.
|
| 1285 |
+
"""
|
| 1286 |
+
good_entry = convert_entry_to_human_readable(entry)
|
| 1287 |
+
vert_edge_per_image = {}
|
| 1288 |
+
for i, (gest, depth, K, R, t, img_id, ade_seg) in enumerate(zip(good_entry['gestalt'],
|
| 1289 |
+
good_entry['depth'],
|
| 1290 |
+
good_entry['K'],
|
| 1291 |
+
good_entry['R'],
|
| 1292 |
+
good_entry['t'],
|
| 1293 |
+
good_entry['image_ids'],
|
| 1294 |
+
good_entry['ade'] # Added ade20k segmentation
|
| 1295 |
+
)):
|
| 1296 |
+
colmap_rec = good_entry['colmap_binary']
|
| 1297 |
+
K = np.array(K)
|
| 1298 |
+
R = np.array(R)
|
| 1299 |
+
t = np.array(t)
|
| 1300 |
+
# Resize gestalt segmentation to match depth map size
|
| 1301 |
+
depth_size = (np.array(depth).shape[1], np.array(depth).shape[0]) # W, H
|
| 1302 |
+
gest_seg = gest.resize(depth_size)
|
| 1303 |
+
gest_seg_np = np.array(gest_seg).astype(np.uint8)
|
| 1304 |
+
|
| 1305 |
+
# Get 2D vertices and edges first
|
| 1306 |
+
vertices, connections = get_vertices_and_edges_from_segmentation(gest_seg_np, edge_th=25.)
|
| 1307 |
+
|
| 1308 |
+
# Check if we have enough to proceed
|
| 1309 |
+
if (len(vertices) < 2) or (len(connections) < 1):
|
| 1310 |
+
print(f'Not enough vertices or connections found in image {i}, skipping.')
|
| 1311 |
+
vert_edge_per_image[i] = [], [], np.empty((0, 3))
|
| 1312 |
+
continue
|
| 1313 |
+
|
| 1314 |
+
# Call the refactored function to get 3D points
|
| 1315 |
+
vertices_3d = create_3d_wireframe_single_image(
|
| 1316 |
+
vertices, connections, depth, colmap_rec, img_id, ade_seg, K, R, t
|
| 1317 |
+
)
|
| 1318 |
+
# Store original 2D vertices, connections, and computed 3D points
|
| 1319 |
+
vert_edge_per_image[i] = vertices, connections, vertices_3d
|
| 1320 |
+
|
| 1321 |
+
# Merge vertices from all images
|
| 1322 |
+
all_3d_vertices, connections_3d = merge_vertices_3d(vert_edge_per_image, 0.5)
|
| 1323 |
+
all_3d_vertices_clean, connections_3d_clean = prune_not_connected(all_3d_vertices, connections_3d, keep_largest=False)
|
| 1324 |
+
all_3d_vertices_clean, connections_3d_clean = prune_too_far(all_3d_vertices_clean, connections_3d_clean, colmap_rec, th = 1.5)
|
| 1325 |
+
|
| 1326 |
+
if (len(all_3d_vertices_clean) < 2) or len(connections_3d_clean) < 1:
|
| 1327 |
print (f'Not enough vertices or connections in the 3D vertices')
|
| 1328 |
return empty_solution()
|
| 1329 |
|
|
|
|
| 1939 |
|
| 1940 |
H, W = gest_seg_np.shape[:2]
|
| 1941 |
|
| 1942 |
+
# Get camera parameters from COLMAP reconstruction if not provided
|
| 1943 |
+
if False:
|
| 1944 |
+
# Find the matching COLMAP image
|
| 1945 |
+
found_img = None
|
| 1946 |
+
for img_id_c, col_img_obj in colmap_rec.images.items():
|
| 1947 |
+
if img_id_substring in col_img_obj.name:
|
| 1948 |
+
found_img = col_img_obj
|
| 1949 |
+
break
|
| 1950 |
+
|
| 1951 |
+
if found_img is not None:
|
| 1952 |
+
# Get camera intrinsic matrix
|
| 1953 |
+
K = found_img.camera.calibration_matrix()
|
| 1954 |
+
|
| 1955 |
+
# Get world-to-camera transformation matrix
|
| 1956 |
+
world_to_cam = found_img.cam_from_world.matrix()
|
| 1957 |
+
R = world_to_cam[:3, :3]
|
| 1958 |
+
t = world_to_cam[:3, 3]
|
| 1959 |
+
else:
|
| 1960 |
+
print(f"Image substring {img_id_substring} not found in COLMAP.")
|
| 1961 |
+
return [], [], [], [], []
|
| 1962 |
+
|
| 1963 |
points_cam, points_xyz_world, points_idxs = get_visible_points(colmap_rec, img_id_substring, R=R, t=t)
|
| 1964 |
|
| 1965 |
uv, valid_indices = project_points_to_2d(points_cam, K, H, W)
|