jskvrna commited on
Commit
1d64568
·
1 Parent(s): 9c5e3ff
Files changed (2) hide show
  1. predict.py +350 -193
  2. train.py +20 -10
predict.py CHANGED
@@ -16,9 +16,11 @@ from fast_pointnet import save_patches_dataset, predict_vertex_from_patch
16
  #import time
17
  from fast_pointnet_class import save_patches_dataset as save_patches_dataset_class
18
  from fast_pointnet_class import predict_class_from_patch
 
19
  from scipy.spatial.distance import cdist
20
  from scipy.optimize import linear_sum_assignment
21
  import torch
 
22
 
23
  GENERATE_DATASET = False
24
  DATASET_DIR = '/home/skvrnjan/personal/hohocustom/'
@@ -1179,6 +1181,114 @@ def generate_edge_patches_forward(frame, pred_vertices):
1179
 
1180
  return forward_patches
1181
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1182
  def calculate_cylinder_overlap_volume(cyl1, cyl2):
1183
  """
1184
  Calculate the intersection volume between two cylinders using numpy vectorization.
@@ -1280,119 +1390,232 @@ def calculate_cylinder_overlap_volume(cyl1, cyl2):
1280
  return max(0.0, overlap_volume)
1281
 
1282
  def create_pcloud(colmap_rec, frame):
1283
- all_imgs_ids = []
1284
- all_imgs = []
1285
- all_imgs_K = []
1286
- all_imgs_R = []
1287
- all_imgs_t = []
1288
- all_imgs_ade = []
1289
- all_imgs_gestalt = []
 
 
 
 
 
 
 
 
 
 
1290
 
1291
- for img_id_c, col_img_obj in colmap_rec.images.items(): # Renamed col_img to col_img_obj to avoid conflict
1292
- all_imgs_ids.append(col_img_obj.name)
1293
- all_imgs.append(col_img_obj)
1294
-
1295
- for i, (K, R, t, img_id, ade, gestalt, depth) in enumerate(zip(frame['K'], frame['R'], frame['t'], frame['image_ids'], frame['ade'], frame['gestalt'], frame['depth'])):
1296
- for all_imgsid in all_imgs_ids:
1297
- if all_imgsid == img_id:
1298
- all_imgs_K.append(np.array(K))
1299
- all_imgs_R.append(np.array(R))
1300
- all_imgs_t.append(np.array(t))
1301
-
1302
- ade_mask = get_house_mask(ade)
1303
- all_imgs_ade.append(np.array(ade_mask))
1304
-
1305
- depth_size = (np.array(depth).shape[1], np.array(depth).shape[0]) # W, H
1306
- gest_seg = gestalt.resize(depth_size)
1307
- gest_seg_np = np.array(gest_seg).astype(np.uint8)
1308
- all_imgs_gestalt.append(np.array(gest_seg_np))
1309
 
1310
- # 2) Gather 3D points that this image sees (according to COLMAP)
1311
- points_xyz_world = []
1312
- points_colors = []
1313
- points_idxs = []
1314
- points_imgs = []
1315
- points_uv = []
1316
- points_ade = []
1317
- points_gestalt = []
 
 
 
 
 
 
1318
 
1319
- for pid, p3D in colmap_rec.points3D.items():
1320
- found = False
1321
- found_in_ids = []
1322
- uv_projections = []
1323
- in_ade = False
1324
- gest = []
1325
-
1326
- for idx, img in enumerate(all_imgs):
1327
- if img.has_point3D(pid):
1328
- found = True
1329
- found_in_ids.append(img.name)
1330
-
1331
- # Project the 3D point to image coordinates using K, R, t
1332
- R = all_imgs_R[idx]
1333
- t = all_imgs_t[idx]
1334
- K = all_imgs_K[idx]
1335
-
1336
- xyz_homogeneous = np.append(p3D.xyz, 1.0)
1337
- world_to_cam_mat = np.hstack([R, t.reshape(3, 1)])
1338
- cam_coords = world_to_cam_mat @ xyz_homogeneous
1339
- if cam_coords[2] > 0: # Point is in front of camera
1340
- pixel_coords = np.dot(K, cam_coords)
1341
- u = pixel_coords[0] / pixel_coords[2]
1342
- v = pixel_coords[1] / pixel_coords[2]
1343
- u = round(u)
1344
- v = round(v)
1345
- uv_projections.append((u, v))
1346
-
1347
- # Check if point is inside ADE segmentation (house mask)
1348
- if 0 <= u < all_imgs_ade[idx].shape[1] and 0 <= v < all_imgs_ade[idx].shape[0]:
1349
- in_ade = all_imgs_ade[idx][v, u] # Point is inside house mask
1350
- else:
1351
- in_ade = False # Default to False if out of bounds
1352
-
1353
- # Check gestalt segmentation value at this point
1354
- if 0 <= u < all_imgs_gestalt[idx].shape[1] and 0 <= v < all_imgs_gestalt[idx].shape[0]:
1355
- gestalt_value = all_imgs_gestalt[idx][v, u]
1356
- gest.append(gestalt_value)
1357
- else:
1358
- gest.append(np.array([0,0,0])) # Default value for out-of-bounds
1359
-
1360
- if found:
1361
- points_xyz_world.append(p3D.xyz) # world coords
1362
- points_colors.append(p3D.color / 255.0) # normalize to [0,1]
1363
- points_idxs.append(pid)
1364
- points_imgs.append(found_in_ids)
1365
- points_uv.append(uv_projections)
1366
- points_ade.append(in_ade)
1367
- points_gestalt.append(gest)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1368
 
1369
- points_xyz_world = np.array(points_xyz_world) if points_xyz_world else np.empty((0, 3))
1370
- points_colors = np.array(points_colors) if points_colors else np.empty((0, 3))
1371
- points_idxs = np.array(points_idxs) if points_idxs else np.empty((0,))
1372
- points_ade = np.array(points_ade) if points_ade else np.empty((0,))
1373
 
1374
- # Create 7D point cloud from COLMAP data (xyz + rgb + img_count)
1375
  if points_xyz_world.shape[0] > 0:
1376
- colmap_points_7d = np.zeros((len(points_xyz_world), 7))
1377
- colmap_points_7d[:, :3] = points_xyz_world # xyz coordinates
1378
- colmap_points_7d[:, 3:6] = points_colors # rgb colors
1379
- colmap_points_7d[:, 6] = points_idxs
1380
-
1381
- whole_pcloud = {'points_7d': colmap_points_7d,
1382
- 'imgs': points_imgs,
1383
- 'uv': points_uv,
1384
- 'all_imgs_ids': all_imgs_ids,
1385
- 'all_imgs_K': all_imgs_K,
1386
- 'all_imgs_R': all_imgs_R,
1387
- 'all_imgs_t': all_imgs_t,
1388
- 'ade': points_ade,
1389
- 'gestalt': points_gestalt}
 
 
1390
  else:
1391
- whole_pcloud = {'points_7d': np.empty((0, 7)),
1392
- 'ids': np.empty((0,)),
1393
- 'imgs': [],
1394
- 'uv': []}
1395
-
 
 
 
 
 
 
1396
  return whole_pcloud
1397
 
1398
  def predict_wireframe(entry, pnet_model, voxel_model, pnet_class_model, config) -> Tuple[np.ndarray, List[int]]:
@@ -1401,11 +1624,19 @@ def predict_wireframe(entry, pnet_model, voxel_model, pnet_class_model, config)
1401
  """
1402
 
1403
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
 
 
 
 
 
1404
 
1405
  good_entry = convert_entry_to_human_readable(entry)
1406
  colmap_rec = good_entry['colmap_binary']
1407
 
1408
- colmap_pcloud = create_pcloud(colmap_rec, good_entry)
 
 
 
1409
 
1410
  vertex_threshold = config.get('vertex_threshold', 0.5)
1411
  edge_threshold = config.get('edge_threshold', 0.5)
@@ -1415,8 +1646,6 @@ def predict_wireframe(entry, pnet_model, voxel_model, pnet_class_model, config)
1415
  idxs_points = []
1416
  all_connections = []
1417
 
1418
- print(f"Processing {len(good_entry['gestalt'])} images")
1419
-
1420
  for i, (gest, depth, K, R, t, img_id, ade_seg) in enumerate(zip(good_entry['gestalt'],
1421
  good_entry['depth'],
1422
  good_entry['K'],
@@ -1425,6 +1654,7 @@ def predict_wireframe(entry, pnet_model, voxel_model, pnet_class_model, config)
1425
  good_entry['image_ids'],
1426
  good_entry['ade'] # Added ade20k segmentation
1427
  )):
 
1428
  # Visualize gestalt segmentation
1429
  K = np.array(K)
1430
  R = np.array(R)
@@ -1436,107 +1666,35 @@ def predict_wireframe(entry, pnet_model, voxel_model, pnet_class_model, config)
1436
  gest_seg_np = np.array(gest_seg).astype(np.uint8)
1437
 
1438
  vertices_ours, connections_ours, vertices_3d_ours, patches, filtered_point_idxs = our_get_vertices_and_edges(gest_seg_np, colmap_rec, img_id, ade_seg, depth, K=K, R=R, t=t, frame=good_entry)
 
1439
  idxs_points.append(filtered_point_idxs)
1440
  all_connections.append(connections_ours)
1441
 
1442
- '''
1443
- if GENERATE_DATASET:
1444
- save_patches_dataset(patches, DATASET_DIR, img_id)
1445
- continue
1446
- '''
1447
- #for idx, patch in enumerate(patches):
1448
- #pred_vertex, pred_dist, pred_class = predict_vertex_from_patch(pnet_model, patch, device=device)
1449
-
1450
- #vertices_3d_ours[idx] = pred_vertex
1451
-
1452
- #visu_patch_and_pred(patch, pred_vertex, pred_dist, pred_class)
1453
-
1454
- # x = 0
1455
-
1456
  vertices, connections, vertices_3d = vertices_ours, connections_ours, vertices_3d_ours
1457
- # Get 2D vertices and edges first
1458
- #vertices, connections = get_vertices_and_edges_from_segmentation(gest_seg_np, edge_th=25.)
1459
-
1460
- #gt_verts = []
1461
- #gt_verts, gt_connects, gt_verts3d = get_gt_vertices_and_edges(good_entry, i, depth, colmap_rec, K, R, t, img_id, ade_seg)
1462
- #vertices, connections = gt_verts, gt_connects
1463
-
1464
- if False:
1465
- gest.save(f'gestalt/{img_id}.png')
1466
- # Save ADE20k segmentation
1467
- # ade_seg is already a PIL Image
1468
- try:
1469
- ade_seg.save(f'ade_segmentations/{img_id}_ade.png')
1470
- except Exception as e:
1471
- print(f"Could not save ADE segmentation for {img_id}: {e}")
1472
- save_gestalt_with_proj(gest_seg_np, gt_verts, img_id)
1473
- # Define a local helper function to draw crosses and save the image
1474
-
1475
- # Draw crosses on the ADE segmentation image and save it
1476
- # 'vertices' here refers to gt_verts
1477
- draw_crosses_on_image(ade_seg, vertices, f'crosses_{img_id}.png', color=(0, 0, 0), size=5)
1478
-
1479
- # Check if we have enough to proceed
1480
- if (len(vertices) < 2) or (len(connections) < 1) and False:
1481
- print(f'Not enough vertices or connections found in image {i}, skipping.')
1482
- vert_edge_per_image[i] = [], [], np.empty((0, 3))
1483
- continue
1484
-
1485
- # Call the refactored function to get 3D points
1486
- #vertices_3d = create_3d_wireframe_single_image(vertices, connections, depth, colmap_rec, img_id, ade_seg, K, R, t)
1487
- #vertices_3d = gt_verts3d
1488
- # Store original 2D vertices, connections, and computed 3D points
1489
- #connections = []
1490
-
1491
- if False:
1492
- pcd, geometries = plot_reconstruction_local(None, colmap_rec, points=True, cameras=True, crop_outliers=True)
1493
- wireframe = plot_wireframe_local(None, good_entry['wf_vertices'], good_entry['wf_edges'], good_entry['wf_classifications'])
1494
- wireframe2 = plot_wireframe_local(None, vertices_3d_ours, connections_ours, None, color='rgb(255, 0, 0)')
1495
- wireframe3 = plot_wireframe_local(None, vertices_3d, connections, None, color='rgb(0, 0, 255)')
1496
- bpo_cams = plot_bpo_cameras_from_entry_local(None, good_entry)
1497
-
1498
- visu_all = [pcd] + geometries + wireframe + bpo_cams + wireframe2 + wireframe3
1499
- #o3d.visualization.draw_geometries(visu_all, window_name="3D Reconstruction")
1500
 
1501
  vert_edge_per_image[i] = vertices, connections, vertices_3d
1502
-
1503
  extracted_points, extracted_colors, extracted_ids, whole_pcloud, connections = extract_vertices_from_whole_pcloud(colmap_rec, idxs_points, all_connections)
1504
 
1505
- patches = generate_patches_v2(extracted_points, extracted_colors, extracted_ids, whole_pcloud, good_entry['wf_vertices'])
 
 
1506
 
1507
- # Predict vertices from patches using the neural network
1508
  predicted_vertices = []
1509
- for patch in patches:
1510
  pred_vertex, pred_dist, pred_class = predict_vertex_from_patch(pnet_model, patch, device=device)
1511
 
1512
- #visu_patch_and_pred(patch, pred_vertex, pred_dist, pred_class)
1513
-
1514
  if pred_class > vertex_threshold:
1515
  predicted_vertices.append(pred_vertex)
1516
  else:
1517
  predicted_vertices.append(np.array([0.0, 0.0, 0.0])) # Append a zero vertex if not predicted
1518
-
1519
- #pred_vertex_voxel, pred_dist_voxel, pred_class_voxel = predict_vertex_from_patch_voxel(voxel_model, patch, device=device)
1520
- #visu_patch_and_pred(patch, pred_vertex_voxel, pred_dist_voxel, pred_class_voxel)
1521
 
1522
  predicted_vertices = np.array(predicted_vertices) if predicted_vertices else np.empty((0, 3))
1523
 
1524
- #visu_pcloud_and_preds(colmap_rec, extracted_ids, extracted_points, extracted_colors, predicted_vertices, connections)
1525
-
1526
  if GENERATE_DATASET:
1527
  save_patches_dataset(patches, DATASET_DIR, img_id)
1528
  return empty_solution()
1529
 
1530
- # Merge vertices from all images
1531
- #all_3d_vertices, connections_3d = merge_vertices_3d(vert_edge_per_image, 0.1)
1532
- #all_3d_vertices_clean, connections_3d_clean = all_3d_vertices, connections_3d
1533
- #all_3d_vertices_clean, connections_3d_clean = prune_not_connected(all_3d_vertices, connections_3d, keep_largest=False)
1534
- #all_3d_vertices_clean, connections_3d_clean = prune_too_far(all_3d_vertices_clean, connections_3d_clean, colmap_rec, th = 1.5)
1535
-
1536
- #if (len(all_3d_vertices_clean) < 2) or len(connections_3d_clean) < 1 and False:
1537
- # print (f'Not enough vertices or connections in the 3D vertices')
1538
- # return empty_solution()
1539
-
1540
  # Filter out zero vertices and update connections accordingly
1541
  non_zero_mask = ~np.all(np.isclose(predicted_vertices, [0.0, 0.0, 0.0]), axis=1)
1542
  valid_indices = np.where(non_zero_mask)[0]
@@ -1544,11 +1702,9 @@ def predict_wireframe(entry, pnet_model, voxel_model, pnet_class_model, config)
1544
  # Filter vertices to only include non-zero ones
1545
  filtered_vertices = predicted_vertices[valid_indices]
1546
 
1547
- #patches = generate_edge_patches(good_entry, filtered_vertices, colmap_pcloud)
1548
  if GENERATE_DATASET_EDGES:
1549
  patches = generate_edge_patches(good_entry, filtered_vertices, colmap_pcloud)
1550
  save_patches_dataset_class(patches, EDGES_DATASET_DIR, good_entry['order_id'])
1551
-
1552
  return empty_solution()
1553
 
1554
  if len(valid_indices) == 0:
@@ -1566,17 +1722,18 @@ def predict_wireframe(entry, pnet_model, voxel_model, pnet_class_model, config)
1566
  new_end = old_to_new_mapping[end_idx]
1567
  if new_start != new_end: # Ensure we don't connect a vertex to itself
1568
  filtered_connections.append((new_start, new_end))
1569
-
1570
- #print(f"Filtered vertices: {len(filtered_vertices)} from {len(predicted_vertices)}")
1571
- #print(f"Filtered connections: {len(filtered_connections)} from {len(connections)}")
1572
 
 
 
1573
  forward_patches = generate_edge_patches_forward(good_entry, filtered_vertices)
 
1574
  new_connections = []
1575
  if len(forward_patches) > 0:
1576
- for patch in forward_patches:
1577
  start_idx, end_idx = patch['connection']
1578
 
1579
  pred_class, pred_score = predict_class_from_patch(pnet_class_model, patch, device=device)
 
1580
  if pred_score > edge_threshold:
1581
  new_connections.append((start_idx, end_idx))
1582
 
 
16
  #import time
17
  from fast_pointnet_class import save_patches_dataset as save_patches_dataset_class
18
  from fast_pointnet_class import predict_class_from_patch
19
+ from fast_pointnet_class_10d import predict_class_from_patch as predict_class_from_patch_10d
20
  from scipy.spatial.distance import cdist
21
  from scipy.optimize import linear_sum_assignment
22
  import torch
23
+ import time
24
 
25
  GENERATE_DATASET = False
26
  DATASET_DIR = '/home/skvrnjan/personal/hohocustom/'
 
1181
 
1182
  return forward_patches
1183
 
1184
+ def generate_edge_patches_forward_10d(frame, pred_vertices, colmap_pcloud):
1185
+ vertices = pred_vertices
1186
+
1187
+ cylinder_radius = 0.5 # meters
1188
+
1189
+ points_6d = colmap_pcloud['points_7d'][:, :7]
1190
+ points_6d[:, 3:6] = points_6d[:, 3:6] * 2 - 1 # Normalize RGB colors to [0, 1]
1191
+ ade = colmap_pcloud['ade']
1192
+ ade = np.where(ade, 1, -1) # Normalize to [-1, 1]
1193
+ gestalt = colmap_pcloud['gestalt']
1194
+
1195
+ # Fuse multiple gestalt values per point using majority voting
1196
+ fused_gestalt = []
1197
+ for point_gestalt_list in gestalt:
1198
+ if len(point_gestalt_list) == 0:
1199
+ fused_gestalt.append(np.array([0, 0, 0]))
1200
+ elif len(point_gestalt_list) == 1:
1201
+ fused_gestalt.append(point_gestalt_list[0])
1202
+ else:
1203
+ # Convert to numpy array for easier manipulation
1204
+ gestalt_values = np.array(point_gestalt_list)
1205
+
1206
+ # Method 1: Average the RGB values
1207
+ fused_value = np.mean(gestalt_values, axis=0).astype(np.uint8)
1208
+
1209
+ fused_gestalt.append(fused_value)
1210
+
1211
+ gestalt = np.array(fused_gestalt)
1212
+ gestalt = (gestalt / 255) * 2 - 1 # Normalize to [-1, 1]
1213
+
1214
+ # Extract 3D coordinates for faster vectorized operations
1215
+ colmap_points_3d = points_6d[:, :3]
1216
+
1217
+ # Create combined 10D point cloud (xyz + rgb + ade + gestalt)
1218
+ colmap_points_10d = np.zeros((len(colmap_points_3d), 10))
1219
+ colmap_points_10d[:, :3] = colmap_points_3d # xyz coordinates
1220
+ colmap_points_10d[:, 3:6] = points_6d[:, 3:6] # rgb colors (already normalized to [-1, 1])
1221
+ colmap_points_10d[:, 6] = ade # ade values (normalized to [-1, 1])
1222
+ colmap_points_10d[:, 7:10] = gestalt # gestalt values (normalized to [-1, 1], all 3 RGB channels)
1223
+
1224
+ forward_patches = []
1225
+
1226
+ # For each vertex pair, create a patch without label
1227
+ for i in range(len(vertices)):
1228
+ for j in range(i + 1, len(vertices)):
1229
+ start_vertex = vertices[i]
1230
+ end_vertex = vertices[j]
1231
+
1232
+ # Create line vector from start to end
1233
+ line_vector = end_vertex - start_vertex
1234
+ line_length = np.linalg.norm(line_vector)
1235
+
1236
+ # Normalize line vector
1237
+ line_direction = line_vector / line_length
1238
+
1239
+ # Extend the line by 1 meter on both ends for more context
1240
+ extension_length = 0.25 # 1 meter in meters
1241
+ extended_start = start_vertex - extension_length * line_direction
1242
+ extended_end = end_vertex + extension_length * line_direction
1243
+ extended_line_length = line_length + 2 * extension_length
1244
+
1245
+ # Vectorized distance calculation
1246
+ # Vector from extended start to all points
1247
+ start_to_points = colmap_points_3d - extended_start[np.newaxis, :]
1248
+
1249
+ # Project onto line direction to get distance along extended line
1250
+ projection_lengths = np.dot(start_to_points, line_direction)
1251
+
1252
+ # Filter points within extended line segment bounds
1253
+ within_bounds = (projection_lengths >= 0) & (projection_lengths <= extended_line_length)
1254
+
1255
+ # Find closest points on extended line segment for all points
1256
+ closest_points_on_line = extended_start[np.newaxis, :] + projection_lengths[:, np.newaxis] * line_direction[np.newaxis, :]
1257
+
1258
+ # Calculate perpendicular distances from points to line
1259
+ perpendicular_distances = np.linalg.norm(colmap_points_3d - closest_points_on_line, axis=1)
1260
+
1261
+ # Find points within cylinder
1262
+ within_cylinder = within_bounds & (perpendicular_distances <= cylinder_radius)
1263
+
1264
+ if np.sum(within_cylinder) <= 10:
1265
+ continue
1266
+
1267
+ points_in_cylinder = colmap_points_10d[within_cylinder]
1268
+ point_indices_in_cylinder = np.where(within_cylinder)[0]
1269
+
1270
+ # Center the patch at the midpoint of the original line (not extended)
1271
+ line_midpoint = (start_vertex + end_vertex) / 2
1272
+
1273
+ # Shift points to center around origin
1274
+ points_centered = points_in_cylinder.copy()
1275
+ points_centered[:, :3] -= line_midpoint
1276
+
1277
+ # Create edge patch without label
1278
+ edge_patch = {
1279
+ 'patch_10d': points_centered,
1280
+ 'connection': (i, j),
1281
+ 'line_start': start_vertex - line_midpoint,
1282
+ 'line_end': end_vertex - line_midpoint,
1283
+ 'cylinder_radius': cylinder_radius,
1284
+ 'point_indices': point_indices_in_cylinder,
1285
+ 'center': line_midpoint
1286
+ }
1287
+
1288
+ forward_patches.append(edge_patch)
1289
+
1290
+ return forward_patches
1291
+
1292
  def calculate_cylinder_overlap_volume(cyl1, cyl2):
1293
  """
1294
  Calculate the intersection volume between two cylinders using numpy vectorization.
 
1390
  return max(0.0, overlap_volume)
1391
 
1392
  def create_pcloud(colmap_rec, frame):
1393
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
1394
+ #print(f"create_pcloud using device: {device}")
1395
+
1396
+ # 1. Preprocess image data from the frame and colmap (mostly on CPU)
1397
+ img_id_to_colmap_img_obj_map = {
1398
+ img_obj.name: img_obj for img_obj_name, img_obj in colmap_rec.images.items()
1399
+ }
1400
+
1401
+ frame_img_data = {}
1402
+ ordered_frame_img_ids = []
1403
+
1404
+ for K_val, R_val, t_val, img_id_val, ade_val, gestalt_val, depth_val in zip(
1405
+ frame['K'], frame['R'], frame['t'], frame['image_ids'],
1406
+ frame['ade'], frame['gestalt'], frame['depth']
1407
+ ):
1408
+ if img_id_val not in img_id_to_colmap_img_obj_map:
1409
+ continue
1410
 
1411
+ ordered_frame_img_ids.append(img_id_val)
1412
+ depth_np = np.array(depth_val)
1413
+ depth_H, depth_W = depth_np.shape[0], depth_np.shape[1]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1414
 
1415
+ ade_mask_np = get_house_mask(ade_val)
1416
+
1417
+ gest_seg_pil = gestalt_val.resize((depth_W, depth_H), Image.Resampling.NEAREST)
1418
+ gest_seg_np = np.array(gest_seg_pil).astype(np.uint8)
1419
+
1420
+ frame_img_data[img_id_val] = {
1421
+ 'K_np': np.array(K_val),
1422
+ 'R_np': np.array(R_val),
1423
+ 't_np': np.array(t_val).reshape(3,1),
1424
+ 'ade_mask_np': ade_mask_np,
1425
+ 'gestalt_seg_np': gest_seg_np,
1426
+ 'H': depth_H,
1427
+ 'W': depth_W
1428
+ }
1429
 
1430
+ # 2. Process 3D points by iterating through images
1431
+ point_data_accumulator = {} # Key: pid, accumulates data on CPU
1432
+
1433
+ # Pre-fetch all COLMAP point data to avoid repeated dictionary lookups
1434
+ colmap_points_data_cpu = {
1435
+ pid: {'xyz': p3D.xyz, 'color': p3D.color / 255.0}
1436
+ for pid, p3D in colmap_rec.points3D.items()
1437
+ }
1438
+
1439
+ for img_id in ordered_frame_img_ids:
1440
+ if img_id not in frame_img_data:
1441
+ continue
1442
+
1443
+ col_img_obj = img_id_to_colmap_img_obj_map[img_id]
1444
+ img_data = frame_img_data[img_id]
1445
+
1446
+ K_np, R_np, t_np = img_data['K_np'], img_data['R_np'], img_data['t_np']
1447
+ ade_mask_np, gestalt_seg_np = img_data['ade_mask_np'], img_data['gestalt_seg_np']
1448
+ H, W = img_data['H'], img_data['W']
1449
+
1450
+ # Convert current image data to GPU tensors
1451
+ K_gpu = torch.from_numpy(K_np).float().to(device)
1452
+ R_gpu = torch.from_numpy(R_np).float().to(device)
1453
+ t_gpu = torch.from_numpy(t_np).float().to(device)
1454
+ ade_mask_gpu = torch.from_numpy(ade_mask_np).bool().to(device)
1455
+ gestalt_seg_gpu = torch.from_numpy(gestalt_seg_np).to(device) # uint8 is fine
1456
+
1457
+ visible_pids_in_img = []
1458
+ visible_xyz_coords_list = []
1459
+
1460
+ for pid, p3D_data in colmap_points_data_cpu.items():
1461
+ if col_img_obj.has_point3D(pid): # This check remains CPU-bound
1462
+ visible_pids_in_img.append(pid)
1463
+ visible_xyz_coords_list.append(p3D_data['xyz'])
1464
+
1465
+ if not visible_pids_in_img:
1466
+ continue
1467
+
1468
+ num_visible_points = len(visible_pids_in_img)
1469
+ world_pts_np = np.array(visible_xyz_coords_list)
1470
+ world_pts_gpu = torch.from_numpy(world_pts_np).float().to(device)
1471
+
1472
+ # Batch projection on GPU
1473
+ world_pts_h_gpu = torch.cat((world_pts_gpu, torch.ones(num_visible_points, 1, device=device)), dim=1)
1474
+ P_world_to_cam_gpu = torch.hstack((R_gpu, t_gpu))
1475
+ cam_coords_proj_gpu = P_world_to_cam_gpu @ world_pts_h_gpu.T
1476
+
1477
+ cam_coords_z_gpu = cam_coords_proj_gpu[2, :]
1478
+ in_front_mask_gpu = cam_coords_z_gpu > 1e-6
1479
+
1480
+ pixel_coords_h_gpu = K_gpu @ cam_coords_proj_gpu
1481
+
1482
+ u_proj_gpu = torch.full_like(cam_coords_z_gpu, -1.0, dtype=torch.float32)
1483
+ v_proj_gpu = torch.full_like(cam_coords_z_gpu, -1.0, dtype=torch.float32)
1484
+
1485
+ # Avoid division by zero/small numbers for points not truly in front or on optical center
1486
+ valid_depth_mask_gpu = in_front_mask_gpu & (torch.abs(cam_coords_z_gpu) > 1e-6)
1487
+
1488
+ if torch.any(valid_depth_mask_gpu):
1489
+ u_proj_gpu[valid_depth_mask_gpu] = pixel_coords_h_gpu[0, valid_depth_mask_gpu] / cam_coords_z_gpu[valid_depth_mask_gpu]
1490
+ v_proj_gpu[valid_depth_mask_gpu] = pixel_coords_h_gpu[1, valid_depth_mask_gpu] / cam_coords_z_gpu[valid_depth_mask_gpu]
1491
+
1492
+ u_rounded_gpu = torch.round(u_proj_gpu).long()
1493
+ v_rounded_gpu = torch.round(v_proj_gpu).long()
1494
+
1495
+ is_in_bounds_gpu = (u_rounded_gpu >= 0) & (u_rounded_gpu < W) & \
1496
+ (v_rounded_gpu >= 0) & (v_rounded_gpu < H) & \
1497
+ in_front_mask_gpu # Re-check in_front_mask_gpu as rounding might affect edge cases slightly
1498
+
1499
+ # Sample ADE and Gestalt on GPU for points in bounds
1500
+ # Initialize with default values for all points, then update for those in bounds
1501
+ sampled_ade_status_gpu = torch.zeros(num_visible_points, dtype=torch.bool, device=device)
1502
+ sampled_gestalt_values_gpu = torch.zeros(num_visible_points, 3, dtype=torch.uint8, device=device)
1503
+
1504
+ # Create a mask for points that are valid for sampling (in_bounds and in_front)
1505
+ valid_for_sampling_mask_gpu = is_in_bounds_gpu
1506
+
1507
+ if torch.any(valid_for_sampling_mask_gpu):
1508
+ u_sample_gpu = u_rounded_gpu[valid_for_sampling_mask_gpu]
1509
+ v_sample_gpu = v_rounded_gpu[valid_for_sampling_mask_gpu]
1510
+
1511
+ sampled_ade_status_gpu[valid_for_sampling_mask_gpu] = ade_mask_gpu[v_sample_gpu, u_sample_gpu]
1512
+ sampled_gestalt_values_gpu[valid_for_sampling_mask_gpu] = gestalt_seg_gpu[v_sample_gpu, u_sample_gpu]
1513
+
1514
+ # Transfer necessary results back to CPU for accumulation
1515
+ u_rounded_cpu = u_rounded_gpu.cpu().numpy()
1516
+ v_rounded_cpu = v_rounded_gpu.cpu().numpy()
1517
+ is_in_bounds_cpu = is_in_bounds_gpu.cpu().numpy() # Use the original is_in_bounds_gpu for logic
1518
+ sampled_ade_status_cpu = sampled_ade_status_gpu.cpu().numpy()
1519
+ sampled_gestalt_values_cpu = sampled_gestalt_values_gpu.cpu().numpy()
1520
+
1521
+
1522
+ # Update accumulator (on CPU)
1523
+ for i in range(num_visible_points):
1524
+ pid = visible_pids_in_img[i]
1525
+
1526
+ if pid not in point_data_accumulator:
1527
+ point_data_accumulator[pid] = {
1528
+ 'xyz': colmap_points_data_cpu[pid]['xyz'],
1529
+ 'color': colmap_points_data_cpu[pid]['color'],
1530
+ 'imgs_seen_by': [],
1531
+ 'uv_projections': [],
1532
+ 'ade_status': False,
1533
+ 'gestalt_values': []
1534
+ }
1535
+
1536
+ acc = point_data_accumulator[pid]
1537
+ acc['imgs_seen_by'].append(img_id)
1538
+ acc['uv_projections'].append((u_rounded_cpu[i], v_rounded_cpu[i]))
1539
+
1540
+ if is_in_bounds_cpu[i]: # This point was projected within bounds and in front
1541
+ acc['ade_status'] = sampled_ade_status_cpu[i]
1542
+ acc['gestalt_values'].append(sampled_gestalt_values_cpu[i])
1543
+ else: # Point projected out of bounds, behind, or failed depth check
1544
+ acc['gestalt_values'].append(np.array([0,0,0], dtype=np.uint8))
1545
+
1546
+ # Optional: clear GPU cache if memory is a concern for many images
1547
+ # if device.type == 'cuda':
1548
+ # torch.cuda.empty_cache()
1549
+
1550
+
1551
+ # 3. Final data assembly (on CPU)
1552
+ points_xyz_world_list = []
1553
+ points_colors_list = []
1554
+ points_idxs_list = []
1555
+ points_imgs_seen_by_list = []
1556
+ points_uv_projections_per_point_list = []
1557
+ points_ade_status_final_list = []
1558
+ points_gestalt_values_per_point_list = []
1559
+
1560
+ # Ensure consistent order if downstream code relies on it, though original didn't specify sorting for pids
1561
+ # Using sorted_pids for reproducibility if point_data_accumulator keys order changes.
1562
+ sorted_pids = sorted(point_data_accumulator.keys())
1563
+
1564
+ for pid in sorted_pids:
1565
+ data = point_data_accumulator[pid]
1566
+ points_xyz_world_list.append(data['xyz'])
1567
+ points_colors_list.append(data['color'])
1568
+ points_idxs_list.append(pid)
1569
+ points_imgs_seen_by_list.append(data['imgs_seen_by'])
1570
+ points_uv_projections_per_point_list.append(data['uv_projections'])
1571
+ points_ade_status_final_list.append(data['ade_status'])
1572
+ points_gestalt_values_per_point_list.append(data['gestalt_values'])
1573
+
1574
+ points_xyz_world = np.array(points_xyz_world_list) if points_xyz_world_list else np.empty((0, 3))
1575
+ points_colors = np.array(points_colors_list) if points_colors_list else np.empty((0, 3))
1576
+ points_idxs = np.array(points_idxs_list, dtype=int) if points_idxs_list else np.empty((0,), dtype=int) # Ensure dtype for pids
1577
+ points_ade = np.array(points_ade_status_final_list, dtype=bool) if points_ade_status_final_list else np.empty((0,), dtype=bool)
1578
+
1579
+ output_all_colmap_img_ids = [img_obj.name for img_obj_name, img_obj in colmap_rec.images.items()]
1580
+ output_frame_K, output_frame_R, output_frame_t = [], [], []
1581
+
1582
+ for img_id_val in frame['image_ids']:
1583
+ if img_id_val in frame_img_data:
1584
+ data = frame_img_data[img_id_val]
1585
+ output_frame_K.append(data['K_np'])
1586
+ output_frame_R.append(data['R_np'])
1587
+ output_frame_t.append(data['t_np'])
1588
 
 
 
 
 
1589
 
 
1590
  if points_xyz_world.shape[0] > 0:
1591
+ colmap_points_7d = np.zeros((points_xyz_world.shape[0], 7))
1592
+ colmap_points_7d[:, :3] = points_xyz_world
1593
+ colmap_points_7d[:, 3:6] = points_colors
1594
+ colmap_points_7d[:, 6] = points_idxs
1595
+
1596
+ whole_pcloud = {
1597
+ 'points_7d': colmap_points_7d,
1598
+ 'imgs': points_imgs_seen_by_list,
1599
+ 'uv': points_uv_projections_per_point_list,
1600
+ 'all_imgs_ids': output_all_colmap_img_ids,
1601
+ 'all_imgs_K': output_frame_K,
1602
+ 'all_imgs_R': output_frame_R,
1603
+ 'all_imgs_t': output_frame_t,
1604
+ 'ade': points_ade,
1605
+ 'gestalt': points_gestalt_values_per_point_list
1606
+ }
1607
  else:
1608
+ whole_pcloud = {
1609
+ 'points_7d': np.empty((0, 7)),
1610
+ 'imgs': [],
1611
+ 'uv': [],
1612
+ 'all_imgs_ids': output_all_colmap_img_ids,
1613
+ 'all_imgs_K': output_frame_K,
1614
+ 'all_imgs_R': output_frame_R,
1615
+ 'all_imgs_t': output_frame_t,
1616
+ 'ade': np.empty((0,), dtype=bool),
1617
+ 'gestalt': []
1618
+ }
1619
  return whole_pcloud
1620
 
1621
  def predict_wireframe(entry, pnet_model, voxel_model, pnet_class_model, config) -> Tuple[np.ndarray, List[int]]:
 
1624
  """
1625
 
1626
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
1627
+ # Delete specified keys from the entry
1628
+ #keys_to_delete = ['wf_vertices', 'wf_edges', 'wf_classifications']
1629
+ #for key in keys_to_delete:
1630
+ # if key in entry:
1631
+ # del entry[key]
1632
 
1633
  good_entry = convert_entry_to_human_readable(entry)
1634
  colmap_rec = good_entry['colmap_binary']
1635
 
1636
+ #start_time = time.time()
1637
+ #colmap_pcloud = create_pcloud(colmap_rec, good_entry)
1638
+ #end_time = time.time()
1639
+ #print(f"create_pcloud took {end_time - start_time:.4f} seconds")
1640
 
1641
  vertex_threshold = config.get('vertex_threshold', 0.5)
1642
  edge_threshold = config.get('edge_threshold', 0.5)
 
1646
  idxs_points = []
1647
  all_connections = []
1648
 
 
 
1649
  for i, (gest, depth, K, R, t, img_id, ade_seg) in enumerate(zip(good_entry['gestalt'],
1650
  good_entry['depth'],
1651
  good_entry['K'],
 
1654
  good_entry['image_ids'],
1655
  good_entry['ade'] # Added ade20k segmentation
1656
  )):
1657
+
1658
  # Visualize gestalt segmentation
1659
  K = np.array(K)
1660
  R = np.array(R)
 
1666
  gest_seg_np = np.array(gest_seg).astype(np.uint8)
1667
 
1668
  vertices_ours, connections_ours, vertices_3d_ours, patches, filtered_point_idxs = our_get_vertices_and_edges(gest_seg_np, colmap_rec, img_id, ade_seg, depth, K=K, R=R, t=t, frame=good_entry)
1669
+
1670
  idxs_points.append(filtered_point_idxs)
1671
  all_connections.append(connections_ours)
1672
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1673
  vertices, connections, vertices_3d = vertices_ours, connections_ours, vertices_3d_ours
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1674
 
1675
  vert_edge_per_image[i] = vertices, connections, vertices_3d
1676
+
1677
  extracted_points, extracted_colors, extracted_ids, whole_pcloud, connections = extract_vertices_from_whole_pcloud(colmap_rec, idxs_points, all_connections)
1678
 
1679
+ wf_vertices = good_entry.get('wf_vertices', None)
1680
+
1681
+ patches = generate_patches_v2(extracted_points, extracted_colors, extracted_ids, whole_pcloud, wf_vertices)
1682
 
 
1683
  predicted_vertices = []
1684
+ for i, patch in enumerate(patches):
1685
  pred_vertex, pred_dist, pred_class = predict_vertex_from_patch(pnet_model, patch, device=device)
1686
 
 
 
1687
  if pred_class > vertex_threshold:
1688
  predicted_vertices.append(pred_vertex)
1689
  else:
1690
  predicted_vertices.append(np.array([0.0, 0.0, 0.0])) # Append a zero vertex if not predicted
 
 
 
1691
 
1692
  predicted_vertices = np.array(predicted_vertices) if predicted_vertices else np.empty((0, 3))
1693
 
 
 
1694
  if GENERATE_DATASET:
1695
  save_patches_dataset(patches, DATASET_DIR, img_id)
1696
  return empty_solution()
1697
 
 
 
 
 
 
 
 
 
 
 
1698
  # Filter out zero vertices and update connections accordingly
1699
  non_zero_mask = ~np.all(np.isclose(predicted_vertices, [0.0, 0.0, 0.0]), axis=1)
1700
  valid_indices = np.where(non_zero_mask)[0]
 
1702
  # Filter vertices to only include non-zero ones
1703
  filtered_vertices = predicted_vertices[valid_indices]
1704
 
 
1705
  if GENERATE_DATASET_EDGES:
1706
  patches = generate_edge_patches(good_entry, filtered_vertices, colmap_pcloud)
1707
  save_patches_dataset_class(patches, EDGES_DATASET_DIR, good_entry['order_id'])
 
1708
  return empty_solution()
1709
 
1710
  if len(valid_indices) == 0:
 
1722
  new_end = old_to_new_mapping[end_idx]
1723
  if new_start != new_end: # Ensure we don't connect a vertex to itself
1724
  filtered_connections.append((new_start, new_end))
 
 
 
1725
 
1726
+ # Generate forward edge patches
1727
+ #forward_patches = generate_edge_patches_forward_10d(good_entry, filtered_vertices, colmap_pcloud)
1728
  forward_patches = generate_edge_patches_forward(good_entry, filtered_vertices)
1729
+
1730
  new_connections = []
1731
  if len(forward_patches) > 0:
1732
+ for i, patch in enumerate(forward_patches):
1733
  start_idx, end_idx = patch['connection']
1734
 
1735
  pred_class, pred_score = predict_class_from_patch(pnet_class_model, patch, device=device)
1736
+
1737
  if pred_score > edge_threshold:
1738
  new_connections.append((start_idx, end_idx))
1739
 
train.py CHANGED
@@ -17,11 +17,13 @@ from tqdm import tqdm
17
  from fast_pointnet import load_pointnet_model
18
  from fast_voxel import load_3dcnn_model
19
  from fast_pointnet_class import load_pointnet_model as load_pointnet_class_model
 
20
  import torch
 
21
 
22
- #ds = load_dataset("usm3d/hoho25k", cache_dir="/media/skvrnjan/sd/hoho25k/", trust_remote_code=True)
23
- ds = load_dataset("usm3d/hoho25k", cache_dir="/mnt/personal/skvrnjan/hoho25k/", trust_remote_code=True)
24
- ds = ds.shuffle()
25
 
26
  scores_hss = []
27
  scores_f1 = []
@@ -31,12 +33,13 @@ show_visu = False
31
 
32
  device = "cuda" if torch.cuda.is_available() else "cpu"
33
 
34
- #pnet_model = load_pointnet_model(model_path="/home/skvrnjan/personal/hoho_pnet/initial_epoch_100.pth", device=device, predict_score=True)
35
- pnet_model = load_pointnet_model(model_path="/mnt/personal/skvrnjan/hoho_pnet/initial_epoch_100.pth", device=device, predict_score=True)
36
  #pnet_model = None
37
 
38
  #pnet_class_model = load_pointnet_class_model(model_path="/home/skvrnjan/personal/hoho_pnet_edges_v2/initial_epoch_100.pth", device=device)
39
- pnet_class_model = load_pointnet_class_model(model_path="/mnt/personal/skvrnjan/hoho_pnet_edges_v2/initial_epoch_100.pth", device=device)
 
40
  #pnet_class_model = None
41
 
42
  #voxel_model = load_3dcnn_model(model_path="/home/skvrnjan/personal/hoho_voxel/initial_epoch_100.pth", device=device, predict_score=True)
@@ -45,13 +48,20 @@ voxel_model = None
45
  config = {'vertex_threshold': 0.4, 'edge_threshold': 0.6, 'only_predicted_connections': False}
46
 
47
  idx = 0
 
48
  for a in tqdm(ds['train'], desc="Processing dataset"):
49
  #plot_all_modalities(a)
50
  #pred_vertices, pred_edges = predict_wireframe_old(a)
51
- #pred_vertices, pred_edges = predict_wireframe(a, pnet_model, voxel_model, pnet_class_model, config)
52
  try:
53
- pred_vertices, pred_edges = predict_wireframe(a, pnet_model, voxel_model, pnet_class_model, config)
 
54
  #pred_vertices, pred_edges = predict_wireframe_old(a)
 
 
 
 
 
55
  except:
56
  pred_vertices, pred_edges = empty_solution()
57
 
@@ -72,8 +82,8 @@ for a in tqdm(ds['train'], desc="Processing dataset"):
72
  o3d.visualization.draw_geometries(visu_all, window_name=f"3D Reconstruction - HSS: {score.hss:.4f}, F1: {score.f1:.4f}, IoU: {score.iou:.4f}")
73
 
74
  idx += 1
75
- #if idx >= 100: # Limit to first 10 samples for testing
76
- # break
77
 
78
  for i in range(10):
79
  print("END OF DATASET")
 
17
  from fast_pointnet import load_pointnet_model
18
  from fast_voxel import load_3dcnn_model
19
  from fast_pointnet_class import load_pointnet_model as load_pointnet_class_model
20
+ from fast_pointnet_class_10d import load_pointnet_model as load_pointnet_class_model_10d
21
  import torch
22
+ import time
23
 
24
+ ds = load_dataset("usm3d/hoho25k", cache_dir="/media/skvrnjan/sd/hoho25k/", trust_remote_code=True)
25
+ #ds = load_dataset("usm3d/hoho25k", cache_dir="/mnt/personal/skvrnjan/hoho25k/", trust_remote_code=True)
26
+ #ds = ds.shuffle()
27
 
28
  scores_hss = []
29
  scores_f1 = []
 
33
 
34
  device = "cuda" if torch.cuda.is_available() else "cpu"
35
 
36
+ pnet_model = load_pointnet_model(model_path="pnet.pth", device=device, predict_score=True)
37
+ #pnet_model = load_pointnet_model(model_path="/mnt/personal/skvrnjan/hoho_pnet/initial_epoch_100.pth", device=device, predict_score=True)
38
  #pnet_model = None
39
 
40
  #pnet_class_model = load_pointnet_class_model(model_path="/home/skvrnjan/personal/hoho_pnet_edges_v2/initial_epoch_100.pth", device=device)
41
+ #pnet_class_model = load_pointnet_class_model_10d(model_path="/home/skvrnjan/personal/hoho_pnet_edges_10d/initial_epoch_75.pth", device=device)
42
+ pnet_class_model = load_pointnet_class_model(model_path="pnet_class.pth", device=device)
43
  #pnet_class_model = None
44
 
45
  #voxel_model = load_3dcnn_model(model_path="/home/skvrnjan/personal/hoho_voxel/initial_epoch_100.pth", device=device, predict_score=True)
 
48
  config = {'vertex_threshold': 0.4, 'edge_threshold': 0.6, 'only_predicted_connections': False}
49
 
50
  idx = 0
51
+ prediction_times = []
52
  for a in tqdm(ds['train'], desc="Processing dataset"):
53
  #plot_all_modalities(a)
54
  #pred_vertices, pred_edges = predict_wireframe_old(a)
55
+ pred_vertices, pred_edges = predict_wireframe(a.copy(), pnet_model, voxel_model, pnet_class_model, config)
56
  try:
57
+ start_time = time.time()
58
+ pred_vertices, pred_edges = predict_wireframe(a.copy(), pnet_model, voxel_model, pnet_class_model, config)
59
  #pred_vertices, pred_edges = predict_wireframe_old(a)
60
+ end_time = time.time()
61
+ prediction_time = end_time - start_time
62
+ prediction_times.append(prediction_time)
63
+ mean_time = np.mean(prediction_times)
64
+ print(f"Prediction time: {prediction_time:.4f} seconds, Mean time: {mean_time:.4f} seconds")
65
  except:
66
  pred_vertices, pred_edges = empty_solution()
67
 
 
82
  o3d.visualization.draw_geometries(visu_all, window_name=f"3D Reconstruction - HSS: {score.hss:.4f}, F1: {score.f1:.4f}, IoU: {score.iou:.4f}")
83
 
84
  idx += 1
85
+ if idx >= 100: # Limit to first 10 samples for testing
86
+ break
87
 
88
  for i in range(10):
89
  print("END OF DATASET")