File size: 901 Bytes
33113fd
 
 
 
 
 
 
 
 
 
 
 
79c53b0
95c76d8
 
 
 
 
9518589
 
95c76d8
 
 
 
33113fd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# This script serves as the main entry point for training a PointNet-based
# classification model.
#
# It imports the necessary training function `train_pointnet` from the
# `fast_pointnet_class` module.
#
# The script defines file paths for the input dataset and the directory
# where the trained model will be saved. It ensures that the model saving
# directory exists before starting the training.
#
# Finally, it initiates the training process by calling the `train_pointnet`
# function with the specified dataset path, model save path, and a batch size.
from fast_pointnet_class import train_pointnet
import os

if __name__ == "__main__":

    # Load the dataset
    dataset_path = "<YOUR_DATASET_PATH_HERE>"
    model_save_path = "<YOUR_MODEL_SAVE_PATH_HERE>"

    os.makedirs(model_save_path, exist_ok=True)

    # Train the model
    train_pointnet(dataset_path, model_save_path, batch_size=4)