File size: 26,941 Bytes
bfacb8c f29cd5b 8e07026 bfacb8c 107cac4 f29cd5b bfacb8c 4a79163 bfacb8c 4a79163 bfacb8c 87f5dd9 f29cd5b bfacb8c 0461195 bfacb8c f29cd5b bfacb8c f29cd5b 107cac4 4a79163 f29cd5b 4a79163 f29cd5b 4a79163 bfacb8c 87f5dd9 4a79163 d0e9ded 87f5dd9 bfacb8c 4a79163 bfacb8c d0e9ded bfacb8c 4a79163 bfacb8c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 | """
SN44 number plate detection miner β single-element chute for
manak0/Detect-number-plates-1-0.
Adapted from the auto-generated detect-person-reference miner with four
substantive changes:
1. Class set is the single class ``numberplate`` (the validator's exact
label string).
2. Lower confidence threshold (0.15 vs 0.25) because the validator's
plates are tiny β 5β92 px wide on a 1408 px frame, median ~30 px.
At standard 0.25 most true positives get filtered before NMS.
3. Standard NMS replaced with Gaussian Soft-NMS (sigma=0.5). Soft-NMS
decays scores of overlapping boxes instead of suppressing them
outright, which helps on plate-dense frames (parking lot, car
carrier, gas station forecourt) where standard NMS over-suppresses
adjacent plates.
4. CUDA library preload at import time so onnxruntime-gpu finds
libcudnn / libcublas from the nvidia-* pip wheels even when
LD_LIBRARY_PATH is not set (the chute container ships these wheels
but does not export them).
Soft-NMS is inlined here rather than imported from /home/miner/utils
because the chute platform sandbox restricts non-stdlib imports beyond
the deps declared in chute_config.yml. The implementation is a
specialised single-class version of soft_nms_yolo from
/home/miner/utils/soft_nms.py β see that file for the full
multi-class / multi-backend version.
"""
import ctypes
import glob as _glob
import logging as _logging
import os
_cuda_log = _logging.getLogger(__name__)
def _preload_cuda_libs() -> None:
"""Pre-load CUDA + cuDNN + cuBLAS shared libs from nvidia-* pip wheels.
Without this, onnxruntime-gpu's CUDAExecutionProvider silently falls
back to CPU because it can't dlopen libcudnn.so.9 β the nvidia
wheels ship the library inside `nvidia/cudnn/lib/` but do NOT add
that directory to the loader path. We import the wheel modules to
locate their lib dirs, prepend them to LD_LIBRARY_PATH for any
child processes, and ctypes.CDLL the .so files with RTLD_GLOBAL so
onnxruntime's dlopen sees them.
"""
try:
lib_dirs: list[str] = []
for mod_name in (
"nvidia.cudnn",
"nvidia.cublas",
"nvidia.cuda_runtime",
"nvidia.cufft",
"nvidia.curand",
"nvidia.cusolver",
"nvidia.cusparse",
"nvidia.nvjitlink",
):
try:
mod = __import__(mod_name, fromlist=["__file__"])
lib_dir = os.path.join(os.path.dirname(mod.__file__), "lib")
if os.path.isdir(lib_dir) and lib_dir not in lib_dirs:
lib_dirs.append(lib_dir)
except ImportError:
pass
if not lib_dirs:
_cuda_log.warning("no nvidia-* lib dirs found; ORT GPU may fall back to CPU")
return
# Update LD_LIBRARY_PATH for any child processes / dlopen fallbacks
existing = os.environ.get("LD_LIBRARY_PATH", "")
os.environ["LD_LIBRARY_PATH"] = ":".join(
lib_dirs + ([existing] if existing else [])
)
# ctypes.CDLL each .so so the symbols are globally visible to ORT
for lib_dir in lib_dirs:
for so in sorted(_glob.glob(os.path.join(lib_dir, "lib*.so*"))):
try:
ctypes.CDLL(so, mode=ctypes.RTLD_GLOBAL)
except OSError:
pass
except Exception as e: # pragma: no cover - best effort
_cuda_log.warning("CUDA preload failed: %s", e)
_preload_cuda_libs()
from pathlib import Path
import math
import cv2
import numpy as np
import onnxruntime as ort
from numpy import ndarray
from pydantic import BaseModel
class BoundingBox(BaseModel):
x1: int
y1: int
x2: int
y2: int
cls_id: int
conf: float
class TVFrameResult(BaseModel):
frame_id: int
boxes: list[BoundingBox]
keypoints: list[tuple[int, int]]
class Miner:
"""
Single-element ONNX miner for the manak0/Detect-number-plates-1-0
element. Auto-loaded by the chute platform; the platform passes the
snapshot path of the HF repo containing weights.onnx as
``path_hf_repo`` and calls ``predict_batch(batch_images, offset,
n_keypoints)`` for each request.
"""
def __init__(self, path_hf_repo) -> None:
self.path_hf_repo = Path(path_hf_repo)
self.class_names = ['numberplate']
self.session = ort.InferenceSession(
str(self.path_hf_repo / "numberplate_weights.onnx"),
providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
)
self.input_name = self.session.get_inputs()[0].name
input_shape = self.session.get_inputs()[0].shape
# expected [N, C, H, W]; dynamic-export ONNX has string placeholders
# for spatial dims. We always run inference at 1408 (the validator's
# native frame width); the ONNX accepts variable shapes via dynamic
# axes, and inference at 1408 gives substantially better small-plate
# recall than the model's training resolution (verified on the 7
# starter assets: 43% recall at 960 vs 60% at 1408).
def _maybe_int(d, default):
try:
return int(d)
except (TypeError, ValueError):
return default
# Hard-pin to the validator's native 1408x768 (rectangular). This
# is half the pixel count of a 1408x1408 square pad and matches
# the validator's exact frame shape, eliminating wasted padding
# rows. yolo11s strides are 32, both 1408 (44*32) and 768 (24*32)
# are valid.
self.input_h = 768
self.input_w = 1408
# Record what the ONNX *declared*, for diagnostic logging only
self._onnx_declared_h = _maybe_int(input_shape[2], None)
self._onnx_declared_w = _maybe_int(input_shape[3], None)
# Pre-NMS confidence threshold. Top-of-leaderboard miners (Cargile,
# alfred8995) run 0.16-0.21 with a LOW post-NMS floor (0.01) and TTA.
# Bench on 12 archived validator tasks (12-task consensus set) at
# conf=0.16, score_threshold=0.01 with our v3 ONNX: HIGH recall
# 57% -> 66% (+3 plates) at the cost of +2 singletons. Net positive
# given 0.6 weight on map50 vs 0.4 on false_positive in composite.
# 2026-04-30: lowered to 0.12 after bench sweep on 33-task archive +
# 30-frame starter showed +1.6pp HIGH recall (0.8607β0.8770) for only
# +2 phantoms (+0.061/frame), 1:1 hit-to-phantom ratio.
self.conf_threshold = 0.12
# Soft-NMS hyperparameters (Gaussian variant).
# Tightened 0.5 β 0.3: sharper decay collapses near-duplicates from SAHI
# tile-seam overlaps faster, dropping more below the 0.01 score floor.
self.soft_nms_sigma = 0.3
# Final score floor after Soft-NMS decay. Was 0.20 β raised threshold
# killed decayed real plates (e.g. plate adjacent to a higher-conf
# detection gets decayed below 0.20 and dropped). Matches competitor
# 0.01 floor; Soft-NMS still prevents wild duplicates via decay.
self.score_threshold = 0.01
# Horizontal-flip TTA. Doubles inference cost (~101ms -> ~200ms at
# batch=1) but we have ~10s budget per-frame, massive headroom. Both
# top miners (Cargile, alfred8995) use TTA β the extra view helps
# catch plates the model is directionally biased against.
self.use_tta = True
# Dual-threshold TTA verification gate (hermes-style, seen in the
# hermestech00/numberplate0 HF repo). Final-output gate:
# - conf >= conf_high β pass unconditionally
# - conf in [conf_threshold, conf_high) β must have a flip-view
# match with IoU >= tta_match_iou
# to survive
# Uses TTA as a cross-view VERIFIER, not just a recall booster.
# Skips when use_tta=False.
self.conf_high = 0.90
self.tta_match_iou = 0.01
# GPU warmup β force ORT / CUDA / cuDNN kernel compilation and pull
# the 4090 out of low-power idle state so the first real validator
# frame doesn't pay a ~20 ms DVFS spin-up tax. SCOREVISION_WARMUP_CALLS
# at the chute level defaults to 3, which is not enough to reach
# steady-state on this tiled inference path (measured: 3 calls -> 52
# ms p95 on the first few frames vs 31 ms steady). 10 full pipeline
# runs on a synthetic frame gets us to the fast regime before the
# platform warmup even starts.
_warmup_frame = np.zeros((self.input_h, self.input_w, 3), dtype=np.uint8)
for _ in range(10):
try:
self._infer_single(_warmup_frame)
except Exception: # pragma: no cover - best effort
break
def __repr__(self) -> str:
return (
f"NumberplateMiner session={type(self.session).__name__} "
f"input={self.input_h}x{self.input_w} classes={len(self.class_names)}"
)
# ---------------------------------------------------------------- preproc
def _preprocess(self, image_bgr: ndarray):
"""Letterbox the BGR image to (input_h, input_w), preserving aspect.
Returns the float32 NCHW tensor plus the metadata needed to undo
the letterbox during decode: (orig_h, orig_w, scale, dx, dy).
"""
h, w = image_bgr.shape[:2]
scale = min(self.input_h / h, self.input_w / w)
nh, nw = int(round(h * scale)), int(round(w * scale))
resized = cv2.resize(image_bgr, (nw, nh))
# Pad to (input_h, input_w) with grey (114) - ultralytics default
canvas = np.full((self.input_h, self.input_w, 3), 114, dtype=np.uint8)
dy = (self.input_h - nh) // 2
dx = (self.input_w - nw) // 2
canvas[dy:dy + nh, dx:dx + nw] = resized
rgb = cv2.cvtColor(canvas, cv2.COLOR_BGR2RGB)
x = rgb.astype(np.float32) / 255.0
x = np.transpose(x, (2, 0, 1))[None, ...]
return x, (h, w, scale, dx, dy)
# ---------------------------------------------------------------- decode
def _normalize_predictions(self, raw: np.ndarray) -> np.ndarray:
"""Handle both common ultralytics export shapes ([1,C,N] and [1,N,C])."""
pred = raw[0]
if pred.ndim != 2:
raise ValueError(f"Unexpected prediction shape: {raw.shape}")
if pred.shape[0] < pred.shape[1]:
pred = pred.transpose(1, 0)
return pred
# ---------------------------------------------------------------- soft NMS
def _soft_nms(
self,
dets: list[tuple[float, float, float, float, float, int]],
) -> list[tuple[float, float, float, float, float, int]]:
"""Gaussian Soft-NMS for a single class.
Decays each remaining box's score by ``exp(-iou^2 / sigma)`` against
the highest-scoring picked box, then drops anything below
``self.score_threshold``. Returns detections in descending decayed
score order.
"""
if not dets:
return []
boxes = np.asarray([[d[0], d[1], d[2], d[3]] for d in dets], dtype=np.float32)
scores = np.asarray([d[4] for d in dets], dtype=np.float32)
cls_ids = [int(d[5]) for d in dets]
n = len(dets)
keep_idx: list[int] = []
keep_scores: list[float] = []
active = np.ones(n, dtype=bool)
while True:
valid_mask = active & (scores >= self.score_threshold)
if not valid_mask.any():
break
valid_idx = np.where(valid_mask)[0]
m_local = valid_idx[int(np.argmax(scores[valid_idx]))]
keep_idx.append(int(m_local))
keep_scores.append(float(scores[m_local]))
active[m_local] = False
# IoU of m_local against all still-active boxes
others = np.where(active)[0]
if others.size == 0:
break
ax1 = np.maximum(boxes[m_local, 0], boxes[others, 0])
ay1 = np.maximum(boxes[m_local, 1], boxes[others, 1])
ax2 = np.minimum(boxes[m_local, 2], boxes[others, 2])
ay2 = np.minimum(boxes[m_local, 3], boxes[others, 3])
inter_w = np.clip(ax2 - ax1, a_min=0.0, a_max=None)
inter_h = np.clip(ay2 - ay1, a_min=0.0, a_max=None)
inter = inter_w * inter_h
area_m = max(0.0, (boxes[m_local, 2] - boxes[m_local, 0])) * \
max(0.0, (boxes[m_local, 3] - boxes[m_local, 1]))
area_o = (
np.clip(boxes[others, 2] - boxes[others, 0], a_min=0.0, a_max=None) *
np.clip(boxes[others, 3] - boxes[others, 1], a_min=0.0, a_max=None)
)
union = area_m + area_o - inter
iou = np.where(union > 0.0, inter / union, 0.0)
decay = np.exp(-(iou * iou) / self.soft_nms_sigma)
scores[others] = scores[others] * decay
return [
(
float(boxes[i, 0]),
float(boxes[i, 1]),
float(boxes[i, 2]),
float(boxes[i, 3]),
float(s),
cls_ids[i],
)
for i, s in zip(keep_idx, keep_scores)
]
# ---------------------------------------------------------------- inference
def _infer_tile(
self,
image_bgr: ndarray,
x0: int,
y0: int,
x1: int,
y1: int,
) -> list[tuple[float, float, float, float, float, int]]:
"""Run one inference pass on ``image_bgr[y0:y1, x0:x1]`` resized
anisotropically to ``(input_h, input_w)`` and return raw detections
(pre-Soft-NMS) mapped back to ORIGINAL-image coordinates.
Anisotropic resize is intentional: the tile aspect ratio differs
from the model input, and we want the tile pixels to magnify up to
the detector's stride-8 feature footprint. For the 1408x422
top/bottom tiles used by ``_infer_single`` this yields ~1.82x
vertical magnification (and 1.0x horizontal), which is what pushes
tiny-height plates (5-12 px on the validator's starter frames)
above the stride-8 threshold.
"""
crop = image_bgr[y0:y1, x0:x1]
ch, cw = crop.shape[:2]
if ch == 0 or cw == 0:
return []
resized = cv2.resize(crop, (self.input_w, self.input_h))
rgb = cv2.cvtColor(resized, cv2.COLOR_BGR2RGB)
x = np.transpose(rgb.astype(np.float32) / 255.0, (2, 0, 1))[None, ...]
out = self.session.run(None, {self.input_name: x})[0]
# Scale factors from model-input space -> crop -> original image coords.
sx = cw / self.input_w
sy = ch / self.input_h
# Shape-dispatch: detect end2end export format (YOLO26 family: [1, N, 6]
# with N<=300, per-row [x1, y1, x2, y2, conf, cls_id] already NMS'd) vs
# raw YOLO11/v8 export ([1, C, anchors] or [1, anchors, C] with cx/cy/w/h
# + per-class scores, pre-NMS).
if out.ndim == 3 and out.shape[-1] == 6:
rows = out[0] # [N, 6]
confs_all = rows[:, 4]
keep = confs_all >= self.conf_threshold
rows = rows[keep]
if rows.shape[0] == 0:
return []
dets_e2e: list[tuple[float, float, float, float, float, int]] = []
for i in range(rows.shape[0]):
x1m, y1m, x2m, y2m, conf, cls_id = rows[i].tolist()
xa = x1m * sx + x0
ya = y1m * sy + y0
xb = x2m * sx + x0
yb = y2m * sy + y0
dets_e2e.append((xa, ya, xb, yb, float(conf), int(cls_id)))
return dets_e2e
pred = self._normalize_predictions(out)
if pred.shape[1] < 5:
return []
boxes_m = pred[:, :4]
cls_scores = pred[:, 4:]
if cls_scores.shape[1] == 0:
return []
cls_ids = np.argmax(cls_scores, axis=1)
confs = np.max(cls_scores, axis=1)
keep = confs >= self.conf_threshold
boxes_m = boxes_m[keep]
confs = confs[keep]
cls_ids = cls_ids[keep]
if boxes_m.shape[0] == 0:
return []
dets: list[tuple[float, float, float, float, float, int]] = []
for i in range(boxes_m.shape[0]):
cx, cy, bw, bh = boxes_m[i].tolist()
xa = (cx - bw / 2.0) * sx + x0
ya = (cy - bh / 2.0) * sy + y0
xb = (cx + bw / 2.0) * sx + x0
yb = (cy + bh / 2.0) * sy + y0
dets.append((xa, ya, xb, yb, float(confs[i]), int(cls_ids[i])))
return dets
def _cluster_dedup(
self,
dets: list[tuple[float, float, float, float, float, int]],
iou_thresh: float = 0.5,
) -> list[tuple[float, float, float, float, float, int]]:
"""Greedy near-duplicate suppression β for any pair with IoU >=
``iou_thresh``, keep only the higher-conf detection.
Purpose: collapse TTA-induced duplicates of the same plate before
Soft-NMS, which would otherwise decay (but not kill) the lower-conf
copy, leaving multiple boxes per plate past our low score_threshold.
Mirrors the TTA-cluster-merge step in alfred8995/arabic000's miner.py.
Applied on *every* call (not just TTA) because the quad-4 overlap
band can also produce near-duplicate detections near tile seams.
IoU threshold 0.5 is loose enough that adjacent-but-distinct plates
(IoU < 0.5) stay separate; tight enough that same-plate variants
(IoU > 0.9 in practice) collapse.
"""
if not dets:
return []
# Sort by conf desc (index 4)
srt = sorted(dets, key=lambda d: -d[4])
kept: list[tuple[float, float, float, float, float, int]] = []
suppressed = [False] * len(srt)
for i in range(len(srt)):
if suppressed[i]:
continue
x1i, y1i, x2i, y2i = srt[i][0], srt[i][1], srt[i][2], srt[i][3]
area_i = max(0.0, x2i - x1i) * max(0.0, y2i - y1i)
kept.append(srt[i])
for j in range(i + 1, len(srt)):
if suppressed[j]:
continue
x1j, y1j, x2j, y2j = srt[j][0], srt[j][1], srt[j][2], srt[j][3]
ix1 = max(x1i, x1j); iy1 = max(y1i, y1j)
ix2 = min(x2i, x2j); iy2 = min(y2i, y2j)
iw = max(0.0, ix2 - ix1); ih = max(0.0, iy2 - iy1)
inter = iw * ih
area_j = max(0.0, x2j - x1j) * max(0.0, y2j - y1j)
union = area_i + area_j - inter
if union > 0 and inter / union >= iou_thresh:
suppressed[j] = True
return kept
def _quad4_raw_dets(
self,
image_bgr: ndarray,
) -> list[tuple[float, float, float, float, float, int]]:
"""Run the quad-4 tile pipeline and return RAW (pre-Soft-NMS)
detections in original-image coordinates."""
orig_h, orig_w = image_bgr.shape[:2]
# 2026-05-01: bumped to OVERLAP_X=55, OVERLAP_Y=32 after bench sweep on
# 33-task archive: archive HIGH R 0.8770β0.8934 (+1.64pp, +2 plates),
# phantom rate 11.7%β9.3% (-2.4pp). cid 62115 plate-1 IoU vs winner
# 0.520β0.849 (y-seam fix). starter R unchanged (saturated 0.9333).
OVERLAP_X = 55 # was 35; +1.64pp archive R from x-seam plate recovery
OVERLAP_Y = 32 # was 19; cures bbox-regression on plates spanning y-seam
mx = orig_w // 2
my = orig_h // 2
tiles = [
(0, 0, min(orig_w, mx + OVERLAP_X), min(orig_h, my + OVERLAP_Y)), # TL
(max(0, mx - OVERLAP_X), 0, orig_w, min(orig_h, my + OVERLAP_Y)), # TR
(0, max(0, my - OVERLAP_Y), min(orig_w, mx + OVERLAP_X), orig_h), # BL
(max(0, mx - OVERLAP_X), max(0, my - OVERLAP_Y), orig_w, orig_h), # BR
]
all_dets: list[tuple[float, float, float, float, float, int]] = []
for x0, y0, x1, y1 in tiles:
all_dets.extend(self._infer_tile(image_bgr, x0, y0, x1, y1))
return all_dets
def _infer_single(self, image_bgr: ndarray) -> list[BoundingBox]:
"""Quad-4 (2x2 quadrant) SAHI inference with optional horizontal-flip TTA.
Splits the frame into four overlapping quadrants, each
anisotropically resized to ``(input_h, input_w)`` for ~2x
magnification in both axes. Overlap is ~10% on each axis.
All tile detections are merged via Soft-NMS.
With ``self.use_tta=True``: additionally runs the same quad-4 pass
on a horizontally flipped copy and un-flips the x-coordinates back
into original space. Soft-NMS then merges across both views,
preferring the higher-confidence one for any paired detection.
Measured (quad-4 without TTA) on 7 starter frames vs TB-2:
mAP@50 0.406 -> 0.489
recall 0.433 -> 0.500
wall p95 55 ms -> 98 ms
TTA roughly doubles inference cost (budget: 10 s).
"""
orig_h, orig_w = image_bgr.shape[:2]
all_dets = self._quad4_raw_dets(image_bgr)
# Adaptive conf fallback removed Apr 26: ep 29 has higher recall than
# v3 so the empty-first-pass case is rarer, and when it did fire the
# conf=0.10 retry generated phantoms (FP drag) AND added a ~3s
# inference pass (latency gate trigger). See mining_history.md.
# Keep flipped-view detections SEPARATE from original, so we can use
# them as a cross-view verifier (hermes-style gate) later β not just
# merge them into all_dets as a recall booster.
flip_dets_unflipped: list[tuple] = []
if self.use_tta:
flipped = cv2.flip(image_bgr, 1) # horizontal flip (mirror)
flip_dets = self._quad4_raw_dets(flipped)
# Un-flip x-coordinates: x_orig = W - x_flipped
for x1f, y1, x2f, y2, conf, cls_id in flip_dets:
flip_dets_unflipped.append(
(orig_w - x2f, y1, orig_w - x1f, y2, conf, cls_id)
)
# Still merge flip into all_dets so dedup + NMS sees both views
# (preserves existing TTA recall behaviour).
all_dets.extend(flip_dets_unflipped)
# TTA-aware cluster-dedup: collapse near-duplicate detections of the
# same plate (e.g. original + unflipped TTA view) BEFORE Soft-NMS,
# which would otherwise decay but not kill the lower-conf copy at
# our low score_threshold=0.01. Without this step the deployed miner
# emitted 2-3 outputs per plate (verified on validator task 57820).
pre_nms_count = len(all_dets)
all_dets = self._cluster_dedup(all_dets, iou_thresh=0.3)
dets = self._soft_nms(all_dets)
# (Dual-threshold TTA gate tried here and reverted 2026-04-21: on our
# YOLO11s ONNX the gate cost β0.037 map50-proxy to save only +0.023 FP,
# net β0.013 composite on 20 post-jump archive tasks. Pattern is the
# right one for hermes's YOLO26s (higher recall, more conf >=0.90 boxes)
# but hurts YOLO11s. Keep self.conf_high + self.tta_match_iou params in
# __init__ in case v7/v8 training closes the recall gap and makes the
# gate net-positive β can re-add this block then.)
out_boxes: list[BoundingBox] = []
for x1, y1, x2, y2, conf, cls_id in dets:
ix1 = max(0, min(orig_w, math.floor(x1)))
iy1 = max(0, min(orig_h, math.floor(y1)))
ix2 = max(0, min(orig_w, math.ceil(x2)))
iy2 = max(0, min(orig_h, math.ceil(y2)))
bw = ix2 - ix1
bh = iy2 - iy1
# Post-filter: reject non-plate geometry.
# F1a: oversized boxes (banners/text overlays at frame edges)
if max(bw, bh) > 150:
continue
# F1b: portrait-aspect boxes below confidence threshold β
# real plates are wider than tall; portrait boxes at low conf
# are vertical artifacts (posts, signs). High-conf portrait
# plates (e.g. vertically mounted) are preserved.
if bh > 0 and bw < bh * 0.8 and conf < 0.5:
continue
out_boxes.append(
BoundingBox(
x1=ix1,
y1=iy1,
x2=ix2,
y2=iy2,
cls_id=cls_id,
conf=max(0.0, min(1.0, conf)),
)
)
# Silent-empty-submission guard: if the pipeline found raw detections
# but every one was filtered to nothing, bypass F1a/F1b and emit the
# post-NMS detections above score_threshold. Accepts a potential FP
# over a guaranteed zero β which scored 0.000-0.010 on validator
# tasks 57803/57836/57848 even though the model had clear plate
# signal in the tiles.
if pre_nms_count > 0 and not out_boxes:
_cuda_log.warning(
"empty-submission guard: %d raw dets β 0 filtered; emitting raw",
pre_nms_count,
)
for x1, y1, x2, y2, conf, cls_id in dets:
if conf < self.score_threshold:
continue
ix1 = max(0, min(orig_w, math.floor(x1)))
iy1 = max(0, min(orig_h, math.floor(y1)))
ix2 = max(0, min(orig_w, math.ceil(x2)))
iy2 = max(0, min(orig_h, math.ceil(y2)))
if ix2 <= ix1 or iy2 <= iy1:
continue
out_boxes.append(
BoundingBox(
x1=ix1,
y1=iy1,
x2=ix2,
y2=iy2,
cls_id=cls_id,
conf=max(0.0, min(1.0, conf)),
)
)
return out_boxes
# ---------------------------------------------------------------- entry
def predict_batch(
self,
batch_images: list[ndarray],
offset: int,
n_keypoints: int,
) -> list[TVFrameResult]:
results: list[TVFrameResult] = []
for idx, image in enumerate(batch_images):
boxes = self._infer_single(image)
keypoints = [(0, 0) for _ in range(max(0, int(n_keypoints)))]
results.append(
TVFrameResult(
frame_id=offset + idx,
boxes=boxes,
keypoints=keypoints,
)
)
return results
|