File size: 13,953 Bytes
b3f968a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
"""
Improved Beverage Detection Miner
Goal: Beat 5.9% baseline and reach 90% target score

Key Improvements over baseline:
1. Better preprocessing (normalization, color correction)
2. Optimized confidence thresholds per class
3. Advanced NMS with class-aware IoU
4. Test-time augmentation support
5. Better post-processing filters
"""

from pathlib import Path
import math
from typing import Optional

import cv2
import numpy as np
import onnxruntime as ort
from numpy import ndarray
from pydantic import BaseModel


class BoundingBox(BaseModel):
    x1: int
    y1: int
    x2: int
    y2: int
    cls_id: int
    conf: float


class TVFrameResult(BaseModel):
    frame_id: int
    boxes: list[BoundingBox]
    keypoints: list[tuple[int, int]]


class Miner:
    """
    Enhanced beverage detection miner with improved accuracy.
    """

    def __init__(self, path_hf_repo: Path) -> None:
        self.path_hf_repo = path_hf_repo
        self.class_names = ['bottle', 'can', 'cup']
        
        # Initialize ONNX session with optimizations
        sess_options = ort.SessionOptions()
        sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
        sess_options.intra_op_num_threads = 4
        sess_options.inter_op_num_threads = 4
        
        self.session = ort.InferenceSession(
            str(path_hf_repo / "weights.onnx"),
            sess_options=sess_options,
            providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
        )
        
        self.input_name = self.session.get_inputs()[0].name
        input_shape = self.session.get_inputs()[0].shape
        
        # Expected [N, C, H, W]
        self.input_h = int(input_shape[2])
        self.input_w = int(input_shape[3])
        
        # Class-specific confidence thresholds (tuned for better performance)
        # These should be tuned based on validation set performance
        self.class_conf_thresholds = {
            0: 0.28,  # bottle - slightly higher (common class)
            1: 0.25,  # can - standard
            2: 0.30,  # cup - higher (harder to detect)
        }
        
        # Default confidence threshold
        self.conf_threshold = 0.25
        
        # Class-specific IoU thresholds for NMS
        self.class_iou_thresholds = {
            0: 0.45,  # bottle
            1: 0.40,  # can - allow more overlap (cans pack together)
            2: 0.45,  # cup
        }
        
        # Default IoU threshold
        self.iou_threshold = 0.45
        
        # Enable test-time augmentation for better accuracy (if latency allows)
        self.enable_tta = False  # Set to True if inference time < 100ms
        
        # Minimum box area filter (remove tiny detections)
        self.min_box_area = 100  # pixels squared
        
        # Maximum box area filter (remove unreasonably large detections)
        self.max_box_area_ratio = 0.8  # 80% of image area

    def __repr__(self) -> str:
        return (
            f"Enhanced ONNX Beverage Miner\n"
            f"Session: {type(self.session).__name__}\n"
            f"Classes: {self.class_names}\n"
            f"Input Size: {self.input_w}x{self.input_h}\n"
            f"TTA Enabled: {self.enable_tta}"
        )

    def _preprocess(self, image_bgr: ndarray, apply_clahe: bool = False) -> tuple[np.ndarray, tuple[int, int]]:
        """Enhanced preprocessing with optional CLAHE for better contrast."""
        h, w = image_bgr.shape[:2]
        
        # Apply CLAHE for better contrast (helps with dark/bright images)
        if apply_clahe:
            lab = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2LAB)
            l, a, b = cv2.split(lab)
            clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
            l = clahe.apply(l)
            lab = cv2.merge([l, a, b])
            image_bgr = cv2.cvtColor(lab, cv2.COLOR_LAB2BGR)
        
        rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
        
        # Use letterbox padding (better than simple resize)
        resized = self._letterbox_resize(rgb, (self.input_w, self.input_h))
        
        # Normalize to [0, 1]
        x = resized.astype(np.float32) / 255.0
        
        # Transpose to NCHW format
        x = np.transpose(x, (2, 0, 1))[None, ...]
        
        return x, (h, w)

    def _letterbox_resize(self, image: ndarray, target_size: tuple[int, int]) -> ndarray:
        """
        Resize image with aspect ratio preservation using letterbox.
        This is better than simple resize as it maintains object proportions.
        """
        target_w, target_h = target_size
        h, w = image.shape[:2]
        
        # Calculate scale factor
        scale = min(target_w / w, target_h / h)
        new_w = int(w * scale)
        new_h = int(h * scale)
        
        # Resize
        resized = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
        
        # Create padded image
        padded = np.full((target_h, target_w, 3), 114, dtype=np.uint8)
        
        # Calculate padding offsets
        pad_w = (target_w - new_w) // 2
        pad_h = (target_h - new_h) // 2
        
        # Place resized image in center
        padded[pad_h:pad_h + new_h, pad_w:pad_w + new_w] = resized
        
        return padded

    def _normalize_predictions(self, raw: np.ndarray) -> np.ndarray:
        """Normalize prediction tensor to [N, C] format."""
        pred = raw[0]
        if pred.ndim != 2:
            raise ValueError(f"Unexpected prediction shape: {raw.shape}")
        
        # Ensure shape is [N, C] where C = 4 + num_classes
        if pred.shape[0] < pred.shape[1]:
            pred = pred.transpose(1, 0)
        
        return pred

    def _nms_class_aware(
        self, 
        dets: list[tuple[float, float, float, float, float, int]]
    ) -> list[tuple[float, float, float, float, float, int]]:
        """
        Class-aware NMS with per-class IoU thresholds.
        Better than standard NMS for multi-class detection.
        """
        if not dets:
            return []
        
        # Group detections by class
        class_dets = {}
        for det in dets:
            cls_id = det[5]
            if cls_id not in class_dets:
                class_dets[cls_id] = []
            class_dets[cls_id].append(det)
        
        # Apply NMS per class
        final_dets = []
        for cls_id, cls_boxes in class_dets.items():
            iou_thresh = self.class_iou_thresholds.get(cls_id, self.iou_threshold)
            kept = self._nms(cls_boxes, iou_thresh)
            final_dets.extend(kept)
        
        return final_dets

    def _nms(
        self, 
        dets: list[tuple[float, float, float, float, float, int]],
        iou_threshold: Optional[float] = None
    ) -> list[tuple[float, float, float, float, float, int]]:
        """Standard NMS implementation."""
        if not dets:
            return []
        
        if iou_threshold is None:
            iou_threshold = self.iou_threshold
        
        boxes = np.array([[d[0], d[1], d[2], d[3]] for d in dets], dtype=np.float32)
        scores = np.array([d[4] for d in dets], dtype=np.float32)
        order = scores.argsort()[::-1]
        keep = []

        while order.size > 0:
            i = order[0]
            keep.append(i)
            
            if order.size == 1:
                break
            
            xx1 = np.maximum(boxes[i, 0], boxes[order[1:], 0])
            yy1 = np.maximum(boxes[i, 1], boxes[order[1:], 1])
            xx2 = np.minimum(boxes[i, 2], boxes[order[1:], 2])
            yy2 = np.minimum(boxes[i, 3], boxes[order[1:], 3])

            w = np.maximum(0.0, xx2 - xx1)
            h = np.maximum(0.0, yy2 - yy1)
            inter = w * h

            area_i = (boxes[i, 2] - boxes[i, 0]) * (boxes[i, 3] - boxes[i, 1])
            area_rest = (boxes[order[1:], 2] - boxes[order[1:], 0]) * (boxes[order[1:], 3] - boxes[order[1:], 1])
            union = np.maximum(area_i + area_rest - inter, 1e-6)
            iou = inter / union

            remaining = np.where(iou <= iou_threshold)[0]
            order = order[remaining + 1]

        return [dets[idx] for idx in keep]

    def _filter_boxes(
        self, 
        boxes: list[tuple[float, float, float, float, float, int]],
        orig_w: int,
        orig_h: int
    ) -> list[tuple[float, float, float, float, float, int]]:
        """Filter out unreasonable detections."""
        filtered = []
        max_area = orig_w * orig_h * self.max_box_area_ratio
        
        for x1, y1, x2, y2, conf, cls_id in boxes:
            # Calculate box area
            area = (x2 - x1) * (y2 - y1)
            
            # Filter by area
            if area < self.min_box_area or area > max_area:
                continue
            
            # Filter by aspect ratio (beverages shouldn't be too extreme)
            width = x2 - x1
            height = y2 - y1
            aspect_ratio = width / max(height, 1)
            
            # Beverages typically have aspect ratio between 0.3 and 3.0
            if aspect_ratio < 0.2 or aspect_ratio > 4.0:
                continue
            
            filtered.append((x1, y1, x2, y2, conf, cls_id))
        
        return filtered

    def _infer_single(self, image_bgr: ndarray) -> list[BoundingBox]:
        """Inference on a single image."""
        inp, (orig_h, orig_w) = self._preprocess(image_bgr)
        out = self.session.run(None, {self.input_name: inp})[0]
        pred = self._normalize_predictions(out)

        if pred.shape[1] < 5:
            return []

        boxes = pred[:, :4]
        cls_scores = pred[:, 4:]

        if cls_scores.shape[1] == 0:
            return []

        cls_ids = np.argmax(cls_scores, axis=1)
        confs = np.max(cls_scores, axis=1)
        
        # Apply class-specific confidence thresholds
        keep = np.zeros(len(confs), dtype=bool)
        for cls_id in range(len(self.class_names)):
            cls_mask = cls_ids == cls_id
            cls_conf_thresh = self.class_conf_thresholds.get(cls_id, self.conf_threshold)
            keep |= (cls_mask & (confs >= cls_conf_thresh))
        
        boxes = boxes[keep]
        confs = confs[keep]
        cls_ids = cls_ids[keep]

        if boxes.shape[0] == 0:
            return []

        # Scale boxes back to original image size
        sx = orig_w / float(self.input_w)
        sy = orig_h / float(self.input_h)

        dets: list[tuple[float, float, float, float, float, int]] = []
        for i in range(boxes.shape[0]):
            cx, cy, bw, bh = boxes[i].tolist()
            x1 = (cx - bw / 2.0) * sx
            y1 = (cy - bh / 2.0) * sy
            x2 = (cx + bw / 2.0) * sx
            y2 = (cy + bh / 2.0) * sy
            dets.append((x1, y1, x2, y2, float(confs[i]), int(cls_ids[i])))

        # Filter unreasonable boxes
        dets = self._filter_boxes(dets, orig_w, orig_h)
        
        # Apply class-aware NMS
        dets = self._nms_class_aware(dets)

        # Convert to BoundingBox objects
        out_boxes: list[BoundingBox] = []
        for x1, y1, x2, y2, conf, cls_id in dets:
            ix1 = max(0, min(orig_w, math.floor(x1)))
            iy1 = max(0, min(orig_h, math.floor(y1)))
            ix2 = max(0, min(orig_w, math.ceil(x2)))
            iy2 = max(0, min(orig_h, math.ceil(y2)))
            
            out_boxes.append(
                BoundingBox(
                    x1=ix1,
                    y1=iy1,
                    x2=ix2,
                    y2=iy2,
                    cls_id=cls_id,
                    conf=max(0.0, min(1.0, conf)),
                )
            )
        
        return out_boxes

    def _infer_with_tta(self, image_bgr: ndarray) -> list[BoundingBox]:
        """
        Test-time augmentation for better accuracy.
        Runs inference on multiple augmentations and merges results.
        """
        # Original image
        boxes_orig = self._infer_single(image_bgr)
        
        # Horizontal flip
        image_flip = cv2.flip(image_bgr, 1)
        boxes_flip = self._infer_single(image_flip)
        
        # Flip boxes back
        h, w = image_bgr.shape[:2]
        for box in boxes_flip:
            box.x1, box.x2 = w - box.x2, w - box.x1
        
        # Merge and NMS
        all_dets = []
        for box in boxes_orig + boxes_flip:
            all_dets.append((
                float(box.x1), float(box.y1), 
                float(box.x2), float(box.y2),
                float(box.conf), int(box.cls_id)
            ))
        
        # Apply NMS to merged results
        final_dets = self._nms_class_aware(all_dets)
        
        # Convert back to BoundingBox
        final_boxes = []
        for x1, y1, x2, y2, conf, cls_id in final_dets:
            final_boxes.append(
                BoundingBox(
                    x1=int(x1), y1=int(y1),
                    x2=int(x2), y2=int(y2),
                    cls_id=cls_id, conf=conf
                )
            )
        
        return final_boxes

    def predict_batch(
        self,
        batch_images: list[ndarray],
        offset: int,
        n_keypoints: int,
    ) -> list[TVFrameResult]:
        """
        Predict on a batch of images.
        """
        results: list[TVFrameResult] = []
        
        for idx, image in enumerate(batch_images):
            # Use TTA if enabled and latency allows
            if self.enable_tta:
                boxes = self._infer_with_tta(image)
            else:
                boxes = self._infer_single(image)
            
            # No keypoints for this task
            keypoints = [(0, 0) for _ in range(max(0, int(n_keypoints)))]
            
            results.append(
                TVFrameResult(
                    frame_id=offset + idx,
                    boxes=boxes,
                    keypoints=keypoints,
                )
            )
        
        return results