SecureMLAPI / model_service.py
yenslife's picture
feat: integrate ppnet inference backend
896740b
from __future__ import annotations
import os
from dataclasses import dataclass
from functools import lru_cache
from pathlib import Path
from typing import Any
import torch
import torch.nn as nn
import torchvision.transforms as T
from PIL import Image
from torchvision import models
from protopnet import build_ppnet
BASE_DIR = Path(__file__).resolve().parent
CLASS_NAMES = ["no_person", "person"]
@dataclass(frozen=True)
class ModelConfig:
name: str
backend: str
model_path: Path
image_size: int
normalize_mean: tuple[float, float, float]
normalize_std: tuple[float, float, float]
MODEL_CONFIGS: dict[str, ModelConfig] = {
"resnet18_presence": ModelConfig(
name="resnet18_presence",
backend="resnet18",
model_path=BASE_DIR / "best_global_model_presence.pt",
image_size=224,
normalize_mean=(0.485, 0.456, 0.406),
normalize_std=(0.229, 0.224, 0.225),
),
"ppnet_baseline": ModelConfig(
name="ppnet_baseline",
backend="ppnet",
model_path=BASE_DIR / "baseline_40_model.pt.tar",
image_size=128,
normalize_mean=(0.4914, 0.4822, 0.4465),
normalize_std=(0.2023, 0.1994, 0.2010),
),
}
DEFAULT_MODEL_NAME = os.getenv("SECUREML_MODEL", "ppnet_baseline")
def build_resnet18(num_classes: int = 2) -> nn.Module:
model = models.resnet18(weights=None)
in_features = model.fc.in_features
model.fc = nn.Linear(in_features, num_classes)
return model
def _normalize_prototype_shape(raw_value: Any) -> tuple[int, int, int, int]:
if isinstance(raw_value, tuple):
return raw_value
if isinstance(raw_value, list):
return tuple(raw_value)
raise ValueError(f"Unsupported prototype_shape value: {raw_value!r}")
def get_model_config(name: str | None = None) -> ModelConfig:
model_name = name or DEFAULT_MODEL_NAME
try:
return MODEL_CONFIGS[model_name]
except KeyError as exc:
available = ", ".join(sorted(MODEL_CONFIGS))
raise ValueError(f"Unknown model '{model_name}'. Available: {available}") from exc
class PresenceModelService:
def __init__(self, config: ModelConfig):
if not config.model_path.exists():
raise FileNotFoundError(f"Model not found: {config.model_path}")
self.config = config
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = self._load_model().to(self.device)
self.model.eval()
self.transform = T.Compose(
[
T.Resize((config.image_size, config.image_size)),
T.ToTensor(),
T.Normalize(config.normalize_mean, config.normalize_std),
]
)
def _load_model(self) -> nn.Module:
if self.config.backend == "resnet18":
model = build_resnet18(num_classes=len(CLASS_NAMES))
state = torch.load(self.config.model_path, map_location="cpu")
model.load_state_dict(state, strict=True)
return model
if self.config.backend == "ppnet":
checkpoint = torch.load(self.config.model_path, map_location="cpu")
state_dict = checkpoint.get("state_dict")
if not isinstance(state_dict, dict):
raise ValueError("Invalid PPNet checkpoint: missing state_dict.")
params = checkpoint.get("params_dict", {})
model = build_ppnet(
base_architecture=str(params.get("base_architecture", "vgg19")),
img_size=int(params.get("img_size", self.config.image_size)),
prototype_shape=_normalize_prototype_shape(
params.get("prototype_shape", (40, 128, 1, 1))
),
num_classes=int(params.get("num_classes", len(CLASS_NAMES))),
prototype_activation_function=str(
params.get("prototype_activation_function", "log")
),
add_on_layers_type=str(params.get("add_on_layers_type", "regular")),
)
model.load_state_dict(state_dict, strict=True)
return model
raise ValueError(f"Unsupported backend: {self.config.backend}")
def predict_image(self, image: Image.Image) -> dict[str, Any]:
x = self.transform(image).unsqueeze(0).to(self.device)
with torch.no_grad():
outputs = self.model(x)
logits = outputs[0] if isinstance(outputs, (tuple, list)) else outputs
probs = torch.softmax(logits, dim=-1)[0]
pred_idx = int(torch.argmax(probs).item())
probabilities = {
CLASS_NAMES[i]: round(float(probs[i].item()), 6) for i in range(len(CLASS_NAMES))
}
return {
"label": CLASS_NAMES[pred_idx],
"prediction_index": pred_idx,
"probabilities": probabilities,
"model_name": self.config.name,
"model_backend": self.config.backend,
"model_path": self.config.model_path.name,
}
@lru_cache(maxsize=None)
def get_model_service(model_name: str | None = None) -> PresenceModelService:
return PresenceModelService(get_model_config(model_name))