| import os |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import numpy as np |
| import pickle |
| from torch.utils.data import Dataset, DataLoader |
| from typing import List, Dict, Tuple, Optional |
| import json |
|
|
| class FastPointNet(nn.Module): |
| """ |
| Fast PointNet implementation for 3D vertex prediction from point cloud patches. |
| Takes 7D point clouds (x,y,z,r,g,b,filtered_flag) and predicts 3D vertex coordinates. |
| Enhanced with deeper architecture and more parameters for better generalization. |
| """ |
| def __init__(self, input_dim=7, output_dim=3, max_points=1024, predict_score=True, predict_class=True, num_classes=1): |
| super(FastPointNet, self).__init__() |
| self.max_points = max_points |
| self.predict_score = predict_score |
| self.predict_class = predict_class |
| self.num_classes = num_classes |
| |
| |
| self.conv1 = nn.Conv1d(input_dim, 128, 1) |
| self.conv2 = nn.Conv1d(128, 256, 1) |
| self.conv3 = nn.Conv1d(256, 512, 1) |
| self.conv4 = nn.Conv1d(512, 1024, 1) |
| |
| |
| self.conv5 = nn.Conv1d(1024, 1024, 1) |
| self.conv6 = nn.Conv1d(1024, 2048, 1) |
| |
| |
| self.shared_fc1 = nn.Linear(2048, 1024) |
| self.shared_fc2 = nn.Linear(1024, 512) |
| |
| |
| self.pos_fc1 = nn.Linear(512, 512) |
| self.pos_fc2 = nn.Linear(512, 256) |
| self.pos_fc3 = nn.Linear(256, 128) |
| self.pos_fc4 = nn.Linear(128, output_dim) |
| |
| |
| if self.predict_score: |
| self.score_fc1 = nn.Linear(512, 512) |
| self.score_fc2 = nn.Linear(512, 256) |
| self.score_fc3 = nn.Linear(256, 128) |
| self.score_fc4 = nn.Linear(128, 64) |
| self.score_fc5 = nn.Linear(64, 1) |
| |
| |
| if self.predict_class: |
| self.class_fc1 = nn.Linear(512, 512) |
| self.class_fc2 = nn.Linear(512, 256) |
| self.class_fc3 = nn.Linear(256, 128) |
| self.class_fc4 = nn.Linear(128, 64) |
| self.class_fc5 = nn.Linear(64, num_classes) |
| |
| |
| self.bn1 = nn.BatchNorm1d(128) |
| self.bn2 = nn.BatchNorm1d(256) |
| self.bn3 = nn.BatchNorm1d(512) |
| self.bn4 = nn.BatchNorm1d(1024) |
| self.bn5 = nn.BatchNorm1d(1024) |
| self.bn6 = nn.BatchNorm1d(2048) |
| |
| |
| self.dropout_light = nn.Dropout(0.2) |
| self.dropout_medium = nn.Dropout(0.3) |
| self.dropout_heavy = nn.Dropout(0.4) |
|
|
| def forward(self, x): |
| """ |
| Forward pass |
| Args: |
| x: (batch_size, input_dim, max_points) tensor |
| Returns: |
| Tuple containing predictions based on configuration: |
| - position: (batch_size, output_dim) tensor of predicted 3D coordinates |
| - score: (batch_size, 1) tensor of predicted distance to GT (if predict_score=True) |
| - classification: (batch_size, num_classes) tensor of class logits (if predict_class=True) |
| """ |
| batch_size = x.size(0) |
| |
| |
| x1 = F.relu(self.bn1(self.conv1(x))) |
| x2 = F.relu(self.bn2(self.conv2(x1))) |
| x3 = F.relu(self.bn3(self.conv3(x2))) |
| x4 = F.relu(self.bn4(self.conv4(x3))) |
| x5 = F.relu(self.bn5(self.conv5(x4))) |
| x6 = F.relu(self.bn6(self.conv6(x5))) |
| |
| |
| max_pool = torch.max(x6, 2)[0] |
| avg_pool = torch.mean(x6, 2) |
| |
| |
| global_features = max_pool + avg_pool |
| |
| |
| shared1 = F.relu(self.shared_fc1(global_features)) |
| shared1 = self.dropout_light(shared1) |
| shared2 = F.relu(self.shared_fc2(shared1)) |
| shared_features = self.dropout_medium(shared2) |
| |
| |
| pos1 = F.relu(self.pos_fc1(shared_features)) |
| pos1 = self.dropout_light(pos1) |
| pos2 = F.relu(self.pos_fc2(pos1)) |
| pos2 = self.dropout_medium(pos2) |
| pos3 = F.relu(self.pos_fc3(pos2)) |
| pos3 = self.dropout_light(pos3) |
| position = self.pos_fc4(pos3) |
| |
| outputs = [position] |
| |
| if self.predict_score: |
| |
| score1 = F.relu(self.score_fc1(shared_features)) |
| score1 = self.dropout_light(score1) |
| score2 = F.relu(self.score_fc2(score1)) |
| score2 = self.dropout_medium(score2) |
| score3 = F.relu(self.score_fc3(score2)) |
| score3 = self.dropout_light(score3) |
| score4 = F.relu(self.score_fc4(score3)) |
| score4 = self.dropout_light(score4) |
| score = F.relu(self.score_fc5(score4)) |
| outputs.append(score) |
| |
| if self.predict_class: |
| |
| class1 = F.relu(self.class_fc1(shared_features)) |
| class1 = self.dropout_light(class1) |
| class2 = F.relu(self.class_fc2(class1)) |
| class2 = self.dropout_medium(class2) |
| class3 = F.relu(self.class_fc3(class2)) |
| class3 = self.dropout_light(class3) |
| class4 = F.relu(self.class_fc4(class3)) |
| class4 = self.dropout_light(class4) |
| classification = self.class_fc5(class4) |
| outputs.append(classification) |
| |
| |
| if len(outputs) == 1: |
| return outputs[0] |
| elif len(outputs) == 2: |
| if self.predict_score: |
| return outputs[0], outputs[1] |
| else: |
| return outputs[0], outputs[1] |
| else: |
| return outputs[0], outputs[1], outputs[2] |
|
|
| class PatchDataset(Dataset): |
| """ |
| Dataset class for loading saved patches for PointNet training. |
| """ |
| |
| def __init__(self, dataset_dir: str, max_points: int = 1024, augment: bool = True): |
| self.dataset_dir = dataset_dir |
| self.max_points = max_points |
| self.augment = augment |
| |
| |
| self.patch_files = [] |
| for file in os.listdir(dataset_dir): |
| if file.endswith('.pkl'): |
| self.patch_files.append(os.path.join(dataset_dir, file)) |
| |
| print(f"Found {len(self.patch_files)} patch files in {dataset_dir}") |
|
|
| def __len__(self): |
| return len(self.patch_files) |
|
|
| def __getitem__(self, idx): |
| """ |
| Load and process a patch for training. |
| Returns: |
| patch_data: (7, max_points) tensor of point cloud data |
| target: (3,) tensor of target 3D coordinates |
| valid_mask: (max_points,) boolean tensor indicating valid points |
| distance_to_gt: scalar tensor of distance from initial prediction to GT |
| classification: scalar tensor for binary classification (1 if GT vertex present, 0 if not) |
| """ |
| patch_file = self.patch_files[idx] |
| |
| with open(patch_file, 'rb') as f: |
| patch_info = pickle.load(f) |
| |
| patch_7d = patch_info['patch_7d'] |
| target = patch_info.get('assigned_wf_vertex', None) |
| initial_pred = patch_info.get('cluster_center', None) |
| |
| |
| has_gt_vertex = 1.0 if target is not None else 0.0 |
| |
| |
| if target is None: |
| |
| target = np.zeros(3) |
| else: |
| target = np.array(target) |
| |
| |
| num_points = patch_7d.shape[0] |
| |
| if num_points >= self.max_points: |
| |
| indices = np.random.choice(num_points, self.max_points, replace=False) |
| patch_sampled = patch_7d[indices] |
| valid_mask = np.ones(self.max_points, dtype=bool) |
| else: |
| |
| patch_sampled = np.zeros((self.max_points, 7)) |
| patch_sampled[:num_points] = patch_7d |
| valid_mask = np.zeros(self.max_points, dtype=bool) |
| valid_mask[:num_points] = True |
| |
| |
| if self.augment and has_gt_vertex > 0: |
| patch_sampled, target = self._augment_patch(patch_sampled, valid_mask, target) |
| |
| |
| patch_tensor = torch.from_numpy(patch_sampled.T).float() |
| target_tensor = torch.from_numpy(target).float() |
| valid_mask_tensor = torch.from_numpy(valid_mask) |
| |
| |
| if initial_pred is not None: |
| initial_pred_tensor = torch.from_numpy(initial_pred).float() |
| else: |
| initial_pred_tensor = torch.zeros(3).float() |
| |
| |
| classification_tensor = torch.tensor(has_gt_vertex).float() |
| |
| return patch_tensor, target_tensor, valid_mask_tensor, initial_pred_tensor, classification_tensor |
|
|
| def save_patches_dataset(patches: List[Dict], dataset_dir: str, entry_id: str): |
| """ |
| Save patches from prediction pipeline to create a training dataset. |
| |
| Args: |
| patches: List of patch dictionaries from generate_patches() |
| dataset_dir: Directory to save the dataset |
| entry_id: Unique identifier for this entry/image |
| """ |
| os.makedirs(dataset_dir, exist_ok=True) |
| |
| for i, patch in enumerate(patches): |
| |
| filename = f"{entry_id}_patch_{i}.pkl" |
| filepath = os.path.join(dataset_dir, filename) |
| |
| |
| if os.path.exists(filepath): |
| continue |
| |
| |
| with open(filepath, 'wb') as f: |
| pickle.dump(patch, f) |
| |
| print(f"Saved {len(patches)} patches for entry {entry_id}") |
|
|
| |
| def collate_fn(batch): |
| valid_batch = [] |
| for patch_data, target, valid_mask, initial_pred, classification in batch: |
| |
| if valid_mask.sum() > 0: |
| valid_batch.append((patch_data, target, valid_mask, initial_pred, classification)) |
| |
| if len(valid_batch) == 0: |
| return None |
| |
| |
| patch_data = torch.stack([item[0] for item in valid_batch]) |
| targets = torch.stack([item[1] for item in valid_batch]) |
| valid_masks = torch.stack([item[2] for item in valid_batch]) |
| initial_preds = torch.stack([item[3] for item in valid_batch]) |
| classifications = torch.stack([item[4] for item in valid_batch]) |
| |
| return patch_data, targets, valid_masks, initial_preds, classifications |
|
|
| |
| def init_weights(m): |
| if isinstance(m, nn.Conv1d): |
| nn.init.xavier_uniform_(m.weight) |
| if m.bias is not None: |
| nn.init.zeros_(m.bias) |
| elif isinstance(m, nn.Linear): |
| nn.init.xavier_uniform_(m.weight) |
| if m.bias is not None: |
| nn.init.zeros_(m.bias) |
| elif isinstance(m, nn.BatchNorm1d): |
| nn.init.ones_(m.weight) |
| nn.init.zeros_(m.bias) |
|
|
| def train_pointnet(dataset_dir: str, model_save_path: str, epochs: int = 100, batch_size: int = 32, lr: float = 0.001, |
| score_weight: float = 0.1, class_weight: float = 0.5): |
| """ |
| Train the FastPointNet model on saved patches. |
| |
| Args: |
| dataset_dir: Directory containing saved patch files |
| model_save_path: Path to save the trained model |
| epochs: Number of training epochs |
| batch_size: Training batch size |
| lr: Learning rate |
| score_weight: Weight for the distance prediction loss |
| class_weight: Weight for the classification loss |
| """ |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| print(f"Training on device: {device}") |
| |
| |
| dataset = PatchDataset(dataset_dir, max_points=1024, augment=False) |
| print(f"Dataset loaded with {len(dataset)} samples") |
| |
| dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=8, |
| collate_fn=collate_fn, drop_last=True) |
| |
| |
| model = FastPointNet(input_dim=7, output_dim=3, max_points=1024, predict_score=True, predict_class=True, num_classes=1) |
| |
| model.apply(init_weights) |
| model.to(device) |
| |
| |
| position_criterion = nn.MSELoss() |
| score_criterion = nn.MSELoss() |
| classification_criterion = nn.BCEWithLogitsLoss() |
| |
| optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4) |
| scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.5) |
| |
| |
| model.train() |
| for epoch in range(epochs): |
| total_loss = 0.0 |
| total_pos_loss = 0.0 |
| total_score_loss = 0.0 |
| total_class_loss = 0.0 |
| num_batches = 0 |
| |
| for batch_idx, batch_data in enumerate(dataloader): |
| if batch_data is None: |
| continue |
| |
| patch_data, targets, valid_masks, initial_preds, classifications = batch_data |
| patch_data = patch_data.to(device) |
| targets = targets.to(device) |
| classifications = classifications.to(device) |
| |
| |
| optimizer.zero_grad() |
| predictions, predicted_scores, predicted_classes = model(patch_data) |
| |
| |
| actual_distances = torch.norm(predictions - targets, dim=1, keepdim=True) |
| |
| |
| has_gt_mask = classifications > 0.5 |
| |
| if has_gt_mask.sum() > 0: |
| |
| pos_loss = position_criterion(predictions[has_gt_mask], targets[has_gt_mask]) |
| score_loss = score_criterion(predicted_scores[has_gt_mask], actual_distances[has_gt_mask]) |
| else: |
| pos_loss = torch.tensor(0.0, device=device) |
| score_loss = torch.tensor(0.0, device=device) |
| |
| |
| class_loss = classification_criterion(predicted_classes.squeeze(), classifications) |
| |
| |
| total_batch_loss = pos_loss + score_weight * score_loss + class_weight * class_loss |
| |
| |
| total_batch_loss.backward() |
| optimizer.step() |
| |
| total_loss += total_batch_loss.item() |
| total_pos_loss += pos_loss.item() |
| total_score_loss += score_loss.item() |
| total_class_loss += class_loss.item() |
| num_batches += 1 |
| |
| if batch_idx % 50 == 0: |
| print(f"Epoch {epoch+1}/{epochs}, Batch {batch_idx}, " |
| f"Total Loss: {total_batch_loss.item():.6f}, " |
| f"Pos Loss: {pos_loss.item():.6f}, " |
| f"Score Loss: {score_loss.item():.6f}, " |
| f"Class Loss: {class_loss.item():.6f}") |
| |
| avg_loss = total_loss / num_batches if num_batches > 0 else 0 |
| avg_pos_loss = total_pos_loss / num_batches if num_batches > 0 else 0 |
| avg_score_loss = total_score_loss / num_batches if num_batches > 0 else 0 |
| avg_class_loss = total_class_loss / num_batches if num_batches > 0 else 0 |
| |
| print(f"Epoch {epoch+1}/{epochs} completed, " |
| f"Avg Total Loss: {avg_loss:.6f}, " |
| f"Avg Pos Loss: {avg_pos_loss:.6f}, " |
| f"Avg Score Loss: {avg_score_loss:.6f}, " |
| f"Avg Class Loss: {avg_class_loss:.6f}") |
| |
| scheduler.step() |
|
|
| |
| checkpoint_path = model_save_path.replace('.pth', f'_epoch_{epoch+1}.pth') |
| torch.save({ |
| 'model_state_dict': model.state_dict(), |
| 'optimizer_state_dict': optimizer.state_dict(), |
| 'epoch': epoch + 1, |
| 'loss': avg_loss, |
| }, checkpoint_path) |
| |
| |
| torch.save({ |
| 'model_state_dict': model.state_dict(), |
| 'optimizer_state_dict': optimizer.state_dict(), |
| 'epoch': epochs, |
| }, model_save_path) |
| |
| print(f"Model saved to {model_save_path}") |
| return model |
|
|
| def load_pointnet_model(model_path: str, device: torch.device = None, predict_score: bool = True) -> FastPointNet: |
| """ |
| Load a trained FastPointNet model. |
| |
| Args: |
| model_path: Path to the saved model |
| device: Device to load the model on |
| predict_score: Whether the model predicts scores |
| |
| Returns: |
| Loaded FastPointNet model |
| """ |
| if device is None: |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| |
| model = FastPointNet(input_dim=7, output_dim=3, max_points=1024, predict_score=predict_score) |
| |
| checkpoint = torch.load(model_path, map_location=device) |
| model.load_state_dict(checkpoint['model_state_dict']) |
| |
| model.to(device) |
| model.eval() |
| |
| return model |
|
|
| def predict_vertex_from_patch(model: FastPointNet, patch: np.ndarray, device: torch.device = None) -> Tuple[np.ndarray, float, float]: |
| """ |
| Predict 3D vertex coordinates, confidence score, and classification from a patch using trained PointNet. |
| |
| Args: |
| model: Trained FastPointNet model |
| patch: Dictionary containing patch data with 'patch_7d' and 'offset' keys |
| device: Device to run prediction on |
| |
| Returns: |
| tuple of (predicted_coordinates, confidence_score, classification_score) |
| predicted_coordinates: (3,) numpy array of predicted 3D coordinates |
| confidence_score: float representing predicted distance to GT (lower is better) |
| classification_score: float representing probability of GT vertex presence (0-1) |
| """ |
| if device is None: |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| |
| patch_7d = patch['patch_7d'] |
| |
| |
| max_points = 1024 |
| num_points = patch_7d.shape[0] |
| |
| if num_points >= max_points: |
| |
| indices = np.random.choice(num_points, max_points, replace=False) |
| patch_sampled = patch_7d[indices] |
| else: |
| |
| patch_sampled = np.zeros((max_points, 7)) |
| patch_sampled[:num_points] = patch_7d |
| |
| |
| patch_tensor = torch.from_numpy(patch_sampled.T).float().unsqueeze(0) |
| patch_tensor = patch_tensor.to(device) |
| |
| |
| with torch.no_grad(): |
| outputs = model(patch_tensor) |
| |
| if model.predict_score and model.predict_class: |
| position, score, classification = outputs |
| position = position.cpu().numpy().squeeze() |
| score = score.cpu().numpy().squeeze() |
| classification = torch.sigmoid(classification).cpu().numpy().squeeze() |
| elif model.predict_score: |
| position, score = outputs |
| position = position.cpu().numpy().squeeze() |
| score = score.cpu().numpy().squeeze() |
| classification = None |
| elif model.predict_class: |
| position, classification = outputs |
| position = position.cpu().numpy().squeeze() |
| score = None |
| classification = torch.sigmoid(classification).cpu().numpy().squeeze() |
| else: |
| position = outputs |
| position = position.cpu().numpy().squeeze() |
| score = None |
| classification = None |
|
|
| |
| offset = patch['cluster_center'] |
| position += offset |
|
|
| return position, score, classification |