afshin-dini's picture
Add a single inference for demo
bc55a23
"""This is the code for training the YOLO model for package detection."""
import logging
from pathlib import Path
from dataclasses import dataclass, field
from typing import Any, Optional, Mapping
from collections import Counter
from ultralytics import YOLO # type: ignore
import cv2
logger = logging.getLogger(__name__)
@dataclass
class PackageDetectorTrainer:
"""Class to train YOLO model for package detection."""
conf: str = field(default="src/deep_package_detection/data/data.yaml")
epochs: int = field(default=100)
img_size: int = field(default=640)
batch_size: int = field(default=16)
device: str = field(default="cuda")
model: Any = field(init=False)
def train(self) -> None:
"""Train the YOLO model for package detection."""
logger.info("Start training the YOLO model for package detection.")
self.model = YOLO("yolov8x-seg.pt")
self.model.train(
data=self.conf,
epochs=self.epochs,
imgsz=self.img_size,
batch=self.batch_size,
device=self.device,
)
def validation(self) -> Any:
"""Validate the YOLO model for package detection."""
logger.info("Validating the YOLO model for package detection.")
return self.model.val()
def model_export(self) -> None:
"""Export the YOLO model for package detection."""
logger.info("Exporting the YOLO model for package detection.")
self.model.export(format="onnx")
@dataclass
class PackageDetectorInference:
"""Class to test package detection using a trained YOLO model."""
model_path: Optional[Any] = field(default=None)
result_path: Optional[str] = field(default=None)
confidence_threshold: float = field(default=0.6)
def __post_init__(self) -> None:
"""Post-initialization method for PackageDetectorInference."""
if self.model_path is None or not self.model_path.exists():
raise ValueError("Model does not exist or the path is not correct.")
def load_model(self) -> Any:
"""Load the YOLO model for package detection."""
logger.info("Loading the trained model for package detection.")
return YOLO(self.model_path) # type: ignore
def inference(self, data_path: str) -> Any:
"""Inference code for egg detection"""
if not Path(data_path).exists():
logger.error("Data path does not exist or the path is not correct.")
model = self.load_model()
results = model(
data_path,
save=False,
project=self.result_path,
name="detections",
)
return results
def count_packages(self, detections: Any) -> Mapping[str, Any]:
"""Count the number of packages detected."""
counts = {}
for result in detections:
class_count = Counter(
int(box.cls.item())
for box in result.boxes
if box.conf.item() > self.confidence_threshold
)
temp = []
for name, count in class_count.items():
temp.append({"class": result.names[name], "count": count})
file_name = Path(result.path).name
counts[str(file_name)] = temp
return counts
def plot_and_save_results(self, detections: Any) -> None:
"""Plot and save images with only high-confidence detected objects."""
if self.result_path is None:
return
output_dir = Path(self.result_path)
output_dir.mkdir(parents=True, exist_ok=True)
logger.info("Saving high-confidence detection images to: %s", output_dir)
for result in detections:
# Read original image
img = cv2.imread(str(result.path))
if img is None:
logger.warning("Could not read image: %s", result.path)
continue
# Iterate through boxes and draw only high-confidence detections
for box in result.boxes:
conf = float(box.conf.item())
if conf < self.confidence_threshold:
continue
cls_id = int(box.cls.item())
label = f"{result.names[cls_id]} {conf:.2f}"
x1, y1, x2, y2 = map(int, box.xyxy[0])
# Draw rectangle and label
cv2.rectangle( # pylint: disable=E1101
img, (x1, y1), (x2, y2), (0, 255, 0), 2
)
cv2.putText( # pylint: disable=E1101
img,
label,
(x1, max(y1 - 10, 0)),
cv2.FONT_HERSHEY_SIMPLEX, # pylint: disable=E1101
0.6,
(0, 255, 0),
1,
cv2.LINE_AA, # pylint: disable=E1101
)
# Skip saving if no boxes above threshold
if not any(
box.conf.item() > self.confidence_threshold for box in result.boxes
):
continue
# Save the result
output_path = output_dir / f"{Path(result.path).stem}_detections.jpg"
cv2.imwrite(str(output_path), img)
logger.info("Saved high-confidence detections to %s", output_path)
def single_inference(self, detections: Any) -> Optional[Any]:
"""Do single inference for application demo considering high-confidence detected objects."""
result = detections[0]
# Read original image
img = cv2.imread(str(result.path))
if img is None:
logger.warning("Could not read image: %s", result.path)
return None
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # pylint: disable=E1101
# Iterate through boxes and draw only high-confidence detections
for box in result.boxes:
conf = float(box.conf.item())
if conf < self.confidence_threshold:
continue
cls_id = int(box.cls.item())
label = f"{result.names[cls_id]} {conf:.2f}"
x1, y1, x2, y2 = map(int, box.xyxy[0])
# Draw rectangle and label
cv2.rectangle( # pylint: disable=E1101
img, (x1, y1), (x2, y2), (0, 255, 0), 2
)
cv2.putText( # pylint: disable=E1101
img,
label,
(x1, max(y1 - 10, 0)),
cv2.FONT_HERSHEY_SIMPLEX, # pylint: disable=E1101
0.6,
(0, 255, 0),
1,
cv2.LINE_AA, # pylint: disable=E1101
)
return img