photo-classification / src /eval /eval_matrix.py
esandorfi's picture
Change to API default directory
d547fdb
#!/usr/bin/env python3
from __future__ import annotations
import json
from dataclasses import dataclass
from pathlib import Path
from typing import Iterable
import click
import httpx
from eval import common
@dataclass(frozen=True)
class Config:
api: str
label_sets: list[Path]
images: list[Path]
domain_top_n: int
top_k: int
out_dir: Path | None
summary: bool
select_domain_n: int | None
select_label_n: int | None
min_domain_score: float | None
min_label_score: float | None
def expand_label_sets(paths: Iterable[str]) -> list[Path]:
out: list[Path] = []
for raw in paths:
p = Path(raw)
if any(ch in raw for ch in ["*", "?", "["]):
out.extend(sorted(Path().glob(raw)))
else:
out.append(p)
return [p for p in out if p.is_file()]
def to_row(
label_set: Path,
image: Path,
data: dict,
*,
select_domain_n: int | None,
select_label_n: int | None,
min_domain_score: float | None,
min_label_score: float | None,
) -> dict[str, str]:
domain_hits = data.get("domain_hits", [])
label_hits = data.get("label_hits", [])
selected_domains = common.select_hits(domain_hits, max_n=select_domain_n, min_score=min_domain_score)
selected_labels = common.select_hits(label_hits, max_n=select_label_n, min_score=min_label_score)
return {
"label_set": label_set.name,
"image": str(image),
"label_set_hash": data.get("label_set_hash", ""),
"model_id": data.get("model_id", ""),
"chosen_domains": "|".join(data.get("chosen_domains", [])),
"selected_domains": "|".join(selected_domains),
"selected_labels": "|".join(selected_labels),
"domain_hits": "|".join(common.fmt_hit(d) for d in domain_hits),
"label_hits": "|".join(common.fmt_hit(l) for l in label_hits),
"elapsed_ms": str(data.get("elapsed_ms", "")),
"elapsed_domain_ms": str(data.get("elapsed_domain_ms", "")),
"elapsed_labels_ms": str(data.get("elapsed_labels_ms", "")),
}
def summarize_by_label_set(rows: list[dict[str, str]]) -> list[dict[str, str]]:
summary: dict[str, list[int]] = {}
for row in rows:
label = row["label_set"]
try:
elapsed = int(row["elapsed_ms"])
except Exception:
continue
summary.setdefault(label, []).append(elapsed)
out_rows: list[dict[str, str]] = []
for label, times in summary.items():
avg = int(sum(times) / max(1, len(times)))
out_rows.append(
{
"label_set": label,
"count": str(len(times)),
"avg_elapsed_ms": str(avg),
"p50_elapsed_ms": str(common.percentile(times, 0.50)),
"p95_elapsed_ms": str(common.percentile(times, 0.95)),
}
)
return out_rows
def run(cfg: Config) -> None:
images = list(common.iter_images(cfg.images))
if not images:
raise SystemExit("No images found.")
if not cfg.label_sets:
raise SystemExit("No label sets found.")
rows: list[dict[str, str]] = []
with httpx.Client(base_url=cfg.api, timeout=30) as client:
for label_set in cfg.label_sets:
label_set_hash = common.upload_label_set(client, label_set)
for image in images:
data = common.classify_one(
client,
label_set_hash,
image_b64=common.encode_image_b64(image),
domain_top_n=cfg.domain_top_n,
top_k=cfg.top_k,
)
print(json.dumps({"label_set": label_set.name, "image": str(image), "result": data}))
rows.append(
to_row(
label_set,
image,
data,
select_domain_n=cfg.select_domain_n,
select_label_n=cfg.select_label_n,
min_domain_score=cfg.min_domain_score,
min_label_score=cfg.min_label_score,
)
)
fieldnames = [
"label_set",
"image",
"label_set_hash",
"model_id",
"chosen_domains",
"selected_domains",
"selected_labels",
"domain_hits",
"label_hits",
"elapsed_ms",
"elapsed_domain_ms",
"elapsed_labels_ms",
]
out_dir = common.resolve_out_dir(cfg.api, cfg.out_dir)
out_path = out_dir / f"eval_matrix_{common.timestamp()}.csv"
common.write_csv(out_path, rows, fieldnames)
if cfg.summary:
summary_rows = summarize_by_label_set(rows)
summary_path = out_dir / f"eval_matrix_summary_{common.timestamp()}.csv"
common.write_csv(summary_path, summary_rows, ["label_set", "count", "avg_elapsed_ms", "p50_elapsed_ms", "p95_elapsed_ms"])
@click.command()
@click.option("--api", default="http://localhost:7860", show_default=True)
@click.option("--label-sets", "label_sets_raw", multiple=True, required=True)
@click.option("--images", multiple=True, required=True, type=click.Path(path_type=Path))
@click.option("--domain-top-n", default=2, show_default=True, type=int)
@click.option("--top-k", default=5, show_default=True, type=int)
@click.option("--out-dir", type=click.Path(path_type=Path))
@click.option("--summary", is_flag=True, default=False)
@click.option("--select-domain-n", type=int, default=None)
@click.option("--select-label-n", type=int, default=None)
@click.option("--min-domain-score", type=float, default=None)
@click.option("--min-label-score", type=float, default=None)
def cli(
api: str,
label_sets_raw: tuple[str, ...],
images: tuple[Path, ...],
domain_top_n: int,
top_k: int,
out_dir: Path | None,
summary: bool,
select_domain_n: int | None,
select_label_n: int | None,
min_domain_score: float | None,
min_label_score: float | None,
) -> None:
label_sets = expand_label_sets(label_sets_raw)
cfg = Config(
api=api,
label_sets=label_sets,
images=list(images),
domain_top_n=domain_top_n,
top_k=top_k,
out_dir=out_dir,
summary=summary,
select_domain_n=select_domain_n,
select_label_n=select_label_n,
min_domain_score=min_domain_score,
min_label_score=min_label_score,
)
run(cfg)
if __name__ == "__main__":
cli()