PrashanthB461 commited on
Commit
153c5f0
·
verified ·
1 Parent(s): 2055b5c

Create train_yolo.py

Browse files
Files changed (1) hide show
  1. safety_analyzer/train_yolo.py +50 -0
safety_analyzer/train_yolo.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ```python
2
+ from ultralytics import YOLO
3
+ import logging
4
+ import torch
5
+ import shutil
6
+ import os
7
+
8
+ # Setup logging
9
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
10
+ logger = logging.getLogger(__name__)
11
+
12
+ def train_model():
13
+ try:
14
+ # Load pretrained YOLOv8 model
15
+ model = YOLO("yolov8n.pt")
16
+ logger.info("Loaded pretrained model: yolov8n.pt")
17
+
18
+ # Train the model
19
+ model.train(
20
+ data="data.yaml",
21
+ epochs=50,
22
+ imgsz=640,
23
+ batch=16,
24
+ device=0 if torch.cuda.is_available() else "cpu",
25
+ name="safety_model",
26
+ project="runs/train",
27
+ patience=10
28
+ )
29
+ logger.info("Training completed. Model saved in runs/train/safety_model/weights/")
30
+
31
+ # Validate the model
32
+ model.val()
33
+ logger.info("Validation completed.")
34
+
35
+ # Copy the best model to the project root as yolov8_safety.pt
36
+ best_model_path = "runs/train/safety_model/weights/best.pt"
37
+ if os.path.exists(best_model_path):
38
+ shutil.copy(best_model_path, "yolov8_safety.pt")
39
+ logger.info("Model copied to yolov8_safety.pt")
40
+ else:
41
+ logger.error("Best model not found at expected path")
42
+ raise FileNotFoundError("Best model not found")
43
+ except Exception as e:
44
+ logger.error(f"Error during training: {e}")
45
+ raise
46
+
47
+ if __name__ == "__main__":
48
+ logger.info("Starting YOLOv8 model training...")
49
+ train_model()
50
+ ```