jskvrna commited on
Commit
9b4ba1f
·
1 Parent(s): a70b55e

Implements FastPointNet for vertex prediction

Browse files

Adds a FastPointNet model for predicting 3D vertex coordinates from point cloud patches.
Includes a dataset class for loading and augmenting patch data.
Also adds training and prediction functions for the model.

Patches are generated and saved for training the PointNet model,
allowing for iterative refinement of vertex predictions.
The patch generation process includes filtering COLMAP points within a
ball around identified vertices and creating a 7D point cloud
representation. GT vertex assignment and data augmentation are incorporated
to improve the training data.

Files changed (3) hide show
  1. fast_pointnet.py +421 -0
  2. predict.py +175 -18
  3. train.py +1 -1
fast_pointnet.py ADDED
@@ -0,0 +1,421 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import numpy as np
6
+ import pickle
7
+ from torch.utils.data import Dataset, DataLoader
8
+ from typing import List, Dict, Tuple, Optional
9
+ import json
10
+
11
+ class FastPointNet(nn.Module):
12
+ """
13
+ Fast PointNet implementation for 3D vertex prediction from point cloud patches.
14
+ Takes 7D point clouds (x,y,z,r,g,b,filtered_flag) and predicts 3D vertex coordinates.
15
+ """
16
+
17
+ def __init__(self, input_dim=7, output_dim=3, max_points=1024, predict_score=True):
18
+ super(FastPointNet, self).__init__()
19
+ self.max_points = max_points
20
+ self.predict_score = predict_score
21
+
22
+ # Point-wise MLPs
23
+ self.conv1 = nn.Conv1d(input_dim, 64, 1)
24
+ self.conv2 = nn.Conv1d(64, 128, 1)
25
+ self.conv3 = nn.Conv1d(128, 256, 1)
26
+
27
+ # Global feature extraction
28
+ self.conv4 = nn.Conv1d(256, 512, 1)
29
+ self.conv5 = nn.Conv1d(512, 1024, 1)
30
+
31
+ # Shared features
32
+ self.shared_fc = nn.Linear(1024, 512)
33
+
34
+ # Position prediction head
35
+ self.pos_fc1 = nn.Linear(512, 256)
36
+ self.pos_fc2 = nn.Linear(256, output_dim)
37
+
38
+ # Score prediction head (predicts distance to GT)
39
+ if self.predict_score:
40
+ self.score_fc1 = nn.Linear(512, 256)
41
+ self.score_fc2 = nn.Linear(256, 128)
42
+ self.score_fc3 = nn.Linear(128, 1) # Single score output
43
+
44
+ self.dropout = nn.Dropout(0.3)
45
+ self.bn1 = nn.BatchNorm1d(64)
46
+ self.bn2 = nn.BatchNorm1d(128)
47
+ self.bn3 = nn.BatchNorm1d(256)
48
+ self.bn4 = nn.BatchNorm1d(512)
49
+ self.bn5 = nn.BatchNorm1d(1024)
50
+
51
+ def forward(self, x):
52
+ """
53
+ Forward pass
54
+ Args:
55
+ x: (batch_size, input_dim, max_points) tensor
56
+ Returns:
57
+ if predict_score=True: tuple of (position, score)
58
+ position: (batch_size, output_dim) tensor of predicted 3D coordinates
59
+ score: (batch_size, 1) tensor of predicted distance to GT
60
+ else: (batch_size, output_dim) tensor of predicted 3D coordinates
61
+ """
62
+ batch_size = x.size(0)
63
+
64
+ # Point-wise feature extraction
65
+ x = F.relu(self.bn1(self.conv1(x)))
66
+ x = F.relu(self.bn2(self.conv2(x)))
67
+ x = F.relu(self.bn3(self.conv3(x)))
68
+ x = F.relu(self.bn4(self.conv4(x)))
69
+ x = F.relu(self.bn5(self.conv5(x)))
70
+
71
+ # Global max pooling
72
+ x = torch.max(x, 2)[0] # (batch_size, 1024)
73
+
74
+ # Shared features
75
+ shared_features = F.relu(self.shared_fc(x))
76
+ shared_features = self.dropout(shared_features)
77
+
78
+ # Position prediction
79
+ pos_features = F.relu(self.pos_fc1(shared_features))
80
+ pos_features = self.dropout(pos_features)
81
+ position = self.pos_fc2(pos_features)
82
+
83
+ if self.predict_score:
84
+ # Score prediction (distance to GT)
85
+ score_features = F.relu(self.score_fc1(shared_features))
86
+ score_features = self.dropout(score_features)
87
+ score_features = F.relu(self.score_fc2(score_features))
88
+ score_features = self.dropout(score_features)
89
+ score = F.relu(self.score_fc3(score_features)) # Ensure positive distance
90
+
91
+ return position, score
92
+ else:
93
+ return position
94
+
95
+ class PatchDataset(Dataset):
96
+ """
97
+ Dataset class for loading saved patches for PointNet training.
98
+ """
99
+
100
+ def __init__(self, dataset_dir: str, max_points: int = 1024, augment: bool = True):
101
+ self.dataset_dir = dataset_dir
102
+ self.max_points = max_points
103
+ self.augment = augment
104
+
105
+ # Load patch files
106
+ self.patch_files = []
107
+ for file in os.listdir(dataset_dir):
108
+ if file.endswith('.pkl'):
109
+ self.patch_files.append(os.path.join(dataset_dir, file))
110
+
111
+ print(f"Found {len(self.patch_files)} patch files in {dataset_dir}")
112
+
113
+ def __len__(self):
114
+ return len(self.patch_files)
115
+
116
+ def __getitem__(self, idx):
117
+ """
118
+ Load and process a patch for training.
119
+ Returns:
120
+ patch_data: (7, max_points) tensor of point cloud data
121
+ target: (3,) tensor of target 3D coordinates
122
+ valid_mask: (max_points,) boolean tensor indicating valid points
123
+ distance_to_gt: scalar tensor of distance from initial prediction to GT
124
+ """
125
+ patch_file = self.patch_files[idx]
126
+
127
+ with open(patch_file, 'rb') as f:
128
+ patch_info = pickle.load(f)
129
+
130
+ patch_7d = patch_info['patch_7d'] # (N, 7)
131
+ target = patch_info['assigned_gt_vertex'] # (3,) or None
132
+ initial_pred = patch_info.get('initial_pred', None) # (3,) or None
133
+
134
+ # Skip patches without ground truth
135
+ if target is None:
136
+ # Return dummy data that will be filtered out
137
+ dummy_patch = np.zeros((self.max_points, 7))
138
+ dummy_target = np.zeros(3)
139
+ valid_mask = np.zeros(self.max_points, dtype=bool)
140
+ dummy_distance = np.array([0.0])
141
+ return (torch.from_numpy(dummy_patch.T).float(),
142
+ torch.from_numpy(dummy_target).float(),
143
+ torch.from_numpy(valid_mask),
144
+ torch.from_numpy(dummy_distance).float())
145
+
146
+ target = np.array(target)
147
+
148
+ # Normalize colors from [0,1] to [-1,1]
149
+ patch_7d[:, 3:6] = patch_7d[:, 3:6] * 2.0 - 1.0
150
+
151
+ # Pad or sample points to max_points
152
+ num_points = patch_7d.shape[0]
153
+
154
+ if num_points >= self.max_points:
155
+ # Randomly sample max_points
156
+ indices = np.random.choice(num_points, self.max_points, replace=False)
157
+ patch_sampled = patch_7d[indices]
158
+ valid_mask = np.ones(self.max_points, dtype=bool)
159
+ else:
160
+ # Pad with zeros
161
+ patch_sampled = np.zeros((self.max_points, 7))
162
+ patch_sampled[:num_points] = patch_7d
163
+ valid_mask = np.zeros(self.max_points, dtype=bool)
164
+ valid_mask[:num_points] = True
165
+
166
+ # Data augmentation
167
+ if self.augment:
168
+ patch_sampled = self._augment_patch(patch_sampled, valid_mask)
169
+ target = self._augment_target(target)
170
+
171
+ # Convert to tensors and transpose for conv1d (channels first)
172
+ patch_tensor = torch.from_numpy(patch_sampled.T).float() # (7, max_points)
173
+ target_tensor = torch.from_numpy(target).float() # (3,)
174
+ valid_mask_tensor = torch.from_numpy(valid_mask)
175
+ initial_pred = torch.from_numpy(initial_pred).float()
176
+
177
+ return patch_tensor, target_tensor, valid_mask_tensor, initial_pred
178
+
179
+ def _augment_patch(self, patch: np.ndarray, valid_mask: np.ndarray, target: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
180
+ """Apply data augmentation to patch and target with consistent transformations."""
181
+ # Only augment valid points
182
+ valid_points = patch[valid_mask]
183
+
184
+ if len(valid_points) == 0:
185
+ return patch, target
186
+
187
+ # Random rotation around Z-axis
188
+ angle = np.random.uniform(0, 2 * np.pi)
189
+ cos_angle, sin_angle = np.cos(angle), np.sin(angle)
190
+ rotation_matrix = np.array([
191
+ [cos_angle, -sin_angle, 0],
192
+ [sin_angle, cos_angle, 0],
193
+ [0, 0, 1]
194
+ ])
195
+
196
+ # Apply rotation to patch coordinates
197
+ valid_points[:, :3] = valid_points[:, :3] @ rotation_matrix.T
198
+
199
+ # Apply same rotation to target
200
+ target_augmented = target @ rotation_matrix.T
201
+
202
+ # Add small random noise to coordinates
203
+ noise = np.random.normal(0, 0.01, valid_points[:, :3].shape)
204
+ valid_points[:, :3] += noise
205
+
206
+ # Color jittering
207
+ color_noise = np.random.normal(0, 0.02, valid_points[:, 3:6].shape)
208
+ valid_points[:, 3:6] = np.clip(valid_points[:, 3:6] + color_noise, 0, 1)
209
+
210
+ patch[valid_mask] = valid_points
211
+ return patch, target_augmented
212
+
213
+ def save_patches_dataset(patches: List[Dict], dataset_dir: str, entry_id: str):
214
+ """
215
+ Save patches from prediction pipeline to create a training dataset.
216
+
217
+ Args:
218
+ patches: List of patch dictionaries from generate_patches()
219
+ dataset_dir: Directory to save the dataset
220
+ entry_id: Unique identifier for this entry/image
221
+ """
222
+ os.makedirs(dataset_dir, exist_ok=True)
223
+
224
+ for i, patch in enumerate(patches):
225
+ # Create unique filename
226
+ filename = f"{entry_id}_patch_{i}.pkl"
227
+ filepath = os.path.join(dataset_dir, filename)
228
+
229
+ # Save patch data
230
+ with open(filepath, 'wb') as f:
231
+ pickle.dump(patch, f)
232
+
233
+ print(f"Saved {len(patches)} patches for entry {entry_id}")
234
+
235
+ def train_pointnet(dataset_dir: str, model_save_path: str, epochs: int = 100, batch_size: int = 32, lr: float = 0.001,
236
+ score_weight: float = 0.1):
237
+ """
238
+ Train the FastPointNet model on saved patches.
239
+
240
+ Args:
241
+ dataset_dir: Directory containing saved patch files
242
+ model_save_path: Path to save the trained model
243
+ epochs: Number of training epochs
244
+ batch_size: Training batch size
245
+ lr: Learning rate
246
+ score_weight: Weight for the distance prediction loss
247
+ """
248
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
249
+ print(f"Training on device: {device}")
250
+
251
+ # Create dataset and dataloader
252
+ dataset = PatchDataset(dataset_dir, max_points=1024, augment=True)
253
+
254
+ # Filter out invalid samples
255
+ valid_indices = []
256
+ for i in range(len(dataset)):
257
+ _, target, valid_mask, _ = dataset[i]
258
+ if valid_mask.sum() > 0 and not torch.all(target == 0):
259
+ valid_indices.append(i)
260
+
261
+ print(f"Found {len(valid_indices)} valid patches out of {len(dataset)}")
262
+
263
+ # Create subset with valid samples
264
+ valid_dataset = torch.utils.data.Subset(dataset, valid_indices)
265
+ dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
266
+
267
+ # Initialize model with score prediction
268
+ model = FastPointNet(input_dim=7, output_dim=3, max_points=1024, predict_score=True)
269
+ model.to(device)
270
+
271
+ # Loss functions
272
+ position_criterion = nn.MSELoss()
273
+ score_criterion = nn.MSELoss()
274
+
275
+ optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)
276
+ scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.5)
277
+
278
+ # Training loop
279
+ model.train()
280
+ for epoch in range(epochs):
281
+ total_loss = 0.0
282
+ total_pos_loss = 0.0
283
+ total_score_loss = 0.0
284
+ num_batches = 0
285
+
286
+ for batch_idx, (patch_data, targets, valid_masks, distances) in enumerate(dataloader):
287
+ patch_data = patch_data.to(device) # (batch_size, 7, max_points)
288
+ targets = targets.to(device) # (batch_size, 3)
289
+ distances = distances.to(device) # (batch_size, 1)
290
+
291
+ # Forward pass
292
+ optimizer.zero_grad()
293
+ predictions, predicted_scores = model(patch_data)
294
+
295
+ # Compute actual distance from predictions to targets
296
+ actual_distances = torch.norm(predictions - targets, dim=1, keepdim=True)
297
+
298
+ # Compute losses
299
+ pos_loss = position_criterion(predictions, targets)
300
+ score_loss = score_criterion(predicted_scores, actual_distances)
301
+
302
+ # Combined loss
303
+ total_batch_loss = pos_loss + score_weight * score_loss
304
+
305
+ # Backward pass
306
+ total_batch_loss.backward()
307
+ optimizer.step()
308
+
309
+ total_loss += total_batch_loss.item()
310
+ total_pos_loss += pos_loss.item()
311
+ total_score_loss += score_loss.item()
312
+ num_batches += 1
313
+
314
+ if batch_idx % 50 == 0:
315
+ print(f"Epoch {epoch+1}/{epochs}, Batch {batch_idx}, "
316
+ f"Total Loss: {total_batch_loss.item():.6f}, "
317
+ f"Pos Loss: {pos_loss.item():.6f}, "
318
+ f"Score Loss: {score_loss.item():.6f}")
319
+
320
+ avg_loss = total_loss / num_batches if num_batches > 0 else 0
321
+ avg_pos_loss = total_pos_loss / num_batches if num_batches > 0 else 0
322
+ avg_score_loss = total_score_loss / num_batches if num_batches > 0 else 0
323
+
324
+ print(f"Epoch {epoch+1}/{epochs} completed, "
325
+ f"Avg Total Loss: {avg_loss:.6f}, "
326
+ f"Avg Pos Loss: {avg_pos_loss:.6f}, "
327
+ f"Avg Score Loss: {avg_score_loss:.6f}")
328
+
329
+ scheduler.step()
330
+
331
+ # Save model checkpoint every epoch
332
+ checkpoint_path = model_save_path.replace('.pth', f'_epoch_{epoch+1}.pth')
333
+ torch.save({
334
+ 'model_state_dict': model.state_dict(),
335
+ 'optimizer_state_dict': optimizer.state_dict(),
336
+ 'epoch': epoch + 1,
337
+ 'loss': avg_loss,
338
+ }, checkpoint_path)
339
+
340
+ # Save the trained model
341
+ torch.save({
342
+ 'model_state_dict': model.state_dict(),
343
+ 'optimizer_state_dict': optimizer.state_dict(),
344
+ 'epoch': epochs,
345
+ }, model_save_path)
346
+
347
+ print(f"Model saved to {model_save_path}")
348
+ return model
349
+
350
+ def load_pointnet_model(model_path: str, device: torch.device = None, predict_score: bool = True) -> FastPointNet:
351
+ """
352
+ Load a trained FastPointNet model.
353
+
354
+ Args:
355
+ model_path: Path to the saved model
356
+ device: Device to load the model on
357
+ predict_score: Whether the model predicts scores
358
+
359
+ Returns:
360
+ Loaded FastPointNet model
361
+ """
362
+ if device is None:
363
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
364
+
365
+ model = FastPointNet(input_dim=7, output_dim=3, max_points=1024, predict_score=predict_score)
366
+
367
+ checkpoint = torch.load(model_path, map_location=device)
368
+ model.load_state_dict(checkpoint['model_state_dict'])
369
+
370
+ model.to(device)
371
+ model.eval()
372
+
373
+ return model
374
+
375
+ def predict_vertex_from_patch(model: FastPointNet, patch_7d: np.ndarray, device: torch.device = None) -> Tuple[np.ndarray, float]:
376
+ """
377
+ Predict 3D vertex coordinates and confidence score from a patch using trained PointNet.
378
+
379
+ Args:
380
+ model: Trained FastPointNet model
381
+ patch_7d: (N, 7) numpy array of point cloud data
382
+ device: Device to run prediction on
383
+
384
+ Returns:
385
+ tuple of (predicted_coordinates, confidence_score)
386
+ predicted_coordinates: (3,) numpy array of predicted 3D coordinates
387
+ confidence_score: float representing predicted distance to GT (lower is better)
388
+ """
389
+ if device is None:
390
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
391
+
392
+ model.eval()
393
+
394
+ # Prepare input
395
+ max_points = 1024
396
+ num_points = patch_7d.shape[0]
397
+
398
+ if num_points >= max_points:
399
+ # Sample points
400
+ indices = np.random.choice(num_points, max_points, replace=False)
401
+ patch_sampled = patch_7d[indices]
402
+ else:
403
+ # Pad with zeros
404
+ patch_sampled = np.zeros((max_points, 7))
405
+ patch_sampled[:num_points] = patch_7d
406
+
407
+ # Convert to tensor
408
+ patch_tensor = torch.from_numpy(patch_sampled.T).float().unsqueeze(0) # (1, 7, max_points)
409
+ patch_tensor = patch_tensor.to(device)
410
+
411
+ # Predict
412
+ with torch.no_grad():
413
+ if model.predict_score:
414
+ position, score = model(patch_tensor)
415
+ position = position.cpu().numpy().squeeze()
416
+ score = score.cpu().numpy().squeeze()
417
+ return position, score
418
+ else:
419
+ position = model(patch_tensor)
420
+ position = position.cpu().numpy().squeeze()
421
+ return position, None
predict.py CHANGED
@@ -10,6 +10,11 @@ from PIL import Image as PImage
10
  import cv2
11
  import open3d as o3d
12
  from visu import plot_reconstruction_local, plot_wireframe_local, plot_bpo_cameras_from_entry_local
 
 
 
 
 
13
 
14
  def convert_entry_to_human_readable(entry):
15
  out = {}
@@ -389,11 +394,6 @@ def predict_wireframe(entry) -> Tuple[np.ndarray, List[int]]:
389
  good_entry = convert_entry_to_human_readable(entry)
390
  colmap_rec = good_entry['colmap_binary']
391
 
392
- colmap_pcloud = []
393
- for i, p3D in colmap_rec.points3D.items():
394
- p3D.color = np.array([0, 0, 0])
395
- colmap_pcloud.append(p3D)
396
-
397
  vert_edge_per_image = {}
398
  for i, (gest, depth, K, R, t, img_id, ade_seg) in enumerate(zip(good_entry['gestalt'],
399
  good_entry['depth'],
@@ -413,7 +413,13 @@ def predict_wireframe(entry) -> Tuple[np.ndarray, List[int]]:
413
  gest_seg = gest.resize(depth_size)
414
  gest_seg_np = np.array(gest_seg).astype(np.uint8)
415
 
416
- vertices_ours, connections_ours, vertices_3d_ours = our_get_vertices_and_edges(gest_seg_np, colmap_rec, img_id, ade_seg, depth, K=K, R=R, t=t)
 
 
 
 
 
 
417
  vertices, connections, vertices_3d = vertices_ours, connections_ours, vertices_3d_ours
418
  # Get 2D vertices and edges first
419
  #vertices, connections = get_vertices_and_edges_from_segmentation(gest_seg_np, edge_th=25.)
@@ -444,7 +450,7 @@ def predict_wireframe(entry) -> Tuple[np.ndarray, List[int]]:
444
  continue
445
 
446
  # Call the refactored function to get 3D points
447
- #vertices_3d = create_3d_wireframe_single_image(vertices, connections, depth, colmap_rec, img_id, ade_seg, K, R, t)
448
  #vertices_3d = gt_verts3d
449
  # Store original 2D vertices, connections, and computed 3D points
450
 
@@ -480,6 +486,9 @@ def predict_wireframe(entry) -> Tuple[np.ndarray, List[int]]:
480
  # Visualize the point cloud
481
  o3d.visualization.draw_geometries([pcd], window_name="COLMAP Point Cloud")
482
  '''
 
 
 
483
  # Merge vertices from all images
484
  all_3d_vertices, connections_3d = merge_vertices_3d(vert_edge_per_image, 0.5)
485
  all_3d_vertices_clean, connections_3d_clean = prune_not_connected(all_3d_vertices, connections_3d, keep_largest=False)
@@ -627,9 +636,9 @@ def get_apex_or_eave_points(apex, uv, gest_seg_np, house_mask, valid_indices, po
627
  final_valid_indices = valid_point_indices[depth_filter]
628
 
629
  # Add corresponding points to filtered lists
630
- filtered_points_xyz.extend(points_xyz_world[final_valid_indices])
631
- filtered_point_idxs.extend(points_idxs[final_valid_indices])
632
- filtered_points_color.extend([color] * np.sum(depth_filter))
633
 
634
  # Find the point with lowest depth in the filtered points
635
  if len(final_valid_indices) > 0:
@@ -637,9 +646,6 @@ def get_apex_or_eave_points(apex, uv, gest_seg_np, house_mask, valid_indices, po
637
  lowest_depth_point = final_valid_indices[lowest_depth_idx]
638
 
639
  filtered_vertices_apex.append(points_xyz_world[lowest_depth_point])
640
- filtered_points_xyz.append(points_xyz_world[lowest_depth_point])
641
- filtered_point_idxs.append(points_idxs[lowest_depth_point])
642
- filtered_points_color.append(np.array([1., 1., 0.]))
643
  filtered_vertices_apex_uv.append(centroids[i])
644
 
645
  return filtered_points_xyz, filtered_point_idxs, filtered_points_color, filtered_vertices_apex, filtered_vertices_apex_uv
@@ -653,9 +659,9 @@ def get_vertexes(uv, gest_seg_np, house_mask, valid_indices, points_xyz_world, p
653
  filtered_point_idxs = filtered_point_idxs_apex + filtered_point_idxs_eave
654
  filtered_points_color = filtered_points_color_apex + filtered_points_color_eave
655
 
656
- filtered_points_xyz = np.array(filtered_points_xyz[::-1]) if filtered_points_xyz else np.empty((0, 3))
657
- filtered_point_idxs = np.array(filtered_point_idxs[::-1]) if filtered_point_idxs else np.empty((0,))
658
- filtered_points_color = np.array(filtered_points_color[::-1]) if filtered_points_color else np.empty((0, 3))
659
  filtered_vertices_apex = np.array(filtered_vertices_apex) if filtered_vertices_apex else np.empty((0, 3))
660
  filtered_vertices_apex_uv = np.array(filtered_vertices_apex_uv) if filtered_vertices_apex_uv else np.empty((0, 2))
661
  filtered_vertices_eave = np.array(filtered_vertices_eave) if filtered_vertices_eave else np.empty((0, 3))
@@ -803,7 +809,156 @@ def visualize_3d_wireframe(colmap_rec, filtered_points_xyz, filtered_points_colo
803
 
804
  #o3d.visualization.draw_geometries(geometries, window_name=f"Combined Point Cloud - {img_id_substring}")
805
 
806
- def our_get_vertices_and_edges(gest_seg_np, colmap_rec, img_id_substring, ade_seg, depth, K=None, R=None, t=None, ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
807
  """
808
  Identify apex and eave-end vertices, then detect lines for eave/ridge/rake/valley.
809
  Also find all COLMAP points that project into apex or eave_end masks.
@@ -838,6 +993,8 @@ def our_get_vertices_and_edges(gest_seg_np, colmap_rec, img_id_substring, ade_se
838
 
839
  vertices_formatted, connections, all_vertices_3d = get_connections(gest_seg_np, filtered_vertices_apex, filtered_vertices_eave, filtered_vertices_apex_uv, filtered_vertices_eave_uv)
840
 
 
 
841
  #visualize_3d_wireframe(colmap_rec, filtered_points_xyz, filtered_points_color, all_vertices_3d, connections)
842
 
843
- return vertices_formatted, connections, all_vertices_3d
 
10
  import cv2
11
  import open3d as o3d
12
  from visu import plot_reconstruction_local, plot_wireframe_local, plot_bpo_cameras_from_entry_local
13
+ import pyvista as pv
14
+ from fast_pointnet import save_patches_dataset
15
+
16
+ GENERATE_DATASET = True
17
+ DATASET_DIR = '/home/skvrnjan/personal/hohocustom/'
18
 
19
  def convert_entry_to_human_readable(entry):
20
  out = {}
 
394
  good_entry = convert_entry_to_human_readable(entry)
395
  colmap_rec = good_entry['colmap_binary']
396
 
 
 
 
 
 
397
  vert_edge_per_image = {}
398
  for i, (gest, depth, K, R, t, img_id, ade_seg) in enumerate(zip(good_entry['gestalt'],
399
  good_entry['depth'],
 
413
  gest_seg = gest.resize(depth_size)
414
  gest_seg_np = np.array(gest_seg).astype(np.uint8)
415
 
416
+ vertices_ours, connections_ours, vertices_3d_ours, patches = our_get_vertices_and_edges(gest_seg_np, colmap_rec, img_id, ade_seg, depth, K=K, R=R, t=t, frame=good_entry)
417
+
418
+ if GENERATE_DATASET:
419
+ save_patches_dataset(patches, DATASET_DIR, img_id)
420
+
421
+ continue
422
+
423
  vertices, connections, vertices_3d = vertices_ours, connections_ours, vertices_3d_ours
424
  # Get 2D vertices and edges first
425
  #vertices, connections = get_vertices_and_edges_from_segmentation(gest_seg_np, edge_th=25.)
 
450
  continue
451
 
452
  # Call the refactored function to get 3D points
453
+ vertices_3d = create_3d_wireframe_single_image(vertices, connections, depth, colmap_rec, img_id, ade_seg, K, R, t)
454
  #vertices_3d = gt_verts3d
455
  # Store original 2D vertices, connections, and computed 3D points
456
 
 
486
  # Visualize the point cloud
487
  o3d.visualization.draw_geometries([pcd], window_name="COLMAP Point Cloud")
488
  '''
489
+ if GENERATE_DATASET:
490
+ return empty_solution()
491
+
492
  # Merge vertices from all images
493
  all_3d_vertices, connections_3d = merge_vertices_3d(vert_edge_per_image, 0.5)
494
  all_3d_vertices_clean, connections_3d_clean = prune_not_connected(all_3d_vertices, connections_3d, keep_largest=False)
 
636
  final_valid_indices = valid_point_indices[depth_filter]
637
 
638
  # Add corresponding points to filtered lists
639
+ filtered_points_xyz.append(points_xyz_world[final_valid_indices])
640
+ filtered_point_idxs.append(points_idxs[final_valid_indices])
641
+ filtered_points_color.append([color] * np.sum(depth_filter))
642
 
643
  # Find the point with lowest depth in the filtered points
644
  if len(final_valid_indices) > 0:
 
646
  lowest_depth_point = final_valid_indices[lowest_depth_idx]
647
 
648
  filtered_vertices_apex.append(points_xyz_world[lowest_depth_point])
 
 
 
649
  filtered_vertices_apex_uv.append(centroids[i])
650
 
651
  return filtered_points_xyz, filtered_point_idxs, filtered_points_color, filtered_vertices_apex, filtered_vertices_apex_uv
 
659
  filtered_point_idxs = filtered_point_idxs_apex + filtered_point_idxs_eave
660
  filtered_points_color = filtered_points_color_apex + filtered_points_color_eave
661
 
662
+ #filtered_points_xyz = np.array(filtered_points_xyz[::-1]) if filtered_points_xyz else np.empty((0, 3))
663
+ #filtered_point_idxs = np.array(filtered_point_idxs[::-1]) if filtered_point_idxs else np.empty((0,))
664
+ #filtered_points_color = np.array(filtered_points_color[::-1]) if filtered_points_color else np.empty((0, 3))
665
  filtered_vertices_apex = np.array(filtered_vertices_apex) if filtered_vertices_apex else np.empty((0, 3))
666
  filtered_vertices_apex_uv = np.array(filtered_vertices_apex_uv) if filtered_vertices_apex_uv else np.empty((0, 2))
667
  filtered_vertices_eave = np.array(filtered_vertices_eave) if filtered_vertices_eave else np.empty((0, 3))
 
809
 
810
  #o3d.visualization.draw_geometries(geometries, window_name=f"Combined Point Cloud - {img_id_substring}")
811
 
812
+ def generate_patches(colmap_rec, filtered_points_idxs, frame, filtered_vertices):
813
+ patches = []
814
+
815
+ gt_vertices = frame['wf_vertices']
816
+
817
+ # Process each group of filtered points
818
+ for group_idx, point_idxs in enumerate(filtered_points_idxs):
819
+ if len(point_idxs) == 0:
820
+ continue
821
+
822
+ # Get 3D coordinates and colors for this group
823
+ group_points_3d = []
824
+ group_colors = []
825
+ assigned_gt_vertex = None
826
+
827
+ for pid in point_idxs:
828
+ p3d = colmap_rec.points3D[pid]
829
+ group_points_3d.append(p3d.xyz)
830
+ group_colors.append(p3d.color)
831
+
832
+ if len(group_points_3d) == 0:
833
+ continue
834
+
835
+ group_points_3d = np.array(group_points_3d)
836
+ group_colors = np.array(group_colors)
837
+
838
+ # Calculate centroid of filtered points
839
+ # Find the closest GT vertex to the centroid of filtered points
840
+ centroid = np.mean(group_points_3d, axis=0)
841
+
842
+ if len(gt_vertices) > 0:
843
+ # Calculate distances from centroid to all GT vertices
844
+ distances_to_gt = []
845
+ for gt_vertex in gt_vertices:
846
+ distance = np.linalg.norm(gt_vertex - centroid)
847
+ distances_to_gt.append(distance)
848
+
849
+ # Find the closest GT vertex
850
+ min_distance_idx = np.argmin(distances_to_gt)
851
+ closest_gt_vertex = gt_vertices[min_distance_idx]
852
+ min_distance = distances_to_gt[min_distance_idx]
853
+
854
+ # Define ball radius (you can adjust this value)
855
+ ball_radius = 2.0 # meters
856
+
857
+ # Use closest GT vertex as centroid if it's within the ball radius
858
+ if min_distance <= ball_radius:
859
+ assigned_gt_vertex = closest_gt_vertex
860
+ # If no GT vertex is close enough, skip this group
861
+ else:
862
+ continue
863
+ else:
864
+ # No GT vertices available, use original centroid
865
+ centroid = np.mean(group_points_3d, axis=0)
866
+
867
+ # Define ball radius (you can adjust this value)
868
+ ball_radius = 2.0 # meters
869
+
870
+ # Find all COLMAP points within the ball around centroid
871
+ patch_points_3d = []
872
+ patch_colors = []
873
+ patch_point_ids = []
874
+
875
+ for pid, p3d in colmap_rec.points3D.items():
876
+ distance = np.linalg.norm(p3d.xyz - centroid)
877
+ if distance <= ball_radius:
878
+ patch_points_3d.append(p3d.xyz)
879
+ patch_colors.append(p3d.color)
880
+ patch_point_ids.append(pid)
881
+
882
+ if len(patch_points_3d) == 0:
883
+ continue
884
+
885
+ patch_points_3d = np.array(patch_points_3d)
886
+
887
+ # Calculate offset to center the patch
888
+ patch_centroid = np.mean(patch_points_3d, axis=0)
889
+ offset = -patch_centroid
890
+
891
+ # Shift points to center them around origin
892
+ patch_points_3d += offset
893
+
894
+ # Also shift the assigned GT vertex by the same offset if it exists
895
+ if assigned_gt_vertex is not None:
896
+ assigned_gt_vertex = assigned_gt_vertex + offset
897
+ patch_colors = np.array(patch_colors)
898
+
899
+ # Create 7D point cloud for this patch
900
+ # [x, y, z, r, g, b, in_filtered_flag]
901
+ patch_7d = np.zeros((len(patch_points_3d), 7))
902
+ patch_7d[:, :3] = patch_points_3d # xyz coordinates
903
+ patch_7d[:, 3:6] = patch_colors / 255.0 # rgb colors normalized to [0,1]
904
+
905
+ # Set in_filtered_flag: 1 if point was in original filtered set, 0 otherwise
906
+ for i, pid in enumerate(patch_point_ids):
907
+ if pid in point_idxs:
908
+ patch_7d[i, 6] = 1.0
909
+ else:
910
+ patch_7d[i, 6] = 0.0
911
+
912
+ if filtered_vertices[group_idx] is not None:
913
+ initial_pred = filtered_vertices[group_idx] + offset
914
+ else:
915
+ initial_pred = None
916
+
917
+ patches.append({
918
+ 'patch_7d': patch_7d,
919
+ 'centroid': centroid,
920
+ 'radius': ball_radius,
921
+ 'point_ids': patch_point_ids,
922
+ 'filtered_point_ids': point_idxs,
923
+ 'group_idx': group_idx,
924
+ 'assigned_gt_vertex': assigned_gt_vertex,
925
+ 'offset': offset,
926
+ 'initial_pred': initial_pred
927
+ })
928
+
929
+ if False:
930
+ # Create plotter
931
+ plotter = pv.Plotter()
932
+
933
+ # Create point cloud for this patch
934
+ patch_cloud = pv.PolyData(patch_points_3d)
935
+
936
+ # Color points: red for filtered points, blue for other points
937
+ patch_point_colors = []
938
+ for i, pid in enumerate(patch_point_ids):
939
+ if pid in point_idxs:
940
+ patch_point_colors.append([255, 0, 0]) # Red for filtered points
941
+ else:
942
+ patch_point_colors.append([0, 0, 255]) # Blue for other points
943
+
944
+ patch_cloud["colors"] = np.array(patch_point_colors)
945
+ plotter.add_mesh(patch_cloud, scalars="colors", rgb=True, point_size=8, render_points_as_spheres=True)
946
+
947
+ # Create sphere to visualize GT vertex if available
948
+ if assigned_gt_vertex is not None:
949
+ gt_sphere = pv.Sphere(radius=0.1, center=assigned_gt_vertex)
950
+ plotter.add_mesh(gt_sphere, color="green", opacity=0.5)
951
+
952
+ if initial_pred is not None:
953
+ # Create sphere to visualize initial prediction
954
+ pred_sphere = pv.Sphere(radius=0.1, center=initial_pred)
955
+ plotter.add_mesh(pred_sphere, color="orange", opacity=0.5)
956
+
957
+ plotter.show(title=f"Patch {group_idx}")
958
+
959
+ return patches
960
+
961
+ def our_get_vertices_and_edges(gest_seg_np, colmap_rec, img_id_substring, ade_seg, depth, K=None, R=None, t=None, frame=None):
962
  """
963
  Identify apex and eave-end vertices, then detect lines for eave/ridge/rake/valley.
964
  Also find all COLMAP points that project into apex or eave_end masks.
 
993
 
994
  vertices_formatted, connections, all_vertices_3d = get_connections(gest_seg_np, filtered_vertices_apex, filtered_vertices_eave, filtered_vertices_apex_uv, filtered_vertices_eave_uv)
995
 
996
+ patches = generate_patches(colmap_rec, filtered_point_idxs, frame, all_vertices_3d)
997
+
998
  #visualize_3d_wireframe(colmap_rec, filtered_points_xyz, filtered_points_color, all_vertices_3d, connections)
999
 
1000
+ return vertices_formatted, connections, all_vertices_3d, patches
train.py CHANGED
@@ -22,7 +22,7 @@ scores_iou = []
22
  show_visu = False
23
 
24
  idx = 0
25
- for a in ds['validation']:
26
  #plot_all_modalities(a)
27
  #pred_vertices, pred_edges = predict_wireframe(a)
28
  try:
 
22
  show_visu = False
23
 
24
  idx = 0
25
+ for a in ds['train']:
26
  #plot_all_modalities(a)
27
  #pred_vertices, pred_edges = predict_wireframe(a)
28
  try: