budijuarto's picture
Upload src/egg_damage/inference.py
c91c838 verified
from __future__ import annotations
from pathlib import Path
from typing import Any
import joblib
import numpy as np
from PIL import Image
from .compare_models import load_best_model_record
from .data_discovery import CANONICAL_LABELS
from .preprocessing import load_pil_image
from .utils import get_logger
LOGGER = get_logger(__name__)
def model_record_from_file(path: str | Path) -> dict[str, Any]:
path = Path(path)
if path.suffix == ".joblib":
bundle = joblib.load(path)
meta = bundle["metadata"]
return {
"model_name": meta["model_name"],
"model_type": "classical",
"feature_type": meta["feature_type"],
"model_path": str(path),
}
if path.suffix == ".pt":
from .dl_models import load_torch_checkpoint
ckpt = load_torch_checkpoint(path, map_location="cpu")
return {
"model_name": ckpt.get("model_name", ckpt["model_key"]),
"model_type": "deep_learning",
"model_key": ckpt["model_key"],
"family": ckpt.get("family", "cnn"),
"model_path": str(path),
}
raise ValueError(f"Unsupported model file: {path}")
def list_available_model_records(config: dict[str, Any]) -> list[dict[str, Any]]:
model_dir = Path(config["paths"]["model_dir"])
records: list[dict[str, Any]] = []
for path in sorted(model_dir.glob("*.joblib")) + sorted(model_dir.glob("*.pt")):
try:
records.append(model_record_from_file(path))
except Exception as exc:
LOGGER.warning("Could not load model metadata for %s: %s", path, exc)
return records
class EggDamagePredictor:
def __init__(self, record: dict[str, Any], config: dict[str, Any]) -> None:
self.record = record
self.config = config
self.model_name = record["model_name"]
self.model_type = record["model_type"]
self.model_path = Path(record["model_path"])
self.class_names = list(CANONICAL_LABELS)
self.device = None
if self.model_type == "classical":
bundle = joblib.load(self.model_path)
self.pipeline = bundle["pipeline"]
self.metadata = bundle["metadata"]
self.feature_type = self.metadata["feature_type"]
self.model = None
elif self.model_type == "deep_learning":
import torch
from .augmentations import build_eval_transform
from .dl_models import create_model, load_torch_checkpoint
checkpoint = load_torch_checkpoint(self.model_path, map_location="cpu")
self.metadata = checkpoint
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = create_model(checkpoint["model_key"], checkpoint.get("config", config), pretrained=False)
self.model.load_state_dict(checkpoint["state_dict"])
self.model.to(self.device)
self.model.eval()
self.transform = build_eval_transform(checkpoint.get("config", config))
self.pipeline = None
self.feature_type = None
else:
raise ValueError(f"Unsupported model type: {self.model_type}")
def predict_proba(self, image: str | Path | Image.Image | np.ndarray) -> np.ndarray:
pil = load_pil_image(Image.fromarray(image) if isinstance(image, np.ndarray) else image, mode="RGB")
if self.model_type == "classical":
from .classical_features import extract_single_feature
feature = extract_single_feature(pil, self.feature_type, self.metadata.get("config", self.config))
return self.pipeline.predict_proba(feature.reshape(1, -1))[0]
import torch
assert self.model is not None and self.device is not None
tensor = self.transform(pil).unsqueeze(0).to(self.device)
with torch.no_grad():
logits = self.model(tensor)
probs = torch.softmax(logits, dim=1).detach().cpu().numpy()[0]
return probs
def predict(self, image: str | Path | Image.Image | np.ndarray) -> dict[str, Any]:
probs = self.predict_proba(image)
pred_idx = int(np.argmax(probs))
confidence = float(probs[pred_idx])
return {
"model_name": self.model_name,
"model_type": self.model_type,
"predicted_label": self.class_names[pred_idx],
"predicted_index": pred_idx,
"confidence": confidence,
"probabilities": {self.class_names[i]: float(probs[i]) for i in range(len(self.class_names))},
"prob_damaged": float(probs[1]),
"low_confidence": confidence < float(self.config["gradio"].get("low_confidence_threshold", 0.65)),
}
def load_best_predictor(config: dict[str, Any]) -> EggDamagePredictor:
return EggDamagePredictor(load_best_model_record(config), config)