hoho / fully_deep.py
jskvrna's picture
Updates data directory path
5012f1c
raw
history blame
42.8 kB
import torch
import os
import pickle
from torch.utils.data import Dataset, DataLoader
import numpy as np
from scipy.optimize import linear_sum_assignment
import torch.nn as nn
import torch.nn.functional as F
# =============================================================================
# CONFIGURATION PARAMETERS
# =============================================================================
# Dataset Configuration
DATA_DIR = '/mnt/personal/skvrnjan/hoho_fully'
SPLIT = 'train'
MAX_POINTS = 8096
BATCH_SIZE = 32
NUM_WORKERS = 8
# Model Architecture Parameters
PC_INPUT_FEATURES = 3
PC_ENCODER_OUTPUT_FEATURES = 128
MAX_VERTICES = 50
VERTEX_COORD_DIM = 3
GNN_HIDDEN_DIM = 64
NUM_GNN_LAYERS = 2
HIDDEN_DIM = 256
NUM_DECODER_LAYERS = 3
NUM_HEADS = 8
# PointNet2 Encoder Parameters
SA1_NPOINT = 1024
SA1_RADIUS = 0.2
SA1_NSAMPLE = 32
SA1_MLP = [64, 64, 128]
SA2_NPOINT = 256
SA2_RADIUS = 0.4
SA2_NSAMPLE = 64
SA2_MLP = [128, 128, 256]
SA3_MLP = [256, 512, 1024] # Global pooling layer
FP3_MLP = [256, 256]
FP2_MLP = [256, 128]
FP1_MLP = [128, 128] # Will add PC_ENCODER_OUTPUT_FEATURES at the end
# Vertex Prediction Head Parameters
VERTEX_TRANSFORMER_DROPOUT = 0.1
VERTEX_TRANSFORMER_FFN_RATIO = 4
# Edge Prediction Head Parameters
EDGE_GNN_NUM_HEADS = 4
EDGE_GNN_DROPOUT = 0.1
EDGE_K_NEIGHBORS = 8
# Training Configuration
NUM_EPOCHS = 100
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 1e-5
GRADIENT_CLIP_MAX_NORM = 1.0
# Loss Weights
VERTEX_LOSS_WEIGHT = 1.0
EDGE_LOSS_WEIGHT = 0.5
CONFIDENCE_LOSS_WEIGHT = 0.3
# Learning Rate Scheduler Parameters
LR_SCHEDULER_FACTOR = 0.5
LR_SCHEDULER_PATIENCE = 10
# Checkpoint and Logging
CHECKPOINT_SAVE_FREQUENCY = 1 # Save every N epochs
LOG_FREQUENCY = 10 # Print progress every N batches
# Device Configuration
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# =============================================================================
# MODEL IMPLEMENTATION
# =============================================================================
# You would likely need a library like torch_geometric for GNNs
# from torch_geometric.nn import GATConv, EdgeConv # Example GNN layers
# --- 1. Point Cloud Encoder Backbone (Placeholder) ---
class PointNet2Encoder(nn.Module):
def __init__(self, input_features, output_features):
super().__init__()
self.input_features = input_features
self.output_features = output_features
# Set Abstraction layers - adjusted for 8096 input points
self.sa1 = SetAbstractionLayer(
npoint=SA1_NPOINT, radius=SA1_RADIUS, nsample=SA1_NSAMPLE,
in_channel=input_features + 3, mlp=SA1_MLP
)
self.sa2 = SetAbstractionLayer(
npoint=SA2_NPOINT, radius=SA2_RADIUS, nsample=SA2_NSAMPLE,
in_channel=SA1_MLP[-1] + 3, mlp=SA2_MLP
)
self.sa3 = SetAbstractionLayer(
npoint=None, radius=None, nsample=None, # Global pooling
in_channel=SA2_MLP[-1] + 3, mlp=SA3_MLP
)
# Feature Propagation layers for point-wise features
self.fp3 = FeaturePropagationLayer(in_channel=SA3_MLP[-1] + SA2_MLP[-1], mlp=FP3_MLP)
self.fp2 = FeaturePropagationLayer(in_channel=FP3_MLP[-1] + SA1_MLP[-1], mlp=FP2_MLP)
self.fp1 = FeaturePropagationLayer(in_channel=FP2_MLP[-1] + input_features, mlp=FP1_MLP + [output_features])
def forward(self, xyz):
# xyz: (B, N, 3) where N = 8096
B, N, _ = xyz.shape
# Initial features (can be empty or coordinates)
points = xyz if self.input_features == 3 else None
# Set Abstraction
l1_xyz, l1_points = self.sa1(xyz, points) # 8096 -> 1024 points
l2_xyz, l2_points = self.sa2(l1_xyz, l1_points) # 1024 -> 256 points
l3_xyz, l3_points = self.sa3(l2_xyz, l2_points) # 256 -> 1 point (global)
# Feature Propagation
l2_points = self.fp3(l2_xyz, l3_xyz, l2_points, l3_points)
l1_points = self.fp2(l1_xyz, l2_xyz, l1_points, l2_points)
l0_points = self.fp1(xyz, l1_xyz, points, l1_points)
# Global feature from the most abstract level
global_feature = l3_points.squeeze(-1) # (B, 1024)
return l0_points, global_feature # (B, 8096, output_features), (B, 1024)
class SetAbstractionLayer(nn.Module):
def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all=False):
super().__init__()
self.npoint = npoint
self.radius = radius
self.nsample = nsample
self.group_all = group_all
self.mlp_convs = nn.ModuleList()
self.mlp_bns = nn.ModuleList()
last_channel = in_channel
for out_channel in mlp:
self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1))
self.mlp_bns.append(nn.BatchNorm2d(out_channel))
last_channel = out_channel
def forward(self, xyz, points):
# xyz: (B, N, 3)
# points: (B, N, C) or None
B, N, C = xyz.shape
if self.group_all or self.npoint is None:
# Global pooling
new_xyz = xyz.mean(dim=1, keepdim=True) # (B, 1, 3)
if points is not None:
new_points = torch.cat([xyz, points], dim=-1) # (B, N, 3+C)
new_points = new_points.transpose(1, 2).unsqueeze(-1) # (B, 3+C, N, 1)
else:
new_points = xyz.transpose(1, 2).unsqueeze(-1) # (B, 3, N, 1)
else:
# Farthest Point Sampling
fps_idx = farthest_point_sample(xyz, self.npoint) # (B, npoint)
new_xyz = index_points(xyz, fps_idx) # (B, npoint, 3)
# Ball Query
idx = ball_query(self.radius, self.nsample, xyz, new_xyz) # (B, npoint, nsample)
grouped_xyz = index_points(xyz, idx) # (B, npoint, nsample, 3)
grouped_xyz_norm = grouped_xyz - new_xyz.unsqueeze(2) # Relative positions
if points is not None:
grouped_points = index_points(points, idx) # (B, npoint, nsample, C)
new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1) # (B, npoint, nsample, 3+C)
else:
new_points = grouped_xyz_norm # (B, npoint, nsample, 3)
new_points = new_points.permute(0, 3, 1, 2) # (B, 3+C, npoint, nsample)
# MLP
for i, conv in enumerate(self.mlp_convs):
bn = self.mlp_bns[i]
new_points = F.relu(bn(conv(new_points)))
# Max pooling
new_points = torch.max(new_points, dim=-1)[0] # (B, mlp[-1], npoint)
new_points = new_points.transpose(1, 2) # (B, npoint, mlp[-1])
return new_xyz, new_points
class FeaturePropagationLayer(nn.Module):
def __init__(self, in_channel, mlp):
super().__init__()
self.mlp_convs = nn.ModuleList()
self.mlp_bns = nn.ModuleList()
last_channel = in_channel
for out_channel in mlp:
self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1))
self.mlp_bns.append(nn.BatchNorm1d(out_channel))
last_channel = out_channel
def forward(self, xyz1, xyz2, points1, points2):
# xyz1: (B, N1, 3) - target points
# xyz2: (B, N2, 3) - source points
# points1: (B, N1, C1) - target features
# points2: (B, N2, C2) - source features
# Interpolate features from xyz2 to xyz1
if points2 is not None:
interpolated_points = interpolate_features(xyz1, xyz2, points2) # (B, N1, C2)
if points1 is not None:
# Ensure both tensors have the same number of points (N1)
assert points1.shape[1] == interpolated_points.shape[1], f"Point count mismatch: {points1.shape[1]} vs {interpolated_points.shape[1]}"
new_points = torch.cat([points1, interpolated_points], dim=-1) # (B, N1, C1+C2)
else:
new_points = interpolated_points
else:
new_points = points1
# Handle None case
if new_points is None:
return None
# MLP
new_points = new_points.transpose(1, 2) # (B, C, N1)
for i, conv in enumerate(self.mlp_convs):
bn = self.mlp_bns[i]
new_points = F.relu(bn(conv(new_points)))
return new_points.transpose(1, 2) # (B, N1, mlp[-1])
def farthest_point_sample(xyz, npoint):
"""Farthest Point Sampling"""
device = xyz.device
B, N, C = xyz.shape
centroids = torch.zeros(B, npoint, dtype=torch.long).to(device)
distance = torch.ones(B, N).to(device) * 1e10
farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device)
for i in range(npoint):
centroids[:, i] = farthest
centroid = xyz[torch.arange(B), farthest, :].view(B, 1, 3)
dist = torch.sum((xyz - centroid) ** 2, -1)
mask = dist < distance
distance[mask] = dist[mask]
farthest = torch.max(distance, -1)[1]
return centroids
def ball_query(radius, nsample, xyz, new_xyz):
"""Ball Query"""
device = xyz.device
B, N, C = xyz.shape
_, S, _ = new_xyz.shape
group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1])
sqrdists = square_distance(new_xyz, xyz)
group_idx[sqrdists > radius ** 2] = N
group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]
group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample])
mask = group_idx == N
group_idx[mask] = group_first[mask]
# If group_first[mask] was N (i.e., no points in the ball for a centroid),
# group_idx can still contain N. Clamp N to 0 to ensure valid indices.
# N corresponds to xyz.shape[1], which is guaranteed to be > 0 by the dataloader logic.
group_idx[group_idx == N] = 0
return group_idx
def square_distance(src, dst):
"""Calculate squared distance between each two points"""
B, N, _ = src.shape
_, M, _ = dst.shape
dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
dist += torch.sum(src ** 2, -1).view(B, N, 1)
dist += torch.sum(dst ** 2, -1).view(B, 1, M)
return dist
def index_points(points, idx):
"""Index points using given indices"""
device = points.device
B = points.shape[0]
view_shape = list(idx.shape)
view_shape[1:] = [1] * (len(view_shape) - 1)
repeat_shape = list(idx.shape)
repeat_shape[0] = 1
batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
new_points = points[batch_indices, idx, :]
return new_points
def interpolate_features(xyz1, xyz2, points2):
"""Interpolate features using inverse distance weighting"""
B, N1, C = xyz1.shape
_, N2, _ = xyz2.shape
if N2 == 1:
# If only one point, broadcast to all target points
interpolated_points = points2.expand(B, N1, -1)
else:
# Find 3 nearest neighbors and interpolate
dists = square_distance(xyz1, xyz2) # (B, N1, N2)
dists, idx = dists.sort(dim=-1)
# Use min(3, N2) neighbors to handle cases with fewer source points
k = min(3, N2)
dists, idx = dists[:, :, :k], idx[:, :, :k]
# Inverse distance weighting
dists[dists < 1e-10] = 1e-10
weight = 1.0 / dists # (B, N1, k)
weight = weight / torch.sum(weight, dim=-1, keepdim=True) # Normalize
# Interpolate
interpolated_points = torch.sum(
index_points(points2, idx) * weight.view(B, N1, k, 1), dim=2
)
return interpolated_points
# --- 2. Vertex Prediction Head (Transformer-based) ---
class VertexPredictionHead(nn.Module):
def __init__(self, point_feature_dim, global_feature_dim, max_vertices, vertex_coord_dim=3,
hidden_dim=256, num_decoder_layers=3, num_heads=8):
super().__init__()
self.max_vertices = max_vertices
self.vertex_coord_dim = vertex_coord_dim
self.hidden_dim = hidden_dim
# Learnable vertex queries (similar to DETR object queries)
self.vertex_queries = nn.Parameter(torch.randn(max_vertices, hidden_dim))
# Project global feature to hidden dimension
self.global_proj = nn.Linear(global_feature_dim, 1)
# Project point features to hidden dimension for cross-attention
self.point_proj = nn.Linear(point_feature_dim, hidden_dim)
# Transformer decoder layers
decoder_layer = nn.TransformerDecoderLayer(
d_model=hidden_dim,
nhead=num_heads,
dim_feedforward=hidden_dim * VERTEX_TRANSFORMER_FFN_RATIO,
dropout=VERTEX_TRANSFORMER_DROPOUT,
batch_first=True
)
self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_decoder_layers)
# Output heads
self.vertex_coord_head = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, vertex_coord_dim)
)
# Confidence/existence head (predicts if vertex exists)
self.vertex_conf_head = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1)
)
# Position encoding for point features
self.pos_encoding = nn.Sequential(
nn.Linear(3, hidden_dim // 2),
nn.ReLU(),
nn.Linear(hidden_dim // 2, hidden_dim)
)
def forward(self, point_features, global_feature, point_coords=None):
# point_features: (B, N, point_feature_dim)
# global_feature: (B, global_feature_dim)
# point_coords: (B, N, 3) - optional point coordinates for positional encoding
batch_size = point_features.shape[0]
# Project features to hidden dimension
point_features_proj = self.point_proj(point_features) # (B, N, hidden_dim)
# Add positional encoding if coordinates are provided
if point_coords is not None:
pos_enc = self.pos_encoding(point_coords) # (B, N, hidden_dim)
point_features_proj = point_features_proj + pos_enc
# Prepare vertex queries
vertex_queries = self.vertex_queries.unsqueeze(0).repeat(batch_size, 1, 1) # (B, max_vertices, hidden_dim)
# Add global context to vertex queries
global_proj = self.global_proj(global_feature).squeeze(-1).unsqueeze(1) # (B, 1, hidden_dim)
vertex_queries = vertex_queries + global_proj # Broadcasting will handle (B, 1, hidden_dim) + (B, max_vertices, hidden_dim)
# Transformer decoder: vertex queries attend to point features
vertex_features = self.transformer_decoder(
tgt=vertex_queries, # (B, max_vertices, hidden_dim)
memory=point_features_proj # (B, N, hidden_dim)
) # (B, max_vertices, hidden_dim)
# Predict vertex coordinates
predicted_vertices = self.vertex_coord_head(vertex_features) # (B, max_vertices, 3)
# Predict vertex confidence/existence
vertex_confidence = self.vertex_conf_head(vertex_features).squeeze(-1) # (B, max_vertices)
return predicted_vertices, vertex_confidence
# --- 3. Edge Prediction Head (GNN-based) ---
class EdgePredictionHeadGNN(nn.Module):
def __init__(self, vertex_feature_dim, gnn_hidden_dim, num_gnn_layers):
super().__init__()
self.vertex_feature_dim = vertex_feature_dim
self.gnn_hidden_dim = gnn_hidden_dim
self.num_gnn_layers = num_gnn_layers
# Initial vertex feature projection
self.vertex_proj = nn.Linear(vertex_feature_dim, gnn_hidden_dim)
# GNN layers using message passing
self.gnn_layers = nn.ModuleList()
for i in range(num_gnn_layers):
self.gnn_layers.append(
GraphAttentionLayer(
in_features=gnn_hidden_dim,
out_features=gnn_hidden_dim,
num_heads=EDGE_GNN_NUM_HEADS,
dropout=EDGE_GNN_DROPOUT
)
)
# Edge classifier MLP
self.edge_mlp = nn.Sequential(
nn.Linear(gnn_hidden_dim * 2, gnn_hidden_dim),
nn.ReLU(),
nn.Dropout(EDGE_GNN_DROPOUT),
nn.Linear(gnn_hidden_dim, gnn_hidden_dim // 2),
nn.ReLU(),
nn.Linear(gnn_hidden_dim // 2, 1)
)
# Learnable threshold for k-NN graph construction
self.k_neighbors = EDGE_K_NEIGHBORS # Number of nearest neighbors for initial graph
def forward(self, vertices):
# vertices: (B, num_vertices, vertex_coord_dim)
batch_size, num_vertices, _ = vertices.shape
# Project vertex coordinates to hidden features
vertex_features = self.vertex_proj(vertices) # (B, num_vertices, gnn_hidden_dim)
# Construct initial graph based on spatial proximity (k-NN)
adjacency_matrix = self.construct_knn_graph(vertices, k=self.k_neighbors) # (B, num_vertices, num_vertices)
# Apply GNN layers
for gnn_layer in self.gnn_layers:
vertex_features = gnn_layer(vertex_features, adjacency_matrix) # (B, num_vertices, gnn_hidden_dim)
# Generate all possible vertex pairs
idx_pairs = torch.combinations(torch.arange(num_vertices), r=2).to(vertices.device) # (num_pairs, 2)
# Gather features for all vertex pairs
v1_features = vertex_features[:, idx_pairs[:, 0], :] # (B, num_pairs, gnn_hidden_dim)
v2_features = vertex_features[:, idx_pairs[:, 1], :] # (B, num_pairs, gnn_hidden_dim)
# Concatenate paired vertex features
edge_features = torch.cat([v1_features, v2_features], dim=2) # (B, num_pairs, gnn_hidden_dim * 2)
# Predict edge probabilities
edge_logits = self.edge_mlp(edge_features).squeeze(-1) # (B, num_pairs)
return edge_logits, idx_pairs
def construct_knn_graph(self, vertices, k):
# vertices: (B, num_vertices, 3)
batch_size, num_vertices, _ = vertices.shape
# Compute pairwise distances
distances = torch.cdist(vertices, vertices, p=2) # (B, num_vertices, num_vertices)
# Find k nearest neighbors for each vertex
_, knn_indices = torch.topk(distances, k + 1, dim=-1, largest=False) # +1 to include self
knn_indices = knn_indices[:, :, 1:] # Remove self-connection
# Create adjacency matrix
adjacency = torch.zeros(batch_size, num_vertices, num_vertices, device=vertices.device)
# Fill adjacency matrix
batch_idx = torch.arange(batch_size).view(-1, 1, 1).expand(-1, num_vertices, k)
vertex_idx = torch.arange(num_vertices).view(1, -1, 1).expand(batch_size, -1, k)
adjacency[batch_idx, vertex_idx, knn_indices] = 1.0
# Make adjacency symmetric
adjacency = torch.max(adjacency, adjacency.transpose(-1, -2))
return adjacency
class GraphAttentionLayer(nn.Module):
def __init__(self, in_features, out_features, num_heads=1, dropout=0.1):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.num_heads = num_heads
self.dropout = dropout
assert out_features % num_heads == 0
self.head_dim = out_features // num_heads
# Linear transformations for queries, keys, values
self.W_q = nn.Linear(in_features, out_features)
self.W_k = nn.Linear(in_features, out_features)
self.W_v = nn.Linear(in_features, out_features)
# Output projection
self.W_o = nn.Linear(out_features, out_features)
# Attention mechanism
self.attention = nn.MultiheadAttention(
embed_dim=out_features,
num_heads=num_heads,
dropout=dropout,
batch_first=True
)
# Layer normalization and residual connection
self.layer_norm = nn.LayerNorm(out_features)
self.ffn = nn.Sequential(
nn.Linear(out_features, out_features * 2),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(out_features * 2, out_features)
)
self.layer_norm2 = nn.LayerNorm(out_features)
def forward(self, x, adjacency_matrix):
# x: (B, num_vertices, in_features)
# adjacency_matrix: (B, num_vertices, num_vertices)
batch_size, num_vertices, _ = x.shape
# Project to query, key, value
Q = self.W_q(x) # (B, num_vertices, out_features)
K = self.W_k(x) # (B, num_vertices, out_features)
V = self.W_v(x) # (B, num_vertices, out_features)
# Create attention mask from adjacency matrix
# Convert adjacency to attention mask (0 for allowed, -inf for masked)
attention_mask = (1 - adjacency_matrix) * (-1e9) # (B, num_vertices, num_vertices)
# Apply multi-head attention with adjacency-based masking
attended_features = []
for b in range(batch_size):
q_b = Q[b:b+1] # (1, num_vertices, out_features)
k_b = K[b:b+1] # (1, num_vertices, out_features)
v_b = V[b:b+1] # (1, num_vertices, out_features)
mask_b = attention_mask[b] # (num_vertices, num_vertices)
# Apply attention
attn_output, _ = self.attention(q_b, k_b, v_b, attn_mask=mask_b)
attended_features.append(attn_output)
attended_features = torch.cat(attended_features, dim=0) # (B, num_vertices, out_features)
# Residual connection and layer norm
x_residual = self.layer_norm(attended_features + Q)
# Feed-forward network
ffn_output = self.ffn(x_residual)
output = self.layer_norm2(ffn_output + x_residual)
return output
# --- Main Model ---
class PointCloudToWireframe(nn.Module):
def __init__(self,
pc_input_features=PC_INPUT_FEATURES,
pc_encoder_output_features=PC_ENCODER_OUTPUT_FEATURES,
max_vertices=MAX_VERTICES,
vertex_coord_dim=VERTEX_COORD_DIM,
gnn_hidden_dim=GNN_HIDDEN_DIM,
num_gnn_layers=NUM_GNN_LAYERS,
hidden_dim=HIDDEN_DIM,
num_decoder_layers=NUM_DECODER_LAYERS,
num_heads=NUM_HEADS):
super().__init__()
# Point cloud encoder using PointNet2-style architecture
self.encoder = PointNet2Encoder(pc_input_features, pc_encoder_output_features)
# Vertex prediction head using transformer decoder
self.vertex_head = VertexPredictionHead(
point_feature_dim=pc_encoder_output_features,
global_feature_dim=SA3_MLP[-1], # From PointNet2Encoder global feature
max_vertices=max_vertices,
vertex_coord_dim=vertex_coord_dim,
hidden_dim=hidden_dim,
num_decoder_layers=num_decoder_layers,
num_heads=num_heads
)
# Edge prediction head using GNN
self.edge_head = EdgePredictionHeadGNN(
vertex_feature_dim=vertex_coord_dim,
gnn_hidden_dim=gnn_hidden_dim,
num_gnn_layers=num_gnn_layers
)
def forward(self, point_cloud):
# point_cloud: (B, N, 3)
batch_size, num_points, _ = point_cloud.shape
# Encode point cloud
point_features, global_feature = self.encoder(point_cloud)
# point_features: (B, N, pc_encoder_output_features)
# global_feature: (B, 1024)
# Predict vertices
predicted_vertices, vertex_confidence = self.vertex_head(
point_features, global_feature, point_coords=point_cloud
)
# predicted_vertices: (B, max_vertices, 3)
# vertex_confidence: (B, max_vertices)
# Predict edges using GNN (using vertex coordinates directly)
edge_logits, edge_indices = self.edge_head(predicted_vertices)
# edge_logits: (B, num_potential_edges)
# edge_indices: (num_potential_edges, 2)
return {
'vertices': predicted_vertices,
'vertex_confidence': vertex_confidence,
'edge_logits': edge_logits,
'edge_indices': edge_indices
}
class WireframeDataset(Dataset):
def __init__(self, data_dir=DATA_DIR, split=SPLIT, transform=None, max_points=MAX_POINTS):
"""
Dataset for point cloud to wireframe conversion.
Args:
data_dir: Directory containing the pickle files
split: 'train', 'val', or 'test'
transform: Optional transforms to apply to point clouds
max_points: Maximum number of points in the point cloud (default: 8096)
"""
self.data_dir = data_dir
self.split = split
self.transform = transform
self.max_points = max_points
# Get all pickle files in the directory
self.data_files = []
for file in os.listdir(data_dir):
if file.endswith('.pkl'):
self.data_files.append(os.path.join(data_dir, file))
self.data_files.sort() # Ensure consistent ordering
def __len__(self):
return len(self.data_files)
def __getitem__(self, idx):
# Load the pickle file
with open(self.data_files[idx], 'rb') as f:
sample_data = pickle.load(f)
# Extract data
point_cloud = torch.tensor(sample_data['point_cloud'], dtype=torch.float32)
point_colors = torch.tensor(sample_data['point_colors'], dtype=torch.float32)
gt_vertices = torch.tensor(sample_data['gt_vertices'], dtype=torch.float32)
gt_connections = sample_data['gt_connections'] # List of tuples
sample_id = sample_data['sample_id']
# Handle point cloud size to match max_points
current_points = point_cloud.shape[0]
if current_points > self.max_points:
# Downsample using random sampling
indices = torch.randperm(current_points)[:self.max_points]
point_cloud = point_cloud[indices]
point_colors = point_colors[indices]
elif current_points < self.max_points:
# Pad by repeating last point or duplicating random points
pad_size = self.max_points - current_points
if current_points > 0:
# Randomly sample existing points to pad
pad_indices = torch.randint(0, current_points, (pad_size,))
pad_points = point_cloud[pad_indices]
pad_colors = point_colors[pad_indices]
point_cloud = torch.cat([point_cloud, pad_points], dim=0)
point_colors = torch.cat([point_colors, pad_colors], dim=0)
else:
# Edge case: no points, pad with zeros
point_cloud = torch.zeros(self.max_points, 3)
point_colors = torch.zeros(self.max_points, 3)
# Convert connections to edge format
if len(gt_connections) > 0:
edge_indices = torch.tensor(gt_connections, dtype=torch.long).t() # (2, num_edges)
else:
edge_indices = torch.zeros((2, 0), dtype=torch.long) # Empty edges
# Apply transforms if any
if self.transform:
point_cloud = self.transform(point_cloud)
return {
'point_cloud': point_cloud,
'point_colors': point_colors,
'gt_vertices': gt_vertices,
'edge_indices': edge_indices,
'sample_id': sample_id
}
def collate_fn(batch):
"""
Custom collate function to handle variable number of vertices and edges.
"""
point_clouds = []
point_colors = []
gt_vertices_list = []
edge_indices_list = []
sample_ids = []
max_vertices = 0
for sample in batch:
point_clouds.append(sample['point_cloud'])
point_colors.append(sample['point_colors'])
gt_vertices_list.append(sample['gt_vertices'])
edge_indices_list.append(sample['edge_indices'])
sample_ids.append(sample['sample_id'])
max_vertices = max(max_vertices, sample['gt_vertices'].shape[0])
# Pad point clouds to same size if needed
max_points = max(pc.shape[0] for pc in point_clouds)
padded_point_clouds = []
padded_point_colors = []
for pc, colors in zip(point_clouds, point_colors):
if pc.shape[0] < max_points:
# Pad with zeros or repeat last point
pad_size = max_points - pc.shape[0]
pc_padded = torch.cat([pc, torch.zeros(pad_size, 3)], dim=0)
colors_padded = torch.cat([colors, torch.zeros(pad_size, 3)], dim=0)
else:
pc_padded = pc
colors_padded = colors
padded_point_clouds.append(pc_padded)
padded_point_colors.append(colors_padded)
# Stack point clouds
point_clouds_batch = torch.stack(padded_point_clouds)
point_colors_batch = torch.stack(padded_point_colors)
# Pad vertices to max_vertices
padded_vertices = []
vertex_masks = [] # To indicate which vertices are real vs padded
for vertices in gt_vertices_list:
num_vertices = vertices.shape[0]
if num_vertices < max_vertices:
# Pad with zeros
pad_size = max_vertices - num_vertices
vertices_padded = torch.cat([vertices, torch.zeros(pad_size, 3)], dim=0)
mask = torch.cat([torch.ones(num_vertices), torch.zeros(pad_size)], dim=0).bool()
else:
vertices_padded = vertices
mask = torch.ones(num_vertices).bool()
padded_vertices.append(vertices_padded)
vertex_masks.append(mask)
gt_vertices_batch = torch.stack(padded_vertices)
vertex_masks_batch = torch.stack(vertex_masks)
# Create adjacency matrices for edges
batch_size = len(batch)
adjacency_matrices = torch.zeros(batch_size, max_vertices, max_vertices)
for i, edge_indices in enumerate(edge_indices_list):
if edge_indices.shape[1] > 0: # If there are edges
src, dst = edge_indices[0], edge_indices[1]
# Only add edges for valid vertices (within the actual vertex count)
valid_edges = (src < gt_vertices_list[i].shape[0]) & (dst < gt_vertices_list[i].shape[0])
src_valid = src[valid_edges]
dst_valid = dst[valid_edges]
adjacency_matrices[i, src_valid, dst_valid] = 1
adjacency_matrices[i, dst_valid, src_valid] = 1 # Undirected graph
return {
'point_cloud': point_clouds_batch,
'point_colors': point_colors_batch,
'gt_vertices': gt_vertices_batch,
'vertex_masks': vertex_masks_batch,
'adjacency_matrices': adjacency_matrices,
'edge_indices_list': edge_indices_list, # Keep original for loss computation
'sample_ids': sample_ids
}
# Loss functions
def compute_vertex_loss(pred_vertices, gt_vertices, vertex_masks, vertex_confidence):
"""
Compute vertex position loss using Hungarian matching
"""
batch_size = pred_vertices.shape[0]
total_loss = 0.0
total_confidence_loss = 0.0
for b in range(batch_size):
# Get valid GT vertices for this sample
valid_mask = vertex_masks[b]
gt_verts = gt_vertices[b][valid_mask] # (num_valid_gt, 3)
num_gt = gt_verts.shape[0]
if num_gt == 0:
# No GT vertices, penalize high confidence predictions
confidence_target = torch.zeros_like(vertex_confidence[b])
conf_loss = F.binary_cross_entropy_with_logits(vertex_confidence[b], confidence_target)
total_confidence_loss += conf_loss
continue
pred_verts = pred_vertices[b] # (max_vertices, 3)
pred_conf = vertex_confidence[b] # (max_vertices,)
# Compute pairwise distances between predicted and GT vertices
distances = torch.cdist(pred_verts, gt_verts) # (max_vertices, num_gt)
# Hungarian matching to find optimal assignment
# Convert to numpy for scipy
cost_matrix = distances.detach().cpu().numpy()
# Pad cost matrix if needed
if distances.shape[0] < distances.shape[1]:
# More GT vertices than predicted - pad with high cost
padding = np.full((distances.shape[1] - distances.shape[0], distances.shape[1]), 1e6)
cost_matrix = np.vstack([cost_matrix, padding])
elif distances.shape[0] > distances.shape[1]:
# More predicted vertices than GT - pad with high cost
padding = np.full((distances.shape[0], distances.shape[0] - distances.shape[1]), 1e6)
cost_matrix = np.hstack([cost_matrix, padding])
# Solve assignment problem
pred_indices, gt_indices = linear_sum_assignment(cost_matrix)
# Filter out dummy assignments (high cost padding)
# Ensure pred_indices are valid for pred_verts and gt_indices for gt_verts
valid_assignments = (pred_indices < pred_verts.shape[0]) & (gt_indices < num_gt)
pred_indices = pred_indices[valid_assignments]
gt_indices = gt_indices[valid_assignments]
if len(pred_indices) > 0:
# Compute position loss for matched vertices
matched_pred = pred_verts[pred_indices]
matched_gt = gt_verts[gt_indices]
position_loss = F.mse_loss(matched_pred, matched_gt)
total_loss += position_loss
# Confidence targets: 1 for matched vertices, 0 for unmatched
confidence_target = torch.zeros_like(pred_conf)
confidence_target[pred_indices] = 1.0
conf_loss = F.binary_cross_entropy_with_logits(pred_conf, confidence_target)
total_confidence_loss += conf_loss
else:
# No valid matches - penalize all predictions
confidence_target = torch.zeros_like(pred_conf)
conf_loss = F.binary_cross_entropy_with_logits(pred_conf, confidence_target)
total_confidence_loss += conf_loss
return total_loss / batch_size, total_confidence_loss / batch_size
def compute_edge_loss(edge_logits, edge_indices, gt_adjacency_matrices):
"""
Compute edge prediction loss
"""
batch_size = gt_adjacency_matrices.shape[0]
# Create edge targets from adjacency matrices
edge_targets = []
for b in range(batch_size):
gt_adj_for_sample = gt_adjacency_matrices[b] # Shape: (batch_max_gt_verts, batch_max_gt_verts)
# Create a target adjacency matrix of size (MAX_VERTICES, MAX_VERTICES)
# as edge_indices are generated based on the global MAX_VERTICES.
target_adj_full_size = torch.zeros(
MAX_VERTICES,
MAX_VERTICES,
device=gt_adj_for_sample.device,
dtype=gt_adj_for_sample.dtype
)
# Determine the actual dimension of the current sample's GT adjacency matrix (padded to batch max)
current_gt_dim = gt_adj_for_sample.shape[0]
# Copy the relevant part of gt_adj_for_sample into the full-sized target matrix.
# The copy_dim is the minimum of MAX_VERTICES and the current GT dimension,
# ensuring we don't read out of bounds from gt_adj_for_sample or write out of bounds to target_adj_full_size.
copy_dim = min(MAX_VERTICES, current_gt_dim)
target_adj_full_size[:copy_dim, :copy_dim] = gt_adj_for_sample[:copy_dim, :copy_dim]
# Extract targets using edge_indices, which refer to pairs in a MAX_VERTICES graph.
targets = target_adj_full_size[edge_indices[:, 0], edge_indices[:, 1]]
edge_targets.append(targets)
edge_targets = torch.stack(edge_targets) # Shape: (batch_size, num_potential_edges_in_MAX_VERTICES_graph)
edge_targets = edge_targets.to(edge_logits.device)
# Binary cross entropy loss
edge_loss = F.binary_cross_entropy_with_logits(edge_logits, edge_targets)
return edge_loss
def compute_total_loss(model_output, batch):
"""
Compute total loss combining vertex and edge losses
"""
# Extract model outputs
pred_vertices = model_output['vertices']
vertex_confidence = model_output['vertex_confidence']
edge_logits = model_output['edge_logits']
edge_indices = model_output['edge_indices']
# Extract ground truth
gt_vertices = batch['gt_vertices'].to(DEVICE)
vertex_masks = batch['vertex_masks'].to(DEVICE)
gt_adjacency = batch['adjacency_matrices'].to(DEVICE)
# Compute individual losses
vertex_pos_loss, vertex_conf_loss = compute_vertex_loss(
pred_vertices, gt_vertices, vertex_masks, vertex_confidence
)
edge_loss = compute_edge_loss(edge_logits, edge_indices, gt_adjacency)
# Combine losses
total_loss = (VERTEX_LOSS_WEIGHT * vertex_pos_loss +
CONFIDENCE_LOSS_WEIGHT * vertex_conf_loss +
EDGE_LOSS_WEIGHT * edge_loss)
return {
'total_loss': total_loss,
'vertex_pos_loss': vertex_pos_loss,
'vertex_conf_loss': vertex_conf_loss,
'edge_loss': edge_loss
}
# =============================================================================
# MAIN TRAINING SCRIPT
# =============================================================================
if __name__ == '__main__':
# Create dataset and dataloader
dataset = WireframeDataset(data_dir=DATA_DIR, split=SPLIT)
dataloader = DataLoader(
dataset,
batch_size=BATCH_SIZE,
shuffle=True,
collate_fn=collate_fn,
num_workers=NUM_WORKERS
)
# Initialize model
model = PointCloudToWireframe()
# Move model to device
model = model.to(DEVICE)
print(f"Model loaded on device: {DEVICE}")
# Initialize optimizer and scheduler
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode='min', factor=LR_SCHEDULER_FACTOR, patience=LR_SCHEDULER_PATIENCE
)
# Training loop
model.train()
print("Starting training...")
for epoch in range(NUM_EPOCHS):
epoch_losses = {
'total_loss': 0.0,
'vertex_pos_loss': 0.0,
'vertex_conf_loss': 0.0,
'edge_loss': 0.0
}
num_batches = 0
for batch_idx, batch in enumerate(dataloader):
# Move data to device
point_cloud = batch['point_cloud'].to(DEVICE)
# Zero gradients
optimizer.zero_grad()
# Forward pass
output = model(point_cloud)
# Compute losses
losses = compute_total_loss(output, batch)
# Backward pass
losses['total_loss'].backward()
# Gradient clipping
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=GRADIENT_CLIP_MAX_NORM)
# Update weights
optimizer.step()
# Accumulate losses
for key in epoch_losses:
epoch_losses[key] += losses[key].item()
num_batches += 1
# Print progress
if batch_idx % LOG_FREQUENCY == 0:
print(f"Epoch {epoch+1}/{NUM_EPOCHS}, Batch {batch_idx}/{len(dataloader)}")
print(f" Total Loss: {losses['total_loss'].item():.4f}")
print(f" Vertex Pos Loss: {losses['vertex_pos_loss'].item():.4f}")
print(f" Vertex Conf Loss: {losses['vertex_conf_loss'].item():.4f}")
print(f" Edge Loss: {losses['edge_loss'].item():.4f}")
# Average losses for the epoch
for key in epoch_losses:
epoch_losses[key] /= num_batches
# Update learning rate scheduler
scheduler.step(epoch_losses['total_loss'])
# Print epoch summary
print(f"\nEpoch {epoch+1} Summary:")
print(f" Avg Total Loss: {epoch_losses['total_loss']:.4f}")
print(f" Avg Vertex Pos Loss: {epoch_losses['vertex_pos_loss']:.4f}")
print(f" Avg Vertex Conf Loss: {epoch_losses['vertex_conf_loss']:.4f}")
print(f" Avg Edge Loss: {epoch_losses['edge_loss']:.4f}")
print(f" Learning Rate: {optimizer.param_groups[0]['lr']:.6f}")
print("-" * 50)
# Save checkpoint every epoch
if (epoch + 1) % CHECKPOINT_SAVE_FREQUENCY == 0:
checkpoint = {
'epoch': epoch + 1,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'losses': epoch_losses,
'config': {
'pc_input_features': PC_INPUT_FEATURES,
'pc_encoder_output_features': PC_ENCODER_OUTPUT_FEATURES,
'max_vertices': MAX_VERTICES,
'gnn_hidden_dim': GNN_HIDDEN_DIM,
'num_gnn_layers': NUM_GNN_LAYERS
}
}
torch.save(checkpoint, f'checkpoint_epoch_{epoch+1}.pth')
print(f"Checkpoint saved: checkpoint_epoch_{epoch+1}.pth")
# Save final model
torch.save({
'model_state_dict': model.state_dict(),
'model_config': {
'pc_input_features': PC_INPUT_FEATURES,
'pc_encoder_output_features': PC_ENCODER_OUTPUT_FEATURES,
'max_vertices': MAX_VERTICES,
'gnn_hidden_dim': GNN_HIDDEN_DIM,
'num_gnn_layers': NUM_GNN_LAYERS
}
}, 'final_model.pth')
print("Training completed!")
print(f"Dataset size: {len(dataset)}")