Panel_Fault_DC / train_model.py
Nawinkumar15's picture
Create train_model.py
3c0ad87 verified
import os
from ultralytics import YOLO
# --- Configuration ---
# Path to your data.yaml file.
DATA_YAML_PATH = "C:\Users\Veera\Downloads\Yolov8n-env\data.yaml"
# Choose a pre-trained YOLOv8 model to start with.
PRETRAINED_MODEL = "yolov8n.pt"
# Training parameters
EPOCHS = 100 # Number of training epochs. Adjust based on your dataset size and desired accuracy.
IMG_SIZE = 640 # Image size for training (as per Roboflow preprocessing).
BATCH_SIZE = 16 # Reduced batch size since you don't have a GPU. You might need to lower it further if you encounter memory issues.
PROJECT_NAME = "Model_trained" # Name of the project directory where results will be saved
RUN_NAME = "best.pt" # Name of the specific run within the project directory
# --- Main Training Logic ---
def train_yolov8_model():
"""
Trains a YOLOv8 model using the specified dataset and parameters.
The trained model (best.pt) will be saved in runs/detect/{RUN_NAME}/weights/.
"""
print(f"Starting YOLOv8 model training with {PRETRAINED_MODEL}...")
# 1. Load a pre-trained YOLOv8 model
try:
model = YOLO(PRETRAINED_MODEL)
print(f"Successfully loaded pre-trained model: {PRETRAINED_MODEL}")
except Exception as e:
print(f"Error loading pre-trained model: {e}")
print("Please ensure you have an active internet connection if downloading for the first time.")
return
# 2. Check if data.yaml exists
if not os.path.exists(DATA_YAML_PATH):
print(f"Error: data.yaml not found at '{DATA_YAML_PATH}'.")
print("Please ensure the 'data.yaml' file is in the correct location.")
return
# 3. Train the model
print(f"Training model on dataset defined in: {DATA_YAML_PATH}")
print(f"Training for {EPOCHS} epochs with image size {IMG_SIZE} and batch size {BATCH_SIZE} on CPU...")
print("Training on CPU will be significantly slower.")
try:
results = model.train(
data=DATA_YAML_PATH,
epochs=EPOCHS,
imgsz=IMG_SIZE,
batch=BATCH_SIZE,
project=PROJECT_NAME,
name=RUN_NAME
)
print("\nTraining completed successfully!")
# The best.pt file is typically saved in runs/detect/{RUN_NAME}/weights/best.pt
output_weights_dir = os.path.join("runs", "detect", RUN_NAME, "weights")
best_pt_path = os.path.join(output_weights_dir, "best.pt")
if os.path.exists(best_pt_path):
print(f"Your trained model (best.pt) is saved at: {os.path.abspath(best_pt_path)}")
print("You can now use this .pt file for local inference or upload it to Hugging Face.")
else:
print("Warning: 'best.pt' file not found in the expected location after training.")
print(f"Please check the output directory: {os.path.abspath(output_weights_dir)}")
except Exception as e:
print(f"An error occurred during training: {e}")
print("Common issues: insufficient CPU memory (try reducing batch_size), incorrect data.yaml paths.")
# Run the training function
if __name__ == "__main__":
train_yolov8_model()