|
|
"""This is the code for training the YOLO model for package detection.""" |
|
|
|
|
|
import logging |
|
|
from pathlib import Path |
|
|
from dataclasses import dataclass, field |
|
|
from typing import Any, Optional, Mapping |
|
|
|
|
|
from collections import Counter |
|
|
from ultralytics import YOLO |
|
|
import cv2 |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class PackageDetectorTrainer: |
|
|
"""Class to train YOLO model for package detection.""" |
|
|
|
|
|
conf: str = field(default="src/deep_package_detection/data/data.yaml") |
|
|
epochs: int = field(default=100) |
|
|
img_size: int = field(default=640) |
|
|
batch_size: int = field(default=16) |
|
|
device: str = field(default="cuda") |
|
|
model: Any = field(init=False) |
|
|
|
|
|
def train(self) -> None: |
|
|
"""Train the YOLO model for package detection.""" |
|
|
logger.info("Start training the YOLO model for package detection.") |
|
|
self.model = YOLO("yolov8x-seg.pt") |
|
|
self.model.train( |
|
|
data=self.conf, |
|
|
epochs=self.epochs, |
|
|
imgsz=self.img_size, |
|
|
batch=self.batch_size, |
|
|
device=self.device, |
|
|
) |
|
|
|
|
|
def validation(self) -> Any: |
|
|
"""Validate the YOLO model for package detection.""" |
|
|
logger.info("Validating the YOLO model for package detection.") |
|
|
return self.model.val() |
|
|
|
|
|
def model_export(self) -> None: |
|
|
"""Export the YOLO model for package detection.""" |
|
|
logger.info("Exporting the YOLO model for package detection.") |
|
|
self.model.export(format="onnx") |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class PackageDetectorInference: |
|
|
"""Class to test package detection using a trained YOLO model.""" |
|
|
|
|
|
model_path: Optional[Any] = field(default=None) |
|
|
result_path: Optional[str] = field(default=None) |
|
|
confidence_threshold: float = field(default=0.6) |
|
|
|
|
|
def __post_init__(self) -> None: |
|
|
"""Post-initialization method for PackageDetectorInference.""" |
|
|
if self.model_path is None or not self.model_path.exists(): |
|
|
raise ValueError("Model does not exist or the path is not correct.") |
|
|
|
|
|
def load_model(self) -> Any: |
|
|
"""Load the YOLO model for package detection.""" |
|
|
logger.info("Loading the trained model for package detection.") |
|
|
return YOLO(self.model_path) |
|
|
|
|
|
def inference(self, data_path: str) -> Any: |
|
|
"""Inference code for egg detection""" |
|
|
if not Path(data_path).exists(): |
|
|
logger.error("Data path does not exist or the path is not correct.") |
|
|
model = self.load_model() |
|
|
results = model( |
|
|
data_path, |
|
|
save=False, |
|
|
project=self.result_path, |
|
|
name="detections", |
|
|
) |
|
|
return results |
|
|
|
|
|
def count_packages(self, detections: Any) -> Mapping[str, Any]: |
|
|
"""Count the number of packages detected.""" |
|
|
counts = {} |
|
|
for result in detections: |
|
|
class_count = Counter( |
|
|
int(box.cls.item()) |
|
|
for box in result.boxes |
|
|
if box.conf.item() > self.confidence_threshold |
|
|
) |
|
|
temp = [] |
|
|
for name, count in class_count.items(): |
|
|
temp.append({"class": result.names[name], "count": count}) |
|
|
file_name = Path(result.path).name |
|
|
counts[str(file_name)] = temp |
|
|
return counts |
|
|
|
|
|
def plot_and_save_results(self, detections: Any) -> None: |
|
|
"""Plot and save images with only high-confidence detected objects.""" |
|
|
if self.result_path is None: |
|
|
return |
|
|
output_dir = Path(self.result_path) |
|
|
output_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
logger.info("Saving high-confidence detection images to: %s", output_dir) |
|
|
|
|
|
for result in detections: |
|
|
|
|
|
img = cv2.imread(str(result.path)) |
|
|
if img is None: |
|
|
logger.warning("Could not read image: %s", result.path) |
|
|
continue |
|
|
|
|
|
|
|
|
for box in result.boxes: |
|
|
conf = float(box.conf.item()) |
|
|
if conf < self.confidence_threshold: |
|
|
continue |
|
|
|
|
|
cls_id = int(box.cls.item()) |
|
|
label = f"{result.names[cls_id]} {conf:.2f}" |
|
|
x1, y1, x2, y2 = map(int, box.xyxy[0]) |
|
|
|
|
|
|
|
|
cv2.rectangle( |
|
|
img, (x1, y1), (x2, y2), (0, 255, 0), 2 |
|
|
) |
|
|
cv2.putText( |
|
|
img, |
|
|
label, |
|
|
(x1, max(y1 - 10, 0)), |
|
|
cv2.FONT_HERSHEY_SIMPLEX, |
|
|
0.6, |
|
|
(0, 255, 0), |
|
|
1, |
|
|
cv2.LINE_AA, |
|
|
) |
|
|
|
|
|
|
|
|
if not any( |
|
|
box.conf.item() > self.confidence_threshold for box in result.boxes |
|
|
): |
|
|
continue |
|
|
|
|
|
|
|
|
output_path = output_dir / f"{Path(result.path).stem}_detections.jpg" |
|
|
cv2.imwrite(str(output_path), img) |
|
|
logger.info("Saved high-confidence detections to %s", output_path) |
|
|
|
|
|
def single_inference(self, detections: Any) -> Optional[Any]: |
|
|
"""Do single inference for application demo considering high-confidence detected objects.""" |
|
|
result = detections[0] |
|
|
|
|
|
img = cv2.imread(str(result.path)) |
|
|
if img is None: |
|
|
logger.warning("Could not read image: %s", result.path) |
|
|
return None |
|
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
|
|
|
|
|
for box in result.boxes: |
|
|
conf = float(box.conf.item()) |
|
|
if conf < self.confidence_threshold: |
|
|
continue |
|
|
|
|
|
cls_id = int(box.cls.item()) |
|
|
label = f"{result.names[cls_id]} {conf:.2f}" |
|
|
x1, y1, x2, y2 = map(int, box.xyxy[0]) |
|
|
|
|
|
|
|
|
cv2.rectangle( |
|
|
img, (x1, y1), (x2, y2), (0, 255, 0), 2 |
|
|
) |
|
|
cv2.putText( |
|
|
img, |
|
|
label, |
|
|
(x1, max(y1 - 10, 0)), |
|
|
cv2.FONT_HERSHEY_SIMPLEX, |
|
|
0.6, |
|
|
(0, 255, 0), |
|
|
1, |
|
|
cv2.LINE_AA, |
|
|
) |
|
|
return img |
|
|
|