#!/usr/bin/env python3 from __future__ import annotations import json from dataclasses import dataclass from pathlib import Path from typing import Iterable import click import httpx from eval import common @dataclass(frozen=True) class Config: api: str label_sets: list[Path] images: list[Path] domain_top_n: int top_k: int out_dir: Path | None summary: bool select_domain_n: int | None select_label_n: int | None min_domain_score: float | None min_label_score: float | None def expand_label_sets(paths: Iterable[str]) -> list[Path]: out: list[Path] = [] for raw in paths: p = Path(raw) if any(ch in raw for ch in ["*", "?", "["]): out.extend(sorted(Path().glob(raw))) else: out.append(p) return [p for p in out if p.is_file()] def to_row( label_set: Path, image: Path, data: dict, *, select_domain_n: int | None, select_label_n: int | None, min_domain_score: float | None, min_label_score: float | None, ) -> dict[str, str]: domain_hits = data.get("domain_hits", []) label_hits = data.get("label_hits", []) selected_domains = common.select_hits(domain_hits, max_n=select_domain_n, min_score=min_domain_score) selected_labels = common.select_hits(label_hits, max_n=select_label_n, min_score=min_label_score) return { "label_set": label_set.name, "image": str(image), "label_set_hash": data.get("label_set_hash", ""), "model_id": data.get("model_id", ""), "chosen_domains": "|".join(data.get("chosen_domains", [])), "selected_domains": "|".join(selected_domains), "selected_labels": "|".join(selected_labels), "domain_hits": "|".join(common.fmt_hit(d) for d in domain_hits), "label_hits": "|".join(common.fmt_hit(l) for l in label_hits), "elapsed_ms": str(data.get("elapsed_ms", "")), "elapsed_domain_ms": str(data.get("elapsed_domain_ms", "")), "elapsed_labels_ms": str(data.get("elapsed_labels_ms", "")), } def summarize_by_label_set(rows: list[dict[str, str]]) -> list[dict[str, str]]: summary: dict[str, list[int]] = {} for row in rows: label = row["label_set"] try: elapsed = int(row["elapsed_ms"]) except Exception: continue summary.setdefault(label, []).append(elapsed) out_rows: list[dict[str, str]] = [] for label, times in summary.items(): avg = int(sum(times) / max(1, len(times))) out_rows.append( { "label_set": label, "count": str(len(times)), "avg_elapsed_ms": str(avg), "p50_elapsed_ms": str(common.percentile(times, 0.50)), "p95_elapsed_ms": str(common.percentile(times, 0.95)), } ) return out_rows def run(cfg: Config) -> None: images = list(common.iter_images(cfg.images)) if not images: raise SystemExit("No images found.") if not cfg.label_sets: raise SystemExit("No label sets found.") rows: list[dict[str, str]] = [] with httpx.Client(base_url=cfg.api, timeout=30) as client: for label_set in cfg.label_sets: label_set_hash = common.upload_label_set(client, label_set) for image in images: data = common.classify_one( client, label_set_hash, image_b64=common.encode_image_b64(image), domain_top_n=cfg.domain_top_n, top_k=cfg.top_k, ) print(json.dumps({"label_set": label_set.name, "image": str(image), "result": data})) rows.append( to_row( label_set, image, data, select_domain_n=cfg.select_domain_n, select_label_n=cfg.select_label_n, min_domain_score=cfg.min_domain_score, min_label_score=cfg.min_label_score, ) ) fieldnames = [ "label_set", "image", "label_set_hash", "model_id", "chosen_domains", "selected_domains", "selected_labels", "domain_hits", "label_hits", "elapsed_ms", "elapsed_domain_ms", "elapsed_labels_ms", ] out_dir = common.resolve_out_dir(cfg.api, cfg.out_dir) out_path = out_dir / f"eval_matrix_{common.timestamp()}.csv" common.write_csv(out_path, rows, fieldnames) if cfg.summary: summary_rows = summarize_by_label_set(rows) summary_path = out_dir / f"eval_matrix_summary_{common.timestamp()}.csv" common.write_csv(summary_path, summary_rows, ["label_set", "count", "avg_elapsed_ms", "p50_elapsed_ms", "p95_elapsed_ms"]) @click.command() @click.option("--api", default="http://localhost:7860", show_default=True) @click.option("--label-sets", "label_sets_raw", multiple=True, required=True) @click.option("--images", multiple=True, required=True, type=click.Path(path_type=Path)) @click.option("--domain-top-n", default=2, show_default=True, type=int) @click.option("--top-k", default=5, show_default=True, type=int) @click.option("--out-dir", type=click.Path(path_type=Path)) @click.option("--summary", is_flag=True, default=False) @click.option("--select-domain-n", type=int, default=None) @click.option("--select-label-n", type=int, default=None) @click.option("--min-domain-score", type=float, default=None) @click.option("--min-label-score", type=float, default=None) def cli( api: str, label_sets_raw: tuple[str, ...], images: tuple[Path, ...], domain_top_n: int, top_k: int, out_dir: Path | None, summary: bool, select_domain_n: int | None, select_label_n: int | None, min_domain_score: float | None, min_label_score: float | None, ) -> None: label_sets = expand_label_sets(label_sets_raw) cfg = Config( api=api, label_sets=label_sets, images=list(images), domain_top_n=domain_top_n, top_k=top_k, out_dir=out_dir, summary=summary, select_domain_n=select_domain_n, select_label_n=select_label_n, min_domain_score=min_domain_score, min_label_score=min_label_score, ) run(cfg) if __name__ == "__main__": cli()