Implements FastPointNet for vertex prediction
Browse filesAdds a FastPointNet model for predicting 3D vertex coordinates from point cloud patches.
Includes a dataset class for loading and augmenting patch data.
Also adds training and prediction functions for the model.
Patches are generated and saved for training the PointNet model,
allowing for iterative refinement of vertex predictions.
The patch generation process includes filtering COLMAP points within a
ball around identified vertices and creating a 7D point cloud
representation. GT vertex assignment and data augmentation are incorporated
to improve the training data.
- fast_pointnet.py +421 -0
- predict.py +175 -18
- train.py +1 -1
fast_pointnet.py
ADDED
|
@@ -0,0 +1,421 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
import numpy as np
|
| 6 |
+
import pickle
|
| 7 |
+
from torch.utils.data import Dataset, DataLoader
|
| 8 |
+
from typing import List, Dict, Tuple, Optional
|
| 9 |
+
import json
|
| 10 |
+
|
| 11 |
+
class FastPointNet(nn.Module):
|
| 12 |
+
"""
|
| 13 |
+
Fast PointNet implementation for 3D vertex prediction from point cloud patches.
|
| 14 |
+
Takes 7D point clouds (x,y,z,r,g,b,filtered_flag) and predicts 3D vertex coordinates.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
def __init__(self, input_dim=7, output_dim=3, max_points=1024, predict_score=True):
|
| 18 |
+
super(FastPointNet, self).__init__()
|
| 19 |
+
self.max_points = max_points
|
| 20 |
+
self.predict_score = predict_score
|
| 21 |
+
|
| 22 |
+
# Point-wise MLPs
|
| 23 |
+
self.conv1 = nn.Conv1d(input_dim, 64, 1)
|
| 24 |
+
self.conv2 = nn.Conv1d(64, 128, 1)
|
| 25 |
+
self.conv3 = nn.Conv1d(128, 256, 1)
|
| 26 |
+
|
| 27 |
+
# Global feature extraction
|
| 28 |
+
self.conv4 = nn.Conv1d(256, 512, 1)
|
| 29 |
+
self.conv5 = nn.Conv1d(512, 1024, 1)
|
| 30 |
+
|
| 31 |
+
# Shared features
|
| 32 |
+
self.shared_fc = nn.Linear(1024, 512)
|
| 33 |
+
|
| 34 |
+
# Position prediction head
|
| 35 |
+
self.pos_fc1 = nn.Linear(512, 256)
|
| 36 |
+
self.pos_fc2 = nn.Linear(256, output_dim)
|
| 37 |
+
|
| 38 |
+
# Score prediction head (predicts distance to GT)
|
| 39 |
+
if self.predict_score:
|
| 40 |
+
self.score_fc1 = nn.Linear(512, 256)
|
| 41 |
+
self.score_fc2 = nn.Linear(256, 128)
|
| 42 |
+
self.score_fc3 = nn.Linear(128, 1) # Single score output
|
| 43 |
+
|
| 44 |
+
self.dropout = nn.Dropout(0.3)
|
| 45 |
+
self.bn1 = nn.BatchNorm1d(64)
|
| 46 |
+
self.bn2 = nn.BatchNorm1d(128)
|
| 47 |
+
self.bn3 = nn.BatchNorm1d(256)
|
| 48 |
+
self.bn4 = nn.BatchNorm1d(512)
|
| 49 |
+
self.bn5 = nn.BatchNorm1d(1024)
|
| 50 |
+
|
| 51 |
+
def forward(self, x):
|
| 52 |
+
"""
|
| 53 |
+
Forward pass
|
| 54 |
+
Args:
|
| 55 |
+
x: (batch_size, input_dim, max_points) tensor
|
| 56 |
+
Returns:
|
| 57 |
+
if predict_score=True: tuple of (position, score)
|
| 58 |
+
position: (batch_size, output_dim) tensor of predicted 3D coordinates
|
| 59 |
+
score: (batch_size, 1) tensor of predicted distance to GT
|
| 60 |
+
else: (batch_size, output_dim) tensor of predicted 3D coordinates
|
| 61 |
+
"""
|
| 62 |
+
batch_size = x.size(0)
|
| 63 |
+
|
| 64 |
+
# Point-wise feature extraction
|
| 65 |
+
x = F.relu(self.bn1(self.conv1(x)))
|
| 66 |
+
x = F.relu(self.bn2(self.conv2(x)))
|
| 67 |
+
x = F.relu(self.bn3(self.conv3(x)))
|
| 68 |
+
x = F.relu(self.bn4(self.conv4(x)))
|
| 69 |
+
x = F.relu(self.bn5(self.conv5(x)))
|
| 70 |
+
|
| 71 |
+
# Global max pooling
|
| 72 |
+
x = torch.max(x, 2)[0] # (batch_size, 1024)
|
| 73 |
+
|
| 74 |
+
# Shared features
|
| 75 |
+
shared_features = F.relu(self.shared_fc(x))
|
| 76 |
+
shared_features = self.dropout(shared_features)
|
| 77 |
+
|
| 78 |
+
# Position prediction
|
| 79 |
+
pos_features = F.relu(self.pos_fc1(shared_features))
|
| 80 |
+
pos_features = self.dropout(pos_features)
|
| 81 |
+
position = self.pos_fc2(pos_features)
|
| 82 |
+
|
| 83 |
+
if self.predict_score:
|
| 84 |
+
# Score prediction (distance to GT)
|
| 85 |
+
score_features = F.relu(self.score_fc1(shared_features))
|
| 86 |
+
score_features = self.dropout(score_features)
|
| 87 |
+
score_features = F.relu(self.score_fc2(score_features))
|
| 88 |
+
score_features = self.dropout(score_features)
|
| 89 |
+
score = F.relu(self.score_fc3(score_features)) # Ensure positive distance
|
| 90 |
+
|
| 91 |
+
return position, score
|
| 92 |
+
else:
|
| 93 |
+
return position
|
| 94 |
+
|
| 95 |
+
class PatchDataset(Dataset):
|
| 96 |
+
"""
|
| 97 |
+
Dataset class for loading saved patches for PointNet training.
|
| 98 |
+
"""
|
| 99 |
+
|
| 100 |
+
def __init__(self, dataset_dir: str, max_points: int = 1024, augment: bool = True):
|
| 101 |
+
self.dataset_dir = dataset_dir
|
| 102 |
+
self.max_points = max_points
|
| 103 |
+
self.augment = augment
|
| 104 |
+
|
| 105 |
+
# Load patch files
|
| 106 |
+
self.patch_files = []
|
| 107 |
+
for file in os.listdir(dataset_dir):
|
| 108 |
+
if file.endswith('.pkl'):
|
| 109 |
+
self.patch_files.append(os.path.join(dataset_dir, file))
|
| 110 |
+
|
| 111 |
+
print(f"Found {len(self.patch_files)} patch files in {dataset_dir}")
|
| 112 |
+
|
| 113 |
+
def __len__(self):
|
| 114 |
+
return len(self.patch_files)
|
| 115 |
+
|
| 116 |
+
def __getitem__(self, idx):
|
| 117 |
+
"""
|
| 118 |
+
Load and process a patch for training.
|
| 119 |
+
Returns:
|
| 120 |
+
patch_data: (7, max_points) tensor of point cloud data
|
| 121 |
+
target: (3,) tensor of target 3D coordinates
|
| 122 |
+
valid_mask: (max_points,) boolean tensor indicating valid points
|
| 123 |
+
distance_to_gt: scalar tensor of distance from initial prediction to GT
|
| 124 |
+
"""
|
| 125 |
+
patch_file = self.patch_files[idx]
|
| 126 |
+
|
| 127 |
+
with open(patch_file, 'rb') as f:
|
| 128 |
+
patch_info = pickle.load(f)
|
| 129 |
+
|
| 130 |
+
patch_7d = patch_info['patch_7d'] # (N, 7)
|
| 131 |
+
target = patch_info['assigned_gt_vertex'] # (3,) or None
|
| 132 |
+
initial_pred = patch_info.get('initial_pred', None) # (3,) or None
|
| 133 |
+
|
| 134 |
+
# Skip patches without ground truth
|
| 135 |
+
if target is None:
|
| 136 |
+
# Return dummy data that will be filtered out
|
| 137 |
+
dummy_patch = np.zeros((self.max_points, 7))
|
| 138 |
+
dummy_target = np.zeros(3)
|
| 139 |
+
valid_mask = np.zeros(self.max_points, dtype=bool)
|
| 140 |
+
dummy_distance = np.array([0.0])
|
| 141 |
+
return (torch.from_numpy(dummy_patch.T).float(),
|
| 142 |
+
torch.from_numpy(dummy_target).float(),
|
| 143 |
+
torch.from_numpy(valid_mask),
|
| 144 |
+
torch.from_numpy(dummy_distance).float())
|
| 145 |
+
|
| 146 |
+
target = np.array(target)
|
| 147 |
+
|
| 148 |
+
# Normalize colors from [0,1] to [-1,1]
|
| 149 |
+
patch_7d[:, 3:6] = patch_7d[:, 3:6] * 2.0 - 1.0
|
| 150 |
+
|
| 151 |
+
# Pad or sample points to max_points
|
| 152 |
+
num_points = patch_7d.shape[0]
|
| 153 |
+
|
| 154 |
+
if num_points >= self.max_points:
|
| 155 |
+
# Randomly sample max_points
|
| 156 |
+
indices = np.random.choice(num_points, self.max_points, replace=False)
|
| 157 |
+
patch_sampled = patch_7d[indices]
|
| 158 |
+
valid_mask = np.ones(self.max_points, dtype=bool)
|
| 159 |
+
else:
|
| 160 |
+
# Pad with zeros
|
| 161 |
+
patch_sampled = np.zeros((self.max_points, 7))
|
| 162 |
+
patch_sampled[:num_points] = patch_7d
|
| 163 |
+
valid_mask = np.zeros(self.max_points, dtype=bool)
|
| 164 |
+
valid_mask[:num_points] = True
|
| 165 |
+
|
| 166 |
+
# Data augmentation
|
| 167 |
+
if self.augment:
|
| 168 |
+
patch_sampled = self._augment_patch(patch_sampled, valid_mask)
|
| 169 |
+
target = self._augment_target(target)
|
| 170 |
+
|
| 171 |
+
# Convert to tensors and transpose for conv1d (channels first)
|
| 172 |
+
patch_tensor = torch.from_numpy(patch_sampled.T).float() # (7, max_points)
|
| 173 |
+
target_tensor = torch.from_numpy(target).float() # (3,)
|
| 174 |
+
valid_mask_tensor = torch.from_numpy(valid_mask)
|
| 175 |
+
initial_pred = torch.from_numpy(initial_pred).float()
|
| 176 |
+
|
| 177 |
+
return patch_tensor, target_tensor, valid_mask_tensor, initial_pred
|
| 178 |
+
|
| 179 |
+
def _augment_patch(self, patch: np.ndarray, valid_mask: np.ndarray, target: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
| 180 |
+
"""Apply data augmentation to patch and target with consistent transformations."""
|
| 181 |
+
# Only augment valid points
|
| 182 |
+
valid_points = patch[valid_mask]
|
| 183 |
+
|
| 184 |
+
if len(valid_points) == 0:
|
| 185 |
+
return patch, target
|
| 186 |
+
|
| 187 |
+
# Random rotation around Z-axis
|
| 188 |
+
angle = np.random.uniform(0, 2 * np.pi)
|
| 189 |
+
cos_angle, sin_angle = np.cos(angle), np.sin(angle)
|
| 190 |
+
rotation_matrix = np.array([
|
| 191 |
+
[cos_angle, -sin_angle, 0],
|
| 192 |
+
[sin_angle, cos_angle, 0],
|
| 193 |
+
[0, 0, 1]
|
| 194 |
+
])
|
| 195 |
+
|
| 196 |
+
# Apply rotation to patch coordinates
|
| 197 |
+
valid_points[:, :3] = valid_points[:, :3] @ rotation_matrix.T
|
| 198 |
+
|
| 199 |
+
# Apply same rotation to target
|
| 200 |
+
target_augmented = target @ rotation_matrix.T
|
| 201 |
+
|
| 202 |
+
# Add small random noise to coordinates
|
| 203 |
+
noise = np.random.normal(0, 0.01, valid_points[:, :3].shape)
|
| 204 |
+
valid_points[:, :3] += noise
|
| 205 |
+
|
| 206 |
+
# Color jittering
|
| 207 |
+
color_noise = np.random.normal(0, 0.02, valid_points[:, 3:6].shape)
|
| 208 |
+
valid_points[:, 3:6] = np.clip(valid_points[:, 3:6] + color_noise, 0, 1)
|
| 209 |
+
|
| 210 |
+
patch[valid_mask] = valid_points
|
| 211 |
+
return patch, target_augmented
|
| 212 |
+
|
| 213 |
+
def save_patches_dataset(patches: List[Dict], dataset_dir: str, entry_id: str):
|
| 214 |
+
"""
|
| 215 |
+
Save patches from prediction pipeline to create a training dataset.
|
| 216 |
+
|
| 217 |
+
Args:
|
| 218 |
+
patches: List of patch dictionaries from generate_patches()
|
| 219 |
+
dataset_dir: Directory to save the dataset
|
| 220 |
+
entry_id: Unique identifier for this entry/image
|
| 221 |
+
"""
|
| 222 |
+
os.makedirs(dataset_dir, exist_ok=True)
|
| 223 |
+
|
| 224 |
+
for i, patch in enumerate(patches):
|
| 225 |
+
# Create unique filename
|
| 226 |
+
filename = f"{entry_id}_patch_{i}.pkl"
|
| 227 |
+
filepath = os.path.join(dataset_dir, filename)
|
| 228 |
+
|
| 229 |
+
# Save patch data
|
| 230 |
+
with open(filepath, 'wb') as f:
|
| 231 |
+
pickle.dump(patch, f)
|
| 232 |
+
|
| 233 |
+
print(f"Saved {len(patches)} patches for entry {entry_id}")
|
| 234 |
+
|
| 235 |
+
def train_pointnet(dataset_dir: str, model_save_path: str, epochs: int = 100, batch_size: int = 32, lr: float = 0.001,
|
| 236 |
+
score_weight: float = 0.1):
|
| 237 |
+
"""
|
| 238 |
+
Train the FastPointNet model on saved patches.
|
| 239 |
+
|
| 240 |
+
Args:
|
| 241 |
+
dataset_dir: Directory containing saved patch files
|
| 242 |
+
model_save_path: Path to save the trained model
|
| 243 |
+
epochs: Number of training epochs
|
| 244 |
+
batch_size: Training batch size
|
| 245 |
+
lr: Learning rate
|
| 246 |
+
score_weight: Weight for the distance prediction loss
|
| 247 |
+
"""
|
| 248 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 249 |
+
print(f"Training on device: {device}")
|
| 250 |
+
|
| 251 |
+
# Create dataset and dataloader
|
| 252 |
+
dataset = PatchDataset(dataset_dir, max_points=1024, augment=True)
|
| 253 |
+
|
| 254 |
+
# Filter out invalid samples
|
| 255 |
+
valid_indices = []
|
| 256 |
+
for i in range(len(dataset)):
|
| 257 |
+
_, target, valid_mask, _ = dataset[i]
|
| 258 |
+
if valid_mask.sum() > 0 and not torch.all(target == 0):
|
| 259 |
+
valid_indices.append(i)
|
| 260 |
+
|
| 261 |
+
print(f"Found {len(valid_indices)} valid patches out of {len(dataset)}")
|
| 262 |
+
|
| 263 |
+
# Create subset with valid samples
|
| 264 |
+
valid_dataset = torch.utils.data.Subset(dataset, valid_indices)
|
| 265 |
+
dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
|
| 266 |
+
|
| 267 |
+
# Initialize model with score prediction
|
| 268 |
+
model = FastPointNet(input_dim=7, output_dim=3, max_points=1024, predict_score=True)
|
| 269 |
+
model.to(device)
|
| 270 |
+
|
| 271 |
+
# Loss functions
|
| 272 |
+
position_criterion = nn.MSELoss()
|
| 273 |
+
score_criterion = nn.MSELoss()
|
| 274 |
+
|
| 275 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)
|
| 276 |
+
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.5)
|
| 277 |
+
|
| 278 |
+
# Training loop
|
| 279 |
+
model.train()
|
| 280 |
+
for epoch in range(epochs):
|
| 281 |
+
total_loss = 0.0
|
| 282 |
+
total_pos_loss = 0.0
|
| 283 |
+
total_score_loss = 0.0
|
| 284 |
+
num_batches = 0
|
| 285 |
+
|
| 286 |
+
for batch_idx, (patch_data, targets, valid_masks, distances) in enumerate(dataloader):
|
| 287 |
+
patch_data = patch_data.to(device) # (batch_size, 7, max_points)
|
| 288 |
+
targets = targets.to(device) # (batch_size, 3)
|
| 289 |
+
distances = distances.to(device) # (batch_size, 1)
|
| 290 |
+
|
| 291 |
+
# Forward pass
|
| 292 |
+
optimizer.zero_grad()
|
| 293 |
+
predictions, predicted_scores = model(patch_data)
|
| 294 |
+
|
| 295 |
+
# Compute actual distance from predictions to targets
|
| 296 |
+
actual_distances = torch.norm(predictions - targets, dim=1, keepdim=True)
|
| 297 |
+
|
| 298 |
+
# Compute losses
|
| 299 |
+
pos_loss = position_criterion(predictions, targets)
|
| 300 |
+
score_loss = score_criterion(predicted_scores, actual_distances)
|
| 301 |
+
|
| 302 |
+
# Combined loss
|
| 303 |
+
total_batch_loss = pos_loss + score_weight * score_loss
|
| 304 |
+
|
| 305 |
+
# Backward pass
|
| 306 |
+
total_batch_loss.backward()
|
| 307 |
+
optimizer.step()
|
| 308 |
+
|
| 309 |
+
total_loss += total_batch_loss.item()
|
| 310 |
+
total_pos_loss += pos_loss.item()
|
| 311 |
+
total_score_loss += score_loss.item()
|
| 312 |
+
num_batches += 1
|
| 313 |
+
|
| 314 |
+
if batch_idx % 50 == 0:
|
| 315 |
+
print(f"Epoch {epoch+1}/{epochs}, Batch {batch_idx}, "
|
| 316 |
+
f"Total Loss: {total_batch_loss.item():.6f}, "
|
| 317 |
+
f"Pos Loss: {pos_loss.item():.6f}, "
|
| 318 |
+
f"Score Loss: {score_loss.item():.6f}")
|
| 319 |
+
|
| 320 |
+
avg_loss = total_loss / num_batches if num_batches > 0 else 0
|
| 321 |
+
avg_pos_loss = total_pos_loss / num_batches if num_batches > 0 else 0
|
| 322 |
+
avg_score_loss = total_score_loss / num_batches if num_batches > 0 else 0
|
| 323 |
+
|
| 324 |
+
print(f"Epoch {epoch+1}/{epochs} completed, "
|
| 325 |
+
f"Avg Total Loss: {avg_loss:.6f}, "
|
| 326 |
+
f"Avg Pos Loss: {avg_pos_loss:.6f}, "
|
| 327 |
+
f"Avg Score Loss: {avg_score_loss:.6f}")
|
| 328 |
+
|
| 329 |
+
scheduler.step()
|
| 330 |
+
|
| 331 |
+
# Save model checkpoint every epoch
|
| 332 |
+
checkpoint_path = model_save_path.replace('.pth', f'_epoch_{epoch+1}.pth')
|
| 333 |
+
torch.save({
|
| 334 |
+
'model_state_dict': model.state_dict(),
|
| 335 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
| 336 |
+
'epoch': epoch + 1,
|
| 337 |
+
'loss': avg_loss,
|
| 338 |
+
}, checkpoint_path)
|
| 339 |
+
|
| 340 |
+
# Save the trained model
|
| 341 |
+
torch.save({
|
| 342 |
+
'model_state_dict': model.state_dict(),
|
| 343 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
| 344 |
+
'epoch': epochs,
|
| 345 |
+
}, model_save_path)
|
| 346 |
+
|
| 347 |
+
print(f"Model saved to {model_save_path}")
|
| 348 |
+
return model
|
| 349 |
+
|
| 350 |
+
def load_pointnet_model(model_path: str, device: torch.device = None, predict_score: bool = True) -> FastPointNet:
|
| 351 |
+
"""
|
| 352 |
+
Load a trained FastPointNet model.
|
| 353 |
+
|
| 354 |
+
Args:
|
| 355 |
+
model_path: Path to the saved model
|
| 356 |
+
device: Device to load the model on
|
| 357 |
+
predict_score: Whether the model predicts scores
|
| 358 |
+
|
| 359 |
+
Returns:
|
| 360 |
+
Loaded FastPointNet model
|
| 361 |
+
"""
|
| 362 |
+
if device is None:
|
| 363 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 364 |
+
|
| 365 |
+
model = FastPointNet(input_dim=7, output_dim=3, max_points=1024, predict_score=predict_score)
|
| 366 |
+
|
| 367 |
+
checkpoint = torch.load(model_path, map_location=device)
|
| 368 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
| 369 |
+
|
| 370 |
+
model.to(device)
|
| 371 |
+
model.eval()
|
| 372 |
+
|
| 373 |
+
return model
|
| 374 |
+
|
| 375 |
+
def predict_vertex_from_patch(model: FastPointNet, patch_7d: np.ndarray, device: torch.device = None) -> Tuple[np.ndarray, float]:
|
| 376 |
+
"""
|
| 377 |
+
Predict 3D vertex coordinates and confidence score from a patch using trained PointNet.
|
| 378 |
+
|
| 379 |
+
Args:
|
| 380 |
+
model: Trained FastPointNet model
|
| 381 |
+
patch_7d: (N, 7) numpy array of point cloud data
|
| 382 |
+
device: Device to run prediction on
|
| 383 |
+
|
| 384 |
+
Returns:
|
| 385 |
+
tuple of (predicted_coordinates, confidence_score)
|
| 386 |
+
predicted_coordinates: (3,) numpy array of predicted 3D coordinates
|
| 387 |
+
confidence_score: float representing predicted distance to GT (lower is better)
|
| 388 |
+
"""
|
| 389 |
+
if device is None:
|
| 390 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 391 |
+
|
| 392 |
+
model.eval()
|
| 393 |
+
|
| 394 |
+
# Prepare input
|
| 395 |
+
max_points = 1024
|
| 396 |
+
num_points = patch_7d.shape[0]
|
| 397 |
+
|
| 398 |
+
if num_points >= max_points:
|
| 399 |
+
# Sample points
|
| 400 |
+
indices = np.random.choice(num_points, max_points, replace=False)
|
| 401 |
+
patch_sampled = patch_7d[indices]
|
| 402 |
+
else:
|
| 403 |
+
# Pad with zeros
|
| 404 |
+
patch_sampled = np.zeros((max_points, 7))
|
| 405 |
+
patch_sampled[:num_points] = patch_7d
|
| 406 |
+
|
| 407 |
+
# Convert to tensor
|
| 408 |
+
patch_tensor = torch.from_numpy(patch_sampled.T).float().unsqueeze(0) # (1, 7, max_points)
|
| 409 |
+
patch_tensor = patch_tensor.to(device)
|
| 410 |
+
|
| 411 |
+
# Predict
|
| 412 |
+
with torch.no_grad():
|
| 413 |
+
if model.predict_score:
|
| 414 |
+
position, score = model(patch_tensor)
|
| 415 |
+
position = position.cpu().numpy().squeeze()
|
| 416 |
+
score = score.cpu().numpy().squeeze()
|
| 417 |
+
return position, score
|
| 418 |
+
else:
|
| 419 |
+
position = model(patch_tensor)
|
| 420 |
+
position = position.cpu().numpy().squeeze()
|
| 421 |
+
return position, None
|
predict.py
CHANGED
|
@@ -10,6 +10,11 @@ from PIL import Image as PImage
|
|
| 10 |
import cv2
|
| 11 |
import open3d as o3d
|
| 12 |
from visu import plot_reconstruction_local, plot_wireframe_local, plot_bpo_cameras_from_entry_local
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
def convert_entry_to_human_readable(entry):
|
| 15 |
out = {}
|
|
@@ -389,11 +394,6 @@ def predict_wireframe(entry) -> Tuple[np.ndarray, List[int]]:
|
|
| 389 |
good_entry = convert_entry_to_human_readable(entry)
|
| 390 |
colmap_rec = good_entry['colmap_binary']
|
| 391 |
|
| 392 |
-
colmap_pcloud = []
|
| 393 |
-
for i, p3D in colmap_rec.points3D.items():
|
| 394 |
-
p3D.color = np.array([0, 0, 0])
|
| 395 |
-
colmap_pcloud.append(p3D)
|
| 396 |
-
|
| 397 |
vert_edge_per_image = {}
|
| 398 |
for i, (gest, depth, K, R, t, img_id, ade_seg) in enumerate(zip(good_entry['gestalt'],
|
| 399 |
good_entry['depth'],
|
|
@@ -413,7 +413,13 @@ def predict_wireframe(entry) -> Tuple[np.ndarray, List[int]]:
|
|
| 413 |
gest_seg = gest.resize(depth_size)
|
| 414 |
gest_seg_np = np.array(gest_seg).astype(np.uint8)
|
| 415 |
|
| 416 |
-
vertices_ours, connections_ours, vertices_3d_ours = our_get_vertices_and_edges(gest_seg_np, colmap_rec, img_id, ade_seg, depth, K=K, R=R, t=t)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 417 |
vertices, connections, vertices_3d = vertices_ours, connections_ours, vertices_3d_ours
|
| 418 |
# Get 2D vertices and edges first
|
| 419 |
#vertices, connections = get_vertices_and_edges_from_segmentation(gest_seg_np, edge_th=25.)
|
|
@@ -444,7 +450,7 @@ def predict_wireframe(entry) -> Tuple[np.ndarray, List[int]]:
|
|
| 444 |
continue
|
| 445 |
|
| 446 |
# Call the refactored function to get 3D points
|
| 447 |
-
|
| 448 |
#vertices_3d = gt_verts3d
|
| 449 |
# Store original 2D vertices, connections, and computed 3D points
|
| 450 |
|
|
@@ -480,6 +486,9 @@ def predict_wireframe(entry) -> Tuple[np.ndarray, List[int]]:
|
|
| 480 |
# Visualize the point cloud
|
| 481 |
o3d.visualization.draw_geometries([pcd], window_name="COLMAP Point Cloud")
|
| 482 |
'''
|
|
|
|
|
|
|
|
|
|
| 483 |
# Merge vertices from all images
|
| 484 |
all_3d_vertices, connections_3d = merge_vertices_3d(vert_edge_per_image, 0.5)
|
| 485 |
all_3d_vertices_clean, connections_3d_clean = prune_not_connected(all_3d_vertices, connections_3d, keep_largest=False)
|
|
@@ -627,9 +636,9 @@ def get_apex_or_eave_points(apex, uv, gest_seg_np, house_mask, valid_indices, po
|
|
| 627 |
final_valid_indices = valid_point_indices[depth_filter]
|
| 628 |
|
| 629 |
# Add corresponding points to filtered lists
|
| 630 |
-
filtered_points_xyz.
|
| 631 |
-
filtered_point_idxs.
|
| 632 |
-
filtered_points_color.
|
| 633 |
|
| 634 |
# Find the point with lowest depth in the filtered points
|
| 635 |
if len(final_valid_indices) > 0:
|
|
@@ -637,9 +646,6 @@ def get_apex_or_eave_points(apex, uv, gest_seg_np, house_mask, valid_indices, po
|
|
| 637 |
lowest_depth_point = final_valid_indices[lowest_depth_idx]
|
| 638 |
|
| 639 |
filtered_vertices_apex.append(points_xyz_world[lowest_depth_point])
|
| 640 |
-
filtered_points_xyz.append(points_xyz_world[lowest_depth_point])
|
| 641 |
-
filtered_point_idxs.append(points_idxs[lowest_depth_point])
|
| 642 |
-
filtered_points_color.append(np.array([1., 1., 0.]))
|
| 643 |
filtered_vertices_apex_uv.append(centroids[i])
|
| 644 |
|
| 645 |
return filtered_points_xyz, filtered_point_idxs, filtered_points_color, filtered_vertices_apex, filtered_vertices_apex_uv
|
|
@@ -653,9 +659,9 @@ def get_vertexes(uv, gest_seg_np, house_mask, valid_indices, points_xyz_world, p
|
|
| 653 |
filtered_point_idxs = filtered_point_idxs_apex + filtered_point_idxs_eave
|
| 654 |
filtered_points_color = filtered_points_color_apex + filtered_points_color_eave
|
| 655 |
|
| 656 |
-
filtered_points_xyz = np.array(filtered_points_xyz[::-1]) if filtered_points_xyz else np.empty((0, 3))
|
| 657 |
-
filtered_point_idxs = np.array(filtered_point_idxs[::-1]) if filtered_point_idxs else np.empty((0,))
|
| 658 |
-
filtered_points_color = np.array(filtered_points_color[::-1]) if filtered_points_color else np.empty((0, 3))
|
| 659 |
filtered_vertices_apex = np.array(filtered_vertices_apex) if filtered_vertices_apex else np.empty((0, 3))
|
| 660 |
filtered_vertices_apex_uv = np.array(filtered_vertices_apex_uv) if filtered_vertices_apex_uv else np.empty((0, 2))
|
| 661 |
filtered_vertices_eave = np.array(filtered_vertices_eave) if filtered_vertices_eave else np.empty((0, 3))
|
|
@@ -803,7 +809,156 @@ def visualize_3d_wireframe(colmap_rec, filtered_points_xyz, filtered_points_colo
|
|
| 803 |
|
| 804 |
#o3d.visualization.draw_geometries(geometries, window_name=f"Combined Point Cloud - {img_id_substring}")
|
| 805 |
|
| 806 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 807 |
"""
|
| 808 |
Identify apex and eave-end vertices, then detect lines for eave/ridge/rake/valley.
|
| 809 |
Also find all COLMAP points that project into apex or eave_end masks.
|
|
@@ -838,6 +993,8 @@ def our_get_vertices_and_edges(gest_seg_np, colmap_rec, img_id_substring, ade_se
|
|
| 838 |
|
| 839 |
vertices_formatted, connections, all_vertices_3d = get_connections(gest_seg_np, filtered_vertices_apex, filtered_vertices_eave, filtered_vertices_apex_uv, filtered_vertices_eave_uv)
|
| 840 |
|
|
|
|
|
|
|
| 841 |
#visualize_3d_wireframe(colmap_rec, filtered_points_xyz, filtered_points_color, all_vertices_3d, connections)
|
| 842 |
|
| 843 |
-
return vertices_formatted, connections, all_vertices_3d
|
|
|
|
| 10 |
import cv2
|
| 11 |
import open3d as o3d
|
| 12 |
from visu import plot_reconstruction_local, plot_wireframe_local, plot_bpo_cameras_from_entry_local
|
| 13 |
+
import pyvista as pv
|
| 14 |
+
from fast_pointnet import save_patches_dataset
|
| 15 |
+
|
| 16 |
+
GENERATE_DATASET = True
|
| 17 |
+
DATASET_DIR = '/home/skvrnjan/personal/hohocustom/'
|
| 18 |
|
| 19 |
def convert_entry_to_human_readable(entry):
|
| 20 |
out = {}
|
|
|
|
| 394 |
good_entry = convert_entry_to_human_readable(entry)
|
| 395 |
colmap_rec = good_entry['colmap_binary']
|
| 396 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 397 |
vert_edge_per_image = {}
|
| 398 |
for i, (gest, depth, K, R, t, img_id, ade_seg) in enumerate(zip(good_entry['gestalt'],
|
| 399 |
good_entry['depth'],
|
|
|
|
| 413 |
gest_seg = gest.resize(depth_size)
|
| 414 |
gest_seg_np = np.array(gest_seg).astype(np.uint8)
|
| 415 |
|
| 416 |
+
vertices_ours, connections_ours, vertices_3d_ours, patches = our_get_vertices_and_edges(gest_seg_np, colmap_rec, img_id, ade_seg, depth, K=K, R=R, t=t, frame=good_entry)
|
| 417 |
+
|
| 418 |
+
if GENERATE_DATASET:
|
| 419 |
+
save_patches_dataset(patches, DATASET_DIR, img_id)
|
| 420 |
+
|
| 421 |
+
continue
|
| 422 |
+
|
| 423 |
vertices, connections, vertices_3d = vertices_ours, connections_ours, vertices_3d_ours
|
| 424 |
# Get 2D vertices and edges first
|
| 425 |
#vertices, connections = get_vertices_and_edges_from_segmentation(gest_seg_np, edge_th=25.)
|
|
|
|
| 450 |
continue
|
| 451 |
|
| 452 |
# Call the refactored function to get 3D points
|
| 453 |
+
vertices_3d = create_3d_wireframe_single_image(vertices, connections, depth, colmap_rec, img_id, ade_seg, K, R, t)
|
| 454 |
#vertices_3d = gt_verts3d
|
| 455 |
# Store original 2D vertices, connections, and computed 3D points
|
| 456 |
|
|
|
|
| 486 |
# Visualize the point cloud
|
| 487 |
o3d.visualization.draw_geometries([pcd], window_name="COLMAP Point Cloud")
|
| 488 |
'''
|
| 489 |
+
if GENERATE_DATASET:
|
| 490 |
+
return empty_solution()
|
| 491 |
+
|
| 492 |
# Merge vertices from all images
|
| 493 |
all_3d_vertices, connections_3d = merge_vertices_3d(vert_edge_per_image, 0.5)
|
| 494 |
all_3d_vertices_clean, connections_3d_clean = prune_not_connected(all_3d_vertices, connections_3d, keep_largest=False)
|
|
|
|
| 636 |
final_valid_indices = valid_point_indices[depth_filter]
|
| 637 |
|
| 638 |
# Add corresponding points to filtered lists
|
| 639 |
+
filtered_points_xyz.append(points_xyz_world[final_valid_indices])
|
| 640 |
+
filtered_point_idxs.append(points_idxs[final_valid_indices])
|
| 641 |
+
filtered_points_color.append([color] * np.sum(depth_filter))
|
| 642 |
|
| 643 |
# Find the point with lowest depth in the filtered points
|
| 644 |
if len(final_valid_indices) > 0:
|
|
|
|
| 646 |
lowest_depth_point = final_valid_indices[lowest_depth_idx]
|
| 647 |
|
| 648 |
filtered_vertices_apex.append(points_xyz_world[lowest_depth_point])
|
|
|
|
|
|
|
|
|
|
| 649 |
filtered_vertices_apex_uv.append(centroids[i])
|
| 650 |
|
| 651 |
return filtered_points_xyz, filtered_point_idxs, filtered_points_color, filtered_vertices_apex, filtered_vertices_apex_uv
|
|
|
|
| 659 |
filtered_point_idxs = filtered_point_idxs_apex + filtered_point_idxs_eave
|
| 660 |
filtered_points_color = filtered_points_color_apex + filtered_points_color_eave
|
| 661 |
|
| 662 |
+
#filtered_points_xyz = np.array(filtered_points_xyz[::-1]) if filtered_points_xyz else np.empty((0, 3))
|
| 663 |
+
#filtered_point_idxs = np.array(filtered_point_idxs[::-1]) if filtered_point_idxs else np.empty((0,))
|
| 664 |
+
#filtered_points_color = np.array(filtered_points_color[::-1]) if filtered_points_color else np.empty((0, 3))
|
| 665 |
filtered_vertices_apex = np.array(filtered_vertices_apex) if filtered_vertices_apex else np.empty((0, 3))
|
| 666 |
filtered_vertices_apex_uv = np.array(filtered_vertices_apex_uv) if filtered_vertices_apex_uv else np.empty((0, 2))
|
| 667 |
filtered_vertices_eave = np.array(filtered_vertices_eave) if filtered_vertices_eave else np.empty((0, 3))
|
|
|
|
| 809 |
|
| 810 |
#o3d.visualization.draw_geometries(geometries, window_name=f"Combined Point Cloud - {img_id_substring}")
|
| 811 |
|
| 812 |
+
def generate_patches(colmap_rec, filtered_points_idxs, frame, filtered_vertices):
|
| 813 |
+
patches = []
|
| 814 |
+
|
| 815 |
+
gt_vertices = frame['wf_vertices']
|
| 816 |
+
|
| 817 |
+
# Process each group of filtered points
|
| 818 |
+
for group_idx, point_idxs in enumerate(filtered_points_idxs):
|
| 819 |
+
if len(point_idxs) == 0:
|
| 820 |
+
continue
|
| 821 |
+
|
| 822 |
+
# Get 3D coordinates and colors for this group
|
| 823 |
+
group_points_3d = []
|
| 824 |
+
group_colors = []
|
| 825 |
+
assigned_gt_vertex = None
|
| 826 |
+
|
| 827 |
+
for pid in point_idxs:
|
| 828 |
+
p3d = colmap_rec.points3D[pid]
|
| 829 |
+
group_points_3d.append(p3d.xyz)
|
| 830 |
+
group_colors.append(p3d.color)
|
| 831 |
+
|
| 832 |
+
if len(group_points_3d) == 0:
|
| 833 |
+
continue
|
| 834 |
+
|
| 835 |
+
group_points_3d = np.array(group_points_3d)
|
| 836 |
+
group_colors = np.array(group_colors)
|
| 837 |
+
|
| 838 |
+
# Calculate centroid of filtered points
|
| 839 |
+
# Find the closest GT vertex to the centroid of filtered points
|
| 840 |
+
centroid = np.mean(group_points_3d, axis=0)
|
| 841 |
+
|
| 842 |
+
if len(gt_vertices) > 0:
|
| 843 |
+
# Calculate distances from centroid to all GT vertices
|
| 844 |
+
distances_to_gt = []
|
| 845 |
+
for gt_vertex in gt_vertices:
|
| 846 |
+
distance = np.linalg.norm(gt_vertex - centroid)
|
| 847 |
+
distances_to_gt.append(distance)
|
| 848 |
+
|
| 849 |
+
# Find the closest GT vertex
|
| 850 |
+
min_distance_idx = np.argmin(distances_to_gt)
|
| 851 |
+
closest_gt_vertex = gt_vertices[min_distance_idx]
|
| 852 |
+
min_distance = distances_to_gt[min_distance_idx]
|
| 853 |
+
|
| 854 |
+
# Define ball radius (you can adjust this value)
|
| 855 |
+
ball_radius = 2.0 # meters
|
| 856 |
+
|
| 857 |
+
# Use closest GT vertex as centroid if it's within the ball radius
|
| 858 |
+
if min_distance <= ball_radius:
|
| 859 |
+
assigned_gt_vertex = closest_gt_vertex
|
| 860 |
+
# If no GT vertex is close enough, skip this group
|
| 861 |
+
else:
|
| 862 |
+
continue
|
| 863 |
+
else:
|
| 864 |
+
# No GT vertices available, use original centroid
|
| 865 |
+
centroid = np.mean(group_points_3d, axis=0)
|
| 866 |
+
|
| 867 |
+
# Define ball radius (you can adjust this value)
|
| 868 |
+
ball_radius = 2.0 # meters
|
| 869 |
+
|
| 870 |
+
# Find all COLMAP points within the ball around centroid
|
| 871 |
+
patch_points_3d = []
|
| 872 |
+
patch_colors = []
|
| 873 |
+
patch_point_ids = []
|
| 874 |
+
|
| 875 |
+
for pid, p3d in colmap_rec.points3D.items():
|
| 876 |
+
distance = np.linalg.norm(p3d.xyz - centroid)
|
| 877 |
+
if distance <= ball_radius:
|
| 878 |
+
patch_points_3d.append(p3d.xyz)
|
| 879 |
+
patch_colors.append(p3d.color)
|
| 880 |
+
patch_point_ids.append(pid)
|
| 881 |
+
|
| 882 |
+
if len(patch_points_3d) == 0:
|
| 883 |
+
continue
|
| 884 |
+
|
| 885 |
+
patch_points_3d = np.array(patch_points_3d)
|
| 886 |
+
|
| 887 |
+
# Calculate offset to center the patch
|
| 888 |
+
patch_centroid = np.mean(patch_points_3d, axis=0)
|
| 889 |
+
offset = -patch_centroid
|
| 890 |
+
|
| 891 |
+
# Shift points to center them around origin
|
| 892 |
+
patch_points_3d += offset
|
| 893 |
+
|
| 894 |
+
# Also shift the assigned GT vertex by the same offset if it exists
|
| 895 |
+
if assigned_gt_vertex is not None:
|
| 896 |
+
assigned_gt_vertex = assigned_gt_vertex + offset
|
| 897 |
+
patch_colors = np.array(patch_colors)
|
| 898 |
+
|
| 899 |
+
# Create 7D point cloud for this patch
|
| 900 |
+
# [x, y, z, r, g, b, in_filtered_flag]
|
| 901 |
+
patch_7d = np.zeros((len(patch_points_3d), 7))
|
| 902 |
+
patch_7d[:, :3] = patch_points_3d # xyz coordinates
|
| 903 |
+
patch_7d[:, 3:6] = patch_colors / 255.0 # rgb colors normalized to [0,1]
|
| 904 |
+
|
| 905 |
+
# Set in_filtered_flag: 1 if point was in original filtered set, 0 otherwise
|
| 906 |
+
for i, pid in enumerate(patch_point_ids):
|
| 907 |
+
if pid in point_idxs:
|
| 908 |
+
patch_7d[i, 6] = 1.0
|
| 909 |
+
else:
|
| 910 |
+
patch_7d[i, 6] = 0.0
|
| 911 |
+
|
| 912 |
+
if filtered_vertices[group_idx] is not None:
|
| 913 |
+
initial_pred = filtered_vertices[group_idx] + offset
|
| 914 |
+
else:
|
| 915 |
+
initial_pred = None
|
| 916 |
+
|
| 917 |
+
patches.append({
|
| 918 |
+
'patch_7d': patch_7d,
|
| 919 |
+
'centroid': centroid,
|
| 920 |
+
'radius': ball_radius,
|
| 921 |
+
'point_ids': patch_point_ids,
|
| 922 |
+
'filtered_point_ids': point_idxs,
|
| 923 |
+
'group_idx': group_idx,
|
| 924 |
+
'assigned_gt_vertex': assigned_gt_vertex,
|
| 925 |
+
'offset': offset,
|
| 926 |
+
'initial_pred': initial_pred
|
| 927 |
+
})
|
| 928 |
+
|
| 929 |
+
if False:
|
| 930 |
+
# Create plotter
|
| 931 |
+
plotter = pv.Plotter()
|
| 932 |
+
|
| 933 |
+
# Create point cloud for this patch
|
| 934 |
+
patch_cloud = pv.PolyData(patch_points_3d)
|
| 935 |
+
|
| 936 |
+
# Color points: red for filtered points, blue for other points
|
| 937 |
+
patch_point_colors = []
|
| 938 |
+
for i, pid in enumerate(patch_point_ids):
|
| 939 |
+
if pid in point_idxs:
|
| 940 |
+
patch_point_colors.append([255, 0, 0]) # Red for filtered points
|
| 941 |
+
else:
|
| 942 |
+
patch_point_colors.append([0, 0, 255]) # Blue for other points
|
| 943 |
+
|
| 944 |
+
patch_cloud["colors"] = np.array(patch_point_colors)
|
| 945 |
+
plotter.add_mesh(patch_cloud, scalars="colors", rgb=True, point_size=8, render_points_as_spheres=True)
|
| 946 |
+
|
| 947 |
+
# Create sphere to visualize GT vertex if available
|
| 948 |
+
if assigned_gt_vertex is not None:
|
| 949 |
+
gt_sphere = pv.Sphere(radius=0.1, center=assigned_gt_vertex)
|
| 950 |
+
plotter.add_mesh(gt_sphere, color="green", opacity=0.5)
|
| 951 |
+
|
| 952 |
+
if initial_pred is not None:
|
| 953 |
+
# Create sphere to visualize initial prediction
|
| 954 |
+
pred_sphere = pv.Sphere(radius=0.1, center=initial_pred)
|
| 955 |
+
plotter.add_mesh(pred_sphere, color="orange", opacity=0.5)
|
| 956 |
+
|
| 957 |
+
plotter.show(title=f"Patch {group_idx}")
|
| 958 |
+
|
| 959 |
+
return patches
|
| 960 |
+
|
| 961 |
+
def our_get_vertices_and_edges(gest_seg_np, colmap_rec, img_id_substring, ade_seg, depth, K=None, R=None, t=None, frame=None):
|
| 962 |
"""
|
| 963 |
Identify apex and eave-end vertices, then detect lines for eave/ridge/rake/valley.
|
| 964 |
Also find all COLMAP points that project into apex or eave_end masks.
|
|
|
|
| 993 |
|
| 994 |
vertices_formatted, connections, all_vertices_3d = get_connections(gest_seg_np, filtered_vertices_apex, filtered_vertices_eave, filtered_vertices_apex_uv, filtered_vertices_eave_uv)
|
| 995 |
|
| 996 |
+
patches = generate_patches(colmap_rec, filtered_point_idxs, frame, all_vertices_3d)
|
| 997 |
+
|
| 998 |
#visualize_3d_wireframe(colmap_rec, filtered_points_xyz, filtered_points_color, all_vertices_3d, connections)
|
| 999 |
|
| 1000 |
+
return vertices_formatted, connections, all_vertices_3d, patches
|
train.py
CHANGED
|
@@ -22,7 +22,7 @@ scores_iou = []
|
|
| 22 |
show_visu = False
|
| 23 |
|
| 24 |
idx = 0
|
| 25 |
-
for a in ds['
|
| 26 |
#plot_all_modalities(a)
|
| 27 |
#pred_vertices, pred_edges = predict_wireframe(a)
|
| 28 |
try:
|
|
|
|
| 22 |
show_visu = False
|
| 23 |
|
| 24 |
idx = 0
|
| 25 |
+
for a in ds['train']:
|
| 26 |
#plot_all_modalities(a)
|
| 27 |
#pred_vertices, pred_edges = predict_wireframe(a)
|
| 28 |
try:
|