jskvrna commited on
Commit
322171e
·
1 Parent(s): 7cf6fd9

Adds PointNet classification for edge detection

Browse files

Implements a PointNet-based architecture for classifying 6D point cloud patches, designed to identify edges within 3D data. This includes a dataset class for loading and augmenting patches, a training pipeline, and functions for prediction. It also updates the 3D CNN prediction function and incorporates point merging based on overlap to refine vertex extraction.

Files changed (3) hide show
  1. fast_pointnet_class.py +407 -0
  2. fast_voxel.py +1 -1
  3. predict.py +31 -11
fast_pointnet_class.py ADDED
@@ -0,0 +1,407 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import numpy as np
6
+ import pickle
7
+ from torch.utils.data import Dataset, DataLoader
8
+ from typing import List, Dict, Tuple, Optional
9
+ import json
10
+
11
+ class ClassificationPointNet(nn.Module):
12
+ """
13
+ PointNet implementation for binary classification from 6D point cloud patches.
14
+ Takes 6D point clouds (x,y,z,r,g,b) and predicts binary classification (edge/not edge).
15
+ """
16
+ def __init__(self, input_dim=6, max_points=1024):
17
+ super(ClassificationPointNet, self).__init__()
18
+ self.max_points = max_points
19
+
20
+ # Point-wise MLPs for feature extraction (deeper network)
21
+ self.conv1 = nn.Conv1d(input_dim, 64, 1)
22
+ self.conv2 = nn.Conv1d(64, 128, 1)
23
+ self.conv3 = nn.Conv1d(128, 256, 1)
24
+ self.conv4 = nn.Conv1d(256, 512, 1)
25
+ self.conv5 = nn.Conv1d(512, 1024, 1)
26
+ self.conv6 = nn.Conv1d(1024, 2048, 1) # Additional layer
27
+
28
+ # Classification head (deeper with more capacity)
29
+ self.fc1 = nn.Linear(2048, 1024)
30
+ self.fc2 = nn.Linear(1024, 512)
31
+ self.fc3 = nn.Linear(512, 256)
32
+ self.fc4 = nn.Linear(256, 128)
33
+ self.fc5 = nn.Linear(128, 64)
34
+ self.fc6 = nn.Linear(64, 1) # Single output for binary classification
35
+
36
+ # Batch normalization layers
37
+ self.bn1 = nn.BatchNorm1d(64)
38
+ self.bn2 = nn.BatchNorm1d(128)
39
+ self.bn3 = nn.BatchNorm1d(256)
40
+ self.bn4 = nn.BatchNorm1d(512)
41
+ self.bn5 = nn.BatchNorm1d(1024)
42
+ self.bn6 = nn.BatchNorm1d(2048)
43
+
44
+ # Dropout layers
45
+ self.dropout1 = nn.Dropout(0.3)
46
+ self.dropout2 = nn.Dropout(0.4)
47
+ self.dropout3 = nn.Dropout(0.5)
48
+ self.dropout4 = nn.Dropout(0.4)
49
+ self.dropout5 = nn.Dropout(0.3)
50
+
51
+ def forward(self, x):
52
+ """
53
+ Forward pass
54
+ Args:
55
+ x: (batch_size, input_dim, max_points) tensor
56
+ Returns:
57
+ classification: (batch_size, 1) tensor of logits (sigmoid for probability)
58
+ """
59
+ batch_size = x.size(0)
60
+
61
+ # Point-wise feature extraction
62
+ x1 = F.relu(self.bn1(self.conv1(x)))
63
+ x2 = F.relu(self.bn2(self.conv2(x1)))
64
+ x3 = F.relu(self.bn3(self.conv3(x2)))
65
+ x4 = F.relu(self.bn4(self.conv4(x3)))
66
+ x5 = F.relu(self.bn5(self.conv5(x4)))
67
+ x6 = F.relu(self.bn6(self.conv6(x5)))
68
+
69
+ # Global max pooling
70
+ global_features = torch.max(x6, 2)[0] # (batch_size, 2048)
71
+
72
+ # Classification head
73
+ x = F.relu(self.fc1(global_features))
74
+ x = self.dropout1(x)
75
+ x = F.relu(self.fc2(x))
76
+ x = self.dropout2(x)
77
+ x = F.relu(self.fc3(x))
78
+ x = self.dropout3(x)
79
+ x = F.relu(self.fc4(x))
80
+ x = self.dropout4(x)
81
+ x = F.relu(self.fc5(x))
82
+ x = self.dropout5(x)
83
+ classification = self.fc6(x) # (batch_size, 1)
84
+
85
+ return classification
86
+
87
+ class PatchClassificationDataset(Dataset):
88
+ """
89
+ Dataset class for loading saved patches for PointNet classification training.
90
+ """
91
+
92
+ def __init__(self, dataset_dir: str, max_points: int = 1024, augment: bool = True):
93
+ self.dataset_dir = dataset_dir
94
+ self.max_points = max_points
95
+ self.augment = augment
96
+
97
+ # Load patch files
98
+ self.patch_files = []
99
+ for file in os.listdir(dataset_dir):
100
+ if file.endswith('.pkl'):
101
+ self.patch_files.append(os.path.join(dataset_dir, file))
102
+
103
+ print(f"Found {len(self.patch_files)} patch files in {dataset_dir}")
104
+
105
+ def __len__(self):
106
+ return len(self.patch_files)
107
+
108
+ def __getitem__(self, idx):
109
+ """
110
+ Load and process a patch for training.
111
+ Returns:
112
+ patch_data: (6, max_points) tensor of point cloud data
113
+ label: scalar tensor for binary classification (0 or 1)
114
+ valid_mask: (max_points,) boolean tensor indicating valid points
115
+ """
116
+ patch_file = self.patch_files[idx]
117
+
118
+ with open(patch_file, 'rb') as f:
119
+ patch_info = pickle.load(f)
120
+
121
+ patch_6d = patch_info['patch_6d'] # (N, 6)
122
+ label = patch_info.get('label', 0) # Get binary classification label (0 or 1)
123
+
124
+ # Pad or sample points to max_points
125
+ num_points = patch_6d.shape[0]
126
+
127
+ if num_points >= self.max_points:
128
+ # Randomly sample max_points
129
+ indices = np.random.choice(num_points, self.max_points, replace=False)
130
+ patch_sampled = patch_6d[indices]
131
+ valid_mask = np.ones(self.max_points, dtype=bool)
132
+ else:
133
+ # Pad with zeros
134
+ patch_sampled = np.zeros((self.max_points, 6))
135
+ patch_sampled[:num_points] = patch_6d
136
+ valid_mask = np.zeros(self.max_points, dtype=bool)
137
+ valid_mask[:num_points] = True
138
+
139
+ # Data augmentation
140
+ if self.augment:
141
+ patch_sampled = self._augment_patch(patch_sampled, valid_mask)
142
+
143
+ # Convert to tensors and transpose for conv1d (channels first)
144
+ patch_tensor = torch.from_numpy(patch_sampled.T).float() # (6, max_points)
145
+ label_tensor = torch.tensor(label, dtype=torch.float32) # Float for BCE loss
146
+ valid_mask_tensor = torch.from_numpy(valid_mask)
147
+
148
+ return patch_tensor, label_tensor, valid_mask_tensor
149
+
150
+ def _augment_patch(self, patch, valid_mask):
151
+ """
152
+ Apply data augmentation to the patch.
153
+ """
154
+ valid_points = patch[valid_mask]
155
+
156
+ if len(valid_points) == 0:
157
+ return patch
158
+
159
+ # Random rotation around z-axis
160
+ angle = np.random.uniform(0, 2 * np.pi)
161
+ cos_angle = np.cos(angle)
162
+ sin_angle = np.sin(angle)
163
+ rotation_matrix = np.array([
164
+ [cos_angle, -sin_angle, 0],
165
+ [sin_angle, cos_angle, 0],
166
+ [0, 0, 1]
167
+ ])
168
+
169
+ # Apply rotation to xyz coordinates
170
+ valid_points[:, :3] = valid_points[:, :3] @ rotation_matrix.T
171
+
172
+ # Random jittering
173
+ noise = np.random.normal(0, 0.01, valid_points[:, :3].shape)
174
+ valid_points[:, :3] += noise
175
+
176
+ # Random scaling
177
+ scale = np.random.uniform(0.9, 1.1)
178
+ valid_points[:, :3] *= scale
179
+
180
+ patch[valid_mask] = valid_points
181
+ return patch
182
+
183
+ def save_patches_dataset(patches: List[Dict], dataset_dir: str, entry_id: str):
184
+ """
185
+ Save patches from prediction pipeline to create a training dataset.
186
+
187
+ Args:
188
+ patches: List of patch dictionaries from generate_patches()
189
+ dataset_dir: Directory to save the dataset
190
+ entry_id: Unique identifier for this entry/image
191
+ """
192
+ os.makedirs(dataset_dir, exist_ok=True)
193
+
194
+ for i, patch in enumerate(patches):
195
+ # Create unique filename
196
+ filename = f"{entry_id}_patch_{i}.pkl"
197
+ filepath = os.path.join(dataset_dir, filename)
198
+
199
+ # Skip if file already exists
200
+ if os.path.exists(filepath):
201
+ continue
202
+
203
+ # Save patch data
204
+ with open(filepath, 'wb') as f:
205
+ pickle.dump(patch, f)
206
+
207
+ print(f"Saved {len(patches)} patches for entry {entry_id}")
208
+
209
+ # Create dataloader with custom collate function to filter invalid samples
210
+ def collate_fn(batch):
211
+ valid_batch = []
212
+ for patch_data, label, valid_mask in batch:
213
+ # Filter out invalid samples (no valid points)
214
+ if valid_mask.sum() > 0:
215
+ valid_batch.append((patch_data, label, valid_mask))
216
+
217
+ if len(valid_batch) == 0:
218
+ return None
219
+
220
+ # Stack valid samples
221
+ patch_data = torch.stack([item[0] for item in valid_batch])
222
+ labels = torch.stack([item[1] for item in valid_batch])
223
+ valid_masks = torch.stack([item[2] for item in valid_batch])
224
+
225
+ return patch_data, labels, valid_masks
226
+
227
+ # Initialize weights using Xavier/Glorot initialization
228
+ def init_weights(m):
229
+ if isinstance(m, nn.Conv1d):
230
+ nn.init.xavier_uniform_(m.weight)
231
+ if m.bias is not None:
232
+ nn.init.zeros_(m.bias)
233
+ elif isinstance(m, nn.Linear):
234
+ nn.init.xavier_uniform_(m.weight)
235
+ if m.bias is not None:
236
+ nn.init.zeros_(m.bias)
237
+ elif isinstance(m, nn.BatchNorm1d):
238
+ nn.init.ones_(m.weight)
239
+ nn.init.zeros_(m.bias)
240
+
241
+ def train_pointnet(dataset_dir: str, model_save_path: str, epochs: int = 100, batch_size: int = 32,
242
+ lr: float = 0.001):
243
+ """
244
+ Train the ClassificationPointNet model on saved patches.
245
+
246
+ Args:
247
+ dataset_dir: Directory containing saved patch files
248
+ model_save_path: Path to save the trained model
249
+ epochs: Number of training epochs
250
+ batch_size: Training batch size
251
+ lr: Learning rate
252
+ """
253
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
254
+ print(f"Training on device: {device}")
255
+
256
+ # Create dataset and dataloader
257
+ dataset = PatchClassificationDataset(dataset_dir, max_points=1024, augment=True)
258
+ print(f"Dataset loaded with {len(dataset)} samples")
259
+
260
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=8,
261
+ collate_fn=collate_fn, drop_last=True)
262
+
263
+ # Initialize model
264
+ model = ClassificationPointNet(input_dim=6, max_points=1024)
265
+ model.apply(init_weights)
266
+ model.to(device)
267
+
268
+ # Loss function and optimizer (BCE for binary classification)
269
+ criterion = nn.BCEWithLogitsLoss()
270
+ optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)
271
+ scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.5)
272
+
273
+ # Training loop
274
+ model.train()
275
+ for epoch in range(epochs):
276
+ total_loss = 0.0
277
+ correct = 0
278
+ total = 0
279
+ num_batches = 0
280
+
281
+ for batch_idx, batch_data in enumerate(dataloader):
282
+ if batch_data is None: # Skip invalid batches
283
+ continue
284
+
285
+ patch_data, labels, valid_masks = batch_data
286
+ patch_data = patch_data.to(device) # (batch_size, 6, max_points)
287
+ labels = labels.to(device).unsqueeze(1) # (batch_size, 1)
288
+
289
+ # Forward pass
290
+ optimizer.zero_grad()
291
+ outputs = model(patch_data) # (batch_size, 1)
292
+ loss = criterion(outputs, labels)
293
+
294
+ # Backward pass
295
+ loss.backward()
296
+ optimizer.step()
297
+
298
+ # Statistics
299
+ total_loss += loss.item()
300
+ predicted = (torch.sigmoid(outputs) > 0.5).float()
301
+ total += labels.size(0)
302
+ correct += (predicted == labels).sum().item()
303
+ num_batches += 1
304
+
305
+ if batch_idx % 50 == 0:
306
+ print(f"Epoch {epoch+1}/{epochs}, Batch {batch_idx}, "
307
+ f"Loss: {loss.item():.6f}, "
308
+ f"Accuracy: {100 * correct / total:.2f}%")
309
+
310
+ avg_loss = total_loss / num_batches if num_batches > 0 else 0
311
+ accuracy = 100 * correct / total if total > 0 else 0
312
+
313
+ print(f"Epoch {epoch+1}/{epochs} completed, "
314
+ f"Avg Loss: {avg_loss:.6f}, "
315
+ f"Accuracy: {accuracy:.2f}%")
316
+
317
+ scheduler.step()
318
+
319
+ # Save model checkpoint every 10 epochs
320
+ if (epoch + 1) % 10 == 0:
321
+ checkpoint_path = model_save_path.replace('.pth', f'_epoch_{epoch+1}.pth')
322
+ torch.save({
323
+ 'model_state_dict': model.state_dict(),
324
+ 'optimizer_state_dict': optimizer.state_dict(),
325
+ 'epoch': epoch + 1,
326
+ 'loss': avg_loss,
327
+ 'accuracy': accuracy,
328
+ }, checkpoint_path)
329
+
330
+ # Save the trained model
331
+ torch.save({
332
+ 'model_state_dict': model.state_dict(),
333
+ 'optimizer_state_dict': optimizer.state_dict(),
334
+ 'epoch': epochs,
335
+ }, model_save_path)
336
+
337
+ print(f"Model saved to {model_save_path}")
338
+ return model
339
+
340
+ def load_pointnet_model(model_path: str, device: torch.device = None) -> ClassificationPointNet:
341
+ """
342
+ Load a trained ClassificationPointNet model.
343
+
344
+ Args:
345
+ model_path: Path to the saved model
346
+ device: Device to load the model on
347
+
348
+ Returns:
349
+ Loaded ClassificationPointNet model
350
+ """
351
+ if device is None:
352
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
353
+
354
+ model = ClassificationPointNet(input_dim=6, max_points=1024)
355
+
356
+ checkpoint = torch.load(model_path, map_location=device)
357
+ model.load_state_dict(checkpoint['model_state_dict'])
358
+
359
+ model.to(device)
360
+ model.eval()
361
+
362
+ return model
363
+
364
+ def predict_class_from_patch(model: ClassificationPointNet, patch: Dict, device: torch.device = None) -> Tuple[int, float]:
365
+ """
366
+ Predict binary classification from a patch using trained PointNet.
367
+
368
+ Args:
369
+ model: Trained ClassificationPointNet model
370
+ patch: Dictionary containing patch data with 'patch_6d' key
371
+ device: Device to run prediction on
372
+
373
+ Returns:
374
+ tuple of (predicted_class, confidence)
375
+ predicted_class: int (0 for not edge, 1 for edge)
376
+ confidence: float representing confidence score (0-1)
377
+ """
378
+ if device is None:
379
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
380
+
381
+ patch_6d = patch['patch_6d'] # (N, 6)
382
+
383
+ # Prepare input
384
+ max_points = 1024
385
+ num_points = patch_6d.shape[0]
386
+
387
+ if num_points >= max_points:
388
+ # Sample points
389
+ indices = np.random.choice(num_points, max_points, replace=False)
390
+ patch_sampled = patch_6d[indices]
391
+ else:
392
+ # Pad with zeros
393
+ patch_sampled = np.zeros((max_points, 6))
394
+ patch_sampled[:num_points] = patch_6d
395
+
396
+ # Convert to tensor
397
+ patch_tensor = torch.from_numpy(patch_sampled.T).float().unsqueeze(0) # (1, 6, max_points)
398
+ patch_tensor = patch_tensor.to(device)
399
+
400
+ # Predict
401
+ with torch.no_grad():
402
+ outputs = model(patch_tensor) # (1, 1)
403
+ probability = torch.sigmoid(outputs).item()
404
+ predicted_class = int(probability > 0.5)
405
+ confidence = probability if predicted_class == 1 else (1 - probability)
406
+
407
+ return predicted_class, confidence
fast_voxel.py CHANGED
@@ -531,7 +531,7 @@ def load_3dcnn_model(model_path: str, device: torch.device = None, voxel_size: i
531
 
532
  return model
533
 
534
- def predict_vertex_from_patch(model: Fast3DCNN, patch: np.ndarray, device: torch.device = None, voxel_size: int = 32) -> Tuple[np.ndarray, float, float]:
535
  """
536
  Predict 3D vertex coordinates, confidence score, and classification from a patch using trained 3D CNN.
537
 
 
531
 
532
  return model
533
 
534
+ def predict_vertex_from_patch_voxel(model: Fast3DCNN, patch: np.ndarray, device: torch.device = None, voxel_size: int = 32) -> Tuple[np.ndarray, float, float]:
535
  """
536
  Predict 3D vertex coordinates, confidence score, and classification from a patch using trained 3D CNN.
537
 
predict.py CHANGED
@@ -393,18 +393,18 @@ def visu_patch_and_pred(patch, pred, pred_dist, pred_class):
393
  plotter = pv.Plotter()
394
 
395
  # Create point cloud for this patch
396
- offset = patch.get('offset', None) # Offset if available
397
  patch_points_3d = np.array(patch['patch_7d'][:, :3])
398
- patch_points_3d = patch_points_3d - offset
399
  patch_cloud = pv.PolyData(patch_points_3d)
400
 
401
- point_idxs = patch['filtered_point_ids'] # List of point indices that are filtered
402
- patch_point_ids = patch['point_ids'] # Assuming the 7th column contains point IDs
403
- assigned_gt_vertex = patch.get('assigned_gt_vertex', None) # GT vertex if available
404
- initial_pred = patch.get('initial_pred', None) # Initial prediction if available
405
- initial_pred = initial_pred - offset
406
 
407
- assigned_gt_vertex = assigned_gt_vertex - offset
 
408
 
409
  # Color points: red for filtered points, blue for other points
410
  patch_point_colors = []
@@ -436,17 +436,34 @@ def visu_patch_and_pred(patch, pred, pred_dist, pred_class):
436
  title_text = f"Patch x\nPred dist: {pred_dist:.4f}\nPred class: {pred_class}"
437
  plotter.show(title=title_text)
438
 
439
- def extract_vertices_from_whole_pcloud(colmap_rec, idxs_points):
440
  # Filter COLMAP points and colors based on idxs_points
441
  filtered_colmap_points = []
442
  filtered_colmap_colors = []
443
  filtered_colmap_ids = []
444
  all_filtered_ids_list = []
445
  all_extracted_groups = []
 
 
446
 
 
 
447
  for group_idx, point_ids_group in enumerate(idxs_points):
448
- for point_ids in point_ids_group:
 
 
 
449
  all_extracted_groups.append(point_ids)
 
 
 
 
 
 
 
 
 
 
450
 
451
  # Collect all filtered point IDs from all images
452
  for group_idxs in idxs_points:
@@ -523,13 +540,15 @@ def extract_vertices_from_whole_pcloud(colmap_rec, idxs_points):
523
  extracted_colors.append(np.array(group_extracted_colors) if group_extracted_colors else np.empty((0, 3)))
524
  extracted_ids.append(np.array(group_extracted_ids) if group_extracted_ids else np.empty((0,)))
525
 
526
-
527
  # Filter extracted_points to merge groups that share more than 50% of their points
 
 
528
  if extracted_points:
529
  print(f"Merging groups based on point overlap... Processing {len(extracted_points)} groups")
530
  # Create a list to track which groups to keep
531
  groups_to_keep = []
532
  merged_groups = set() # Track which groups have been merged
 
533
 
534
  for i, (points_i, colors_i, ids_i) in enumerate(zip(extracted_points, extracted_colors, extracted_ids)):
535
  if i in merged_groups or len(ids_i) == 0:
@@ -539,6 +558,7 @@ def extract_vertices_from_whole_pcloud(colmap_rec, idxs_points):
539
  merged_points = points_i.copy()
540
  merged_colors = colors_i.copy()
541
  merged_ids = set(ids_i)
 
542
 
543
  # Check all subsequent groups for overlap
544
  for j in range(i + 1, len(extracted_points)):
 
393
  plotter = pv.Plotter()
394
 
395
  # Create point cloud for this patch
396
+ offset = patch.get('cluster_center', None) # Offset if available
397
  patch_points_3d = np.array(patch['patch_7d'][:, :3])
398
+ patch_points_3d = patch_points_3d + offset
399
  patch_cloud = pv.PolyData(patch_points_3d)
400
 
401
+ point_idxs = patch['cluster_point_ids'] # List of point indices that are filtered
402
+ patch_point_ids = patch['cube_point_ids'] # Assuming the 7th column contains point IDs
403
+ assigned_gt_vertex = patch.get('assigned_wf_vertex', None) # GT vertex if available
404
+ initial_pred = None
 
405
 
406
+ if assigned_gt_vertex is not None:
407
+ assigned_gt_vertex = assigned_gt_vertex + offset
408
 
409
  # Color points: red for filtered points, blue for other points
410
  patch_point_colors = []
 
436
  title_text = f"Patch x\nPred dist: {pred_dist:.4f}\nPred class: {pred_class}"
437
  plotter.show(title=title_text)
438
 
439
+ def extract_vertices_from_whole_pcloud(colmap_rec, idxs_points, all_connections):
440
  # Filter COLMAP points and colors based on idxs_points
441
  filtered_colmap_points = []
442
  filtered_colmap_colors = []
443
  filtered_colmap_ids = []
444
  all_filtered_ids_list = []
445
  all_extracted_groups = []
446
+ all_flattened_connections = []
447
+ group_to_flattened_mapping = {} # Maps (group_idx, local_vertex_idx) to flattened_idx
448
 
449
+ # Flatten all groups and create mapping for connections
450
+ flattened_idx = 0
451
  for group_idx, point_ids_group in enumerate(idxs_points):
452
+ cur_connections = all_connections[group_idx]
453
+ group_to_flattened_mapping[group_idx] = {}
454
+
455
+ for local_idx, point_ids in enumerate(point_ids_group):
456
  all_extracted_groups.append(point_ids)
457
+ group_to_flattened_mapping[group_idx][local_idx] = flattened_idx
458
+ flattened_idx += 1
459
+
460
+ # Convert connections to flattened indices
461
+ for conn in cur_connections:
462
+ start_idx, end_idx = conn
463
+ if start_idx in group_to_flattened_mapping[group_idx] and end_idx in group_to_flattened_mapping[group_idx]:
464
+ flattened_start = group_to_flattened_mapping[group_idx][start_idx]
465
+ flattened_end = group_to_flattened_mapping[group_idx][end_idx]
466
+ all_flattened_connections.append((flattened_start, flattened_end))
467
 
468
  # Collect all filtered point IDs from all images
469
  for group_idxs in idxs_points:
 
540
  extracted_colors.append(np.array(group_extracted_colors) if group_extracted_colors else np.empty((0, 3)))
541
  extracted_ids.append(np.array(group_extracted_ids) if group_extracted_ids else np.empty((0,)))
542
 
 
543
  # Filter extracted_points to merge groups that share more than 50% of their points
544
+ # and update connections accordingly
545
+ updated_connections = []
546
  if extracted_points:
547
  print(f"Merging groups based on point overlap... Processing {len(extracted_points)} groups")
548
  # Create a list to track which groups to keep
549
  groups_to_keep = []
550
  merged_groups = set() # Track which groups have been merged
551
+ old_to_new_mapping = {} # Maps old flattened index to new index
552
 
553
  for i, (points_i, colors_i, ids_i) in enumerate(zip(extracted_points, extracted_colors, extracted_ids)):
554
  if i in merged_groups or len(ids_i) == 0:
 
558
  merged_points = points_i.copy()
559
  merged_colors = colors_i.copy()
560
  merged_ids = set(ids_i)
561
+ merged_indices = [i] # Track which original indices are merged
562
 
563
  # Check all subsequent groups for overlap
564
  for j in range(i + 1, len(extracted_points)):