Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import os | |
| from dataclasses import dataclass | |
| from functools import lru_cache | |
| from pathlib import Path | |
| from typing import Any | |
| import torch | |
| import torch.nn as nn | |
| import torchvision.transforms as T | |
| from PIL import Image | |
| from torchvision import models | |
| from protopnet import build_ppnet | |
| BASE_DIR = Path(__file__).resolve().parent | |
| CLASS_NAMES = ["no_person", "person"] | |
| class ModelConfig: | |
| name: str | |
| backend: str | |
| model_path: Path | |
| image_size: int | |
| normalize_mean: tuple[float, float, float] | |
| normalize_std: tuple[float, float, float] | |
| MODEL_CONFIGS: dict[str, ModelConfig] = { | |
| "resnet18_presence": ModelConfig( | |
| name="resnet18_presence", | |
| backend="resnet18", | |
| model_path=BASE_DIR / "best_global_model_presence.pt", | |
| image_size=224, | |
| normalize_mean=(0.485, 0.456, 0.406), | |
| normalize_std=(0.229, 0.224, 0.225), | |
| ), | |
| "ppnet_baseline": ModelConfig( | |
| name="ppnet_baseline", | |
| backend="ppnet", | |
| model_path=BASE_DIR / "baseline_40_model.pt.tar", | |
| image_size=128, | |
| normalize_mean=(0.4914, 0.4822, 0.4465), | |
| normalize_std=(0.2023, 0.1994, 0.2010), | |
| ), | |
| } | |
| DEFAULT_MODEL_NAME = os.getenv("SECUREML_MODEL", "ppnet_baseline") | |
| def build_resnet18(num_classes: int = 2) -> nn.Module: | |
| model = models.resnet18(weights=None) | |
| in_features = model.fc.in_features | |
| model.fc = nn.Linear(in_features, num_classes) | |
| return model | |
| def _normalize_prototype_shape(raw_value: Any) -> tuple[int, int, int, int]: | |
| if isinstance(raw_value, tuple): | |
| return raw_value | |
| if isinstance(raw_value, list): | |
| return tuple(raw_value) | |
| raise ValueError(f"Unsupported prototype_shape value: {raw_value!r}") | |
| def get_model_config(name: str | None = None) -> ModelConfig: | |
| model_name = name or DEFAULT_MODEL_NAME | |
| try: | |
| return MODEL_CONFIGS[model_name] | |
| except KeyError as exc: | |
| available = ", ".join(sorted(MODEL_CONFIGS)) | |
| raise ValueError(f"Unknown model '{model_name}'. Available: {available}") from exc | |
| class PresenceModelService: | |
| def __init__(self, config: ModelConfig): | |
| if not config.model_path.exists(): | |
| raise FileNotFoundError(f"Model not found: {config.model_path}") | |
| self.config = config | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| self.model = self._load_model().to(self.device) | |
| self.model.eval() | |
| self.transform = T.Compose( | |
| [ | |
| T.Resize((config.image_size, config.image_size)), | |
| T.ToTensor(), | |
| T.Normalize(config.normalize_mean, config.normalize_std), | |
| ] | |
| ) | |
| def _load_model(self) -> nn.Module: | |
| if self.config.backend == "resnet18": | |
| model = build_resnet18(num_classes=len(CLASS_NAMES)) | |
| state = torch.load(self.config.model_path, map_location="cpu") | |
| model.load_state_dict(state, strict=True) | |
| return model | |
| if self.config.backend == "ppnet": | |
| checkpoint = torch.load(self.config.model_path, map_location="cpu") | |
| state_dict = checkpoint.get("state_dict") | |
| if not isinstance(state_dict, dict): | |
| raise ValueError("Invalid PPNet checkpoint: missing state_dict.") | |
| params = checkpoint.get("params_dict", {}) | |
| model = build_ppnet( | |
| base_architecture=str(params.get("base_architecture", "vgg19")), | |
| img_size=int(params.get("img_size", self.config.image_size)), | |
| prototype_shape=_normalize_prototype_shape( | |
| params.get("prototype_shape", (40, 128, 1, 1)) | |
| ), | |
| num_classes=int(params.get("num_classes", len(CLASS_NAMES))), | |
| prototype_activation_function=str( | |
| params.get("prototype_activation_function", "log") | |
| ), | |
| add_on_layers_type=str(params.get("add_on_layers_type", "regular")), | |
| ) | |
| model.load_state_dict(state_dict, strict=True) | |
| return model | |
| raise ValueError(f"Unsupported backend: {self.config.backend}") | |
| def predict_image(self, image: Image.Image) -> dict[str, Any]: | |
| x = self.transform(image).unsqueeze(0).to(self.device) | |
| with torch.no_grad(): | |
| outputs = self.model(x) | |
| logits = outputs[0] if isinstance(outputs, (tuple, list)) else outputs | |
| probs = torch.softmax(logits, dim=-1)[0] | |
| pred_idx = int(torch.argmax(probs).item()) | |
| probabilities = { | |
| CLASS_NAMES[i]: round(float(probs[i].item()), 6) for i in range(len(CLASS_NAMES)) | |
| } | |
| return { | |
| "label": CLASS_NAMES[pred_idx], | |
| "prediction_index": pred_idx, | |
| "probabilities": probabilities, | |
| "model_name": self.config.name, | |
| "model_backend": self.config.backend, | |
| "model_path": self.config.model_path.name, | |
| } | |
| def get_model_service(model_name: str | None = None) -> PresenceModelService: | |
| return PresenceModelService(get_model_config(model_name)) | |