Generates edge patches from predicted vertices
Browse filesImplements edge patch generation based on predicted vertices instead of ground truth.
It finds the closest ground truth vertex for each predicted vertex within a distance threshold and propagates ground truth connections to predicted vertices.
This facilitates training using predicted wireframes.
Also, the IoU threshold for positive patch overlap is increased to 1.0, effectively disabling it.
Enables generation of edge datasets from predicted vertices.
Adds visualization of ground truth vertices and connections in blue to the edge patch visualization.
- predict.py +71 -20
- train.py +10 -10
predict.py
CHANGED
|
@@ -16,14 +16,16 @@ from fast_voxel import predict_vertex_from_patch_voxel
|
|
| 16 |
import time
|
| 17 |
from fast_pointnet_class import save_patches_dataset as save_patches_dataset_class
|
| 18 |
from fast_pointnet_class import predict_class_from_patch
|
|
|
|
|
|
|
| 19 |
|
| 20 |
GENERATE_DATASET = False
|
| 21 |
DATASET_DIR = '/home/skvrnjan/personal/hohocustom/'
|
| 22 |
#DATASET_DIR = '/mnt/personal/skvrnjan/hohocustom'
|
| 23 |
|
| 24 |
-
GENERATE_DATASET_EDGES =
|
| 25 |
-
EDGES_DATASET_DIR = '/home/skvrnjan/personal/hohocustom_edges/'
|
| 26 |
-
|
| 27 |
|
| 28 |
def convert_entry_to_human_readable(entry):
|
| 29 |
out = {}
|
|
@@ -723,12 +725,41 @@ def visu_pcloud_and_preds(colmap_rec, extracted_ids, extracted_points, extracted
|
|
| 723 |
ball_radius = 1.0 # meters
|
| 724 |
plotter.show(title=f"Extracted Points within {ball_radius}m radius - Spheres at group means")
|
| 725 |
|
| 726 |
-
def generate_edge_patches(frame):
|
| 727 |
-
|
| 728 |
-
|
| 729 |
-
|
| 730 |
-
vertices = np.array(
|
| 731 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 732 |
positive_patches = []
|
| 733 |
negative_patches = []
|
| 734 |
|
|
@@ -897,7 +928,7 @@ def generate_edge_patches(frame):
|
|
| 897 |
union_volume = current_volume + pos_volume - overlap_volume
|
| 898 |
if union_volume > 0:
|
| 899 |
iou = overlap_volume / union_volume
|
| 900 |
-
if iou >
|
| 901 |
has_overlap = True
|
| 902 |
break
|
| 903 |
|
|
@@ -954,7 +985,7 @@ def generate_edge_patches(frame):
|
|
| 954 |
all_patches = positive_patches + negative_patches
|
| 955 |
|
| 956 |
# Visualize edge patches
|
| 957 |
-
if
|
| 958 |
# Create plotter
|
| 959 |
plotter = pv.Plotter()
|
| 960 |
|
|
@@ -965,6 +996,24 @@ def generate_edge_patches(frame):
|
|
| 965 |
whole_cloud["colors"] = gray_colors
|
| 966 |
plotter.add_mesh(whole_cloud, scalars="colors", rgb=True, point_size=3, render_points_as_spheres=True)
|
| 967 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 968 |
# Visualize each patch
|
| 969 |
for patch_idx, patch in enumerate(all_patches):
|
| 970 |
# Use green for positive (edge), red for negative (non-edge)
|
|
@@ -1008,7 +1057,7 @@ def generate_edge_patches(frame):
|
|
| 1008 |
# Set title based on label distribution
|
| 1009 |
positive_count = sum(1 for patch in all_patches if patch['label'] == 1)
|
| 1010 |
negative_count = sum(1 for patch in all_patches if patch['label'] == 0)
|
| 1011 |
-
title = f"Edge Patches - Positive (Green): {positive_count}, Negative (Red): {negative_count}"
|
| 1012 |
|
| 1013 |
plotter.show(title=title)
|
| 1014 |
|
|
@@ -1332,11 +1381,6 @@ def predict_wireframe(entry, pnet_model, voxel_model, pnet_class_model, config)
|
|
| 1332 |
edge_threshold = config.get('edge_threshold', 0.5)
|
| 1333 |
only_predicted_connections = config.get('only_predicted_connections', False)
|
| 1334 |
|
| 1335 |
-
if GENERATE_DATASET_EDGES:
|
| 1336 |
-
patches = generate_edge_patches(good_entry)
|
| 1337 |
-
#save_patches_dataset_class(patches, EDGES_DATASET_DIR, good_entry['order_id'])
|
| 1338 |
-
return empty_solution()
|
| 1339 |
-
|
| 1340 |
vert_edge_per_image = {}
|
| 1341 |
idxs_points = []
|
| 1342 |
all_connections = []
|
|
@@ -1467,13 +1511,20 @@ def predict_wireframe(entry, pnet_model, voxel_model, pnet_class_model, config)
|
|
| 1467 |
non_zero_mask = ~np.all(np.isclose(predicted_vertices, [0.0, 0.0, 0.0]), axis=1)
|
| 1468 |
valid_indices = np.where(non_zero_mask)[0]
|
| 1469 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1470 |
if len(valid_indices) == 0:
|
| 1471 |
print("No valid predicted vertices found")
|
| 1472 |
return empty_solution()
|
| 1473 |
|
| 1474 |
-
# Filter vertices to only include non-zero ones
|
| 1475 |
-
filtered_vertices = predicted_vertices[valid_indices]
|
| 1476 |
-
|
| 1477 |
# Create mapping from old indices to new indices
|
| 1478 |
old_to_new_mapping = {old_idx: new_idx for new_idx, old_idx in enumerate(valid_indices)}
|
| 1479 |
|
|
|
|
| 16 |
import time
|
| 17 |
from fast_pointnet_class import save_patches_dataset as save_patches_dataset_class
|
| 18 |
from fast_pointnet_class import predict_class_from_patch
|
| 19 |
+
from scipy.spatial.distance import cdist
|
| 20 |
+
from scipy.optimize import linear_sum_assignment
|
| 21 |
|
| 22 |
GENERATE_DATASET = False
|
| 23 |
DATASET_DIR = '/home/skvrnjan/personal/hohocustom/'
|
| 24 |
#DATASET_DIR = '/mnt/personal/skvrnjan/hohocustom'
|
| 25 |
|
| 26 |
+
GENERATE_DATASET_EDGES = True
|
| 27 |
+
#EDGES_DATASET_DIR = '/home/skvrnjan/personal/hohocustom_edges/'
|
| 28 |
+
EDGES_DATASET_DIR = '/mnt/personal/skvrnjan/hohocustom_edges'
|
| 29 |
|
| 30 |
def convert_entry_to_human_readable(entry):
|
| 31 |
out = {}
|
|
|
|
| 725 |
ball_radius = 1.0 # meters
|
| 726 |
plotter.show(title=f"Extracted Points within {ball_radius}m radius - Spheres at group means")
|
| 727 |
|
| 728 |
+
def generate_edge_patches(frame, pred_vertices):
|
| 729 |
+
gt_vertices = np.array(frame['wf_vertices']) if frame['wf_vertices'] else np.empty((0, 3))
|
| 730 |
+
gt_connections = frame['wf_edges']
|
| 731 |
+
|
| 732 |
+
vertices = np.array(pred_vertices) if pred_vertices is not None and len(pred_vertices) > 0 else np.empty((0, 3))
|
| 733 |
+
|
| 734 |
+
# Find closest GT vertex for each predicted vertex
|
| 735 |
+
connections = []
|
| 736 |
+
if len(vertices) > 0 and len(gt_vertices) > 0:
|
| 737 |
+
# For each GT vertex, find the closest predicted vertex
|
| 738 |
+
gt_to_pred_mapping = {}
|
| 739 |
+
for gt_idx, gt_vertex in enumerate(gt_vertices):
|
| 740 |
+
# Calculate distances from this GT vertex to all predicted vertices
|
| 741 |
+
distances = np.linalg.norm(vertices - gt_vertex, axis=1)
|
| 742 |
+
|
| 743 |
+
# Find the closest predicted vertex
|
| 744 |
+
closest_pred_idx = np.argmin(distances)
|
| 745 |
+
closest_distance = distances[closest_pred_idx]
|
| 746 |
+
|
| 747 |
+
# Only map if within distance threshold
|
| 748 |
+
distance_threshold = 1.5
|
| 749 |
+
if closest_distance <= distance_threshold:
|
| 750 |
+
gt_to_pred_mapping[gt_idx] = closest_pred_idx
|
| 751 |
+
|
| 752 |
+
# Propagate GT connections to predicted vertices
|
| 753 |
+
for gt_connection in gt_connections:
|
| 754 |
+
gt_start, gt_end = gt_connection
|
| 755 |
+
if gt_start in gt_to_pred_mapping and gt_end in gt_to_pred_mapping:
|
| 756 |
+
pred_start = gt_to_pred_mapping[gt_start]
|
| 757 |
+
pred_end = gt_to_pred_mapping[gt_end]
|
| 758 |
+
connections.append((pred_start, pred_end))
|
| 759 |
+
|
| 760 |
+
print(f"Matched {len(gt_to_pred_mapping)} GT vertices to predicted vertices")
|
| 761 |
+
print(f"Propagated {len(connections)} connections from GT to predicted vertices")
|
| 762 |
+
|
| 763 |
positive_patches = []
|
| 764 |
negative_patches = []
|
| 765 |
|
|
|
|
| 928 |
union_volume = current_volume + pos_volume - overlap_volume
|
| 929 |
if union_volume > 0:
|
| 930 |
iou = overlap_volume / union_volume
|
| 931 |
+
if iou > 1.: # 0.2 IoU threshold
|
| 932 |
has_overlap = True
|
| 933 |
break
|
| 934 |
|
|
|
|
| 985 |
all_patches = positive_patches + negative_patches
|
| 986 |
|
| 987 |
# Visualize edge patches
|
| 988 |
+
if False: # Set to True to enable visualization
|
| 989 |
# Create plotter
|
| 990 |
plotter = pv.Plotter()
|
| 991 |
|
|
|
|
| 996 |
whole_cloud["colors"] = gray_colors
|
| 997 |
plotter.add_mesh(whole_cloud, scalars="colors", rgb=True, point_size=3, render_points_as_spheres=True)
|
| 998 |
|
| 999 |
+
# Add GT vertices and connections in blue
|
| 1000 |
+
gt_vertices = np.array(frame['wf_vertices']) if frame['wf_vertices'] else np.empty((0, 3))
|
| 1001 |
+
gt_connections = frame['wf_edges']
|
| 1002 |
+
|
| 1003 |
+
if len(gt_vertices) > 0:
|
| 1004 |
+
# Add GT vertices as blue spheres
|
| 1005 |
+
for gt_vertex in gt_vertices:
|
| 1006 |
+
gt_sphere = pv.Sphere(radius=0.15, center=gt_vertex)
|
| 1007 |
+
plotter.add_mesh(gt_sphere, color='blue', opacity=0.8)
|
| 1008 |
+
|
| 1009 |
+
# Add GT connections as blue lines
|
| 1010 |
+
for gt_connection in gt_connections:
|
| 1011 |
+
gt_start_idx, gt_end_idx = gt_connection
|
| 1012 |
+
if gt_start_idx < len(gt_vertices) and gt_end_idx < len(gt_vertices):
|
| 1013 |
+
gt_line_points = np.array([gt_vertices[gt_start_idx], gt_vertices[gt_end_idx]])
|
| 1014 |
+
gt_line = pv.Line(gt_line_points[0], gt_line_points[1])
|
| 1015 |
+
plotter.add_mesh(gt_line, color='blue', line_width=8)
|
| 1016 |
+
|
| 1017 |
# Visualize each patch
|
| 1018 |
for patch_idx, patch in enumerate(all_patches):
|
| 1019 |
# Use green for positive (edge), red for negative (non-edge)
|
|
|
|
| 1057 |
# Set title based on label distribution
|
| 1058 |
positive_count = sum(1 for patch in all_patches if patch['label'] == 1)
|
| 1059 |
negative_count = sum(1 for patch in all_patches if patch['label'] == 0)
|
| 1060 |
+
title = f"Edge Patches - Positive (Green): {positive_count}, Negative (Red): {negative_count}, GT (Blue)"
|
| 1061 |
|
| 1062 |
plotter.show(title=title)
|
| 1063 |
|
|
|
|
| 1381 |
edge_threshold = config.get('edge_threshold', 0.5)
|
| 1382 |
only_predicted_connections = config.get('only_predicted_connections', False)
|
| 1383 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1384 |
vert_edge_per_image = {}
|
| 1385 |
idxs_points = []
|
| 1386 |
all_connections = []
|
|
|
|
| 1511 |
non_zero_mask = ~np.all(np.isclose(predicted_vertices, [0.0, 0.0, 0.0]), axis=1)
|
| 1512 |
valid_indices = np.where(non_zero_mask)[0]
|
| 1513 |
|
| 1514 |
+
# Filter vertices to only include non-zero ones
|
| 1515 |
+
filtered_vertices = predicted_vertices[valid_indices]
|
| 1516 |
+
|
| 1517 |
+
#patches = generate_edge_patches(good_entry, filtered_vertices)
|
| 1518 |
+
if GENERATE_DATASET_EDGES:
|
| 1519 |
+
patches = generate_edge_patches(good_entry, filtered_vertices)
|
| 1520 |
+
save_patches_dataset_class(patches, EDGES_DATASET_DIR, good_entry['order_id'])
|
| 1521 |
+
|
| 1522 |
+
return empty_solution()
|
| 1523 |
+
|
| 1524 |
if len(valid_indices) == 0:
|
| 1525 |
print("No valid predicted vertices found")
|
| 1526 |
return empty_solution()
|
| 1527 |
|
|
|
|
|
|
|
|
|
|
| 1528 |
# Create mapping from old indices to new indices
|
| 1529 |
old_to_new_mapping = {old_idx: new_idx for new_idx, old_idx in enumerate(valid_indices)}
|
| 1530 |
|
train.py
CHANGED
|
@@ -19,9 +19,9 @@ from fast_voxel import load_3dcnn_model
|
|
| 19 |
from fast_pointnet_class import load_pointnet_model as load_pointnet_class_model
|
| 20 |
import torch
|
| 21 |
|
| 22 |
-
ds = load_dataset("usm3d/hoho25k", cache_dir="/media/skvrnjan/sd/hoho25k/", trust_remote_code=True)
|
| 23 |
-
|
| 24 |
-
|
| 25 |
|
| 26 |
scores_hss = []
|
| 27 |
scores_f1 = []
|
|
@@ -31,11 +31,11 @@ show_visu = False
|
|
| 31 |
|
| 32 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 33 |
|
| 34 |
-
pnet_model = load_pointnet_model(model_path="/home/skvrnjan/personal/hoho_pnet/initial_epoch_100.pth", device=device, predict_score=True)
|
| 35 |
-
|
| 36 |
|
| 37 |
-
pnet_class_model = load_pointnet_class_model(model_path="/home/skvrnjan/personal/hoho_pnet_edges_v2/initial_epoch_100.pth", device=device)
|
| 38 |
-
|
| 39 |
|
| 40 |
#voxel_model = load_3dcnn_model(model_path="/home/skvrnjan/personal/hoho_voxel/initial_epoch_100.pth", device=device, predict_score=True)
|
| 41 |
voxel_model = None
|
|
@@ -46,7 +46,7 @@ idx = 0
|
|
| 46 |
for a in tqdm(ds['train'], desc="Processing dataset"):
|
| 47 |
#plot_all_modalities(a)
|
| 48 |
#pred_vertices, pred_edges = predict_wireframe_old(a)
|
| 49 |
-
pred_vertices, pred_edges = predict_wireframe(a, pnet_model, voxel_model, pnet_class_model, config)
|
| 50 |
try:
|
| 51 |
pred_vertices, pred_edges = predict_wireframe(a, pnet_model, voxel_model, pnet_class_model, config)
|
| 52 |
#pred_vertices, pred_edges = predict_wireframe_old(a)
|
|
@@ -70,8 +70,8 @@ for a in tqdm(ds['train'], desc="Processing dataset"):
|
|
| 70 |
o3d.visualization.draw_geometries(visu_all, window_name=f"3D Reconstruction - HSS: {score.hss:.4f}, F1: {score.f1:.4f}, IoU: {score.iou:.4f}")
|
| 71 |
|
| 72 |
idx += 1
|
| 73 |
-
if idx >= 100: # Limit to first 10 samples for testing
|
| 74 |
-
|
| 75 |
|
| 76 |
for i in range(10):
|
| 77 |
print("END OF DATASET")
|
|
|
|
| 19 |
from fast_pointnet_class import load_pointnet_model as load_pointnet_class_model
|
| 20 |
import torch
|
| 21 |
|
| 22 |
+
#ds = load_dataset("usm3d/hoho25k", cache_dir="/media/skvrnjan/sd/hoho25k/", trust_remote_code=True)
|
| 23 |
+
ds = load_dataset("usm3d/hoho25k", cache_dir="/mnt/personal/skvrnjan/hoho25k/", trust_remote_code=True)
|
| 24 |
+
ds = ds.shuffle()
|
| 25 |
|
| 26 |
scores_hss = []
|
| 27 |
scores_f1 = []
|
|
|
|
| 31 |
|
| 32 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 33 |
|
| 34 |
+
#pnet_model = load_pointnet_model(model_path="/home/skvrnjan/personal/hoho_pnet/initial_epoch_100.pth", device=device, predict_score=True)
|
| 35 |
+
pnet_model = None
|
| 36 |
|
| 37 |
+
#pnet_class_model = load_pointnet_class_model(model_path="/home/skvrnjan/personal/hoho_pnet_edges_v2/initial_epoch_100.pth", device=device)
|
| 38 |
+
pnet_class_model = None
|
| 39 |
|
| 40 |
#voxel_model = load_3dcnn_model(model_path="/home/skvrnjan/personal/hoho_voxel/initial_epoch_100.pth", device=device, predict_score=True)
|
| 41 |
voxel_model = None
|
|
|
|
| 46 |
for a in tqdm(ds['train'], desc="Processing dataset"):
|
| 47 |
#plot_all_modalities(a)
|
| 48 |
#pred_vertices, pred_edges = predict_wireframe_old(a)
|
| 49 |
+
#pred_vertices, pred_edges = predict_wireframe(a, pnet_model, voxel_model, pnet_class_model, config)
|
| 50 |
try:
|
| 51 |
pred_vertices, pred_edges = predict_wireframe(a, pnet_model, voxel_model, pnet_class_model, config)
|
| 52 |
#pred_vertices, pred_edges = predict_wireframe_old(a)
|
|
|
|
| 70 |
o3d.visualization.draw_geometries(visu_all, window_name=f"3D Reconstruction - HSS: {score.hss:.4f}, F1: {score.f1:.4f}, IoU: {score.iou:.4f}")
|
| 71 |
|
| 72 |
idx += 1
|
| 73 |
+
#if idx >= 100: # Limit to first 10 samples for testing
|
| 74 |
+
# break
|
| 75 |
|
| 76 |
for i in range(10):
|
| 77 |
print("END OF DATASET")
|