budijuarto commited on
Commit
c91c838
·
verified ·
1 Parent(s): e14b469

Upload src/egg_damage/inference.py

Browse files
Files changed (1) hide show
  1. src/egg_damage/inference.py +122 -0
src/egg_damage/inference.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from pathlib import Path
4
+ from typing import Any
5
+
6
+ import joblib
7
+ import numpy as np
8
+ from PIL import Image
9
+
10
+ from .compare_models import load_best_model_record
11
+ from .data_discovery import CANONICAL_LABELS
12
+ from .preprocessing import load_pil_image
13
+ from .utils import get_logger
14
+
15
+
16
+ LOGGER = get_logger(__name__)
17
+
18
+
19
+ def model_record_from_file(path: str | Path) -> dict[str, Any]:
20
+ path = Path(path)
21
+ if path.suffix == ".joblib":
22
+ bundle = joblib.load(path)
23
+ meta = bundle["metadata"]
24
+ return {
25
+ "model_name": meta["model_name"],
26
+ "model_type": "classical",
27
+ "feature_type": meta["feature_type"],
28
+ "model_path": str(path),
29
+ }
30
+ if path.suffix == ".pt":
31
+ from .dl_models import load_torch_checkpoint
32
+
33
+ ckpt = load_torch_checkpoint(path, map_location="cpu")
34
+ return {
35
+ "model_name": ckpt.get("model_name", ckpt["model_key"]),
36
+ "model_type": "deep_learning",
37
+ "model_key": ckpt["model_key"],
38
+ "family": ckpt.get("family", "cnn"),
39
+ "model_path": str(path),
40
+ }
41
+ raise ValueError(f"Unsupported model file: {path}")
42
+
43
+
44
+ def list_available_model_records(config: dict[str, Any]) -> list[dict[str, Any]]:
45
+ model_dir = Path(config["paths"]["model_dir"])
46
+ records: list[dict[str, Any]] = []
47
+ for path in sorted(model_dir.glob("*.joblib")) + sorted(model_dir.glob("*.pt")):
48
+ try:
49
+ records.append(model_record_from_file(path))
50
+ except Exception as exc:
51
+ LOGGER.warning("Could not load model metadata for %s: %s", path, exc)
52
+ return records
53
+
54
+
55
+ class EggDamagePredictor:
56
+ def __init__(self, record: dict[str, Any], config: dict[str, Any]) -> None:
57
+ self.record = record
58
+ self.config = config
59
+ self.model_name = record["model_name"]
60
+ self.model_type = record["model_type"]
61
+ self.model_path = Path(record["model_path"])
62
+ self.class_names = list(CANONICAL_LABELS)
63
+ self.device = None
64
+ if self.model_type == "classical":
65
+ bundle = joblib.load(self.model_path)
66
+ self.pipeline = bundle["pipeline"]
67
+ self.metadata = bundle["metadata"]
68
+ self.feature_type = self.metadata["feature_type"]
69
+ self.model = None
70
+ elif self.model_type == "deep_learning":
71
+ import torch
72
+
73
+ from .augmentations import build_eval_transform
74
+ from .dl_models import create_model, load_torch_checkpoint
75
+
76
+ checkpoint = load_torch_checkpoint(self.model_path, map_location="cpu")
77
+ self.metadata = checkpoint
78
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
79
+ self.model = create_model(checkpoint["model_key"], checkpoint.get("config", config), pretrained=False)
80
+ self.model.load_state_dict(checkpoint["state_dict"])
81
+ self.model.to(self.device)
82
+ self.model.eval()
83
+ self.transform = build_eval_transform(checkpoint.get("config", config))
84
+ self.pipeline = None
85
+ self.feature_type = None
86
+ else:
87
+ raise ValueError(f"Unsupported model type: {self.model_type}")
88
+
89
+ def predict_proba(self, image: str | Path | Image.Image | np.ndarray) -> np.ndarray:
90
+ pil = load_pil_image(Image.fromarray(image) if isinstance(image, np.ndarray) else image, mode="RGB")
91
+ if self.model_type == "classical":
92
+ from .classical_features import extract_single_feature
93
+
94
+ feature = extract_single_feature(pil, self.feature_type, self.metadata.get("config", self.config))
95
+ return self.pipeline.predict_proba(feature.reshape(1, -1))[0]
96
+ import torch
97
+
98
+ assert self.model is not None and self.device is not None
99
+ tensor = self.transform(pil).unsqueeze(0).to(self.device)
100
+ with torch.no_grad():
101
+ logits = self.model(tensor)
102
+ probs = torch.softmax(logits, dim=1).detach().cpu().numpy()[0]
103
+ return probs
104
+
105
+ def predict(self, image: str | Path | Image.Image | np.ndarray) -> dict[str, Any]:
106
+ probs = self.predict_proba(image)
107
+ pred_idx = int(np.argmax(probs))
108
+ confidence = float(probs[pred_idx])
109
+ return {
110
+ "model_name": self.model_name,
111
+ "model_type": self.model_type,
112
+ "predicted_label": self.class_names[pred_idx],
113
+ "predicted_index": pred_idx,
114
+ "confidence": confidence,
115
+ "probabilities": {self.class_names[i]: float(probs[i]) for i in range(len(self.class_names))},
116
+ "prob_damaged": float(probs[1]),
117
+ "low_confidence": confidence < float(self.config["gradio"].get("low_confidence_threshold", 0.65)),
118
+ }
119
+
120
+
121
+ def load_best_predictor(config: dict[str, Any]) -> EggDamagePredictor:
122
+ return EggDamagePredictor(load_best_model_record(config), config)