afshin-dini commited on
Commit
1108dca
·
1 Parent(s): a36d904

Add the detector and inference classes

Browse files
src/deep_package_detection/detector.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This is the code for training the YOLO model for package detection."""
2
+
3
+ import logging
4
+ from pathlib import Path
5
+ from dataclasses import dataclass, field
6
+ from typing import Any, Optional, Mapping
7
+
8
+ from collections import Counter
9
+ from ultralytics import YOLO # type: ignore
10
+ import cv2
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ @dataclass
16
+ class PackageDetectorTrainer:
17
+ """Class to train YOLO model for package detection."""
18
+
19
+ conf: str = field(default="src/deep_package_detection/data/data.yaml")
20
+ epochs: int = field(default=100)
21
+ img_size: int = field(default=640)
22
+ batch_size: int = field(default=16)
23
+ device: str = field(default="cuda")
24
+ model: Any = field(init=False)
25
+
26
+ def train(self) -> None:
27
+ """Train the YOLO model for package detection."""
28
+ logger.info("Start training the YOLO model for package detection.")
29
+ self.model = YOLO("yolov8x-seg.pt")
30
+ self.model.train(
31
+ data=self.conf,
32
+ epochs=self.epochs,
33
+ imgsz=self.img_size,
34
+ batch=self.batch_size,
35
+ device=self.device,
36
+ )
37
+
38
+ def validation(self) -> Any:
39
+ """Validate the YOLO model for package detection."""
40
+ logger.info("Validating the YOLO model for package detection.")
41
+ return self.model.val()
42
+
43
+ def model_export(self) -> None:
44
+ """Export the YOLO model for package detection."""
45
+ logger.info("Exporting the YOLO model for package detection.")
46
+ self.model.export(format="onnx")
47
+
48
+
49
+ @dataclass
50
+ class PackageDetectorInference:
51
+ """Class to test package detection using a trained YOLO model."""
52
+
53
+ model_path: Optional[Any] = field(default=None)
54
+ result_path: Optional[str] = field(default=None)
55
+ confidence_threshold: float = field(default=0.6)
56
+
57
+ def __post_init__(self) -> None:
58
+ """Post-initialization method for PackageDetectorInference."""
59
+ if self.model_path is None or not self.model_path.exists():
60
+ raise ValueError("Model does not exist or the path is not correct.")
61
+
62
+ def load_model(self) -> Any:
63
+ """Load the YOLO model for package detection."""
64
+ logger.info("Loading the trained model for package detection.")
65
+ return YOLO(self.model_path) # type: ignore
66
+
67
+ def inference(self, data_path: str) -> Any:
68
+ """Inference code for egg detection"""
69
+ if not Path(data_path).exists():
70
+ logger.error("Data path does not exist or the path is not correct.")
71
+ model = self.load_model()
72
+ results = model(
73
+ data_path,
74
+ save=False,
75
+ project=self.result_path,
76
+ name="detections",
77
+ )
78
+ return results
79
+
80
+ def count_packages(self, detections: Any) -> Mapping[str, Any]:
81
+ """Count the number of packages detected."""
82
+ counts = {}
83
+ for result in detections:
84
+ class_count = Counter(
85
+ int(box.cls.item())
86
+ for box in result.boxes
87
+ if box.conf.item() > self.confidence_threshold
88
+ )
89
+ temp = []
90
+ for name, count in class_count.items():
91
+ temp.append({"class": result.names[name], "count": count})
92
+ file_name = Path(result.path).name
93
+ counts[str(file_name)] = temp
94
+ return counts
95
+
96
+ def plot_and_save_results(self, detections: Any) -> None:
97
+ """Plot and save images with only high-confidence detected objects."""
98
+ if self.result_path is None:
99
+ return
100
+ output_dir = Path(self.result_path)
101
+ output_dir.mkdir(parents=True, exist_ok=True)
102
+
103
+ logger.info("Saving high-confidence detection images to: %s", output_dir)
104
+
105
+ for result in detections:
106
+ # Read original image
107
+ img = cv2.imread(str(result.path))
108
+ if img is None:
109
+ logger.warning("Could not read image: %s", result.path)
110
+ continue
111
+
112
+ # Iterate through boxes and draw only high-confidence detections
113
+ for box in result.boxes:
114
+ conf = float(box.conf.item())
115
+ if conf < self.confidence_threshold:
116
+ continue
117
+
118
+ cls_id = int(box.cls.item())
119
+ label = f"{result.names[cls_id]} {conf:.2f}"
120
+ x1, y1, x2, y2 = map(int, box.xyxy[0])
121
+
122
+ # Draw rectangle and label
123
+ cv2.rectangle( # pylint: disable=E1101
124
+ img, (x1, y1), (x2, y2), (0, 255, 0), 2
125
+ )
126
+ cv2.putText( # pylint: disable=E1101
127
+ img,
128
+ label,
129
+ (x1, max(y1 - 10, 0)),
130
+ cv2.FONT_HERSHEY_SIMPLEX, # pylint: disable=E1101
131
+ 0.6,
132
+ (0, 255, 0),
133
+ 1,
134
+ cv2.LINE_AA, # pylint: disable=E1101
135
+ )
136
+
137
+ # Skip saving if no boxes above threshold
138
+ if not any(
139
+ box.conf.item() > self.confidence_threshold for box in result.boxes
140
+ ):
141
+ continue
142
+
143
+ # Save the result
144
+ output_path = output_dir / f"{Path(result.path).stem}_detections.jpg"
145
+ cv2.imwrite(str(output_path), img)
146
+ logger.info("Saved high-confidence detections to %s", output_path)