| # This script serves as the main entry point for training a PointNet-based | |
| # classification model. | |
| # | |
| # It imports the necessary training function `train_pointnet` from the | |
| # `fast_pointnet_class` module. | |
| # | |
| # The script defines file paths for the input dataset and the directory | |
| # where the trained model will be saved. It ensures that the model saving | |
| # directory exists before starting the training. | |
| # | |
| # Finally, it initiates the training process by calling the `train_pointnet` | |
| # function with the specified dataset path, model save path, and a batch size. | |
| from fast_pointnet_class import train_pointnet | |
| import os | |
| if __name__ == "__main__": | |
| # Load the dataset | |
| dataset_path = "<YOUR_DATASET_PATH_HERE>" | |
| model_save_path = "<YOUR_MODEL_SAVE_PATH_HERE>" | |
| os.makedirs(model_save_path, exist_ok=True) | |
| # Train the model | |
| train_pointnet(dataset_path, model_save_path, batch_size=4) |