File size: 4,929 Bytes
c91c838
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
from __future__ import annotations

from pathlib import Path
from typing import Any

import joblib
import numpy as np
from PIL import Image

from .compare_models import load_best_model_record
from .data_discovery import CANONICAL_LABELS
from .preprocessing import load_pil_image
from .utils import get_logger


LOGGER = get_logger(__name__)


def model_record_from_file(path: str | Path) -> dict[str, Any]:
    path = Path(path)
    if path.suffix == ".joblib":
        bundle = joblib.load(path)
        meta = bundle["metadata"]
        return {
            "model_name": meta["model_name"],
            "model_type": "classical",
            "feature_type": meta["feature_type"],
            "model_path": str(path),
        }
    if path.suffix == ".pt":
        from .dl_models import load_torch_checkpoint

        ckpt = load_torch_checkpoint(path, map_location="cpu")
        return {
            "model_name": ckpt.get("model_name", ckpt["model_key"]),
            "model_type": "deep_learning",
            "model_key": ckpt["model_key"],
            "family": ckpt.get("family", "cnn"),
            "model_path": str(path),
        }
    raise ValueError(f"Unsupported model file: {path}")


def list_available_model_records(config: dict[str, Any]) -> list[dict[str, Any]]:
    model_dir = Path(config["paths"]["model_dir"])
    records: list[dict[str, Any]] = []
    for path in sorted(model_dir.glob("*.joblib")) + sorted(model_dir.glob("*.pt")):
        try:
            records.append(model_record_from_file(path))
        except Exception as exc:
            LOGGER.warning("Could not load model metadata for %s: %s", path, exc)
    return records


class EggDamagePredictor:
    def __init__(self, record: dict[str, Any], config: dict[str, Any]) -> None:
        self.record = record
        self.config = config
        self.model_name = record["model_name"]
        self.model_type = record["model_type"]
        self.model_path = Path(record["model_path"])
        self.class_names = list(CANONICAL_LABELS)
        self.device = None
        if self.model_type == "classical":
            bundle = joblib.load(self.model_path)
            self.pipeline = bundle["pipeline"]
            self.metadata = bundle["metadata"]
            self.feature_type = self.metadata["feature_type"]
            self.model = None
        elif self.model_type == "deep_learning":
            import torch

            from .augmentations import build_eval_transform
            from .dl_models import create_model, load_torch_checkpoint

            checkpoint = load_torch_checkpoint(self.model_path, map_location="cpu")
            self.metadata = checkpoint
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            self.model = create_model(checkpoint["model_key"], checkpoint.get("config", config), pretrained=False)
            self.model.load_state_dict(checkpoint["state_dict"])
            self.model.to(self.device)
            self.model.eval()
            self.transform = build_eval_transform(checkpoint.get("config", config))
            self.pipeline = None
            self.feature_type = None
        else:
            raise ValueError(f"Unsupported model type: {self.model_type}")

    def predict_proba(self, image: str | Path | Image.Image | np.ndarray) -> np.ndarray:
        pil = load_pil_image(Image.fromarray(image) if isinstance(image, np.ndarray) else image, mode="RGB")
        if self.model_type == "classical":
            from .classical_features import extract_single_feature

            feature = extract_single_feature(pil, self.feature_type, self.metadata.get("config", self.config))
            return self.pipeline.predict_proba(feature.reshape(1, -1))[0]
        import torch

        assert self.model is not None and self.device is not None
        tensor = self.transform(pil).unsqueeze(0).to(self.device)
        with torch.no_grad():
            logits = self.model(tensor)
            probs = torch.softmax(logits, dim=1).detach().cpu().numpy()[0]
        return probs

    def predict(self, image: str | Path | Image.Image | np.ndarray) -> dict[str, Any]:
        probs = self.predict_proba(image)
        pred_idx = int(np.argmax(probs))
        confidence = float(probs[pred_idx])
        return {
            "model_name": self.model_name,
            "model_type": self.model_type,
            "predicted_label": self.class_names[pred_idx],
            "predicted_index": pred_idx,
            "confidence": confidence,
            "probabilities": {self.class_names[i]: float(probs[i]) for i in range(len(self.class_names))},
            "prob_damaged": float(probs[1]),
            "low_confidence": confidence < float(self.config["gradio"].get("low_confidence_threshold", 0.65)),
        }


def load_best_predictor(config: dict[str, Any]) -> EggDamagePredictor:
    return EggDamagePredictor(load_best_model_record(config), config)