Spaces:
Sleeping
Sleeping
Eval updates
Browse files- Makefile +13 -11
- README.md +17 -12
- pyproject.toml +3 -0
- src/eval/README.md +41 -0
- src/eval/__init__.py +1 -0
- {scripts → src/eval}/classify_dataset.py +31 -104
- src/eval/cli.py +21 -0
- src/eval/common.py +106 -0
- {scripts → src/eval}/dataset_prep.py +0 -0
- {scripts → src/eval}/eval_matrix.py +45 -78
- tests/test_eval_common.py +16 -0
Makefile
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
.PHONY: help docker-build docker-run local-install local-run local-test local-test-integration eval-photo
|
| 2 |
|
| 3 |
help:
|
| 4 |
@echo "---------------------------------------------------"
|
|
@@ -35,33 +35,35 @@ local-test-integration:
|
|
| 35 |
uv run pytest -q -m integration
|
| 36 |
|
| 37 |
eval-photo:
|
| 38 |
-
uv run
|
| 39 |
--label-set label-dataset/personal-photos-lite-v1.json \
|
| 40 |
--images data_eval/photos/normalized \
|
| 41 |
-
--
|
|
|
|
| 42 |
|
| 43 |
eval-dance:
|
| 44 |
-
uv run
|
| 45 |
--label-set label-dataset/scene-dance-formation-group-v1.json \
|
| 46 |
--images data_eval/dance/normalized \
|
| 47 |
-
--
|
|
|
|
| 48 |
|
| 49 |
eval-photo-matrix:
|
| 50 |
-
uv run
|
| 51 |
--label-sets "label-dataset/personal-photos-*.json" \
|
| 52 |
--images data_eval/photos/normalized \
|
| 53 |
--out-dir data_results \
|
| 54 |
-
--summary
|
| 55 |
|
| 56 |
eval-dance-matrix:
|
| 57 |
-
uv run
|
| 58 |
--label-sets "label-dataset/scene-dance-*.json" \
|
| 59 |
--images data_eval/dance/normalized \
|
| 60 |
--out-dir data_results \
|
| 61 |
-
--summary
|
| 62 |
|
| 63 |
data-photos:
|
| 64 |
-
uv run
|
| 65 |
|
| 66 |
data-dance:
|
| 67 |
-
uv run
|
|
|
|
| 1 |
+
.PHONY: help docker-build docker-run local-install local-run local-test local-test-integration eval-photo eval-dance eval-photo-matrix eval-dance-matrix data-photos data-dance
|
| 2 |
|
| 3 |
help:
|
| 4 |
@echo "---------------------------------------------------"
|
|
|
|
| 35 |
uv run pytest -q -m integration
|
| 36 |
|
| 37 |
eval-photo:
|
| 38 |
+
uv run photo-eval single \
|
| 39 |
--label-set label-dataset/personal-photos-lite-v1.json \
|
| 40 |
--images data_eval/photos/normalized \
|
| 41 |
+
--out-dir data_results \
|
| 42 |
+
--summary
|
| 43 |
|
| 44 |
eval-dance:
|
| 45 |
+
uv run photo-eval single \
|
| 46 |
--label-set label-dataset/scene-dance-formation-group-v1.json \
|
| 47 |
--images data_eval/dance/normalized \
|
| 48 |
+
--out-dir data_results \
|
| 49 |
+
--summary
|
| 50 |
|
| 51 |
eval-photo-matrix:
|
| 52 |
+
uv run photo-eval matrix \
|
| 53 |
--label-sets "label-dataset/personal-photos-*.json" \
|
| 54 |
--images data_eval/photos/normalized \
|
| 55 |
--out-dir data_results \
|
| 56 |
+
--summary
|
| 57 |
|
| 58 |
eval-dance-matrix:
|
| 59 |
+
uv run photo-eval matrix \
|
| 60 |
--label-sets "label-dataset/scene-dance-*.json" \
|
| 61 |
--images data_eval/dance/normalized \
|
| 62 |
--out-dir data_results \
|
| 63 |
+
--summary
|
| 64 |
|
| 65 |
data-photos:
|
| 66 |
+
uv run photo-eval prep --out data_eval --target photos --n 50 --normalize
|
| 67 |
|
| 68 |
data-dance:
|
| 69 |
+
uv run photo-eval prep --out data_eval --target dance --n 50 --normalize
|
README.md
CHANGED
|
@@ -154,36 +154,41 @@ Error handling and HTTP boundaries:
|
|
| 154 |
|
| 155 |
## Eval scripts
|
| 156 |
|
| 157 |
-
Use the lightweight evaluator
|
| 158 |
local images and capture timings:
|
| 159 |
|
| 160 |
```bash
|
| 161 |
-
uv run
|
| 162 |
--label-set label-dataset/personal-photos-lite-v1.json \
|
| 163 |
--images /path/to/images \
|
| 164 |
-
--
|
|
|
|
| 165 |
```
|
| 166 |
|
|
|
|
|
|
|
| 167 |
Makefile shortcuts:
|
| 168 |
|
| 169 |
-
- `make eval-photo
|
| 170 |
-
- `make eval-dance
|
| 171 |
- `make eval-photo-matrix`
|
| 172 |
- `make eval-dance-matrix`
|
| 173 |
|
|
|
|
|
|
|
| 174 |
Matrix eval (multiple label sets against the same images):
|
| 175 |
|
| 176 |
```bash
|
| 177 |
-
uv run
|
| 178 |
--label-sets "label-dataset/personal-photos-*.json" \
|
| 179 |
--images data_eval/photos/normalized \
|
| 180 |
--out-dir data_results \
|
| 181 |
-
--summary
|
| 182 |
```
|
| 183 |
|
| 184 |
## Eval datasets (download schema)
|
| 185 |
|
| 186 |
-
We use a simple, reproducible layout for evaluation datasets created by `
|
| 187 |
|
| 188 |
```
|
| 189 |
data_eval/
|
|
@@ -198,20 +203,20 @@ data_eval/
|
|
| 198 |
Download and normalize (recommended):
|
| 199 |
|
| 200 |
```bash
|
| 201 |
-
uv run
|
| 202 |
-
uv run
|
| 203 |
```
|
| 204 |
|
| 205 |
Reset existing files and start fresh:
|
| 206 |
|
| 207 |
```bash
|
| 208 |
-
uv run
|
| 209 |
```
|
| 210 |
|
| 211 |
Normalize your own folder into the same schema:
|
| 212 |
|
| 213 |
```bash
|
| 214 |
-
uv run
|
| 215 |
```
|
| 216 |
|
| 217 |
## Project layout
|
|
|
|
| 154 |
|
| 155 |
## Eval scripts
|
| 156 |
|
| 157 |
+
Use the lightweight evaluator via `photo-eval` to run a label set against
|
| 158 |
local images and capture timings:
|
| 159 |
|
| 160 |
```bash
|
| 161 |
+
uv run photo-eval single \
|
| 162 |
--label-set label-dataset/personal-photos-lite-v1.json \
|
| 163 |
--images /path/to/images \
|
| 164 |
+
--out-dir data_results \
|
| 165 |
+
--summary
|
| 166 |
```
|
| 167 |
|
| 168 |
+
Output CSV files are timestamped (UTC) in `data_results/`.
|
| 169 |
+
|
| 170 |
Makefile shortcuts:
|
| 171 |
|
| 172 |
+
- `make eval-photo`
|
| 173 |
+
- `make eval-dance`
|
| 174 |
- `make eval-photo-matrix`
|
| 175 |
- `make eval-dance-matrix`
|
| 176 |
|
| 177 |
+
See `src/eval/README.md` for the eval CLI reference and API endpoints.
|
| 178 |
+
|
| 179 |
Matrix eval (multiple label sets against the same images):
|
| 180 |
|
| 181 |
```bash
|
| 182 |
+
uv run photo-eval matrix \
|
| 183 |
--label-sets "label-dataset/personal-photos-*.json" \
|
| 184 |
--images data_eval/photos/normalized \
|
| 185 |
--out-dir data_results \
|
| 186 |
+
--summary
|
| 187 |
```
|
| 188 |
|
| 189 |
## Eval datasets (download schema)
|
| 190 |
|
| 191 |
+
We use a simple, reproducible layout for evaluation datasets created by `photo-eval prep`:
|
| 192 |
|
| 193 |
```
|
| 194 |
data_eval/
|
|
|
|
| 203 |
Download and normalize (recommended):
|
| 204 |
|
| 205 |
```bash
|
| 206 |
+
uv run photo-eval prep --out data_eval --target photos --n 50 --normalize
|
| 207 |
+
uv run photo-eval prep --out data_eval --target dance --n 50 --normalize
|
| 208 |
```
|
| 209 |
|
| 210 |
Reset existing files and start fresh:
|
| 211 |
|
| 212 |
```bash
|
| 213 |
+
uv run photo-eval prep --out data_eval --target photos --n 50 --normalize --reset
|
| 214 |
```
|
| 215 |
|
| 216 |
Normalize your own folder into the same schema:
|
| 217 |
|
| 218 |
```bash
|
| 219 |
+
uv run photo-eval prep --normalize-only --in-dir /path/to/images --out data_eval/photos
|
| 220 |
```
|
| 221 |
|
| 222 |
## Project layout
|
pyproject.toml
CHANGED
|
@@ -27,6 +27,9 @@ dev = [
|
|
| 27 |
"tqdm==4.66.5",
|
| 28 |
]
|
| 29 |
|
|
|
|
|
|
|
|
|
|
| 30 |
[tool.setuptools]
|
| 31 |
package-dir = {"" = "src"}
|
| 32 |
|
|
|
|
| 27 |
"tqdm==4.66.5",
|
| 28 |
]
|
| 29 |
|
| 30 |
+
[project.scripts]
|
| 31 |
+
photo-eval = "eval.cli:main"
|
| 32 |
+
|
| 33 |
[tool.setuptools]
|
| 34 |
package-dir = {"" = "src"}
|
| 35 |
|
src/eval/README.md
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Eval CLI
|
| 2 |
+
|
| 3 |
+
Quick reference for the standalone evaluation tools in `src/eval`.
|
| 4 |
+
|
| 5 |
+
## Commands
|
| 6 |
+
|
| 7 |
+
- `photo-eval prep`
|
| 8 |
+
Download a small evaluation dataset (photos or dance) and optionally normalize.
|
| 9 |
+
|
| 10 |
+
- `photo-eval single`
|
| 11 |
+
Upload one label set, classify a set of images, and write timestamped CSV + optional summary.
|
| 12 |
+
|
| 13 |
+
- `photo-eval matrix`
|
| 14 |
+
Upload multiple label sets and run the same images against each; outputs a matrix CSV + optional summary.
|
| 15 |
+
|
| 16 |
+
## API endpoints used
|
| 17 |
+
|
| 18 |
+
All eval commands call the running API (default `http://localhost:7860`):
|
| 19 |
+
|
| 20 |
+
- `POST /api/v1/label-sets` (upload a label set)
|
| 21 |
+
- `POST /api/v1/label-sets/{label_set_hash}/activate` (optional for `single` via `--activate`)
|
| 22 |
+
- `POST /api/v1/classify` (classify image with optional `label_set_hash`)
|
| 23 |
+
|
| 24 |
+
## Examples
|
| 25 |
+
|
| 26 |
+
```bash
|
| 27 |
+
photo-eval prep --out data_eval --target photos --n 50 --normalize
|
| 28 |
+
photo-eval prep --out data_eval --target dance --n 50 --normalize
|
| 29 |
+
|
| 30 |
+
photo-eval single \
|
| 31 |
+
--label-set label-dataset/personal-photos-lite-v1.json \
|
| 32 |
+
--images data_eval/photos/normalized \
|
| 33 |
+
--out-dir data_results \
|
| 34 |
+
--summary
|
| 35 |
+
|
| 36 |
+
photo-eval matrix \
|
| 37 |
+
--label-sets "label-dataset/personal-photos-*.json" \
|
| 38 |
+
--images data_eval/photos/normalized \
|
| 39 |
+
--out-dir data_results \
|
| 40 |
+
--summary
|
| 41 |
+
```
|
src/eval/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Eval CLI package (autonomous from app)."""
|
{scripts → src/eval}/classify_dataset.py
RENAMED
|
@@ -1,17 +1,15 @@
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
from __future__ import annotations
|
| 3 |
|
| 4 |
-
import base64
|
| 5 |
-
import csv
|
| 6 |
import json
|
| 7 |
from dataclasses import dataclass
|
| 8 |
-
from datetime import datetime, timezone
|
| 9 |
from pathlib import Path
|
| 10 |
-
from typing import Iterable
|
| 11 |
|
| 12 |
import click
|
| 13 |
import httpx
|
| 14 |
|
|
|
|
|
|
|
| 15 |
|
| 16 |
@dataclass(frozen=True)
|
| 17 |
class Config:
|
|
@@ -27,44 +25,11 @@ class Config:
|
|
| 27 |
summary: bool
|
| 28 |
|
| 29 |
|
| 30 |
-
def iter_images(paths: list[Path]) -> Iterable[Path]:
|
| 31 |
-
exts = {".jpg", ".jpeg", ".png", ".webp"}
|
| 32 |
-
for path in paths:
|
| 33 |
-
if path.is_dir():
|
| 34 |
-
for p in sorted(path.rglob("*")):
|
| 35 |
-
if p.suffix.lower() in exts and p.is_file():
|
| 36 |
-
yield p
|
| 37 |
-
elif path.is_file() and path.suffix.lower() in exts:
|
| 38 |
-
yield path
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
def b64_image(path: Path) -> str:
|
| 42 |
-
return base64.b64encode(path.read_bytes()).decode("utf-8")
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
def upload_label_set(client: httpx.Client, label_set: Path) -> str:
|
| 46 |
-
payload = json.loads(label_set.read_text())
|
| 47 |
-
r = client.post("/api/v1/label-sets", json=payload)
|
| 48 |
-
r.raise_for_status()
|
| 49 |
-
return r.json()["label_set_hash"]
|
| 50 |
-
|
| 51 |
-
|
| 52 |
def activate_label_set(client: httpx.Client, label_set_hash: str) -> None:
|
| 53 |
r = client.post(f"/api/v1/label-sets/{label_set_hash}/activate")
|
| 54 |
r.raise_for_status()
|
| 55 |
|
| 56 |
|
| 57 |
-
def classify_one(client: httpx.Client, label_set_hash: str, image: Path, domain_top_n: int, top_k: int) -> dict:
|
| 58 |
-
payload = {
|
| 59 |
-
"image_base64": b64_image(image),
|
| 60 |
-
"domain_top_n": domain_top_n,
|
| 61 |
-
"top_k": top_k,
|
| 62 |
-
}
|
| 63 |
-
r = client.post(f"/api/v1/classify?label_set_hash={label_set_hash}", json=payload)
|
| 64 |
-
r.raise_for_status()
|
| 65 |
-
return r.json()
|
| 66 |
-
|
| 67 |
-
|
| 68 |
def to_row(image: Path, data: dict) -> dict[str, str]:
|
| 69 |
domain_hits = data.get("domain_hits", [])
|
| 70 |
label_hits = data.get("label_hits", [])
|
|
@@ -73,71 +38,16 @@ def to_row(image: Path, data: dict) -> dict[str, str]:
|
|
| 73 |
"label_set_hash": data.get("label_set_hash", ""),
|
| 74 |
"model_id": data.get("model_id", ""),
|
| 75 |
"chosen_domains": "|".join(data.get("chosen_domains", [])),
|
| 76 |
-
"domain_hits": "|".join(
|
| 77 |
-
"label_hits": "|".join(
|
| 78 |
"elapsed_ms": str(data.get("elapsed_ms", "")),
|
| 79 |
"elapsed_domain_ms": str(data.get("elapsed_domain_ms", "")),
|
| 80 |
"elapsed_labels_ms": str(data.get("elapsed_labels_ms", "")),
|
| 81 |
}
|
| 82 |
|
| 83 |
|
| 84 |
-
def write_csv(path: Path, rows: list[dict[str, str]]) -> None:
|
| 85 |
-
path.parent.mkdir(parents=True, exist_ok=True)
|
| 86 |
-
with path.open("w", newline="", encoding="utf-8") as f:
|
| 87 |
-
writer = csv.DictWriter(
|
| 88 |
-
f,
|
| 89 |
-
fieldnames=[
|
| 90 |
-
"image",
|
| 91 |
-
"label_set_hash",
|
| 92 |
-
"model_id",
|
| 93 |
-
"chosen_domains",
|
| 94 |
-
"domain_hits",
|
| 95 |
-
"label_hits",
|
| 96 |
-
"elapsed_ms",
|
| 97 |
-
"elapsed_domain_ms",
|
| 98 |
-
"elapsed_labels_ms",
|
| 99 |
-
],
|
| 100 |
-
)
|
| 101 |
-
writer.writeheader()
|
| 102 |
-
writer.writerows(rows)
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
def _percentile(values: list[int], q: float) -> int:
|
| 106 |
-
if not values:
|
| 107 |
-
return 0
|
| 108 |
-
values = sorted(values)
|
| 109 |
-
idx = int(round((len(values) - 1) * q))
|
| 110 |
-
return values[idx]
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
def write_summary(path: Path, rows: list[dict[str, str]]) -> None:
|
| 114 |
-
times: list[int] = []
|
| 115 |
-
for row in rows:
|
| 116 |
-
try:
|
| 117 |
-
times.append(int(row["elapsed_ms"]))
|
| 118 |
-
except Exception:
|
| 119 |
-
continue
|
| 120 |
-
summary = {
|
| 121 |
-
"count": str(len(times)),
|
| 122 |
-
"avg_elapsed_ms": str(int(sum(times) / max(1, len(times)))),
|
| 123 |
-
"p50_elapsed_ms": str(_percentile(times, 0.50)),
|
| 124 |
-
"p90_elapsed_ms": str(_percentile(times, 0.90)),
|
| 125 |
-
"p95_elapsed_ms": str(_percentile(times, 0.95)),
|
| 126 |
-
"p99_elapsed_ms": str(_percentile(times, 0.99)),
|
| 127 |
-
}
|
| 128 |
-
path.parent.mkdir(parents=True, exist_ok=True)
|
| 129 |
-
with path.open("w", newline="", encoding="utf-8") as f:
|
| 130 |
-
writer = csv.DictWriter(f, fieldnames=list(summary.keys()))
|
| 131 |
-
writer.writeheader()
|
| 132 |
-
writer.writerow(summary)
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
def _timestamp() -> str:
|
| 136 |
-
return datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S")
|
| 137 |
-
|
| 138 |
-
|
| 139 |
def run(cfg: Config) -> int:
|
| 140 |
-
images = list(iter_images(cfg.images))
|
| 141 |
if cfg.limit > 0:
|
| 142 |
images = images[: cfg.limit]
|
| 143 |
if not images:
|
|
@@ -145,26 +55,43 @@ def run(cfg: Config) -> int:
|
|
| 145 |
|
| 146 |
rows: list[dict[str, str]] = []
|
| 147 |
with httpx.Client(base_url=cfg.api, timeout=30) as client:
|
| 148 |
-
label_set_hash = upload_label_set(client, cfg.label_set)
|
| 149 |
if cfg.activate:
|
| 150 |
activate_label_set(client, label_set_hash)
|
| 151 |
|
| 152 |
for image in images:
|
| 153 |
-
data = classify_one(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
print(json.dumps({"image": str(image), "result": data}, ensure_ascii=True))
|
| 155 |
rows.append(to_row(image, data))
|
| 156 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
if cfg.csv_path and rows:
|
| 158 |
-
write_csv(cfg.csv_path, rows)
|
| 159 |
elif rows:
|
| 160 |
-
|
| 161 |
-
out_path
|
| 162 |
-
write_csv(out_path, rows)
|
| 163 |
|
| 164 |
if cfg.summary:
|
| 165 |
-
|
| 166 |
-
summary_path = cfg.out_dir / f"{cfg.label_set.stem}_summary_{
|
| 167 |
-
|
| 168 |
|
| 169 |
return 0
|
| 170 |
|
|
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
from __future__ import annotations
|
| 3 |
|
|
|
|
|
|
|
| 4 |
import json
|
| 5 |
from dataclasses import dataclass
|
|
|
|
| 6 |
from pathlib import Path
|
|
|
|
| 7 |
|
| 8 |
import click
|
| 9 |
import httpx
|
| 10 |
|
| 11 |
+
from eval import common
|
| 12 |
+
|
| 13 |
|
| 14 |
@dataclass(frozen=True)
|
| 15 |
class Config:
|
|
|
|
| 25 |
summary: bool
|
| 26 |
|
| 27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
def activate_label_set(client: httpx.Client, label_set_hash: str) -> None:
|
| 29 |
r = client.post(f"/api/v1/label-sets/{label_set_hash}/activate")
|
| 30 |
r.raise_for_status()
|
| 31 |
|
| 32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
def to_row(image: Path, data: dict) -> dict[str, str]:
|
| 34 |
domain_hits = data.get("domain_hits", [])
|
| 35 |
label_hits = data.get("label_hits", [])
|
|
|
|
| 38 |
"label_set_hash": data.get("label_set_hash", ""),
|
| 39 |
"model_id": data.get("model_id", ""),
|
| 40 |
"chosen_domains": "|".join(data.get("chosen_domains", [])),
|
| 41 |
+
"domain_hits": "|".join(common.fmt_hit(d) for d in domain_hits),
|
| 42 |
+
"label_hits": "|".join(common.fmt_hit(l) for l in label_hits),
|
| 43 |
"elapsed_ms": str(data.get("elapsed_ms", "")),
|
| 44 |
"elapsed_domain_ms": str(data.get("elapsed_domain_ms", "")),
|
| 45 |
"elapsed_labels_ms": str(data.get("elapsed_labels_ms", "")),
|
| 46 |
}
|
| 47 |
|
| 48 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
def run(cfg: Config) -> int:
|
| 50 |
+
images = list(common.iter_images(cfg.images))
|
| 51 |
if cfg.limit > 0:
|
| 52 |
images = images[: cfg.limit]
|
| 53 |
if not images:
|
|
|
|
| 55 |
|
| 56 |
rows: list[dict[str, str]] = []
|
| 57 |
with httpx.Client(base_url=cfg.api, timeout=30) as client:
|
| 58 |
+
label_set_hash = common.upload_label_set(client, cfg.label_set)
|
| 59 |
if cfg.activate:
|
| 60 |
activate_label_set(client, label_set_hash)
|
| 61 |
|
| 62 |
for image in images:
|
| 63 |
+
data = common.classify_one(
|
| 64 |
+
client,
|
| 65 |
+
label_set_hash,
|
| 66 |
+
image_b64=common.encode_image_b64(image),
|
| 67 |
+
domain_top_n=cfg.domain_top_n,
|
| 68 |
+
top_k=cfg.top_k,
|
| 69 |
+
)
|
| 70 |
print(json.dumps({"image": str(image), "result": data}, ensure_ascii=True))
|
| 71 |
rows.append(to_row(image, data))
|
| 72 |
|
| 73 |
+
fieldnames = [
|
| 74 |
+
"image",
|
| 75 |
+
"label_set_hash",
|
| 76 |
+
"model_id",
|
| 77 |
+
"chosen_domains",
|
| 78 |
+
"domain_hits",
|
| 79 |
+
"label_hits",
|
| 80 |
+
"elapsed_ms",
|
| 81 |
+
"elapsed_domain_ms",
|
| 82 |
+
"elapsed_labels_ms",
|
| 83 |
+
]
|
| 84 |
+
|
| 85 |
if cfg.csv_path and rows:
|
| 86 |
+
common.write_csv(cfg.csv_path, rows, fieldnames)
|
| 87 |
elif rows:
|
| 88 |
+
out_path = cfg.out_dir / f"{cfg.label_set.stem}_{common.timestamp()}.csv"
|
| 89 |
+
common.write_csv(out_path, rows, fieldnames)
|
|
|
|
| 90 |
|
| 91 |
if cfg.summary:
|
| 92 |
+
summary = common.summarize_latency(rows)
|
| 93 |
+
summary_path = cfg.out_dir / f"{cfg.label_set.stem}_summary_{common.timestamp()}.csv"
|
| 94 |
+
common.write_csv(summary_path, [summary], list(summary.keys()))
|
| 95 |
|
| 96 |
return 0
|
| 97 |
|
src/eval/cli.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import click
|
| 4 |
+
|
| 5 |
+
from eval.classify_dataset import cli as classify_cli
|
| 6 |
+
from eval.dataset_prep import cli as prep_cli
|
| 7 |
+
from eval.eval_matrix import cli as matrix_cli
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@click.group()
|
| 11 |
+
def main() -> None:
|
| 12 |
+
"""Evaluation CLI (standalone from the API)."""
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
main.add_command(prep_cli, name="prep")
|
| 16 |
+
main.add_command(classify_cli, name="single")
|
| 17 |
+
main.add_command(matrix_cli, name="matrix")
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if __name__ == "__main__":
|
| 21 |
+
main()
|
src/eval/common.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import base64
|
| 4 |
+
import csv
|
| 5 |
+
import json
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from datetime import datetime, timezone
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Iterable
|
| 10 |
+
|
| 11 |
+
import httpx
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@dataclass(frozen=True)
|
| 15 |
+
class EvalConfig:
|
| 16 |
+
api: str
|
| 17 |
+
images: list[Path]
|
| 18 |
+
domain_top_n: int
|
| 19 |
+
top_k: int
|
| 20 |
+
out_dir: Path
|
| 21 |
+
summary: bool
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def iter_images(paths: Iterable[Path]) -> Iterable[Path]:
|
| 25 |
+
exts = {".jpg", ".jpeg", ".png", ".webp"}
|
| 26 |
+
for path in paths:
|
| 27 |
+
if path.is_dir():
|
| 28 |
+
for p in sorted(path.rglob("*")):
|
| 29 |
+
if p.is_file() and p.suffix.lower() in exts:
|
| 30 |
+
yield p
|
| 31 |
+
elif path.is_file() and path.suffix.lower() in exts:
|
| 32 |
+
yield path
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def upload_label_set(client: httpx.Client, label_set: Path) -> str:
|
| 36 |
+
payload = json.loads(label_set.read_text())
|
| 37 |
+
r = client.post("/api/v1/label-sets", json=payload)
|
| 38 |
+
r.raise_for_status()
|
| 39 |
+
return r.json()["label_set_hash"]
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def classify_one(
|
| 43 |
+
client: httpx.Client,
|
| 44 |
+
label_set_hash: str,
|
| 45 |
+
image_b64: str,
|
| 46 |
+
domain_top_n: int,
|
| 47 |
+
top_k: int,
|
| 48 |
+
) -> dict:
|
| 49 |
+
payload = {
|
| 50 |
+
"image_base64": image_b64,
|
| 51 |
+
"domain_top_n": domain_top_n,
|
| 52 |
+
"top_k": top_k,
|
| 53 |
+
}
|
| 54 |
+
r = client.post(f"/api/v1/classify?label_set_hash={label_set_hash}", json=payload)
|
| 55 |
+
r.raise_for_status()
|
| 56 |
+
return r.json()
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def encode_image_b64(path: Path) -> str:
|
| 60 |
+
return base64.b64encode(path.read_bytes()).decode("utf-8")
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def fmt_hit(hit: dict) -> str:
|
| 64 |
+
score = hit.get("score")
|
| 65 |
+
try:
|
| 66 |
+
score_str = f"{float(score):.4f}"
|
| 67 |
+
except Exception:
|
| 68 |
+
score_str = ""
|
| 69 |
+
return f"{hit.get('id')}:{score_str}"
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def percentile(values: list[int], q: float) -> int:
|
| 73 |
+
if not values:
|
| 74 |
+
return 0
|
| 75 |
+
values = sorted(values)
|
| 76 |
+
idx = int(round((len(values) - 1) * q))
|
| 77 |
+
return values[idx]
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def timestamp() -> str:
|
| 81 |
+
return datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S")
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def write_csv(path: Path, rows: list[dict[str, str]], fieldnames: list[str]) -> None:
|
| 85 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 86 |
+
with path.open("w", newline="", encoding="utf-8") as f:
|
| 87 |
+
writer = csv.DictWriter(f, fieldnames=fieldnames)
|
| 88 |
+
writer.writeheader()
|
| 89 |
+
writer.writerows(rows)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def summarize_latency(rows: list[dict[str, str]]) -> dict[str, str]:
|
| 93 |
+
times: list[int] = []
|
| 94 |
+
for row in rows:
|
| 95 |
+
try:
|
| 96 |
+
times.append(int(row["elapsed_ms"]))
|
| 97 |
+
except Exception:
|
| 98 |
+
continue
|
| 99 |
+
return {
|
| 100 |
+
"count": str(len(times)),
|
| 101 |
+
"avg_elapsed_ms": str(int(sum(times) / max(1, len(times)))),
|
| 102 |
+
"p50_elapsed_ms": str(percentile(times, 0.50)),
|
| 103 |
+
"p90_elapsed_ms": str(percentile(times, 0.90)),
|
| 104 |
+
"p95_elapsed_ms": str(percentile(times, 0.95)),
|
| 105 |
+
"p99_elapsed_ms": str(percentile(times, 0.99)),
|
| 106 |
+
}
|
{scripts → src/eval}/dataset_prep.py
RENAMED
|
File without changes
|
{scripts → src/eval}/eval_matrix.py
RENAMED
|
@@ -1,16 +1,16 @@
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
from __future__ import annotations
|
| 3 |
|
| 4 |
-
import csv
|
| 5 |
import json
|
| 6 |
from dataclasses import dataclass
|
| 7 |
-
import base64
|
| 8 |
from pathlib import Path
|
| 9 |
from typing import Iterable
|
| 10 |
|
| 11 |
import click
|
| 12 |
import httpx
|
| 13 |
|
|
|
|
|
|
|
| 14 |
|
| 15 |
@dataclass(frozen=True)
|
| 16 |
class Config:
|
|
@@ -20,18 +20,7 @@ class Config:
|
|
| 20 |
domain_top_n: int
|
| 21 |
top_k: int
|
| 22 |
out_dir: Path
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
def iter_images(paths: Iterable[Path]) -> Iterable[Path]:
|
| 27 |
-
exts = {".jpg", ".jpeg", ".png", ".webp"}
|
| 28 |
-
for path in paths:
|
| 29 |
-
if path.is_dir():
|
| 30 |
-
for p in sorted(path.rglob("*")):
|
| 31 |
-
if p.is_file() and p.suffix.lower() in exts:
|
| 32 |
-
yield p
|
| 33 |
-
elif path.is_file() and path.suffix.lower() in exts:
|
| 34 |
-
yield path
|
| 35 |
|
| 36 |
|
| 37 |
def expand_label_sets(paths: Iterable[str]) -> list[Path]:
|
|
@@ -45,64 +34,24 @@ def expand_label_sets(paths: Iterable[str]) -> list[Path]:
|
|
| 45 |
return [p for p in out if p.is_file()]
|
| 46 |
|
| 47 |
|
| 48 |
-
def upload_label_set(client: httpx.Client, label_set: Path) -> str:
|
| 49 |
-
payload = json.loads(label_set.read_text())
|
| 50 |
-
r = client.post("/api/v1/label-sets", json=payload)
|
| 51 |
-
r.raise_for_status()
|
| 52 |
-
return r.json()["label_set_hash"]
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
def classify_one(
|
| 56 |
-
client: httpx.Client,
|
| 57 |
-
label_set_hash: str,
|
| 58 |
-
image_b64: str,
|
| 59 |
-
domain_top_n: int,
|
| 60 |
-
top_k: int,
|
| 61 |
-
) -> dict:
|
| 62 |
-
payload = {
|
| 63 |
-
"image_base64": image_b64,
|
| 64 |
-
"domain_top_n": domain_top_n,
|
| 65 |
-
"top_k": top_k,
|
| 66 |
-
}
|
| 67 |
-
r = client.post(f"/api/v1/classify?label_set_hash={label_set_hash}", json=payload)
|
| 68 |
-
r.raise_for_status()
|
| 69 |
-
return r.json()
|
| 70 |
-
|
| 71 |
-
|
| 72 |
def to_row(label_set: Path, image: Path, data: dict) -> dict[str, str]:
|
|
|
|
|
|
|
| 73 |
return {
|
| 74 |
"label_set": label_set.name,
|
| 75 |
"image": str(image),
|
| 76 |
"label_set_hash": data.get("label_set_hash", ""),
|
| 77 |
"model_id": data.get("model_id", ""),
|
| 78 |
"chosen_domains": "|".join(data.get("chosen_domains", [])),
|
|
|
|
|
|
|
| 79 |
"elapsed_ms": str(data.get("elapsed_ms", "")),
|
| 80 |
"elapsed_domain_ms": str(data.get("elapsed_domain_ms", "")),
|
| 81 |
"elapsed_labels_ms": str(data.get("elapsed_labels_ms", "")),
|
| 82 |
}
|
| 83 |
|
| 84 |
|
| 85 |
-
def
|
| 86 |
-
path.parent.mkdir(parents=True, exist_ok=True)
|
| 87 |
-
with path.open("w", newline="", encoding="utf-8") as f:
|
| 88 |
-
writer = csv.DictWriter(
|
| 89 |
-
f,
|
| 90 |
-
fieldnames=[
|
| 91 |
-
"label_set",
|
| 92 |
-
"image",
|
| 93 |
-
"label_set_hash",
|
| 94 |
-
"model_id",
|
| 95 |
-
"chosen_domains",
|
| 96 |
-
"elapsed_ms",
|
| 97 |
-
"elapsed_domain_ms",
|
| 98 |
-
"elapsed_labels_ms",
|
| 99 |
-
],
|
| 100 |
-
)
|
| 101 |
-
writer.writeheader()
|
| 102 |
-
writer.writerows(rows)
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
def write_summary(path: Path, rows: list[dict[str, str]]) -> None:
|
| 106 |
summary: dict[str, list[int]] = {}
|
| 107 |
for row in rows:
|
| 108 |
label = row["label_set"]
|
|
@@ -115,17 +64,20 @@ def write_summary(path: Path, rows: list[dict[str, str]]) -> None:
|
|
| 115 |
out_rows: list[dict[str, str]] = []
|
| 116 |
for label, times in summary.items():
|
| 117 |
avg = int(sum(times) / max(1, len(times)))
|
| 118 |
-
out_rows.append(
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
|
|
|
|
|
|
|
|
|
| 125 |
|
| 126 |
|
| 127 |
def run(cfg: Config) -> None:
|
| 128 |
-
images = list(iter_images(cfg.images))
|
| 129 |
if not images:
|
| 130 |
raise SystemExit("No images found.")
|
| 131 |
if not cfg.label_sets:
|
|
@@ -134,23 +86,38 @@ def run(cfg: Config) -> None:
|
|
| 134 |
rows: list[dict[str, str]] = []
|
| 135 |
with httpx.Client(base_url=cfg.api, timeout=30) as client:
|
| 136 |
for label_set in cfg.label_sets:
|
| 137 |
-
label_set_hash = upload_label_set(client, label_set)
|
| 138 |
for image in images:
|
| 139 |
-
|
| 140 |
-
data = classify_one(
|
| 141 |
client,
|
| 142 |
label_set_hash,
|
| 143 |
-
image_b64=
|
| 144 |
domain_top_n=cfg.domain_top_n,
|
| 145 |
top_k=cfg.top_k,
|
| 146 |
)
|
| 147 |
print(json.dumps({"label_set": label_set.name, "image": str(image), "result": data}))
|
| 148 |
rows.append(to_row(label_set, image, data))
|
| 149 |
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
|
| 155 |
|
| 156 |
@click.command()
|
|
@@ -160,7 +127,7 @@ def run(cfg: Config) -> None:
|
|
| 160 |
@click.option("--domain-top-n", default=2, show_default=True, type=int)
|
| 161 |
@click.option("--top-k", default=5, show_default=True, type=int)
|
| 162 |
@click.option("--out-dir", default="data_results", show_default=True, type=click.Path(path_type=Path))
|
| 163 |
-
@click.option("--summary
|
| 164 |
def cli(
|
| 165 |
api: str,
|
| 166 |
label_sets_raw: tuple[str, ...],
|
|
@@ -168,7 +135,7 @@ def cli(
|
|
| 168 |
domain_top_n: int,
|
| 169 |
top_k: int,
|
| 170 |
out_dir: Path,
|
| 171 |
-
|
| 172 |
) -> None:
|
| 173 |
label_sets = expand_label_sets(label_sets_raw)
|
| 174 |
cfg = Config(
|
|
@@ -178,7 +145,7 @@ def cli(
|
|
| 178 |
domain_top_n=domain_top_n,
|
| 179 |
top_k=top_k,
|
| 180 |
out_dir=out_dir,
|
| 181 |
-
|
| 182 |
)
|
| 183 |
run(cfg)
|
| 184 |
|
|
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
from __future__ import annotations
|
| 3 |
|
|
|
|
| 4 |
import json
|
| 5 |
from dataclasses import dataclass
|
|
|
|
| 6 |
from pathlib import Path
|
| 7 |
from typing import Iterable
|
| 8 |
|
| 9 |
import click
|
| 10 |
import httpx
|
| 11 |
|
| 12 |
+
from eval import common
|
| 13 |
+
|
| 14 |
|
| 15 |
@dataclass(frozen=True)
|
| 16 |
class Config:
|
|
|
|
| 20 |
domain_top_n: int
|
| 21 |
top_k: int
|
| 22 |
out_dir: Path
|
| 23 |
+
summary: bool
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
|
| 26 |
def expand_label_sets(paths: Iterable[str]) -> list[Path]:
|
|
|
|
| 34 |
return [p for p in out if p.is_file()]
|
| 35 |
|
| 36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
def to_row(label_set: Path, image: Path, data: dict) -> dict[str, str]:
|
| 38 |
+
domain_hits = data.get("domain_hits", [])
|
| 39 |
+
label_hits = data.get("label_hits", [])
|
| 40 |
return {
|
| 41 |
"label_set": label_set.name,
|
| 42 |
"image": str(image),
|
| 43 |
"label_set_hash": data.get("label_set_hash", ""),
|
| 44 |
"model_id": data.get("model_id", ""),
|
| 45 |
"chosen_domains": "|".join(data.get("chosen_domains", [])),
|
| 46 |
+
"domain_hits": "|".join(common.fmt_hit(d) for d in domain_hits),
|
| 47 |
+
"label_hits": "|".join(common.fmt_hit(l) for l in label_hits),
|
| 48 |
"elapsed_ms": str(data.get("elapsed_ms", "")),
|
| 49 |
"elapsed_domain_ms": str(data.get("elapsed_domain_ms", "")),
|
| 50 |
"elapsed_labels_ms": str(data.get("elapsed_labels_ms", "")),
|
| 51 |
}
|
| 52 |
|
| 53 |
|
| 54 |
+
def summarize_by_label_set(rows: list[dict[str, str]]) -> list[dict[str, str]]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
summary: dict[str, list[int]] = {}
|
| 56 |
for row in rows:
|
| 57 |
label = row["label_set"]
|
|
|
|
| 64 |
out_rows: list[dict[str, str]] = []
|
| 65 |
for label, times in summary.items():
|
| 66 |
avg = int(sum(times) / max(1, len(times)))
|
| 67 |
+
out_rows.append(
|
| 68 |
+
{
|
| 69 |
+
"label_set": label,
|
| 70 |
+
"count": str(len(times)),
|
| 71 |
+
"avg_elapsed_ms": str(avg),
|
| 72 |
+
"p50_elapsed_ms": str(common.percentile(times, 0.50)),
|
| 73 |
+
"p95_elapsed_ms": str(common.percentile(times, 0.95)),
|
| 74 |
+
}
|
| 75 |
+
)
|
| 76 |
+
return out_rows
|
| 77 |
|
| 78 |
|
| 79 |
def run(cfg: Config) -> None:
|
| 80 |
+
images = list(common.iter_images(cfg.images))
|
| 81 |
if not images:
|
| 82 |
raise SystemExit("No images found.")
|
| 83 |
if not cfg.label_sets:
|
|
|
|
| 86 |
rows: list[dict[str, str]] = []
|
| 87 |
with httpx.Client(base_url=cfg.api, timeout=30) as client:
|
| 88 |
for label_set in cfg.label_sets:
|
| 89 |
+
label_set_hash = common.upload_label_set(client, label_set)
|
| 90 |
for image in images:
|
| 91 |
+
data = common.classify_one(
|
|
|
|
| 92 |
client,
|
| 93 |
label_set_hash,
|
| 94 |
+
image_b64=common.encode_image_b64(image),
|
| 95 |
domain_top_n=cfg.domain_top_n,
|
| 96 |
top_k=cfg.top_k,
|
| 97 |
)
|
| 98 |
print(json.dumps({"label_set": label_set.name, "image": str(image), "result": data}))
|
| 99 |
rows.append(to_row(label_set, image, data))
|
| 100 |
|
| 101 |
+
fieldnames = [
|
| 102 |
+
"label_set",
|
| 103 |
+
"image",
|
| 104 |
+
"label_set_hash",
|
| 105 |
+
"model_id",
|
| 106 |
+
"chosen_domains",
|
| 107 |
+
"domain_hits",
|
| 108 |
+
"label_hits",
|
| 109 |
+
"elapsed_ms",
|
| 110 |
+
"elapsed_domain_ms",
|
| 111 |
+
"elapsed_labels_ms",
|
| 112 |
+
]
|
| 113 |
+
|
| 114 |
+
out_path = cfg.out_dir / f"eval_matrix_{common.timestamp()}.csv"
|
| 115 |
+
common.write_csv(out_path, rows, fieldnames)
|
| 116 |
+
|
| 117 |
+
if cfg.summary:
|
| 118 |
+
summary_rows = summarize_by_label_set(rows)
|
| 119 |
+
summary_path = cfg.out_dir / f"eval_matrix_summary_{common.timestamp()}.csv"
|
| 120 |
+
common.write_csv(summary_path, summary_rows, ["label_set", "count", "avg_elapsed_ms", "p50_elapsed_ms", "p95_elapsed_ms"])
|
| 121 |
|
| 122 |
|
| 123 |
@click.command()
|
|
|
|
| 127 |
@click.option("--domain-top-n", default=2, show_default=True, type=int)
|
| 128 |
@click.option("--top-k", default=5, show_default=True, type=int)
|
| 129 |
@click.option("--out-dir", default="data_results", show_default=True, type=click.Path(path_type=Path))
|
| 130 |
+
@click.option("--summary", is_flag=True, default=False)
|
| 131 |
def cli(
|
| 132 |
api: str,
|
| 133 |
label_sets_raw: tuple[str, ...],
|
|
|
|
| 135 |
domain_top_n: int,
|
| 136 |
top_k: int,
|
| 137 |
out_dir: Path,
|
| 138 |
+
summary: bool,
|
| 139 |
) -> None:
|
| 140 |
label_sets = expand_label_sets(label_sets_raw)
|
| 141 |
cfg = Config(
|
|
|
|
| 145 |
domain_top_n=domain_top_n,
|
| 146 |
top_k=top_k,
|
| 147 |
out_dir=out_dir,
|
| 148 |
+
summary=summary,
|
| 149 |
)
|
| 150 |
run(cfg)
|
| 151 |
|
tests/test_eval_common.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from eval import common
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def test_percentile_basic():
|
| 7 |
+
values = [10, 20, 30, 40, 50]
|
| 8 |
+
assert common.percentile(values, 0.0) == 10
|
| 9 |
+
assert common.percentile(values, 0.5) == 30
|
| 10 |
+
assert common.percentile(values, 1.0) == 50
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def test_summarize_latency_empty():
|
| 14 |
+
summary = common.summarize_latency([])
|
| 15 |
+
assert summary["count"] == "0"
|
| 16 |
+
assert summary["avg_elapsed_ms"] == "0"
|