| | """This is the code for training the YOLO model for egg segmentation.""" |
| |
|
| | import logging |
| | from pathlib import Path |
| | from dataclasses import dataclass, field |
| | from typing import Any, Optional, Mapping, List |
| |
|
| | from collections import Counter |
| | from ultralytics import YOLO |
| | import numpy as np |
| | import cv2 |
| |
|
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | @dataclass |
| | class EggSegmentorTrainer: |
| | """Class for training the YOLO model for egg segmentation.""" |
| |
|
| | conf: str = field(default="src/egg_segmentation_size/data/data.yaml") |
| | epochs: int = field(default=100) |
| | img_size: int = field(default=640) |
| | batch_size: int = field(default=16) |
| | device: str = field(default="cuda") |
| | model: Any = field(init=False) |
| |
|
| | def train(self) -> None: |
| | """Train the YOLO model for egg segmentation.""" |
| | logger.info("Start training the YOLO model for egg segmentation.") |
| | self.model = YOLO("yolov8n-seg.pt") |
| | self.model.train( |
| | data=self.conf, |
| | epochs=self.epochs, |
| | imgsz=self.img_size, |
| | batch=self.batch_size, |
| | device=self.device, |
| | ) |
| |
|
| | def validation(self) -> Any: |
| | """Validate the YOLO model for egg segmentation.""" |
| | logger.info("Validating the YOLO model for egg segmentation.") |
| | return self.model.val() |
| |
|
| | def model_export(self) -> None: |
| | """Export the YOLO model for egg segmentation.""" |
| | logger.info("Exporting the YOLO model for egg segmentation.") |
| | self.model.export(format="onnx") |
| |
|
| |
|
| | @dataclass |
| | class EggSegmentorInference: |
| | """Class for testing the YOLO model for egg segmentation.""" |
| |
|
| | model_path: Optional[Any] = field(default=None) |
| | result_path: Optional[str] = field(default=None) |
| | scale_factor: float = field(default=11.61) |
| |
|
| | def __post_init__(self) -> None: |
| | """Post-initialization method for EggSegmentorInference.""" |
| | if self.model_path is None or not self.model_path.exists(): |
| | raise ValueError("Model does not exist or the path is not correct.") |
| |
|
| | def load_model(self) -> Any: |
| | """Load the YOLO model for egg detection.""" |
| | logger.info("Loading the trained model for egg segmentation.") |
| | return YOLO(self.model_path) |
| |
|
| | def inference(self, data_path: str) -> Any: |
| | """Inference code for egg segmentation""" |
| | if not Path(data_path).exists(): |
| | logger.error("Data path does not exist or the path is not correct.") |
| | model = self.load_model() |
| | results = model( |
| | data_path, |
| | save=False if not self.result_path else True, |
| | project=self.result_path, |
| | name="detections", |
| | ) |
| | return results |
| |
|
| | @staticmethod |
| | def _shoelace_area(polygon: Any) -> float: |
| | """Calculate the area of a polygon using the shoelace formula.""" |
| | x, y = polygon[:, 0], polygon[:, 1] |
| | return float(0.5 * np.abs(np.dot(x, np.roll(y, 1)) - np.dot(y, np.roll(x, 1)))) |
| |
|
| | @staticmethod |
| | def number_of_eggs(detections: Any) -> Mapping[str, Any]: |
| | """Count the number of eggs detected.""" |
| | counts = {} |
| | for result in detections: |
| | class_count = Counter(int(box.cls.item()) for box in result.boxes) |
| | temp = [] |
| | for name, count in class_count.items(): |
| | temp.append({"class": result.names[name], "count": count}) |
| | file_name = Path(result.path).name |
| | counts[str(file_name)] = temp |
| | return counts |
| |
|
| | def _egg_volume(self, polygon: Any, circularity_thr: int = 15) -> float: |
| | """Calculate the volume of eggs based on the detected polygon for each egg.""" |
| | polygon = polygon.reshape((-1, 1, 2)) |
| | ellipse = cv2.fitEllipse(polygon) |
| | minor_axis, major_axis = ( |
| | ellipse[1][0] / self.scale_factor, |
| | ellipse[1][1] / self.scale_factor, |
| | ) |
| |
|
| | if (major_axis - minor_axis) > circularity_thr: |
| | return 4 * np.pi * (major_axis / 2) * ((minor_axis / 2) ** 2) / 3000 |
| | return 4 * np.pi * (((major_axis + minor_axis) / 4) ** 3) / 3000 |
| |
|
| | def results_detail(self, detections: Any) -> Mapping[str, Any]: |
| | """Get the detailed results of the segmented eggs such as bounding boxes, class names, and confidences.""" |
| | results = {} |
| | for result in detections: |
| | temp = [] |
| | if result.masks is not None: |
| | boxes = result.boxes |
| | masks = result.masks.xy |
| | for i, mask in enumerate(masks): |
| | polygon = np.array(mask, dtype=np.float32) |
| | temp.append( |
| | { |
| | "class": result.names[int(boxes.cls[i].item())], |
| | "confidence": boxes.conf[i].item(), |
| | "areas in pixel": self._shoelace_area(polygon), |
| | "volume in cm3": self._egg_volume(polygon), |
| | } |
| | ) |
| | file_name = Path(result.path).name |
| | results[str(file_name)] = temp |
| | return results |
| |
|
| | @staticmethod |
| | def result_images(detections: Any) -> List[Any]: |
| | """Make a list of the result images with detections.""" |
| | images = [] |
| | for result in detections: |
| | images.append(np.array(result.plot())[:, :, [2, 1, 0]]) |
| | return images |
| |
|