File size: 5,218 Bytes
896740b
 
 
 
dcd4485
 
896740b
dcd4485
 
 
 
 
 
 
896740b
 
dcd4485
 
 
 
 
896740b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dcd4485
 
 
 
 
 
 
896740b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dcd4485
896740b
 
 
dcd4485
896740b
dcd4485
896740b
dcd4485
 
 
896740b
dcd4485
896740b
dcd4485
 
 
896740b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dcd4485
 
 
896740b
 
dcd4485
 
 
 
 
 
 
 
 
 
896740b
 
 
dcd4485
 
 
896740b
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
from __future__ import annotations

import os
from dataclasses import dataclass
from functools import lru_cache
from pathlib import Path
from typing import Any

import torch
import torch.nn as nn
import torchvision.transforms as T
from PIL import Image
from torchvision import models

from protopnet import build_ppnet


BASE_DIR = Path(__file__).resolve().parent
CLASS_NAMES = ["no_person", "person"]


@dataclass(frozen=True)
class ModelConfig:
    name: str
    backend: str
    model_path: Path
    image_size: int
    normalize_mean: tuple[float, float, float]
    normalize_std: tuple[float, float, float]


MODEL_CONFIGS: dict[str, ModelConfig] = {
    "resnet18_presence": ModelConfig(
        name="resnet18_presence",
        backend="resnet18",
        model_path=BASE_DIR / "best_global_model_presence.pt",
        image_size=224,
        normalize_mean=(0.485, 0.456, 0.406),
        normalize_std=(0.229, 0.224, 0.225),
    ),
    "ppnet_baseline": ModelConfig(
        name="ppnet_baseline",
        backend="ppnet",
        model_path=BASE_DIR / "baseline_40_model.pt.tar",
        image_size=128,
        normalize_mean=(0.4914, 0.4822, 0.4465),
        normalize_std=(0.2023, 0.1994, 0.2010),
    ),
}

DEFAULT_MODEL_NAME = os.getenv("SECUREML_MODEL", "ppnet_baseline")


def build_resnet18(num_classes: int = 2) -> nn.Module:
    model = models.resnet18(weights=None)
    in_features = model.fc.in_features
    model.fc = nn.Linear(in_features, num_classes)
    return model


def _normalize_prototype_shape(raw_value: Any) -> tuple[int, int, int, int]:
    if isinstance(raw_value, tuple):
        return raw_value
    if isinstance(raw_value, list):
        return tuple(raw_value)
    raise ValueError(f"Unsupported prototype_shape value: {raw_value!r}")


def get_model_config(name: str | None = None) -> ModelConfig:
    model_name = name or DEFAULT_MODEL_NAME
    try:
        return MODEL_CONFIGS[model_name]
    except KeyError as exc:
        available = ", ".join(sorted(MODEL_CONFIGS))
        raise ValueError(f"Unknown model '{model_name}'. Available: {available}") from exc


class PresenceModelService:
    def __init__(self, config: ModelConfig):
        if not config.model_path.exists():
            raise FileNotFoundError(f"Model not found: {config.model_path}")

        self.config = config
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = self._load_model().to(self.device)
        self.model.eval()
        self.transform = T.Compose(
            [
                T.Resize((config.image_size, config.image_size)),
                T.ToTensor(),
                T.Normalize(config.normalize_mean, config.normalize_std),
            ]
        )

    def _load_model(self) -> nn.Module:
        if self.config.backend == "resnet18":
            model = build_resnet18(num_classes=len(CLASS_NAMES))
            state = torch.load(self.config.model_path, map_location="cpu")
            model.load_state_dict(state, strict=True)
            return model

        if self.config.backend == "ppnet":
            checkpoint = torch.load(self.config.model_path, map_location="cpu")
            state_dict = checkpoint.get("state_dict")
            if not isinstance(state_dict, dict):
                raise ValueError("Invalid PPNet checkpoint: missing state_dict.")

            params = checkpoint.get("params_dict", {})
            model = build_ppnet(
                base_architecture=str(params.get("base_architecture", "vgg19")),
                img_size=int(params.get("img_size", self.config.image_size)),
                prototype_shape=_normalize_prototype_shape(
                    params.get("prototype_shape", (40, 128, 1, 1))
                ),
                num_classes=int(params.get("num_classes", len(CLASS_NAMES))),
                prototype_activation_function=str(
                    params.get("prototype_activation_function", "log")
                ),
                add_on_layers_type=str(params.get("add_on_layers_type", "regular")),
            )
            model.load_state_dict(state_dict, strict=True)
            return model

        raise ValueError(f"Unsupported backend: {self.config.backend}")

    def predict_image(self, image: Image.Image) -> dict[str, Any]:
        x = self.transform(image).unsqueeze(0).to(self.device)

        with torch.no_grad():
            outputs = self.model(x)
            logits = outputs[0] if isinstance(outputs, (tuple, list)) else outputs
            probs = torch.softmax(logits, dim=-1)[0]
            pred_idx = int(torch.argmax(probs).item())

        probabilities = {
            CLASS_NAMES[i]: round(float(probs[i].item()), 6) for i in range(len(CLASS_NAMES))
        }
        return {
            "label": CLASS_NAMES[pred_idx],
            "prediction_index": pred_idx,
            "probabilities": probabilities,
            "model_name": self.config.name,
            "model_backend": self.config.backend,
            "model_path": self.config.model_path.name,
        }


@lru_cache(maxsize=None)
def get_model_service(model_name: str | None = None) -> PresenceModelService:
    return PresenceModelService(get_model_config(model_name))