davanstrien's picture
davanstrien HF Staff
Upload folder using huggingface_hub
1118181 verified
"""Judge backends — API-based (HF Inference Providers, OpenAI-compatible)."""
from __future__ import annotations
import abc
from collections import Counter
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Any
import stamina
import structlog
from huggingface_hub import InferenceClient
from openai import OpenAI
from ocr_bench.judge import JUDGE_SCHEMA, Comparison, parse_judge_output
logger = structlog.get_logger()
# Retry on these exception types with exponential backoff + jitter.
_RETRYABLE = (Exception,)
class JudgeBackend(abc.ABC):
"""Base class for judge backends."""
name: str
concurrency: int = 1
@abc.abstractmethod
def _call_single(self, comp: Comparison) -> dict[str, str]:
"""Run the judge on a single comparison."""
def judge(self, comparisons: list[Comparison]) -> list[dict[str, str]]:
"""Run the judge on a list of comparisons (concurrently if supported).
Returns a list of parsed results (one per comparison).
Each result is a dict with ``winner`` and ``reason`` keys,
or an empty dict on failure.
"""
if self.concurrency <= 1 or len(comparisons) <= 1:
return [self._call_single(comp) for comp in comparisons]
# Concurrent execution preserving order
results: list[dict[str, str]] = [{}] * len(comparisons)
with ThreadPoolExecutor(max_workers=self.concurrency) as pool:
future_to_idx = {
pool.submit(self._call_single, comp): i
for i, comp in enumerate(comparisons)
}
for future in as_completed(future_to_idx):
idx = future_to_idx[future]
try:
results[idx] = future.result()
except Exception as exc:
logger.warning("judge_call_failed", idx=idx, error=str(exc))
results[idx] = {}
return results
DEFAULT_MAX_TOKENS = 1024
class InferenceProviderJudge(JudgeBackend):
"""HF Inference Providers backend (Novita, Together, etc.)."""
def __init__(
self, model: str, provider: str | None = None, max_tokens: int = DEFAULT_MAX_TOKENS,
):
self.name = f"{provider + ':' if provider else ''}{model}"
self.model = model
self.max_tokens = max_tokens
self.client = InferenceClient(model=model, provider=provider) # type: ignore[invalid-argument-type]
@stamina.retry(on=_RETRYABLE, attempts=6)
def _call_single(self, comp: Comparison) -> dict[str, str]:
response = self.client.chat_completion( # type: ignore[no-matching-overload]
messages=comp.messages,
max_tokens=self.max_tokens,
temperature=0.0,
response_format={"type": "json_object"},
extra_body={"chat_template_kwargs": {"enable_thinking": False}},
)
raw = response.choices[0].message.content.strip()
result = parse_judge_output(raw)
if not result:
logger.warning("empty_parse", backend=self.name, sample=comp.sample_idx)
return result
class OpenAICompatibleJudge(JudgeBackend):
"""OpenAI-compatible endpoint (local vLLM server, Ollama, HF IE, etc.)."""
def __init__(
self,
base_url: str,
model: str = "default",
max_tokens: int = DEFAULT_MAX_TOKENS,
api_key: str = "not-needed",
extra_body: dict | None = None,
temperature: float = 0.0,
concurrency: int = 1,
):
self.name = model if model != "default" else f"openai@{base_url}"
self.model = model
self.max_tokens = max_tokens
self.temperature = temperature
self.extra_body = extra_body if extra_body is not None else {"guided_json": JUDGE_SCHEMA}
self.concurrency = concurrency
self.client = OpenAI(base_url=base_url, api_key=api_key)
@stamina.retry(on=_RETRYABLE, attempts=3)
def _call_single(self, comp: Comparison) -> dict[str, str]:
response = self.client.chat.completions.create(
model=self.model,
messages=comp.messages, # type: ignore[invalid-argument-type]
max_tokens=self.max_tokens,
temperature=self.temperature,
extra_body=self.extra_body,
)
raw = response.choices[0].message.content.strip()
result = parse_judge_output(raw)
if not result:
logger.warning("empty_parse", backend=self.name, sample=comp.sample_idx)
return result
# ---------------------------------------------------------------------------
# Spec parsing
# ---------------------------------------------------------------------------
DEFAULT_JUDGE = "novita:moonshotai/Kimi-K2.5"
def parse_judge_spec(
spec: str, max_tokens: int = DEFAULT_MAX_TOKENS, concurrency: int = 1,
) -> JudgeBackend:
"""Parse a judge specification string into a backend.
Formats:
- ``"https://xxx.endpoints.huggingface.cloud"`` → :class:`OpenAICompatibleJudge`
(HF Inference Endpoints, OpenAI-compatible with HF token auth)
- ``"http://..."`` or ``"https://..."`` (other) → :class:`OpenAICompatibleJudge`
- ``"provider:org/model"`` (colon before first ``/``) → :class:`InferenceProviderJudge`
- anything else → :class:`InferenceProviderJudge` (no provider)
"""
if spec.startswith("http://") or spec.startswith("https://"):
# Check for url:model format (e.g. https://...cloud/v1/:org/model)
url_part = spec
model_name = "default"
# Split on /v1/: to separate URL from model name
if "/v1/:" in spec:
url_part, model_name = spec.split("/v1/:", 1)
url_part += "/v1"
# HF Inference Endpoints — OpenAI-compatible, auth via HF token
if ".endpoints.huggingface." in url_part:
from huggingface_hub import get_token
base_url = url_part.rstrip("/")
if not base_url.endswith("/v1"):
base_url += "/v1"
token = get_token() or "not-needed"
return OpenAICompatibleJudge(
base_url=base_url,
model=model_name,
api_key=token,
max_tokens=max_tokens,
temperature=0.7,
extra_body={"chat_template_kwargs": {"enable_thinking": False}},
concurrency=concurrency,
)
return OpenAICompatibleJudge(
base_url=url_part, model=model_name, max_tokens=max_tokens,
concurrency=concurrency,
)
if ":" in spec:
# provider:model format — colon must come before first slash
colon_idx = spec.index(":")
slash_idx = spec.find("/")
if slash_idx == -1 or colon_idx < slash_idx:
provider, model = spec.split(":", 1)
return InferenceProviderJudge(model=model, provider=provider, max_tokens=max_tokens)
return InferenceProviderJudge(model=spec, max_tokens=max_tokens)
# ---------------------------------------------------------------------------
# Jury aggregation
# ---------------------------------------------------------------------------
def aggregate_jury_votes(
all_results: list[list[dict[str, str]]],
judge_names: list[str],
) -> list[dict[str, Any]]:
"""Aggregate votes from multiple judges using majority voting.
Args:
all_results: List of result lists, one per judge. Each inner list
has one dict per comparison.
judge_names: Names of the judges (same order as *all_results*).
Returns:
Aggregated results with ``winner``, ``reason``, and ``agreement`` fields.
"""
if not all_results:
return []
n_comparisons = len(all_results[0])
n_judges = len(all_results)
aggregated: list[dict[str, Any]] = []
for i in range(n_comparisons):
votes: list[str] = []
reasons: list[str] = []
for j in range(n_judges):
result = all_results[j][i] if i < len(all_results[j]) else {}
winner = result.get("winner", "")
if winner:
votes.append(winner)
reasons.append(f"{judge_names[j]}: {result.get('reason', '')}")
if not votes:
aggregated.append({"winner": "tie", "reason": "no valid votes", "agreement": "0/0"})
continue
counter = Counter(votes)
majority_winner, majority_count = counter.most_common(1)[0]
agreement = f"{majority_count}/{len(votes)}"
aggregated.append({
"winner": majority_winner,
"reason": "; ".join(reasons),
"agreement": agreement,
})
return aggregated