Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import base64 | |
| import csv | |
| import json | |
| from dataclasses import dataclass | |
| from datetime import datetime, timezone | |
| from pathlib import Path | |
| from typing import Iterable | |
| from urllib.parse import urlparse | |
| import httpx | |
| class EvalConfig: | |
| api: str | |
| images: list[Path] | |
| domain_top_n: int | |
| top_k: int | |
| out_dir: Path | |
| summary: bool | |
| def select_hits( | |
| hits: list[dict], | |
| *, | |
| max_n: int | None = None, | |
| min_score: float | None = None, | |
| ) -> list[str]: | |
| out: list[str] = [] | |
| for hit in hits: | |
| if min_score is not None: | |
| try: | |
| if float(hit.get("score", 0.0)) < min_score: | |
| continue | |
| except Exception: | |
| continue | |
| out.append(str(hit.get("id"))) | |
| if max_n is not None and len(out) >= max_n: | |
| break | |
| return out | |
| def iter_images(paths: Iterable[Path]) -> Iterable[Path]: | |
| exts = {".jpg", ".jpeg", ".png", ".webp"} | |
| for path in paths: | |
| if path.is_dir(): | |
| for p in sorted(path.rglob("*")): | |
| if p.is_file() and p.suffix.lower() in exts: | |
| yield p | |
| elif path.is_file() and path.suffix.lower() in exts: | |
| yield path | |
| def upload_label_set(client: httpx.Client, label_set: Path) -> str: | |
| payload = json.loads(label_set.read_text()) | |
| r = client.post("/api/v1/label-sets", json=payload) | |
| r.raise_for_status() | |
| return r.json()["label_set_hash"] | |
| def classify_one( | |
| client: httpx.Client, | |
| label_set_hash: str, | |
| image_b64: str, | |
| domain_top_n: int, | |
| top_k: int, | |
| ) -> dict: | |
| payload = { | |
| "image_base64": image_b64, | |
| "domain_top_n": domain_top_n, | |
| "top_k": top_k, | |
| } | |
| r = client.post(f"/api/v1/classify?label_set_hash={label_set_hash}", json=payload) | |
| r.raise_for_status() | |
| return r.json() | |
| def encode_image_b64(path: Path) -> str: | |
| return base64.b64encode(path.read_bytes()).decode("utf-8") | |
| def fmt_hit(hit: dict) -> str: | |
| score = hit.get("score") | |
| try: | |
| score_str = f"{float(score):.4f}" | |
| except Exception: | |
| score_str = "" | |
| return f"{hit.get('id')}:{score_str}" | |
| def percentile(values: list[int], q: float) -> int: | |
| if not values: | |
| return 0 | |
| values = sorted(values) | |
| idx = int(round((len(values) - 1) * q)) | |
| return values[idx] | |
| def timestamp() -> str: | |
| return datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S") | |
| def api_slug(api: str) -> str: | |
| parsed = urlparse(api) | |
| host = parsed.netloc or parsed.path | |
| host = host.replace("http://", "").replace("https://", "") | |
| host = host.strip("/") | |
| if host in {"localhost:7860", "localhost", "127.0.0.1:7860", "127.0.0.1"}: | |
| return "local" | |
| return "".join(ch if ch.isalnum() or ch in {"-", "."} else "-" for ch in host) | |
| def resolve_out_dir(api: str, out_dir: Path | None) -> Path: | |
| if out_dir is not None: | |
| return out_dir | |
| return Path("data_results") / api_slug(api) | |
| def write_csv(path: Path, rows: list[dict[str, str]], fieldnames: list[str]) -> None: | |
| path.parent.mkdir(parents=True, exist_ok=True) | |
| with path.open("w", newline="", encoding="utf-8") as f: | |
| writer = csv.DictWriter(f, fieldnames=fieldnames) | |
| writer.writeheader() | |
| writer.writerows(rows) | |
| def summarize_latency(rows: list[dict[str, str]]) -> dict[str, str]: | |
| times: list[int] = [] | |
| for row in rows: | |
| try: | |
| times.append(int(row["elapsed_ms"])) | |
| except Exception: | |
| continue | |
| return { | |
| "count": str(len(times)), | |
| "avg_elapsed_ms": str(int(sum(times) / max(1, len(times)))), | |
| "p50_elapsed_ms": str(percentile(times, 0.50)), | |
| "p90_elapsed_ms": str(percentile(times, 0.90)), | |
| "p95_elapsed_ms": str(percentile(times, 0.95)), | |
| "p99_elapsed_ms": str(percentile(times, 0.99)), | |
| } | |