jskvrna commited on
Commit
95c76d8
·
1 Parent(s): e36bb23

Adds training scripts for PointNet classification

Browse files

Adds scripts to train PointNet models for classification tasks.

These scripts initialize training with specific dataset and model save paths,
along with customizable hyperparameters.

Files changed (2) hide show
  1. train_pnet_class.py +13 -0
  2. train_pnet_class_cluster.py +13 -0
train_pnet_class.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fast_pointnet import train_pointnet
2
+ import os
3
+
4
+ if __name__ == "__main__":
5
+
6
+ # Load the dataset
7
+ dataset_path = "/home/skvrnjan/personal/hohocustom_edges/"
8
+ model_save_path = "/home/skvrnjan/personal/hoho_pnet_edges/"
9
+
10
+ os.makedirs(model_save_path, exist_ok=True)
11
+
12
+ # Train the model
13
+ train_pointnet(dataset_path, model_save_path)
train_pnet_class_cluster.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fast_pointnet import train_pointnet
2
+ import os
3
+
4
+ if __name__ == "__main__":
5
+
6
+ # Load the dataset
7
+ dataset_path = "/mnt/personal/skvrnjan/hohocustom_edges/"
8
+ model_save_path = "/mnt/personal/skvrnjan/hoho_pnet_edges/initial.pth"
9
+
10
+ os.makedirs(model_save_path, exist_ok=True)
11
+
12
+ # Train the model
13
+ train_pointnet(dataset_path, model_save_path, epochs=100, batch_size=128, lr=0.001, score_weight=1.0, class_weight=0.5)