Spaces:
Sleeping
Sleeping
| """ | |
| Pet-ID Test Framework | |
| Automatisiertes Testing mit Metriken, Threshold-Sweep und Per-Model-Analyse. | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import logging | |
| import shutil | |
| import tempfile | |
| import zipfile | |
| from dataclasses import dataclass, field, asdict | |
| from datetime import datetime | |
| from pathlib import Path | |
| from typing import Callable, Optional | |
| from logic import PetIdentifier, SIMILARITY_THRESHOLD, ENSEMBLE_WEIGHTS | |
| logger = logging.getLogger("pet_id") | |
| TEST_DATA_DIR = Path(__file__).parent / "test_data" | |
| IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png", ".bmp", ".webp"} | |
| # --- Dataclasses --- | |
| class ImageResult: | |
| path: str | |
| pet_name: str # ground-truth Tiername (aus Ordnername) | |
| matched_name: Optional[str] # erkannter Name oder None | |
| matched_id: Optional[str] | |
| ensemble_score: float | |
| mega_score: float | |
| dino_score: float | |
| classification: str # "TP", "FN", "FP" | |
| class PetTestResult: | |
| name: str | |
| total: int | |
| tp: int | |
| fn: int | |
| fp: int | |
| scores: list[float] = field(default_factory=list) | |
| class ThresholdPoint: | |
| threshold: float | |
| tp: int | |
| fn: int | |
| fp: int | |
| precision: float | |
| recall: float | |
| f1: float | |
| class ModelAnalysis: | |
| model_key: str # "mega", "dino", "ensemble" | |
| avg_tp_score: float | |
| avg_fn_score: float | |
| min_tp_score: float | |
| max_fn_score: float | |
| class TestRunResult: | |
| timestamp: str | |
| threshold: float | |
| num_registration: int | |
| total_images: int | |
| total_test_images: int | |
| tp: int | |
| fn: int | |
| fp: int | |
| precision: float | |
| recall: float | |
| f1: float | |
| pet_results: list[dict] = field(default_factory=list) | |
| image_results: list[dict] = field(default_factory=list) | |
| threshold_sweep: list[dict] = field(default_factory=list) | |
| model_analysis: list[dict] = field(default_factory=list) | |
| optimal_threshold: float = 0.0 | |
| optimal_f1: float = 0.0 | |
| # --- Discovery --- | |
| def discover_test_pets() -> dict[str, list[Path]]: | |
| """Scannt test_data/ nach Tier-Ordnern, extrahiert Zips automatisch. | |
| Returns: {pet_name: [image_paths]}""" | |
| if not TEST_DATA_DIR.exists(): | |
| return {} | |
| pets = {} | |
| for pet_dir in sorted(TEST_DATA_DIR.iterdir()): | |
| if not pet_dir.is_dir(): | |
| continue | |
| name = pet_dir.name | |
| # Zips automatisch entpacken | |
| for zf in pet_dir.glob("*.zip"): | |
| logger.info("Entpacke %s ...", zf.name) | |
| with zipfile.ZipFile(zf, "r") as z: | |
| z.extractall(pet_dir) | |
| logger.info("Zip entpackt: %d Dateien", len(list(pet_dir.iterdir())) - 1) | |
| # Bilder sammeln (rekursiv, da Zips Unterordner haben koennen) | |
| images = sorted([ | |
| p for p in pet_dir.rglob("*") | |
| if p.is_file() | |
| and p.suffix.lower() in IMAGE_EXTENSIONS | |
| and not p.name.startswith(".") | |
| and "__MACOSX" not in str(p) | |
| ]) | |
| if images: | |
| pets[name] = images | |
| logger.info("Test-Daten: %s -> %d Bilder", name, len(images)) | |
| return pets | |
| # --- Test-Kernfunktionen --- | |
| def _test_single_image( | |
| identifier: PetIdentifier, | |
| image_path: Path, | |
| ground_truth_name: str, | |
| registered_names: set[str], | |
| ) -> ImageResult: | |
| """Klassifiziert ein einzelnes Bild als TP/FN/FP.""" | |
| pet_id, name, score, details = identifier.identify(str(image_path)) | |
| # Score-Details extrahieren | |
| mega_score = 0.0 | |
| dino_score = 0.0 | |
| if details.get("scores"): | |
| # Besten Match pro Modell finden | |
| for d in details["scores"]: | |
| if d["name"].lower() == ground_truth_name.lower() or (pet_id and d["id"] == pet_id): | |
| mega_score = d["mega_score"] | |
| dino_score = d["dino_score"] | |
| break | |
| # Falls kein spezifischer Match: hoechsten Score nehmen | |
| if mega_score == 0.0 and dino_score == 0.0 and details["scores"]: | |
| best = max(details["scores"], key=lambda x: x["ensemble_score"]) | |
| mega_score = best["mega_score"] | |
| dino_score = best["dino_score"] | |
| # Klassifikation | |
| if pet_id and name and name.lower() == ground_truth_name.lower(): | |
| classification = "TP" | |
| elif pet_id and name and name.lower() != ground_truth_name.lower(): | |
| classification = "FP" | |
| else: | |
| classification = "FN" | |
| return ImageResult( | |
| path=str(image_path), | |
| pet_name=ground_truth_name, | |
| matched_name=name, | |
| matched_id=pet_id, | |
| ensemble_score=score, | |
| mega_score=mega_score, | |
| dino_score=dino_score, | |
| classification=classification, | |
| ) | |
| def _threshold_sweep( | |
| image_results: list[ImageResult], | |
| start: float = 0.20, | |
| end: float = 0.80, | |
| step: float = 0.01, | |
| ) -> tuple[list[ThresholdPoint], float, float]: | |
| """Sweep ueber Thresholds, berechnet Metriken bei jedem Punkt. | |
| Returns: (points, optimal_threshold, optimal_f1)""" | |
| points = [] | |
| best_f1 = 0.0 | |
| best_thresh = start | |
| thresh = start | |
| while thresh <= end + 1e-9: | |
| tp = fn = fp = 0 | |
| for r in image_results: | |
| if r.ensemble_score >= thresh: | |
| if r.pet_name.lower() == (r.matched_name or "").lower(): | |
| tp += 1 | |
| else: | |
| fp += 1 | |
| else: | |
| fn += 1 | |
| precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0 | |
| recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0 | |
| f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0 | |
| points.append(ThresholdPoint( | |
| threshold=round(thresh, 4), | |
| tp=tp, fn=fn, fp=fp, | |
| precision=round(precision, 4), | |
| recall=round(recall, 4), | |
| f1=round(f1, 4), | |
| )) | |
| if f1 > best_f1: | |
| best_f1 = f1 | |
| best_thresh = round(thresh, 4) | |
| thresh += step | |
| return points, best_thresh, best_f1 | |
| def _per_model_analysis(image_results: list[ImageResult]) -> list[ModelAnalysis]: | |
| """Analysiert die Performance jedes Modells einzeln.""" | |
| analyses = [] | |
| for key, score_fn in [ | |
| ("mega", lambda r: r.mega_score), | |
| ("dino", lambda r: r.dino_score), | |
| ("ensemble", lambda r: r.ensemble_score), | |
| ]: | |
| tp_scores = [score_fn(r) for r in image_results if r.classification == "TP"] | |
| fn_scores = [score_fn(r) for r in image_results if r.classification == "FN"] | |
| analyses.append(ModelAnalysis( | |
| model_key=key, | |
| avg_tp_score=round(sum(tp_scores) / len(tp_scores), 4) if tp_scores else 0.0, | |
| avg_fn_score=round(sum(fn_scores) / len(fn_scores), 4) if fn_scores else 0.0, | |
| min_tp_score=round(min(tp_scores), 4) if tp_scores else 0.0, | |
| max_fn_score=round(max(fn_scores), 4) if fn_scores else 0.0, | |
| )) | |
| return analyses | |
| def run_test( | |
| source_identifier: PetIdentifier, | |
| num_registration: int = 3, | |
| progress_callback: Optional[Callable[[float, str], None]] = None, | |
| ) -> TestRunResult: | |
| """Hauptfunktion: Fuehrt kompletten Test durch. | |
| - Erstellt temporaere DB (produktive DB wird NICHT angefasst) | |
| - Registriert erste N Bilder pro Tier, testet den Rest | |
| - Berechnet alle Metriken inkl. Threshold-Sweep | |
| """ | |
| # Test-Daten laden | |
| test_pets = discover_test_pets() | |
| if not test_pets: | |
| raise ValueError("Keine Test-Daten gefunden in test_data/") | |
| # Temporaeres Verzeichnis fuer Test-DB | |
| tmp_dir = Path(tempfile.mkdtemp(prefix="petid_test_")) | |
| tmp_db = tmp_dir / "test_db.json" | |
| tmp_storage = tmp_dir / "storage" | |
| def _progress(frac: float, msg: str): | |
| if progress_callback: | |
| progress_callback(frac, msg) | |
| logger.info("TEST [%d%%] %s", int(frac * 100), msg) | |
| try: | |
| # Modelle sicherstellen | |
| _progress(0.0, "Lade Modelle...") | |
| source_identifier.load_models() | |
| # Test-Identifier mit geteilten Modellen erstellen | |
| test_id = PetIdentifier.create_with_shared_models( | |
| source=source_identifier, | |
| db_path=tmp_db, | |
| storage_dir=tmp_storage, | |
| ) | |
| # Phase 1: Registrierung | |
| total_images = sum(len(imgs) for imgs in test_pets.values()) | |
| registered_names = set() | |
| test_images = [] # (path, ground_truth_name) | |
| for pet_name, images in test_pets.items(): | |
| reg_images = images[:num_registration] | |
| test_imgs = images[num_registration:] | |
| if not test_imgs: | |
| logger.warning("Tier '%s' hat nur %d Bilder, ueberspringe (brauche >%d)", | |
| pet_name, len(images), num_registration) | |
| continue | |
| _progress(0.05, f"Registriere {pet_name} ({len(reg_images)} Bilder)...") | |
| reg_paths = [str(p) for p in reg_images] | |
| test_id.register(pet_name, reg_paths) | |
| registered_names.add(pet_name) | |
| for img in test_imgs: | |
| test_images.append((img, pet_name)) | |
| if not test_images: | |
| raise ValueError("Keine Test-Bilder uebrig nach Registrierung") | |
| # Phase 2: Test | |
| image_results = [] | |
| total_test = len(test_images) | |
| for i, (img_path, ground_truth) in enumerate(test_images): | |
| frac = 0.10 + 0.80 * (i / total_test) | |
| _progress(frac, f"Teste {img_path.name} ({i + 1}/{total_test})") | |
| result = _test_single_image(test_id, img_path, ground_truth, registered_names) | |
| image_results.append(result) | |
| # Phase 3: Metriken | |
| _progress(0.92, "Berechne Metriken...") | |
| tp = sum(1 for r in image_results if r.classification == "TP") | |
| fn = sum(1 for r in image_results if r.classification == "FN") | |
| fp = sum(1 for r in image_results if r.classification == "FP") | |
| precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0 | |
| recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0 | |
| f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0 | |
| # Per-Pet Ergebnisse | |
| pet_results = {} | |
| for r in image_results: | |
| if r.pet_name not in pet_results: | |
| pet_results[r.pet_name] = PetTestResult( | |
| name=r.pet_name, total=0, tp=0, fn=0, fp=0, scores=[]) | |
| pr = pet_results[r.pet_name] | |
| pr.total += 1 | |
| pr.scores.append(r.ensemble_score) | |
| if r.classification == "TP": | |
| pr.tp += 1 | |
| elif r.classification == "FN": | |
| pr.fn += 1 | |
| elif r.classification == "FP": | |
| pr.fp += 1 | |
| # Threshold-Sweep | |
| _progress(0.95, "Threshold-Sweep...") | |
| sweep_points, opt_thresh, opt_f1 = _threshold_sweep(image_results) | |
| # Per-Model-Analyse | |
| model_analysis = _per_model_analysis(image_results) | |
| _progress(0.98, "Erstelle Ergebnis...") | |
| run_result = TestRunResult( | |
| timestamp=datetime.now().isoformat(), | |
| threshold=SIMILARITY_THRESHOLD, | |
| num_registration=num_registration, | |
| total_images=total_images, | |
| total_test_images=total_test, | |
| tp=tp, fn=fn, fp=fp, | |
| precision=round(precision, 4), | |
| recall=round(recall, 4), | |
| f1=round(f1, 4), | |
| pet_results=[asdict(pr) for pr in pet_results.values()], | |
| image_results=[asdict(r) for r in image_results], | |
| threshold_sweep=[asdict(p) for p in sweep_points], | |
| model_analysis=[asdict(a) for a in model_analysis], | |
| optimal_threshold=opt_thresh, | |
| optimal_f1=opt_f1, | |
| ) | |
| # Ergebnis speichern | |
| _progress(1.0, "Fertig!") | |
| save_results(run_result) | |
| return run_result | |
| finally: | |
| # Temporaeres Verzeichnis aufraeumen | |
| shutil.rmtree(tmp_dir, ignore_errors=True) | |
| # --- Persistenz --- | |
| def save_results(result: TestRunResult) -> Path: | |
| """Speichert Testergebnis als JSON.""" | |
| TEST_DATA_DIR.mkdir(exist_ok=True) | |
| ts = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| path = TEST_DATA_DIR / f"results_{ts}.json" | |
| with open(path, "w") as f: | |
| json.dump(asdict(result), f, indent=2, ensure_ascii=False) | |
| logger.info("Test-Ergebnisse gespeichert: %s", path) | |
| return path | |
| def load_results(path: Path) -> dict: | |
| """Laedt Testergebnis aus JSON als dict.""" | |
| with open(path, "r") as f: | |
| return json.load(f) | |
| def list_result_files() -> list[str]: | |
| """Listet alle gespeicherten Ergebnis-Dateien.""" | |
| if not TEST_DATA_DIR.exists(): | |
| return [] | |
| files = sorted(TEST_DATA_DIR.glob("results_*.json"), reverse=True) | |
| return [f.name for f in files] | |