hoho / train_pnet_class.py
jskvrna's picture
Final submission code
9518589
raw
history blame contribute delete
901 Bytes
# This script serves as the main entry point for training a PointNet-based
# classification model.
#
# It imports the necessary training function `train_pointnet` from the
# `fast_pointnet_class` module.
#
# The script defines file paths for the input dataset and the directory
# where the trained model will be saved. It ensures that the model saving
# directory exists before starting the training.
#
# Finally, it initiates the training process by calling the `train_pointnet`
# function with the specified dataset path, model save path, and a batch size.
from fast_pointnet_class import train_pointnet
import os
if __name__ == "__main__":
# Load the dataset
dataset_path = "<YOUR_DATASET_PATH_HERE>"
model_save_path = "<YOUR_MODEL_SAVE_PATH_HERE>"
os.makedirs(model_save_path, exist_ok=True)
# Train the model
train_pointnet(dataset_path, model_save_path, batch_size=4)