jskvrna commited on
Commit
e0d1066
·
1 Parent(s): a4de7b0

Improves training setup and parameters

Browse files

Adds directory creation for saving the model, ensuring the path exists before training starts.

Updates training parameters to adjust the balance between classification and scoring losses, and introduces a class weight parameter. These modifications are aimed at improving training performance and potentially addressing class imbalance issues.

Files changed (2) hide show
  1. train_pnet.py +3 -0
  2. train_pnet_cluster.py +1 -1
train_pnet.py CHANGED
@@ -1,4 +1,5 @@
1
  from fast_pointnet import train_pointnet
 
2
 
3
  if __name__ == "__main__":
4
 
@@ -6,5 +7,7 @@ if __name__ == "__main__":
6
  dataset_path = "/home/skvrnjan/personal/hohocustom/"
7
  model_save_path = "/home/skvrnjan/personal/hoho_pnet/"
8
 
 
 
9
  # Train the model
10
  train_pointnet(dataset_path, model_save_path)
 
1
  from fast_pointnet import train_pointnet
2
+ import os
3
 
4
  if __name__ == "__main__":
5
 
 
7
  dataset_path = "/home/skvrnjan/personal/hohocustom/"
8
  model_save_path = "/home/skvrnjan/personal/hoho_pnet/"
9
 
10
+ os.makedirs(model_save_path, exist_ok=True)
11
+
12
  # Train the model
13
  train_pointnet(dataset_path, model_save_path)
train_pnet_cluster.py CHANGED
@@ -7,4 +7,4 @@ if __name__ == "__main__":
7
  model_save_path = "/mnt/personal/skvrnjan/hoho_pnet/initial.pth"
8
 
9
  # Train the model
10
- train_pointnet(dataset_path, model_save_path, epochs=100, batch_size=128, lr=0.001, score_weight=0.1)
 
7
  model_save_path = "/mnt/personal/skvrnjan/hoho_pnet/initial.pth"
8
 
9
  # Train the model
10
+ train_pointnet(dataset_path, model_save_path, epochs=100, batch_size=128, lr=0.001, score_weight=1.0, class_weight=0.5)