DamageLensAI / src /training /train_yolo.py
junaid17's picture
Upload 43 files
eef8873 verified
Raw
History Blame Contribute Delete
2.24 kB
import logging
from shutil import copy2, rmtree
from ultralytics import YOLO
from src.config import BASE_DIR, CHECKPOINT_DIR, DEVICE
logger = logging.getLogger(__name__)
YOLO_DATASET_CONFIG = BASE_DIR / "data" / "yolo" / "dataset_custom.yaml"
YOLO_BASE_MODEL = CHECKPOINT_DIR / "yolo11m.pt"
def run_yolo_training():
logger.info("Initializing YOLO training pipeline...")
if not YOLO_DATASET_CONFIG.exists():
raise FileNotFoundError(
f"YOLO dataset config not found: {YOLO_DATASET_CONFIG}"
)
if not YOLO_BASE_MODEL.exists():
raise FileNotFoundError(
f"YOLO base model not found: {YOLO_BASE_MODEL}"
)
yolo_device = 0 if DEVICE == "cuda" else "cpu"
checkpoint_root = CHECKPOINT_DIR.resolve()
temp_run_name = "temp_yolo_run"
logger.info("Loading YOLO base model...")
model = YOLO(str(YOLO_BASE_MODEL.resolve()))
logger.info("Starting YOLO training...")
model.train(
data=str(YOLO_DATASET_CONFIG.resolve()),
imgsz=416,
batch=4,
epochs=1,
device=yolo_device,
project=str(checkpoint_root),
name=temp_run_name,
exist_ok=True
)
best_model_path = (
checkpoint_root /
temp_run_name /
"weights" /
"best.pt"
)
if not best_model_path.exists():
raise FileNotFoundError(
f"YOLO best model not found: {best_model_path}"
)
final_model_path = checkpoint_root / "damage_detector.pt"
copy2(best_model_path, final_model_path)
logger.info(f"Final YOLO model saved at: {final_model_path}")
# cleanup temp training folder
temp_run_dir = checkpoint_root / temp_run_name
if temp_run_dir.exists():
rmtree(temp_run_dir)
logger.info("Temporary YOLO training artifacts deleted.")
return final_model_path
if __name__ == "__main__":
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s"
)
model_path = run_yolo_training()
print("\nYOLO training completed successfully.")
print(f"Saved model: {model_path}")