PrashanthB461 commited on
Commit
4ea55e1
·
verified ·
1 Parent(s): 9cc7878

Update safety_analyzer/train_yolo.py

Browse files
Files changed (1) hide show
  1. safety_analyzer/train_yolo.py +17 -27
safety_analyzer/train_yolo.py CHANGED
@@ -1,48 +1,38 @@
1
  from ultralytics import YOLO
2
  import logging
3
- import torch
4
- import shutil
5
- import os
6
 
7
  # Setup logging
8
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
9
  logger = logging.getLogger(__name__)
10
 
11
- def train_model():
12
  try:
13
- # Load pretrained YOLOv8 model
14
  model = YOLO("yolov8n.pt")
15
- logger.info("Loaded pretrained model: yolov8n.pt")
16
 
17
  # Train the model
18
  model.train(
19
- data="data.yaml",
20
- epochs=50,
21
- imgsz=640,
22
- batch=16,
23
- device=0 if torch.cuda.is_available() else "cpu",
24
- name="safety_model",
25
- project="runs/train",
26
- patience=10
 
27
  )
28
- logger.info("Training completed. Model saved in runs/train/safety_model/weights/")
29
 
30
- # Validate the model
31
- model.val()
32
- logger.info("Validation completed.")
33
 
34
- # Copy the best model to the project root as yolov8_safety.pt
35
- best_model_path = "runs/train/safety_model/weights/best.pt"
36
- if os.path.exists(best_model_path):
37
- shutil.copy(best_model_path, "yolov8_safety.pt")
38
- logger.info("Model copied to yolov8_safety.pt")
39
- else:
40
- logger.error("Best model not found at expected path")
41
- raise FileNotFoundError("Best model not found")
42
  except Exception as e:
43
  logger.error(f"Error during training: {e}")
44
  raise
45
 
46
  if __name__ == "__main__":
47
  logger.info("Starting YOLOv8 model training...")
48
- train_model()
 
1
  from ultralytics import YOLO
2
  import logging
 
 
 
3
 
4
  # Setup logging
5
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
6
  logger = logging.getLogger(__name__)
7
 
8
+ def train_yolov8():
9
  try:
10
+ # Load a pretrained YOLOv8 model
11
  model = YOLO("yolov8n.pt")
12
+ logger.info("Loaded pretrained YOLOv8n model")
13
 
14
  # Train the model
15
  model.train(
16
+ data="path/to/data.yaml", # Path to your data.yaml file
17
+ epochs=100, # Number of epochs
18
+ imgsz=640, # Image size
19
+ batch=16, # Batch size
20
+ device=0, # Use GPU (0) if available
21
+ patience=50, # Early stopping patience
22
+ project="runs/train", # Output directory
23
+ name="safety_model", # Experiment name
24
+ exist_ok=True
25
  )
26
+ logger.info("Model training completed")
27
 
28
+ # Save the trained model
29
+ model.save("yolov8_safety.pt")
30
+ logger.info("Saved trained model as yolov8_safety.pt")
31
 
 
 
 
 
 
 
 
 
32
  except Exception as e:
33
  logger.error(f"Error during training: {e}")
34
  raise
35
 
36
  if __name__ == "__main__":
37
  logger.info("Starting YOLOv8 model training...")
38
+ train_yolov8()