from __future__ import annotations import os from dataclasses import dataclass from functools import lru_cache from pathlib import Path from typing import Any import torch import torch.nn as nn import torchvision.transforms as T from PIL import Image from torchvision import models from protopnet import build_ppnet BASE_DIR = Path(__file__).resolve().parent CLASS_NAMES = ["no_person", "person"] @dataclass(frozen=True) class ModelConfig: name: str backend: str model_path: Path image_size: int normalize_mean: tuple[float, float, float] normalize_std: tuple[float, float, float] MODEL_CONFIGS: dict[str, ModelConfig] = { "resnet18_presence": ModelConfig( name="resnet18_presence", backend="resnet18", model_path=BASE_DIR / "best_global_model_presence.pt", image_size=224, normalize_mean=(0.485, 0.456, 0.406), normalize_std=(0.229, 0.224, 0.225), ), "ppnet_baseline": ModelConfig( name="ppnet_baseline", backend="ppnet", model_path=BASE_DIR / "baseline_40_model.pt.tar", image_size=128, normalize_mean=(0.4914, 0.4822, 0.4465), normalize_std=(0.2023, 0.1994, 0.2010), ), } DEFAULT_MODEL_NAME = os.getenv("SECUREML_MODEL", "ppnet_baseline") def build_resnet18(num_classes: int = 2) -> nn.Module: model = models.resnet18(weights=None) in_features = model.fc.in_features model.fc = nn.Linear(in_features, num_classes) return model def _normalize_prototype_shape(raw_value: Any) -> tuple[int, int, int, int]: if isinstance(raw_value, tuple): return raw_value if isinstance(raw_value, list): return tuple(raw_value) raise ValueError(f"Unsupported prototype_shape value: {raw_value!r}") def get_model_config(name: str | None = None) -> ModelConfig: model_name = name or DEFAULT_MODEL_NAME try: return MODEL_CONFIGS[model_name] except KeyError as exc: available = ", ".join(sorted(MODEL_CONFIGS)) raise ValueError(f"Unknown model '{model_name}'. Available: {available}") from exc class PresenceModelService: def __init__(self, config: ModelConfig): if not config.model_path.exists(): raise FileNotFoundError(f"Model not found: {config.model_path}") self.config = config self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model = self._load_model().to(self.device) self.model.eval() self.transform = T.Compose( [ T.Resize((config.image_size, config.image_size)), T.ToTensor(), T.Normalize(config.normalize_mean, config.normalize_std), ] ) def _load_model(self) -> nn.Module: if self.config.backend == "resnet18": model = build_resnet18(num_classes=len(CLASS_NAMES)) state = torch.load(self.config.model_path, map_location="cpu") model.load_state_dict(state, strict=True) return model if self.config.backend == "ppnet": checkpoint = torch.load(self.config.model_path, map_location="cpu") state_dict = checkpoint.get("state_dict") if not isinstance(state_dict, dict): raise ValueError("Invalid PPNet checkpoint: missing state_dict.") params = checkpoint.get("params_dict", {}) model = build_ppnet( base_architecture=str(params.get("base_architecture", "vgg19")), img_size=int(params.get("img_size", self.config.image_size)), prototype_shape=_normalize_prototype_shape( params.get("prototype_shape", (40, 128, 1, 1)) ), num_classes=int(params.get("num_classes", len(CLASS_NAMES))), prototype_activation_function=str( params.get("prototype_activation_function", "log") ), add_on_layers_type=str(params.get("add_on_layers_type", "regular")), ) model.load_state_dict(state_dict, strict=True) return model raise ValueError(f"Unsupported backend: {self.config.backend}") def predict_image(self, image: Image.Image) -> dict[str, Any]: x = self.transform(image).unsqueeze(0).to(self.device) with torch.no_grad(): outputs = self.model(x) logits = outputs[0] if isinstance(outputs, (tuple, list)) else outputs probs = torch.softmax(logits, dim=-1)[0] pred_idx = int(torch.argmax(probs).item()) probabilities = { CLASS_NAMES[i]: round(float(probs[i].item()), 6) for i in range(len(CLASS_NAMES)) } return { "label": CLASS_NAMES[pred_idx], "prediction_index": pred_idx, "probabilities": probabilities, "model_name": self.config.name, "model_backend": self.config.backend, "model_path": self.config.model_path.name, } @lru_cache(maxsize=None) def get_model_service(model_name: str | None = None) -> PresenceModelService: return PresenceModelService(get_model_config(model_name))