photo-classification / src /eval /classify_dataset.py
esandorfi's picture
Change to API default directory
d547fdb
#!/usr/bin/env python3
from __future__ import annotations
import json
from dataclasses import dataclass
from pathlib import Path
import click
import httpx
from eval import common
@dataclass(frozen=True)
class Config:
api: str
label_set: Path
images: list[Path]
domain_top_n: int
top_k: int
activate: bool
limit: int
out_dir: Path | None
csv_path: Path | None
summary: bool
select_domain_n: int | None
select_label_n: int | None
min_domain_score: float | None
min_label_score: float | None
def activate_label_set(client: httpx.Client, label_set_hash: str) -> None:
r = client.post(f"/api/v1/label-sets/{label_set_hash}/activate")
r.raise_for_status()
def to_row(
image: Path,
data: dict,
*,
select_domain_n: int | None,
select_label_n: int | None,
min_domain_score: float | None,
min_label_score: float | None,
) -> dict[str, str]:
domain_hits = data.get("domain_hits", [])
label_hits = data.get("label_hits", [])
selected_domains = common.select_hits(domain_hits, max_n=select_domain_n, min_score=min_domain_score)
selected_labels = common.select_hits(label_hits, max_n=select_label_n, min_score=min_label_score)
return {
"image": str(image),
"label_set_hash": data.get("label_set_hash", ""),
"model_id": data.get("model_id", ""),
"chosen_domains": "|".join(data.get("chosen_domains", [])),
"selected_domains": "|".join(selected_domains),
"selected_labels": "|".join(selected_labels),
"domain_hits": "|".join(common.fmt_hit(d) for d in domain_hits),
"label_hits": "|".join(common.fmt_hit(l) for l in label_hits),
"elapsed_ms": str(data.get("elapsed_ms", "")),
"elapsed_domain_ms": str(data.get("elapsed_domain_ms", "")),
"elapsed_labels_ms": str(data.get("elapsed_labels_ms", "")),
}
def run(cfg: Config) -> int:
images = list(common.iter_images(cfg.images))
if cfg.limit > 0:
images = images[: cfg.limit]
if not images:
raise SystemExit("No images found.")
rows: list[dict[str, str]] = []
with httpx.Client(base_url=cfg.api, timeout=30) as client:
label_set_hash = common.upload_label_set(client, cfg.label_set)
if cfg.activate:
activate_label_set(client, label_set_hash)
for image in images:
data = common.classify_one(
client,
label_set_hash,
image_b64=common.encode_image_b64(image),
domain_top_n=cfg.domain_top_n,
top_k=cfg.top_k,
)
print(json.dumps({"image": str(image), "result": data}, ensure_ascii=True))
rows.append(
to_row(
image,
data,
select_domain_n=cfg.select_domain_n,
select_label_n=cfg.select_label_n,
min_domain_score=cfg.min_domain_score,
min_label_score=cfg.min_label_score,
)
)
fieldnames = [
"image",
"label_set_hash",
"model_id",
"chosen_domains",
"selected_domains",
"selected_labels",
"domain_hits",
"label_hits",
"elapsed_ms",
"elapsed_domain_ms",
"elapsed_labels_ms",
]
out_dir = common.resolve_out_dir(cfg.api, cfg.out_dir)
if cfg.csv_path and rows:
common.write_csv(cfg.csv_path, rows, fieldnames)
elif rows:
out_path = out_dir / f"{cfg.label_set.stem}_{common.timestamp()}.csv"
common.write_csv(out_path, rows, fieldnames)
if cfg.summary:
summary = common.summarize_latency(rows)
summary_path = out_dir / f"{cfg.label_set.stem}_summary_{common.timestamp()}.csv"
common.write_csv(summary_path, [summary], list(summary.keys()))
return 0
@click.command()
@click.option("--api", default="http://localhost:7860", show_default=True)
@click.option("--label-set", "label_set", required=True, type=click.Path(path_type=Path))
@click.option("--images", multiple=True, required=True, type=click.Path(path_type=Path))
@click.option("--domain-top-n", default=2, show_default=True, type=int)
@click.option("--top-k", default=5, show_default=True, type=int)
@click.option("--activate", is_flag=True, default=False)
@click.option("--limit", default=0, show_default=True, type=int)
@click.option("--out-dir", type=click.Path(path_type=Path))
@click.option("--csv", "csv_path", type=click.Path(path_type=Path))
@click.option("--summary", is_flag=True, default=False)
@click.option("--select-domain-n", type=int, default=None)
@click.option("--select-label-n", type=int, default=None)
@click.option("--min-domain-score", type=float, default=None)
@click.option("--min-label-score", type=float, default=None)
def cli(
api: str,
label_set: Path,
images: tuple[Path, ...],
domain_top_n: int,
top_k: int,
activate: bool,
limit: int,
out_dir: Path | None,
csv_path: Path | None,
summary: bool,
select_domain_n: int | None,
select_label_n: int | None,
min_domain_score: float | None,
min_label_score: float | None,
) -> None:
cfg = Config(
api=api,
label_set=label_set,
images=list(images),
domain_top_n=domain_top_n,
top_k=top_k,
activate=activate,
limit=limit,
out_dir=out_dir,
csv_path=csv_path,
summary=summary,
select_domain_n=select_domain_n,
select_label_n=select_label_n,
min_domain_score=min_domain_score,
min_label_score=min_label_score,
)
raise SystemExit(run(cfg))
if __name__ == "__main__":
cli()