Spaces:
Running
Running
| import os | |
| from ultralytics import YOLO | |
| import torch | |
| import mlflow | |
| device= 0 if torch.cuda.is_available() else "cpu" | |
| if device==0: | |
| print("GPU") | |
| else: | |
| print("CPU") | |
| def train(): | |
| # Project root | |
| ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../")) | |
| data_path = os.path.join(ROOT_DIR, "data/raw/data.yaml") | |
| # Output directory (YOLO saves here) | |
| project_name= "experiments" | |
| run_name= "yolov8s_768_v2_run" | |
| output_dir= os.path.join(ROOT_DIR, project_name, run_name) | |
| # MLflow Setup | |
| mlflow.set_tracking_uri("sqlite:///mlflow.db") | |
| mlflow.set_experiment("license-plate-detection") | |
| # Training Config | |
| params= { | |
| "model": "yolov8s", | |
| "epochs": 40, | |
| "imgsz": 768, | |
| "batch": 6, | |
| "optimizer": "auto", | |
| "mosaic": 0.3, | |
| "device": device, | |
| } | |
| # Start MLflow run | |
| with mlflow.start_run(run_name=run_name): | |
| # log parameters | |
| mlflow.log_params(params) | |
| # load model | |
| model = YOLO("yolov8s.pt") | |
| # train | |
| results= model.train( | |
| data=data_path, | |
| epochs=params["epochs"], | |
| imgsz=params["imgsz"], | |
| device=params["device"], | |
| batch=params["batch"], | |
| cache=False, | |
| workers=0, | |
| patience=10, | |
| mosaic=params["mosaic"], | |
| project=project_name, | |
| name=run_name | |
| ) | |
| # log metrics | |
| metrics = results.results_dict | |
| mlflow.log_metric("mAP50", metrics.get("metrics/mAP50(B)", 0)) | |
| mlflow.log_metric("mAP50-95", metrics.get("metrics/mAP50-95(B)", 0)) | |
| mlflow.log_metric("precision", metrics.get("metrics/precision(B)", 0)) | |
| mlflow.log_metric("recall", metrics.get("metrics/recall(B)", 0)) | |
| # log artifacts | |
| # ------------- | |
| # 1. Best model | |
| best_model_path= os.path.join(output_dir, "weights/best.pt") | |
| if os.path.exists(best_model_path): | |
| mlflow.log_artifact(best_model_path, artifact_path="model") | |
| # 2. Training results csv | |
| results_csv= os.path.join(output_dir, "results.csv") | |
| if os.path.exists(results_csv): | |
| mlflow.log_artifact(results_csv, artifact_path="metrics") | |
| # 3. labels plot / confusion matrix (if generated) | |
| labels_img= os.path.join(output_dir, "labels.jpg") | |
| if os.path.exists(labels_img): | |
| mlflow.log_artifact(labels_img, artifact_path="plots") | |
| print("Training + MLflow logging completed") | |
| if __name__ == "__main__": | |
| train() |