| | """Judge backends — API-based (HF Inference Providers, OpenAI-compatible).""" |
| |
|
| | from __future__ import annotations |
| |
|
| | import abc |
| | from collections import Counter |
| | from concurrent.futures import ThreadPoolExecutor, as_completed |
| | from typing import Any |
| |
|
| | import stamina |
| | import structlog |
| | from huggingface_hub import InferenceClient |
| | from openai import OpenAI |
| |
|
| | from ocr_bench.judge import JUDGE_SCHEMA, Comparison, parse_judge_output |
| |
|
| | logger = structlog.get_logger() |
| |
|
| | |
| | _RETRYABLE = (Exception,) |
| |
|
| |
|
| | class JudgeBackend(abc.ABC): |
| | """Base class for judge backends.""" |
| |
|
| | name: str |
| | concurrency: int = 1 |
| |
|
| | @abc.abstractmethod |
| | def _call_single(self, comp: Comparison) -> dict[str, str]: |
| | """Run the judge on a single comparison.""" |
| |
|
| | def judge(self, comparisons: list[Comparison]) -> list[dict[str, str]]: |
| | """Run the judge on a list of comparisons (concurrently if supported). |
| | |
| | Returns a list of parsed results (one per comparison). |
| | Each result is a dict with ``winner`` and ``reason`` keys, |
| | or an empty dict on failure. |
| | """ |
| | if self.concurrency <= 1 or len(comparisons) <= 1: |
| | return [self._call_single(comp) for comp in comparisons] |
| |
|
| | |
| | results: list[dict[str, str]] = [{}] * len(comparisons) |
| | with ThreadPoolExecutor(max_workers=self.concurrency) as pool: |
| | future_to_idx = { |
| | pool.submit(self._call_single, comp): i |
| | for i, comp in enumerate(comparisons) |
| | } |
| | for future in as_completed(future_to_idx): |
| | idx = future_to_idx[future] |
| | try: |
| | results[idx] = future.result() |
| | except Exception as exc: |
| | logger.warning("judge_call_failed", idx=idx, error=str(exc)) |
| | results[idx] = {} |
| | return results |
| |
|
| |
|
| | DEFAULT_MAX_TOKENS = 1024 |
| |
|
| |
|
| | class InferenceProviderJudge(JudgeBackend): |
| | """HF Inference Providers backend (Novita, Together, etc.).""" |
| |
|
| | def __init__( |
| | self, model: str, provider: str | None = None, max_tokens: int = DEFAULT_MAX_TOKENS, |
| | ): |
| | self.name = f"{provider + ':' if provider else ''}{model}" |
| | self.model = model |
| | self.max_tokens = max_tokens |
| | self.client = InferenceClient(model=model, provider=provider) |
| |
|
| | @stamina.retry(on=_RETRYABLE, attempts=6) |
| | def _call_single(self, comp: Comparison) -> dict[str, str]: |
| | response = self.client.chat_completion( |
| | messages=comp.messages, |
| | max_tokens=self.max_tokens, |
| | temperature=0.0, |
| | response_format={"type": "json_object"}, |
| | extra_body={"chat_template_kwargs": {"enable_thinking": False}}, |
| | ) |
| | raw = response.choices[0].message.content.strip() |
| | result = parse_judge_output(raw) |
| | if not result: |
| | logger.warning("empty_parse", backend=self.name, sample=comp.sample_idx) |
| | return result |
| |
|
| |
|
| | class OpenAICompatibleJudge(JudgeBackend): |
| | """OpenAI-compatible endpoint (local vLLM server, Ollama, HF IE, etc.).""" |
| |
|
| | def __init__( |
| | self, |
| | base_url: str, |
| | model: str = "default", |
| | max_tokens: int = DEFAULT_MAX_TOKENS, |
| | api_key: str = "not-needed", |
| | extra_body: dict | None = None, |
| | temperature: float = 0.0, |
| | concurrency: int = 1, |
| | ): |
| | self.name = model if model != "default" else f"openai@{base_url}" |
| | self.model = model |
| | self.max_tokens = max_tokens |
| | self.temperature = temperature |
| | self.extra_body = extra_body if extra_body is not None else {"guided_json": JUDGE_SCHEMA} |
| | self.concurrency = concurrency |
| | self.client = OpenAI(base_url=base_url, api_key=api_key) |
| |
|
| | @stamina.retry(on=_RETRYABLE, attempts=3) |
| | def _call_single(self, comp: Comparison) -> dict[str, str]: |
| | response = self.client.chat.completions.create( |
| | model=self.model, |
| | messages=comp.messages, |
| | max_tokens=self.max_tokens, |
| | temperature=self.temperature, |
| | extra_body=self.extra_body, |
| | ) |
| | raw = response.choices[0].message.content.strip() |
| | result = parse_judge_output(raw) |
| | if not result: |
| | logger.warning("empty_parse", backend=self.name, sample=comp.sample_idx) |
| | return result |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | DEFAULT_JUDGE = "novita:moonshotai/Kimi-K2.5" |
| |
|
| |
|
| | def parse_judge_spec( |
| | spec: str, max_tokens: int = DEFAULT_MAX_TOKENS, concurrency: int = 1, |
| | ) -> JudgeBackend: |
| | """Parse a judge specification string into a backend. |
| | |
| | Formats: |
| | - ``"https://xxx.endpoints.huggingface.cloud"`` → :class:`OpenAICompatibleJudge` |
| | (HF Inference Endpoints, OpenAI-compatible with HF token auth) |
| | - ``"http://..."`` or ``"https://..."`` (other) → :class:`OpenAICompatibleJudge` |
| | - ``"provider:org/model"`` (colon before first ``/``) → :class:`InferenceProviderJudge` |
| | - anything else → :class:`InferenceProviderJudge` (no provider) |
| | """ |
| | if spec.startswith("http://") or spec.startswith("https://"): |
| | |
| | url_part = spec |
| | model_name = "default" |
| | |
| | if "/v1/:" in spec: |
| | url_part, model_name = spec.split("/v1/:", 1) |
| | url_part += "/v1" |
| |
|
| | |
| | if ".endpoints.huggingface." in url_part: |
| | from huggingface_hub import get_token |
| |
|
| | base_url = url_part.rstrip("/") |
| | if not base_url.endswith("/v1"): |
| | base_url += "/v1" |
| | token = get_token() or "not-needed" |
| | return OpenAICompatibleJudge( |
| | base_url=base_url, |
| | model=model_name, |
| | api_key=token, |
| | max_tokens=max_tokens, |
| | temperature=0.7, |
| | extra_body={"chat_template_kwargs": {"enable_thinking": False}}, |
| | concurrency=concurrency, |
| | ) |
| | return OpenAICompatibleJudge( |
| | base_url=url_part, model=model_name, max_tokens=max_tokens, |
| | concurrency=concurrency, |
| | ) |
| |
|
| | if ":" in spec: |
| | |
| | colon_idx = spec.index(":") |
| | slash_idx = spec.find("/") |
| | if slash_idx == -1 or colon_idx < slash_idx: |
| | provider, model = spec.split(":", 1) |
| | return InferenceProviderJudge(model=model, provider=provider, max_tokens=max_tokens) |
| |
|
| | return InferenceProviderJudge(model=spec, max_tokens=max_tokens) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | def aggregate_jury_votes( |
| | all_results: list[list[dict[str, str]]], |
| | judge_names: list[str], |
| | ) -> list[dict[str, Any]]: |
| | """Aggregate votes from multiple judges using majority voting. |
| | |
| | Args: |
| | all_results: List of result lists, one per judge. Each inner list |
| | has one dict per comparison. |
| | judge_names: Names of the judges (same order as *all_results*). |
| | |
| | Returns: |
| | Aggregated results with ``winner``, ``reason``, and ``agreement`` fields. |
| | """ |
| | if not all_results: |
| | return [] |
| |
|
| | n_comparisons = len(all_results[0]) |
| | n_judges = len(all_results) |
| | aggregated: list[dict[str, Any]] = [] |
| |
|
| | for i in range(n_comparisons): |
| | votes: list[str] = [] |
| | reasons: list[str] = [] |
| | for j in range(n_judges): |
| | result = all_results[j][i] if i < len(all_results[j]) else {} |
| | winner = result.get("winner", "") |
| | if winner: |
| | votes.append(winner) |
| | reasons.append(f"{judge_names[j]}: {result.get('reason', '')}") |
| |
|
| | if not votes: |
| | aggregated.append({"winner": "tie", "reason": "no valid votes", "agreement": "0/0"}) |
| | continue |
| |
|
| | counter = Counter(votes) |
| | majority_winner, majority_count = counter.most_common(1)[0] |
| | agreement = f"{majority_count}/{len(votes)}" |
| |
|
| | aggregated.append({ |
| | "winner": majority_winner, |
| | "reason": "; ".join(reasons), |
| | "agreement": agreement, |
| | }) |
| |
|
| | return aggregated |
| |
|