from __future__ import annotations import base64 import csv import json from dataclasses import dataclass from datetime import datetime, timezone from pathlib import Path from typing import Iterable from urllib.parse import urlparse import httpx @dataclass(frozen=True) class EvalConfig: api: str images: list[Path] domain_top_n: int top_k: int out_dir: Path summary: bool def select_hits( hits: list[dict], *, max_n: int | None = None, min_score: float | None = None, ) -> list[str]: out: list[str] = [] for hit in hits: if min_score is not None: try: if float(hit.get("score", 0.0)) < min_score: continue except Exception: continue out.append(str(hit.get("id"))) if max_n is not None and len(out) >= max_n: break return out def iter_images(paths: Iterable[Path]) -> Iterable[Path]: exts = {".jpg", ".jpeg", ".png", ".webp"} for path in paths: if path.is_dir(): for p in sorted(path.rglob("*")): if p.is_file() and p.suffix.lower() in exts: yield p elif path.is_file() and path.suffix.lower() in exts: yield path def upload_label_set(client: httpx.Client, label_set: Path) -> str: payload = json.loads(label_set.read_text()) r = client.post("/api/v1/label-sets", json=payload) r.raise_for_status() return r.json()["label_set_hash"] def classify_one( client: httpx.Client, label_set_hash: str, image_b64: str, domain_top_n: int, top_k: int, ) -> dict: payload = { "image_base64": image_b64, "domain_top_n": domain_top_n, "top_k": top_k, } r = client.post(f"/api/v1/classify?label_set_hash={label_set_hash}", json=payload) r.raise_for_status() return r.json() def encode_image_b64(path: Path) -> str: return base64.b64encode(path.read_bytes()).decode("utf-8") def fmt_hit(hit: dict) -> str: score = hit.get("score") try: score_str = f"{float(score):.4f}" except Exception: score_str = "" return f"{hit.get('id')}:{score_str}" def percentile(values: list[int], q: float) -> int: if not values: return 0 values = sorted(values) idx = int(round((len(values) - 1) * q)) return values[idx] def timestamp() -> str: return datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S") def api_slug(api: str) -> str: parsed = urlparse(api) host = parsed.netloc or parsed.path host = host.replace("http://", "").replace("https://", "") host = host.strip("/") if host in {"localhost:7860", "localhost", "127.0.0.1:7860", "127.0.0.1"}: return "local" return "".join(ch if ch.isalnum() or ch in {"-", "."} else "-" for ch in host) def resolve_out_dir(api: str, out_dir: Path | None) -> Path: if out_dir is not None: return out_dir return Path("data_results") / api_slug(api) def write_csv(path: Path, rows: list[dict[str, str]], fieldnames: list[str]) -> None: path.parent.mkdir(parents=True, exist_ok=True) with path.open("w", newline="", encoding="utf-8") as f: writer = csv.DictWriter(f, fieldnames=fieldnames) writer.writeheader() writer.writerows(rows) def summarize_latency(rows: list[dict[str, str]]) -> dict[str, str]: times: list[int] = [] for row in rows: try: times.append(int(row["elapsed_ms"])) except Exception: continue return { "count": str(len(times)), "avg_elapsed_ms": str(int(sum(times) / max(1, len(times)))), "p50_elapsed_ms": str(percentile(times, 0.50)), "p90_elapsed_ms": str(percentile(times, 0.90)), "p95_elapsed_ms": str(percentile(times, 0.95)), "p99_elapsed_ms": str(percentile(times, 0.99)), }