Improves PointNet classification model.
Browse filesThis commit introduces a significantly enhanced PointNet
implementation for binary classification of 6D point cloud
patches.
Key improvements include:
- A deeper architecture with residual connections for enhanced
feature extraction.
- An attention mechanism for improved feature weighting.
- Multi-scale feature aggregation for richer representations.
- An enhanced classification head with residual connections and varying dropout rates for better generalization.
Additionally, the training batch size was reduced to 128, and the model save path was updated to reflect the new, improved model.
- fast_pointnet_class.py +161 -40
- train_pnet_class_cluster.py +2 -2
fast_pointnet_class.py
CHANGED
|
@@ -10,77 +10,199 @@ import json
|
|
| 10 |
|
| 11 |
class ClassificationPointNet(nn.Module):
|
| 12 |
"""
|
| 13 |
-
PointNet implementation for binary classification from 6D point cloud patches.
|
| 14 |
Takes 6D point clouds (x,y,z,r,g,b) and predicts binary classification (edge/not edge).
|
|
|
|
| 15 |
"""
|
| 16 |
def __init__(self, input_dim=6, max_points=1024):
|
| 17 |
super(ClassificationPointNet, self).__init__()
|
| 18 |
self.max_points = max_points
|
| 19 |
|
| 20 |
-
# Point-wise MLPs
|
| 21 |
self.conv1 = nn.Conv1d(input_dim, 64, 1)
|
| 22 |
-
self.conv2 = nn.Conv1d(64,
|
| 23 |
-
self.conv3 = nn.Conv1d(
|
| 24 |
-
self.conv4 = nn.Conv1d(
|
| 25 |
-
self.conv5 = nn.Conv1d(
|
| 26 |
-
self.conv6 = nn.Conv1d(
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
self.
|
| 30 |
-
self.
|
| 31 |
-
self.
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
|
| 36 |
# Batch normalization layers
|
| 37 |
self.bn1 = nn.BatchNorm1d(64)
|
| 38 |
-
self.bn2 = nn.BatchNorm1d(
|
| 39 |
-
self.bn3 = nn.BatchNorm1d(
|
| 40 |
-
self.bn4 = nn.BatchNorm1d(
|
| 41 |
-
self.bn5 = nn.BatchNorm1d(
|
| 42 |
-
self.bn6 = nn.BatchNorm1d(
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
self.
|
| 46 |
-
self.
|
| 47 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
self.dropout4 = nn.Dropout(0.4)
|
| 49 |
-
self.dropout5 = nn.Dropout(0.
|
|
|
|
|
|
|
|
|
|
| 50 |
|
| 51 |
def forward(self, x):
|
| 52 |
"""
|
| 53 |
-
Forward pass
|
| 54 |
Args:
|
| 55 |
x: (batch_size, input_dim, max_points) tensor
|
| 56 |
Returns:
|
| 57 |
-
classification: (batch_size, 1) tensor of logits
|
| 58 |
"""
|
| 59 |
batch_size = x.size(0)
|
| 60 |
|
| 61 |
-
#
|
| 62 |
x1 = F.relu(self.bn1(self.conv1(x)))
|
| 63 |
x2 = F.relu(self.bn2(self.conv2(x1)))
|
|
|
|
|
|
|
| 64 |
x3 = F.relu(self.bn3(self.conv3(x2)))
|
| 65 |
x4 = F.relu(self.bn4(self.conv4(x3)))
|
|
|
|
|
|
|
|
|
|
| 66 |
x5 = F.relu(self.bn5(self.conv5(x4)))
|
| 67 |
x6 = F.relu(self.bn6(self.conv6(x5)))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
|
| 69 |
-
#
|
| 70 |
-
|
| 71 |
|
| 72 |
-
#
|
| 73 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
x = self.dropout1(x)
|
| 75 |
-
|
|
|
|
|
|
|
| 76 |
x = self.dropout2(x)
|
| 77 |
-
|
|
|
|
| 78 |
x = self.dropout3(x)
|
| 79 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
x = self.dropout4(x)
|
| 81 |
-
|
|
|
|
| 82 |
x = self.dropout5(x)
|
| 83 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
|
| 85 |
return classification
|
| 86 |
|
|
@@ -401,6 +523,5 @@ def predict_class_from_patch(model: ClassificationPointNet, patch: Dict, device:
|
|
| 401 |
outputs = model(patch_tensor) # (1, 1)
|
| 402 |
probability = torch.sigmoid(outputs).item()
|
| 403 |
predicted_class = int(probability > 0.5)
|
| 404 |
-
confidence = probability if predicted_class == 1 else (1 - probability)
|
| 405 |
|
| 406 |
-
return predicted_class,
|
|
|
|
| 10 |
|
| 11 |
class ClassificationPointNet(nn.Module):
|
| 12 |
"""
|
| 13 |
+
Enhanced PointNet implementation for binary classification from 6D point cloud patches.
|
| 14 |
Takes 6D point clouds (x,y,z,r,g,b) and predicts binary classification (edge/not edge).
|
| 15 |
+
Features: Residual connections, attention mechanism, multi-scale features, deeper architecture.
|
| 16 |
"""
|
| 17 |
def __init__(self, input_dim=6, max_points=1024):
|
| 18 |
super(ClassificationPointNet, self).__init__()
|
| 19 |
self.max_points = max_points
|
| 20 |
|
| 21 |
+
# Point-wise MLPs with residual connections (much deeper)
|
| 22 |
self.conv1 = nn.Conv1d(input_dim, 64, 1)
|
| 23 |
+
self.conv2 = nn.Conv1d(64, 64, 1)
|
| 24 |
+
self.conv3 = nn.Conv1d(64, 128, 1)
|
| 25 |
+
self.conv4 = nn.Conv1d(128, 128, 1)
|
| 26 |
+
self.conv5 = nn.Conv1d(128, 256, 1)
|
| 27 |
+
self.conv6 = nn.Conv1d(256, 256, 1)
|
| 28 |
+
self.conv7 = nn.Conv1d(256, 512, 1)
|
| 29 |
+
self.conv8 = nn.Conv1d(512, 512, 1)
|
| 30 |
+
self.conv9 = nn.Conv1d(512, 1024, 1)
|
| 31 |
+
self.conv10 = nn.Conv1d(1024, 1024, 1)
|
| 32 |
+
self.conv11 = nn.Conv1d(1024, 2048, 1)
|
| 33 |
+
|
| 34 |
+
# Residual connection layers
|
| 35 |
+
self.res_conv1 = nn.Conv1d(64, 128, 1)
|
| 36 |
+
self.res_conv2 = nn.Conv1d(128, 256, 1)
|
| 37 |
+
self.res_conv3 = nn.Conv1d(256, 512, 1)
|
| 38 |
+
self.res_conv4 = nn.Conv1d(512, 1024, 1)
|
| 39 |
+
|
| 40 |
+
# Self-attention mechanism
|
| 41 |
+
self.attention = nn.MultiheadAttention(embed_dim=2048, num_heads=8, batch_first=True)
|
| 42 |
+
self.attention_norm = nn.LayerNorm(2048)
|
| 43 |
+
|
| 44 |
+
# Multi-scale feature aggregation
|
| 45 |
+
self.scale_conv1 = nn.Conv1d(2048, 512, 1)
|
| 46 |
+
self.scale_conv2 = nn.Conv1d(2048, 512, 1)
|
| 47 |
+
self.scale_conv3 = nn.Conv1d(2048, 512, 1)
|
| 48 |
+
|
| 49 |
+
# Enhanced classification head with residual connections
|
| 50 |
+
self.fc1 = nn.Linear(4096, 2048) # Increased input due to multi-scale features
|
| 51 |
+
self.fc2 = nn.Linear(2048, 2048)
|
| 52 |
+
self.fc3 = nn.Linear(2048, 1024)
|
| 53 |
+
self.fc4 = nn.Linear(1024, 1024)
|
| 54 |
+
self.fc5 = nn.Linear(1024, 512)
|
| 55 |
+
self.fc6 = nn.Linear(512, 512)
|
| 56 |
+
self.fc7 = nn.Linear(512, 256)
|
| 57 |
+
self.fc8 = nn.Linear(256, 128)
|
| 58 |
+
self.fc9 = nn.Linear(128, 64)
|
| 59 |
+
self.fc10 = nn.Linear(64, 1)
|
| 60 |
+
|
| 61 |
+
# Residual connections for FC layers
|
| 62 |
+
self.fc_res1 = nn.Linear(2048, 1024)
|
| 63 |
+
self.fc_res2 = nn.Linear(1024, 512)
|
| 64 |
+
self.fc_res3 = nn.Linear(512, 256)
|
| 65 |
|
| 66 |
# Batch normalization layers
|
| 67 |
self.bn1 = nn.BatchNorm1d(64)
|
| 68 |
+
self.bn2 = nn.BatchNorm1d(64)
|
| 69 |
+
self.bn3 = nn.BatchNorm1d(128)
|
| 70 |
+
self.bn4 = nn.BatchNorm1d(128)
|
| 71 |
+
self.bn5 = nn.BatchNorm1d(256)
|
| 72 |
+
self.bn6 = nn.BatchNorm1d(256)
|
| 73 |
+
self.bn7 = nn.BatchNorm1d(512)
|
| 74 |
+
self.bn8 = nn.BatchNorm1d(512)
|
| 75 |
+
self.bn9 = nn.BatchNorm1d(1024)
|
| 76 |
+
self.bn10 = nn.BatchNorm1d(1024)
|
| 77 |
+
self.bn11 = nn.BatchNorm1d(2048)
|
| 78 |
+
|
| 79 |
+
# Scale batch norms
|
| 80 |
+
self.scale_bn1 = nn.BatchNorm1d(512)
|
| 81 |
+
self.scale_bn2 = nn.BatchNorm1d(512)
|
| 82 |
+
self.scale_bn3 = nn.BatchNorm1d(512)
|
| 83 |
+
|
| 84 |
+
# FC batch norms
|
| 85 |
+
self.fc_bn1 = nn.BatchNorm1d(2048)
|
| 86 |
+
self.fc_bn2 = nn.BatchNorm1d(2048)
|
| 87 |
+
self.fc_bn3 = nn.BatchNorm1d(1024)
|
| 88 |
+
self.fc_bn4 = nn.BatchNorm1d(1024)
|
| 89 |
+
self.fc_bn5 = nn.BatchNorm1d(512)
|
| 90 |
+
self.fc_bn6 = nn.BatchNorm1d(512)
|
| 91 |
+
self.fc_bn7 = nn.BatchNorm1d(256)
|
| 92 |
+
self.fc_bn8 = nn.BatchNorm1d(128)
|
| 93 |
+
|
| 94 |
+
# Dropout layers with varying rates
|
| 95 |
+
self.dropout1 = nn.Dropout(0.1)
|
| 96 |
+
self.dropout2 = nn.Dropout(0.2)
|
| 97 |
+
self.dropout3 = nn.Dropout(0.3)
|
| 98 |
self.dropout4 = nn.Dropout(0.4)
|
| 99 |
+
self.dropout5 = nn.Dropout(0.5)
|
| 100 |
+
self.dropout6 = nn.Dropout(0.4)
|
| 101 |
+
self.dropout7 = nn.Dropout(0.3)
|
| 102 |
+
self.dropout8 = nn.Dropout(0.2)
|
| 103 |
|
| 104 |
def forward(self, x):
|
| 105 |
"""
|
| 106 |
+
Forward pass with residual connections and attention
|
| 107 |
Args:
|
| 108 |
x: (batch_size, input_dim, max_points) tensor
|
| 109 |
Returns:
|
| 110 |
+
classification: (batch_size, 1) tensor of logits
|
| 111 |
"""
|
| 112 |
batch_size = x.size(0)
|
| 113 |
|
| 114 |
+
# Deep point-wise feature extraction with residual connections
|
| 115 |
x1 = F.relu(self.bn1(self.conv1(x)))
|
| 116 |
x2 = F.relu(self.bn2(self.conv2(x1)))
|
| 117 |
+
x2 = x2 + x1 # Residual connection
|
| 118 |
+
|
| 119 |
x3 = F.relu(self.bn3(self.conv3(x2)))
|
| 120 |
x4 = F.relu(self.bn4(self.conv4(x3)))
|
| 121 |
+
res1 = self.res_conv1(x2)
|
| 122 |
+
x4 = x4 + res1 # Residual connection
|
| 123 |
+
|
| 124 |
x5 = F.relu(self.bn5(self.conv5(x4)))
|
| 125 |
x6 = F.relu(self.bn6(self.conv6(x5)))
|
| 126 |
+
res2 = self.res_conv2(x4)
|
| 127 |
+
x6 = x6 + res2 # Residual connection
|
| 128 |
+
|
| 129 |
+
x7 = F.relu(self.bn7(self.conv7(x6)))
|
| 130 |
+
x8 = F.relu(self.bn8(self.conv8(x7)))
|
| 131 |
+
res3 = self.res_conv3(x6)
|
| 132 |
+
x8 = x8 + res3 # Residual connection
|
| 133 |
+
|
| 134 |
+
x9 = F.relu(self.bn9(self.conv9(x8)))
|
| 135 |
+
x10 = F.relu(self.bn10(self.conv10(x9)))
|
| 136 |
+
res4 = self.res_conv4(x8)
|
| 137 |
+
x10 = x10 + res4 # Residual connection
|
| 138 |
+
|
| 139 |
+
x11 = F.relu(self.bn11(self.conv11(x10)))
|
| 140 |
+
|
| 141 |
+
# Multi-scale global pooling
|
| 142 |
+
# Max pooling
|
| 143 |
+
global_max = torch.max(x11, 2)[0] # (batch_size, 2048)
|
| 144 |
|
| 145 |
+
# Average pooling
|
| 146 |
+
global_avg = torch.mean(x11, 2) # (batch_size, 2048)
|
| 147 |
|
| 148 |
+
# Attention-based pooling
|
| 149 |
+
x11_transposed = x11.transpose(1, 2) # (batch_size, max_points, 2048)
|
| 150 |
+
attended, _ = self.attention(x11_transposed, x11_transposed, x11_transposed)
|
| 151 |
+
attended = self.attention_norm(attended + x11_transposed)
|
| 152 |
+
global_att = torch.mean(attended, 1) # (batch_size, 2048)
|
| 153 |
+
|
| 154 |
+
# Multi-scale feature extraction
|
| 155 |
+
scale1 = F.relu(self.scale_bn1(self.scale_conv1(x11)))
|
| 156 |
+
scale1_pool = torch.max(scale1, 2)[0]
|
| 157 |
+
|
| 158 |
+
scale2 = F.relu(self.scale_bn2(self.scale_conv2(x11)))
|
| 159 |
+
scale2_pool = torch.mean(scale2, 2)
|
| 160 |
+
|
| 161 |
+
scale3 = F.relu(self.scale_bn3(self.scale_conv3(x11)))
|
| 162 |
+
scale3_pool = torch.std(scale3, 2)
|
| 163 |
+
|
| 164 |
+
# Concatenate all global features
|
| 165 |
+
global_features = torch.cat([
|
| 166 |
+
global_max, global_avg, global_att,
|
| 167 |
+
scale1_pool, scale2_pool, scale3_pool
|
| 168 |
+
], dim=1) # (batch_size, 4096)
|
| 169 |
+
|
| 170 |
+
# Enhanced classification head with residual connections
|
| 171 |
+
x = F.relu(self.fc_bn1(self.fc1(global_features)))
|
| 172 |
x = self.dropout1(x)
|
| 173 |
+
|
| 174 |
+
x = F.relu(self.fc_bn2(self.fc2(x)))
|
| 175 |
+
identity1 = x
|
| 176 |
x = self.dropout2(x)
|
| 177 |
+
|
| 178 |
+
x = F.relu(self.fc_bn3(self.fc3(x)))
|
| 179 |
x = self.dropout3(x)
|
| 180 |
+
|
| 181 |
+
x = F.relu(self.fc_bn4(self.fc4(x)))
|
| 182 |
+
res_fc1 = self.fc_res1(identity1)
|
| 183 |
+
x = x + res_fc1 # Residual connection
|
| 184 |
+
identity2 = x
|
| 185 |
x = self.dropout4(x)
|
| 186 |
+
|
| 187 |
+
x = F.relu(self.fc_bn5(self.fc5(x)))
|
| 188 |
x = self.dropout5(x)
|
| 189 |
+
|
| 190 |
+
x = F.relu(self.fc_bn6(self.fc6(x)))
|
| 191 |
+
res_fc2 = self.fc_res2(identity2)
|
| 192 |
+
x = x + res_fc2 # Residual connection
|
| 193 |
+
identity3 = x
|
| 194 |
+
x = self.dropout6(x)
|
| 195 |
+
|
| 196 |
+
x = F.relu(self.fc_bn7(self.fc7(x)))
|
| 197 |
+
x = self.dropout7(x)
|
| 198 |
+
|
| 199 |
+
x = F.relu(self.fc_bn8(self.fc8(x)))
|
| 200 |
+
res_fc3 = self.fc_res3(identity3)
|
| 201 |
+
x = x + res_fc3 # Residual connection
|
| 202 |
+
x = self.dropout8(x)
|
| 203 |
+
|
| 204 |
+
x = F.relu(self.fc9(x))
|
| 205 |
+
classification = self.fc10(x) # (batch_size, 1)
|
| 206 |
|
| 207 |
return classification
|
| 208 |
|
|
|
|
| 523 |
outputs = model(patch_tensor) # (1, 1)
|
| 524 |
probability = torch.sigmoid(outputs).item()
|
| 525 |
predicted_class = int(probability > 0.5)
|
|
|
|
| 526 |
|
| 527 |
+
return predicted_class, probability
|
train_pnet_class_cluster.py
CHANGED
|
@@ -5,9 +5,9 @@ if __name__ == "__main__":
|
|
| 5 |
|
| 6 |
# Load the dataset
|
| 7 |
dataset_path = "/mnt/personal/skvrnjan/hohocustom_edges/"
|
| 8 |
-
model_save_path = "/mnt/personal/skvrnjan/
|
| 9 |
|
| 10 |
os.makedirs(model_save_path, exist_ok=True)
|
| 11 |
|
| 12 |
# Train the model
|
| 13 |
-
train_pointnet(dataset_path, model_save_path, epochs=100, batch_size=
|
|
|
|
| 5 |
|
| 6 |
# Load the dataset
|
| 7 |
dataset_path = "/mnt/personal/skvrnjan/hohocustom_edges/"
|
| 8 |
+
model_save_path = "/mnt/personal/skvrnjan/hoho_pnet_edges_stronger/initial.pth"
|
| 9 |
|
| 10 |
os.makedirs(model_save_path, exist_ok=True)
|
| 11 |
|
| 12 |
# Train the model
|
| 13 |
+
train_pointnet(dataset_path, model_save_path, epochs=100, batch_size=128, lr=0.001)
|