hoho / fast_pointnet.py
jskvrna's picture
Refactors patch generation and prediction
a4de7b0
raw
history blame
21.1 kB
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pickle
from torch.utils.data import Dataset, DataLoader
from typing import List, Dict, Tuple, Optional
import json
class FastPointNet(nn.Module):
"""
Fast PointNet implementation for 3D vertex prediction from point cloud patches.
Takes 7D point clouds (x,y,z,r,g,b,filtered_flag) and predicts 3D vertex coordinates.
Enhanced with deeper architecture and more parameters for better generalization.
"""
def __init__(self, input_dim=7, output_dim=3, max_points=1024, predict_score=True, predict_class=True, num_classes=1):
super(FastPointNet, self).__init__()
self.max_points = max_points
self.predict_score = predict_score
self.predict_class = predict_class
self.num_classes = num_classes
# Enhanced point-wise MLPs with deeper architecture
self.conv1 = nn.Conv1d(input_dim, 128, 1)
self.conv2 = nn.Conv1d(128, 256, 1)
self.conv3 = nn.Conv1d(256, 512, 1)
self.conv4 = nn.Conv1d(512, 1024, 1)
# Additional layers for better feature extraction
self.conv5 = nn.Conv1d(1024, 1024, 1)
self.conv6 = nn.Conv1d(1024, 2048, 1)
# Larger shared features
self.shared_fc1 = nn.Linear(2048, 1024)
self.shared_fc2 = nn.Linear(1024, 512)
# Enhanced position prediction head
self.pos_fc1 = nn.Linear(512, 512)
self.pos_fc2 = nn.Linear(512, 256)
self.pos_fc3 = nn.Linear(256, 128)
self.pos_fc4 = nn.Linear(128, output_dim)
# Enhanced score prediction head
if self.predict_score:
self.score_fc1 = nn.Linear(512, 512)
self.score_fc2 = nn.Linear(512, 256)
self.score_fc3 = nn.Linear(256, 128)
self.score_fc4 = nn.Linear(128, 64)
self.score_fc5 = nn.Linear(64, 1)
# Classification head
if self.predict_class:
self.class_fc1 = nn.Linear(512, 512)
self.class_fc2 = nn.Linear(512, 256)
self.class_fc3 = nn.Linear(256, 128)
self.class_fc4 = nn.Linear(128, 64)
self.class_fc5 = nn.Linear(64, num_classes)
# Batch normalization layers
self.bn1 = nn.BatchNorm1d(128)
self.bn2 = nn.BatchNorm1d(256)
self.bn3 = nn.BatchNorm1d(512)
self.bn4 = nn.BatchNorm1d(1024)
self.bn5 = nn.BatchNorm1d(1024)
self.bn6 = nn.BatchNorm1d(2048)
# Dropout with different rates
self.dropout_light = nn.Dropout(0.2)
self.dropout_medium = nn.Dropout(0.3)
self.dropout_heavy = nn.Dropout(0.4)
def forward(self, x):
"""
Forward pass
Args:
x: (batch_size, input_dim, max_points) tensor
Returns:
Tuple containing predictions based on configuration:
- position: (batch_size, output_dim) tensor of predicted 3D coordinates
- score: (batch_size, 1) tensor of predicted distance to GT (if predict_score=True)
- classification: (batch_size, num_classes) tensor of class logits (if predict_class=True)
"""
batch_size = x.size(0)
# Enhanced point-wise feature extraction with residual-like connections
x1 = F.relu(self.bn1(self.conv1(x)))
x2 = F.relu(self.bn2(self.conv2(x1)))
x3 = F.relu(self.bn3(self.conv3(x2)))
x4 = F.relu(self.bn4(self.conv4(x3)))
x5 = F.relu(self.bn5(self.conv5(x4)))
x6 = F.relu(self.bn6(self.conv6(x5)))
# Global max pooling with additional global average pooling
max_pool = torch.max(x6, 2)[0] # (batch_size, 2048)
avg_pool = torch.mean(x6, 2) # (batch_size, 2048)
# Combine max and average pooling for richer global features
global_features = max_pool + avg_pool # (batch_size, 2048)
# Enhanced shared features with residual connection
shared1 = F.relu(self.shared_fc1(global_features))
shared1 = self.dropout_light(shared1)
shared2 = F.relu(self.shared_fc2(shared1))
shared_features = self.dropout_medium(shared2)
# Enhanced position prediction with skip connections
pos1 = F.relu(self.pos_fc1(shared_features))
pos1 = self.dropout_light(pos1)
pos2 = F.relu(self.pos_fc2(pos1))
pos2 = self.dropout_medium(pos2)
pos3 = F.relu(self.pos_fc3(pos2))
pos3 = self.dropout_light(pos3)
position = self.pos_fc4(pos3)
outputs = [position]
if self.predict_score:
# Enhanced score prediction
score1 = F.relu(self.score_fc1(shared_features))
score1 = self.dropout_light(score1)
score2 = F.relu(self.score_fc2(score1))
score2 = self.dropout_medium(score2)
score3 = F.relu(self.score_fc3(score2))
score3 = self.dropout_light(score3)
score4 = F.relu(self.score_fc4(score3))
score4 = self.dropout_light(score4)
score = F.relu(self.score_fc5(score4)) # Ensure positive distance
outputs.append(score)
if self.predict_class:
# Classification prediction
class1 = F.relu(self.class_fc1(shared_features))
class1 = self.dropout_light(class1)
class2 = F.relu(self.class_fc2(class1))
class2 = self.dropout_medium(class2)
class3 = F.relu(self.class_fc3(class2))
class3 = self.dropout_light(class3)
class4 = F.relu(self.class_fc4(class3))
class4 = self.dropout_light(class4)
classification = self.class_fc5(class4) # Raw logits
outputs.append(classification)
# Return outputs based on configuration
if len(outputs) == 1:
return outputs[0] # Only position
elif len(outputs) == 2:
if self.predict_score:
return outputs[0], outputs[1] # position, score
else:
return outputs[0], outputs[1] # position, classification
else:
return outputs[0], outputs[1], outputs[2] # position, score, classification
class PatchDataset(Dataset):
"""
Dataset class for loading saved patches for PointNet training.
"""
def __init__(self, dataset_dir: str, max_points: int = 1024, augment: bool = True):
self.dataset_dir = dataset_dir
self.max_points = max_points
self.augment = augment
# Load patch files
self.patch_files = []
for file in os.listdir(dataset_dir):
if file.endswith('.pkl'):
self.patch_files.append(os.path.join(dataset_dir, file))
print(f"Found {len(self.patch_files)} patch files in {dataset_dir}")
def __len__(self):
return len(self.patch_files)
def __getitem__(self, idx):
"""
Load and process a patch for training.
Returns:
patch_data: (7, max_points) tensor of point cloud data
target: (3,) tensor of target 3D coordinates
valid_mask: (max_points,) boolean tensor indicating valid points
distance_to_gt: scalar tensor of distance from initial prediction to GT
classification: scalar tensor for binary classification (1 if GT vertex present, 0 if not)
"""
patch_file = self.patch_files[idx]
with open(patch_file, 'rb') as f:
patch_info = pickle.load(f)
patch_7d = patch_info['patch_7d'] # (N, 7)
target = patch_info.get('assigned_wf_vertex', None) # (3,) or None
initial_pred = patch_info.get('cluster_center', None) # (3,) or None
# Determine classification label based on GT vertex presence
has_gt_vertex = 1.0 if target is not None else 0.0
# Handle patches without ground truth
if target is None:
# Use a dummy target for consistency, but mark as invalid with classification
target = np.zeros(3)
else:
target = np.array(target)
# Pad or sample points to max_points
num_points = patch_7d.shape[0]
if num_points >= self.max_points:
# Randomly sample max_points
indices = np.random.choice(num_points, self.max_points, replace=False)
patch_sampled = patch_7d[indices]
valid_mask = np.ones(self.max_points, dtype=bool)
else:
# Pad with zeros
patch_sampled = np.zeros((self.max_points, 7))
patch_sampled[:num_points] = patch_7d
valid_mask = np.zeros(self.max_points, dtype=bool)
valid_mask[:num_points] = True
# Data augmentation (only if GT vertex is present)
if self.augment and has_gt_vertex > 0:
patch_sampled, target = self._augment_patch(patch_sampled, valid_mask, target)
# Convert to tensors and transpose for conv1d (channels first)
patch_tensor = torch.from_numpy(patch_sampled.T).float() # (7, max_points)
target_tensor = torch.from_numpy(target).float() # (3,)
valid_mask_tensor = torch.from_numpy(valid_mask)
# Handle initial_pred
if initial_pred is not None:
initial_pred_tensor = torch.from_numpy(initial_pred).float()
else:
initial_pred_tensor = torch.zeros(3).float()
# Classification tensor
classification_tensor = torch.tensor(has_gt_vertex).float()
return patch_tensor, target_tensor, valid_mask_tensor, initial_pred_tensor, classification_tensor
def save_patches_dataset(patches: List[Dict], dataset_dir: str, entry_id: str):
"""
Save patches from prediction pipeline to create a training dataset.
Args:
patches: List of patch dictionaries from generate_patches()
dataset_dir: Directory to save the dataset
entry_id: Unique identifier for this entry/image
"""
os.makedirs(dataset_dir, exist_ok=True)
for i, patch in enumerate(patches):
# Create unique filename
filename = f"{entry_id}_patch_{i}.pkl"
filepath = os.path.join(dataset_dir, filename)
# Skip if file already exists
if os.path.exists(filepath):
continue
# Save patch data
with open(filepath, 'wb') as f:
pickle.dump(patch, f)
print(f"Saved {len(patches)} patches for entry {entry_id}")
# Create dataloader with custom collate function to filter invalid samples
def collate_fn(batch):
valid_batch = []
for patch_data, target, valid_mask, initial_pred, classification in batch:
# Filter out invalid samples (no valid points)
if valid_mask.sum() > 0:
valid_batch.append((patch_data, target, valid_mask, initial_pred, classification))
if len(valid_batch) == 0:
return None
# Stack valid samples
patch_data = torch.stack([item[0] for item in valid_batch])
targets = torch.stack([item[1] for item in valid_batch])
valid_masks = torch.stack([item[2] for item in valid_batch])
initial_preds = torch.stack([item[3] for item in valid_batch])
classifications = torch.stack([item[4] for item in valid_batch])
return patch_data, targets, valid_masks, initial_preds, classifications
# Initialize weights using Xavier/Glorot initialization
def init_weights(m):
if isinstance(m, nn.Conv1d):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.BatchNorm1d):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
def train_pointnet(dataset_dir: str, model_save_path: str, epochs: int = 100, batch_size: int = 32, lr: float = 0.001,
score_weight: float = 0.1, class_weight: float = 0.5):
"""
Train the FastPointNet model on saved patches.
Args:
dataset_dir: Directory containing saved patch files
model_save_path: Path to save the trained model
epochs: Number of training epochs
batch_size: Training batch size
lr: Learning rate
score_weight: Weight for the distance prediction loss
class_weight: Weight for the classification loss
"""
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Training on device: {device}")
# Create dataset and dataloader
dataset = PatchDataset(dataset_dir, max_points=1024, augment=False)
print(f"Dataset loaded with {len(dataset)} samples")
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=8,
collate_fn=collate_fn, drop_last=True)
# Initialize model with score and classification prediction
model = FastPointNet(input_dim=7, output_dim=3, max_points=1024, predict_score=True, predict_class=True, num_classes=1)
model.apply(init_weights)
model.to(device)
# Loss functions
position_criterion = nn.MSELoss()
score_criterion = nn.MSELoss()
classification_criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.5)
# Training loop
model.train()
for epoch in range(epochs):
total_loss = 0.0
total_pos_loss = 0.0
total_score_loss = 0.0
total_class_loss = 0.0
num_batches = 0
for batch_idx, batch_data in enumerate(dataloader):
if batch_data is None: # Skip invalid batches
continue
patch_data, targets, valid_masks, initial_preds, classifications = batch_data
patch_data = patch_data.to(device) # (batch_size, 7, max_points)
targets = targets.to(device) # (batch_size, 3)
classifications = classifications.to(device) # (batch_size,)
# Forward pass
optimizer.zero_grad()
predictions, predicted_scores, predicted_classes = model(patch_data)
# Compute actual distance from predictions to targets
actual_distances = torch.norm(predictions - targets, dim=1, keepdim=True)
# Only compute position and score losses for samples with GT vertices
has_gt_mask = classifications > 0.5
if has_gt_mask.sum() > 0:
# Position loss only for samples with GT vertices
pos_loss = position_criterion(predictions[has_gt_mask], targets[has_gt_mask])
score_loss = score_criterion(predicted_scores[has_gt_mask], actual_distances[has_gt_mask])
else:
pos_loss = torch.tensor(0.0, device=device)
score_loss = torch.tensor(0.0, device=device)
# Classification loss for all samples
class_loss = classification_criterion(predicted_classes.squeeze(), classifications)
# Combined loss
total_batch_loss = pos_loss + score_weight * score_loss + class_weight * class_loss
# Backward pass
total_batch_loss.backward()
optimizer.step()
total_loss += total_batch_loss.item()
total_pos_loss += pos_loss.item()
total_score_loss += score_loss.item()
total_class_loss += class_loss.item()
num_batches += 1
if batch_idx % 50 == 0:
print(f"Epoch {epoch+1}/{epochs}, Batch {batch_idx}, "
f"Total Loss: {total_batch_loss.item():.6f}, "
f"Pos Loss: {pos_loss.item():.6f}, "
f"Score Loss: {score_loss.item():.6f}, "
f"Class Loss: {class_loss.item():.6f}")
avg_loss = total_loss / num_batches if num_batches > 0 else 0
avg_pos_loss = total_pos_loss / num_batches if num_batches > 0 else 0
avg_score_loss = total_score_loss / num_batches if num_batches > 0 else 0
avg_class_loss = total_class_loss / num_batches if num_batches > 0 else 0
print(f"Epoch {epoch+1}/{epochs} completed, "
f"Avg Total Loss: {avg_loss:.6f}, "
f"Avg Pos Loss: {avg_pos_loss:.6f}, "
f"Avg Score Loss: {avg_score_loss:.6f}, "
f"Avg Class Loss: {avg_class_loss:.6f}")
scheduler.step()
# Save model checkpoint every epoch
checkpoint_path = model_save_path.replace('.pth', f'_epoch_{epoch+1}.pth')
torch.save({
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'epoch': epoch + 1,
'loss': avg_loss,
}, checkpoint_path)
# Save the trained model
torch.save({
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'epoch': epochs,
}, model_save_path)
print(f"Model saved to {model_save_path}")
return model
def load_pointnet_model(model_path: str, device: torch.device = None, predict_score: bool = True) -> FastPointNet:
"""
Load a trained FastPointNet model.
Args:
model_path: Path to the saved model
device: Device to load the model on
predict_score: Whether the model predicts scores
Returns:
Loaded FastPointNet model
"""
if device is None:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = FastPointNet(input_dim=7, output_dim=3, max_points=1024, predict_score=predict_score)
checkpoint = torch.load(model_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
model.eval()
return model
def predict_vertex_from_patch(model: FastPointNet, patch: np.ndarray, device: torch.device = None) -> Tuple[np.ndarray, float, float]:
"""
Predict 3D vertex coordinates, confidence score, and classification from a patch using trained PointNet.
Args:
model: Trained FastPointNet model
patch: Dictionary containing patch data with 'patch_7d' and 'offset' keys
device: Device to run prediction on
Returns:
tuple of (predicted_coordinates, confidence_score, classification_score)
predicted_coordinates: (3,) numpy array of predicted 3D coordinates
confidence_score: float representing predicted distance to GT (lower is better)
classification_score: float representing probability of GT vertex presence (0-1)
"""
if device is None:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
patch_7d = patch['patch_7d'] # (N, 7)
# Prepare input
max_points = 1024
num_points = patch_7d.shape[0]
if num_points >= max_points:
# Sample points
indices = np.random.choice(num_points, max_points, replace=False)
patch_sampled = patch_7d[indices]
else:
# Pad with zeros
patch_sampled = np.zeros((max_points, 7))
patch_sampled[:num_points] = patch_7d
# Convert to tensor
patch_tensor = torch.from_numpy(patch_sampled.T).float().unsqueeze(0) # (1, 7, max_points)
patch_tensor = patch_tensor.to(device)
# Predict
with torch.no_grad():
outputs = model(patch_tensor)
if model.predict_score and model.predict_class:
position, score, classification = outputs
position = position.cpu().numpy().squeeze()
score = score.cpu().numpy().squeeze()
classification = torch.sigmoid(classification).cpu().numpy().squeeze() # Apply sigmoid for probability
elif model.predict_score:
position, score = outputs
position = position.cpu().numpy().squeeze()
score = score.cpu().numpy().squeeze()
classification = None
elif model.predict_class:
position, classification = outputs
position = position.cpu().numpy().squeeze()
score = None
classification = torch.sigmoid(classification).cpu().numpy().squeeze() # Apply sigmoid for probability
else:
position = outputs
position = position.cpu().numpy().squeeze()
score = None
classification = None
# Apply offset correction
offset = patch['cluster_center']
position += offset
return position, score, classification