arm-model / model /yolo_detection.py
pragadeeshv23's picture
Upload folder using huggingface_hub
5b86813 verified
#!/usr/bin/env python3
"""
YOLOv11 Road Anomaly Detection Module
Based on: Nature Scientific Reports (Nov 2025) - YOLOv11 + CNN-BiGRU
Handles spatial detection of road anomalies:
- Pothole, Alligator Crack, Longitudinal Crack, Transverse Crack
Optimised for NVIDIA RTX 2050 (4 GB VRAM) + i5-12450H.
Image size defaults to 416 to fit within 4 GB VRAM budget.
"""
import os
import sys
import time
import logging
from pathlib import Path
from typing import List, Dict, Any, Optional, Tuple, Union
import cv2
import numpy as np
import torch
from ultralytics import YOLO
# ---------------------------------------------------------------------------
# Logging
# ---------------------------------------------------------------------------
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s - %(message)s",
)
logger = logging.getLogger("YOLOv11Detector")
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
CLASS_NAMES = [
"Alligator Crack",
"Longitudinal Crack",
"Pothole",
"Transverse Crack",
]
# BGR colours for each class
CLASS_COLORS = {
0: (0, 0, 255), # Red – Alligator Crack
1: (0, 165, 255), # Orange – Longitudinal Crack
2: (0, 255, 255), # Yellow – Pothole
3: (255, 0, 0), # Blue – Transverse Crack
}
# Severity scale (0-1) based on anomaly type (paper Table 3)
SEVERITY_WEIGHTS = {
0: 0.80, # Alligator Crack – structural
1: 0.50, # Longitudinal Crack – moderate
2: 0.95, # Pothole – critical
3: 0.60, # Transverse Crack – moderate-high
}
# ═══════════════════════════════════════════════════════════════════════════
# YOLOv11 Detector
# ═══════════════════════════════════════════════════════════════════════════
class YOLOv11Detector:
"""
Wrapper around Ultralytics YOLO for road-anomaly detection.
Supports:
β€’ Training on custom road-anomaly datasets (YOLO format)
β€’ Single-image / batch / video / live-camera inference
β€’ TensorRT / ONNX / TFLite export for edge deployment
"""
def __init__(
self,
model_path: str = "yolo11n.pt",
img_size: int = 416,
conf_threshold: float = 0.2,
iou_threshold: float = 0.45,
device: Optional[Union[int, str]] = None,
classes: Optional[List[str]] = None,
):
self.img_size = img_size
self.conf_threshold = conf_threshold
self.iou_threshold = iou_threshold
self.classes = classes or CLASS_NAMES
# Auto-select device
if device is None:
self.device = 0 if torch.cuda.is_available() else "cpu"
else:
self.device = device
logger.info("Loading model: %s (device=%s)", model_path, self.device)
self.model = YOLO(model_path)
logger.info("Model loaded – %d classes", len(self.classes))
# ------------------------------------------------------------------
# Training
# ------------------------------------------------------------------
def train(
self,
data_yaml: str,
epochs: int = 100,
batch: int = 4,
optimizer: str = "AdamW",
lr0: float = 0.001,
weight_decay: float = 0.0005,
warmup_epochs: float = 3.0,
augment: bool = True,
mosaic: float = 0.5,
mixup: float = 0.0,
cache: Union[bool, str] = "disk",
amp: bool = True,
workers: int = 4,
project: str = "road_anomaly",
name: str = "yolov11_experiment",
resume: bool = False,
exist_ok: bool = True,
**extra_args,
) -> Any:
"""
Train YOLOv11 on a road-anomaly dataset.
Returns the ultralytics Results object.
"""
train_args = dict(
data=data_yaml,
epochs=epochs,
imgsz=self.img_size,
batch=batch,
device=self.device,
workers=workers,
optimizer=optimizer,
lr0=lr0,
weight_decay=weight_decay,
warmup_epochs=warmup_epochs,
augment=augment,
mosaic=mosaic,
mixup=mixup,
cache=cache,
amp=amp,
project=project,
name=name,
resume=resume,
exist_ok=exist_ok,
save=True,
val=True,
plots=True,
)
train_args.update(extra_args)
logger.info("πŸš€ Starting YOLOv11 training for %d epochs …", epochs)
results = self.model.train(**train_args)
logger.info("βœ… Training complete")
return results
# ------------------------------------------------------------------
# Validation
# ------------------------------------------------------------------
def validate(self, data_yaml: str, split: str = "val", **kwargs) -> Any:
"""Run validation and return metrics."""
return self.model.val(
data=data_yaml,
split=split,
imgsz=self.img_size,
device=self.device,
**kwargs,
)
# ------------------------------------------------------------------
# Single-image detection
# ------------------------------------------------------------------
def detect(
self,
source: Union[str, np.ndarray],
conf_threshold: Optional[float] = None,
iou_threshold: Optional[float] = None,
verbose: bool = False,
) -> List[Dict[str, Any]]:
"""
Detect road anomalies in one image (path or numpy BGR array).
Returns a list of dicts:
[{"bbox": [x1,y1,x2,y2], "class_id": int,
"class_name": str, "confidence": float, "severity": float}, …]
"""
conf = conf_threshold or self.conf_threshold
iou = iou_threshold or self.iou_threshold
results = self.model.predict(
source=source,
conf=conf,
iou=iou,
imgsz=self.img_size,
device=self.device,
verbose=verbose,
)
detections: List[Dict[str, Any]] = []
for box in results[0].boxes:
cls_id = int(box.cls[0])
det = {
"bbox": box.xyxy[0].cpu().numpy().astype(int).tolist(),
"class_id": cls_id,
"class_name": self.classes[cls_id] if cls_id < len(self.classes) else f"class_{cls_id}",
"confidence": float(box.conf[0]),
"severity": SEVERITY_WEIGHTS.get(cls_id, 0.5),
}
detections.append(det)
return detections
# ------------------------------------------------------------------
# Batch detection
# ------------------------------------------------------------------
def detect_batch(
self,
image_dir: str,
conf_threshold: Optional[float] = None,
save: bool = True,
save_txt: bool = True,
project: str = "inference_results",
) -> Any:
"""Run detection on all images in a directory."""
conf = conf_threshold or self.conf_threshold
return self.model.predict(
source=image_dir,
conf=conf,
iou=self.iou_threshold,
imgsz=self.img_size,
device=self.device,
save=save,
save_txt=save_txt,
project=project,
verbose=False,
)
# ------------------------------------------------------------------
# Real-time video detection
# ------------------------------------------------------------------
def detect_realtime(
self,
video_source: Union[str, int] = 0,
output_path: Optional[str] = "output.mp4",
show: bool = True,
return_detections: bool = False,
) -> Optional[List[List[Dict[str, Any]]]]:
"""
Real-time detection on a video file or camera feed.
Args:
video_source: video file path or camera index (default 0).
output_path: path to write annotated video (None to skip).
show: display live window.
return_detections: if True, collect per-frame detections list.
Returns:
Per-frame detection list (if return_detections=True), else None.
"""
cap = cv2.VideoCapture(video_source)
if not cap.isOpened():
logger.error("Cannot open video source: %s", video_source)
return None
fps_in = int(cap.get(cv2.CAP_PROP_FPS)) or 30
w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
writer = None
if output_path:
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
writer = cv2.VideoWriter(output_path, fourcc, fps_in, (w, h))
all_detections: List[List[Dict[str, Any]]] = []
frame_times: List[float] = []
frame_idx = 0
logger.info("Processing video – %dx%d @ %d fps (%d frames)", w, h, fps_in, total)
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
t0 = time.perf_counter()
detections = self.detect(frame)
dt = time.perf_counter() - t0
frame_times.append(dt)
if return_detections:
all_detections.append(detections)
# Annotate
annotated = self.draw(frame, detections)
live_fps = 1.0 / dt if dt > 0 else 0
cv2.putText(annotated, f"FPS: {live_fps:.1f}", (10, 30),
cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
if writer:
writer.write(annotated)
if show:
cv2.imshow("YOLOv11 Road Anomaly Detection", annotated)
if cv2.waitKey(1) & 0xFF == ord("q"):
break
frame_idx += 1
if frame_idx % 200 == 0:
avg = 1.0 / np.mean(frame_times[-200:])
logger.info(" frame %d/%d avg FPS=%.1f", frame_idx, total, avg)
cap.release()
if writer:
writer.release()
logger.info("Saved annotated video β†’ %s", output_path)
cv2.destroyAllWindows()
avg_ms = np.mean(frame_times) * 1000
avg_fps = 1.0 / np.mean(frame_times)
logger.info("Average latency: %.1f ms | Average FPS: %.1f", avg_ms, avg_fps)
return all_detections if return_detections else None
# ------------------------------------------------------------------
# Feature extraction (for CNN-BiGRU pipeline)
# ------------------------------------------------------------------
def extract_features(
self,
frame: np.ndarray,
) -> Tuple[List[Dict[str, Any]], List[np.ndarray]]:
"""
Detect objects and crop the bounding-box regions.
Returns:
(detections, crops) – crops are BGR np arrays, one per detection.
"""
detections = self.detect(frame)
crops: List[np.ndarray] = []
h, w = frame.shape[:2]
for det in detections:
x1, y1, x2, y2 = det["bbox"]
x1, y1 = max(0, x1), max(0, y1)
x2, y2 = min(w, x2), min(h, y2)
crop = frame[y1:y2, x1:x2]
if crop.size > 0:
crops.append(crop)
return detections, crops
# ------------------------------------------------------------------
# Drawing helpers
# ------------------------------------------------------------------
@staticmethod
def draw(
image: np.ndarray,
detections: List[Dict[str, Any]],
thickness: int = 2,
) -> np.ndarray:
"""Draw bounding boxes + labels on image (returns copy)."""
img = image.copy()
for det in detections:
x1, y1, x2, y2 = det["bbox"]
cls_id = det["class_id"]
color = CLASS_COLORS.get(cls_id, (0, 255, 0))
label = f"{det['class_name']} {det['confidence']:.2f}"
cv2.rectangle(img, (x1, y1), (x2, y2), color, thickness)
(tw, th), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
cv2.rectangle(img, (x1, y1 - th - 8), (x1 + tw, y1), color, -1)
cv2.putText(img, label, (x1, y1 - 4),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
return img
# ------------------------------------------------------------------
# Export
# ------------------------------------------------------------------
def export(
self,
format: str = "engine",
half: bool = True,
int8: bool = False,
workspace: int = 2,
output_dir: Optional[str] = None,
) -> str:
"""
Export the model for deployment.
Supported formats: 'onnx', 'engine' (TensorRT), 'tflite',
'torchscript', 'openvino', etc.
"""
kwargs: Dict[str, Any] = dict(
format=format,
imgsz=self.img_size,
half=half,
int8=int8,
device=self.device,
)
if format == "engine":
kwargs["workspace"] = workspace
logger.info("Exporting model β†’ %s (half=%s, int8=%s)", format, half, int8)
path = self.model.export(**kwargs)
logger.info("Export saved β†’ %s", path)
return path
# ------------------------------------------------------------------
# Utility
# ------------------------------------------------------------------
def __repr__(self) -> str:
return (
f"YOLOv11Detector(img_size={self.img_size}, "
f"conf={self.conf_threshold}, device={self.device}, "
f"classes={len(self.classes)})"
)
# ═══════════════════════════════════════════════════════════════════════════
# Quick self-test
# ═══════════════════════════════════════════════════════════════════════════
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="YOLOv11 Road Anomaly Detector")
sub = parser.add_subparsers(dest="cmd")
# --- train ---
p_train = sub.add_parser("train", help="Train on a dataset")
p_train.add_argument("--data", required=True, help="data.yaml path")
p_train.add_argument("--model", default="yolo11n.pt", help="Base model")
p_train.add_argument("--epochs", type=int, default=100)
p_train.add_argument("--batch", type=int, default=16)
p_train.add_argument("--imgsz", type=int, default=416)
# --- detect ---
p_det = sub.add_parser("detect", help="Run detection on image/video")
p_det.add_argument("--model", required=True, help="Weights path")
p_det.add_argument("--source", required=True, help="Image / video / 'camera'")
p_det.add_argument("--conf", type=float, default=0.2)
p_det.add_argument("--output", default=None)
# --- export ---
p_exp = sub.add_parser("export", help="Export model")
p_exp.add_argument("--model", required=True)
p_exp.add_argument("--format", default="engine", choices=["onnx", "engine", "tflite", "torchscript", "openvino"])
p_exp.add_argument("--half", action="store_true", default=True)
args = parser.parse_args()
if args.cmd == "train":
det = YOLOv11Detector(model_path=args.model, img_size=args.imgsz)
det.train(data_yaml=args.data, epochs=args.epochs, batch=args.batch)
elif args.cmd == "detect":
det = YOLOv11Detector(model_path=args.model, conf_threshold=args.conf)
if args.source.lower() == "camera":
det.detect_realtime(video_source=0, show=True)
elif Path(args.source).is_dir():
det.detect_batch(args.source, save=True)
elif Path(args.source).suffix.lower() in (".mp4", ".avi", ".mov", ".mkv"):
det.detect_realtime(video_source=args.source, output_path=args.output, show=True)
else:
dets = det.detect(args.source)
img = cv2.imread(args.source)
out = det.draw(img, dets)
if args.output:
cv2.imwrite(args.output, out)
print(f"Saved β†’ {args.output}")
for d in dets:
print(f" {d['class_name']:>20s} conf={d['confidence']:.3f} severity={d['severity']:.2f}")
elif args.cmd == "export":
det = YOLOv11Detector(model_path=args.model)
det.export(format=args.format, half=args.half)
else:
parser.print_help()