"""DGCNN backbone — drop-in replacement for PointNet. EdgeConv with dynamic graph KNN captures local geometric structure better than PointNet's global aggregation. Ref: Wang et al., "Dynamic Graph CNN for Learning on Point Clouds", TOG 2019 https://github.com/antao97/dgcnn.pytorch """ import torch import torch.nn as nn import torch.nn.functional as F def knn(x, k): """Compute KNN graph. x: (B, C, N). Returns (B, N, k) indices.""" inner = -2 * torch.matmul(x.transpose(2, 1), x) # (B, N, N) xx = torch.sum(x ** 2, dim=1, keepdim=True) # (B, 1, N) pairwise_dist = -xx - inner - xx.transpose(2, 1) # (B, N, N) negative distances idx = pairwise_dist.topk(k=k, dim=-1)[1] # (B, N, k) return idx def get_graph_feature(x, k=20, idx=None): """Build edge features for EdgeConv. For each point, concatenate [x_j - x_i, x_i] for its k neighbors. Returns (B, 2*C, N, k). """ B, C, N = x.shape device = x.device if idx is None: idx = knn(x, k=k) # (B, N, k) idx_base = torch.arange(0, B, device=device).view(-1, 1, 1) * N idx = idx + idx_base idx = idx.view(-1) x = x.transpose(2, 1).contiguous() # (B, N, C) feature = x.view(B * N, -1)[idx, :] # (B*N*k, C) feature = feature.view(B, N, k, C) x = x.view(B, N, 1, C).repeat(1, 1, k, 1) # (B, N, k, C) feature = torch.cat((feature - x, x), dim=3).permute(0, 3, 1, 2).contiguous() # (B, 2*C, N, k) return feature class EdgeConv(nn.Module): """Single EdgeConv layer.""" def __init__(self, in_channels, out_channels, k=20): super().__init__() self.k = k self.conv = nn.Sequential( nn.Conv2d(in_channels * 2, out_channels, 1, bias=False), nn.BatchNorm2d(out_channels), nn.LeakyReLU(0.2, inplace=True), ) def forward(self, x): # x: (B, C, N) feat = get_graph_feature(x, k=self.k) # (B, 2*C, N, k) feat = self.conv(feat) # (B, out, N, k) feat = feat.max(dim=-1)[0] # (B, out, N) return feat class DGCNNBackbone(nn.Module): """DGCNN backbone with multiple EdgeConv layers. Same interface as PointNetBackbone: (B, C, N) → (B, out_dim). """ def __init__(self, in_channels, k=20, emb_dims=1024): super().__init__() self.k = k self.edge_conv1 = EdgeConv(in_channels, 64, k) self.edge_conv2 = EdgeConv(64, 64, k) self.edge_conv3 = EdgeConv(64, 128, k) self.edge_conv4 = EdgeConv(128, 256, k) # Aggregate all EdgeConv outputs self.conv5 = nn.Sequential( nn.Conv1d(64 + 64 + 128 + 256, emb_dims, 1, bias=False), nn.BatchNorm1d(emb_dims), nn.LeakyReLU(0.2, inplace=True), ) self.out_dim = emb_dims * 2 # max + avg pooling def forward(self, x): """ Args: x: (B, C, N) Returns: global_feat: (B, out_dim) """ x1 = self.edge_conv1(x) # (B, 64, N) x2 = self.edge_conv2(x1) # (B, 64, N) x3 = self.edge_conv3(x2) # (B, 128, N) x4 = self.edge_conv4(x3) # (B, 256, N) x_cat = torch.cat([x1, x2, x3, x4], dim=1) # (B, 512, N) x5 = self.conv5(x_cat) # (B, emb_dims, N) x_max = x5.max(dim=-1)[0] # (B, emb_dims) x_avg = x5.mean(dim=-1) # (B, emb_dims) global_feat = torch.cat([x_max, x_avg], dim=1) # (B, 2*emb_dims) return global_feat class DGCNNVertexClassifier(nn.Module): """DGCNN vertex classifier — same heads as PointNet version.""" def __init__(self, in_channels=11, k=10, emb_dims=512): super().__init__() self.backbone = DGCNNBackbone(in_channels, k, emb_dims) feat_dim = self.backbone.out_dim self.cls_head = nn.Sequential( nn.Linear(feat_dim, 512), nn.BatchNorm1d(512), nn.LeakyReLU(0.2, inplace=True), nn.Dropout(0.3), nn.Linear(512, 128), nn.LeakyReLU(0.2, inplace=True), nn.Linear(128, 1), ) self.offset_head = nn.Sequential( nn.Linear(feat_dim, 512), nn.BatchNorm1d(512), nn.LeakyReLU(0.2, inplace=True), nn.Dropout(0.3), nn.Linear(512, 128), nn.LeakyReLU(0.2, inplace=True), nn.Linear(128, 3), ) self.conf_head = nn.Sequential( nn.Linear(feat_dim, 256), nn.BatchNorm1d(256), nn.LeakyReLU(0.2, inplace=True), nn.Linear(256, 1), nn.Sigmoid(), ) def forward(self, x): feat = self.backbone(x) cls_logits = self.cls_head(feat) offset = self.offset_head(feat) confidence = self.conf_head(feat) return cls_logits, offset, confidence class DGCNNEdgeClassifier(nn.Module): """DGCNN edge classifier — same heads as PointNet version.""" def __init__(self, in_channels=6, k=10, emb_dims=256): super().__init__() self.backbone = DGCNNBackbone(in_channels, k, emb_dims) feat_dim = self.backbone.out_dim self.head = nn.Sequential( nn.Linear(feat_dim, 512), nn.BatchNorm1d(512), nn.LeakyReLU(0.2, inplace=True), nn.Dropout(0.5), nn.Linear(512, 256), nn.LeakyReLU(0.2, inplace=True), nn.Dropout(0.3), nn.Linear(256, 1), ) def forward(self, x): feat = self.backbone(x) return self.head(feat)