Minerv4 / labels.py
Imaginethat's picture
Upload 4 files
0e20813 verified
raw
history blame contribute delete
883 Bytes
from __future__ import annotations
from typing import Sequence, Tuple
import numpy as np
UNKNOWN_LABEL = "Unknown"
def dominant_label_gated(
scores: Sequence[float] | None,
labels: Sequence[str] | None,
min_score: float,
*,
unknown_label: str = UNKNOWN_LABEL,
) -> Tuple[str, float]:
"""
Return the dominant label only if it's confidently supported.
Rule:
- If max(score) >= min_score: return (label, max_score)
- Else: return (unknown_label, max_score)
"""
if not scores or not labels:
return unknown_label, 0.0
arr = np.asarray(list(scores), dtype=float)
if arr.size == 0:
return unknown_label, 0.0
idx = int(arr.argmax())
max_score = float(arr[idx])
if idx >= len(labels) or max_score < float(min_score):
return unknown_label, max_score
return str(labels[idx]), max_score