PrashanthB461's picture
Update safety_analyzer/train_yolo.py
4ea55e1 verified
from ultralytics import YOLO
import logging
# Setup logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
def train_yolov8():
try:
# Load a pretrained YOLOv8 model
model = YOLO("yolov8n.pt")
logger.info("Loaded pretrained YOLOv8n model")
# Train the model
model.train(
data="path/to/data.yaml", # Path to your data.yaml file
epochs=100, # Number of epochs
imgsz=640, # Image size
batch=16, # Batch size
device=0, # Use GPU (0) if available
patience=50, # Early stopping patience
project="runs/train", # Output directory
name="safety_model", # Experiment name
exist_ok=True
)
logger.info("Model training completed")
# Save the trained model
model.save("yolov8_safety.pt")
logger.info("Saved trained model as yolov8_safety.pt")
except Exception as e:
logger.error(f"Error during training: {e}")
raise
if __name__ == "__main__":
logger.info("Starting YOLOv8 model training...")
train_yolov8()