subm / dgcnn.py
Neritz's picture
Add handcrafted_submission_2026 contents (model-repo form for S23DR2026 submission)
31f43c9 verified
"""DGCNN backbone — drop-in replacement for PointNet.
EdgeConv with dynamic graph KNN captures local geometric structure
better than PointNet's global aggregation.
Ref: Wang et al., "Dynamic Graph CNN for Learning on Point Clouds", TOG 2019
https://github.com/antao97/dgcnn.pytorch
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
def knn(x, k):
"""Compute KNN graph. x: (B, C, N). Returns (B, N, k) indices."""
inner = -2 * torch.matmul(x.transpose(2, 1), x) # (B, N, N)
xx = torch.sum(x ** 2, dim=1, keepdim=True) # (B, 1, N)
pairwise_dist = -xx - inner - xx.transpose(2, 1) # (B, N, N) negative distances
idx = pairwise_dist.topk(k=k, dim=-1)[1] # (B, N, k)
return idx
def get_graph_feature(x, k=20, idx=None):
"""Build edge features for EdgeConv.
For each point, concatenate [x_j - x_i, x_i] for its k neighbors.
Returns (B, 2*C, N, k).
"""
B, C, N = x.shape
device = x.device
if idx is None:
idx = knn(x, k=k) # (B, N, k)
idx_base = torch.arange(0, B, device=device).view(-1, 1, 1) * N
idx = idx + idx_base
idx = idx.view(-1)
x = x.transpose(2, 1).contiguous() # (B, N, C)
feature = x.view(B * N, -1)[idx, :] # (B*N*k, C)
feature = feature.view(B, N, k, C)
x = x.view(B, N, 1, C).repeat(1, 1, k, 1) # (B, N, k, C)
feature = torch.cat((feature - x, x), dim=3).permute(0, 3, 1, 2).contiguous()
# (B, 2*C, N, k)
return feature
class EdgeConv(nn.Module):
"""Single EdgeConv layer."""
def __init__(self, in_channels, out_channels, k=20):
super().__init__()
self.k = k
self.conv = nn.Sequential(
nn.Conv2d(in_channels * 2, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.LeakyReLU(0.2, inplace=True),
)
def forward(self, x):
# x: (B, C, N)
feat = get_graph_feature(x, k=self.k) # (B, 2*C, N, k)
feat = self.conv(feat) # (B, out, N, k)
feat = feat.max(dim=-1)[0] # (B, out, N)
return feat
class DGCNNBackbone(nn.Module):
"""DGCNN backbone with multiple EdgeConv layers.
Same interface as PointNetBackbone: (B, C, N) → (B, out_dim).
"""
def __init__(self, in_channels, k=20, emb_dims=1024):
super().__init__()
self.k = k
self.edge_conv1 = EdgeConv(in_channels, 64, k)
self.edge_conv2 = EdgeConv(64, 64, k)
self.edge_conv3 = EdgeConv(64, 128, k)
self.edge_conv4 = EdgeConv(128, 256, k)
# Aggregate all EdgeConv outputs
self.conv5 = nn.Sequential(
nn.Conv1d(64 + 64 + 128 + 256, emb_dims, 1, bias=False),
nn.BatchNorm1d(emb_dims),
nn.LeakyReLU(0.2, inplace=True),
)
self.out_dim = emb_dims * 2 # max + avg pooling
def forward(self, x):
"""
Args:
x: (B, C, N)
Returns:
global_feat: (B, out_dim)
"""
x1 = self.edge_conv1(x) # (B, 64, N)
x2 = self.edge_conv2(x1) # (B, 64, N)
x3 = self.edge_conv3(x2) # (B, 128, N)
x4 = self.edge_conv4(x3) # (B, 256, N)
x_cat = torch.cat([x1, x2, x3, x4], dim=1) # (B, 512, N)
x5 = self.conv5(x_cat) # (B, emb_dims, N)
x_max = x5.max(dim=-1)[0] # (B, emb_dims)
x_avg = x5.mean(dim=-1) # (B, emb_dims)
global_feat = torch.cat([x_max, x_avg], dim=1) # (B, 2*emb_dims)
return global_feat
class DGCNNVertexClassifier(nn.Module):
"""DGCNN vertex classifier — same heads as PointNet version."""
def __init__(self, in_channels=11, k=10, emb_dims=512):
super().__init__()
self.backbone = DGCNNBackbone(in_channels, k, emb_dims)
feat_dim = self.backbone.out_dim
self.cls_head = nn.Sequential(
nn.Linear(feat_dim, 512),
nn.BatchNorm1d(512),
nn.LeakyReLU(0.2, inplace=True),
nn.Dropout(0.3),
nn.Linear(512, 128),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(128, 1),
)
self.offset_head = nn.Sequential(
nn.Linear(feat_dim, 512),
nn.BatchNorm1d(512),
nn.LeakyReLU(0.2, inplace=True),
nn.Dropout(0.3),
nn.Linear(512, 128),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(128, 3),
)
self.conf_head = nn.Sequential(
nn.Linear(feat_dim, 256),
nn.BatchNorm1d(256),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(256, 1),
nn.Sigmoid(),
)
def forward(self, x):
feat = self.backbone(x)
cls_logits = self.cls_head(feat)
offset = self.offset_head(feat)
confidence = self.conf_head(feat)
return cls_logits, offset, confidence
class DGCNNEdgeClassifier(nn.Module):
"""DGCNN edge classifier — same heads as PointNet version."""
def __init__(self, in_channels=6, k=10, emb_dims=256):
super().__init__()
self.backbone = DGCNNBackbone(in_channels, k, emb_dims)
feat_dim = self.backbone.out_dim
self.head = nn.Sequential(
nn.Linear(feat_dim, 512),
nn.BatchNorm1d(512),
nn.LeakyReLU(0.2, inplace=True),
nn.Dropout(0.5),
nn.Linear(512, 256),
nn.LeakyReLU(0.2, inplace=True),
nn.Dropout(0.3),
nn.Linear(256, 1),
)
def forward(self, x):
feat = self.backbone(x)
return self.head(feat)