jskvrna commited on
Commit
c1c37b0
·
1 Parent(s): af9c931
Files changed (3) hide show
  1. fast_pointnet_class_v2.py +464 -0
  2. predict.py +11 -16
  3. train.py +3 -3
fast_pointnet_class_v2.py ADDED
@@ -0,0 +1,464 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import numpy as np
6
+ import pickle
7
+ from torch.utils.data import Dataset, DataLoader
8
+ from typing import List, Dict, Tuple, Optional
9
+ import json
10
+
11
+ class ClassificationPointNet(nn.Module):
12
+ """
13
+ PointNet implementation for binary classification from 10D point cloud patches.
14
+ Takes 10D point clouds and predicts binary classification (edge/not edge).
15
+ """
16
+ def __init__(self, input_dim=10, max_points=1024): # Changed input_dim default to 10
17
+ super(ClassificationPointNet, self).__init__()
18
+ self.max_points = max_points
19
+
20
+ # Point-wise MLPs for feature extraction (deeper network)
21
+ self.conv1 = nn.Conv1d(input_dim, 64, 1) # Changed input_dim here
22
+ self.conv2 = nn.Conv1d(64, 128, 1)
23
+ self.conv3 = nn.Conv1d(128, 256, 1)
24
+ self.conv4 = nn.Conv1d(256, 512, 1)
25
+ self.conv5 = nn.Conv1d(512, 1024, 1)
26
+ self.conv6 = nn.Conv1d(1024, 2048, 1) # Additional layer
27
+
28
+ # Classification head (deeper with more capacity)
29
+ self.fc1 = nn.Linear(2048, 1024)
30
+ self.fc2 = nn.Linear(1024, 512)
31
+ self.fc3 = nn.Linear(512, 256)
32
+ self.fc4 = nn.Linear(256, 128)
33
+ self.fc5 = nn.Linear(128, 64)
34
+ self.fc6 = nn.Linear(64, 1) # Single output for binary classification
35
+
36
+ # Batch normalization layers
37
+ self.bn1 = nn.BatchNorm1d(64)
38
+ self.bn2 = nn.BatchNorm1d(128)
39
+ self.bn3 = nn.BatchNorm1d(256)
40
+ self.bn4 = nn.BatchNorm1d(512)
41
+ self.bn5 = nn.BatchNorm1d(1024)
42
+ self.bn6 = nn.BatchNorm1d(2048)
43
+
44
+ # Dropout layers
45
+ self.dropout1 = nn.Dropout(0.3)
46
+ self.dropout2 = nn.Dropout(0.4)
47
+ self.dropout3 = nn.Dropout(0.5)
48
+ self.dropout4 = nn.Dropout(0.4)
49
+ self.dropout5 = nn.Dropout(0.3)
50
+
51
+ def forward(self, x):
52
+ """
53
+ Forward pass
54
+ Args:
55
+ x: (batch_size, input_dim, max_points) tensor
56
+ Returns:
57
+ classification: (batch_size, 1) tensor of logits (sigmoid for probability)
58
+ """
59
+ batch_size = x.size(0)
60
+
61
+ # Point-wise feature extraction
62
+ x1 = F.relu(self.bn1(self.conv1(x)))
63
+ x2 = F.relu(self.bn2(self.conv2(x1)))
64
+ x3 = F.relu(self.bn3(self.conv3(x2)))
65
+ x4 = F.relu(self.bn4(self.conv4(x3)))
66
+ x5 = F.relu(self.bn5(self.conv5(x4)))
67
+ x6 = F.relu(self.bn6(self.conv6(x5)))
68
+
69
+ # Global max pooling
70
+ global_features = torch.max(x6, 2)[0] # (batch_size, 2048)
71
+
72
+ # Classification head
73
+ x = F.relu(self.fc1(global_features))
74
+ x = self.dropout1(x)
75
+ x = F.relu(self.fc2(x))
76
+ x = self.dropout2(x)
77
+ x = F.relu(self.fc3(x))
78
+ x = self.dropout3(x)
79
+ x = F.relu(self.fc4(x))
80
+ x = self.dropout4(x)
81
+ x = F.relu(self.fc5(x))
82
+ x = self.dropout5(x)
83
+ classification = self.fc6(x) # (batch_size, 1)
84
+
85
+ return classification
86
+
87
+ class PatchClassificationDataset(Dataset):
88
+ """
89
+ Dataset class for loading saved patches for PointNet classification training.
90
+ """
91
+
92
+ def __init__(self, dataset_dir: str, max_points: int = 1024, augment: bool = True, input_dim: int = 10): # Added input_dim
93
+ self.dataset_dir = dataset_dir
94
+ self.max_points = max_points
95
+ self.augment = augment
96
+ self.input_dim = input_dim # Store input_dim
97
+
98
+ # Load patch files
99
+ self.patch_files = []
100
+ for file in os.listdir(dataset_dir):
101
+ if file.endswith('.pkl'):
102
+ self.patch_files.append(os.path.join(dataset_dir, file))
103
+
104
+ print(f"Found {len(self.patch_files)} patch files in {dataset_dir}")
105
+
106
+ def __len__(self):
107
+ return len(self.patch_files)
108
+
109
+ def __getitem__(self, idx):
110
+ """
111
+ Load and process a patch for training.
112
+ Returns:
113
+ patch_data: (input_dim, max_points) tensor of point cloud data
114
+ label: scalar tensor for binary classification (0 or 1)
115
+ valid_mask: (max_points,) boolean tensor indicating valid points
116
+ """
117
+ patch_file = self.patch_files[idx]
118
+
119
+ with open(patch_file, 'rb') as f:
120
+ patch_info = pickle.load(f)
121
+
122
+ # Assuming the key in patch_info is now 'patch_10d' or similar, or that patch_info['patch_data'] is (N, 10)
123
+ # For this example, let's assume the key is 'patch_data' and it holds the 10D data.
124
+ # If your key is 'patch_10d', change 'patch_data' to 'patch_10d' below.
125
+ patch_data_nd = patch_info.get('patch_data', patch_info.get('patch_10d', patch_info.get('patch_6d'))) # Try to get 10d, fallback to 6d for now
126
+ if patch_data_nd.shape[1] != self.input_dim:
127
+ # This is a fallback or error handling if the loaded data isn't 10D.
128
+ # You might want to raise an error or handle this case specifically.
129
+ # For now, if it's 6D, we'll pad it to 10D with zeros as a placeholder.
130
+ # This part needs to be adjusted based on how your 10D data is actually stored.
131
+ print(f"Warning: Patch {patch_file} has {patch_data_nd.shape[1]} dimensions, expected {self.input_dim}. Padding with zeros if necessary.")
132
+ if patch_data_nd.shape[1] < self.input_dim:
133
+ padding = np.zeros((patch_data_nd.shape[0], self.input_dim - patch_data_nd.shape[1]))
134
+ patch_data_nd = np.concatenate((patch_data_nd, padding), axis=1)
135
+ elif patch_data_nd.shape[1] > self.input_dim:
136
+ patch_data_nd = patch_data_nd[:, :self.input_dim]
137
+
138
+
139
+ label = patch_info.get('label', 0) # Get binary classification label (0 or 1)
140
+
141
+ # Pad or sample points to max_points
142
+ num_points = patch_data_nd.shape[0]
143
+
144
+ if num_points >= self.max_points:
145
+ # Randomly sample max_points
146
+ indices = np.random.choice(num_points, self.max_points, replace=False)
147
+ patch_sampled = patch_data_nd[indices]
148
+ valid_mask = np.ones(self.max_points, dtype=bool)
149
+ else:
150
+ # Pad with zeros
151
+ patch_sampled = np.zeros((self.max_points, self.input_dim)) # Changed to self.input_dim
152
+ patch_sampled[:num_points] = patch_data_nd
153
+ valid_mask = np.zeros(self.max_points, dtype=bool)
154
+ valid_mask[:num_points] = True
155
+
156
+ # Data augmentation
157
+ if self.augment:
158
+ # Note: _augment_patch currently only augments xyz (first 3 dims).
159
+ # If other dimensions are geometric and need augmentation, this function needs an update.
160
+ patch_sampled = self._augment_patch(patch_sampled, valid_mask)
161
+
162
+ # Convert to tensors and transpose for conv1d (channels first)
163
+ patch_tensor = torch.from_numpy(patch_sampled.T).float() # (input_dim, max_points)
164
+ label_tensor = torch.tensor(label, dtype=torch.float32) # Float for BCE loss
165
+ valid_mask_tensor = torch.from_numpy(valid_mask)
166
+
167
+ return patch_tensor, label_tensor, valid_mask_tensor
168
+
169
+ def _augment_patch(self, patch, valid_mask):
170
+ """
171
+ Apply data augmentation to the patch.
172
+ Note: This implementation only augments the first 3 dimensions (assumed to be XYZ).
173
+ If your 10D representation has other geometric features that need augmentation,
174
+ this function should be updated accordingly.
175
+ """
176
+ valid_points_data = patch[valid_mask]
177
+
178
+ if len(valid_points_data) == 0:
179
+ return patch
180
+
181
+ # Extract XYZ for augmentation (first 3 columns)
182
+ valid_points_xyz = valid_points_data[:, :3].copy() # Operate on a copy
183
+
184
+ # Random rotation around z-axis
185
+ angle = np.random.uniform(0, 2 * np.pi)
186
+ cos_angle = np.cos(angle)
187
+ sin_angle = np.sin(angle)
188
+ rotation_matrix = np.array([
189
+ [cos_angle, -sin_angle, 0],
190
+ [sin_angle, cos_angle, 0],
191
+ [0, 0, 1]
192
+ ])
193
+
194
+ # Apply rotation to xyz coordinates
195
+ valid_points_xyz = valid_points_xyz @ rotation_matrix.T
196
+
197
+ # Random jittering
198
+ noise = np.random.normal(0, 0.01, valid_points_xyz.shape)
199
+ valid_points_xyz += noise
200
+
201
+ # Random scaling
202
+ scale = np.random.uniform(0.9, 1.1)
203
+ valid_points_xyz *= scale
204
+
205
+ # Update the original patch data
206
+ augmented_patch = patch.copy()
207
+ augmented_patch[valid_mask, :3] = valid_points_xyz
208
+
209
+ return augmented_patch
210
+
211
+ def save_patches_dataset(patches: List[Dict], dataset_dir: str, entry_id: str):
212
+ """
213
+ Save patches from prediction pipeline to create a training dataset.
214
+ Ensure 'patch_data' (or 'patch_10d') in the patch dictionary contains the 10D data.
215
+
216
+ Args:
217
+ patches: List of patch dictionaries from generate_patches()
218
+ dataset_dir: Directory to save the dataset
219
+ entry_id: Unique identifier for this entry/image
220
+ """
221
+ os.makedirs(dataset_dir, exist_ok=True)
222
+
223
+ for i, patch in enumerate(patches):
224
+ # Create unique filename
225
+ filename = f"{entry_id}_patch_{i}.pkl"
226
+ filepath = os.path.join(dataset_dir, filename)
227
+
228
+ # Skip if file already exists
229
+ if os.path.exists(filepath):
230
+ continue
231
+
232
+ # Ensure the patch data being saved is 10D.
233
+ # Example: patch_data_key = 'patch_10d' or 'patch_data'
234
+ # if 'patch_data' not in patch or patch['patch_data'].shape[1] != 10:
235
+ # print(f"Warning: Patch {i} for entry {entry_id} does not seem to be 10D. Skipping or error handling needed.")
236
+ # continue
237
+
238
+ with open(filepath, 'wb') as f:
239
+ pickle.dump(patch, f)
240
+
241
+ print(f"Saved {len(patches)} patches for entry {entry_id}")
242
+
243
+ # Create dataloader with custom collate function to filter invalid samples
244
+ def collate_fn(batch):
245
+ valid_batch = []
246
+ for patch_data, label, valid_mask in batch:
247
+ # Filter out invalid samples (no valid points)
248
+ if valid_mask.sum() > 0:
249
+ valid_batch.append((patch_data, label, valid_mask))
250
+
251
+ if len(valid_batch) == 0:
252
+ return None
253
+
254
+ # Stack valid samples
255
+ patch_data = torch.stack([item[0] for item in valid_batch])
256
+ labels = torch.stack([item[1] for item in valid_batch])
257
+ valid_masks = torch.stack([item[2] for item in valid_batch])
258
+
259
+ return patch_data, labels, valid_masks
260
+
261
+ # Initialize weights using Xavier/Glorot initialization
262
+ def init_weights(m):
263
+ if isinstance(m, nn.Conv1d):
264
+ nn.init.xavier_uniform_(m.weight)
265
+ if m.bias is not None:
266
+ nn.init.zeros_(m.bias)
267
+ elif isinstance(m, nn.Linear):
268
+ nn.init.xavier_uniform_(m.weight)
269
+ if m.bias is not None:
270
+ nn.init.zeros_(m.bias)
271
+ elif isinstance(m, nn.BatchNorm1d):
272
+ nn.init.ones_(m.weight)
273
+ nn.init.zeros_(m.bias)
274
+
275
+ def train_pointnet(dataset_dir: str, model_save_path: str, epochs: int = 100, batch_size: int = 32,
276
+ lr: float = 0.001, input_dim: int = 10): # Added input_dim
277
+ """
278
+ Train the ClassificationPointNet model on saved patches.
279
+
280
+ Args:
281
+ dataset_dir: Directory containing saved patch files
282
+ model_save_path: Path to save the trained model
283
+ epochs: Number of training epochs
284
+ batch_size: Training batch size
285
+ lr: Learning rate
286
+ input_dim: Dimensionality of the input points (e.g., 10 for 10D)
287
+ """
288
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
289
+ print(f"Training on device: {device}")
290
+
291
+ # Create dataset and dataloader
292
+ dataset = PatchClassificationDataset(dataset_dir, max_points=1024, augment=True, input_dim=input_dim) # Pass input_dim
293
+ print(f"Dataset loaded with {len(dataset)} samples")
294
+
295
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=8,
296
+ collate_fn=collate_fn, drop_last=True)
297
+
298
+ # Initialize model
299
+ model = ClassificationPointNet(input_dim=input_dim, max_points=1024) # Pass input_dim
300
+ model.apply(init_weights)
301
+ model.to(device)
302
+
303
+ # Loss function and optimizer (BCE for binary classification)
304
+ criterion = nn.BCEWithLogitsLoss()
305
+ optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)
306
+ scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.5)
307
+
308
+ # Training loop
309
+ model.train()
310
+ for epoch in range(epochs):
311
+ total_loss = 0.0
312
+ correct = 0
313
+ total = 0
314
+ num_batches = 0
315
+
316
+ for batch_idx, batch_data in enumerate(dataloader):
317
+ if batch_data is None: # Skip invalid batches
318
+ continue
319
+
320
+ patch_data, labels, valid_masks = batch_data
321
+ patch_data = patch_data.to(device) # (batch_size, input_dim, max_points)
322
+ labels = labels.to(device).unsqueeze(1) # (batch_size, 1)
323
+
324
+ # Forward pass
325
+ optimizer.zero_grad()
326
+ outputs = model(patch_data) # (batch_size, 1)
327
+ loss = criterion(outputs, labels)
328
+
329
+ # Backward pass
330
+ loss.backward()
331
+ optimizer.step()
332
+
333
+ # Statistics
334
+ total_loss += loss.item()
335
+ predicted = (torch.sigmoid(outputs) > 0.5).float()
336
+ total += labels.size(0)
337
+ correct += (predicted == labels).sum().item()
338
+ num_batches += 1
339
+
340
+ if batch_idx % 50 == 0:
341
+ print(f"Epoch {epoch+1}/{epochs}, Batch {batch_idx}, "
342
+ f"Loss: {loss.item():.6f}, "
343
+ f"Accuracy: {100 * correct / total:.2f}%")
344
+
345
+ avg_loss = total_loss / num_batches if num_batches > 0 else 0
346
+ accuracy = 100 * correct / total if total > 0 else 0
347
+
348
+ print(f"Epoch {epoch+1}/{epochs} completed, "
349
+ f"Avg Loss: {avg_loss:.6f}, "
350
+ f"Accuracy: {accuracy:.2f}%")
351
+
352
+ scheduler.step()
353
+
354
+ # Save model checkpoint every epoch
355
+ checkpoint_path = model_save_path.replace('.pth', f'_epoch_{epoch+1}.pth')
356
+ torch.save({
357
+ 'model_state_dict': model.state_dict(),
358
+ 'optimizer_state_dict': optimizer.state_dict(),
359
+ 'epoch': epoch + 1,
360
+ 'loss': avg_loss,
361
+ 'accuracy': accuracy,
362
+ 'input_dim': input_dim, # Save input_dim with checkpoint
363
+ }, checkpoint_path)
364
+
365
+ # Save the trained model
366
+ torch.save({
367
+ 'model_state_dict': model.state_dict(),
368
+ 'optimizer_state_dict': optimizer.state_dict(),
369
+ 'epoch': epochs,
370
+ 'input_dim': input_dim, # Save input_dim with final model
371
+ }, model_save_path)
372
+
373
+ print(f"Model saved to {model_save_path}")
374
+ return model
375
+
376
+ def load_pointnet_model(model_path: str, device: torch.device = None) -> ClassificationPointNet:
377
+ """
378
+ Load a trained ClassificationPointNet model.
379
+
380
+ Args:
381
+ model_path: Path to the saved model
382
+ device: Device to load the model on
383
+
384
+ Returns:
385
+ Loaded ClassificationPointNet model
386
+ """
387
+ if device is None:
388
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
389
+
390
+ checkpoint = torch.load(model_path, map_location=device)
391
+
392
+ # Load input_dim from checkpoint if available, otherwise default to 10
393
+ # For older models saved without input_dim, you might need to specify it or assume a default.
394
+ input_dim = checkpoint.get('input_dim', 10)
395
+
396
+ model = ClassificationPointNet(input_dim=input_dim, max_points=1024) # Use loaded or default input_dim
397
+ model.load_state_dict(checkpoint['model_state_dict'])
398
+
399
+ model.to(device)
400
+ model.eval()
401
+
402
+ return model
403
+
404
+ def predict_class_from_patch(model: ClassificationPointNet, patch: Dict, device: torch.device = None) -> Tuple[int, float]:
405
+ """
406
+ Predict binary classification from a patch using trained PointNet.
407
+ Assumes the model's input_dim matches the data.
408
+
409
+ Args:
410
+ model: Trained ClassificationPointNet model
411
+ patch: Dictionary containing patch data. Expects a key like 'patch_data' or 'patch_10d' with (N, 10) shape.
412
+ device: Device to run prediction on
413
+
414
+ Returns:
415
+ tuple of (predicted_class, confidence)
416
+ predicted_class: int (0 for not edge, 1 for edge)
417
+ confidence: float representing confidence score (0-1)
418
+ """
419
+ if device is None:
420
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
421
+
422
+ # Determine input_dim from the model
423
+ input_dim = model.conv1.in_channels
424
+
425
+ # Assuming the key in patch_info is now 'patch_10d' or similar, or that patch_info['patch_data'] is (N, 10)
426
+ # For this example, let's assume the key is 'patch_data' and it holds the 10D data.
427
+ # If your key is 'patch_10d', change 'patch_data' to 'patch_10d' below.
428
+ patch_data_nd = patch.get('patch_data', patch.get('patch_10d', patch.get('patch_6d'))) # Try to get 10d, fallback to 6d
429
+
430
+ if patch_data_nd.shape[1] != input_dim:
431
+ # Handle dimension mismatch, e.g., by padding or raising an error
432
+ print(f"Warning: Input patch has {patch_data_nd.shape[1]} dimensions, but model expects {input_dim}. Adjusting...")
433
+ if patch_data_nd.shape[1] < input_dim:
434
+ padding = np.zeros((patch_data_nd.shape[0], input_dim - patch_data_nd.shape[1]))
435
+ patch_data_nd = np.concatenate((patch_data_nd, padding), axis=1)
436
+ elif patch_data_nd.shape[1] > input_dim:
437
+ patch_data_nd = patch_data_nd[:, :input_dim]
438
+
439
+ # Prepare input
440
+ max_points = model.max_points # Use max_points from the model instance
441
+ num_points = patch_data_nd.shape[0]
442
+
443
+ if num_points >= max_points:
444
+ # Sample points
445
+ indices = np.random.choice(num_points, max_points, replace=False)
446
+ patch_sampled = patch_data_nd[indices]
447
+ else:
448
+ # Pad with zeros
449
+ patch_sampled = np.zeros((max_points, input_dim)) # Use model's input_dim
450
+ patch_sampled[:num_points] = patch_data_nd
451
+
452
+ # Convert to tensor
453
+ patch_tensor = torch.from_numpy(patch_sampled.T).float().unsqueeze(0) # (1, input_dim, max_points)
454
+ patch_tensor = patch_tensor.to(device)
455
+
456
+ # Predict
457
+ model.eval() # Ensure model is in eval mode
458
+ with torch.no_grad():
459
+ outputs = model(patch_tensor) # (1, 1)
460
+ probability = torch.sigmoid(outputs).item()
461
+ predicted_class = int(probability > 0.5)
462
+
463
+ return predicted_class, probability
464
+
predict.py CHANGED
@@ -15,8 +15,8 @@ import cv2
15
  from fast_pointnet_v2 import save_patches_dataset, predict_vertex_from_patch
16
  #from fast_voxel import predict_vertex_from_patch_voxel
17
  #import time
18
- from fast_pointnet_class import save_patches_dataset as save_patches_dataset_class
19
- from fast_pointnet_class import predict_class_from_patch
20
  #from fast_pointnet_class_10d import predict_class_from_patch as predict_class_from_patch_10d
21
  from scipy.spatial.distance import cdist
22
  from scipy.optimize import linear_sum_assignment
@@ -28,9 +28,9 @@ GENERATE_DATASET = False
28
  #DATASET_DIR = '/home/skvrnjan/personal/hohocustom/'
29
  DATASET_DIR = '/mnt/personal/skvrnjan/hohocustom_v4/'
30
 
31
- GENERATE_DATASET_EDGES = False
32
  #EDGES_DATASET_DIR = '/home/skvrnjan/personal/hohocustom_edges/'
33
- EDGES_DATASET_DIR = '/mnt/personal/skvrnjan/hohocustom_edges_10d_1m/'
34
 
35
  def convert_entry_to_human_readable(entry):
36
  out = {}
@@ -1010,18 +1010,13 @@ def generate_edge_patches(frame, pred_vertices, colmap_pcloud):
1010
  elif len(point_gestalt_list) == 1:
1011
  fused_gestalt.append(point_gestalt_list[0])
1012
  else:
1013
- # Convert to numpy array for easier manipulation
1014
- gestalt_values = np.array(point_gestalt_list)
1015
-
1016
- # Method 1: Average the RGB values
1017
- fused_value = np.mean(gestalt_values, axis=0).astype(np.uint8)
1018
 
1019
- # Method 2: Majority voting per channel (commented out alternative)
1020
- # fused_value = np.array([
1021
- # np.bincount(gestalt_values[:, 0]).argmax(),
1022
- # np.bincount(gestalt_values[:, 1]).argmax(),
1023
- # np.bincount(gestalt_values[:, 2]).argmax()
1024
- # ])
1025
 
1026
  fused_gestalt.append(fused_value)
1027
 
@@ -1078,7 +1073,7 @@ def generate_edge_patches(frame, pred_vertices, colmap_pcloud):
1078
  # Find points within cylinder
1079
  within_cylinder = within_bounds & (perpendicular_distances <= cylinder_radius)
1080
 
1081
- if np.sum(within_cylinder) <= 10:
1082
  continue
1083
 
1084
  points_in_cylinder = colmap_points_10d[within_cylinder]
 
15
  from fast_pointnet_v2 import save_patches_dataset, predict_vertex_from_patch
16
  #from fast_voxel import predict_vertex_from_patch_voxel
17
  #import time
18
+ from fast_pointnet_class_v2 import save_patches_dataset as save_patches_dataset_class
19
+ from fast_pointnet_class_v2 import predict_class_from_patch
20
  #from fast_pointnet_class_10d import predict_class_from_patch as predict_class_from_patch_10d
21
  from scipy.spatial.distance import cdist
22
  from scipy.optimize import linear_sum_assignment
 
28
  #DATASET_DIR = '/home/skvrnjan/personal/hohocustom/'
29
  DATASET_DIR = '/mnt/personal/skvrnjan/hohocustom_v4/'
30
 
31
+ GENERATE_DATASET_EDGES = True
32
  #EDGES_DATASET_DIR = '/home/skvrnjan/personal/hohocustom_edges/'
33
+ EDGES_DATASET_DIR = '/mnt/personal/skvrnjan/hohocustom_edges_10d_v4/'
34
 
35
  def convert_entry_to_human_readable(entry):
36
  out = {}
 
1010
  elif len(point_gestalt_list) == 1:
1011
  fused_gestalt.append(point_gestalt_list[0])
1012
  else:
1013
+ # Convert to tuples for hashable voting
1014
+ gestalt_tuples = [tuple(gestalt_val) for gestalt_val in point_gestalt_list]
 
 
 
1015
 
1016
+ # Use Counter for majority voting
1017
+ counts = Counter(gestalt_tuples)
1018
+ most_common_tuple = counts.most_common(1)[0][0]
1019
+ fused_value = np.array(most_common_tuple, dtype=np.uint8)
 
 
1020
 
1021
  fused_gestalt.append(fused_value)
1022
 
 
1073
  # Find points within cylinder
1074
  within_cylinder = within_bounds & (perpendicular_distances <= cylinder_radius)
1075
 
1076
+ if np.sum(within_cylinder) <= 5:
1077
  continue
1078
 
1079
  points_in_cylinder = colmap_points_10d[within_cylinder]
train.py CHANGED
@@ -26,8 +26,8 @@ import time
26
 
27
  # --- Argument Parsing ---
28
  parser = argparse.ArgumentParser(description="Train and evaluate HoHo model with custom config.")
29
- parser.add_argument('--vertex_threshold', type=float, default=0.6, help='Vertex threshold for prediction.')
30
- parser.add_argument('--edge_threshold', type=float, default=0.65, help='Edge threshold for prediction.')
31
  parser.add_argument('--only_predicted_connections', type=lambda x: (str(x).lower() == 'true'), default=True, help='Use only predicted connections (True/False).')
32
  parser.add_argument('--max_samples', type=int, default=50000, help='Maximum number of samples to process.')
33
  parser.add_argument('--results_dir', type=str, default="results", help='Directory to save result files.')
@@ -75,7 +75,7 @@ voxel_model = None
75
 
76
  idx = 0
77
  prediction_times = []
78
- for a in tqdm(ds['validation'], desc="Processing dataset"):
79
  #plot_all_modalities(a)
80
  #pred_vertices, pred_edges = predict_wireframe_old(a)
81
  #pred_vertices, pred_edges = predict_wireframe(a.copy(), pnet_model, voxel_model, pnet_class_model, config)
 
26
 
27
  # --- Argument Parsing ---
28
  parser = argparse.ArgumentParser(description="Train and evaluate HoHo model with custom config.")
29
+ parser.add_argument('--vertex_threshold', type=float, default=0.72, help='Vertex threshold for prediction.')
30
+ parser.add_argument('--edge_threshold', type=float, default=0.72, help='Edge threshold for prediction.')
31
  parser.add_argument('--only_predicted_connections', type=lambda x: (str(x).lower() == 'true'), default=True, help='Use only predicted connections (True/False).')
32
  parser.add_argument('--max_samples', type=int, default=50000, help='Maximum number of samples to process.')
33
  parser.add_argument('--results_dir', type=str, default="results", help='Directory to save result files.')
 
75
 
76
  idx = 0
77
  prediction_times = []
78
+ for a in tqdm(ds['train'], desc="Processing dataset"):
79
  #plot_all_modalities(a)
80
  #pred_vertices, pred_edges = predict_wireframe_old(a)
81
  #pred_vertices, pred_edges = predict_wireframe(a.copy(), pnet_model, voxel_model, pnet_class_model, config)