jskvrna commited on
Commit
5291a52
·
1 Parent(s): 075717a

Adds PointNet class for classification

Browse files

Implements a PointNet model for binary classification of 6D point cloud patches.

Includes a dataset class for loading and augmenting patch data, along with
helper functions for saving patches, creating a data loader with a custom
collate function, initializing weights, training the model, and loading the
trained model for prediction.

Files changed (2) hide show
  1. fast_pointnet_class.py +38 -160
  2. fast_pointnet_class_deeper.py +527 -0
fast_pointnet_class.py CHANGED
@@ -10,199 +10,77 @@ import json
10
 
11
  class ClassificationPointNet(nn.Module):
12
  """
13
- Enhanced PointNet implementation for binary classification from 6D point cloud patches.
14
  Takes 6D point clouds (x,y,z,r,g,b) and predicts binary classification (edge/not edge).
15
- Features: Residual connections, attention mechanism, multi-scale features, deeper architecture.
16
  """
17
  def __init__(self, input_dim=6, max_points=1024):
18
  super(ClassificationPointNet, self).__init__()
19
  self.max_points = max_points
20
 
21
- # Point-wise MLPs with residual connections (much deeper)
22
  self.conv1 = nn.Conv1d(input_dim, 64, 1)
23
- self.conv2 = nn.Conv1d(64, 64, 1)
24
- self.conv3 = nn.Conv1d(64, 128, 1)
25
- self.conv4 = nn.Conv1d(128, 128, 1)
26
- self.conv5 = nn.Conv1d(128, 256, 1)
27
- self.conv6 = nn.Conv1d(256, 256, 1)
28
- self.conv7 = nn.Conv1d(256, 512, 1)
29
- self.conv8 = nn.Conv1d(512, 512, 1)
30
- self.conv9 = nn.Conv1d(512, 1024, 1)
31
- self.conv10 = nn.Conv1d(1024, 1024, 1)
32
- self.conv11 = nn.Conv1d(1024, 2048, 1)
33
-
34
- # Residual connection layers
35
- self.res_conv1 = nn.Conv1d(64, 128, 1)
36
- self.res_conv2 = nn.Conv1d(128, 256, 1)
37
- self.res_conv3 = nn.Conv1d(256, 512, 1)
38
- self.res_conv4 = nn.Conv1d(512, 1024, 1)
39
-
40
- # Self-attention mechanism
41
- self.attention = nn.MultiheadAttention(embed_dim=2048, num_heads=8, batch_first=True)
42
- self.attention_norm = nn.LayerNorm(2048)
43
-
44
- # Multi-scale feature aggregation
45
- self.scale_conv1 = nn.Conv1d(2048, 512, 1)
46
- self.scale_conv2 = nn.Conv1d(2048, 512, 1)
47
- self.scale_conv3 = nn.Conv1d(2048, 512, 1)
48
-
49
- # Enhanced classification head with residual connections
50
- self.fc1 = nn.Linear(7680, 2048) # Updated input size: 2048*3 + 512*3 = 7680
51
- self.fc2 = nn.Linear(2048, 2048)
52
- self.fc3 = nn.Linear(2048, 1024)
53
- self.fc4 = nn.Linear(1024, 1024)
54
- self.fc5 = nn.Linear(1024, 512)
55
- self.fc6 = nn.Linear(512, 512)
56
- self.fc7 = nn.Linear(512, 256)
57
- self.fc8 = nn.Linear(256, 128)
58
- self.fc9 = nn.Linear(128, 64)
59
- self.fc10 = nn.Linear(64, 1)
60
-
61
- # Residual connections for FC layers
62
- self.fc_res1 = nn.Linear(2048, 1024)
63
- self.fc_res2 = nn.Linear(1024, 512)
64
- self.fc_res3 = nn.Linear(512, 128)
65
 
66
  # Batch normalization layers
67
  self.bn1 = nn.BatchNorm1d(64)
68
- self.bn2 = nn.BatchNorm1d(64)
69
- self.bn3 = nn.BatchNorm1d(128)
70
- self.bn4 = nn.BatchNorm1d(128)
71
- self.bn5 = nn.BatchNorm1d(256)
72
- self.bn6 = nn.BatchNorm1d(256)
73
- self.bn7 = nn.BatchNorm1d(512)
74
- self.bn8 = nn.BatchNorm1d(512)
75
- self.bn9 = nn.BatchNorm1d(1024)
76
- self.bn10 = nn.BatchNorm1d(1024)
77
- self.bn11 = nn.BatchNorm1d(2048)
78
-
79
- # Scale batch norms
80
- self.scale_bn1 = nn.BatchNorm1d(512)
81
- self.scale_bn2 = nn.BatchNorm1d(512)
82
- self.scale_bn3 = nn.BatchNorm1d(512)
83
-
84
- # FC batch norms
85
- self.fc_bn1 = nn.BatchNorm1d(2048)
86
- self.fc_bn2 = nn.BatchNorm1d(2048)
87
- self.fc_bn3 = nn.BatchNorm1d(1024)
88
- self.fc_bn4 = nn.BatchNorm1d(1024)
89
- self.fc_bn5 = nn.BatchNorm1d(512)
90
- self.fc_bn6 = nn.BatchNorm1d(512)
91
- self.fc_bn7 = nn.BatchNorm1d(256)
92
- self.fc_bn8 = nn.BatchNorm1d(128)
93
-
94
- # Dropout layers with varying rates
95
- self.dropout1 = nn.Dropout(0.1)
96
- self.dropout2 = nn.Dropout(0.2)
97
- self.dropout3 = nn.Dropout(0.3)
98
  self.dropout4 = nn.Dropout(0.4)
99
- self.dropout5 = nn.Dropout(0.5)
100
- self.dropout6 = nn.Dropout(0.4)
101
- self.dropout7 = nn.Dropout(0.3)
102
- self.dropout8 = nn.Dropout(0.2)
103
 
104
  def forward(self, x):
105
  """
106
- Forward pass with residual connections and attention
107
  Args:
108
  x: (batch_size, input_dim, max_points) tensor
109
  Returns:
110
- classification: (batch_size, 1) tensor of logits
111
  """
112
  batch_size = x.size(0)
113
 
114
- # Deep point-wise feature extraction with residual connections
115
  x1 = F.relu(self.bn1(self.conv1(x)))
116
  x2 = F.relu(self.bn2(self.conv2(x1)))
117
- x2 = x2 + x1 # Residual connection
118
-
119
  x3 = F.relu(self.bn3(self.conv3(x2)))
120
  x4 = F.relu(self.bn4(self.conv4(x3)))
121
- res1 = self.res_conv1(x2)
122
- x4 = x4 + res1 # Residual connection
123
-
124
  x5 = F.relu(self.bn5(self.conv5(x4)))
125
  x6 = F.relu(self.bn6(self.conv6(x5)))
126
- res2 = self.res_conv2(x4)
127
- x6 = x6 + res2 # Residual connection
128
-
129
- x7 = F.relu(self.bn7(self.conv7(x6)))
130
- x8 = F.relu(self.bn8(self.conv8(x7)))
131
- res3 = self.res_conv3(x6)
132
- x8 = x8 + res3 # Residual connection
133
-
134
- x9 = F.relu(self.bn9(self.conv9(x8)))
135
- x10 = F.relu(self.bn10(self.conv10(x9)))
136
- res4 = self.res_conv4(x8)
137
- x10 = x10 + res4 # Residual connection
138
-
139
- x11 = F.relu(self.bn11(self.conv11(x10)))
140
-
141
- # Multi-scale global pooling
142
- # Max pooling
143
- global_max = torch.max(x11, 2)[0] # (batch_size, 2048)
144
 
145
- # Average pooling
146
- global_avg = torch.mean(x11, 2) # (batch_size, 2048)
147
 
148
- # Attention-based pooling
149
- x11_transposed = x11.transpose(1, 2) # (batch_size, max_points, 2048)
150
- attended, _ = self.attention(x11_transposed, x11_transposed, x11_transposed)
151
- attended = self.attention_norm(attended + x11_transposed)
152
- global_att = torch.mean(attended, 1) # (batch_size, 2048)
153
-
154
- # Multi-scale feature extraction
155
- scale1 = F.relu(self.scale_bn1(self.scale_conv1(x11)))
156
- scale1_pool = torch.max(scale1, 2)[0]
157
-
158
- scale2 = F.relu(self.scale_bn2(self.scale_conv2(x11)))
159
- scale2_pool = torch.mean(scale2, 2)
160
-
161
- scale3 = F.relu(self.scale_bn3(self.scale_conv3(x11)))
162
- scale3_pool = torch.std(scale3, 2)
163
-
164
- # Concatenate all global features
165
- global_features = torch.cat([
166
- global_max, global_avg, global_att,
167
- scale1_pool, scale2_pool, scale3_pool
168
- ], dim=1) # (batch_size, 4096)
169
-
170
- # Enhanced classification head with residual connections
171
- x = F.relu(self.fc_bn1(self.fc1(global_features)))
172
  x = self.dropout1(x)
173
-
174
- x = F.relu(self.fc_bn2(self.fc2(x)))
175
- identity1 = x
176
  x = self.dropout2(x)
177
-
178
- x = F.relu(self.fc_bn3(self.fc3(x)))
179
  x = self.dropout3(x)
180
-
181
- x = F.relu(self.fc_bn4(self.fc4(x)))
182
- res_fc1 = self.fc_res1(identity1)
183
- x = x + res_fc1 # Residual connection
184
- identity2 = x
185
  x = self.dropout4(x)
186
-
187
- x = F.relu(self.fc_bn5(self.fc5(x)))
188
  x = self.dropout5(x)
189
-
190
- x = F.relu(self.fc_bn6(self.fc6(x)))
191
- res_fc2 = self.fc_res2(identity2)
192
- x = x + res_fc2 # Residual connection
193
- identity3 = x
194
- x = self.dropout6(x)
195
-
196
- x = F.relu(self.fc_bn7(self.fc7(x)))
197
- x = self.dropout7(x)
198
-
199
- x = F.relu(self.fc_bn8(self.fc8(x)))
200
- res_fc3 = self.fc_res3(identity3)
201
- x = x + res_fc3 # Residual connection
202
- x = self.dropout8(x)
203
-
204
- x = F.relu(self.fc9(x))
205
- classification = self.fc10(x) # (batch_size, 1)
206
 
207
  return classification
208
 
 
10
 
11
  class ClassificationPointNet(nn.Module):
12
  """
13
+ PointNet implementation for binary classification from 6D point cloud patches.
14
  Takes 6D point clouds (x,y,z,r,g,b) and predicts binary classification (edge/not edge).
 
15
  """
16
  def __init__(self, input_dim=6, max_points=1024):
17
  super(ClassificationPointNet, self).__init__()
18
  self.max_points = max_points
19
 
20
+ # Point-wise MLPs for feature extraction (deeper network)
21
  self.conv1 = nn.Conv1d(input_dim, 64, 1)
22
+ self.conv2 = nn.Conv1d(64, 128, 1)
23
+ self.conv3 = nn.Conv1d(128, 256, 1)
24
+ self.conv4 = nn.Conv1d(256, 512, 1)
25
+ self.conv5 = nn.Conv1d(512, 1024, 1)
26
+ self.conv6 = nn.Conv1d(1024, 2048, 1) # Additional layer
27
+
28
+ # Classification head (deeper with more capacity)
29
+ self.fc1 = nn.Linear(2048, 1024)
30
+ self.fc2 = nn.Linear(1024, 512)
31
+ self.fc3 = nn.Linear(512, 256)
32
+ self.fc4 = nn.Linear(256, 128)
33
+ self.fc5 = nn.Linear(128, 64)
34
+ self.fc6 = nn.Linear(64, 1) # Single output for binary classification
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  # Batch normalization layers
37
  self.bn1 = nn.BatchNorm1d(64)
38
+ self.bn2 = nn.BatchNorm1d(128)
39
+ self.bn3 = nn.BatchNorm1d(256)
40
+ self.bn4 = nn.BatchNorm1d(512)
41
+ self.bn5 = nn.BatchNorm1d(1024)
42
+ self.bn6 = nn.BatchNorm1d(2048)
43
+
44
+ # Dropout layers
45
+ self.dropout1 = nn.Dropout(0.3)
46
+ self.dropout2 = nn.Dropout(0.4)
47
+ self.dropout3 = nn.Dropout(0.5)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  self.dropout4 = nn.Dropout(0.4)
49
+ self.dropout5 = nn.Dropout(0.3)
 
 
 
50
 
51
  def forward(self, x):
52
  """
53
+ Forward pass
54
  Args:
55
  x: (batch_size, input_dim, max_points) tensor
56
  Returns:
57
+ classification: (batch_size, 1) tensor of logits (sigmoid for probability)
58
  """
59
  batch_size = x.size(0)
60
 
61
+ # Point-wise feature extraction
62
  x1 = F.relu(self.bn1(self.conv1(x)))
63
  x2 = F.relu(self.bn2(self.conv2(x1)))
 
 
64
  x3 = F.relu(self.bn3(self.conv3(x2)))
65
  x4 = F.relu(self.bn4(self.conv4(x3)))
 
 
 
66
  x5 = F.relu(self.bn5(self.conv5(x4)))
67
  x6 = F.relu(self.bn6(self.conv6(x5)))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
+ # Global max pooling
70
+ global_features = torch.max(x6, 2)[0] # (batch_size, 2048)
71
 
72
+ # Classification head
73
+ x = F.relu(self.fc1(global_features))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  x = self.dropout1(x)
75
+ x = F.relu(self.fc2(x))
 
 
76
  x = self.dropout2(x)
77
+ x = F.relu(self.fc3(x))
 
78
  x = self.dropout3(x)
79
+ x = F.relu(self.fc4(x))
 
 
 
 
80
  x = self.dropout4(x)
81
+ x = F.relu(self.fc5(x))
 
82
  x = self.dropout5(x)
83
+ classification = self.fc6(x) # (batch_size, 1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
  return classification
86
 
fast_pointnet_class_deeper.py ADDED
@@ -0,0 +1,527 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import numpy as np
6
+ import pickle
7
+ from torch.utils.data import Dataset, DataLoader
8
+ from typing import List, Dict, Tuple, Optional
9
+ import json
10
+
11
+ class ClassificationPointNet(nn.Module):
12
+ """
13
+ Enhanced PointNet implementation for binary classification from 6D point cloud patches.
14
+ Takes 6D point clouds (x,y,z,r,g,b) and predicts binary classification (edge/not edge).
15
+ Features: Residual connections, attention mechanism, multi-scale features, deeper architecture.
16
+ """
17
+ def __init__(self, input_dim=6, max_points=1024):
18
+ super(ClassificationPointNet, self).__init__()
19
+ self.max_points = max_points
20
+
21
+ # Point-wise MLPs with residual connections (much deeper)
22
+ self.conv1 = nn.Conv1d(input_dim, 64, 1)
23
+ self.conv2 = nn.Conv1d(64, 64, 1)
24
+ self.conv3 = nn.Conv1d(64, 128, 1)
25
+ self.conv4 = nn.Conv1d(128, 128, 1)
26
+ self.conv5 = nn.Conv1d(128, 256, 1)
27
+ self.conv6 = nn.Conv1d(256, 256, 1)
28
+ self.conv7 = nn.Conv1d(256, 512, 1)
29
+ self.conv8 = nn.Conv1d(512, 512, 1)
30
+ self.conv9 = nn.Conv1d(512, 1024, 1)
31
+ self.conv10 = nn.Conv1d(1024, 1024, 1)
32
+ self.conv11 = nn.Conv1d(1024, 2048, 1)
33
+
34
+ # Residual connection layers
35
+ self.res_conv1 = nn.Conv1d(64, 128, 1)
36
+ self.res_conv2 = nn.Conv1d(128, 256, 1)
37
+ self.res_conv3 = nn.Conv1d(256, 512, 1)
38
+ self.res_conv4 = nn.Conv1d(512, 1024, 1)
39
+
40
+ # Self-attention mechanism
41
+ self.attention = nn.MultiheadAttention(embed_dim=2048, num_heads=8, batch_first=True)
42
+ self.attention_norm = nn.LayerNorm(2048)
43
+
44
+ # Multi-scale feature aggregation
45
+ self.scale_conv1 = nn.Conv1d(2048, 512, 1)
46
+ self.scale_conv2 = nn.Conv1d(2048, 512, 1)
47
+ self.scale_conv3 = nn.Conv1d(2048, 512, 1)
48
+
49
+ # Enhanced classification head with residual connections
50
+ self.fc1 = nn.Linear(7680, 2048) # Updated input size: 2048*3 + 512*3 = 7680
51
+ self.fc2 = nn.Linear(2048, 2048)
52
+ self.fc3 = nn.Linear(2048, 1024)
53
+ self.fc4 = nn.Linear(1024, 1024)
54
+ self.fc5 = nn.Linear(1024, 512)
55
+ self.fc6 = nn.Linear(512, 512)
56
+ self.fc7 = nn.Linear(512, 256)
57
+ self.fc8 = nn.Linear(256, 128)
58
+ self.fc9 = nn.Linear(128, 64)
59
+ self.fc10 = nn.Linear(64, 1)
60
+
61
+ # Residual connections for FC layers
62
+ self.fc_res1 = nn.Linear(2048, 1024)
63
+ self.fc_res2 = nn.Linear(1024, 512)
64
+ self.fc_res3 = nn.Linear(512, 128)
65
+
66
+ # Batch normalization layers
67
+ self.bn1 = nn.BatchNorm1d(64)
68
+ self.bn2 = nn.BatchNorm1d(64)
69
+ self.bn3 = nn.BatchNorm1d(128)
70
+ self.bn4 = nn.BatchNorm1d(128)
71
+ self.bn5 = nn.BatchNorm1d(256)
72
+ self.bn6 = nn.BatchNorm1d(256)
73
+ self.bn7 = nn.BatchNorm1d(512)
74
+ self.bn8 = nn.BatchNorm1d(512)
75
+ self.bn9 = nn.BatchNorm1d(1024)
76
+ self.bn10 = nn.BatchNorm1d(1024)
77
+ self.bn11 = nn.BatchNorm1d(2048)
78
+
79
+ # Scale batch norms
80
+ self.scale_bn1 = nn.BatchNorm1d(512)
81
+ self.scale_bn2 = nn.BatchNorm1d(512)
82
+ self.scale_bn3 = nn.BatchNorm1d(512)
83
+
84
+ # FC batch norms
85
+ self.fc_bn1 = nn.BatchNorm1d(2048)
86
+ self.fc_bn2 = nn.BatchNorm1d(2048)
87
+ self.fc_bn3 = nn.BatchNorm1d(1024)
88
+ self.fc_bn4 = nn.BatchNorm1d(1024)
89
+ self.fc_bn5 = nn.BatchNorm1d(512)
90
+ self.fc_bn6 = nn.BatchNorm1d(512)
91
+ self.fc_bn7 = nn.BatchNorm1d(256)
92
+ self.fc_bn8 = nn.BatchNorm1d(128)
93
+
94
+ # Dropout layers with varying rates
95
+ self.dropout1 = nn.Dropout(0.1)
96
+ self.dropout2 = nn.Dropout(0.2)
97
+ self.dropout3 = nn.Dropout(0.3)
98
+ self.dropout4 = nn.Dropout(0.4)
99
+ self.dropout5 = nn.Dropout(0.5)
100
+ self.dropout6 = nn.Dropout(0.4)
101
+ self.dropout7 = nn.Dropout(0.3)
102
+ self.dropout8 = nn.Dropout(0.2)
103
+
104
+ def forward(self, x):
105
+ """
106
+ Forward pass with residual connections and attention
107
+ Args:
108
+ x: (batch_size, input_dim, max_points) tensor
109
+ Returns:
110
+ classification: (batch_size, 1) tensor of logits
111
+ """
112
+ batch_size = x.size(0)
113
+
114
+ # Deep point-wise feature extraction with residual connections
115
+ x1 = F.relu(self.bn1(self.conv1(x)))
116
+ x2 = F.relu(self.bn2(self.conv2(x1)))
117
+ x2 = x2 + x1 # Residual connection
118
+
119
+ x3 = F.relu(self.bn3(self.conv3(x2)))
120
+ x4 = F.relu(self.bn4(self.conv4(x3)))
121
+ res1 = self.res_conv1(x2)
122
+ x4 = x4 + res1 # Residual connection
123
+
124
+ x5 = F.relu(self.bn5(self.conv5(x4)))
125
+ x6 = F.relu(self.bn6(self.conv6(x5)))
126
+ res2 = self.res_conv2(x4)
127
+ x6 = x6 + res2 # Residual connection
128
+
129
+ x7 = F.relu(self.bn7(self.conv7(x6)))
130
+ x8 = F.relu(self.bn8(self.conv8(x7)))
131
+ res3 = self.res_conv3(x6)
132
+ x8 = x8 + res3 # Residual connection
133
+
134
+ x9 = F.relu(self.bn9(self.conv9(x8)))
135
+ x10 = F.relu(self.bn10(self.conv10(x9)))
136
+ res4 = self.res_conv4(x8)
137
+ x10 = x10 + res4 # Residual connection
138
+
139
+ x11 = F.relu(self.bn11(self.conv11(x10)))
140
+
141
+ # Multi-scale global pooling
142
+ # Max pooling
143
+ global_max = torch.max(x11, 2)[0] # (batch_size, 2048)
144
+
145
+ # Average pooling
146
+ global_avg = torch.mean(x11, 2) # (batch_size, 2048)
147
+
148
+ # Attention-based pooling
149
+ x11_transposed = x11.transpose(1, 2) # (batch_size, max_points, 2048)
150
+ attended, _ = self.attention(x11_transposed, x11_transposed, x11_transposed)
151
+ attended = self.attention_norm(attended + x11_transposed)
152
+ global_att = torch.mean(attended, 1) # (batch_size, 2048)
153
+
154
+ # Multi-scale feature extraction
155
+ scale1 = F.relu(self.scale_bn1(self.scale_conv1(x11)))
156
+ scale1_pool = torch.max(scale1, 2)[0]
157
+
158
+ scale2 = F.relu(self.scale_bn2(self.scale_conv2(x11)))
159
+ scale2_pool = torch.mean(scale2, 2)
160
+
161
+ scale3 = F.relu(self.scale_bn3(self.scale_conv3(x11)))
162
+ scale3_pool = torch.std(scale3, 2)
163
+
164
+ # Concatenate all global features
165
+ global_features = torch.cat([
166
+ global_max, global_avg, global_att,
167
+ scale1_pool, scale2_pool, scale3_pool
168
+ ], dim=1) # (batch_size, 4096)
169
+
170
+ # Enhanced classification head with residual connections
171
+ x = F.relu(self.fc_bn1(self.fc1(global_features)))
172
+ x = self.dropout1(x)
173
+
174
+ x = F.relu(self.fc_bn2(self.fc2(x)))
175
+ identity1 = x
176
+ x = self.dropout2(x)
177
+
178
+ x = F.relu(self.fc_bn3(self.fc3(x)))
179
+ x = self.dropout3(x)
180
+
181
+ x = F.relu(self.fc_bn4(self.fc4(x)))
182
+ res_fc1 = self.fc_res1(identity1)
183
+ x = x + res_fc1 # Residual connection
184
+ identity2 = x
185
+ x = self.dropout4(x)
186
+
187
+ x = F.relu(self.fc_bn5(self.fc5(x)))
188
+ x = self.dropout5(x)
189
+
190
+ x = F.relu(self.fc_bn6(self.fc6(x)))
191
+ res_fc2 = self.fc_res2(identity2)
192
+ x = x + res_fc2 # Residual connection
193
+ identity3 = x
194
+ x = self.dropout6(x)
195
+
196
+ x = F.relu(self.fc_bn7(self.fc7(x)))
197
+ x = self.dropout7(x)
198
+
199
+ x = F.relu(self.fc_bn8(self.fc8(x)))
200
+ res_fc3 = self.fc_res3(identity3)
201
+ x = x + res_fc3 # Residual connection
202
+ x = self.dropout8(x)
203
+
204
+ x = F.relu(self.fc9(x))
205
+ classification = self.fc10(x) # (batch_size, 1)
206
+
207
+ return classification
208
+
209
+ class PatchClassificationDataset(Dataset):
210
+ """
211
+ Dataset class for loading saved patches for PointNet classification training.
212
+ """
213
+
214
+ def __init__(self, dataset_dir: str, max_points: int = 1024, augment: bool = True):
215
+ self.dataset_dir = dataset_dir
216
+ self.max_points = max_points
217
+ self.augment = augment
218
+
219
+ # Load patch files
220
+ self.patch_files = []
221
+ for file in os.listdir(dataset_dir):
222
+ if file.endswith('.pkl'):
223
+ self.patch_files.append(os.path.join(dataset_dir, file))
224
+
225
+ print(f"Found {len(self.patch_files)} patch files in {dataset_dir}")
226
+
227
+ def __len__(self):
228
+ return len(self.patch_files)
229
+
230
+ def __getitem__(self, idx):
231
+ """
232
+ Load and process a patch for training.
233
+ Returns:
234
+ patch_data: (6, max_points) tensor of point cloud data
235
+ label: scalar tensor for binary classification (0 or 1)
236
+ valid_mask: (max_points,) boolean tensor indicating valid points
237
+ """
238
+ patch_file = self.patch_files[idx]
239
+
240
+ with open(patch_file, 'rb') as f:
241
+ patch_info = pickle.load(f)
242
+
243
+ patch_6d = patch_info['patch_6d'] # (N, 6)
244
+ label = patch_info.get('label', 0) # Get binary classification label (0 or 1)
245
+
246
+ # Pad or sample points to max_points
247
+ num_points = patch_6d.shape[0]
248
+
249
+ if num_points >= self.max_points:
250
+ # Randomly sample max_points
251
+ indices = np.random.choice(num_points, self.max_points, replace=False)
252
+ patch_sampled = patch_6d[indices]
253
+ valid_mask = np.ones(self.max_points, dtype=bool)
254
+ else:
255
+ # Pad with zeros
256
+ patch_sampled = np.zeros((self.max_points, 6))
257
+ patch_sampled[:num_points] = patch_6d
258
+ valid_mask = np.zeros(self.max_points, dtype=bool)
259
+ valid_mask[:num_points] = True
260
+
261
+ # Data augmentation
262
+ if self.augment:
263
+ patch_sampled = self._augment_patch(patch_sampled, valid_mask)
264
+
265
+ # Convert to tensors and transpose for conv1d (channels first)
266
+ patch_tensor = torch.from_numpy(patch_sampled.T).float() # (6, max_points)
267
+ label_tensor = torch.tensor(label, dtype=torch.float32) # Float for BCE loss
268
+ valid_mask_tensor = torch.from_numpy(valid_mask)
269
+
270
+ return patch_tensor, label_tensor, valid_mask_tensor
271
+
272
+ def _augment_patch(self, patch, valid_mask):
273
+ """
274
+ Apply data augmentation to the patch.
275
+ """
276
+ valid_points = patch[valid_mask]
277
+
278
+ if len(valid_points) == 0:
279
+ return patch
280
+
281
+ # Random rotation around z-axis
282
+ angle = np.random.uniform(0, 2 * np.pi)
283
+ cos_angle = np.cos(angle)
284
+ sin_angle = np.sin(angle)
285
+ rotation_matrix = np.array([
286
+ [cos_angle, -sin_angle, 0],
287
+ [sin_angle, cos_angle, 0],
288
+ [0, 0, 1]
289
+ ])
290
+
291
+ # Apply rotation to xyz coordinates
292
+ valid_points[:, :3] = valid_points[:, :3] @ rotation_matrix.T
293
+
294
+ # Random jittering
295
+ noise = np.random.normal(0, 0.01, valid_points[:, :3].shape)
296
+ valid_points[:, :3] += noise
297
+
298
+ # Random scaling
299
+ scale = np.random.uniform(0.9, 1.1)
300
+ valid_points[:, :3] *= scale
301
+
302
+ patch[valid_mask] = valid_points
303
+ return patch
304
+
305
+ def save_patches_dataset(patches: List[Dict], dataset_dir: str, entry_id: str):
306
+ """
307
+ Save patches from prediction pipeline to create a training dataset.
308
+
309
+ Args:
310
+ patches: List of patch dictionaries from generate_patches()
311
+ dataset_dir: Directory to save the dataset
312
+ entry_id: Unique identifier for this entry/image
313
+ """
314
+ os.makedirs(dataset_dir, exist_ok=True)
315
+
316
+ for i, patch in enumerate(patches):
317
+ # Create unique filename
318
+ filename = f"{entry_id}_patch_{i}.pkl"
319
+ filepath = os.path.join(dataset_dir, filename)
320
+
321
+ # Skip if file already exists
322
+ if os.path.exists(filepath):
323
+ continue
324
+
325
+ # Save patch data
326
+ with open(filepath, 'wb') as f:
327
+ pickle.dump(patch, f)
328
+
329
+ print(f"Saved {len(patches)} patches for entry {entry_id}")
330
+
331
+ # Create dataloader with custom collate function to filter invalid samples
332
+ def collate_fn(batch):
333
+ valid_batch = []
334
+ for patch_data, label, valid_mask in batch:
335
+ # Filter out invalid samples (no valid points)
336
+ if valid_mask.sum() > 0:
337
+ valid_batch.append((patch_data, label, valid_mask))
338
+
339
+ if len(valid_batch) == 0:
340
+ return None
341
+
342
+ # Stack valid samples
343
+ patch_data = torch.stack([item[0] for item in valid_batch])
344
+ labels = torch.stack([item[1] for item in valid_batch])
345
+ valid_masks = torch.stack([item[2] for item in valid_batch])
346
+
347
+ return patch_data, labels, valid_masks
348
+
349
+ # Initialize weights using Xavier/Glorot initialization
350
+ def init_weights(m):
351
+ if isinstance(m, nn.Conv1d):
352
+ nn.init.xavier_uniform_(m.weight)
353
+ if m.bias is not None:
354
+ nn.init.zeros_(m.bias)
355
+ elif isinstance(m, nn.Linear):
356
+ nn.init.xavier_uniform_(m.weight)
357
+ if m.bias is not None:
358
+ nn.init.zeros_(m.bias)
359
+ elif isinstance(m, nn.BatchNorm1d):
360
+ nn.init.ones_(m.weight)
361
+ nn.init.zeros_(m.bias)
362
+
363
+ def train_pointnet(dataset_dir: str, model_save_path: str, epochs: int = 100, batch_size: int = 32,
364
+ lr: float = 0.001):
365
+ """
366
+ Train the ClassificationPointNet model on saved patches.
367
+
368
+ Args:
369
+ dataset_dir: Directory containing saved patch files
370
+ model_save_path: Path to save the trained model
371
+ epochs: Number of training epochs
372
+ batch_size: Training batch size
373
+ lr: Learning rate
374
+ """
375
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
376
+ print(f"Training on device: {device}")
377
+
378
+ # Create dataset and dataloader
379
+ dataset = PatchClassificationDataset(dataset_dir, max_points=1024, augment=True)
380
+ print(f"Dataset loaded with {len(dataset)} samples")
381
+
382
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=8,
383
+ collate_fn=collate_fn, drop_last=True)
384
+
385
+ # Initialize model
386
+ model = ClassificationPointNet(input_dim=6, max_points=1024)
387
+ model.apply(init_weights)
388
+ model.to(device)
389
+
390
+ # Loss function and optimizer (BCE for binary classification)
391
+ criterion = nn.BCEWithLogitsLoss()
392
+ optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)
393
+ scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.5)
394
+
395
+ # Training loop
396
+ model.train()
397
+ for epoch in range(epochs):
398
+ total_loss = 0.0
399
+ correct = 0
400
+ total = 0
401
+ num_batches = 0
402
+
403
+ for batch_idx, batch_data in enumerate(dataloader):
404
+ if batch_data is None: # Skip invalid batches
405
+ continue
406
+
407
+ patch_data, labels, valid_masks = batch_data
408
+ patch_data = patch_data.to(device) # (batch_size, 6, max_points)
409
+ labels = labels.to(device).unsqueeze(1) # (batch_size, 1)
410
+
411
+ # Forward pass
412
+ optimizer.zero_grad()
413
+ outputs = model(patch_data) # (batch_size, 1)
414
+ loss = criterion(outputs, labels)
415
+
416
+ # Backward pass
417
+ loss.backward()
418
+ optimizer.step()
419
+
420
+ # Statistics
421
+ total_loss += loss.item()
422
+ predicted = (torch.sigmoid(outputs) > 0.5).float()
423
+ total += labels.size(0)
424
+ correct += (predicted == labels).sum().item()
425
+ num_batches += 1
426
+
427
+ if batch_idx % 50 == 0:
428
+ print(f"Epoch {epoch+1}/{epochs}, Batch {batch_idx}, "
429
+ f"Loss: {loss.item():.6f}, "
430
+ f"Accuracy: {100 * correct / total:.2f}%")
431
+
432
+ avg_loss = total_loss / num_batches if num_batches > 0 else 0
433
+ accuracy = 100 * correct / total if total > 0 else 0
434
+
435
+ print(f"Epoch {epoch+1}/{epochs} completed, "
436
+ f"Avg Loss: {avg_loss:.6f}, "
437
+ f"Accuracy: {accuracy:.2f}%")
438
+
439
+ scheduler.step()
440
+
441
+ # Save model checkpoint every epoch
442
+ checkpoint_path = model_save_path.replace('.pth', f'_epoch_{epoch+1}.pth')
443
+ torch.save({
444
+ 'model_state_dict': model.state_dict(),
445
+ 'optimizer_state_dict': optimizer.state_dict(),
446
+ 'epoch': epoch + 1,
447
+ 'loss': avg_loss,
448
+ 'accuracy': accuracy,
449
+ }, checkpoint_path)
450
+
451
+ # Save the trained model
452
+ torch.save({
453
+ 'model_state_dict': model.state_dict(),
454
+ 'optimizer_state_dict': optimizer.state_dict(),
455
+ 'epoch': epochs,
456
+ }, model_save_path)
457
+
458
+ print(f"Model saved to {model_save_path}")
459
+ return model
460
+
461
+ def load_pointnet_model(model_path: str, device: torch.device = None) -> ClassificationPointNet:
462
+ """
463
+ Load a trained ClassificationPointNet model.
464
+
465
+ Args:
466
+ model_path: Path to the saved model
467
+ device: Device to load the model on
468
+
469
+ Returns:
470
+ Loaded ClassificationPointNet model
471
+ """
472
+ if device is None:
473
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
474
+
475
+ model = ClassificationPointNet(input_dim=6, max_points=1024)
476
+
477
+ checkpoint = torch.load(model_path, map_location=device)
478
+ model.load_state_dict(checkpoint['model_state_dict'])
479
+
480
+ model.to(device)
481
+ model.eval()
482
+
483
+ return model
484
+
485
+ def predict_class_from_patch(model: ClassificationPointNet, patch: Dict, device: torch.device = None) -> Tuple[int, float]:
486
+ """
487
+ Predict binary classification from a patch using trained PointNet.
488
+
489
+ Args:
490
+ model: Trained ClassificationPointNet model
491
+ patch: Dictionary containing patch data with 'patch_6d' key
492
+ device: Device to run prediction on
493
+
494
+ Returns:
495
+ tuple of (predicted_class, confidence)
496
+ predicted_class: int (0 for not edge, 1 for edge)
497
+ confidence: float representing confidence score (0-1)
498
+ """
499
+ if device is None:
500
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
501
+
502
+ patch_6d = patch['patch_6d'] # (N, 6)
503
+
504
+ # Prepare input
505
+ max_points = 1024
506
+ num_points = patch_6d.shape[0]
507
+
508
+ if num_points >= max_points:
509
+ # Sample points
510
+ indices = np.random.choice(num_points, max_points, replace=False)
511
+ patch_sampled = patch_6d[indices]
512
+ else:
513
+ # Pad with zeros
514
+ patch_sampled = np.zeros((max_points, 6))
515
+ patch_sampled[:num_points] = patch_6d
516
+
517
+ # Convert to tensor
518
+ patch_tensor = torch.from_numpy(patch_sampled.T).float().unsqueeze(0) # (1, 6, max_points)
519
+ patch_tensor = patch_tensor.to(device)
520
+
521
+ # Predict
522
+ with torch.no_grad():
523
+ outputs = model(patch_tensor) # (1, 1)
524
+ probability = torch.sigmoid(outputs).item()
525
+ predicted_class = int(probability > 0.5)
526
+
527
+ return predicted_class, probability