File size: 6,679 Bytes
1108dca bc55a23 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 |
"""This is the code for training the YOLO model for package detection."""
import logging
from pathlib import Path
from dataclasses import dataclass, field
from typing import Any, Optional, Mapping
from collections import Counter
from ultralytics import YOLO # type: ignore
import cv2
logger = logging.getLogger(__name__)
@dataclass
class PackageDetectorTrainer:
"""Class to train YOLO model for package detection."""
conf: str = field(default="src/deep_package_detection/data/data.yaml")
epochs: int = field(default=100)
img_size: int = field(default=640)
batch_size: int = field(default=16)
device: str = field(default="cuda")
model: Any = field(init=False)
def train(self) -> None:
"""Train the YOLO model for package detection."""
logger.info("Start training the YOLO model for package detection.")
self.model = YOLO("yolov8x-seg.pt")
self.model.train(
data=self.conf,
epochs=self.epochs,
imgsz=self.img_size,
batch=self.batch_size,
device=self.device,
)
def validation(self) -> Any:
"""Validate the YOLO model for package detection."""
logger.info("Validating the YOLO model for package detection.")
return self.model.val()
def model_export(self) -> None:
"""Export the YOLO model for package detection."""
logger.info("Exporting the YOLO model for package detection.")
self.model.export(format="onnx")
@dataclass
class PackageDetectorInference:
"""Class to test package detection using a trained YOLO model."""
model_path: Optional[Any] = field(default=None)
result_path: Optional[str] = field(default=None)
confidence_threshold: float = field(default=0.6)
def __post_init__(self) -> None:
"""Post-initialization method for PackageDetectorInference."""
if self.model_path is None or not self.model_path.exists():
raise ValueError("Model does not exist or the path is not correct.")
def load_model(self) -> Any:
"""Load the YOLO model for package detection."""
logger.info("Loading the trained model for package detection.")
return YOLO(self.model_path) # type: ignore
def inference(self, data_path: str) -> Any:
"""Inference code for egg detection"""
if not Path(data_path).exists():
logger.error("Data path does not exist or the path is not correct.")
model = self.load_model()
results = model(
data_path,
save=False,
project=self.result_path,
name="detections",
)
return results
def count_packages(self, detections: Any) -> Mapping[str, Any]:
"""Count the number of packages detected."""
counts = {}
for result in detections:
class_count = Counter(
int(box.cls.item())
for box in result.boxes
if box.conf.item() > self.confidence_threshold
)
temp = []
for name, count in class_count.items():
temp.append({"class": result.names[name], "count": count})
file_name = Path(result.path).name
counts[str(file_name)] = temp
return counts
def plot_and_save_results(self, detections: Any) -> None:
"""Plot and save images with only high-confidence detected objects."""
if self.result_path is None:
return
output_dir = Path(self.result_path)
output_dir.mkdir(parents=True, exist_ok=True)
logger.info("Saving high-confidence detection images to: %s", output_dir)
for result in detections:
# Read original image
img = cv2.imread(str(result.path))
if img is None:
logger.warning("Could not read image: %s", result.path)
continue
# Iterate through boxes and draw only high-confidence detections
for box in result.boxes:
conf = float(box.conf.item())
if conf < self.confidence_threshold:
continue
cls_id = int(box.cls.item())
label = f"{result.names[cls_id]} {conf:.2f}"
x1, y1, x2, y2 = map(int, box.xyxy[0])
# Draw rectangle and label
cv2.rectangle( # pylint: disable=E1101
img, (x1, y1), (x2, y2), (0, 255, 0), 2
)
cv2.putText( # pylint: disable=E1101
img,
label,
(x1, max(y1 - 10, 0)),
cv2.FONT_HERSHEY_SIMPLEX, # pylint: disable=E1101
0.6,
(0, 255, 0),
1,
cv2.LINE_AA, # pylint: disable=E1101
)
# Skip saving if no boxes above threshold
if not any(
box.conf.item() > self.confidence_threshold for box in result.boxes
):
continue
# Save the result
output_path = output_dir / f"{Path(result.path).stem}_detections.jpg"
cv2.imwrite(str(output_path), img)
logger.info("Saved high-confidence detections to %s", output_path)
def single_inference(self, detections: Any) -> Optional[Any]:
"""Do single inference for application demo considering high-confidence detected objects."""
result = detections[0]
# Read original image
img = cv2.imread(str(result.path))
if img is None:
logger.warning("Could not read image: %s", result.path)
return None
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # pylint: disable=E1101
# Iterate through boxes and draw only high-confidence detections
for box in result.boxes:
conf = float(box.conf.item())
if conf < self.confidence_threshold:
continue
cls_id = int(box.cls.item())
label = f"{result.names[cls_id]} {conf:.2f}"
x1, y1, x2, y2 = map(int, box.xyxy[0])
# Draw rectangle and label
cv2.rectangle( # pylint: disable=E1101
img, (x1, y1), (x2, y2), (0, 255, 0), 2
)
cv2.putText( # pylint: disable=E1101
img,
label,
(x1, max(y1 - 10, 0)),
cv2.FONT_HERSHEY_SIMPLEX, # pylint: disable=E1101
0.6,
(0, 255, 0),
1,
cv2.LINE_AA, # pylint: disable=E1101
)
return img
|