Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| from __future__ import annotations | |
| import json | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| import click | |
| import httpx | |
| from eval import common | |
| class Config: | |
| api: str | |
| label_set: Path | |
| images: list[Path] | |
| domain_top_n: int | |
| top_k: int | |
| activate: bool | |
| limit: int | |
| out_dir: Path | None | |
| csv_path: Path | None | |
| summary: bool | |
| select_domain_n: int | None | |
| select_label_n: int | None | |
| min_domain_score: float | None | |
| min_label_score: float | None | |
| def activate_label_set(client: httpx.Client, label_set_hash: str) -> None: | |
| r = client.post(f"/api/v1/label-sets/{label_set_hash}/activate") | |
| r.raise_for_status() | |
| def to_row( | |
| image: Path, | |
| data: dict, | |
| *, | |
| select_domain_n: int | None, | |
| select_label_n: int | None, | |
| min_domain_score: float | None, | |
| min_label_score: float | None, | |
| ) -> dict[str, str]: | |
| domain_hits = data.get("domain_hits", []) | |
| label_hits = data.get("label_hits", []) | |
| selected_domains = common.select_hits(domain_hits, max_n=select_domain_n, min_score=min_domain_score) | |
| selected_labels = common.select_hits(label_hits, max_n=select_label_n, min_score=min_label_score) | |
| return { | |
| "image": str(image), | |
| "label_set_hash": data.get("label_set_hash", ""), | |
| "model_id": data.get("model_id", ""), | |
| "chosen_domains": "|".join(data.get("chosen_domains", [])), | |
| "selected_domains": "|".join(selected_domains), | |
| "selected_labels": "|".join(selected_labels), | |
| "domain_hits": "|".join(common.fmt_hit(d) for d in domain_hits), | |
| "label_hits": "|".join(common.fmt_hit(l) for l in label_hits), | |
| "elapsed_ms": str(data.get("elapsed_ms", "")), | |
| "elapsed_domain_ms": str(data.get("elapsed_domain_ms", "")), | |
| "elapsed_labels_ms": str(data.get("elapsed_labels_ms", "")), | |
| } | |
| def run(cfg: Config) -> int: | |
| images = list(common.iter_images(cfg.images)) | |
| if cfg.limit > 0: | |
| images = images[: cfg.limit] | |
| if not images: | |
| raise SystemExit("No images found.") | |
| rows: list[dict[str, str]] = [] | |
| with httpx.Client(base_url=cfg.api, timeout=30) as client: | |
| label_set_hash = common.upload_label_set(client, cfg.label_set) | |
| if cfg.activate: | |
| activate_label_set(client, label_set_hash) | |
| for image in images: | |
| data = common.classify_one( | |
| client, | |
| label_set_hash, | |
| image_b64=common.encode_image_b64(image), | |
| domain_top_n=cfg.domain_top_n, | |
| top_k=cfg.top_k, | |
| ) | |
| print(json.dumps({"image": str(image), "result": data}, ensure_ascii=True)) | |
| rows.append( | |
| to_row( | |
| image, | |
| data, | |
| select_domain_n=cfg.select_domain_n, | |
| select_label_n=cfg.select_label_n, | |
| min_domain_score=cfg.min_domain_score, | |
| min_label_score=cfg.min_label_score, | |
| ) | |
| ) | |
| fieldnames = [ | |
| "image", | |
| "label_set_hash", | |
| "model_id", | |
| "chosen_domains", | |
| "selected_domains", | |
| "selected_labels", | |
| "domain_hits", | |
| "label_hits", | |
| "elapsed_ms", | |
| "elapsed_domain_ms", | |
| "elapsed_labels_ms", | |
| ] | |
| out_dir = common.resolve_out_dir(cfg.api, cfg.out_dir) | |
| if cfg.csv_path and rows: | |
| common.write_csv(cfg.csv_path, rows, fieldnames) | |
| elif rows: | |
| out_path = out_dir / f"{cfg.label_set.stem}_{common.timestamp()}.csv" | |
| common.write_csv(out_path, rows, fieldnames) | |
| if cfg.summary: | |
| summary = common.summarize_latency(rows) | |
| summary_path = out_dir / f"{cfg.label_set.stem}_summary_{common.timestamp()}.csv" | |
| common.write_csv(summary_path, [summary], list(summary.keys())) | |
| return 0 | |
| def cli( | |
| api: str, | |
| label_set: Path, | |
| images: tuple[Path, ...], | |
| domain_top_n: int, | |
| top_k: int, | |
| activate: bool, | |
| limit: int, | |
| out_dir: Path | None, | |
| csv_path: Path | None, | |
| summary: bool, | |
| select_domain_n: int | None, | |
| select_label_n: int | None, | |
| min_domain_score: float | None, | |
| min_label_score: float | None, | |
| ) -> None: | |
| cfg = Config( | |
| api=api, | |
| label_set=label_set, | |
| images=list(images), | |
| domain_top_n=domain_top_n, | |
| top_k=top_k, | |
| activate=activate, | |
| limit=limit, | |
| out_dir=out_dir, | |
| csv_path=csv_path, | |
| summary=summary, | |
| select_domain_n=select_domain_n, | |
| select_label_n=select_label_n, | |
| min_domain_score=min_domain_score, | |
| min_label_score=min_label_score, | |
| ) | |
| raise SystemExit(run(cfg)) | |
| if __name__ == "__main__": | |
| cli() | |