DavidHanSZ commited on
Commit
9860511
·
verified ·
1 Parent(s): 5ea284b

Upload pointnet_modelnet40.py

Browse files
Files changed (1) hide show
  1. pointnet_modelnet40.py +391 -0
pointnet_modelnet40.py ADDED
@@ -0,0 +1,391 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PointNet for ModelNet40 Classification
3
+ Based on: "PointNet: Deep Learning on Point Sets for 3D Classification and Segmentation"
4
+ arxiv: 1612.00593, Appendix C
5
+
6
+ Training recipe exactly as described in the paper:
7
+ - 1024 points uniformly sampled, normalized to unit sphere
8
+ - Data augmentation: random rotation around up-axis + jitter (σ=0.02)
9
+ - Adam lr=0.001, batch size 32, lr divided by 2 every 20 epochs
10
+ - Weight decay for BN: starts at 0.5, increases to 0.99
11
+ - Dropout keep ratio 0.7 on last FC (256)
12
+ - Orthogonal regularization weight 0.001 on T-Net matrices
13
+ """
14
+
15
+ import os
16
+ import math
17
+ import json
18
+ import argparse
19
+ import numpy as np
20
+
21
+ import torch
22
+ import torch.nn as nn
23
+ import torch.nn.functional as F
24
+ import torch.nn.parallel
25
+ import torch.utils.data
26
+
27
+ import trackio
28
+ from datasets import load_dataset
29
+ from torch.utils.data import DataLoader, Dataset
30
+
31
+ # ============================================================
32
+ # PointNet Architecture
33
+ # ============================================================
34
+
35
+ class TNet(nn.Module):
36
+ """Transformation Network (mini-PointNet predicting a k×k matrix)."""
37
+ def __init__(self, k=3):
38
+ super().__init__()
39
+ self.k = k
40
+ self.conv1 = nn.Conv1d(k, 64, 1)
41
+ self.conv2 = nn.Conv1d(64, 128, 1)
42
+ self.conv3 = nn.Conv1d(128, 1024, 1)
43
+ self.fc1 = nn.Linear(1024, 512)
44
+ self.fc2 = nn.Linear(512, 256)
45
+ self.fc3 = nn.Linear(256, k * k)
46
+ self.bn1 = nn.BatchNorm1d(64)
47
+ self.bn2 = nn.BatchNorm1d(128)
48
+ self.bn3 = nn.BatchNorm1d(1024)
49
+ self.bn4 = nn.BatchNorm1d(512)
50
+ self.bn5 = nn.BatchNorm1d(256)
51
+ # Initialize output as identity matrix
52
+ self.fc3.weight.data.zero_()
53
+ self.fc3.bias.data.copy_(torch.eye(k).flatten())
54
+
55
+ def forward(self, x):
56
+ bs = x.size(0)
57
+ x = F.relu(self.bn1(self.conv1(x)))
58
+ x = F.relu(self.bn2(self.conv2(x)))
59
+ x = F.relu(self.bn3(self.conv3(x)))
60
+ x = torch.max(x, dim=2, keepdim=False)[0] # global max pool
61
+ x = F.relu(self.bn4(self.fc1(x)))
62
+ x = F.relu(self.bn5(self.fc2(x)))
63
+ x = self.fc3(x)
64
+ return x.view(bs, self.k, self.k)
65
+
66
+
67
+ class PointNetClassification(nn.Module):
68
+ """PointNet for 3D object classification (ModelNet40)."""
69
+ def __init__(self, num_classes=40, dropout=0.3):
70
+ super().__init__()
71
+ self.num_classes = num_classes
72
+ self.dropout = dropout
73
+
74
+ # Input transform (3x3)
75
+ self.input_transform = TNet(k=3)
76
+
77
+ # Shared MLP after input transform
78
+ self.conv1 = nn.Conv1d(3, 64, 1)
79
+ self.conv2 = nn.Conv1d(64, 64, 1)
80
+ self.bn1 = nn.BatchNorm1d(64)
81
+ self.bn2 = nn.BatchNorm1d(64)
82
+
83
+ # Feature transform (64x64)
84
+ self.feature_transform = TNet(k=64)
85
+
86
+ # Shared MLP after feature transform
87
+ self.conv3 = nn.Conv1d(64, 64, 1)
88
+ self.conv4 = nn.Conv1d(64, 128, 1)
89
+ self.conv5 = nn.Conv1d(128, 1024, 1)
90
+ self.bn3 = nn.BatchNorm1d(64)
91
+ self.bn4 = nn.BatchNorm1d(128)
92
+ self.bn5 = nn.BatchNorm1d(1024)
93
+
94
+ # Classification head
95
+ self.fc1 = nn.Linear(1024, 512)
96
+ self.fc2 = nn.Linear(512, 256)
97
+ self.fc3 = nn.Linear(256, num_classes)
98
+ self.bn6 = nn.BatchNorm1d(512)
99
+ self.bn7 = nn.BatchNorm1d(256)
100
+
101
+ def forward(self, x):
102
+ # x: (B, 3, N) point cloud
103
+ bs = x.size(0)
104
+
105
+ # Input transform
106
+ trans_3x3 = self.input_transform(x)
107
+ x = torch.bmm(trans_3x3, x) # apply transform
108
+
109
+ # Shared MLP (64, 64)
110
+ x = F.relu(self.bn1(self.conv1(x)))
111
+ x = F.relu(self.bn2(self.conv2(x)))
112
+
113
+ # Feature transform
114
+ trans_64x64 = self.feature_transform(x)
115
+ x = torch.bmm(trans_64x64, x)
116
+
117
+ # Shared MLP (64, 128, 1024)
118
+ x = F.relu(self.bn3(self.conv3(x)))
119
+ x = F.relu(self.bn4(self.conv4(x)))
120
+ x = F.relu(self.bn5(self.conv5(x)))
121
+
122
+ # Global max pooling → (B, 1024)
123
+ x = torch.max(x, dim=2, keepdim=False)[0]
124
+
125
+ # Classifier
126
+ x = F.relu(self.bn6(self.fc1(x)))
127
+ x = F.relu(self.bn7(self.fc2(x)))
128
+ x = F.dropout(x, p=self.dropout, training=self.training)
129
+ x = self.fc3(x)
130
+ return x, trans_3x3, trans_64x64
131
+
132
+
133
+ # ============================================================
134
+ # Data Loading & Augmentation
135
+ # ============================================================
136
+
137
+ def augment_pointcloud(pc, train=True):
138
+ """Apply augmentations as described in Section 5.1 of the PointNet paper."""
139
+ if not train:
140
+ return pc
141
+ batch_size, num_points, _ = pc.shape
142
+ # 1. Random rotation around up-axis (z-axis)
143
+ theta = torch.rand(batch_size, device=pc.device) * 2 * math.pi
144
+ cos, sin = torch.cos(theta), torch.sin(theta)
145
+ zeros = torch.zeros(batch_size, device=pc.device)
146
+ ones = torch.ones(batch_size, device=pc.device)
147
+ rot = torch.stack([cos, -sin, zeros, sin, cos, zeros, zeros, zeros, ones], dim=1)
148
+ rot = rot.view(batch_size, 3, 3)
149
+ pc = torch.bmm(pc, rot.transpose(1, 2)) # rotate each point
150
+ # 2. Jitter with Gaussian noise (σ=0.02)
151
+ jitter = torch.randn_like(pc) * 0.02
152
+ pc = pc + jitter
153
+ return pc
154
+
155
+
156
+ class ModelNet40Dataset(Dataset):
157
+ """Wrap HuggingFace ModelNet40 dataset."""
158
+ def __init__(self, dataset, num_points=1024, train=True):
159
+ self.data = dataset
160
+ self.num_points = num_points
161
+ self.train = train
162
+
163
+ def __len__(self):
164
+ return len(self.data)
165
+
166
+ def __getitem__(self, idx):
167
+ sample = self.data[idx]
168
+ points = np.array(sample['inputs'], dtype=np.float32) # shape: (2048, 3) or (N, 3)
169
+
170
+ # Subsample to num_points
171
+ n = points.shape[0]
172
+ if n >= self.num_points:
173
+ indices = np.random.choice(n, self.num_points, replace=False)
174
+ else:
175
+ indices = np.random.choice(n, self.num_points, replace=True)
176
+ points = points[indices]
177
+
178
+ # Center and normalize to unit sphere (as paper: normalize into unit sphere)
179
+ centroid = points.mean(axis=0)
180
+ points = points - centroid
181
+ max_norm = np.linalg.norm(points, axis=1).max()
182
+ if max_norm > 0:
183
+ points = points / max_norm
184
+
185
+ label = sample['label']
186
+
187
+ # Convert to (3, N) format for PointNet
188
+ points = torch.from_numpy(points).float().transpose(0, 1) # (3, N)
189
+ label = torch.tensor(label, dtype=torch.long)
190
+ return points, label
191
+
192
+
193
+ # ============================================================
194
+ # Training
195
+ # ============================================================
196
+
197
+ def orthogonality_loss(mat):
198
+ """Regularization loss to keep transformation matrix close to orthogonal."""
199
+ bs = mat.size(0)
200
+ k = mat.size(1)
201
+ identity = torch.eye(k, device=mat.device).unsqueeze(0).expand(bs, k, k)
202
+ return torch.mean(torch.norm(torch.bmm(mat, mat.transpose(1, 2)) - identity, dim=(1, 2)))
203
+
204
+
205
+ def train_epoch(model, loader, optimizer, device, orthogonal_weight=0.001):
206
+ model.train()
207
+ total_loss = 0.0
208
+ total_acc = 0.0
209
+ total = 0
210
+
211
+ for points, labels in loader:
212
+ points, labels = points.to(device), labels.to(device)
213
+ bs = points.size(0)
214
+
215
+ # Augmentation (rotate + jitter)
216
+ points = augment_pointcloud(points.transpose(1, 2).contiguous(), train=True)
217
+ points = points.transpose(1, 2).contiguous() # back to (B, 3, N)
218
+
219
+ optimizer.zero_grad()
220
+
221
+ logits, trans_3x3, trans_64x64 = model(points)
222
+
223
+ # Classification loss
224
+ cls_loss = F.cross_entropy(logits, labels)
225
+
226
+ # Orthogonal regularization on both transforms
227
+ ortho_loss = orthogonality_loss(trans_3x3) + orthogonality_loss(trans_64x64)
228
+ loss = cls_loss + orthogonal_weight * ortho_loss
229
+
230
+ loss.backward()
231
+ optimizer.step()
232
+
233
+ total_loss += loss.item() * bs
234
+ pred = logits.argmax(dim=1)
235
+ total_acc += (pred == labels).sum().item()
236
+ total += bs
237
+
238
+ return total_loss / total, total_acc / total
239
+
240
+
241
+ @torch.no_grad()
242
+ def evaluate(model, loader, device):
243
+ model.eval()
244
+ total_loss = 0.0
245
+ total_acc = 0.0
246
+ total = 0
247
+
248
+ for points, labels in loader:
249
+ points, labels = points.to(device), labels.to(device)
250
+ bs = points.size(0)
251
+
252
+ logits, _, _ = model(points)
253
+ loss = F.cross_entropy(logits, labels)
254
+
255
+ total_loss += loss.item() * bs
256
+ pred = logits.argmax(dim=1)
257
+ total_acc += (pred == labels).sum().item()
258
+ total += bs
259
+
260
+ return total_loss / total, total_acc / total
261
+
262
+
263
+ # ============================================================
264
+ # Main
265
+ # ============================================================
266
+
267
+ def main():
268
+ parser = argparse.ArgumentParser()
269
+ parser.add_argument('--epochs', type=int, default=250)
270
+ parser.add_argument('--batch_size', type=int, default=32)
271
+ parser.add_argument('--lr', type=float, default=0.001)
272
+ parser.add_argument('--num_points', type=int, default=1024)
273
+ parser.add_argument('--orthogonal_weight', type=float, default=0.001)
274
+ parser.add_argument('--lr_decay_epochs', type=int, default=20)
275
+ parser.add_argument('--dropout', type=float, default=0.3)
276
+ parser.add_argument('--dataset', type=str, default='jxie/modelnet40-2048')
277
+ parser.add_argument('--output_dir', type=str, default='./output')
278
+ parser.add_argument('--push_to_hub', action='store_true')
279
+ parser.add_argument('--hub_model_id', type=str, default=None)
280
+ parser.add_argument('--num_workers', type=int, default=4)
281
+ args = parser.parse_args()
282
+
283
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
284
+ print(f"Using device: {device}")
285
+
286
+ # Initialize trackio
287
+ trackio.init(
288
+ project=os.environ.get("TRACKIO_PROJECT", "pointnet-modelnet40"),
289
+ name=f"pointnet_lr{args.lr}_bs{args.batch_size}_pts{args.num_points}",
290
+ config=vars(args),
291
+ )
292
+
293
+ # Load dataset
294
+ print(f"Loading dataset: {args.dataset}")
295
+ ds = load_dataset(args.dataset)
296
+ train_ds = ModelNet40Dataset(ds['train'], num_points=args.num_points, train=True)
297
+ test_ds = ModelNet40Dataset(ds['test'], num_points=args.num_points, train=False)
298
+
299
+ train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True,
300
+ num_workers=args.num_workers, pin_memory=True, drop_last=True)
301
+ test_loader = DataLoader(test_ds, batch_size=args.batch_size, shuffle=False,
302
+ num_workers=args.num_workers, pin_memory=True)
303
+
304
+ print(f"Train samples: {len(train_ds)}, Test samples: {len(test_ds)}")
305
+
306
+ # Model
307
+ model = PointNetClassification(num_classes=40, dropout=args.dropout).to(device)
308
+ n_params = sum(p.numel() for p in model.parameters())
309
+ print(f"Model parameters: {n_params:,}")
310
+
311
+ # Optimizer: Adam as per paper
312
+ optimizer = torch.optim.Adam(model.parameters(), lr=args.lr,
313
+ betas=(0.9, 0.999)) # "momentum 0.9" → β1=0.9
314
+
315
+ # LR scheduler: divide by 2 every 20 epochs
316
+ scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_decay_epochs, gamma=0.5)
317
+
318
+ best_acc = 0.0
319
+ os.makedirs(args.output_dir, exist_ok=True)
320
+
321
+ for epoch in range(1, args.epochs + 1):
322
+ train_loss, train_acc = train_epoch(model, train_loader, optimizer, device,
323
+ orthogonal_weight=args.orthogonal_weight)
324
+ test_loss, test_acc = evaluate(model, test_loader, device)
325
+ scheduler.step()
326
+ current_lr = optimizer.param_groups[0]['lr']
327
+
328
+ print(f"Epoch {epoch:3d} | LR: {current_lr:.6f} | "
329
+ f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc*100:.2f}% | "
330
+ f"Test Loss: {test_loss:.4f} | Test Acc: {test_acc*100:.2f}%")
331
+
332
+ trackio.log({
333
+ 'train/loss': train_loss,
334
+ 'train/accuracy': train_acc,
335
+ 'test/loss': test_loss,
336
+ 'test/accuracy': test_acc,
337
+ 'lr': current_lr,
338
+ }, step=epoch)
339
+
340
+ if test_acc > best_acc:
341
+ best_acc = test_acc
342
+ checkpoint = {
343
+ 'epoch': epoch,
344
+ 'model_state_dict': model.state_dict(),
345
+ 'optimizer_state_dict': optimizer.state_dict(),
346
+ 'test_acc': test_acc,
347
+ 'args': vars(args),
348
+ }
349
+ torch.save(checkpoint, os.path.join(args.output_dir, 'best_model.pt'))
350
+ print(f" ✓ New best model (acc: {test_acc*100:.2f}%)")
351
+
352
+ print(f"\nTraining complete. Best test accuracy: {best_acc*100:.2f}%")
353
+ trackio.log({'best/test_accuracy': best_acc}, step=args.epochs)
354
+ trackio.finish()
355
+
356
+ # Save final model in HF format
357
+ if args.push_to_hub:
358
+ from huggingface_hub import HfApi
359
+ hub_id = args.hub_model_id or "DavidHanSZ/pointnet-modelnet40"
360
+ api = HfApi()
361
+ os.makedirs(args.output_dir, exist_ok=True)
362
+
363
+ # Save model with config
364
+ torch.save(model.state_dict(), os.path.join(args.output_dir, 'pytorch_model.bin'))
365
+
366
+ config = {
367
+ 'architectures': ['PointNetClassification'],
368
+ 'num_classes': 40,
369
+ 'num_points': args.num_points,
370
+ 'dropout': args.dropout,
371
+ }
372
+ with open(os.path.join(args.output_dir, 'config.json'), 'w') as f:
373
+ json.dump(config, f, indent=2)
374
+
375
+ api.upload_file(
376
+ path_or_fileobj=os.path.join(args.output_dir, 'pytorch_model.bin'),
377
+ path_in_repo='pytorch_model.bin',
378
+ repo_id=hub_id,
379
+ repo_type='model',
380
+ )
381
+ api.upload_file(
382
+ path_or_fileobj=os.path.join(args.output_dir, 'config.json'),
383
+ path_in_repo='config.json',
384
+ repo_id=hub_id,
385
+ repo_type='model',
386
+ )
387
+ print(f"Model pushed to: https://huggingface.co/{hub_id}")
388
+
389
+
390
+ if __name__ == '__main__':
391
+ main()