#!/usr/bin/env python3 from __future__ import annotations import json from dataclasses import dataclass from pathlib import Path import click import httpx from eval import common @dataclass(frozen=True) class Config: api: str label_set: Path images: list[Path] domain_top_n: int top_k: int activate: bool limit: int out_dir: Path | None csv_path: Path | None summary: bool select_domain_n: int | None select_label_n: int | None min_domain_score: float | None min_label_score: float | None def activate_label_set(client: httpx.Client, label_set_hash: str) -> None: r = client.post(f"/api/v1/label-sets/{label_set_hash}/activate") r.raise_for_status() def to_row( image: Path, data: dict, *, select_domain_n: int | None, select_label_n: int | None, min_domain_score: float | None, min_label_score: float | None, ) -> dict[str, str]: domain_hits = data.get("domain_hits", []) label_hits = data.get("label_hits", []) selected_domains = common.select_hits(domain_hits, max_n=select_domain_n, min_score=min_domain_score) selected_labels = common.select_hits(label_hits, max_n=select_label_n, min_score=min_label_score) return { "image": str(image), "label_set_hash": data.get("label_set_hash", ""), "model_id": data.get("model_id", ""), "chosen_domains": "|".join(data.get("chosen_domains", [])), "selected_domains": "|".join(selected_domains), "selected_labels": "|".join(selected_labels), "domain_hits": "|".join(common.fmt_hit(d) for d in domain_hits), "label_hits": "|".join(common.fmt_hit(l) for l in label_hits), "elapsed_ms": str(data.get("elapsed_ms", "")), "elapsed_domain_ms": str(data.get("elapsed_domain_ms", "")), "elapsed_labels_ms": str(data.get("elapsed_labels_ms", "")), } def run(cfg: Config) -> int: images = list(common.iter_images(cfg.images)) if cfg.limit > 0: images = images[: cfg.limit] if not images: raise SystemExit("No images found.") rows: list[dict[str, str]] = [] with httpx.Client(base_url=cfg.api, timeout=30) as client: label_set_hash = common.upload_label_set(client, cfg.label_set) if cfg.activate: activate_label_set(client, label_set_hash) for image in images: data = common.classify_one( client, label_set_hash, image_b64=common.encode_image_b64(image), domain_top_n=cfg.domain_top_n, top_k=cfg.top_k, ) print(json.dumps({"image": str(image), "result": data}, ensure_ascii=True)) rows.append( to_row( image, data, select_domain_n=cfg.select_domain_n, select_label_n=cfg.select_label_n, min_domain_score=cfg.min_domain_score, min_label_score=cfg.min_label_score, ) ) fieldnames = [ "image", "label_set_hash", "model_id", "chosen_domains", "selected_domains", "selected_labels", "domain_hits", "label_hits", "elapsed_ms", "elapsed_domain_ms", "elapsed_labels_ms", ] out_dir = common.resolve_out_dir(cfg.api, cfg.out_dir) if cfg.csv_path and rows: common.write_csv(cfg.csv_path, rows, fieldnames) elif rows: out_path = out_dir / f"{cfg.label_set.stem}_{common.timestamp()}.csv" common.write_csv(out_path, rows, fieldnames) if cfg.summary: summary = common.summarize_latency(rows) summary_path = out_dir / f"{cfg.label_set.stem}_summary_{common.timestamp()}.csv" common.write_csv(summary_path, [summary], list(summary.keys())) return 0 @click.command() @click.option("--api", default="http://localhost:7860", show_default=True) @click.option("--label-set", "label_set", required=True, type=click.Path(path_type=Path)) @click.option("--images", multiple=True, required=True, type=click.Path(path_type=Path)) @click.option("--domain-top-n", default=2, show_default=True, type=int) @click.option("--top-k", default=5, show_default=True, type=int) @click.option("--activate", is_flag=True, default=False) @click.option("--limit", default=0, show_default=True, type=int) @click.option("--out-dir", type=click.Path(path_type=Path)) @click.option("--csv", "csv_path", type=click.Path(path_type=Path)) @click.option("--summary", is_flag=True, default=False) @click.option("--select-domain-n", type=int, default=None) @click.option("--select-label-n", type=int, default=None) @click.option("--min-domain-score", type=float, default=None) @click.option("--min-label-score", type=float, default=None) def cli( api: str, label_set: Path, images: tuple[Path, ...], domain_top_n: int, top_k: int, activate: bool, limit: int, out_dir: Path | None, csv_path: Path | None, summary: bool, select_domain_n: int | None, select_label_n: int | None, min_domain_score: float | None, min_label_score: float | None, ) -> None: cfg = Config( api=api, label_set=label_set, images=list(images), domain_top_n=domain_top_n, top_k=top_k, activate=activate, limit=limit, out_dir=out_dir, csv_path=csv_path, summary=summary, select_domain_n=select_domain_n, select_label_n=select_label_n, min_domain_score=min_domain_score, min_label_score=min_label_score, ) raise SystemExit(run(cfg)) if __name__ == "__main__": cli()