| from __future__ import annotations |
|
|
| from pathlib import Path |
| from typing import Any |
|
|
| import joblib |
| import numpy as np |
| from PIL import Image |
|
|
| from .compare_models import load_best_model_record |
| from .data_discovery import CANONICAL_LABELS |
| from .preprocessing import load_pil_image |
| from .utils import get_logger |
|
|
|
|
| LOGGER = get_logger(__name__) |
|
|
|
|
| def model_record_from_file(path: str | Path) -> dict[str, Any]: |
| path = Path(path) |
| if path.suffix == ".joblib": |
| bundle = joblib.load(path) |
| meta = bundle["metadata"] |
| return { |
| "model_name": meta["model_name"], |
| "model_type": "classical", |
| "feature_type": meta["feature_type"], |
| "model_path": str(path), |
| } |
| if path.suffix == ".pt": |
| from .dl_models import load_torch_checkpoint |
|
|
| ckpt = load_torch_checkpoint(path, map_location="cpu") |
| return { |
| "model_name": ckpt.get("model_name", ckpt["model_key"]), |
| "model_type": "deep_learning", |
| "model_key": ckpt["model_key"], |
| "family": ckpt.get("family", "cnn"), |
| "model_path": str(path), |
| } |
| raise ValueError(f"Unsupported model file: {path}") |
|
|
|
|
| def list_available_model_records(config: dict[str, Any]) -> list[dict[str, Any]]: |
| model_dir = Path(config["paths"]["model_dir"]) |
| records: list[dict[str, Any]] = [] |
| for path in sorted(model_dir.glob("*.joblib")) + sorted(model_dir.glob("*.pt")): |
| try: |
| records.append(model_record_from_file(path)) |
| except Exception as exc: |
| LOGGER.warning("Could not load model metadata for %s: %s", path, exc) |
| return records |
|
|
|
|
| class EggDamagePredictor: |
| def __init__(self, record: dict[str, Any], config: dict[str, Any]) -> None: |
| self.record = record |
| self.config = config |
| self.model_name = record["model_name"] |
| self.model_type = record["model_type"] |
| self.model_path = Path(record["model_path"]) |
| self.class_names = list(CANONICAL_LABELS) |
| self.device = None |
| if self.model_type == "classical": |
| bundle = joblib.load(self.model_path) |
| self.pipeline = bundle["pipeline"] |
| self.metadata = bundle["metadata"] |
| self.feature_type = self.metadata["feature_type"] |
| self.model = None |
| elif self.model_type == "deep_learning": |
| import torch |
|
|
| from .augmentations import build_eval_transform |
| from .dl_models import create_model, load_torch_checkpoint |
|
|
| checkpoint = load_torch_checkpoint(self.model_path, map_location="cpu") |
| self.metadata = checkpoint |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| self.model = create_model(checkpoint["model_key"], checkpoint.get("config", config), pretrained=False) |
| self.model.load_state_dict(checkpoint["state_dict"]) |
| self.model.to(self.device) |
| self.model.eval() |
| self.transform = build_eval_transform(checkpoint.get("config", config)) |
| self.pipeline = None |
| self.feature_type = None |
| else: |
| raise ValueError(f"Unsupported model type: {self.model_type}") |
|
|
| def predict_proba(self, image: str | Path | Image.Image | np.ndarray) -> np.ndarray: |
| pil = load_pil_image(Image.fromarray(image) if isinstance(image, np.ndarray) else image, mode="RGB") |
| if self.model_type == "classical": |
| from .classical_features import extract_single_feature |
|
|
| feature = extract_single_feature(pil, self.feature_type, self.metadata.get("config", self.config)) |
| return self.pipeline.predict_proba(feature.reshape(1, -1))[0] |
| import torch |
|
|
| assert self.model is not None and self.device is not None |
| tensor = self.transform(pil).unsqueeze(0).to(self.device) |
| with torch.no_grad(): |
| logits = self.model(tensor) |
| probs = torch.softmax(logits, dim=1).detach().cpu().numpy()[0] |
| return probs |
|
|
| def predict(self, image: str | Path | Image.Image | np.ndarray) -> dict[str, Any]: |
| probs = self.predict_proba(image) |
| pred_idx = int(np.argmax(probs)) |
| confidence = float(probs[pred_idx]) |
| return { |
| "model_name": self.model_name, |
| "model_type": self.model_type, |
| "predicted_label": self.class_names[pred_idx], |
| "predicted_index": pred_idx, |
| "confidence": confidence, |
| "probabilities": {self.class_names[i]: float(probs[i]) for i in range(len(self.class_names))}, |
| "prob_damaged": float(probs[1]), |
| "low_confidence": confidence < float(self.config["gradio"].get("low_confidence_threshold", 0.65)), |
| } |
|
|
|
|
| def load_best_predictor(config: dict[str, Any]) -> EggDamagePredictor: |
| return EggDamagePredictor(load_best_model_record(config), config) |
|
|