| import torch |
| import os |
| import pickle |
| from torch.utils.data import Dataset, DataLoader |
| import numpy as np |
| from scipy.optimize import linear_sum_assignment |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| |
| |
| |
|
|
| |
| DATA_DIR = '/mnt/personal/skvrnjan/hoho_fully' |
| SPLIT = 'train' |
| MAX_POINTS = 8096 |
| BATCH_SIZE = 32 |
| NUM_WORKERS = 8 |
|
|
| |
| PC_INPUT_FEATURES = 3 |
| PC_ENCODER_OUTPUT_FEATURES = 128 |
| MAX_VERTICES = 50 |
| VERTEX_COORD_DIM = 3 |
| GNN_HIDDEN_DIM = 64 |
| NUM_GNN_LAYERS = 2 |
| HIDDEN_DIM = 256 |
| NUM_DECODER_LAYERS = 3 |
| NUM_HEADS = 8 |
|
|
| |
| SA1_NPOINT = 1024 |
| SA1_RADIUS = 0.2 |
| SA1_NSAMPLE = 32 |
| SA1_MLP = [64, 64, 128] |
|
|
| SA2_NPOINT = 256 |
| SA2_RADIUS = 0.4 |
| SA2_NSAMPLE = 64 |
| SA2_MLP = [128, 128, 256] |
|
|
| SA3_MLP = [256, 512, 1024] |
|
|
| FP3_MLP = [256, 256] |
| FP2_MLP = [256, 128] |
| FP1_MLP = [128, 128] |
|
|
| |
| VERTEX_TRANSFORMER_DROPOUT = 0.1 |
| VERTEX_TRANSFORMER_FFN_RATIO = 4 |
|
|
| |
| EDGE_GNN_NUM_HEADS = 4 |
| EDGE_GNN_DROPOUT = 0.1 |
| EDGE_K_NEIGHBORS = 8 |
|
|
| |
| NUM_EPOCHS = 100 |
| LEARNING_RATE = 1e-4 |
| WEIGHT_DECAY = 1e-5 |
| GRADIENT_CLIP_MAX_NORM = 1.0 |
|
|
| |
| VERTEX_LOSS_WEIGHT = 1.0 |
| EDGE_LOSS_WEIGHT = 0.5 |
| CONFIDENCE_LOSS_WEIGHT = 0.3 |
|
|
| |
| LR_SCHEDULER_FACTOR = 0.5 |
| LR_SCHEDULER_PATIENCE = 10 |
|
|
| |
| CHECKPOINT_SAVE_FREQUENCY = 1 |
| LOG_FREQUENCY = 10 |
|
|
| |
| DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
| |
| |
| |
|
|
| |
| |
|
|
| |
| class PointNet2Encoder(nn.Module): |
| def __init__(self, input_features, output_features): |
| super().__init__() |
| self.input_features = input_features |
| self.output_features = output_features |
| |
| |
| self.sa1 = SetAbstractionLayer( |
| npoint=SA1_NPOINT, radius=SA1_RADIUS, nsample=SA1_NSAMPLE, |
| in_channel=input_features + 3, mlp=SA1_MLP |
| ) |
| self.sa2 = SetAbstractionLayer( |
| npoint=SA2_NPOINT, radius=SA2_RADIUS, nsample=SA2_NSAMPLE, |
| in_channel=SA1_MLP[-1] + 3, mlp=SA2_MLP |
| ) |
| self.sa3 = SetAbstractionLayer( |
| npoint=None, radius=None, nsample=None, |
| in_channel=SA2_MLP[-1] + 3, mlp=SA3_MLP |
| ) |
| |
| |
| self.fp3 = FeaturePropagationLayer(in_channel=SA3_MLP[-1] + SA2_MLP[-1], mlp=FP3_MLP) |
| self.fp2 = FeaturePropagationLayer(in_channel=FP3_MLP[-1] + SA1_MLP[-1], mlp=FP2_MLP) |
| self.fp1 = FeaturePropagationLayer(in_channel=FP2_MLP[-1] + input_features, mlp=FP1_MLP + [output_features]) |
|
|
| def forward(self, xyz): |
| |
| B, N, _ = xyz.shape |
| |
| |
| points = xyz if self.input_features == 3 else None |
| |
| |
| l1_xyz, l1_points = self.sa1(xyz, points) |
| l2_xyz, l2_points = self.sa2(l1_xyz, l1_points) |
| l3_xyz, l3_points = self.sa3(l2_xyz, l2_points) |
| |
| |
| l2_points = self.fp3(l2_xyz, l3_xyz, l2_points, l3_points) |
| l1_points = self.fp2(l1_xyz, l2_xyz, l1_points, l2_points) |
| l0_points = self.fp1(xyz, l1_xyz, points, l1_points) |
| |
| |
| global_feature = l3_points.squeeze(-1) |
| |
| return l0_points, global_feature |
|
|
|
|
| class SetAbstractionLayer(nn.Module): |
| def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all=False): |
| super().__init__() |
| self.npoint = npoint |
| self.radius = radius |
| self.nsample = nsample |
| self.group_all = group_all |
| |
| self.mlp_convs = nn.ModuleList() |
| self.mlp_bns = nn.ModuleList() |
| last_channel = in_channel |
| for out_channel in mlp: |
| self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1)) |
| self.mlp_bns.append(nn.BatchNorm2d(out_channel)) |
| last_channel = out_channel |
|
|
| def forward(self, xyz, points): |
| |
| |
| B, N, C = xyz.shape |
| |
| if self.group_all or self.npoint is None: |
| |
| new_xyz = xyz.mean(dim=1, keepdim=True) |
| if points is not None: |
| new_points = torch.cat([xyz, points], dim=-1) |
| new_points = new_points.transpose(1, 2).unsqueeze(-1) |
| else: |
| new_points = xyz.transpose(1, 2).unsqueeze(-1) |
| else: |
| |
| fps_idx = farthest_point_sample(xyz, self.npoint) |
| new_xyz = index_points(xyz, fps_idx) |
| |
| |
| idx = ball_query(self.radius, self.nsample, xyz, new_xyz) |
| grouped_xyz = index_points(xyz, idx) |
| grouped_xyz_norm = grouped_xyz - new_xyz.unsqueeze(2) |
| |
| if points is not None: |
| grouped_points = index_points(points, idx) |
| new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1) |
| else: |
| new_points = grouped_xyz_norm |
| |
| new_points = new_points.permute(0, 3, 1, 2) |
| |
| |
| for i, conv in enumerate(self.mlp_convs): |
| bn = self.mlp_bns[i] |
| new_points = F.relu(bn(conv(new_points))) |
| |
| |
| new_points = torch.max(new_points, dim=-1)[0] |
| new_points = new_points.transpose(1, 2) |
| |
| return new_xyz, new_points |
|
|
|
|
| class FeaturePropagationLayer(nn.Module): |
| def __init__(self, in_channel, mlp): |
| super().__init__() |
| self.mlp_convs = nn.ModuleList() |
| self.mlp_bns = nn.ModuleList() |
| last_channel = in_channel |
| for out_channel in mlp: |
| self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1)) |
| self.mlp_bns.append(nn.BatchNorm1d(out_channel)) |
| last_channel = out_channel |
|
|
| def forward(self, xyz1, xyz2, points1, points2): |
| |
| |
| |
| |
| |
| |
| if points2 is not None: |
| interpolated_points = interpolate_features(xyz1, xyz2, points2) |
| if points1 is not None: |
| |
| assert points1.shape[1] == interpolated_points.shape[1], f"Point count mismatch: {points1.shape[1]} vs {interpolated_points.shape[1]}" |
| new_points = torch.cat([points1, interpolated_points], dim=-1) |
| else: |
| new_points = interpolated_points |
| else: |
| new_points = points1 |
| |
| |
| if new_points is None: |
| return None |
| |
| |
| new_points = new_points.transpose(1, 2) |
| for i, conv in enumerate(self.mlp_convs): |
| bn = self.mlp_bns[i] |
| new_points = F.relu(bn(conv(new_points))) |
| |
| return new_points.transpose(1, 2) |
|
|
|
|
| def farthest_point_sample(xyz, npoint): |
| """Farthest Point Sampling""" |
| device = xyz.device |
| B, N, C = xyz.shape |
| centroids = torch.zeros(B, npoint, dtype=torch.long).to(device) |
| distance = torch.ones(B, N).to(device) * 1e10 |
| farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device) |
| |
| for i in range(npoint): |
| centroids[:, i] = farthest |
| centroid = xyz[torch.arange(B), farthest, :].view(B, 1, 3) |
| dist = torch.sum((xyz - centroid) ** 2, -1) |
| mask = dist < distance |
| distance[mask] = dist[mask] |
| farthest = torch.max(distance, -1)[1] |
| |
| return centroids |
|
|
|
|
| def ball_query(radius, nsample, xyz, new_xyz): |
| """Ball Query""" |
| device = xyz.device |
| B, N, C = xyz.shape |
| _, S, _ = new_xyz.shape |
| |
| group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1]) |
| sqrdists = square_distance(new_xyz, xyz) |
| group_idx[sqrdists > radius ** 2] = N |
| group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample] |
| group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample]) |
| mask = group_idx == N |
| group_idx[mask] = group_first[mask] |
| |
| |
| |
| |
| group_idx[group_idx == N] = 0 |
| |
| return group_idx |
|
|
|
|
| def square_distance(src, dst): |
| """Calculate squared distance between each two points""" |
| B, N, _ = src.shape |
| _, M, _ = dst.shape |
| dist = -2 * torch.matmul(src, dst.permute(0, 2, 1)) |
| dist += torch.sum(src ** 2, -1).view(B, N, 1) |
| dist += torch.sum(dst ** 2, -1).view(B, 1, M) |
| return dist |
|
|
|
|
| def index_points(points, idx): |
| """Index points using given indices""" |
| device = points.device |
| B = points.shape[0] |
| view_shape = list(idx.shape) |
| view_shape[1:] = [1] * (len(view_shape) - 1) |
| repeat_shape = list(idx.shape) |
| repeat_shape[0] = 1 |
| batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape) |
| new_points = points[batch_indices, idx, :] |
| return new_points |
|
|
|
|
| def interpolate_features(xyz1, xyz2, points2): |
| """Interpolate features using inverse distance weighting""" |
| B, N1, C = xyz1.shape |
| _, N2, _ = xyz2.shape |
| |
| if N2 == 1: |
| |
| interpolated_points = points2.expand(B, N1, -1) |
| else: |
| |
| dists = square_distance(xyz1, xyz2) |
| dists, idx = dists.sort(dim=-1) |
| |
| |
| k = min(3, N2) |
| dists, idx = dists[:, :, :k], idx[:, :, :k] |
| |
| |
| dists[dists < 1e-10] = 1e-10 |
| weight = 1.0 / dists |
| weight = weight / torch.sum(weight, dim=-1, keepdim=True) |
| |
| |
| interpolated_points = torch.sum( |
| index_points(points2, idx) * weight.view(B, N1, k, 1), dim=2 |
| ) |
| |
| return interpolated_points |
|
|
| |
| class VertexPredictionHead(nn.Module): |
| def __init__(self, point_feature_dim, global_feature_dim, max_vertices, vertex_coord_dim=3, |
| hidden_dim=256, num_decoder_layers=3, num_heads=8): |
| super().__init__() |
| self.max_vertices = max_vertices |
| self.vertex_coord_dim = vertex_coord_dim |
| self.hidden_dim = hidden_dim |
| |
| |
| self.vertex_queries = nn.Parameter(torch.randn(max_vertices, hidden_dim)) |
| |
| |
| self.global_proj = nn.Linear(global_feature_dim, 1) |
| |
| |
| self.point_proj = nn.Linear(point_feature_dim, hidden_dim) |
| |
| |
| decoder_layer = nn.TransformerDecoderLayer( |
| d_model=hidden_dim, |
| nhead=num_heads, |
| dim_feedforward=hidden_dim * VERTEX_TRANSFORMER_FFN_RATIO, |
| dropout=VERTEX_TRANSFORMER_DROPOUT, |
| batch_first=True |
| ) |
| self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_decoder_layers) |
| |
| |
| self.vertex_coord_head = nn.Sequential( |
| nn.Linear(hidden_dim, hidden_dim), |
| nn.ReLU(), |
| nn.Linear(hidden_dim, vertex_coord_dim) |
| ) |
| |
| |
| self.vertex_conf_head = nn.Sequential( |
| nn.Linear(hidden_dim, hidden_dim), |
| nn.ReLU(), |
| nn.Linear(hidden_dim, 1) |
| ) |
| |
| |
| self.pos_encoding = nn.Sequential( |
| nn.Linear(3, hidden_dim // 2), |
| nn.ReLU(), |
| nn.Linear(hidden_dim // 2, hidden_dim) |
| ) |
|
|
| def forward(self, point_features, global_feature, point_coords=None): |
| |
| |
| |
| |
| batch_size = point_features.shape[0] |
| |
| |
| point_features_proj = self.point_proj(point_features) |
| |
| |
| if point_coords is not None: |
| pos_enc = self.pos_encoding(point_coords) |
| point_features_proj = point_features_proj + pos_enc |
| |
| |
| vertex_queries = self.vertex_queries.unsqueeze(0).repeat(batch_size, 1, 1) |
| |
| |
| global_proj = self.global_proj(global_feature).squeeze(-1).unsqueeze(1) |
| vertex_queries = vertex_queries + global_proj |
| |
| |
| vertex_features = self.transformer_decoder( |
| tgt=vertex_queries, |
| memory=point_features_proj |
| ) |
| |
| |
| predicted_vertices = self.vertex_coord_head(vertex_features) |
| |
| |
| vertex_confidence = self.vertex_conf_head(vertex_features).squeeze(-1) |
| |
| return predicted_vertices, vertex_confidence |
| |
| |
| class EdgePredictionHeadGNN(nn.Module): |
| def __init__(self, vertex_feature_dim, gnn_hidden_dim, num_gnn_layers): |
| super().__init__() |
| self.vertex_feature_dim = vertex_feature_dim |
| self.gnn_hidden_dim = gnn_hidden_dim |
| self.num_gnn_layers = num_gnn_layers |
| |
| |
| self.vertex_proj = nn.Linear(vertex_feature_dim, gnn_hidden_dim) |
| |
| |
| self.gnn_layers = nn.ModuleList() |
| for i in range(num_gnn_layers): |
| self.gnn_layers.append( |
| GraphAttentionLayer( |
| in_features=gnn_hidden_dim, |
| out_features=gnn_hidden_dim, |
| num_heads=EDGE_GNN_NUM_HEADS, |
| dropout=EDGE_GNN_DROPOUT |
| ) |
| ) |
| |
| |
| self.edge_mlp = nn.Sequential( |
| nn.Linear(gnn_hidden_dim * 2, gnn_hidden_dim), |
| nn.ReLU(), |
| nn.Dropout(EDGE_GNN_DROPOUT), |
| nn.Linear(gnn_hidden_dim, gnn_hidden_dim // 2), |
| nn.ReLU(), |
| nn.Linear(gnn_hidden_dim // 2, 1) |
| ) |
| |
| |
| self.k_neighbors = EDGE_K_NEIGHBORS |
|
|
| def forward(self, vertices): |
| |
| batch_size, num_vertices, _ = vertices.shape |
| |
| |
| vertex_features = self.vertex_proj(vertices) |
| |
| |
| adjacency_matrix = self.construct_knn_graph(vertices, k=self.k_neighbors) |
| |
| |
| for gnn_layer in self.gnn_layers: |
| vertex_features = gnn_layer(vertex_features, adjacency_matrix) |
| |
| |
| idx_pairs = torch.combinations(torch.arange(num_vertices), r=2).to(vertices.device) |
| |
| |
| v1_features = vertex_features[:, idx_pairs[:, 0], :] |
| v2_features = vertex_features[:, idx_pairs[:, 1], :] |
| |
| |
| edge_features = torch.cat([v1_features, v2_features], dim=2) |
| |
| |
| edge_logits = self.edge_mlp(edge_features).squeeze(-1) |
| |
| return edge_logits, idx_pairs |
|
|
| def construct_knn_graph(self, vertices, k): |
| |
| batch_size, num_vertices, _ = vertices.shape |
| |
| |
| distances = torch.cdist(vertices, vertices, p=2) |
| |
| |
| _, knn_indices = torch.topk(distances, k + 1, dim=-1, largest=False) |
| knn_indices = knn_indices[:, :, 1:] |
| |
| |
| adjacency = torch.zeros(batch_size, num_vertices, num_vertices, device=vertices.device) |
| |
| |
| batch_idx = torch.arange(batch_size).view(-1, 1, 1).expand(-1, num_vertices, k) |
| vertex_idx = torch.arange(num_vertices).view(1, -1, 1).expand(batch_size, -1, k) |
| |
| adjacency[batch_idx, vertex_idx, knn_indices] = 1.0 |
| |
| |
| adjacency = torch.max(adjacency, adjacency.transpose(-1, -2)) |
| |
| return adjacency |
|
|
|
|
| class GraphAttentionLayer(nn.Module): |
| def __init__(self, in_features, out_features, num_heads=1, dropout=0.1): |
| super().__init__() |
| self.in_features = in_features |
| self.out_features = out_features |
| self.num_heads = num_heads |
| self.dropout = dropout |
| |
| assert out_features % num_heads == 0 |
| self.head_dim = out_features // num_heads |
| |
| |
| self.W_q = nn.Linear(in_features, out_features) |
| self.W_k = nn.Linear(in_features, out_features) |
| self.W_v = nn.Linear(in_features, out_features) |
| |
| |
| self.W_o = nn.Linear(out_features, out_features) |
| |
| |
| self.attention = nn.MultiheadAttention( |
| embed_dim=out_features, |
| num_heads=num_heads, |
| dropout=dropout, |
| batch_first=True |
| ) |
| |
| |
| self.layer_norm = nn.LayerNorm(out_features) |
| self.ffn = nn.Sequential( |
| nn.Linear(out_features, out_features * 2), |
| nn.ReLU(), |
| nn.Dropout(dropout), |
| nn.Linear(out_features * 2, out_features) |
| ) |
| self.layer_norm2 = nn.LayerNorm(out_features) |
|
|
| def forward(self, x, adjacency_matrix): |
| |
| |
| batch_size, num_vertices, _ = x.shape |
| |
| |
| Q = self.W_q(x) |
| K = self.W_k(x) |
| V = self.W_v(x) |
| |
| |
| |
| attention_mask = (1 - adjacency_matrix) * (-1e9) |
| |
| |
| attended_features = [] |
| for b in range(batch_size): |
| q_b = Q[b:b+1] |
| k_b = K[b:b+1] |
| v_b = V[b:b+1] |
| mask_b = attention_mask[b] |
| |
| |
| attn_output, _ = self.attention(q_b, k_b, v_b, attn_mask=mask_b) |
| attended_features.append(attn_output) |
| |
| attended_features = torch.cat(attended_features, dim=0) |
| |
| |
| x_residual = self.layer_norm(attended_features + Q) |
| |
| |
| ffn_output = self.ffn(x_residual) |
| output = self.layer_norm2(ffn_output + x_residual) |
| |
| return output |
|
|
| |
| class PointCloudToWireframe(nn.Module): |
| def __init__(self, |
| pc_input_features=PC_INPUT_FEATURES, |
| pc_encoder_output_features=PC_ENCODER_OUTPUT_FEATURES, |
| max_vertices=MAX_VERTICES, |
| vertex_coord_dim=VERTEX_COORD_DIM, |
| gnn_hidden_dim=GNN_HIDDEN_DIM, |
| num_gnn_layers=NUM_GNN_LAYERS, |
| hidden_dim=HIDDEN_DIM, |
| num_decoder_layers=NUM_DECODER_LAYERS, |
| num_heads=NUM_HEADS): |
| super().__init__() |
| |
| |
| self.encoder = PointNet2Encoder(pc_input_features, pc_encoder_output_features) |
| |
| |
| self.vertex_head = VertexPredictionHead( |
| point_feature_dim=pc_encoder_output_features, |
| global_feature_dim=SA3_MLP[-1], |
| max_vertices=max_vertices, |
| vertex_coord_dim=vertex_coord_dim, |
| hidden_dim=hidden_dim, |
| num_decoder_layers=num_decoder_layers, |
| num_heads=num_heads |
| ) |
| |
| |
| self.edge_head = EdgePredictionHeadGNN( |
| vertex_feature_dim=vertex_coord_dim, |
| gnn_hidden_dim=gnn_hidden_dim, |
| num_gnn_layers=num_gnn_layers |
| ) |
|
|
| def forward(self, point_cloud): |
| |
| batch_size, num_points, _ = point_cloud.shape |
| |
| |
| point_features, global_feature = self.encoder(point_cloud) |
| |
| |
| |
| |
| predicted_vertices, vertex_confidence = self.vertex_head( |
| point_features, global_feature, point_coords=point_cloud |
| ) |
| |
| |
| |
| |
| edge_logits, edge_indices = self.edge_head(predicted_vertices) |
| |
| |
| |
| return { |
| 'vertices': predicted_vertices, |
| 'vertex_confidence': vertex_confidence, |
| 'edge_logits': edge_logits, |
| 'edge_indices': edge_indices |
| } |
|
|
| class WireframeDataset(Dataset): |
| def __init__(self, data_dir=DATA_DIR, split=SPLIT, transform=None, max_points=MAX_POINTS): |
| """ |
| Dataset for point cloud to wireframe conversion. |
| |
| Args: |
| data_dir: Directory containing the pickle files |
| split: 'train', 'val', or 'test' |
| transform: Optional transforms to apply to point clouds |
| max_points: Maximum number of points in the point cloud (default: 8096) |
| """ |
| self.data_dir = data_dir |
| self.split = split |
| self.transform = transform |
| self.max_points = max_points |
| |
| |
| self.data_files = [] |
| for file in os.listdir(data_dir): |
| if file.endswith('.pkl'): |
| self.data_files.append(os.path.join(data_dir, file)) |
| |
| self.data_files.sort() |
| |
| def __len__(self): |
| return len(self.data_files) |
| |
| def __getitem__(self, idx): |
| |
| with open(self.data_files[idx], 'rb') as f: |
| sample_data = pickle.load(f) |
| |
| |
| point_cloud = torch.tensor(sample_data['point_cloud'], dtype=torch.float32) |
| point_colors = torch.tensor(sample_data['point_colors'], dtype=torch.float32) |
| gt_vertices = torch.tensor(sample_data['gt_vertices'], dtype=torch.float32) |
| gt_connections = sample_data['gt_connections'] |
| sample_id = sample_data['sample_id'] |
| |
| |
| current_points = point_cloud.shape[0] |
| |
| if current_points > self.max_points: |
| |
| indices = torch.randperm(current_points)[:self.max_points] |
| point_cloud = point_cloud[indices] |
| point_colors = point_colors[indices] |
| elif current_points < self.max_points: |
| |
| pad_size = self.max_points - current_points |
| if current_points > 0: |
| |
| pad_indices = torch.randint(0, current_points, (pad_size,)) |
| pad_points = point_cloud[pad_indices] |
| pad_colors = point_colors[pad_indices] |
| point_cloud = torch.cat([point_cloud, pad_points], dim=0) |
| point_colors = torch.cat([point_colors, pad_colors], dim=0) |
| else: |
| |
| point_cloud = torch.zeros(self.max_points, 3) |
| point_colors = torch.zeros(self.max_points, 3) |
| |
| |
| if len(gt_connections) > 0: |
| edge_indices = torch.tensor(gt_connections, dtype=torch.long).t() |
| else: |
| edge_indices = torch.zeros((2, 0), dtype=torch.long) |
| |
| |
| if self.transform: |
| point_cloud = self.transform(point_cloud) |
| |
| return { |
| 'point_cloud': point_cloud, |
| 'point_colors': point_colors, |
| 'gt_vertices': gt_vertices, |
| 'edge_indices': edge_indices, |
| 'sample_id': sample_id |
| } |
|
|
| def collate_fn(batch): |
| """ |
| Custom collate function to handle variable number of vertices and edges. |
| """ |
| point_clouds = [] |
| point_colors = [] |
| gt_vertices_list = [] |
| edge_indices_list = [] |
| sample_ids = [] |
| |
| max_vertices = 0 |
| |
| for sample in batch: |
| point_clouds.append(sample['point_cloud']) |
| point_colors.append(sample['point_colors']) |
| gt_vertices_list.append(sample['gt_vertices']) |
| edge_indices_list.append(sample['edge_indices']) |
| sample_ids.append(sample['sample_id']) |
| |
| max_vertices = max(max_vertices, sample['gt_vertices'].shape[0]) |
| |
| |
| max_points = max(pc.shape[0] for pc in point_clouds) |
| padded_point_clouds = [] |
| padded_point_colors = [] |
| |
| for pc, colors in zip(point_clouds, point_colors): |
| if pc.shape[0] < max_points: |
| |
| pad_size = max_points - pc.shape[0] |
| pc_padded = torch.cat([pc, torch.zeros(pad_size, 3)], dim=0) |
| colors_padded = torch.cat([colors, torch.zeros(pad_size, 3)], dim=0) |
| else: |
| pc_padded = pc |
| colors_padded = colors |
| |
| padded_point_clouds.append(pc_padded) |
| padded_point_colors.append(colors_padded) |
| |
| |
| point_clouds_batch = torch.stack(padded_point_clouds) |
| point_colors_batch = torch.stack(padded_point_colors) |
| |
| |
| padded_vertices = [] |
| vertex_masks = [] |
| |
| for vertices in gt_vertices_list: |
| num_vertices = vertices.shape[0] |
| if num_vertices < max_vertices: |
| |
| pad_size = max_vertices - num_vertices |
| vertices_padded = torch.cat([vertices, torch.zeros(pad_size, 3)], dim=0) |
| mask = torch.cat([torch.ones(num_vertices), torch.zeros(pad_size)], dim=0).bool() |
| else: |
| vertices_padded = vertices |
| mask = torch.ones(num_vertices).bool() |
| |
| padded_vertices.append(vertices_padded) |
| vertex_masks.append(mask) |
| |
| gt_vertices_batch = torch.stack(padded_vertices) |
| vertex_masks_batch = torch.stack(vertex_masks) |
| |
| |
| batch_size = len(batch) |
| adjacency_matrices = torch.zeros(batch_size, max_vertices, max_vertices) |
| |
| for i, edge_indices in enumerate(edge_indices_list): |
| if edge_indices.shape[1] > 0: |
| src, dst = edge_indices[0], edge_indices[1] |
| |
| valid_edges = (src < gt_vertices_list[i].shape[0]) & (dst < gt_vertices_list[i].shape[0]) |
| src_valid = src[valid_edges] |
| dst_valid = dst[valid_edges] |
| adjacency_matrices[i, src_valid, dst_valid] = 1 |
| adjacency_matrices[i, dst_valid, src_valid] = 1 |
| |
| return { |
| 'point_cloud': point_clouds_batch, |
| 'point_colors': point_colors_batch, |
| 'gt_vertices': gt_vertices_batch, |
| 'vertex_masks': vertex_masks_batch, |
| 'adjacency_matrices': adjacency_matrices, |
| 'edge_indices_list': edge_indices_list, |
| 'sample_ids': sample_ids |
| } |
|
|
| |
| def compute_vertex_loss(pred_vertices, gt_vertices, vertex_masks, vertex_confidence): |
| """ |
| Compute vertex position loss using Hungarian matching |
| """ |
| batch_size = pred_vertices.shape[0] |
| total_loss = 0.0 |
| total_confidence_loss = 0.0 |
| |
| for b in range(batch_size): |
| |
| valid_mask = vertex_masks[b] |
| gt_verts = gt_vertices[b][valid_mask] |
| num_gt = gt_verts.shape[0] |
| |
| if num_gt == 0: |
| |
| confidence_target = torch.zeros_like(vertex_confidence[b]) |
| conf_loss = F.binary_cross_entropy_with_logits(vertex_confidence[b], confidence_target) |
| total_confidence_loss += conf_loss |
| continue |
| |
| pred_verts = pred_vertices[b] |
| pred_conf = vertex_confidence[b] |
| |
| |
| distances = torch.cdist(pred_verts, gt_verts) |
| |
| |
| |
| |
| cost_matrix = distances.detach().cpu().numpy() |
| |
| |
| if distances.shape[0] < distances.shape[1]: |
| |
| padding = np.full((distances.shape[1] - distances.shape[0], distances.shape[1]), 1e6) |
| cost_matrix = np.vstack([cost_matrix, padding]) |
| elif distances.shape[0] > distances.shape[1]: |
| |
| padding = np.full((distances.shape[0], distances.shape[0] - distances.shape[1]), 1e6) |
| cost_matrix = np.hstack([cost_matrix, padding]) |
| |
| |
| pred_indices, gt_indices = linear_sum_assignment(cost_matrix) |
| |
| |
| |
| valid_assignments = (pred_indices < pred_verts.shape[0]) & (gt_indices < num_gt) |
| pred_indices = pred_indices[valid_assignments] |
| gt_indices = gt_indices[valid_assignments] |
| |
| if len(pred_indices) > 0: |
| |
| matched_pred = pred_verts[pred_indices] |
| matched_gt = gt_verts[gt_indices] |
| position_loss = F.mse_loss(matched_pred, matched_gt) |
| total_loss += position_loss |
| |
| |
| confidence_target = torch.zeros_like(pred_conf) |
| confidence_target[pred_indices] = 1.0 |
| conf_loss = F.binary_cross_entropy_with_logits(pred_conf, confidence_target) |
| total_confidence_loss += conf_loss |
| else: |
| |
| confidence_target = torch.zeros_like(pred_conf) |
| conf_loss = F.binary_cross_entropy_with_logits(pred_conf, confidence_target) |
| total_confidence_loss += conf_loss |
| |
| return total_loss / batch_size, total_confidence_loss / batch_size |
|
|
| def compute_edge_loss(edge_logits, edge_indices, gt_adjacency_matrices): |
| """ |
| Compute edge prediction loss |
| """ |
| batch_size = gt_adjacency_matrices.shape[0] |
| |
| |
| edge_targets = [] |
| for b in range(batch_size): |
| gt_adj_for_sample = gt_adjacency_matrices[b] |
| |
| |
| |
| target_adj_full_size = torch.zeros( |
| MAX_VERTICES, |
| MAX_VERTICES, |
| device=gt_adj_for_sample.device, |
| dtype=gt_adj_for_sample.dtype |
| ) |
| |
| |
| current_gt_dim = gt_adj_for_sample.shape[0] |
| |
| |
| |
| |
| copy_dim = min(MAX_VERTICES, current_gt_dim) |
| |
| target_adj_full_size[:copy_dim, :copy_dim] = gt_adj_for_sample[:copy_dim, :copy_dim] |
| |
| |
| targets = target_adj_full_size[edge_indices[:, 0], edge_indices[:, 1]] |
| edge_targets.append(targets) |
| |
| edge_targets = torch.stack(edge_targets) |
| edge_targets = edge_targets.to(edge_logits.device) |
| |
| |
| edge_loss = F.binary_cross_entropy_with_logits(edge_logits, edge_targets) |
| |
| return edge_loss |
|
|
| def compute_total_loss(model_output, batch): |
| """ |
| Compute total loss combining vertex and edge losses |
| """ |
| |
| pred_vertices = model_output['vertices'] |
| vertex_confidence = model_output['vertex_confidence'] |
| edge_logits = model_output['edge_logits'] |
| edge_indices = model_output['edge_indices'] |
| |
| |
| gt_vertices = batch['gt_vertices'].to(DEVICE) |
| vertex_masks = batch['vertex_masks'].to(DEVICE) |
| gt_adjacency = batch['adjacency_matrices'].to(DEVICE) |
| |
| |
| vertex_pos_loss, vertex_conf_loss = compute_vertex_loss( |
| pred_vertices, gt_vertices, vertex_masks, vertex_confidence |
| ) |
| edge_loss = compute_edge_loss(edge_logits, edge_indices, gt_adjacency) |
| |
| |
| total_loss = (VERTEX_LOSS_WEIGHT * vertex_pos_loss + |
| CONFIDENCE_LOSS_WEIGHT * vertex_conf_loss + |
| EDGE_LOSS_WEIGHT * edge_loss) |
| |
| return { |
| 'total_loss': total_loss, |
| 'vertex_pos_loss': vertex_pos_loss, |
| 'vertex_conf_loss': vertex_conf_loss, |
| 'edge_loss': edge_loss |
| } |
|
|
| |
| |
| |
|
|
| if __name__ == '__main__': |
| |
| dataset = WireframeDataset(data_dir=DATA_DIR, split=SPLIT) |
| dataloader = DataLoader( |
| dataset, |
| batch_size=BATCH_SIZE, |
| shuffle=True, |
| collate_fn=collate_fn, |
| num_workers=NUM_WORKERS |
| ) |
| |
| |
| model = PointCloudToWireframe() |
|
|
| |
| model = model.to(DEVICE) |
| print(f"Model loaded on device: {DEVICE}") |
| |
| |
| optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY) |
| scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( |
| optimizer, mode='min', factor=LR_SCHEDULER_FACTOR, patience=LR_SCHEDULER_PATIENCE |
| ) |
| |
| |
| model.train() |
| print("Starting training...") |
| |
| for epoch in range(NUM_EPOCHS): |
| epoch_losses = { |
| 'total_loss': 0.0, |
| 'vertex_pos_loss': 0.0, |
| 'vertex_conf_loss': 0.0, |
| 'edge_loss': 0.0 |
| } |
| num_batches = 0 |
| |
| for batch_idx, batch in enumerate(dataloader): |
| |
| point_cloud = batch['point_cloud'].to(DEVICE) |
| |
| |
| optimizer.zero_grad() |
| |
| |
| output = model(point_cloud) |
| |
| |
| losses = compute_total_loss(output, batch) |
| |
| |
| losses['total_loss'].backward() |
| |
| |
| torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=GRADIENT_CLIP_MAX_NORM) |
| |
| |
| optimizer.step() |
| |
| |
| for key in epoch_losses: |
| epoch_losses[key] += losses[key].item() |
| num_batches += 1 |
| |
| |
| if batch_idx % LOG_FREQUENCY == 0: |
| print(f"Epoch {epoch+1}/{NUM_EPOCHS}, Batch {batch_idx}/{len(dataloader)}") |
| print(f" Total Loss: {losses['total_loss'].item():.4f}") |
| print(f" Vertex Pos Loss: {losses['vertex_pos_loss'].item():.4f}") |
| print(f" Vertex Conf Loss: {losses['vertex_conf_loss'].item():.4f}") |
| print(f" Edge Loss: {losses['edge_loss'].item():.4f}") |
| |
| |
| for key in epoch_losses: |
| epoch_losses[key] /= num_batches |
| |
| |
| scheduler.step(epoch_losses['total_loss']) |
| |
| |
| print(f"\nEpoch {epoch+1} Summary:") |
| print(f" Avg Total Loss: {epoch_losses['total_loss']:.4f}") |
| print(f" Avg Vertex Pos Loss: {epoch_losses['vertex_pos_loss']:.4f}") |
| print(f" Avg Vertex Conf Loss: {epoch_losses['vertex_conf_loss']:.4f}") |
| print(f" Avg Edge Loss: {epoch_losses['edge_loss']:.4f}") |
| print(f" Learning Rate: {optimizer.param_groups[0]['lr']:.6f}") |
| print("-" * 50) |
| |
| |
| if (epoch + 1) % CHECKPOINT_SAVE_FREQUENCY == 0: |
| checkpoint = { |
| 'epoch': epoch + 1, |
| 'model_state_dict': model.state_dict(), |
| 'optimizer_state_dict': optimizer.state_dict(), |
| 'scheduler_state_dict': scheduler.state_dict(), |
| 'losses': epoch_losses, |
| 'config': { |
| 'pc_input_features': PC_INPUT_FEATURES, |
| 'pc_encoder_output_features': PC_ENCODER_OUTPUT_FEATURES, |
| 'max_vertices': MAX_VERTICES, |
| 'gnn_hidden_dim': GNN_HIDDEN_DIM, |
| 'num_gnn_layers': NUM_GNN_LAYERS |
| } |
| } |
| torch.save(checkpoint, f'checkpoint_epoch_{epoch+1}.pth') |
| print(f"Checkpoint saved: checkpoint_epoch_{epoch+1}.pth") |
| |
| |
| torch.save({ |
| 'model_state_dict': model.state_dict(), |
| 'model_config': { |
| 'pc_input_features': PC_INPUT_FEATURES, |
| 'pc_encoder_output_features': PC_ENCODER_OUTPUT_FEATURES, |
| 'max_vertices': MAX_VERTICES, |
| 'gnn_hidden_dim': GNN_HIDDEN_DIM, |
| 'num_gnn_layers': NUM_GNN_LAYERS |
| } |
| }, 'final_model.pth') |
| |
| print("Training completed!") |
| print(f"Dataset size: {len(dataset)}") |
|
|