jskvrna commited on
Commit
d4c0222
·
1 Parent(s): c4db724

Generates edge patches from predicted vertices

Browse files

Implements edge patch generation based on predicted vertices instead of ground truth.

It finds the closest ground truth vertex for each predicted vertex within a distance threshold and propagates ground truth connections to predicted vertices.
This facilitates training using predicted wireframes.

Also, the IoU threshold for positive patch overlap is increased to 1.0, effectively disabling it.

Enables generation of edge datasets from predicted vertices.

Adds visualization of ground truth vertices and connections in blue to the edge patch visualization.

Files changed (2) hide show
  1. predict.py +71 -20
  2. train.py +10 -10
predict.py CHANGED
@@ -16,14 +16,16 @@ from fast_voxel import predict_vertex_from_patch_voxel
16
  import time
17
  from fast_pointnet_class import save_patches_dataset as save_patches_dataset_class
18
  from fast_pointnet_class import predict_class_from_patch
 
 
19
 
20
  GENERATE_DATASET = False
21
  DATASET_DIR = '/home/skvrnjan/personal/hohocustom/'
22
  #DATASET_DIR = '/mnt/personal/skvrnjan/hohocustom'
23
 
24
- GENERATE_DATASET_EDGES = False
25
- EDGES_DATASET_DIR = '/home/skvrnjan/personal/hohocustom_edges/'
26
- #EDGES_DATASET_DIR = '/mnt/personal/skvrnjan/hohocustom_edges'
27
 
28
  def convert_entry_to_human_readable(entry):
29
  out = {}
@@ -723,12 +725,41 @@ def visu_pcloud_and_preds(colmap_rec, extracted_ids, extracted_points, extracted
723
  ball_radius = 1.0 # meters
724
  plotter.show(title=f"Extracted Points within {ball_radius}m radius - Spheres at group means")
725
 
726
- def generate_edge_patches(frame):
727
- vertices = frame['wf_vertices']
728
- connections = frame['wf_edges']
729
-
730
- vertices = np.array(vertices) if vertices else np.empty((0, 3))
731
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
732
  positive_patches = []
733
  negative_patches = []
734
 
@@ -897,7 +928,7 @@ def generate_edge_patches(frame):
897
  union_volume = current_volume + pos_volume - overlap_volume
898
  if union_volume > 0:
899
  iou = overlap_volume / union_volume
900
- if iou > 0.5: # 0.2 IoU threshold
901
  has_overlap = True
902
  break
903
 
@@ -954,7 +985,7 @@ def generate_edge_patches(frame):
954
  all_patches = positive_patches + negative_patches
955
 
956
  # Visualize edge patches
957
- if True: # Set to True to enable visualization
958
  # Create plotter
959
  plotter = pv.Plotter()
960
 
@@ -965,6 +996,24 @@ def generate_edge_patches(frame):
965
  whole_cloud["colors"] = gray_colors
966
  plotter.add_mesh(whole_cloud, scalars="colors", rgb=True, point_size=3, render_points_as_spheres=True)
967
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
968
  # Visualize each patch
969
  for patch_idx, patch in enumerate(all_patches):
970
  # Use green for positive (edge), red for negative (non-edge)
@@ -1008,7 +1057,7 @@ def generate_edge_patches(frame):
1008
  # Set title based on label distribution
1009
  positive_count = sum(1 for patch in all_patches if patch['label'] == 1)
1010
  negative_count = sum(1 for patch in all_patches if patch['label'] == 0)
1011
- title = f"Edge Patches - Positive (Green): {positive_count}, Negative (Red): {negative_count}"
1012
 
1013
  plotter.show(title=title)
1014
 
@@ -1332,11 +1381,6 @@ def predict_wireframe(entry, pnet_model, voxel_model, pnet_class_model, config)
1332
  edge_threshold = config.get('edge_threshold', 0.5)
1333
  only_predicted_connections = config.get('only_predicted_connections', False)
1334
 
1335
- if GENERATE_DATASET_EDGES:
1336
- patches = generate_edge_patches(good_entry)
1337
- #save_patches_dataset_class(patches, EDGES_DATASET_DIR, good_entry['order_id'])
1338
- return empty_solution()
1339
-
1340
  vert_edge_per_image = {}
1341
  idxs_points = []
1342
  all_connections = []
@@ -1467,13 +1511,20 @@ def predict_wireframe(entry, pnet_model, voxel_model, pnet_class_model, config)
1467
  non_zero_mask = ~np.all(np.isclose(predicted_vertices, [0.0, 0.0, 0.0]), axis=1)
1468
  valid_indices = np.where(non_zero_mask)[0]
1469
 
 
 
 
 
 
 
 
 
 
 
1470
  if len(valid_indices) == 0:
1471
  print("No valid predicted vertices found")
1472
  return empty_solution()
1473
 
1474
- # Filter vertices to only include non-zero ones
1475
- filtered_vertices = predicted_vertices[valid_indices]
1476
-
1477
  # Create mapping from old indices to new indices
1478
  old_to_new_mapping = {old_idx: new_idx for new_idx, old_idx in enumerate(valid_indices)}
1479
 
 
16
  import time
17
  from fast_pointnet_class import save_patches_dataset as save_patches_dataset_class
18
  from fast_pointnet_class import predict_class_from_patch
19
+ from scipy.spatial.distance import cdist
20
+ from scipy.optimize import linear_sum_assignment
21
 
22
  GENERATE_DATASET = False
23
  DATASET_DIR = '/home/skvrnjan/personal/hohocustom/'
24
  #DATASET_DIR = '/mnt/personal/skvrnjan/hohocustom'
25
 
26
+ GENERATE_DATASET_EDGES = True
27
+ #EDGES_DATASET_DIR = '/home/skvrnjan/personal/hohocustom_edges/'
28
+ EDGES_DATASET_DIR = '/mnt/personal/skvrnjan/hohocustom_edges'
29
 
30
  def convert_entry_to_human_readable(entry):
31
  out = {}
 
725
  ball_radius = 1.0 # meters
726
  plotter.show(title=f"Extracted Points within {ball_radius}m radius - Spheres at group means")
727
 
728
+ def generate_edge_patches(frame, pred_vertices):
729
+ gt_vertices = np.array(frame['wf_vertices']) if frame['wf_vertices'] else np.empty((0, 3))
730
+ gt_connections = frame['wf_edges']
731
+
732
+ vertices = np.array(pred_vertices) if pred_vertices is not None and len(pred_vertices) > 0 else np.empty((0, 3))
733
+
734
+ # Find closest GT vertex for each predicted vertex
735
+ connections = []
736
+ if len(vertices) > 0 and len(gt_vertices) > 0:
737
+ # For each GT vertex, find the closest predicted vertex
738
+ gt_to_pred_mapping = {}
739
+ for gt_idx, gt_vertex in enumerate(gt_vertices):
740
+ # Calculate distances from this GT vertex to all predicted vertices
741
+ distances = np.linalg.norm(vertices - gt_vertex, axis=1)
742
+
743
+ # Find the closest predicted vertex
744
+ closest_pred_idx = np.argmin(distances)
745
+ closest_distance = distances[closest_pred_idx]
746
+
747
+ # Only map if within distance threshold
748
+ distance_threshold = 1.5
749
+ if closest_distance <= distance_threshold:
750
+ gt_to_pred_mapping[gt_idx] = closest_pred_idx
751
+
752
+ # Propagate GT connections to predicted vertices
753
+ for gt_connection in gt_connections:
754
+ gt_start, gt_end = gt_connection
755
+ if gt_start in gt_to_pred_mapping and gt_end in gt_to_pred_mapping:
756
+ pred_start = gt_to_pred_mapping[gt_start]
757
+ pred_end = gt_to_pred_mapping[gt_end]
758
+ connections.append((pred_start, pred_end))
759
+
760
+ print(f"Matched {len(gt_to_pred_mapping)} GT vertices to predicted vertices")
761
+ print(f"Propagated {len(connections)} connections from GT to predicted vertices")
762
+
763
  positive_patches = []
764
  negative_patches = []
765
 
 
928
  union_volume = current_volume + pos_volume - overlap_volume
929
  if union_volume > 0:
930
  iou = overlap_volume / union_volume
931
+ if iou > 1.: # 0.2 IoU threshold
932
  has_overlap = True
933
  break
934
 
 
985
  all_patches = positive_patches + negative_patches
986
 
987
  # Visualize edge patches
988
+ if False: # Set to True to enable visualization
989
  # Create plotter
990
  plotter = pv.Plotter()
991
 
 
996
  whole_cloud["colors"] = gray_colors
997
  plotter.add_mesh(whole_cloud, scalars="colors", rgb=True, point_size=3, render_points_as_spheres=True)
998
 
999
+ # Add GT vertices and connections in blue
1000
+ gt_vertices = np.array(frame['wf_vertices']) if frame['wf_vertices'] else np.empty((0, 3))
1001
+ gt_connections = frame['wf_edges']
1002
+
1003
+ if len(gt_vertices) > 0:
1004
+ # Add GT vertices as blue spheres
1005
+ for gt_vertex in gt_vertices:
1006
+ gt_sphere = pv.Sphere(radius=0.15, center=gt_vertex)
1007
+ plotter.add_mesh(gt_sphere, color='blue', opacity=0.8)
1008
+
1009
+ # Add GT connections as blue lines
1010
+ for gt_connection in gt_connections:
1011
+ gt_start_idx, gt_end_idx = gt_connection
1012
+ if gt_start_idx < len(gt_vertices) and gt_end_idx < len(gt_vertices):
1013
+ gt_line_points = np.array([gt_vertices[gt_start_idx], gt_vertices[gt_end_idx]])
1014
+ gt_line = pv.Line(gt_line_points[0], gt_line_points[1])
1015
+ plotter.add_mesh(gt_line, color='blue', line_width=8)
1016
+
1017
  # Visualize each patch
1018
  for patch_idx, patch in enumerate(all_patches):
1019
  # Use green for positive (edge), red for negative (non-edge)
 
1057
  # Set title based on label distribution
1058
  positive_count = sum(1 for patch in all_patches if patch['label'] == 1)
1059
  negative_count = sum(1 for patch in all_patches if patch['label'] == 0)
1060
+ title = f"Edge Patches - Positive (Green): {positive_count}, Negative (Red): {negative_count}, GT (Blue)"
1061
 
1062
  plotter.show(title=title)
1063
 
 
1381
  edge_threshold = config.get('edge_threshold', 0.5)
1382
  only_predicted_connections = config.get('only_predicted_connections', False)
1383
 
 
 
 
 
 
1384
  vert_edge_per_image = {}
1385
  idxs_points = []
1386
  all_connections = []
 
1511
  non_zero_mask = ~np.all(np.isclose(predicted_vertices, [0.0, 0.0, 0.0]), axis=1)
1512
  valid_indices = np.where(non_zero_mask)[0]
1513
 
1514
+ # Filter vertices to only include non-zero ones
1515
+ filtered_vertices = predicted_vertices[valid_indices]
1516
+
1517
+ #patches = generate_edge_patches(good_entry, filtered_vertices)
1518
+ if GENERATE_DATASET_EDGES:
1519
+ patches = generate_edge_patches(good_entry, filtered_vertices)
1520
+ save_patches_dataset_class(patches, EDGES_DATASET_DIR, good_entry['order_id'])
1521
+
1522
+ return empty_solution()
1523
+
1524
  if len(valid_indices) == 0:
1525
  print("No valid predicted vertices found")
1526
  return empty_solution()
1527
 
 
 
 
1528
  # Create mapping from old indices to new indices
1529
  old_to_new_mapping = {old_idx: new_idx for new_idx, old_idx in enumerate(valid_indices)}
1530
 
train.py CHANGED
@@ -19,9 +19,9 @@ from fast_voxel import load_3dcnn_model
19
  from fast_pointnet_class import load_pointnet_model as load_pointnet_class_model
20
  import torch
21
 
22
- ds = load_dataset("usm3d/hoho25k", cache_dir="/media/skvrnjan/sd/hoho25k/", trust_remote_code=True)
23
- #ds = load_dataset("usm3d/hoho25k", cache_dir="/mnt/personal/skvrnjan/hoho25k/", trust_remote_code=True)
24
- #ds = ds.shuffle()
25
 
26
  scores_hss = []
27
  scores_f1 = []
@@ -31,11 +31,11 @@ show_visu = False
31
 
32
  device = "cuda" if torch.cuda.is_available() else "cpu"
33
 
34
- pnet_model = load_pointnet_model(model_path="/home/skvrnjan/personal/hoho_pnet/initial_epoch_100.pth", device=device, predict_score=True)
35
- #pnet_model = None
36
 
37
- pnet_class_model = load_pointnet_class_model(model_path="/home/skvrnjan/personal/hoho_pnet_edges_v2/initial_epoch_100.pth", device=device)
38
- #pnet_class_model = None
39
 
40
  #voxel_model = load_3dcnn_model(model_path="/home/skvrnjan/personal/hoho_voxel/initial_epoch_100.pth", device=device, predict_score=True)
41
  voxel_model = None
@@ -46,7 +46,7 @@ idx = 0
46
  for a in tqdm(ds['train'], desc="Processing dataset"):
47
  #plot_all_modalities(a)
48
  #pred_vertices, pred_edges = predict_wireframe_old(a)
49
- pred_vertices, pred_edges = predict_wireframe(a, pnet_model, voxel_model, pnet_class_model, config)
50
  try:
51
  pred_vertices, pred_edges = predict_wireframe(a, pnet_model, voxel_model, pnet_class_model, config)
52
  #pred_vertices, pred_edges = predict_wireframe_old(a)
@@ -70,8 +70,8 @@ for a in tqdm(ds['train'], desc="Processing dataset"):
70
  o3d.visualization.draw_geometries(visu_all, window_name=f"3D Reconstruction - HSS: {score.hss:.4f}, F1: {score.f1:.4f}, IoU: {score.iou:.4f}")
71
 
72
  idx += 1
73
- if idx >= 100: # Limit to first 10 samples for testing
74
- break
75
 
76
  for i in range(10):
77
  print("END OF DATASET")
 
19
  from fast_pointnet_class import load_pointnet_model as load_pointnet_class_model
20
  import torch
21
 
22
+ #ds = load_dataset("usm3d/hoho25k", cache_dir="/media/skvrnjan/sd/hoho25k/", trust_remote_code=True)
23
+ ds = load_dataset("usm3d/hoho25k", cache_dir="/mnt/personal/skvrnjan/hoho25k/", trust_remote_code=True)
24
+ ds = ds.shuffle()
25
 
26
  scores_hss = []
27
  scores_f1 = []
 
31
 
32
  device = "cuda" if torch.cuda.is_available() else "cpu"
33
 
34
+ #pnet_model = load_pointnet_model(model_path="/home/skvrnjan/personal/hoho_pnet/initial_epoch_100.pth", device=device, predict_score=True)
35
+ pnet_model = None
36
 
37
+ #pnet_class_model = load_pointnet_class_model(model_path="/home/skvrnjan/personal/hoho_pnet_edges_v2/initial_epoch_100.pth", device=device)
38
+ pnet_class_model = None
39
 
40
  #voxel_model = load_3dcnn_model(model_path="/home/skvrnjan/personal/hoho_voxel/initial_epoch_100.pth", device=device, predict_score=True)
41
  voxel_model = None
 
46
  for a in tqdm(ds['train'], desc="Processing dataset"):
47
  #plot_all_modalities(a)
48
  #pred_vertices, pred_edges = predict_wireframe_old(a)
49
+ #pred_vertices, pred_edges = predict_wireframe(a, pnet_model, voxel_model, pnet_class_model, config)
50
  try:
51
  pred_vertices, pred_edges = predict_wireframe(a, pnet_model, voxel_model, pnet_class_model, config)
52
  #pred_vertices, pred_edges = predict_wireframe_old(a)
 
70
  o3d.visualization.draw_geometries(visu_all, window_name=f"3D Reconstruction - HSS: {score.hss:.4f}, F1: {score.f1:.4f}, IoU: {score.iou:.4f}")
71
 
72
  idx += 1
73
+ #if idx >= 100: # Limit to first 10 samples for testing
74
+ # break
75
 
76
  for i in range(10):
77
  print("END OF DATASET")