jskvrna commited on
Commit
db367cd
·
1 Parent(s): 322171e

Improves wireframe extraction and prediction

Browse files

Refactors wireframe extraction by merging overlapping point groups and updating connections.

Adds edge patch generation for training edge prediction models using point clouds near wireframe edges and non-edges, improving data creation and model training.

Adds visualization tools for point clouds and predicted vertices with connections to help debugging and evaluation.

Simplifies and streamlines vertex prediction by applying a PointNet model to extracted patches and filtering zero vertices.

Files changed (1) hide show
  1. predict.py +619 -11
predict.py CHANGED
@@ -12,10 +12,17 @@ import open3d as o3d
12
  from visu import plot_reconstruction_local, plot_wireframe_local, plot_bpo_cameras_from_entry_local
13
  import pyvista as pv
14
  from fast_pointnet import save_patches_dataset, predict_vertex_from_patch
 
 
 
15
 
16
- GENERATE_DATASET = True
17
- #DATASET_DIR = '/home/skvrnjan/personal/hohocustom/'
18
- DATASET_DIR = '/mnt/personal/skvrnjan/hohocustom'
 
 
 
 
19
 
20
  def convert_entry_to_human_readable(entry):
21
  out = {}
@@ -544,7 +551,7 @@ def extract_vertices_from_whole_pcloud(colmap_rec, idxs_points, all_connections)
544
  # and update connections accordingly
545
  updated_connections = []
546
  if extracted_points:
547
- print(f"Merging groups based on point overlap... Processing {len(extracted_points)} groups")
548
  # Create a list to track which groups to keep
549
  groups_to_keep = []
550
  merged_groups = set() # Track which groups have been merged
@@ -579,19 +586,36 @@ def extract_vertices_from_whole_pcloud(colmap_rec, idxs_points, all_connections)
579
  merged_points = np.vstack([merged_points, extracted_points[j]]) if len(merged_points) > 0 else extracted_points[j]
580
  merged_colors = np.vstack([merged_colors, extracted_colors[j]]) if len(merged_colors) > 0 else extracted_colors[j]
581
  merged_ids.update(ids_j)
 
582
  merged_groups.add(j)
583
 
584
  # Add the merged group to the list of groups to keep
585
  if len(merged_points) > 0:
 
586
  groups_to_keep.append((merged_points, merged_colors, np.array(list(merged_ids))))
 
 
 
 
587
 
588
  # Update extracted_points, extracted_colors, and extracted_ids with filtered results
589
  extracted_points = [group[0] for group in groups_to_keep]
590
  extracted_colors = [group[1] for group in groups_to_keep]
591
  extracted_ids = [group[2] for group in groups_to_keep]
592
 
593
- print(f"After merging, number of groups: {len(extracted_points)}")
 
 
 
 
 
 
 
 
 
594
 
 
 
595
 
596
  # Create visualization showing extracted points for each group as balls within their mean
597
  if False:
@@ -629,17 +653,481 @@ def extract_vertices_from_whole_pcloud(colmap_rec, idxs_points, all_connections)
629
 
630
  plotter.show(title=f"Extracted Points within {ball_radius}m radius - Spheres at group means")
631
 
632
- return extracted_points, extracted_colors, extracted_ids, whole_pcloud
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
633
 
634
- def predict_wireframe(entry, pnet_model) -> Tuple[np.ndarray, List[int]]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
635
  """
636
  Predict 3D wireframe from a dataset entry.
637
  """
638
  good_entry = convert_entry_to_human_readable(entry)
639
  colmap_rec = good_entry['colmap_binary']
640
 
 
 
 
 
 
641
  vert_edge_per_image = {}
642
  idxs_points = []
 
643
  for i, (gest, depth, K, R, t, img_id, ade_seg) in enumerate(zip(good_entry['gestalt'],
644
  good_entry['depth'],
645
  good_entry['K'],
@@ -660,6 +1148,7 @@ def predict_wireframe(entry, pnet_model) -> Tuple[np.ndarray, List[int]]:
660
 
661
  vertices_ours, connections_ours, vertices_3d_ours, patches, filtered_point_idxs = our_get_vertices_and_edges(gest_seg_np, colmap_rec, img_id, ade_seg, depth, K=K, R=R, t=t, frame=good_entry)
662
  idxs_points.append(filtered_point_idxs)
 
663
 
664
  '''
665
  if GENERATE_DATASET:
@@ -722,21 +1211,119 @@ def predict_wireframe(entry, pnet_model) -> Tuple[np.ndarray, List[int]]:
722
 
723
  vert_edge_per_image[i] = vertices, connections, vertices_3d
724
 
725
- extracted_points, extracted_colors, extracted_ids, whole_pcloud = extract_vertices_from_whole_pcloud(colmap_rec, idxs_points)
726
 
727
  patches = generate_patches_v2(extracted_points, extracted_colors, extracted_ids, whole_pcloud, good_entry['wf_vertices'])
728
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
729
  if GENERATE_DATASET:
730
  save_patches_dataset(patches, DATASET_DIR, img_id)
731
  return empty_solution()
732
 
733
  # Merge vertices from all images
734
- all_3d_vertices, connections_3d = merge_vertices_3d(vert_edge_per_image, 0.1)
735
- all_3d_vertices_clean, connections_3d_clean = all_3d_vertices, connections_3d
736
  #all_3d_vertices_clean, connections_3d_clean = prune_not_connected(all_3d_vertices, connections_3d, keep_largest=False)
737
  #all_3d_vertices_clean, connections_3d_clean = prune_too_far(all_3d_vertices_clean, connections_3d_clean, colmap_rec, th = 1.5)
738
 
739
- if (len(all_3d_vertices_clean) < 2) or len(connections_3d_clean) < 1 and False:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
740
  print (f'Not enough vertices or connections in the 3D vertices')
741
  return empty_solution()
742
 
@@ -1352,6 +1939,27 @@ def our_get_vertices_and_edges(gest_seg_np, colmap_rec, img_id_substring, ade_se
1352
 
1353
  H, W = gest_seg_np.shape[:2]
1354
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1355
  points_cam, points_xyz_world, points_idxs = get_visible_points(colmap_rec, img_id_substring, R=R, t=t)
1356
 
1357
  uv, valid_indices = project_points_to_2d(points_cam, K, H, W)
 
12
  from visu import plot_reconstruction_local, plot_wireframe_local, plot_bpo_cameras_from_entry_local
13
  import pyvista as pv
14
  from fast_pointnet import save_patches_dataset, predict_vertex_from_patch
15
+ from fast_voxel import predict_vertex_from_patch_voxel
16
+ import time
17
+ from fast_pointnet_class import save_patches_dataset as save_patches_dataset_class
18
 
19
+ GENERATE_DATASET = False
20
+ DATASET_DIR = '/home/skvrnjan/personal/hohocustom/'
21
+ #DATASET_DIR = '/mnt/personal/skvrnjan/hohocustom'
22
+
23
+ GENERATE_DATASET_EDGES = True
24
+ #EDGES_DATASET_DIR = '/home/skvrnjan/personal/hohocustom_edges/'
25
+ EDGES_DATASET_DIR = '/mnt/personal/skvrnjan/hohocustom_edges'
26
 
27
  def convert_entry_to_human_readable(entry):
28
  out = {}
 
551
  # and update connections accordingly
552
  updated_connections = []
553
  if extracted_points:
554
+ #print(f"Merging groups based on point overlap... Processing {len(extracted_points)} groups")
555
  # Create a list to track which groups to keep
556
  groups_to_keep = []
557
  merged_groups = set() # Track which groups have been merged
 
586
  merged_points = np.vstack([merged_points, extracted_points[j]]) if len(merged_points) > 0 else extracted_points[j]
587
  merged_colors = np.vstack([merged_colors, extracted_colors[j]]) if len(merged_colors) > 0 else extracted_colors[j]
588
  merged_ids.update(ids_j)
589
+ merged_indices.append(j)
590
  merged_groups.add(j)
591
 
592
  # Add the merged group to the list of groups to keep
593
  if len(merged_points) > 0:
594
+ new_group_idx = len(groups_to_keep)
595
  groups_to_keep.append((merged_points, merged_colors, np.array(list(merged_ids))))
596
+
597
+ # Update mapping for all merged indices
598
+ for old_idx in merged_indices:
599
+ old_to_new_mapping[old_idx] = new_group_idx
600
 
601
  # Update extracted_points, extracted_colors, and extracted_ids with filtered results
602
  extracted_points = [group[0] for group in groups_to_keep]
603
  extracted_colors = [group[1] for group in groups_to_keep]
604
  extracted_ids = [group[2] for group in groups_to_keep]
605
 
606
+ # Update connections based on the new mapping
607
+ for start_idx, end_idx in all_flattened_connections:
608
+ if start_idx in old_to_new_mapping and end_idx in old_to_new_mapping:
609
+ new_start = old_to_new_mapping[start_idx]
610
+ new_end = old_to_new_mapping[end_idx]
611
+ # Only add connection if vertices are still different after merging
612
+ if new_start != new_end:
613
+ connection = tuple(sorted((new_start, new_end)))
614
+ if connection not in updated_connections:
615
+ updated_connections.append(connection)
616
 
617
+ #print(f"After merging, number of groups: {len(extracted_points)}")
618
+ #print(f"Updated connections: {updated_connections}")
619
 
620
  # Create visualization showing extracted points for each group as balls within their mean
621
  if False:
 
653
 
654
  plotter.show(title=f"Extracted Points within {ball_radius}m radius - Spheres at group means")
655
 
656
+ return extracted_points, extracted_colors, extracted_ids, whole_pcloud, updated_connections
657
+
658
+ def visu_pcloud_and_preds(colmap_rec, extracted_ids, extracted_points, extracted_colors, predicted_vertices, connections):
659
+ if extracted_ids:
660
+ plotter = pv.Plotter()
661
+
662
+ # Add all COLMAP points in gray
663
+ all_points = []
664
+ all_colors = []
665
+ for pid, p3D in colmap_rec.points3D.items():
666
+ all_points.append(p3D.xyz)
667
+ all_colors.append([0.8, 0.8, 0.8]) # Gray color
668
+
669
+ if all_points:
670
+ all_points = np.array(all_points)
671
+ all_colors = np.array(all_colors)
672
+ point_cloud = pv.PolyData(all_points)
673
+ point_cloud["colors"] = np.array(all_colors)
674
+ plotter.add_mesh(point_cloud, scalars="colors", rgb=True, point_size=3, render_points_as_spheres=True)
675
+
676
+ for group_idx, (group_points, group_colors) in enumerate(zip(extracted_points, extracted_colors)):
677
+ if len(group_points) > 0:
678
+ # Calculate mean position for this group
679
+ group_mean = np.mean(group_points, axis=0)
680
+
681
+ # Create a sphere at the mean position
682
+ sphere = pv.Sphere(radius=0.2, center=group_mean)
683
+ # Generate a random color for each group
684
+ group_color = np.random.rand(3)
685
+ plotter.add_mesh(sphere, color=group_color, opacity=0.5)
686
+
687
+ # Add the extracted points for this group in the same color
688
+ group_cloud = pv.PolyData(group_points)
689
+ plotter.add_mesh(group_cloud, color=group_color, point_size=6, render_points_as_spheres=True)
690
+
691
+ # Add predicted vertex as sphere if it exists and is valid
692
+ if group_idx < len(predicted_vertices):
693
+ pred_vertex = predicted_vertices[group_idx]
694
+ if not np.allclose(pred_vertex, [0.0, 0.0, 0.0]): # Check if it's not a zero vertex
695
+ pred_sphere = pv.Sphere(radius=0.15, center=pred_vertex)
696
+ plotter.add_mesh(pred_sphere, color="black", opacity=1.)
697
+
698
+ # Add connections between predicted vertices
699
+ if len(predicted_vertices) > 0 and len(connections) > 0:
700
+ valid_pred_vertices = []
701
+ valid_indices = []
702
+ for i, pred_vertex in enumerate(predicted_vertices):
703
+ if not np.allclose(pred_vertex, [0.0, 0.0, 0.0]):
704
+ valid_pred_vertices.append(pred_vertex)
705
+ valid_indices.append(i)
706
+
707
+ if len(valid_pred_vertices) > 1:
708
+ valid_pred_vertices = np.array(valid_pred_vertices)
709
+
710
+ # Create lines for connections
711
+ for start_idx, end_idx in connections:
712
+ if start_idx in valid_indices and end_idx in valid_indices:
713
+ # Map to valid vertex indices
714
+ valid_start = valid_indices.index(start_idx)
715
+ valid_end = valid_indices.index(end_idx)
716
+
717
+ # Create line between vertices
718
+ line_points = np.array([valid_pred_vertices[valid_start], valid_pred_vertices[valid_end]])
719
+ line = pv.Line(line_points[0], line_points[1])
720
+ plotter.add_mesh(line, color="red", line_width=3)
721
+
722
+ ball_radius = 1.0 # meters
723
+ plotter.show(title=f"Extracted Points within {ball_radius}m radius - Spheres at group means")
724
+
725
+ def generate_edge_patches(frame):
726
+ vertices = frame['wf_vertices']
727
+ connections = frame['wf_edges']
728
+
729
+ vertices = np.array(vertices) if vertices else np.empty((0, 3))
730
+
731
+ positive_patches = []
732
+ negative_patches = []
733
+
734
+ cylinder_radius = 0.5
735
+
736
+ colmap = frame['colmap_binary']
737
+
738
+ # Create 6D point cloud from COLMAP data
739
+ colmap_points_6d = []
740
+ for pid, p3D in colmap.points3D.items():
741
+ # Combine xyz coordinates and RGB color
742
+ point_6d = np.concatenate([p3D.xyz, p3D.color / 255.0]) # Normalize color to [0,1]
743
+ colmap_points_6d.append(point_6d)
744
+
745
+ colmap_points_6d = np.array(colmap_points_6d) if colmap_points_6d else np.empty((0, 6))
746
+
747
+ colmap_points_6d[:, 3:] = colmap_points_6d[:, 3:] * 2 - 1
748
+
749
+ # Extract 3D coordinates for faster vectorized operations
750
+ colmap_points_3d = colmap_points_6d[:, :3]
751
+
752
+ # For each connection, create a positive edge patch
753
+ for connection in connections:
754
+ start_idx, end_idx = connection
755
+
756
+ # Get start and end vertices from the connections
757
+ start_vertex = vertices[start_idx]
758
+ end_vertex = vertices[end_idx]
759
+
760
+ # Create line vector from start to end
761
+ line_vector = end_vertex - start_vertex
762
+ line_length = np.linalg.norm(line_vector)
763
+
764
+ # Normalize line vector
765
+ line_direction = line_vector / line_length
766
+
767
+ # Extend the line by 25 cm (0.25 meters) on both ends for more context
768
+ extension_length = 0.25 # 25 cm in meters
769
+ extended_start = start_vertex - extension_length * line_direction
770
+ extended_end = end_vertex + extension_length * line_direction
771
+ extended_line_length = line_length + 2 * extension_length
772
+
773
+ # Vectorized distance calculation
774
+ # Vector from extended start to all points
775
+ start_to_points = colmap_points_3d - extended_start[np.newaxis, :]
776
+
777
+ # Project onto line direction to get distance along extended line
778
+ projection_lengths = np.dot(start_to_points, line_direction)
779
+
780
+ # Filter points within extended line segment bounds
781
+ within_bounds = (projection_lengths >= 0) & (projection_lengths <= extended_line_length)
782
+
783
+ # Find closest points on extended line segment for all points
784
+ closest_points_on_line = extended_start[np.newaxis, :] + projection_lengths[:, np.newaxis] * line_direction[np.newaxis, :]
785
+
786
+ # Calculate perpendicular distances from points to line
787
+ perpendicular_distances = np.linalg.norm(colmap_points_3d - closest_points_on_line, axis=1)
788
+
789
+ # Find points within cylinder
790
+ within_cylinder = within_bounds & (perpendicular_distances <= cylinder_radius)
791
+
792
+ if np.sum(within_cylinder) <= 10:
793
+ continue
794
+
795
+ points_in_cylinder = colmap_points_6d[within_cylinder]
796
+ point_indices_in_cylinder = np.where(within_cylinder)[0]
797
+
798
+ # Center the patch at the midpoint of the original line (not extended)
799
+ line_midpoint = (start_vertex + end_vertex) / 2
800
+
801
+ # Shift points to center around origin
802
+ points_centered = points_in_cylinder.copy()
803
+ points_centered[:, :3] -= line_midpoint
804
+
805
+ # Create positive edge patch
806
+ positive_patch = {
807
+ 'patch_6d': points_centered,
808
+ 'connection': connection,
809
+ 'line_start': start_vertex - line_midpoint,
810
+ 'line_end': end_vertex - line_midpoint,
811
+ 'cylinder_radius': cylinder_radius,
812
+ 'point_indices': point_indices_in_cylinder,
813
+ 'label': 1, # Positive label for edge
814
+ 'center': line_midpoint
815
+ }
816
+
817
+
818
+ positive_patches.append(positive_patch)
819
+
820
+ # Generate negative edge patches by sampling random unconnected vertex pairs
821
+ num_negative_patches = len(positive_patches)
822
+
823
+ if num_negative_patches > 0 and len(vertices) >= 2:
824
+ # Create set of connected pairs for fast lookup
825
+ connected_pairs = set(tuple(sorted(conn)) for conn in connections)
826
+
827
+ # Generate all possible vertex pairs
828
+ vertex_indices = np.arange(len(vertices))
829
+ all_pairs = np.array(np.meshgrid(vertex_indices, vertex_indices)).T.reshape(-1, 2)
830
+
831
+ # Filter out pairs where both indices are the same
832
+ all_pairs = all_pairs[all_pairs[:, 0] != all_pairs[:, 1]]
833
+
834
+ # Sort pairs to match connected_pairs format
835
+ all_pairs_sorted = np.sort(all_pairs, axis=1)
836
+
837
+ # Find unconnected pairs
838
+ unconnected_mask = np.array([tuple(pair) not in connected_pairs for pair in all_pairs_sorted])
839
+ unconnected_pairs = all_pairs[unconnected_mask]
840
+
841
+ if len(unconnected_pairs) > 0:
842
+ # Pre-compute positive patch cylinder info for overlap checks
843
+ positive_cylinders = []
844
+ for pos_patch in positive_patches:
845
+ start_world = pos_patch['line_start'] + pos_patch['center']
846
+ end_world = pos_patch['line_end'] + pos_patch['center']
847
+ positive_cylinders.append({
848
+ 'start': start_world,
849
+ 'end': end_world,
850
+ 'radius': pos_patch['cylinder_radius']
851
+ })
852
+
853
+ # Randomly sample negative pairs without replacement
854
+ num_to_sample = min(num_negative_patches * 3, len(unconnected_pairs)) # Sample more to account for rejections
855
+ sampled_indices = np.random.choice(len(unconnected_pairs), size=num_to_sample, replace=False)
856
+ sampled_pairs = unconnected_pairs[sampled_indices]
857
+
858
+ for idx1, idx2 in sampled_pairs:
859
+ if len(negative_patches) >= num_negative_patches:
860
+ break
861
+
862
+ start_vertex = vertices[idx1]
863
+ end_vertex = vertices[idx2]
864
+
865
+ # Create line vector from start to end
866
+ line_vector = end_vertex - start_vertex
867
+ line_length = np.linalg.norm(line_vector)
868
+
869
+ # Normalize line vector
870
+ line_direction = line_vector / line_length
871
+
872
+ # Extend the line by 25 cm (0.25 meters) on both ends for more context
873
+ extension_length = 0.25 # 25 cm in meters
874
+ extended_start = start_vertex - extension_length * line_direction
875
+ extended_end = end_vertex + extension_length * line_direction
876
+ extended_line_length = line_length + 2 * extension_length
877
+
878
+ # Check cylinder overlap with positive patches
879
+ current_cylinder = {
880
+ 'start': extended_start,
881
+ 'end': extended_end,
882
+ 'radius': cylinder_radius
883
+ }
884
+
885
+ has_overlap = False
886
+ for pos_cylinder in positive_cylinders:
887
+ # Calculate cylinder-cylinder intersection volume
888
+ overlap_volume = calculate_cylinder_overlap_volume(current_cylinder, pos_cylinder)
889
+
890
+ # Calculate volumes of both cylinders
891
+ current_volume = np.pi * cylinder_radius**2 * extended_line_length
892
+ pos_height = np.linalg.norm(pos_cylinder['end'] - pos_cylinder['start'])
893
+ pos_volume = np.pi * pos_cylinder['radius']**2 * pos_height
894
+
895
+ # Calculate IoU
896
+ union_volume = current_volume + pos_volume - overlap_volume
897
+ if union_volume > 0:
898
+ iou = overlap_volume / union_volume
899
+ if iou > 0.5: # 0.2 IoU threshold
900
+ has_overlap = True
901
+ break
902
+
903
+ if has_overlap:
904
+ continue # Skip this negative patch due to cylinder overlap
905
+
906
+ # Vectorized distance calculation
907
+ # Vector from extended start to all points
908
+ start_to_points = colmap_points_3d - extended_start[np.newaxis, :]
909
+
910
+ # Project onto line direction to get distance along extended line
911
+ projection_lengths = np.dot(start_to_points, line_direction)
912
+
913
+ # Filter points within extended line segment bounds
914
+ within_bounds = (projection_lengths >= 0) & (projection_lengths <= extended_line_length)
915
+
916
+ # Find closest points on extended line segment for all points
917
+ closest_points_on_line = extended_start[np.newaxis, :] + projection_lengths[:, np.newaxis] * line_direction[np.newaxis, :]
918
+
919
+ # Calculate perpendicular distances from points to line
920
+ perpendicular_distances = np.linalg.norm(colmap_points_3d - closest_points_on_line, axis=1)
921
+
922
+ # Find points within cylinder
923
+ within_cylinder = within_bounds & (perpendicular_distances <= cylinder_radius)
924
+
925
+ if np.sum(within_cylinder) <= 10:
926
+ continue
927
+
928
+ points_in_cylinder = colmap_points_6d[within_cylinder]
929
+ point_indices_in_cylinder = np.where(within_cylinder)[0]
930
+
931
+ # Center the patch at the midpoint of the original line (not extended)
932
+ line_midpoint = (start_vertex + end_vertex) / 2
933
+
934
+ # Shift points to center around origin
935
+ points_centered = points_in_cylinder.copy()
936
+ points_centered[:, :3] -= line_midpoint
937
+
938
+ # Create negative edge patch
939
+ negative_patch = {
940
+ 'patch_6d': points_centered,
941
+ 'connection': (idx1, idx2),
942
+ 'line_start': start_vertex - line_midpoint,
943
+ 'line_end': end_vertex - line_midpoint,
944
+ 'cylinder_radius': cylinder_radius,
945
+ 'point_indices': point_indices_in_cylinder,
946
+ 'label': 0, # Negative label for non-edge
947
+ 'center': line_midpoint # Center of the patch
948
+ }
949
+
950
+ negative_patches.append(negative_patch)
951
+
952
+ print(f"Generated {len(positive_patches)} positive patches and {len(negative_patches)} negative patches")
953
+ all_patches = positive_patches + negative_patches
954
+
955
+ # Visualize edge patches
956
+ if False: # Set to True to enable visualization
957
+ # Create plotter
958
+ plotter = pv.Plotter()
959
+
960
+ # Add whole point cloud in gray
961
+ if len(colmap_points_6d) > 0:
962
+ whole_cloud = pv.PolyData(colmap_points_3d)
963
+ gray_colors = np.full((len(colmap_points_3d), 3), [0.5, 0.5, 0.5])
964
+ whole_cloud["colors"] = gray_colors
965
+ plotter.add_mesh(whole_cloud, scalars="colors", rgb=True, point_size=3, render_points_as_spheres=True)
966
+
967
+ # Visualize each patch
968
+ for patch_idx, patch in enumerate(all_patches):
969
+ # Use green for positive (edge), red for negative (non-edge)
970
+ patch_color = 'green' if patch['label'] == 1 else 'red'
971
+
972
+ # Get patch data
973
+ points_in_cylinder = patch['patch_6d'][:, :3] # xyz coordinates
974
+ line_start = patch['line_start']
975
+ line_end = patch['line_end']
976
+ center = patch['center'] # Use center instead of calculating midpoint
977
+
978
+ # Shift points back to world coordinates for visualization
979
+ points_world = points_in_cylinder + center
980
+
981
+ # Add points inside cylinder with patch-specific color
982
+ if len(points_world) > 0:
983
+ cylinder_cloud = pv.PolyData(points_world)
984
+ plotter.add_mesh(cylinder_cloud, color=patch_color, point_size=8, render_points_as_spheres=True)
985
+
986
+ # Add start and end points as larger spheres
987
+ start_sphere = pv.Sphere(radius=0.1, center=line_start + center)
988
+ end_sphere = pv.Sphere(radius=0.1, center=line_end + center)
989
+ plotter.add_mesh(start_sphere, color='black', opacity=0.8)
990
+ plotter.add_mesh(end_sphere, color='white', opacity=0.8)
991
+
992
+ # Add line between start and end
993
+ line_points = np.array([line_start + center, line_end + center])
994
+ line = pv.Line(line_points[0], line_points[1])
995
+ plotter.add_mesh(line, color=patch_color, line_width=5)
996
+
997
+ # Add cylinder wireframe to show extraction bounds
998
+ cylinder_center = center
999
+ cylinder_direction = (line_end - line_start) / np.linalg.norm(line_end - line_start)
1000
+ cylinder_height = np.linalg.norm(line_end - line_start) + 2 * 0.25 # Including extensions
1001
+
1002
+ # Create cylinder mesh for visualization
1003
+ cylinder_mesh = pv.Cylinder(center=cylinder_center, direction=cylinder_direction,
1004
+ radius=patch['cylinder_radius'], height=cylinder_height)
1005
+ plotter.add_mesh(cylinder_mesh, color=patch_color, opacity=0.2, style='wireframe')
1006
+
1007
+ # Set title based on label distribution
1008
+ positive_count = sum(1 for patch in all_patches if patch['label'] == 1)
1009
+ negative_count = sum(1 for patch in all_patches if patch['label'] == 0)
1010
+ title = f"Edge Patches - Positive (Green): {positive_count}, Negative (Red): {negative_count}"
1011
+
1012
+ plotter.show(title=title)
1013
+
1014
+ return all_patches
1015
 
1016
+ def calculate_cylinder_overlap_volume(cyl1, cyl2):
1017
+ """
1018
+ Calculate the intersection volume between two cylinders using numpy vectorization.
1019
+ Returns approximate overlap volume.
1020
+ """
1021
+ # Get cylinder parameters
1022
+ p1_start, p1_end = cyl1['start'], cyl1['end']
1023
+ p2_start, p2_end = cyl2['start'], cyl2['end']
1024
+ r1, r2 = cyl1['radius'], cyl2['radius']
1025
+
1026
+ # Calculate cylinder axes
1027
+ axis1 = p1_end - p1_start
1028
+ axis2 = p2_end - p2_start
1029
+ len1 = np.linalg.norm(axis1)
1030
+ len2 = np.linalg.norm(axis2)
1031
+
1032
+ if len1 == 0 or len2 == 0:
1033
+ return 0.0
1034
+
1035
+ axis1_norm = axis1 / len1
1036
+ axis2_norm = axis2 / len2
1037
+
1038
+ # Calculate distance between cylinder axes using line-line distance formula
1039
+ w = p1_start - p2_start
1040
+ a = np.dot(axis1_norm, axis1_norm)
1041
+ b = np.dot(axis1_norm, axis2_norm)
1042
+ c = np.dot(axis2_norm, axis2_norm)
1043
+ d = np.dot(axis1_norm, w)
1044
+ e = np.dot(axis2_norm, w)
1045
+
1046
+ denom = a * c - b * b
1047
+ if abs(denom) < 1e-10: # Lines are parallel
1048
+ # Calculate perpendicular distance between parallel lines
1049
+ cross_product = np.cross(axis1_norm, w)
1050
+ if axis1_norm.shape[0] == 3: # 3D case
1051
+ dist = np.linalg.norm(cross_product)
1052
+ else: # 2D case
1053
+ dist = abs(cross_product)
1054
+ else:
1055
+ # Calculate closest points on both lines
1056
+ t1 = (b * e - c * d) / denom
1057
+ t2 = (a * e - b * d) / denom
1058
+
1059
+ # Clamp to cylinder bounds
1060
+ t1 = np.clip(t1, 0, len1)
1061
+ t2 = np.clip(t2, 0, len2)
1062
+
1063
+ # Calculate distance between closest points
1064
+ point1 = p1_start + t1 * axis1_norm
1065
+ point2 = p2_start + t2 * axis2_norm
1066
+ dist = np.linalg.norm(point1 - point2)
1067
+
1068
+ # If cylinders don't intersect radially, return 0
1069
+ if dist >= (r1 + r2):
1070
+ return 0.0
1071
+
1072
+ # Calculate overlapping length along both axes
1073
+ # Project cylinder 2 endpoints onto cylinder 1 axis
1074
+ proj_start = np.dot(p2_start - p1_start, axis1_norm)
1075
+ proj_end = np.dot(p2_end - p1_start, axis1_norm)
1076
+
1077
+ # Find overlap interval
1078
+ overlap_start = max(0, min(proj_start, proj_end))
1079
+ overlap_end = min(len1, max(proj_start, proj_end))
1080
+ overlap_length = max(0, overlap_end - overlap_start)
1081
+
1082
+ if overlap_length <= 0:
1083
+ return 0.0
1084
+
1085
+ # Approximate volume calculation
1086
+ # For simplicity, assume uniform overlap along the length
1087
+ if dist < abs(r1 - r2):
1088
+ # One cylinder is inside the other
1089
+ smaller_radius = min(r1, r2)
1090
+ overlap_volume = np.pi * smaller_radius**2 * overlap_length
1091
+ else:
1092
+ # Partial overlap - use geometric approximation
1093
+ # This is a simplified calculation for the intersection area of two circles
1094
+ r_smaller = min(r1, r2)
1095
+ r_larger = max(r1, r2)
1096
+
1097
+ if dist < (r1 + r2):
1098
+ # Calculate intersection area of two circles (approximate)
1099
+ # Using lens area formula
1100
+ d1 = (r1**2 - r2**2 + dist**2) / (2 * dist) if dist > 0 else 0
1101
+ d2 = dist - d1
1102
+
1103
+ if d1 >= 0 and d1 <= r1 and d2 >= 0 and d2 <= r2:
1104
+ area1 = r1**2 * np.arccos(d1/r1) - d1 * np.sqrt(r1**2 - d1**2)
1105
+ area2 = r2**2 * np.arccos(d2/r2) - d2 * np.sqrt(r2**2 - d2**2)
1106
+ intersection_area = area1 + area2
1107
+ else:
1108
+ intersection_area = np.pi * r_smaller**2
1109
+
1110
+ overlap_volume = intersection_area * overlap_length
1111
+ else:
1112
+ overlap_volume = 0.0
1113
+
1114
+ return max(0.0, overlap_volume)
1115
+
1116
+ def predict_wireframe(entry, pnet_model, voxel_model) -> Tuple[np.ndarray, List[int]]:
1117
  """
1118
  Predict 3D wireframe from a dataset entry.
1119
  """
1120
  good_entry = convert_entry_to_human_readable(entry)
1121
  colmap_rec = good_entry['colmap_binary']
1122
 
1123
+ if GENERATE_DATASET_EDGES:
1124
+ patches = generate_edge_patches(good_entry)
1125
+ save_patches_dataset_class(patches, EDGES_DATASET_DIR, good_entry['order_id'])
1126
+ return empty_solution()
1127
+
1128
  vert_edge_per_image = {}
1129
  idxs_points = []
1130
+ all_connections = []
1131
  for i, (gest, depth, K, R, t, img_id, ade_seg) in enumerate(zip(good_entry['gestalt'],
1132
  good_entry['depth'],
1133
  good_entry['K'],
 
1148
 
1149
  vertices_ours, connections_ours, vertices_3d_ours, patches, filtered_point_idxs = our_get_vertices_and_edges(gest_seg_np, colmap_rec, img_id, ade_seg, depth, K=K, R=R, t=t, frame=good_entry)
1150
  idxs_points.append(filtered_point_idxs)
1151
+ all_connections.append(connections_ours)
1152
 
1153
  '''
1154
  if GENERATE_DATASET:
 
1211
 
1212
  vert_edge_per_image[i] = vertices, connections, vertices_3d
1213
 
1214
+ extracted_points, extracted_colors, extracted_ids, whole_pcloud, connections = extract_vertices_from_whole_pcloud(colmap_rec, idxs_points, all_connections)
1215
 
1216
  patches = generate_patches_v2(extracted_points, extracted_colors, extracted_ids, whole_pcloud, good_entry['wf_vertices'])
1217
 
1218
+ # Predict vertices from patches using the neural network
1219
+ predicted_vertices = []
1220
+ for patch in patches:
1221
+ pred_vertex, pred_dist, pred_class = predict_vertex_from_patch(pnet_model, patch, device='cuda')
1222
+
1223
+ #visu_patch_and_pred(patch, pred_vertex, pred_dist, pred_class)
1224
+
1225
+ if pred_class > 0.5:
1226
+ predicted_vertices.append(pred_vertex)
1227
+ else:
1228
+ predicted_vertices.append(np.array([0.0, 0.0, 0.0])) # Append a zero vertex if not predicted
1229
+
1230
+ #pred_vertex_voxel, pred_dist_voxel, pred_class_voxel = predict_vertex_from_patch_voxel(voxel_model, patch, device='cuda')
1231
+ #visu_patch_and_pred(patch, pred_vertex_voxel, pred_dist_voxel, pred_class_voxel)
1232
+
1233
+ predicted_vertices = np.array(predicted_vertices) if predicted_vertices else np.empty((0, 3))
1234
+
1235
+ #visu_pcloud_and_preds(colmap_rec, extracted_ids, extracted_points, extracted_colors, predicted_vertices, connections)
1236
+
1237
  if GENERATE_DATASET:
1238
  save_patches_dataset(patches, DATASET_DIR, img_id)
1239
  return empty_solution()
1240
 
1241
  # Merge vertices from all images
1242
+ #all_3d_vertices, connections_3d = merge_vertices_3d(vert_edge_per_image, 0.1)
1243
+ #all_3d_vertices_clean, connections_3d_clean = all_3d_vertices, connections_3d
1244
  #all_3d_vertices_clean, connections_3d_clean = prune_not_connected(all_3d_vertices, connections_3d, keep_largest=False)
1245
  #all_3d_vertices_clean, connections_3d_clean = prune_too_far(all_3d_vertices_clean, connections_3d_clean, colmap_rec, th = 1.5)
1246
 
1247
+ #if (len(all_3d_vertices_clean) < 2) or len(connections_3d_clean) < 1 and False:
1248
+ # print (f'Not enough vertices or connections in the 3D vertices')
1249
+ # return empty_solution()
1250
+
1251
+ # Filter out zero vertices and update connections accordingly
1252
+ non_zero_mask = ~np.all(np.isclose(predicted_vertices, [0.0, 0.0, 0.0]), axis=1)
1253
+ valid_indices = np.where(non_zero_mask)[0]
1254
+
1255
+ if len(valid_indices) == 0:
1256
+ print("No valid predicted vertices found")
1257
+ return empty_solution()
1258
+
1259
+ # Filter vertices to only include non-zero ones
1260
+ filtered_vertices = predicted_vertices[valid_indices]
1261
+
1262
+ # Create mapping from old indices to new indices
1263
+ old_to_new_mapping = {old_idx: new_idx for new_idx, old_idx in enumerate(valid_indices)}
1264
+
1265
+ # Filter and update connections
1266
+ filtered_connections = []
1267
+ for start_idx, end_idx in connections:
1268
+ if start_idx in old_to_new_mapping and end_idx in old_to_new_mapping:
1269
+ new_start = old_to_new_mapping[start_idx]
1270
+ new_end = old_to_new_mapping[end_idx]
1271
+ if new_start != new_end: # Ensure we don't connect a vertex to itself
1272
+ filtered_connections.append((new_start, new_end))
1273
+
1274
+ #print(f"Filtered vertices: {len(filtered_vertices)} from {len(predicted_vertices)}")
1275
+ #print(f"Filtered connections: {len(filtered_connections)} from {len(connections)}")
1276
+
1277
+ predicted_vertices = np.array(filtered_vertices)
1278
+ connections = filtered_connections
1279
+
1280
+ return predicted_vertices, connections
1281
+
1282
+ def predict_wireframe_old(entry) -> Tuple[np.ndarray, List[int]]:
1283
+ """
1284
+ Predict 3D wireframe from a dataset entry.
1285
+ """
1286
+ good_entry = convert_entry_to_human_readable(entry)
1287
+ vert_edge_per_image = {}
1288
+ for i, (gest, depth, K, R, t, img_id, ade_seg) in enumerate(zip(good_entry['gestalt'],
1289
+ good_entry['depth'],
1290
+ good_entry['K'],
1291
+ good_entry['R'],
1292
+ good_entry['t'],
1293
+ good_entry['image_ids'],
1294
+ good_entry['ade'] # Added ade20k segmentation
1295
+ )):
1296
+ colmap_rec = good_entry['colmap_binary']
1297
+ K = np.array(K)
1298
+ R = np.array(R)
1299
+ t = np.array(t)
1300
+ # Resize gestalt segmentation to match depth map size
1301
+ depth_size = (np.array(depth).shape[1], np.array(depth).shape[0]) # W, H
1302
+ gest_seg = gest.resize(depth_size)
1303
+ gest_seg_np = np.array(gest_seg).astype(np.uint8)
1304
+
1305
+ # Get 2D vertices and edges first
1306
+ vertices, connections = get_vertices_and_edges_from_segmentation(gest_seg_np, edge_th=25.)
1307
+
1308
+ # Check if we have enough to proceed
1309
+ if (len(vertices) < 2) or (len(connections) < 1):
1310
+ print(f'Not enough vertices or connections found in image {i}, skipping.')
1311
+ vert_edge_per_image[i] = [], [], np.empty((0, 3))
1312
+ continue
1313
+
1314
+ # Call the refactored function to get 3D points
1315
+ vertices_3d = create_3d_wireframe_single_image(
1316
+ vertices, connections, depth, colmap_rec, img_id, ade_seg, K, R, t
1317
+ )
1318
+ # Store original 2D vertices, connections, and computed 3D points
1319
+ vert_edge_per_image[i] = vertices, connections, vertices_3d
1320
+
1321
+ # Merge vertices from all images
1322
+ all_3d_vertices, connections_3d = merge_vertices_3d(vert_edge_per_image, 0.5)
1323
+ all_3d_vertices_clean, connections_3d_clean = prune_not_connected(all_3d_vertices, connections_3d, keep_largest=False)
1324
+ all_3d_vertices_clean, connections_3d_clean = prune_too_far(all_3d_vertices_clean, connections_3d_clean, colmap_rec, th = 1.5)
1325
+
1326
+ if (len(all_3d_vertices_clean) < 2) or len(connections_3d_clean) < 1:
1327
  print (f'Not enough vertices or connections in the 3D vertices')
1328
  return empty_solution()
1329
 
 
1939
 
1940
  H, W = gest_seg_np.shape[:2]
1941
 
1942
+ # Get camera parameters from COLMAP reconstruction if not provided
1943
+ if False:
1944
+ # Find the matching COLMAP image
1945
+ found_img = None
1946
+ for img_id_c, col_img_obj in colmap_rec.images.items():
1947
+ if img_id_substring in col_img_obj.name:
1948
+ found_img = col_img_obj
1949
+ break
1950
+
1951
+ if found_img is not None:
1952
+ # Get camera intrinsic matrix
1953
+ K = found_img.camera.calibration_matrix()
1954
+
1955
+ # Get world-to-camera transformation matrix
1956
+ world_to_cam = found_img.cam_from_world.matrix()
1957
+ R = world_to_cam[:3, :3]
1958
+ t = world_to_cam[:3, 3]
1959
+ else:
1960
+ print(f"Image substring {img_id_substring} not found in COLMAP.")
1961
+ return [], [], [], [], []
1962
+
1963
  points_cam, points_xyz_world, points_idxs = get_visible_points(colmap_rec, img_id_substring, R=R, t=t)
1964
 
1965
  uv, valid_indices = project_points_to_2d(points_cam, K, H, W)