pluto90's picture
files upload
50386b1
import os
from ultralytics import YOLO
import torch
import mlflow
device= 0 if torch.cuda.is_available() else "cpu"
if device==0:
print("GPU")
else:
print("CPU")
def train():
# Project root
ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../"))
data_path = os.path.join(ROOT_DIR, "data/raw/data.yaml")
# Output directory (YOLO saves here)
project_name= "experiments"
run_name= "yolov8s_768_v2_run"
output_dir= os.path.join(ROOT_DIR, project_name, run_name)
# MLflow Setup
mlflow.set_tracking_uri("sqlite:///mlflow.db")
mlflow.set_experiment("license-plate-detection")
# Training Config
params= {
"model": "yolov8s",
"epochs": 40,
"imgsz": 768,
"batch": 6,
"optimizer": "auto",
"mosaic": 0.3,
"device": device,
}
# Start MLflow run
with mlflow.start_run(run_name=run_name):
# log parameters
mlflow.log_params(params)
# load model
model = YOLO("yolov8s.pt")
# train
results= model.train(
data=data_path,
epochs=params["epochs"],
imgsz=params["imgsz"],
device=params["device"],
batch=params["batch"],
cache=False,
workers=0,
patience=10,
mosaic=params["mosaic"],
project=project_name,
name=run_name
)
# log metrics
metrics = results.results_dict
mlflow.log_metric("mAP50", metrics.get("metrics/mAP50(B)", 0))
mlflow.log_metric("mAP50-95", metrics.get("metrics/mAP50-95(B)", 0))
mlflow.log_metric("precision", metrics.get("metrics/precision(B)", 0))
mlflow.log_metric("recall", metrics.get("metrics/recall(B)", 0))
# log artifacts
# -------------
# 1. Best model
best_model_path= os.path.join(output_dir, "weights/best.pt")
if os.path.exists(best_model_path):
mlflow.log_artifact(best_model_path, artifact_path="model")
# 2. Training results csv
results_csv= os.path.join(output_dir, "results.csv")
if os.path.exists(results_csv):
mlflow.log_artifact(results_csv, artifact_path="metrics")
# 3. labels plot / confusion matrix (if generated)
labels_img= os.path.join(output_dir, "labels.jpg")
if os.path.exists(labels_img):
mlflow.log_artifact(labels_img, artifact_path="plots")
print("Training + MLflow logging completed")
if __name__ == "__main__":
train()