nexu02 commited on
Commit
98903c8
·
verified ·
1 Parent(s): b1f40fb

R11 backup: miner.py

Browse files
Files changed (1) hide show
  1. miner.py +501 -0
miner.py ADDED
@@ -0,0 +1,501 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Build: 2026-05-29 23:30 UTC R11 redeploy (force new revision)
2
+ from pathlib import Path
3
+ import math
4
+
5
+ import cv2
6
+ import numpy as np
7
+ import onnxruntime as ort
8
+ from numpy import ndarray
9
+ from pydantic import BaseModel
10
+
11
+
12
+ class BoundingBox(BaseModel):
13
+ x1: int
14
+ y1: int
15
+ x2: int
16
+ y2: int
17
+ cls_id: int
18
+ conf: float
19
+
20
+
21
+ class TVFrameResult(BaseModel):
22
+ frame_id: int
23
+ boxes: list[BoundingBox]
24
+ keypoints: list[tuple[int, int]]
25
+
26
+
27
+ class Miner:
28
+ """ONNX Runtime miner. Hard global NMS + sanity filter + dedup + flip TTA, with per-class rescue bonus."""
29
+
30
+ class_names = ["cup", "bottle", "can"]
31
+ input_size = 1280
32
+ iou_thres = 0.4
33
+ cross_iou_thresh = 0.7
34
+ min_side = 8.0
35
+ min_box_area = 100.0
36
+ max_aspect_ratio = 10.0
37
+ max_det = 300
38
+ _conf_thres_array = np.array([0.6, 0.45, 0.5], dtype=np.float32)
39
+ _bonus_array = np.array([0.0, 0.0, 0.2], dtype=np.float32)
40
+
41
+ def __init__(self, path_hf_repo: Path) -> None:
42
+ model_path = path_hf_repo / "weights.onnx"
43
+ print("ORT version:", ort.__version__)
44
+
45
+ try:
46
+ ort.preload_dlls()
47
+ print("preload_dlls success")
48
+ except Exception as e:
49
+ print(f"preload_dlls failed: {e}")
50
+
51
+ print("ORT available providers BEFORE session:", ort.get_available_providers())
52
+
53
+ sess_options = ort.SessionOptions()
54
+ sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
55
+
56
+ try:
57
+ self.session = ort.InferenceSession(
58
+ str(model_path),
59
+ sess_options=sess_options,
60
+ providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
61
+ )
62
+ print("Created ORT session with preferred CUDA provider list")
63
+ except Exception as e:
64
+ print(f"CUDA session creation failed, falling back to CPU: {e}")
65
+ self.session = ort.InferenceSession(
66
+ str(model_path),
67
+ sess_options=sess_options,
68
+ providers=["CPUExecutionProvider"],
69
+ )
70
+
71
+ print("ORT session providers:", self.session.get_providers())
72
+
73
+ for inp in self.session.get_inputs():
74
+ print("INPUT:", inp.name, inp.shape, inp.type)
75
+ for out in self.session.get_outputs():
76
+ print("OUTPUT:", out.name, out.shape, out.type)
77
+
78
+ self.input_name = self.session.get_inputs()[0].name
79
+ self.output_names = [output.name for output in self.session.get_outputs()]
80
+ self.input_shape = self.session.get_inputs()[0].shape
81
+
82
+ self.input_height = self._safe_dim(self.input_shape[2], default=self.input_size)
83
+ self.input_width = self._safe_dim(self.input_shape[3], default=self.input_size)
84
+
85
+ print(f"ONNX model loaded from: {model_path}")
86
+ print(f"ONNX providers: {self.session.get_providers()}")
87
+ print(f"ONNX input: name={self.input_name}, shape={self.input_shape}")
88
+
89
+ def __repr__(self) -> str:
90
+ return (
91
+ f"ONNXRuntime(session={type(self.session).__name__}, "
92
+ f"providers={self.session.get_providers()})"
93
+ )
94
+
95
+ @staticmethod
96
+ def _safe_dim(value, default: int) -> int:
97
+ return value if isinstance(value, int) and value > 0 else default
98
+
99
+ def _letterbox(self, image: ndarray, new_shape: tuple[int, int],
100
+ color=(114, 114, 114)
101
+ ) -> tuple[ndarray, float, tuple[float, float]]:
102
+ h, w = image.shape[:2]
103
+ new_w, new_h = new_shape
104
+ ratio = min(new_w / w, new_h / h)
105
+ resized_w = int(round(w * ratio))
106
+ resized_h = int(round(h * ratio))
107
+ if (resized_w, resized_h) != (w, h):
108
+ interp = cv2.INTER_CUBIC if ratio > 1.0 else cv2.INTER_LINEAR
109
+ image = cv2.resize(image, (resized_w, resized_h), interpolation=interp)
110
+ dw = (new_w - resized_w) / 2.0
111
+ dh = (new_h - resized_h) / 2.0
112
+ left = int(round(dw - 0.1))
113
+ right = int(round(dw + 0.1))
114
+ top = int(round(dh - 0.1))
115
+ bottom = int(round(dh + 0.1))
116
+ padded = cv2.copyMakeBorder(image, top, bottom, left, right,
117
+ borderType=cv2.BORDER_CONSTANT, value=color)
118
+ return padded, ratio, (dw, dh)
119
+
120
+ def _preprocess(self, image: ndarray
121
+ ) -> tuple[np.ndarray, float, tuple[float, float],
122
+ tuple[int, int]]:
123
+ orig_h, orig_w = image.shape[:2]
124
+ img, ratio, pad = self._letterbox(image, (self.input_width, self.input_height))
125
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
126
+ img = img.astype(np.float32) / 255.0
127
+ img = np.transpose(img, (2, 0, 1))[None, ...]
128
+ img = np.ascontiguousarray(img, dtype=np.float32)
129
+ return img, ratio, pad, (orig_w, orig_h)
130
+
131
+ @staticmethod
132
+ def _clip_boxes(boxes: np.ndarray, image_size: tuple[int, int]) -> np.ndarray:
133
+ w, h = image_size
134
+ boxes[:, 0] = np.clip(boxes[:, 0], 0, w - 1)
135
+ boxes[:, 1] = np.clip(boxes[:, 1], 0, h - 1)
136
+ boxes[:, 2] = np.clip(boxes[:, 2], 0, w - 1)
137
+ boxes[:, 3] = np.clip(boxes[:, 3], 0, h - 1)
138
+ return boxes
139
+
140
+ @staticmethod
141
+ def _xywh_to_xyxy(boxes: np.ndarray) -> np.ndarray:
142
+ out = np.empty_like(boxes)
143
+ out[:, 0] = boxes[:, 0] - boxes[:, 2] / 2.0
144
+ out[:, 1] = boxes[:, 1] - boxes[:, 3] / 2.0
145
+ out[:, 2] = boxes[:, 0] + boxes[:, 2] / 2.0
146
+ out[:, 3] = boxes[:, 1] + boxes[:, 3] / 2.0
147
+ return out
148
+
149
+ @staticmethod
150
+ def _hard_nms(boxes: np.ndarray, scores: np.ndarray,
151
+ iou_thresh: float) -> np.ndarray:
152
+ n = len(boxes)
153
+ if n == 0:
154
+ return np.array([], dtype=np.intp)
155
+ order = np.argsort(-scores)
156
+ keep: list[int] = []
157
+ while len(order) > 0:
158
+ i = int(order[0])
159
+ keep.append(i)
160
+ if len(order) == 1:
161
+ break
162
+ rest = order[1:]
163
+ xx1 = np.maximum(boxes[i, 0], boxes[rest, 0])
164
+ yy1 = np.maximum(boxes[i, 1], boxes[rest, 1])
165
+ xx2 = np.minimum(boxes[i, 2], boxes[rest, 2])
166
+ yy2 = np.minimum(boxes[i, 3], boxes[rest, 3])
167
+ inter = np.maximum(0.0, xx2 - xx1) * np.maximum(0.0, yy2 - yy1)
168
+ a_i = (max(0.0, boxes[i, 2] - boxes[i, 0]) *
169
+ max(0.0, boxes[i, 3] - boxes[i, 1]))
170
+ a_r = (np.maximum(0.0, boxes[rest, 2] - boxes[rest, 0]) *
171
+ np.maximum(0.0, boxes[rest, 3] - boxes[rest, 1]))
172
+ iou = inter / (a_i + a_r - inter + 1e-7)
173
+ order = rest[iou <= iou_thresh]
174
+ return np.array(keep, dtype=np.intp)
175
+
176
+ def _per_class_hard_nms(self, boxes: np.ndarray, scores: np.ndarray,
177
+ cls_ids: np.ndarray, iou_thresh: float
178
+ ) -> np.ndarray:
179
+ if len(boxes) == 0:
180
+ return np.array([], dtype=np.intp)
181
+ all_keep: list[int] = []
182
+ for c in np.unique(cls_ids):
183
+ mask = cls_ids == c
184
+ indices = np.where(mask)[0]
185
+ keep = self._hard_nms(boxes[mask], scores[mask], iou_thresh)
186
+ all_keep.extend(indices[keep].tolist())
187
+ all_keep.sort()
188
+ return np.array(all_keep, dtype=np.intp)
189
+
190
+ def _cross_class_dedup_op(self, boxes: np.ndarray, scores: np.ndarray,
191
+ cls_ids: np.ndarray, iou_thresh: float
192
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
193
+ n = len(boxes)
194
+ if n <= 1:
195
+ return boxes, scores, cls_ids
196
+ boxes = np.asarray(boxes, dtype=np.float32)
197
+ scores = np.asarray(scores, dtype=np.float32)
198
+ cls_ids = np.asarray(cls_ids, dtype=np.int32)
199
+ areas = (np.maximum(0.0, boxes[:, 2] - boxes[:, 0]) *
200
+ np.maximum(0.0, boxes[:, 3] - boxes[:, 1]))
201
+ margins = scores - self._conf_thres_array[cls_ids]
202
+ order = np.lexsort((-areas, -margins))
203
+ suppressed = np.zeros(n, dtype=bool)
204
+ keep: list[int] = []
205
+ for i in order:
206
+ if suppressed[i]:
207
+ continue
208
+ keep.append(int(i))
209
+ bi = boxes[i]
210
+ xx1 = np.maximum(bi[0], boxes[:, 0])
211
+ yy1 = np.maximum(bi[1], boxes[:, 1])
212
+ xx2 = np.minimum(bi[2], boxes[:, 2])
213
+ yy2 = np.minimum(bi[3], boxes[:, 3])
214
+ inter = np.maximum(0.0, xx2 - xx1) * np.maximum(0.0, yy2 - yy1)
215
+ a_i = max(1e-7, float((bi[2] - bi[0]) * (bi[3] - bi[1])))
216
+ iou = inter / (a_i + areas - inter + 1e-7)
217
+ dup = iou > iou_thresh
218
+ dup[i] = False
219
+ suppressed |= dup
220
+ keep_idx = np.array(keep, dtype=np.intp)
221
+ return boxes[keep_idx], scores[keep_idx], cls_ids[keep_idx]
222
+
223
+ def _filter_sane_boxes(self, boxes: np.ndarray, scores: np.ndarray,
224
+ cls_ids: np.ndarray, orig_size: tuple[int, int]
225
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
226
+ if len(boxes) == 0:
227
+ return boxes, scores, cls_ids
228
+ orig_w, orig_h = orig_size
229
+ image_area = float(orig_w * orig_h)
230
+ bw = np.maximum(0.0, boxes[:, 2] - boxes[:, 0])
231
+ bh = np.maximum(0.0, boxes[:, 3] - boxes[:, 1])
232
+ area = bw * bh
233
+ ar = np.where(
234
+ (bw > 0) & (bh > 0),
235
+ np.maximum(bw / np.maximum(bh, 1e-6), bh / np.maximum(bw, 1e-6)),
236
+ np.inf,
237
+ )
238
+ keep = (
239
+ (bw >= self.min_side) & (bh >= self.min_side) &
240
+ (area >= self.min_box_area) &
241
+ (area <= 0.95 * image_area) &
242
+ (ar <= self.max_aspect_ratio)
243
+ )
244
+ return boxes[keep], scores[keep], cls_ids[keep]
245
+
246
+ def _max_score_per_cluster(self, post_boxes: np.ndarray,
247
+ post_cls: np.ndarray,
248
+ full_boxes: np.ndarray,
249
+ full_scores: np.ndarray,
250
+ full_cls: np.ndarray,
251
+ iou_thresh: float) -> np.ndarray:
252
+ n = len(post_boxes)
253
+ if n == 0:
254
+ return np.empty(0, dtype=np.float32)
255
+ full_areas = (np.maximum(0.0, full_boxes[:, 2] - full_boxes[:, 0]) *
256
+ np.maximum(0.0, full_boxes[:, 3] - full_boxes[:, 1]))
257
+ out = np.empty(n, dtype=np.float32)
258
+ for i in range(n):
259
+ bi = post_boxes[i]
260
+ xx1 = np.maximum(bi[0], full_boxes[:, 0])
261
+ yy1 = np.maximum(bi[1], full_boxes[:, 1])
262
+ xx2 = np.minimum(bi[2], full_boxes[:, 2])
263
+ yy2 = np.minimum(bi[3], full_boxes[:, 3])
264
+ inter = np.maximum(0.0, xx2 - xx1) * np.maximum(0.0, yy2 - yy1)
265
+ a_i = max(0.0, float((bi[2] - bi[0]) * (bi[3] - bi[1])))
266
+ iou = inter / (a_i + full_areas - inter + 1e-7)
267
+ cluster = (iou >= iou_thresh) & (full_cls == post_cls[i])
268
+ out[i] = float(np.max(full_scores[cluster])) if np.any(cluster) else 0.0
269
+ return out
270
+
271
+ def _conf_filter_mask(self, scores: np.ndarray,
272
+ cls_ids: np.ndarray) -> np.ndarray:
273
+ """Boolean keep-mask: score >= per-class threshold, with a per-class
274
+ rescue — if a class has zero boxes passing, admit its top-1 candidate
275
+ when its score >= (per-class threshold - per-class bonus)."""
276
+ if len(scores) == 0:
277
+ return np.zeros(0, dtype=bool)
278
+ thr = self._conf_thres_array[cls_ids]
279
+ keep = scores >= thr
280
+ for c in np.unique(cls_ids):
281
+ b = float(self._bonus_array[c])
282
+ if b <= 0.0:
283
+ continue
284
+ cm = cls_ids == c
285
+ if keep[cm].any():
286
+ continue
287
+ idx = np.where(cm)[0]
288
+ top = int(idx[int(np.argmax(scores[idx]))])
289
+ if scores[top] >= self._conf_thres_array[c] - b:
290
+ keep[top] = True
291
+ return keep
292
+
293
+ def _per_view_pipeline(self, boxes: np.ndarray, scores: np.ndarray,
294
+ cls_ids: np.ndarray, orig_size: tuple[int, int]
295
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
296
+ boxes, scores, cls_ids = self._filter_sane_boxes(
297
+ boxes, scores, cls_ids, orig_size
298
+ )
299
+ if len(boxes) == 0:
300
+ return boxes, scores, cls_ids
301
+ if len(boxes) > 1:
302
+ keep = self._hard_nms(boxes, scores, self.iou_thres)
303
+ boxes, scores, cls_ids = boxes[keep], scores[keep], cls_ids[keep]
304
+ if len(scores) > self.max_det:
305
+ top = np.argsort(-scores)[: self.max_det]
306
+ boxes, scores, cls_ids = boxes[top], scores[top], cls_ids[top]
307
+ if len(boxes) > 1:
308
+ boxes, scores, cls_ids = self._cross_class_dedup_op(
309
+ boxes, scores, cls_ids, self.cross_iou_thresh
310
+ )
311
+ return boxes, scores, cls_ids
312
+
313
+ def _decode_final_dets(self, preds: np.ndarray, ratio: float,
314
+ pad: tuple[float, float],
315
+ orig_size: tuple[int, int]) -> list[BoundingBox]:
316
+ if preds.ndim == 3 and preds.shape[0] == 1:
317
+ preds = preds[0]
318
+ if preds.ndim != 2 or preds.shape[1] < 6:
319
+ raise ValueError(f"Unexpected ONNX final-det output shape: {preds.shape}")
320
+
321
+ boxes = preds[:, :4].astype(np.float32)
322
+ scores = preds[:, 4].astype(np.float32)
323
+ cls_ids = preds[:, 5].astype(np.int32)
324
+
325
+ keep = self._conf_filter_mask(scores, cls_ids)
326
+ boxes = boxes[keep]
327
+ scores = scores[keep]
328
+ cls_ids = cls_ids[keep]
329
+ if len(boxes) == 0:
330
+ return []
331
+
332
+ pad_w, pad_h = pad
333
+ boxes[:, [0, 2]] -= pad_w
334
+ boxes[:, [1, 3]] -= pad_h
335
+ boxes /= ratio
336
+ boxes = self._clip_boxes(boxes, orig_size)
337
+
338
+ boxes, scores, cls_ids = self._per_view_pipeline(
339
+ boxes, scores, cls_ids, orig_size
340
+ )
341
+ return self._build_results(boxes, scores, cls_ids)
342
+
343
+ def _decode_raw_yolo(self, preds: np.ndarray, ratio: float,
344
+ pad: tuple[float, float],
345
+ orig_size: tuple[int, int]) -> list[BoundingBox]:
346
+ if preds.ndim != 3 or preds.shape[0] != 1:
347
+ raise ValueError(f"Unexpected raw ONNX output shape: {preds.shape}")
348
+ preds = preds[0]
349
+ if preds.shape[0] <= 16 and preds.shape[1] > preds.shape[0]:
350
+ preds = preds.T
351
+ if preds.ndim != 2 or preds.shape[1] < 5:
352
+ raise ValueError(f"Unexpected raw output shape: {preds.shape}")
353
+
354
+ boxes_xywh = preds[:, :4].astype(np.float32)
355
+ cls_part = preds[:, 4:].astype(np.float32)
356
+ if cls_part.shape[1] == 1:
357
+ scores = cls_part[:, 0]
358
+ cls_ids = np.zeros(len(scores), dtype=np.int32)
359
+ else:
360
+ cls_ids = np.argmax(cls_part, axis=1).astype(np.int32)
361
+ scores = cls_part[np.arange(len(cls_part)), cls_ids]
362
+
363
+ keep = self._conf_filter_mask(scores, cls_ids)
364
+ boxes_xywh = boxes_xywh[keep]
365
+ scores = scores[keep]
366
+ cls_ids = cls_ids[keep]
367
+ if len(boxes_xywh) == 0:
368
+ return []
369
+ boxes = self._xywh_to_xyxy(boxes_xywh)
370
+
371
+ pad_w, pad_h = pad
372
+ boxes[:, [0, 2]] -= pad_w
373
+ boxes[:, [1, 3]] -= pad_h
374
+ boxes /= ratio
375
+ boxes = self._clip_boxes(boxes, orig_size)
376
+
377
+ boxes, scores, cls_ids = self._per_view_pipeline(
378
+ boxes, scores, cls_ids, orig_size
379
+ )
380
+ return self._build_results(boxes, scores, cls_ids)
381
+
382
+ @staticmethod
383
+ def _build_results(boxes: np.ndarray, scores: np.ndarray,
384
+ cls_ids: np.ndarray) -> list[BoundingBox]:
385
+ results: list[BoundingBox] = []
386
+ for box, conf, cls_id in zip(boxes, scores, cls_ids):
387
+ x1, y1, x2, y2 = box.tolist()
388
+ if x2 <= x1 or y2 <= y1:
389
+ continue
390
+ results.append(
391
+ BoundingBox(
392
+ x1=int(math.floor(x1)),
393
+ y1=int(math.floor(y1)),
394
+ x2=int(math.ceil(x2)),
395
+ y2=int(math.ceil(y2)),
396
+ cls_id=int(cls_id),
397
+ conf=float(conf),
398
+ )
399
+ )
400
+ return results
401
+
402
+ def _postprocess(self, output: np.ndarray, ratio: float,
403
+ pad: tuple[float, float],
404
+ orig_size: tuple[int, int]) -> list[BoundingBox]:
405
+ if output.ndim == 2 and output.shape[1] >= 6:
406
+ return self._decode_final_dets(output, ratio, pad, orig_size)
407
+ if output.ndim == 3 and output.shape[0] == 1 and output.shape[2] == 6:
408
+ return self._decode_final_dets(output, ratio, pad, orig_size)
409
+ return self._decode_raw_yolo(output, ratio, pad, orig_size)
410
+
411
+ def _predict_single(self, image: np.ndarray) -> list[BoundingBox]:
412
+ if image is None:
413
+ raise ValueError("Input image is None")
414
+ if not isinstance(image, np.ndarray):
415
+ raise TypeError(f"Input is not numpy array: {type(image)}")
416
+ if image.ndim != 3:
417
+ raise ValueError(f"Expected HWC image, got shape={image.shape}")
418
+ if image.shape[2] != 3:
419
+ raise ValueError(f"Expected 3 channels, got shape={image.shape}")
420
+ if image.dtype != np.uint8:
421
+ image = image.astype(np.uint8)
422
+
423
+ input_tensor, ratio, pad, orig_size = self._preprocess(image)
424
+ expected = (1, 3, self.input_height, self.input_width)
425
+ if input_tensor.shape != expected:
426
+ raise ValueError(
427
+ f"Bad input tensor shape={input_tensor.shape}, expected={expected}"
428
+ )
429
+
430
+ outputs = self.session.run(self.output_names, {self.input_name: input_tensor})
431
+ return self._postprocess(outputs[0], ratio, pad, orig_size)
432
+
433
+ def _predict_tta(self, image: np.ndarray) -> list[BoundingBox]:
434
+ boxes_orig = self._predict_single(image)
435
+ flipped = cv2.flip(image, 1)
436
+ boxes_flip = self._predict_single(flipped)
437
+ w = image.shape[1]
438
+ boxes_flip = [
439
+ BoundingBox(
440
+ x1=w - b.x2, y1=b.y1, x2=w - b.x1, y2=b.y2,
441
+ cls_id=b.cls_id, conf=b.conf,
442
+ )
443
+ for b in boxes_flip
444
+ ]
445
+ all_boxes = boxes_orig + boxes_flip
446
+ if not all_boxes:
447
+ return []
448
+
449
+ coords = np.array(
450
+ [[b.x1, b.y1, b.x2, b.y2] for b in all_boxes], dtype=np.float32
451
+ )
452
+ scores = np.array([b.conf for b in all_boxes], dtype=np.float32)
453
+ cls_ids = np.array([b.cls_id for b in all_boxes], dtype=np.int32)
454
+
455
+ hard_keep = self._per_class_hard_nms(coords, scores, cls_ids, self.iou_thres)
456
+ if len(hard_keep) == 0:
457
+ return []
458
+ if len(hard_keep) > self.max_det:
459
+ top = np.argsort(-scores[hard_keep])[: self.max_det]
460
+ hard_keep = hard_keep[top]
461
+ boosted = self._max_score_per_cluster(
462
+ coords[hard_keep], cls_ids[hard_keep],
463
+ coords, scores, cls_ids, self.iou_thres,
464
+ )
465
+
466
+ kept_coords = coords[hard_keep]
467
+ kept_cls = cls_ids[hard_keep]
468
+ if len(kept_coords) > 1:
469
+ kept_coords, boosted, kept_cls = self._cross_class_dedup_op(
470
+ kept_coords, boosted, kept_cls, self.cross_iou_thresh
471
+ )
472
+
473
+ return [
474
+ BoundingBox(
475
+ x1=int(math.floor(kept_coords[j, 0])),
476
+ y1=int(math.floor(kept_coords[j, 1])),
477
+ x2=int(math.ceil(kept_coords[j, 2])),
478
+ y2=int(math.ceil(kept_coords[j, 3])),
479
+ cls_id=int(kept_cls[j]),
480
+ conf=float(boosted[j]),
481
+ )
482
+ for j in range(len(kept_coords))
483
+ ]
484
+
485
+ def predict_batch(self, batch_images: list[ndarray], offset: int,
486
+ n_keypoints: int) -> list[TVFrameResult]:
487
+ results: list[TVFrameResult] = []
488
+ for frame_number_in_batch, image in enumerate(batch_images):
489
+ try:
490
+ boxes = self._predict_tta(image)
491
+ except Exception as e:
492
+ print(f"Inference failed for frame {offset + frame_number_in_batch}: {e}")
493
+ boxes = []
494
+ results.append(
495
+ TVFrameResult(
496
+ frame_id=offset + frame_number_in_batch,
497
+ boxes=boxes,
498
+ keypoints=[(0, 0) for _ in range(max(0, int(n_keypoints)))],
499
+ )
500
+ )
501
+ return results