xx
Browse files- predict.py +350 -193
- train.py +20 -10
predict.py
CHANGED
|
@@ -16,9 +16,11 @@ from fast_pointnet import save_patches_dataset, predict_vertex_from_patch
|
|
| 16 |
#import time
|
| 17 |
from fast_pointnet_class import save_patches_dataset as save_patches_dataset_class
|
| 18 |
from fast_pointnet_class import predict_class_from_patch
|
|
|
|
| 19 |
from scipy.spatial.distance import cdist
|
| 20 |
from scipy.optimize import linear_sum_assignment
|
| 21 |
import torch
|
|
|
|
| 22 |
|
| 23 |
GENERATE_DATASET = False
|
| 24 |
DATASET_DIR = '/home/skvrnjan/personal/hohocustom/'
|
|
@@ -1179,6 +1181,114 @@ def generate_edge_patches_forward(frame, pred_vertices):
|
|
| 1179 |
|
| 1180 |
return forward_patches
|
| 1181 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1182 |
def calculate_cylinder_overlap_volume(cyl1, cyl2):
|
| 1183 |
"""
|
| 1184 |
Calculate the intersection volume between two cylinders using numpy vectorization.
|
|
@@ -1280,119 +1390,232 @@ def calculate_cylinder_overlap_volume(cyl1, cyl2):
|
|
| 1280 |
return max(0.0, overlap_volume)
|
| 1281 |
|
| 1282 |
def create_pcloud(colmap_rec, frame):
|
| 1283 |
-
|
| 1284 |
-
|
| 1285 |
-
|
| 1286 |
-
|
| 1287 |
-
|
| 1288 |
-
|
| 1289 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1290 |
|
| 1291 |
-
|
| 1292 |
-
|
| 1293 |
-
|
| 1294 |
-
|
| 1295 |
-
for i, (K, R, t, img_id, ade, gestalt, depth) in enumerate(zip(frame['K'], frame['R'], frame['t'], frame['image_ids'], frame['ade'], frame['gestalt'], frame['depth'])):
|
| 1296 |
-
for all_imgsid in all_imgs_ids:
|
| 1297 |
-
if all_imgsid == img_id:
|
| 1298 |
-
all_imgs_K.append(np.array(K))
|
| 1299 |
-
all_imgs_R.append(np.array(R))
|
| 1300 |
-
all_imgs_t.append(np.array(t))
|
| 1301 |
-
|
| 1302 |
-
ade_mask = get_house_mask(ade)
|
| 1303 |
-
all_imgs_ade.append(np.array(ade_mask))
|
| 1304 |
-
|
| 1305 |
-
depth_size = (np.array(depth).shape[1], np.array(depth).shape[0]) # W, H
|
| 1306 |
-
gest_seg = gestalt.resize(depth_size)
|
| 1307 |
-
gest_seg_np = np.array(gest_seg).astype(np.uint8)
|
| 1308 |
-
all_imgs_gestalt.append(np.array(gest_seg_np))
|
| 1309 |
|
| 1310 |
-
|
| 1311 |
-
|
| 1312 |
-
|
| 1313 |
-
|
| 1314 |
-
|
| 1315 |
-
|
| 1316 |
-
|
| 1317 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1318 |
|
| 1319 |
-
|
| 1320 |
-
|
| 1321 |
-
|
| 1322 |
-
|
| 1323 |
-
|
| 1324 |
-
|
| 1325 |
-
|
| 1326 |
-
|
| 1327 |
-
|
| 1328 |
-
|
| 1329 |
-
|
| 1330 |
-
|
| 1331 |
-
|
| 1332 |
-
|
| 1333 |
-
|
| 1334 |
-
|
| 1335 |
-
|
| 1336 |
-
|
| 1337 |
-
|
| 1338 |
-
|
| 1339 |
-
|
| 1340 |
-
|
| 1341 |
-
|
| 1342 |
-
|
| 1343 |
-
|
| 1344 |
-
|
| 1345 |
-
|
| 1346 |
-
|
| 1347 |
-
|
| 1348 |
-
|
| 1349 |
-
|
| 1350 |
-
|
| 1351 |
-
|
| 1352 |
-
|
| 1353 |
-
|
| 1354 |
-
|
| 1355 |
-
|
| 1356 |
-
|
| 1357 |
-
|
| 1358 |
-
|
| 1359 |
-
|
| 1360 |
-
|
| 1361 |
-
|
| 1362 |
-
|
| 1363 |
-
|
| 1364 |
-
|
| 1365 |
-
|
| 1366 |
-
|
| 1367 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1368 |
|
| 1369 |
-
points_xyz_world = np.array(points_xyz_world) if points_xyz_world else np.empty((0, 3))
|
| 1370 |
-
points_colors = np.array(points_colors) if points_colors else np.empty((0, 3))
|
| 1371 |
-
points_idxs = np.array(points_idxs) if points_idxs else np.empty((0,))
|
| 1372 |
-
points_ade = np.array(points_ade) if points_ade else np.empty((0,))
|
| 1373 |
|
| 1374 |
-
# Create 7D point cloud from COLMAP data (xyz + rgb + img_count)
|
| 1375 |
if points_xyz_world.shape[0] > 0:
|
| 1376 |
-
colmap_points_7d = np.zeros((
|
| 1377 |
-
colmap_points_7d[:, :3] = points_xyz_world
|
| 1378 |
-
colmap_points_7d[:, 3:6] = points_colors
|
| 1379 |
-
colmap_points_7d[:, 6] = points_idxs
|
| 1380 |
-
|
| 1381 |
-
whole_pcloud = {
|
| 1382 |
-
|
| 1383 |
-
|
| 1384 |
-
|
| 1385 |
-
|
| 1386 |
-
|
| 1387 |
-
|
| 1388 |
-
|
| 1389 |
-
|
|
|
|
|
|
|
| 1390 |
else:
|
| 1391 |
-
whole_pcloud = {
|
| 1392 |
-
|
| 1393 |
-
|
| 1394 |
-
|
| 1395 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1396 |
return whole_pcloud
|
| 1397 |
|
| 1398 |
def predict_wireframe(entry, pnet_model, voxel_model, pnet_class_model, config) -> Tuple[np.ndarray, List[int]]:
|
|
@@ -1401,11 +1624,19 @@ def predict_wireframe(entry, pnet_model, voxel_model, pnet_class_model, config)
|
|
| 1401 |
"""
|
| 1402 |
|
| 1403 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1404 |
|
| 1405 |
good_entry = convert_entry_to_human_readable(entry)
|
| 1406 |
colmap_rec = good_entry['colmap_binary']
|
| 1407 |
|
| 1408 |
-
|
|
|
|
|
|
|
|
|
|
| 1409 |
|
| 1410 |
vertex_threshold = config.get('vertex_threshold', 0.5)
|
| 1411 |
edge_threshold = config.get('edge_threshold', 0.5)
|
|
@@ -1415,8 +1646,6 @@ def predict_wireframe(entry, pnet_model, voxel_model, pnet_class_model, config)
|
|
| 1415 |
idxs_points = []
|
| 1416 |
all_connections = []
|
| 1417 |
|
| 1418 |
-
print(f"Processing {len(good_entry['gestalt'])} images")
|
| 1419 |
-
|
| 1420 |
for i, (gest, depth, K, R, t, img_id, ade_seg) in enumerate(zip(good_entry['gestalt'],
|
| 1421 |
good_entry['depth'],
|
| 1422 |
good_entry['K'],
|
|
@@ -1425,6 +1654,7 @@ def predict_wireframe(entry, pnet_model, voxel_model, pnet_class_model, config)
|
|
| 1425 |
good_entry['image_ids'],
|
| 1426 |
good_entry['ade'] # Added ade20k segmentation
|
| 1427 |
)):
|
|
|
|
| 1428 |
# Visualize gestalt segmentation
|
| 1429 |
K = np.array(K)
|
| 1430 |
R = np.array(R)
|
|
@@ -1436,107 +1666,35 @@ def predict_wireframe(entry, pnet_model, voxel_model, pnet_class_model, config)
|
|
| 1436 |
gest_seg_np = np.array(gest_seg).astype(np.uint8)
|
| 1437 |
|
| 1438 |
vertices_ours, connections_ours, vertices_3d_ours, patches, filtered_point_idxs = our_get_vertices_and_edges(gest_seg_np, colmap_rec, img_id, ade_seg, depth, K=K, R=R, t=t, frame=good_entry)
|
|
|
|
| 1439 |
idxs_points.append(filtered_point_idxs)
|
| 1440 |
all_connections.append(connections_ours)
|
| 1441 |
|
| 1442 |
-
'''
|
| 1443 |
-
if GENERATE_DATASET:
|
| 1444 |
-
save_patches_dataset(patches, DATASET_DIR, img_id)
|
| 1445 |
-
continue
|
| 1446 |
-
'''
|
| 1447 |
-
#for idx, patch in enumerate(patches):
|
| 1448 |
-
#pred_vertex, pred_dist, pred_class = predict_vertex_from_patch(pnet_model, patch, device=device)
|
| 1449 |
-
|
| 1450 |
-
#vertices_3d_ours[idx] = pred_vertex
|
| 1451 |
-
|
| 1452 |
-
#visu_patch_and_pred(patch, pred_vertex, pred_dist, pred_class)
|
| 1453 |
-
|
| 1454 |
-
# x = 0
|
| 1455 |
-
|
| 1456 |
vertices, connections, vertices_3d = vertices_ours, connections_ours, vertices_3d_ours
|
| 1457 |
-
# Get 2D vertices and edges first
|
| 1458 |
-
#vertices, connections = get_vertices_and_edges_from_segmentation(gest_seg_np, edge_th=25.)
|
| 1459 |
-
|
| 1460 |
-
#gt_verts = []
|
| 1461 |
-
#gt_verts, gt_connects, gt_verts3d = get_gt_vertices_and_edges(good_entry, i, depth, colmap_rec, K, R, t, img_id, ade_seg)
|
| 1462 |
-
#vertices, connections = gt_verts, gt_connects
|
| 1463 |
-
|
| 1464 |
-
if False:
|
| 1465 |
-
gest.save(f'gestalt/{img_id}.png')
|
| 1466 |
-
# Save ADE20k segmentation
|
| 1467 |
-
# ade_seg is already a PIL Image
|
| 1468 |
-
try:
|
| 1469 |
-
ade_seg.save(f'ade_segmentations/{img_id}_ade.png')
|
| 1470 |
-
except Exception as e:
|
| 1471 |
-
print(f"Could not save ADE segmentation for {img_id}: {e}")
|
| 1472 |
-
save_gestalt_with_proj(gest_seg_np, gt_verts, img_id)
|
| 1473 |
-
# Define a local helper function to draw crosses and save the image
|
| 1474 |
-
|
| 1475 |
-
# Draw crosses on the ADE segmentation image and save it
|
| 1476 |
-
# 'vertices' here refers to gt_verts
|
| 1477 |
-
draw_crosses_on_image(ade_seg, vertices, f'crosses_{img_id}.png', color=(0, 0, 0), size=5)
|
| 1478 |
-
|
| 1479 |
-
# Check if we have enough to proceed
|
| 1480 |
-
if (len(vertices) < 2) or (len(connections) < 1) and False:
|
| 1481 |
-
print(f'Not enough vertices or connections found in image {i}, skipping.')
|
| 1482 |
-
vert_edge_per_image[i] = [], [], np.empty((0, 3))
|
| 1483 |
-
continue
|
| 1484 |
-
|
| 1485 |
-
# Call the refactored function to get 3D points
|
| 1486 |
-
#vertices_3d = create_3d_wireframe_single_image(vertices, connections, depth, colmap_rec, img_id, ade_seg, K, R, t)
|
| 1487 |
-
#vertices_3d = gt_verts3d
|
| 1488 |
-
# Store original 2D vertices, connections, and computed 3D points
|
| 1489 |
-
#connections = []
|
| 1490 |
-
|
| 1491 |
-
if False:
|
| 1492 |
-
pcd, geometries = plot_reconstruction_local(None, colmap_rec, points=True, cameras=True, crop_outliers=True)
|
| 1493 |
-
wireframe = plot_wireframe_local(None, good_entry['wf_vertices'], good_entry['wf_edges'], good_entry['wf_classifications'])
|
| 1494 |
-
wireframe2 = plot_wireframe_local(None, vertices_3d_ours, connections_ours, None, color='rgb(255, 0, 0)')
|
| 1495 |
-
wireframe3 = plot_wireframe_local(None, vertices_3d, connections, None, color='rgb(0, 0, 255)')
|
| 1496 |
-
bpo_cams = plot_bpo_cameras_from_entry_local(None, good_entry)
|
| 1497 |
-
|
| 1498 |
-
visu_all = [pcd] + geometries + wireframe + bpo_cams + wireframe2 + wireframe3
|
| 1499 |
-
#o3d.visualization.draw_geometries(visu_all, window_name="3D Reconstruction")
|
| 1500 |
|
| 1501 |
vert_edge_per_image[i] = vertices, connections, vertices_3d
|
| 1502 |
-
|
| 1503 |
extracted_points, extracted_colors, extracted_ids, whole_pcloud, connections = extract_vertices_from_whole_pcloud(colmap_rec, idxs_points, all_connections)
|
| 1504 |
|
| 1505 |
-
|
|
|
|
|
|
|
| 1506 |
|
| 1507 |
-
# Predict vertices from patches using the neural network
|
| 1508 |
predicted_vertices = []
|
| 1509 |
-
for patch in patches:
|
| 1510 |
pred_vertex, pred_dist, pred_class = predict_vertex_from_patch(pnet_model, patch, device=device)
|
| 1511 |
|
| 1512 |
-
#visu_patch_and_pred(patch, pred_vertex, pred_dist, pred_class)
|
| 1513 |
-
|
| 1514 |
if pred_class > vertex_threshold:
|
| 1515 |
predicted_vertices.append(pred_vertex)
|
| 1516 |
else:
|
| 1517 |
predicted_vertices.append(np.array([0.0, 0.0, 0.0])) # Append a zero vertex if not predicted
|
| 1518 |
-
|
| 1519 |
-
#pred_vertex_voxel, pred_dist_voxel, pred_class_voxel = predict_vertex_from_patch_voxel(voxel_model, patch, device=device)
|
| 1520 |
-
#visu_patch_and_pred(patch, pred_vertex_voxel, pred_dist_voxel, pred_class_voxel)
|
| 1521 |
|
| 1522 |
predicted_vertices = np.array(predicted_vertices) if predicted_vertices else np.empty((0, 3))
|
| 1523 |
|
| 1524 |
-
#visu_pcloud_and_preds(colmap_rec, extracted_ids, extracted_points, extracted_colors, predicted_vertices, connections)
|
| 1525 |
-
|
| 1526 |
if GENERATE_DATASET:
|
| 1527 |
save_patches_dataset(patches, DATASET_DIR, img_id)
|
| 1528 |
return empty_solution()
|
| 1529 |
|
| 1530 |
-
# Merge vertices from all images
|
| 1531 |
-
#all_3d_vertices, connections_3d = merge_vertices_3d(vert_edge_per_image, 0.1)
|
| 1532 |
-
#all_3d_vertices_clean, connections_3d_clean = all_3d_vertices, connections_3d
|
| 1533 |
-
#all_3d_vertices_clean, connections_3d_clean = prune_not_connected(all_3d_vertices, connections_3d, keep_largest=False)
|
| 1534 |
-
#all_3d_vertices_clean, connections_3d_clean = prune_too_far(all_3d_vertices_clean, connections_3d_clean, colmap_rec, th = 1.5)
|
| 1535 |
-
|
| 1536 |
-
#if (len(all_3d_vertices_clean) < 2) or len(connections_3d_clean) < 1 and False:
|
| 1537 |
-
# print (f'Not enough vertices or connections in the 3D vertices')
|
| 1538 |
-
# return empty_solution()
|
| 1539 |
-
|
| 1540 |
# Filter out zero vertices and update connections accordingly
|
| 1541 |
non_zero_mask = ~np.all(np.isclose(predicted_vertices, [0.0, 0.0, 0.0]), axis=1)
|
| 1542 |
valid_indices = np.where(non_zero_mask)[0]
|
|
@@ -1544,11 +1702,9 @@ def predict_wireframe(entry, pnet_model, voxel_model, pnet_class_model, config)
|
|
| 1544 |
# Filter vertices to only include non-zero ones
|
| 1545 |
filtered_vertices = predicted_vertices[valid_indices]
|
| 1546 |
|
| 1547 |
-
#patches = generate_edge_patches(good_entry, filtered_vertices, colmap_pcloud)
|
| 1548 |
if GENERATE_DATASET_EDGES:
|
| 1549 |
patches = generate_edge_patches(good_entry, filtered_vertices, colmap_pcloud)
|
| 1550 |
save_patches_dataset_class(patches, EDGES_DATASET_DIR, good_entry['order_id'])
|
| 1551 |
-
|
| 1552 |
return empty_solution()
|
| 1553 |
|
| 1554 |
if len(valid_indices) == 0:
|
|
@@ -1566,17 +1722,18 @@ def predict_wireframe(entry, pnet_model, voxel_model, pnet_class_model, config)
|
|
| 1566 |
new_end = old_to_new_mapping[end_idx]
|
| 1567 |
if new_start != new_end: # Ensure we don't connect a vertex to itself
|
| 1568 |
filtered_connections.append((new_start, new_end))
|
| 1569 |
-
|
| 1570 |
-
#print(f"Filtered vertices: {len(filtered_vertices)} from {len(predicted_vertices)}")
|
| 1571 |
-
#print(f"Filtered connections: {len(filtered_connections)} from {len(connections)}")
|
| 1572 |
|
|
|
|
|
|
|
| 1573 |
forward_patches = generate_edge_patches_forward(good_entry, filtered_vertices)
|
|
|
|
| 1574 |
new_connections = []
|
| 1575 |
if len(forward_patches) > 0:
|
| 1576 |
-
for patch in forward_patches:
|
| 1577 |
start_idx, end_idx = patch['connection']
|
| 1578 |
|
| 1579 |
pred_class, pred_score = predict_class_from_patch(pnet_class_model, patch, device=device)
|
|
|
|
| 1580 |
if pred_score > edge_threshold:
|
| 1581 |
new_connections.append((start_idx, end_idx))
|
| 1582 |
|
|
|
|
| 16 |
#import time
|
| 17 |
from fast_pointnet_class import save_patches_dataset as save_patches_dataset_class
|
| 18 |
from fast_pointnet_class import predict_class_from_patch
|
| 19 |
+
from fast_pointnet_class_10d import predict_class_from_patch as predict_class_from_patch_10d
|
| 20 |
from scipy.spatial.distance import cdist
|
| 21 |
from scipy.optimize import linear_sum_assignment
|
| 22 |
import torch
|
| 23 |
+
import time
|
| 24 |
|
| 25 |
GENERATE_DATASET = False
|
| 26 |
DATASET_DIR = '/home/skvrnjan/personal/hohocustom/'
|
|
|
|
| 1181 |
|
| 1182 |
return forward_patches
|
| 1183 |
|
| 1184 |
+
def generate_edge_patches_forward_10d(frame, pred_vertices, colmap_pcloud):
|
| 1185 |
+
vertices = pred_vertices
|
| 1186 |
+
|
| 1187 |
+
cylinder_radius = 0.5 # meters
|
| 1188 |
+
|
| 1189 |
+
points_6d = colmap_pcloud['points_7d'][:, :7]
|
| 1190 |
+
points_6d[:, 3:6] = points_6d[:, 3:6] * 2 - 1 # Normalize RGB colors to [0, 1]
|
| 1191 |
+
ade = colmap_pcloud['ade']
|
| 1192 |
+
ade = np.where(ade, 1, -1) # Normalize to [-1, 1]
|
| 1193 |
+
gestalt = colmap_pcloud['gestalt']
|
| 1194 |
+
|
| 1195 |
+
# Fuse multiple gestalt values per point using majority voting
|
| 1196 |
+
fused_gestalt = []
|
| 1197 |
+
for point_gestalt_list in gestalt:
|
| 1198 |
+
if len(point_gestalt_list) == 0:
|
| 1199 |
+
fused_gestalt.append(np.array([0, 0, 0]))
|
| 1200 |
+
elif len(point_gestalt_list) == 1:
|
| 1201 |
+
fused_gestalt.append(point_gestalt_list[0])
|
| 1202 |
+
else:
|
| 1203 |
+
# Convert to numpy array for easier manipulation
|
| 1204 |
+
gestalt_values = np.array(point_gestalt_list)
|
| 1205 |
+
|
| 1206 |
+
# Method 1: Average the RGB values
|
| 1207 |
+
fused_value = np.mean(gestalt_values, axis=0).astype(np.uint8)
|
| 1208 |
+
|
| 1209 |
+
fused_gestalt.append(fused_value)
|
| 1210 |
+
|
| 1211 |
+
gestalt = np.array(fused_gestalt)
|
| 1212 |
+
gestalt = (gestalt / 255) * 2 - 1 # Normalize to [-1, 1]
|
| 1213 |
+
|
| 1214 |
+
# Extract 3D coordinates for faster vectorized operations
|
| 1215 |
+
colmap_points_3d = points_6d[:, :3]
|
| 1216 |
+
|
| 1217 |
+
# Create combined 10D point cloud (xyz + rgb + ade + gestalt)
|
| 1218 |
+
colmap_points_10d = np.zeros((len(colmap_points_3d), 10))
|
| 1219 |
+
colmap_points_10d[:, :3] = colmap_points_3d # xyz coordinates
|
| 1220 |
+
colmap_points_10d[:, 3:6] = points_6d[:, 3:6] # rgb colors (already normalized to [-1, 1])
|
| 1221 |
+
colmap_points_10d[:, 6] = ade # ade values (normalized to [-1, 1])
|
| 1222 |
+
colmap_points_10d[:, 7:10] = gestalt # gestalt values (normalized to [-1, 1], all 3 RGB channels)
|
| 1223 |
+
|
| 1224 |
+
forward_patches = []
|
| 1225 |
+
|
| 1226 |
+
# For each vertex pair, create a patch without label
|
| 1227 |
+
for i in range(len(vertices)):
|
| 1228 |
+
for j in range(i + 1, len(vertices)):
|
| 1229 |
+
start_vertex = vertices[i]
|
| 1230 |
+
end_vertex = vertices[j]
|
| 1231 |
+
|
| 1232 |
+
# Create line vector from start to end
|
| 1233 |
+
line_vector = end_vertex - start_vertex
|
| 1234 |
+
line_length = np.linalg.norm(line_vector)
|
| 1235 |
+
|
| 1236 |
+
# Normalize line vector
|
| 1237 |
+
line_direction = line_vector / line_length
|
| 1238 |
+
|
| 1239 |
+
# Extend the line by 1 meter on both ends for more context
|
| 1240 |
+
extension_length = 0.25 # 1 meter in meters
|
| 1241 |
+
extended_start = start_vertex - extension_length * line_direction
|
| 1242 |
+
extended_end = end_vertex + extension_length * line_direction
|
| 1243 |
+
extended_line_length = line_length + 2 * extension_length
|
| 1244 |
+
|
| 1245 |
+
# Vectorized distance calculation
|
| 1246 |
+
# Vector from extended start to all points
|
| 1247 |
+
start_to_points = colmap_points_3d - extended_start[np.newaxis, :]
|
| 1248 |
+
|
| 1249 |
+
# Project onto line direction to get distance along extended line
|
| 1250 |
+
projection_lengths = np.dot(start_to_points, line_direction)
|
| 1251 |
+
|
| 1252 |
+
# Filter points within extended line segment bounds
|
| 1253 |
+
within_bounds = (projection_lengths >= 0) & (projection_lengths <= extended_line_length)
|
| 1254 |
+
|
| 1255 |
+
# Find closest points on extended line segment for all points
|
| 1256 |
+
closest_points_on_line = extended_start[np.newaxis, :] + projection_lengths[:, np.newaxis] * line_direction[np.newaxis, :]
|
| 1257 |
+
|
| 1258 |
+
# Calculate perpendicular distances from points to line
|
| 1259 |
+
perpendicular_distances = np.linalg.norm(colmap_points_3d - closest_points_on_line, axis=1)
|
| 1260 |
+
|
| 1261 |
+
# Find points within cylinder
|
| 1262 |
+
within_cylinder = within_bounds & (perpendicular_distances <= cylinder_radius)
|
| 1263 |
+
|
| 1264 |
+
if np.sum(within_cylinder) <= 10:
|
| 1265 |
+
continue
|
| 1266 |
+
|
| 1267 |
+
points_in_cylinder = colmap_points_10d[within_cylinder]
|
| 1268 |
+
point_indices_in_cylinder = np.where(within_cylinder)[0]
|
| 1269 |
+
|
| 1270 |
+
# Center the patch at the midpoint of the original line (not extended)
|
| 1271 |
+
line_midpoint = (start_vertex + end_vertex) / 2
|
| 1272 |
+
|
| 1273 |
+
# Shift points to center around origin
|
| 1274 |
+
points_centered = points_in_cylinder.copy()
|
| 1275 |
+
points_centered[:, :3] -= line_midpoint
|
| 1276 |
+
|
| 1277 |
+
# Create edge patch without label
|
| 1278 |
+
edge_patch = {
|
| 1279 |
+
'patch_10d': points_centered,
|
| 1280 |
+
'connection': (i, j),
|
| 1281 |
+
'line_start': start_vertex - line_midpoint,
|
| 1282 |
+
'line_end': end_vertex - line_midpoint,
|
| 1283 |
+
'cylinder_radius': cylinder_radius,
|
| 1284 |
+
'point_indices': point_indices_in_cylinder,
|
| 1285 |
+
'center': line_midpoint
|
| 1286 |
+
}
|
| 1287 |
+
|
| 1288 |
+
forward_patches.append(edge_patch)
|
| 1289 |
+
|
| 1290 |
+
return forward_patches
|
| 1291 |
+
|
| 1292 |
def calculate_cylinder_overlap_volume(cyl1, cyl2):
|
| 1293 |
"""
|
| 1294 |
Calculate the intersection volume between two cylinders using numpy vectorization.
|
|
|
|
| 1390 |
return max(0.0, overlap_volume)
|
| 1391 |
|
| 1392 |
def create_pcloud(colmap_rec, frame):
|
| 1393 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 1394 |
+
#print(f"create_pcloud using device: {device}")
|
| 1395 |
+
|
| 1396 |
+
# 1. Preprocess image data from the frame and colmap (mostly on CPU)
|
| 1397 |
+
img_id_to_colmap_img_obj_map = {
|
| 1398 |
+
img_obj.name: img_obj for img_obj_name, img_obj in colmap_rec.images.items()
|
| 1399 |
+
}
|
| 1400 |
+
|
| 1401 |
+
frame_img_data = {}
|
| 1402 |
+
ordered_frame_img_ids = []
|
| 1403 |
+
|
| 1404 |
+
for K_val, R_val, t_val, img_id_val, ade_val, gestalt_val, depth_val in zip(
|
| 1405 |
+
frame['K'], frame['R'], frame['t'], frame['image_ids'],
|
| 1406 |
+
frame['ade'], frame['gestalt'], frame['depth']
|
| 1407 |
+
):
|
| 1408 |
+
if img_id_val not in img_id_to_colmap_img_obj_map:
|
| 1409 |
+
continue
|
| 1410 |
|
| 1411 |
+
ordered_frame_img_ids.append(img_id_val)
|
| 1412 |
+
depth_np = np.array(depth_val)
|
| 1413 |
+
depth_H, depth_W = depth_np.shape[0], depth_np.shape[1]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1414 |
|
| 1415 |
+
ade_mask_np = get_house_mask(ade_val)
|
| 1416 |
+
|
| 1417 |
+
gest_seg_pil = gestalt_val.resize((depth_W, depth_H), Image.Resampling.NEAREST)
|
| 1418 |
+
gest_seg_np = np.array(gest_seg_pil).astype(np.uint8)
|
| 1419 |
+
|
| 1420 |
+
frame_img_data[img_id_val] = {
|
| 1421 |
+
'K_np': np.array(K_val),
|
| 1422 |
+
'R_np': np.array(R_val),
|
| 1423 |
+
't_np': np.array(t_val).reshape(3,1),
|
| 1424 |
+
'ade_mask_np': ade_mask_np,
|
| 1425 |
+
'gestalt_seg_np': gest_seg_np,
|
| 1426 |
+
'H': depth_H,
|
| 1427 |
+
'W': depth_W
|
| 1428 |
+
}
|
| 1429 |
|
| 1430 |
+
# 2. Process 3D points by iterating through images
|
| 1431 |
+
point_data_accumulator = {} # Key: pid, accumulates data on CPU
|
| 1432 |
+
|
| 1433 |
+
# Pre-fetch all COLMAP point data to avoid repeated dictionary lookups
|
| 1434 |
+
colmap_points_data_cpu = {
|
| 1435 |
+
pid: {'xyz': p3D.xyz, 'color': p3D.color / 255.0}
|
| 1436 |
+
for pid, p3D in colmap_rec.points3D.items()
|
| 1437 |
+
}
|
| 1438 |
+
|
| 1439 |
+
for img_id in ordered_frame_img_ids:
|
| 1440 |
+
if img_id not in frame_img_data:
|
| 1441 |
+
continue
|
| 1442 |
+
|
| 1443 |
+
col_img_obj = img_id_to_colmap_img_obj_map[img_id]
|
| 1444 |
+
img_data = frame_img_data[img_id]
|
| 1445 |
+
|
| 1446 |
+
K_np, R_np, t_np = img_data['K_np'], img_data['R_np'], img_data['t_np']
|
| 1447 |
+
ade_mask_np, gestalt_seg_np = img_data['ade_mask_np'], img_data['gestalt_seg_np']
|
| 1448 |
+
H, W = img_data['H'], img_data['W']
|
| 1449 |
+
|
| 1450 |
+
# Convert current image data to GPU tensors
|
| 1451 |
+
K_gpu = torch.from_numpy(K_np).float().to(device)
|
| 1452 |
+
R_gpu = torch.from_numpy(R_np).float().to(device)
|
| 1453 |
+
t_gpu = torch.from_numpy(t_np).float().to(device)
|
| 1454 |
+
ade_mask_gpu = torch.from_numpy(ade_mask_np).bool().to(device)
|
| 1455 |
+
gestalt_seg_gpu = torch.from_numpy(gestalt_seg_np).to(device) # uint8 is fine
|
| 1456 |
+
|
| 1457 |
+
visible_pids_in_img = []
|
| 1458 |
+
visible_xyz_coords_list = []
|
| 1459 |
+
|
| 1460 |
+
for pid, p3D_data in colmap_points_data_cpu.items():
|
| 1461 |
+
if col_img_obj.has_point3D(pid): # This check remains CPU-bound
|
| 1462 |
+
visible_pids_in_img.append(pid)
|
| 1463 |
+
visible_xyz_coords_list.append(p3D_data['xyz'])
|
| 1464 |
+
|
| 1465 |
+
if not visible_pids_in_img:
|
| 1466 |
+
continue
|
| 1467 |
+
|
| 1468 |
+
num_visible_points = len(visible_pids_in_img)
|
| 1469 |
+
world_pts_np = np.array(visible_xyz_coords_list)
|
| 1470 |
+
world_pts_gpu = torch.from_numpy(world_pts_np).float().to(device)
|
| 1471 |
+
|
| 1472 |
+
# Batch projection on GPU
|
| 1473 |
+
world_pts_h_gpu = torch.cat((world_pts_gpu, torch.ones(num_visible_points, 1, device=device)), dim=1)
|
| 1474 |
+
P_world_to_cam_gpu = torch.hstack((R_gpu, t_gpu))
|
| 1475 |
+
cam_coords_proj_gpu = P_world_to_cam_gpu @ world_pts_h_gpu.T
|
| 1476 |
+
|
| 1477 |
+
cam_coords_z_gpu = cam_coords_proj_gpu[2, :]
|
| 1478 |
+
in_front_mask_gpu = cam_coords_z_gpu > 1e-6
|
| 1479 |
+
|
| 1480 |
+
pixel_coords_h_gpu = K_gpu @ cam_coords_proj_gpu
|
| 1481 |
+
|
| 1482 |
+
u_proj_gpu = torch.full_like(cam_coords_z_gpu, -1.0, dtype=torch.float32)
|
| 1483 |
+
v_proj_gpu = torch.full_like(cam_coords_z_gpu, -1.0, dtype=torch.float32)
|
| 1484 |
+
|
| 1485 |
+
# Avoid division by zero/small numbers for points not truly in front or on optical center
|
| 1486 |
+
valid_depth_mask_gpu = in_front_mask_gpu & (torch.abs(cam_coords_z_gpu) > 1e-6)
|
| 1487 |
+
|
| 1488 |
+
if torch.any(valid_depth_mask_gpu):
|
| 1489 |
+
u_proj_gpu[valid_depth_mask_gpu] = pixel_coords_h_gpu[0, valid_depth_mask_gpu] / cam_coords_z_gpu[valid_depth_mask_gpu]
|
| 1490 |
+
v_proj_gpu[valid_depth_mask_gpu] = pixel_coords_h_gpu[1, valid_depth_mask_gpu] / cam_coords_z_gpu[valid_depth_mask_gpu]
|
| 1491 |
+
|
| 1492 |
+
u_rounded_gpu = torch.round(u_proj_gpu).long()
|
| 1493 |
+
v_rounded_gpu = torch.round(v_proj_gpu).long()
|
| 1494 |
+
|
| 1495 |
+
is_in_bounds_gpu = (u_rounded_gpu >= 0) & (u_rounded_gpu < W) & \
|
| 1496 |
+
(v_rounded_gpu >= 0) & (v_rounded_gpu < H) & \
|
| 1497 |
+
in_front_mask_gpu # Re-check in_front_mask_gpu as rounding might affect edge cases slightly
|
| 1498 |
+
|
| 1499 |
+
# Sample ADE and Gestalt on GPU for points in bounds
|
| 1500 |
+
# Initialize with default values for all points, then update for those in bounds
|
| 1501 |
+
sampled_ade_status_gpu = torch.zeros(num_visible_points, dtype=torch.bool, device=device)
|
| 1502 |
+
sampled_gestalt_values_gpu = torch.zeros(num_visible_points, 3, dtype=torch.uint8, device=device)
|
| 1503 |
+
|
| 1504 |
+
# Create a mask for points that are valid for sampling (in_bounds and in_front)
|
| 1505 |
+
valid_for_sampling_mask_gpu = is_in_bounds_gpu
|
| 1506 |
+
|
| 1507 |
+
if torch.any(valid_for_sampling_mask_gpu):
|
| 1508 |
+
u_sample_gpu = u_rounded_gpu[valid_for_sampling_mask_gpu]
|
| 1509 |
+
v_sample_gpu = v_rounded_gpu[valid_for_sampling_mask_gpu]
|
| 1510 |
+
|
| 1511 |
+
sampled_ade_status_gpu[valid_for_sampling_mask_gpu] = ade_mask_gpu[v_sample_gpu, u_sample_gpu]
|
| 1512 |
+
sampled_gestalt_values_gpu[valid_for_sampling_mask_gpu] = gestalt_seg_gpu[v_sample_gpu, u_sample_gpu]
|
| 1513 |
+
|
| 1514 |
+
# Transfer necessary results back to CPU for accumulation
|
| 1515 |
+
u_rounded_cpu = u_rounded_gpu.cpu().numpy()
|
| 1516 |
+
v_rounded_cpu = v_rounded_gpu.cpu().numpy()
|
| 1517 |
+
is_in_bounds_cpu = is_in_bounds_gpu.cpu().numpy() # Use the original is_in_bounds_gpu for logic
|
| 1518 |
+
sampled_ade_status_cpu = sampled_ade_status_gpu.cpu().numpy()
|
| 1519 |
+
sampled_gestalt_values_cpu = sampled_gestalt_values_gpu.cpu().numpy()
|
| 1520 |
+
|
| 1521 |
+
|
| 1522 |
+
# Update accumulator (on CPU)
|
| 1523 |
+
for i in range(num_visible_points):
|
| 1524 |
+
pid = visible_pids_in_img[i]
|
| 1525 |
+
|
| 1526 |
+
if pid not in point_data_accumulator:
|
| 1527 |
+
point_data_accumulator[pid] = {
|
| 1528 |
+
'xyz': colmap_points_data_cpu[pid]['xyz'],
|
| 1529 |
+
'color': colmap_points_data_cpu[pid]['color'],
|
| 1530 |
+
'imgs_seen_by': [],
|
| 1531 |
+
'uv_projections': [],
|
| 1532 |
+
'ade_status': False,
|
| 1533 |
+
'gestalt_values': []
|
| 1534 |
+
}
|
| 1535 |
+
|
| 1536 |
+
acc = point_data_accumulator[pid]
|
| 1537 |
+
acc['imgs_seen_by'].append(img_id)
|
| 1538 |
+
acc['uv_projections'].append((u_rounded_cpu[i], v_rounded_cpu[i]))
|
| 1539 |
+
|
| 1540 |
+
if is_in_bounds_cpu[i]: # This point was projected within bounds and in front
|
| 1541 |
+
acc['ade_status'] = sampled_ade_status_cpu[i]
|
| 1542 |
+
acc['gestalt_values'].append(sampled_gestalt_values_cpu[i])
|
| 1543 |
+
else: # Point projected out of bounds, behind, or failed depth check
|
| 1544 |
+
acc['gestalt_values'].append(np.array([0,0,0], dtype=np.uint8))
|
| 1545 |
+
|
| 1546 |
+
# Optional: clear GPU cache if memory is a concern for many images
|
| 1547 |
+
# if device.type == 'cuda':
|
| 1548 |
+
# torch.cuda.empty_cache()
|
| 1549 |
+
|
| 1550 |
+
|
| 1551 |
+
# 3. Final data assembly (on CPU)
|
| 1552 |
+
points_xyz_world_list = []
|
| 1553 |
+
points_colors_list = []
|
| 1554 |
+
points_idxs_list = []
|
| 1555 |
+
points_imgs_seen_by_list = []
|
| 1556 |
+
points_uv_projections_per_point_list = []
|
| 1557 |
+
points_ade_status_final_list = []
|
| 1558 |
+
points_gestalt_values_per_point_list = []
|
| 1559 |
+
|
| 1560 |
+
# Ensure consistent order if downstream code relies on it, though original didn't specify sorting for pids
|
| 1561 |
+
# Using sorted_pids for reproducibility if point_data_accumulator keys order changes.
|
| 1562 |
+
sorted_pids = sorted(point_data_accumulator.keys())
|
| 1563 |
+
|
| 1564 |
+
for pid in sorted_pids:
|
| 1565 |
+
data = point_data_accumulator[pid]
|
| 1566 |
+
points_xyz_world_list.append(data['xyz'])
|
| 1567 |
+
points_colors_list.append(data['color'])
|
| 1568 |
+
points_idxs_list.append(pid)
|
| 1569 |
+
points_imgs_seen_by_list.append(data['imgs_seen_by'])
|
| 1570 |
+
points_uv_projections_per_point_list.append(data['uv_projections'])
|
| 1571 |
+
points_ade_status_final_list.append(data['ade_status'])
|
| 1572 |
+
points_gestalt_values_per_point_list.append(data['gestalt_values'])
|
| 1573 |
+
|
| 1574 |
+
points_xyz_world = np.array(points_xyz_world_list) if points_xyz_world_list else np.empty((0, 3))
|
| 1575 |
+
points_colors = np.array(points_colors_list) if points_colors_list else np.empty((0, 3))
|
| 1576 |
+
points_idxs = np.array(points_idxs_list, dtype=int) if points_idxs_list else np.empty((0,), dtype=int) # Ensure dtype for pids
|
| 1577 |
+
points_ade = np.array(points_ade_status_final_list, dtype=bool) if points_ade_status_final_list else np.empty((0,), dtype=bool)
|
| 1578 |
+
|
| 1579 |
+
output_all_colmap_img_ids = [img_obj.name for img_obj_name, img_obj in colmap_rec.images.items()]
|
| 1580 |
+
output_frame_K, output_frame_R, output_frame_t = [], [], []
|
| 1581 |
+
|
| 1582 |
+
for img_id_val in frame['image_ids']:
|
| 1583 |
+
if img_id_val in frame_img_data:
|
| 1584 |
+
data = frame_img_data[img_id_val]
|
| 1585 |
+
output_frame_K.append(data['K_np'])
|
| 1586 |
+
output_frame_R.append(data['R_np'])
|
| 1587 |
+
output_frame_t.append(data['t_np'])
|
| 1588 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1589 |
|
|
|
|
| 1590 |
if points_xyz_world.shape[0] > 0:
|
| 1591 |
+
colmap_points_7d = np.zeros((points_xyz_world.shape[0], 7))
|
| 1592 |
+
colmap_points_7d[:, :3] = points_xyz_world
|
| 1593 |
+
colmap_points_7d[:, 3:6] = points_colors
|
| 1594 |
+
colmap_points_7d[:, 6] = points_idxs
|
| 1595 |
+
|
| 1596 |
+
whole_pcloud = {
|
| 1597 |
+
'points_7d': colmap_points_7d,
|
| 1598 |
+
'imgs': points_imgs_seen_by_list,
|
| 1599 |
+
'uv': points_uv_projections_per_point_list,
|
| 1600 |
+
'all_imgs_ids': output_all_colmap_img_ids,
|
| 1601 |
+
'all_imgs_K': output_frame_K,
|
| 1602 |
+
'all_imgs_R': output_frame_R,
|
| 1603 |
+
'all_imgs_t': output_frame_t,
|
| 1604 |
+
'ade': points_ade,
|
| 1605 |
+
'gestalt': points_gestalt_values_per_point_list
|
| 1606 |
+
}
|
| 1607 |
else:
|
| 1608 |
+
whole_pcloud = {
|
| 1609 |
+
'points_7d': np.empty((0, 7)),
|
| 1610 |
+
'imgs': [],
|
| 1611 |
+
'uv': [],
|
| 1612 |
+
'all_imgs_ids': output_all_colmap_img_ids,
|
| 1613 |
+
'all_imgs_K': output_frame_K,
|
| 1614 |
+
'all_imgs_R': output_frame_R,
|
| 1615 |
+
'all_imgs_t': output_frame_t,
|
| 1616 |
+
'ade': np.empty((0,), dtype=bool),
|
| 1617 |
+
'gestalt': []
|
| 1618 |
+
}
|
| 1619 |
return whole_pcloud
|
| 1620 |
|
| 1621 |
def predict_wireframe(entry, pnet_model, voxel_model, pnet_class_model, config) -> Tuple[np.ndarray, List[int]]:
|
|
|
|
| 1624 |
"""
|
| 1625 |
|
| 1626 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 1627 |
+
# Delete specified keys from the entry
|
| 1628 |
+
#keys_to_delete = ['wf_vertices', 'wf_edges', 'wf_classifications']
|
| 1629 |
+
#for key in keys_to_delete:
|
| 1630 |
+
# if key in entry:
|
| 1631 |
+
# del entry[key]
|
| 1632 |
|
| 1633 |
good_entry = convert_entry_to_human_readable(entry)
|
| 1634 |
colmap_rec = good_entry['colmap_binary']
|
| 1635 |
|
| 1636 |
+
#start_time = time.time()
|
| 1637 |
+
#colmap_pcloud = create_pcloud(colmap_rec, good_entry)
|
| 1638 |
+
#end_time = time.time()
|
| 1639 |
+
#print(f"create_pcloud took {end_time - start_time:.4f} seconds")
|
| 1640 |
|
| 1641 |
vertex_threshold = config.get('vertex_threshold', 0.5)
|
| 1642 |
edge_threshold = config.get('edge_threshold', 0.5)
|
|
|
|
| 1646 |
idxs_points = []
|
| 1647 |
all_connections = []
|
| 1648 |
|
|
|
|
|
|
|
| 1649 |
for i, (gest, depth, K, R, t, img_id, ade_seg) in enumerate(zip(good_entry['gestalt'],
|
| 1650 |
good_entry['depth'],
|
| 1651 |
good_entry['K'],
|
|
|
|
| 1654 |
good_entry['image_ids'],
|
| 1655 |
good_entry['ade'] # Added ade20k segmentation
|
| 1656 |
)):
|
| 1657 |
+
|
| 1658 |
# Visualize gestalt segmentation
|
| 1659 |
K = np.array(K)
|
| 1660 |
R = np.array(R)
|
|
|
|
| 1666 |
gest_seg_np = np.array(gest_seg).astype(np.uint8)
|
| 1667 |
|
| 1668 |
vertices_ours, connections_ours, vertices_3d_ours, patches, filtered_point_idxs = our_get_vertices_and_edges(gest_seg_np, colmap_rec, img_id, ade_seg, depth, K=K, R=R, t=t, frame=good_entry)
|
| 1669 |
+
|
| 1670 |
idxs_points.append(filtered_point_idxs)
|
| 1671 |
all_connections.append(connections_ours)
|
| 1672 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1673 |
vertices, connections, vertices_3d = vertices_ours, connections_ours, vertices_3d_ours
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1674 |
|
| 1675 |
vert_edge_per_image[i] = vertices, connections, vertices_3d
|
| 1676 |
+
|
| 1677 |
extracted_points, extracted_colors, extracted_ids, whole_pcloud, connections = extract_vertices_from_whole_pcloud(colmap_rec, idxs_points, all_connections)
|
| 1678 |
|
| 1679 |
+
wf_vertices = good_entry.get('wf_vertices', None)
|
| 1680 |
+
|
| 1681 |
+
patches = generate_patches_v2(extracted_points, extracted_colors, extracted_ids, whole_pcloud, wf_vertices)
|
| 1682 |
|
|
|
|
| 1683 |
predicted_vertices = []
|
| 1684 |
+
for i, patch in enumerate(patches):
|
| 1685 |
pred_vertex, pred_dist, pred_class = predict_vertex_from_patch(pnet_model, patch, device=device)
|
| 1686 |
|
|
|
|
|
|
|
| 1687 |
if pred_class > vertex_threshold:
|
| 1688 |
predicted_vertices.append(pred_vertex)
|
| 1689 |
else:
|
| 1690 |
predicted_vertices.append(np.array([0.0, 0.0, 0.0])) # Append a zero vertex if not predicted
|
|
|
|
|
|
|
|
|
|
| 1691 |
|
| 1692 |
predicted_vertices = np.array(predicted_vertices) if predicted_vertices else np.empty((0, 3))
|
| 1693 |
|
|
|
|
|
|
|
| 1694 |
if GENERATE_DATASET:
|
| 1695 |
save_patches_dataset(patches, DATASET_DIR, img_id)
|
| 1696 |
return empty_solution()
|
| 1697 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1698 |
# Filter out zero vertices and update connections accordingly
|
| 1699 |
non_zero_mask = ~np.all(np.isclose(predicted_vertices, [0.0, 0.0, 0.0]), axis=1)
|
| 1700 |
valid_indices = np.where(non_zero_mask)[0]
|
|
|
|
| 1702 |
# Filter vertices to only include non-zero ones
|
| 1703 |
filtered_vertices = predicted_vertices[valid_indices]
|
| 1704 |
|
|
|
|
| 1705 |
if GENERATE_DATASET_EDGES:
|
| 1706 |
patches = generate_edge_patches(good_entry, filtered_vertices, colmap_pcloud)
|
| 1707 |
save_patches_dataset_class(patches, EDGES_DATASET_DIR, good_entry['order_id'])
|
|
|
|
| 1708 |
return empty_solution()
|
| 1709 |
|
| 1710 |
if len(valid_indices) == 0:
|
|
|
|
| 1722 |
new_end = old_to_new_mapping[end_idx]
|
| 1723 |
if new_start != new_end: # Ensure we don't connect a vertex to itself
|
| 1724 |
filtered_connections.append((new_start, new_end))
|
|
|
|
|
|
|
|
|
|
| 1725 |
|
| 1726 |
+
# Generate forward edge patches
|
| 1727 |
+
#forward_patches = generate_edge_patches_forward_10d(good_entry, filtered_vertices, colmap_pcloud)
|
| 1728 |
forward_patches = generate_edge_patches_forward(good_entry, filtered_vertices)
|
| 1729 |
+
|
| 1730 |
new_connections = []
|
| 1731 |
if len(forward_patches) > 0:
|
| 1732 |
+
for i, patch in enumerate(forward_patches):
|
| 1733 |
start_idx, end_idx = patch['connection']
|
| 1734 |
|
| 1735 |
pred_class, pred_score = predict_class_from_patch(pnet_class_model, patch, device=device)
|
| 1736 |
+
|
| 1737 |
if pred_score > edge_threshold:
|
| 1738 |
new_connections.append((start_idx, end_idx))
|
| 1739 |
|
train.py
CHANGED
|
@@ -17,11 +17,13 @@ from tqdm import tqdm
|
|
| 17 |
from fast_pointnet import load_pointnet_model
|
| 18 |
from fast_voxel import load_3dcnn_model
|
| 19 |
from fast_pointnet_class import load_pointnet_model as load_pointnet_class_model
|
|
|
|
| 20 |
import torch
|
|
|
|
| 21 |
|
| 22 |
-
|
| 23 |
-
ds = load_dataset("usm3d/hoho25k", cache_dir="/mnt/personal/skvrnjan/hoho25k/", trust_remote_code=True)
|
| 24 |
-
ds = ds.shuffle()
|
| 25 |
|
| 26 |
scores_hss = []
|
| 27 |
scores_f1 = []
|
|
@@ -31,12 +33,13 @@ show_visu = False
|
|
| 31 |
|
| 32 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 33 |
|
| 34 |
-
|
| 35 |
-
pnet_model = load_pointnet_model(model_path="/mnt/personal/skvrnjan/hoho_pnet/initial_epoch_100.pth", device=device, predict_score=True)
|
| 36 |
#pnet_model = None
|
| 37 |
|
| 38 |
#pnet_class_model = load_pointnet_class_model(model_path="/home/skvrnjan/personal/hoho_pnet_edges_v2/initial_epoch_100.pth", device=device)
|
| 39 |
-
pnet_class_model =
|
|
|
|
| 40 |
#pnet_class_model = None
|
| 41 |
|
| 42 |
#voxel_model = load_3dcnn_model(model_path="/home/skvrnjan/personal/hoho_voxel/initial_epoch_100.pth", device=device, predict_score=True)
|
|
@@ -45,13 +48,20 @@ voxel_model = None
|
|
| 45 |
config = {'vertex_threshold': 0.4, 'edge_threshold': 0.6, 'only_predicted_connections': False}
|
| 46 |
|
| 47 |
idx = 0
|
|
|
|
| 48 |
for a in tqdm(ds['train'], desc="Processing dataset"):
|
| 49 |
#plot_all_modalities(a)
|
| 50 |
#pred_vertices, pred_edges = predict_wireframe_old(a)
|
| 51 |
-
|
| 52 |
try:
|
| 53 |
-
|
|
|
|
| 54 |
#pred_vertices, pred_edges = predict_wireframe_old(a)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
except:
|
| 56 |
pred_vertices, pred_edges = empty_solution()
|
| 57 |
|
|
@@ -72,8 +82,8 @@ for a in tqdm(ds['train'], desc="Processing dataset"):
|
|
| 72 |
o3d.visualization.draw_geometries(visu_all, window_name=f"3D Reconstruction - HSS: {score.hss:.4f}, F1: {score.f1:.4f}, IoU: {score.iou:.4f}")
|
| 73 |
|
| 74 |
idx += 1
|
| 75 |
-
|
| 76 |
-
|
| 77 |
|
| 78 |
for i in range(10):
|
| 79 |
print("END OF DATASET")
|
|
|
|
| 17 |
from fast_pointnet import load_pointnet_model
|
| 18 |
from fast_voxel import load_3dcnn_model
|
| 19 |
from fast_pointnet_class import load_pointnet_model as load_pointnet_class_model
|
| 20 |
+
from fast_pointnet_class_10d import load_pointnet_model as load_pointnet_class_model_10d
|
| 21 |
import torch
|
| 22 |
+
import time
|
| 23 |
|
| 24 |
+
ds = load_dataset("usm3d/hoho25k", cache_dir="/media/skvrnjan/sd/hoho25k/", trust_remote_code=True)
|
| 25 |
+
#ds = load_dataset("usm3d/hoho25k", cache_dir="/mnt/personal/skvrnjan/hoho25k/", trust_remote_code=True)
|
| 26 |
+
#ds = ds.shuffle()
|
| 27 |
|
| 28 |
scores_hss = []
|
| 29 |
scores_f1 = []
|
|
|
|
| 33 |
|
| 34 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 35 |
|
| 36 |
+
pnet_model = load_pointnet_model(model_path="pnet.pth", device=device, predict_score=True)
|
| 37 |
+
#pnet_model = load_pointnet_model(model_path="/mnt/personal/skvrnjan/hoho_pnet/initial_epoch_100.pth", device=device, predict_score=True)
|
| 38 |
#pnet_model = None
|
| 39 |
|
| 40 |
#pnet_class_model = load_pointnet_class_model(model_path="/home/skvrnjan/personal/hoho_pnet_edges_v2/initial_epoch_100.pth", device=device)
|
| 41 |
+
#pnet_class_model = load_pointnet_class_model_10d(model_path="/home/skvrnjan/personal/hoho_pnet_edges_10d/initial_epoch_75.pth", device=device)
|
| 42 |
+
pnet_class_model = load_pointnet_class_model(model_path="pnet_class.pth", device=device)
|
| 43 |
#pnet_class_model = None
|
| 44 |
|
| 45 |
#voxel_model = load_3dcnn_model(model_path="/home/skvrnjan/personal/hoho_voxel/initial_epoch_100.pth", device=device, predict_score=True)
|
|
|
|
| 48 |
config = {'vertex_threshold': 0.4, 'edge_threshold': 0.6, 'only_predicted_connections': False}
|
| 49 |
|
| 50 |
idx = 0
|
| 51 |
+
prediction_times = []
|
| 52 |
for a in tqdm(ds['train'], desc="Processing dataset"):
|
| 53 |
#plot_all_modalities(a)
|
| 54 |
#pred_vertices, pred_edges = predict_wireframe_old(a)
|
| 55 |
+
pred_vertices, pred_edges = predict_wireframe(a.copy(), pnet_model, voxel_model, pnet_class_model, config)
|
| 56 |
try:
|
| 57 |
+
start_time = time.time()
|
| 58 |
+
pred_vertices, pred_edges = predict_wireframe(a.copy(), pnet_model, voxel_model, pnet_class_model, config)
|
| 59 |
#pred_vertices, pred_edges = predict_wireframe_old(a)
|
| 60 |
+
end_time = time.time()
|
| 61 |
+
prediction_time = end_time - start_time
|
| 62 |
+
prediction_times.append(prediction_time)
|
| 63 |
+
mean_time = np.mean(prediction_times)
|
| 64 |
+
print(f"Prediction time: {prediction_time:.4f} seconds, Mean time: {mean_time:.4f} seconds")
|
| 65 |
except:
|
| 66 |
pred_vertices, pred_edges = empty_solution()
|
| 67 |
|
|
|
|
| 82 |
o3d.visualization.draw_geometries(visu_all, window_name=f"3D Reconstruction - HSS: {score.hss:.4f}, F1: {score.f1:.4f}, IoU: {score.iou:.4f}")
|
| 83 |
|
| 84 |
idx += 1
|
| 85 |
+
if idx >= 100: # Limit to first 10 samples for testing
|
| 86 |
+
break
|
| 87 |
|
| 88 |
for i in range(10):
|
| 89 |
print("END OF DATASET")
|