from __future__ import annotations from pathlib import Path from typing import Any import joblib import numpy as np from PIL import Image from .compare_models import load_best_model_record from .data_discovery import CANONICAL_LABELS from .preprocessing import load_pil_image from .utils import get_logger LOGGER = get_logger(__name__) def model_record_from_file(path: str | Path) -> dict[str, Any]: path = Path(path) if path.suffix == ".joblib": bundle = joblib.load(path) meta = bundle["metadata"] return { "model_name": meta["model_name"], "model_type": "classical", "feature_type": meta["feature_type"], "model_path": str(path), } if path.suffix == ".pt": from .dl_models import load_torch_checkpoint ckpt = load_torch_checkpoint(path, map_location="cpu") return { "model_name": ckpt.get("model_name", ckpt["model_key"]), "model_type": "deep_learning", "model_key": ckpt["model_key"], "family": ckpt.get("family", "cnn"), "model_path": str(path), } raise ValueError(f"Unsupported model file: {path}") def list_available_model_records(config: dict[str, Any]) -> list[dict[str, Any]]: model_dir = Path(config["paths"]["model_dir"]) records: list[dict[str, Any]] = [] for path in sorted(model_dir.glob("*.joblib")) + sorted(model_dir.glob("*.pt")): try: records.append(model_record_from_file(path)) except Exception as exc: LOGGER.warning("Could not load model metadata for %s: %s", path, exc) return records class EggDamagePredictor: def __init__(self, record: dict[str, Any], config: dict[str, Any]) -> None: self.record = record self.config = config self.model_name = record["model_name"] self.model_type = record["model_type"] self.model_path = Path(record["model_path"]) self.class_names = list(CANONICAL_LABELS) self.device = None if self.model_type == "classical": bundle = joblib.load(self.model_path) self.pipeline = bundle["pipeline"] self.metadata = bundle["metadata"] self.feature_type = self.metadata["feature_type"] self.model = None elif self.model_type == "deep_learning": import torch from .augmentations import build_eval_transform from .dl_models import create_model, load_torch_checkpoint checkpoint = load_torch_checkpoint(self.model_path, map_location="cpu") self.metadata = checkpoint self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model = create_model(checkpoint["model_key"], checkpoint.get("config", config), pretrained=False) self.model.load_state_dict(checkpoint["state_dict"]) self.model.to(self.device) self.model.eval() self.transform = build_eval_transform(checkpoint.get("config", config)) self.pipeline = None self.feature_type = None else: raise ValueError(f"Unsupported model type: {self.model_type}") def predict_proba(self, image: str | Path | Image.Image | np.ndarray) -> np.ndarray: pil = load_pil_image(Image.fromarray(image) if isinstance(image, np.ndarray) else image, mode="RGB") if self.model_type == "classical": from .classical_features import extract_single_feature feature = extract_single_feature(pil, self.feature_type, self.metadata.get("config", self.config)) return self.pipeline.predict_proba(feature.reshape(1, -1))[0] import torch assert self.model is not None and self.device is not None tensor = self.transform(pil).unsqueeze(0).to(self.device) with torch.no_grad(): logits = self.model(tensor) probs = torch.softmax(logits, dim=1).detach().cpu().numpy()[0] return probs def predict(self, image: str | Path | Image.Image | np.ndarray) -> dict[str, Any]: probs = self.predict_proba(image) pred_idx = int(np.argmax(probs)) confidence = float(probs[pred_idx]) return { "model_name": self.model_name, "model_type": self.model_type, "predicted_label": self.class_names[pred_idx], "predicted_index": pred_idx, "confidence": confidence, "probabilities": {self.class_names[i]: float(probs[i]) for i in range(len(self.class_names))}, "prob_damaged": float(probs[1]), "low_confidence": confidence < float(self.config["gradio"].get("low_confidence_threshold", 0.65)), } def load_best_predictor(config: dict[str, Any]) -> EggDamagePredictor: return EggDamagePredictor(load_best_model_record(config), config)