afshin-dini's picture
transfer the repo
623606b
"""This is the code for training the YOLO model for egg segmentation."""
import logging
from pathlib import Path
from dataclasses import dataclass, field
from typing import Any, Optional, Mapping, List
from collections import Counter
from ultralytics import YOLO
import numpy as np
import cv2
logger = logging.getLogger(__name__)
@dataclass
class EggSegmentorTrainer:
"""Class for training the YOLO model for egg segmentation."""
conf: str = field(default="src/egg_segmentation_size/data/data.yaml")
epochs: int = field(default=100)
img_size: int = field(default=640)
batch_size: int = field(default=16)
device: str = field(default="cuda")
model: Any = field(init=False)
def train(self) -> None:
"""Train the YOLO model for egg segmentation."""
logger.info("Start training the YOLO model for egg segmentation.")
self.model = YOLO("yolov8n-seg.pt")
self.model.train(
data=self.conf,
epochs=self.epochs,
imgsz=self.img_size,
batch=self.batch_size,
device=self.device,
)
def validation(self) -> Any:
"""Validate the YOLO model for egg segmentation."""
logger.info("Validating the YOLO model for egg segmentation.")
return self.model.val()
def model_export(self) -> None:
"""Export the YOLO model for egg segmentation."""
logger.info("Exporting the YOLO model for egg segmentation.")
self.model.export(format="onnx")
@dataclass
class EggSegmentorInference:
"""Class for testing the YOLO model for egg segmentation."""
model_path: Optional[Any] = field(default=None)
result_path: Optional[str] = field(default=None)
scale_factor: float = field(default=11.61)
def __post_init__(self) -> None:
"""Post-initialization method for EggSegmentorInference."""
if self.model_path is None or not self.model_path.exists():
raise ValueError("Model does not exist or the path is not correct.")
def load_model(self) -> Any:
"""Load the YOLO model for egg detection."""
logger.info("Loading the trained model for egg segmentation.")
return YOLO(self.model_path)
def inference(self, data_path: str) -> Any:
"""Inference code for egg segmentation"""
if not Path(data_path).exists():
logger.error("Data path does not exist or the path is not correct.")
model = self.load_model()
results = model(
data_path,
save=False if not self.result_path else True, # pylint: disable=R1719
project=self.result_path,
name="detections",
)
return results
@staticmethod
def _shoelace_area(polygon: Any) -> float:
"""Calculate the area of a polygon using the shoelace formula."""
x, y = polygon[:, 0], polygon[:, 1]
return float(0.5 * np.abs(np.dot(x, np.roll(y, 1)) - np.dot(y, np.roll(x, 1))))
@staticmethod
def number_of_eggs(detections: Any) -> Mapping[str, Any]:
"""Count the number of eggs detected."""
counts = {}
for result in detections:
class_count = Counter(int(box.cls.item()) for box in result.boxes)
temp = []
for name, count in class_count.items():
temp.append({"class": result.names[name], "count": count})
file_name = Path(result.path).name
counts[str(file_name)] = temp
return counts
def _egg_volume(self, polygon: Any, circularity_thr: int = 15) -> float:
"""Calculate the volume of eggs based on the detected polygon for each egg."""
polygon = polygon.reshape((-1, 1, 2))
ellipse = cv2.fitEllipse(polygon) # pylint: disable=E1101
minor_axis, major_axis = (
ellipse[1][0] / self.scale_factor,
ellipse[1][1] / self.scale_factor,
)
if (major_axis - minor_axis) > circularity_thr:
return 4 * np.pi * (major_axis / 2) * ((minor_axis / 2) ** 2) / 3000
return 4 * np.pi * (((major_axis + minor_axis) / 4) ** 3) / 3000
def results_detail(self, detections: Any) -> Mapping[str, Any]:
"""Get the detailed results of the segmented eggs such as bounding boxes, class names, and confidences."""
results = {}
for result in detections:
temp = []
if result.masks is not None:
boxes = result.boxes
masks = result.masks.xy
for i, mask in enumerate(masks):
polygon = np.array(mask, dtype=np.float32)
temp.append(
{
"class": result.names[int(boxes.cls[i].item())],
"confidence": boxes.conf[i].item(),
"areas in pixel": self._shoelace_area(polygon),
"volume in cm3": self._egg_volume(polygon),
}
)
file_name = Path(result.path).name
results[str(file_name)] = temp
return results
@staticmethod
def result_images(detections: Any) -> List[Any]:
"""Make a list of the result images with detections."""
images = []
for result in detections:
images.append(np.array(result.plot())[:, :, [2, 1, 0]])
return images