"""This is the code for training the YOLO model for package detection.""" import logging from pathlib import Path from dataclasses import dataclass, field from typing import Any, Optional, Mapping from collections import Counter from ultralytics import YOLO # type: ignore import cv2 logger = logging.getLogger(__name__) @dataclass class PackageDetectorTrainer: """Class to train YOLO model for package detection.""" conf: str = field(default="src/deep_package_detection/data/data.yaml") epochs: int = field(default=100) img_size: int = field(default=640) batch_size: int = field(default=16) device: str = field(default="cuda") model: Any = field(init=False) def train(self) -> None: """Train the YOLO model for package detection.""" logger.info("Start training the YOLO model for package detection.") self.model = YOLO("yolov8x-seg.pt") self.model.train( data=self.conf, epochs=self.epochs, imgsz=self.img_size, batch=self.batch_size, device=self.device, ) def validation(self) -> Any: """Validate the YOLO model for package detection.""" logger.info("Validating the YOLO model for package detection.") return self.model.val() def model_export(self) -> None: """Export the YOLO model for package detection.""" logger.info("Exporting the YOLO model for package detection.") self.model.export(format="onnx") @dataclass class PackageDetectorInference: """Class to test package detection using a trained YOLO model.""" model_path: Optional[Any] = field(default=None) result_path: Optional[str] = field(default=None) confidence_threshold: float = field(default=0.6) def __post_init__(self) -> None: """Post-initialization method for PackageDetectorInference.""" if self.model_path is None or not self.model_path.exists(): raise ValueError("Model does not exist or the path is not correct.") def load_model(self) -> Any: """Load the YOLO model for package detection.""" logger.info("Loading the trained model for package detection.") return YOLO(self.model_path) # type: ignore def inference(self, data_path: str) -> Any: """Inference code for egg detection""" if not Path(data_path).exists(): logger.error("Data path does not exist or the path is not correct.") model = self.load_model() results = model( data_path, save=False, project=self.result_path, name="detections", ) return results def count_packages(self, detections: Any) -> Mapping[str, Any]: """Count the number of packages detected.""" counts = {} for result in detections: class_count = Counter( int(box.cls.item()) for box in result.boxes if box.conf.item() > self.confidence_threshold ) temp = [] for name, count in class_count.items(): temp.append({"class": result.names[name], "count": count}) file_name = Path(result.path).name counts[str(file_name)] = temp return counts def plot_and_save_results(self, detections: Any) -> None: """Plot and save images with only high-confidence detected objects.""" if self.result_path is None: return output_dir = Path(self.result_path) output_dir.mkdir(parents=True, exist_ok=True) logger.info("Saving high-confidence detection images to: %s", output_dir) for result in detections: # Read original image img = cv2.imread(str(result.path)) if img is None: logger.warning("Could not read image: %s", result.path) continue # Iterate through boxes and draw only high-confidence detections for box in result.boxes: conf = float(box.conf.item()) if conf < self.confidence_threshold: continue cls_id = int(box.cls.item()) label = f"{result.names[cls_id]} {conf:.2f}" x1, y1, x2, y2 = map(int, box.xyxy[0]) # Draw rectangle and label cv2.rectangle( # pylint: disable=E1101 img, (x1, y1), (x2, y2), (0, 255, 0), 2 ) cv2.putText( # pylint: disable=E1101 img, label, (x1, max(y1 - 10, 0)), cv2.FONT_HERSHEY_SIMPLEX, # pylint: disable=E1101 0.6, (0, 255, 0), 1, cv2.LINE_AA, # pylint: disable=E1101 ) # Skip saving if no boxes above threshold if not any( box.conf.item() > self.confidence_threshold for box in result.boxes ): continue # Save the result output_path = output_dir / f"{Path(result.path).stem}_detections.jpg" cv2.imwrite(str(output_path), img) logger.info("Saved high-confidence detections to %s", output_path) def single_inference(self, detections: Any) -> Optional[Any]: """Do single inference for application demo considering high-confidence detected objects.""" result = detections[0] # Read original image img = cv2.imread(str(result.path)) if img is None: logger.warning("Could not read image: %s", result.path) return None img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # pylint: disable=E1101 # Iterate through boxes and draw only high-confidence detections for box in result.boxes: conf = float(box.conf.item()) if conf < self.confidence_threshold: continue cls_id = int(box.cls.item()) label = f"{result.names[cls_id]} {conf:.2f}" x1, y1, x2, y2 = map(int, box.xyxy[0]) # Draw rectangle and label cv2.rectangle( # pylint: disable=E1101 img, (x1, y1), (x2, y2), (0, 255, 0), 2 ) cv2.putText( # pylint: disable=E1101 img, label, (x1, max(y1 - 10, 0)), cv2.FONT_HERSHEY_SIMPLEX, # pylint: disable=E1101 0.6, (0, 255, 0), 1, cv2.LINE_AA, # pylint: disable=E1101 ) return img