Spaces:
Build error
Build error
| from ultralytics import YOLO | |
| # Define a class for training and validating a YOLO model | |
| class YOLOTrainer: | |
| def __init__(self, model_config, data_config, batch_size, img_size, epochs, patience): | |
| # Initialize the YOLO model with the given configuration | |
| self.model = YOLO(model_config) | |
| self.data_config = data_config | |
| self.batch_size = batch_size | |
| self.img_size = img_size | |
| self.epochs = epochs | |
| self.patience = patience | |
| # Method to train the model | |
| def train(self): | |
| self.model.train(data=self.data_config, batch=self.batch_size, imgsz=self.img_size, epochs=self.epochs, patience=self.patience) | |
| # Method to validate the model | |
| def validate(self): | |
| self.model.val() | |
| # Check if the script is run directly (not imported as a module) | |
| if __name__ == "__main__": | |
| # Define the configuration for the model | |
| model_config = 'yolov8m.yaml' | |
| # Define the data configuration | |
| data_config = 'dataset/data.yaml' | |
| # Define the batch size for training | |
| batch_size = 16 | |
| # Define the image size for training | |
| img_size = 640 | |
| # Define the number of epochs for training | |
| epochs = 100 | |
| # Define the patience for early stopping | |
| patience = 20 | |
| # Create a YOLOTrainer object with the specified configurations | |
| trainer = YOLOTrainer(model_config, data_config, batch_size, img_size, epochs, patience) | |
| # Train the model | |
| trainer.train() | |
| # Validate the model | |
| trainer.validate() | |
| # Optional: Save the best model to a file | |
| trainer.model.save('model/best_model.pt') | |