Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| from __future__ import annotations | |
| import json | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Iterable | |
| import click | |
| import httpx | |
| from eval import common | |
| class Config: | |
| api: str | |
| label_sets: list[Path] | |
| images: list[Path] | |
| domain_top_n: int | |
| top_k: int | |
| out_dir: Path | None | |
| summary: bool | |
| select_domain_n: int | None | |
| select_label_n: int | None | |
| min_domain_score: float | None | |
| min_label_score: float | None | |
| def expand_label_sets(paths: Iterable[str]) -> list[Path]: | |
| out: list[Path] = [] | |
| for raw in paths: | |
| p = Path(raw) | |
| if any(ch in raw for ch in ["*", "?", "["]): | |
| out.extend(sorted(Path().glob(raw))) | |
| else: | |
| out.append(p) | |
| return [p for p in out if p.is_file()] | |
| def to_row( | |
| label_set: Path, | |
| image: Path, | |
| data: dict, | |
| *, | |
| select_domain_n: int | None, | |
| select_label_n: int | None, | |
| min_domain_score: float | None, | |
| min_label_score: float | None, | |
| ) -> dict[str, str]: | |
| domain_hits = data.get("domain_hits", []) | |
| label_hits = data.get("label_hits", []) | |
| selected_domains = common.select_hits(domain_hits, max_n=select_domain_n, min_score=min_domain_score) | |
| selected_labels = common.select_hits(label_hits, max_n=select_label_n, min_score=min_label_score) | |
| return { | |
| "label_set": label_set.name, | |
| "image": str(image), | |
| "label_set_hash": data.get("label_set_hash", ""), | |
| "model_id": data.get("model_id", ""), | |
| "chosen_domains": "|".join(data.get("chosen_domains", [])), | |
| "selected_domains": "|".join(selected_domains), | |
| "selected_labels": "|".join(selected_labels), | |
| "domain_hits": "|".join(common.fmt_hit(d) for d in domain_hits), | |
| "label_hits": "|".join(common.fmt_hit(l) for l in label_hits), | |
| "elapsed_ms": str(data.get("elapsed_ms", "")), | |
| "elapsed_domain_ms": str(data.get("elapsed_domain_ms", "")), | |
| "elapsed_labels_ms": str(data.get("elapsed_labels_ms", "")), | |
| } | |
| def summarize_by_label_set(rows: list[dict[str, str]]) -> list[dict[str, str]]: | |
| summary: dict[str, list[int]] = {} | |
| for row in rows: | |
| label = row["label_set"] | |
| try: | |
| elapsed = int(row["elapsed_ms"]) | |
| except Exception: | |
| continue | |
| summary.setdefault(label, []).append(elapsed) | |
| out_rows: list[dict[str, str]] = [] | |
| for label, times in summary.items(): | |
| avg = int(sum(times) / max(1, len(times))) | |
| out_rows.append( | |
| { | |
| "label_set": label, | |
| "count": str(len(times)), | |
| "avg_elapsed_ms": str(avg), | |
| "p50_elapsed_ms": str(common.percentile(times, 0.50)), | |
| "p95_elapsed_ms": str(common.percentile(times, 0.95)), | |
| } | |
| ) | |
| return out_rows | |
| def run(cfg: Config) -> None: | |
| images = list(common.iter_images(cfg.images)) | |
| if not images: | |
| raise SystemExit("No images found.") | |
| if not cfg.label_sets: | |
| raise SystemExit("No label sets found.") | |
| rows: list[dict[str, str]] = [] | |
| with httpx.Client(base_url=cfg.api, timeout=30) as client: | |
| for label_set in cfg.label_sets: | |
| label_set_hash = common.upload_label_set(client, label_set) | |
| for image in images: | |
| data = common.classify_one( | |
| client, | |
| label_set_hash, | |
| image_b64=common.encode_image_b64(image), | |
| domain_top_n=cfg.domain_top_n, | |
| top_k=cfg.top_k, | |
| ) | |
| print(json.dumps({"label_set": label_set.name, "image": str(image), "result": data})) | |
| rows.append( | |
| to_row( | |
| label_set, | |
| image, | |
| data, | |
| select_domain_n=cfg.select_domain_n, | |
| select_label_n=cfg.select_label_n, | |
| min_domain_score=cfg.min_domain_score, | |
| min_label_score=cfg.min_label_score, | |
| ) | |
| ) | |
| fieldnames = [ | |
| "label_set", | |
| "image", | |
| "label_set_hash", | |
| "model_id", | |
| "chosen_domains", | |
| "selected_domains", | |
| "selected_labels", | |
| "domain_hits", | |
| "label_hits", | |
| "elapsed_ms", | |
| "elapsed_domain_ms", | |
| "elapsed_labels_ms", | |
| ] | |
| out_dir = common.resolve_out_dir(cfg.api, cfg.out_dir) | |
| out_path = out_dir / f"eval_matrix_{common.timestamp()}.csv" | |
| common.write_csv(out_path, rows, fieldnames) | |
| if cfg.summary: | |
| summary_rows = summarize_by_label_set(rows) | |
| summary_path = out_dir / f"eval_matrix_summary_{common.timestamp()}.csv" | |
| common.write_csv(summary_path, summary_rows, ["label_set", "count", "avg_elapsed_ms", "p50_elapsed_ms", "p95_elapsed_ms"]) | |
| def cli( | |
| api: str, | |
| label_sets_raw: tuple[str, ...], | |
| images: tuple[Path, ...], | |
| domain_top_n: int, | |
| top_k: int, | |
| out_dir: Path | None, | |
| summary: bool, | |
| select_domain_n: int | None, | |
| select_label_n: int | None, | |
| min_domain_score: float | None, | |
| min_label_score: float | None, | |
| ) -> None: | |
| label_sets = expand_label_sets(label_sets_raw) | |
| cfg = Config( | |
| api=api, | |
| label_sets=label_sets, | |
| images=list(images), | |
| domain_top_n=domain_top_n, | |
| top_k=top_k, | |
| out_dir=out_dir, | |
| summary=summary, | |
| select_domain_n=select_domain_n, | |
| select_label_n=select_label_n, | |
| min_domain_score=min_domain_score, | |
| min_label_score=min_label_score, | |
| ) | |
| run(cfg) | |
| if __name__ == "__main__": | |
| cli() | |