jskvrna commited on
Commit
81255ac
·
1 Parent(s): 5d8d206

Improves PointNet classification model.

Browse files

This commit introduces a significantly enhanced PointNet
implementation for binary classification of 6D point cloud
patches.

Key improvements include:

- A deeper architecture with residual connections for enhanced
feature extraction.
- An attention mechanism for improved feature weighting.
- Multi-scale feature aggregation for richer representations.
- An enhanced classification head with residual connections and varying dropout rates for better generalization.

Additionally, the training batch size was reduced to 128, and the model save path was updated to reflect the new, improved model.

Files changed (2) hide show
  1. fast_pointnet_class.py +161 -40
  2. train_pnet_class_cluster.py +2 -2
fast_pointnet_class.py CHANGED
@@ -10,77 +10,199 @@ import json
10
 
11
  class ClassificationPointNet(nn.Module):
12
  """
13
- PointNet implementation for binary classification from 6D point cloud patches.
14
  Takes 6D point clouds (x,y,z,r,g,b) and predicts binary classification (edge/not edge).
 
15
  """
16
  def __init__(self, input_dim=6, max_points=1024):
17
  super(ClassificationPointNet, self).__init__()
18
  self.max_points = max_points
19
 
20
- # Point-wise MLPs for feature extraction (deeper network)
21
  self.conv1 = nn.Conv1d(input_dim, 64, 1)
22
- self.conv2 = nn.Conv1d(64, 128, 1)
23
- self.conv3 = nn.Conv1d(128, 256, 1)
24
- self.conv4 = nn.Conv1d(256, 512, 1)
25
- self.conv5 = nn.Conv1d(512, 1024, 1)
26
- self.conv6 = nn.Conv1d(1024, 2048, 1) # Additional layer
27
-
28
- # Classification head (deeper with more capacity)
29
- self.fc1 = nn.Linear(2048, 1024)
30
- self.fc2 = nn.Linear(1024, 512)
31
- self.fc3 = nn.Linear(512, 256)
32
- self.fc4 = nn.Linear(256, 128)
33
- self.fc5 = nn.Linear(128, 64)
34
- self.fc6 = nn.Linear(64, 1) # Single output for binary classification
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  # Batch normalization layers
37
  self.bn1 = nn.BatchNorm1d(64)
38
- self.bn2 = nn.BatchNorm1d(128)
39
- self.bn3 = nn.BatchNorm1d(256)
40
- self.bn4 = nn.BatchNorm1d(512)
41
- self.bn5 = nn.BatchNorm1d(1024)
42
- self.bn6 = nn.BatchNorm1d(2048)
43
-
44
- # Dropout layers
45
- self.dropout1 = nn.Dropout(0.3)
46
- self.dropout2 = nn.Dropout(0.4)
47
- self.dropout3 = nn.Dropout(0.5)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  self.dropout4 = nn.Dropout(0.4)
49
- self.dropout5 = nn.Dropout(0.3)
 
 
 
50
 
51
  def forward(self, x):
52
  """
53
- Forward pass
54
  Args:
55
  x: (batch_size, input_dim, max_points) tensor
56
  Returns:
57
- classification: (batch_size, 1) tensor of logits (sigmoid for probability)
58
  """
59
  batch_size = x.size(0)
60
 
61
- # Point-wise feature extraction
62
  x1 = F.relu(self.bn1(self.conv1(x)))
63
  x2 = F.relu(self.bn2(self.conv2(x1)))
 
 
64
  x3 = F.relu(self.bn3(self.conv3(x2)))
65
  x4 = F.relu(self.bn4(self.conv4(x3)))
 
 
 
66
  x5 = F.relu(self.bn5(self.conv5(x4)))
67
  x6 = F.relu(self.bn6(self.conv6(x5)))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
- # Global max pooling
70
- global_features = torch.max(x6, 2)[0] # (batch_size, 2048)
71
 
72
- # Classification head
73
- x = F.relu(self.fc1(global_features))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  x = self.dropout1(x)
75
- x = F.relu(self.fc2(x))
 
 
76
  x = self.dropout2(x)
77
- x = F.relu(self.fc3(x))
 
78
  x = self.dropout3(x)
79
- x = F.relu(self.fc4(x))
 
 
 
 
80
  x = self.dropout4(x)
81
- x = F.relu(self.fc5(x))
 
82
  x = self.dropout5(x)
83
- classification = self.fc6(x) # (batch_size, 1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
  return classification
86
 
@@ -401,6 +523,5 @@ def predict_class_from_patch(model: ClassificationPointNet, patch: Dict, device:
401
  outputs = model(patch_tensor) # (1, 1)
402
  probability = torch.sigmoid(outputs).item()
403
  predicted_class = int(probability > 0.5)
404
- confidence = probability if predicted_class == 1 else (1 - probability)
405
 
406
- return predicted_class, confidence
 
10
 
11
  class ClassificationPointNet(nn.Module):
12
  """
13
+ Enhanced PointNet implementation for binary classification from 6D point cloud patches.
14
  Takes 6D point clouds (x,y,z,r,g,b) and predicts binary classification (edge/not edge).
15
+ Features: Residual connections, attention mechanism, multi-scale features, deeper architecture.
16
  """
17
  def __init__(self, input_dim=6, max_points=1024):
18
  super(ClassificationPointNet, self).__init__()
19
  self.max_points = max_points
20
 
21
+ # Point-wise MLPs with residual connections (much deeper)
22
  self.conv1 = nn.Conv1d(input_dim, 64, 1)
23
+ self.conv2 = nn.Conv1d(64, 64, 1)
24
+ self.conv3 = nn.Conv1d(64, 128, 1)
25
+ self.conv4 = nn.Conv1d(128, 128, 1)
26
+ self.conv5 = nn.Conv1d(128, 256, 1)
27
+ self.conv6 = nn.Conv1d(256, 256, 1)
28
+ self.conv7 = nn.Conv1d(256, 512, 1)
29
+ self.conv8 = nn.Conv1d(512, 512, 1)
30
+ self.conv9 = nn.Conv1d(512, 1024, 1)
31
+ self.conv10 = nn.Conv1d(1024, 1024, 1)
32
+ self.conv11 = nn.Conv1d(1024, 2048, 1)
33
+
34
+ # Residual connection layers
35
+ self.res_conv1 = nn.Conv1d(64, 128, 1)
36
+ self.res_conv2 = nn.Conv1d(128, 256, 1)
37
+ self.res_conv3 = nn.Conv1d(256, 512, 1)
38
+ self.res_conv4 = nn.Conv1d(512, 1024, 1)
39
+
40
+ # Self-attention mechanism
41
+ self.attention = nn.MultiheadAttention(embed_dim=2048, num_heads=8, batch_first=True)
42
+ self.attention_norm = nn.LayerNorm(2048)
43
+
44
+ # Multi-scale feature aggregation
45
+ self.scale_conv1 = nn.Conv1d(2048, 512, 1)
46
+ self.scale_conv2 = nn.Conv1d(2048, 512, 1)
47
+ self.scale_conv3 = nn.Conv1d(2048, 512, 1)
48
+
49
+ # Enhanced classification head with residual connections
50
+ self.fc1 = nn.Linear(4096, 2048) # Increased input due to multi-scale features
51
+ self.fc2 = nn.Linear(2048, 2048)
52
+ self.fc3 = nn.Linear(2048, 1024)
53
+ self.fc4 = nn.Linear(1024, 1024)
54
+ self.fc5 = nn.Linear(1024, 512)
55
+ self.fc6 = nn.Linear(512, 512)
56
+ self.fc7 = nn.Linear(512, 256)
57
+ self.fc8 = nn.Linear(256, 128)
58
+ self.fc9 = nn.Linear(128, 64)
59
+ self.fc10 = nn.Linear(64, 1)
60
+
61
+ # Residual connections for FC layers
62
+ self.fc_res1 = nn.Linear(2048, 1024)
63
+ self.fc_res2 = nn.Linear(1024, 512)
64
+ self.fc_res3 = nn.Linear(512, 256)
65
 
66
  # Batch normalization layers
67
  self.bn1 = nn.BatchNorm1d(64)
68
+ self.bn2 = nn.BatchNorm1d(64)
69
+ self.bn3 = nn.BatchNorm1d(128)
70
+ self.bn4 = nn.BatchNorm1d(128)
71
+ self.bn5 = nn.BatchNorm1d(256)
72
+ self.bn6 = nn.BatchNorm1d(256)
73
+ self.bn7 = nn.BatchNorm1d(512)
74
+ self.bn8 = nn.BatchNorm1d(512)
75
+ self.bn9 = nn.BatchNorm1d(1024)
76
+ self.bn10 = nn.BatchNorm1d(1024)
77
+ self.bn11 = nn.BatchNorm1d(2048)
78
+
79
+ # Scale batch norms
80
+ self.scale_bn1 = nn.BatchNorm1d(512)
81
+ self.scale_bn2 = nn.BatchNorm1d(512)
82
+ self.scale_bn3 = nn.BatchNorm1d(512)
83
+
84
+ # FC batch norms
85
+ self.fc_bn1 = nn.BatchNorm1d(2048)
86
+ self.fc_bn2 = nn.BatchNorm1d(2048)
87
+ self.fc_bn3 = nn.BatchNorm1d(1024)
88
+ self.fc_bn4 = nn.BatchNorm1d(1024)
89
+ self.fc_bn5 = nn.BatchNorm1d(512)
90
+ self.fc_bn6 = nn.BatchNorm1d(512)
91
+ self.fc_bn7 = nn.BatchNorm1d(256)
92
+ self.fc_bn8 = nn.BatchNorm1d(128)
93
+
94
+ # Dropout layers with varying rates
95
+ self.dropout1 = nn.Dropout(0.1)
96
+ self.dropout2 = nn.Dropout(0.2)
97
+ self.dropout3 = nn.Dropout(0.3)
98
  self.dropout4 = nn.Dropout(0.4)
99
+ self.dropout5 = nn.Dropout(0.5)
100
+ self.dropout6 = nn.Dropout(0.4)
101
+ self.dropout7 = nn.Dropout(0.3)
102
+ self.dropout8 = nn.Dropout(0.2)
103
 
104
  def forward(self, x):
105
  """
106
+ Forward pass with residual connections and attention
107
  Args:
108
  x: (batch_size, input_dim, max_points) tensor
109
  Returns:
110
+ classification: (batch_size, 1) tensor of logits
111
  """
112
  batch_size = x.size(0)
113
 
114
+ # Deep point-wise feature extraction with residual connections
115
  x1 = F.relu(self.bn1(self.conv1(x)))
116
  x2 = F.relu(self.bn2(self.conv2(x1)))
117
+ x2 = x2 + x1 # Residual connection
118
+
119
  x3 = F.relu(self.bn3(self.conv3(x2)))
120
  x4 = F.relu(self.bn4(self.conv4(x3)))
121
+ res1 = self.res_conv1(x2)
122
+ x4 = x4 + res1 # Residual connection
123
+
124
  x5 = F.relu(self.bn5(self.conv5(x4)))
125
  x6 = F.relu(self.bn6(self.conv6(x5)))
126
+ res2 = self.res_conv2(x4)
127
+ x6 = x6 + res2 # Residual connection
128
+
129
+ x7 = F.relu(self.bn7(self.conv7(x6)))
130
+ x8 = F.relu(self.bn8(self.conv8(x7)))
131
+ res3 = self.res_conv3(x6)
132
+ x8 = x8 + res3 # Residual connection
133
+
134
+ x9 = F.relu(self.bn9(self.conv9(x8)))
135
+ x10 = F.relu(self.bn10(self.conv10(x9)))
136
+ res4 = self.res_conv4(x8)
137
+ x10 = x10 + res4 # Residual connection
138
+
139
+ x11 = F.relu(self.bn11(self.conv11(x10)))
140
+
141
+ # Multi-scale global pooling
142
+ # Max pooling
143
+ global_max = torch.max(x11, 2)[0] # (batch_size, 2048)
144
 
145
+ # Average pooling
146
+ global_avg = torch.mean(x11, 2) # (batch_size, 2048)
147
 
148
+ # Attention-based pooling
149
+ x11_transposed = x11.transpose(1, 2) # (batch_size, max_points, 2048)
150
+ attended, _ = self.attention(x11_transposed, x11_transposed, x11_transposed)
151
+ attended = self.attention_norm(attended + x11_transposed)
152
+ global_att = torch.mean(attended, 1) # (batch_size, 2048)
153
+
154
+ # Multi-scale feature extraction
155
+ scale1 = F.relu(self.scale_bn1(self.scale_conv1(x11)))
156
+ scale1_pool = torch.max(scale1, 2)[0]
157
+
158
+ scale2 = F.relu(self.scale_bn2(self.scale_conv2(x11)))
159
+ scale2_pool = torch.mean(scale2, 2)
160
+
161
+ scale3 = F.relu(self.scale_bn3(self.scale_conv3(x11)))
162
+ scale3_pool = torch.std(scale3, 2)
163
+
164
+ # Concatenate all global features
165
+ global_features = torch.cat([
166
+ global_max, global_avg, global_att,
167
+ scale1_pool, scale2_pool, scale3_pool
168
+ ], dim=1) # (batch_size, 4096)
169
+
170
+ # Enhanced classification head with residual connections
171
+ x = F.relu(self.fc_bn1(self.fc1(global_features)))
172
  x = self.dropout1(x)
173
+
174
+ x = F.relu(self.fc_bn2(self.fc2(x)))
175
+ identity1 = x
176
  x = self.dropout2(x)
177
+
178
+ x = F.relu(self.fc_bn3(self.fc3(x)))
179
  x = self.dropout3(x)
180
+
181
+ x = F.relu(self.fc_bn4(self.fc4(x)))
182
+ res_fc1 = self.fc_res1(identity1)
183
+ x = x + res_fc1 # Residual connection
184
+ identity2 = x
185
  x = self.dropout4(x)
186
+
187
+ x = F.relu(self.fc_bn5(self.fc5(x)))
188
  x = self.dropout5(x)
189
+
190
+ x = F.relu(self.fc_bn6(self.fc6(x)))
191
+ res_fc2 = self.fc_res2(identity2)
192
+ x = x + res_fc2 # Residual connection
193
+ identity3 = x
194
+ x = self.dropout6(x)
195
+
196
+ x = F.relu(self.fc_bn7(self.fc7(x)))
197
+ x = self.dropout7(x)
198
+
199
+ x = F.relu(self.fc_bn8(self.fc8(x)))
200
+ res_fc3 = self.fc_res3(identity3)
201
+ x = x + res_fc3 # Residual connection
202
+ x = self.dropout8(x)
203
+
204
+ x = F.relu(self.fc9(x))
205
+ classification = self.fc10(x) # (batch_size, 1)
206
 
207
  return classification
208
 
 
523
  outputs = model(patch_tensor) # (1, 1)
524
  probability = torch.sigmoid(outputs).item()
525
  predicted_class = int(probability > 0.5)
 
526
 
527
+ return predicted_class, probability
train_pnet_class_cluster.py CHANGED
@@ -5,9 +5,9 @@ if __name__ == "__main__":
5
 
6
  # Load the dataset
7
  dataset_path = "/mnt/personal/skvrnjan/hohocustom_edges/"
8
- model_save_path = "/mnt/personal/skvrnjan/hoho_pnet_edges_v2/initial.pth"
9
 
10
  os.makedirs(model_save_path, exist_ok=True)
11
 
12
  # Train the model
13
- train_pointnet(dataset_path, model_save_path, epochs=100, batch_size=512, lr=0.001)
 
5
 
6
  # Load the dataset
7
  dataset_path = "/mnt/personal/skvrnjan/hohocustom_edges/"
8
+ model_save_path = "/mnt/personal/skvrnjan/hoho_pnet_edges_stronger/initial.pth"
9
 
10
  os.makedirs(model_save_path, exist_ok=True)
11
 
12
  # Train the model
13
+ train_pointnet(dataset_path, model_save_path, epochs=100, batch_size=128, lr=0.001)