yenslife commited on
Commit
896740b
·
1 Parent(s): 8ffefcd

feat: integrate ppnet inference backend

Browse files

新增 PPNet baseline 推論流程並整合到 FastAPI 服務。

支援以設定切換 ppnet 與 resnet18 模型,並補上本地 inference 腳本與 README 說明。

Files changed (7) hide show
  1. README.md +119 -1
  2. app.py +5 -2
  3. baseline_40_model.pt.tar +3 -0
  4. inference.py +46 -0
  5. main.py +5 -2
  6. model_service.py +103 -18
  7. protopnet.py +315 -0
README.md CHANGED
@@ -10,4 +10,122 @@ license: mit
10
  short_description: 成大資安計畫使用
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  short_description: 成大資安計畫使用
11
  ---
12
 
13
+ # SecureMLAPI
14
+
15
+ 這個專案提供一個 FastAPI 服務,用來判斷圖片中是否有人。
16
+
17
+ 目前已整合兩種推論後端:
18
+
19
+ - `ppnet_baseline`:使用 `people_detection_baseline/baseline_40_model.pt.tar`
20
+ - `resnet18_presence`:使用 `best_global_model_presence.pt`
21
+
22
+ 預設模型是 `ppnet_baseline`。
23
+
24
+ ## 開發環境
25
+
26
+ 請使用 `uv` 安裝依賴與執行指令。
27
+
28
+ ```bash
29
+ uv sync
30
+ ```
31
+
32
+ ## 啟動服務
33
+
34
+ ```bash
35
+ uv run uvicorn app:app --host 0.0.0.0 --port 8000 --reload
36
+ ```
37
+
38
+ 啟動後可使用以下路徑:
39
+
40
+ - `/docs`:Swagger UI
41
+ - `/health`:健康檢查
42
+ - `/predict`:上傳圖片並取得 JSON 推論結果
43
+ - `/demo`:簡易網頁測試介面
44
+
45
+ ## 切換模型
46
+
47
+ 目前不需要從 HTML 介面切換模型,直接用程式設定即可。
48
+
49
+ ### 方式一:用環境變數切換
50
+
51
+ ```bash
52
+ SECUREML_MODEL=ppnet_baseline uv run uvicorn app:app --host 0.0.0.0 --port 8000
53
+ ```
54
+
55
+ ```bash
56
+ SECUREML_MODEL=resnet18_presence uv run uvicorn app:app --host 0.0.0.0 --port 8000
57
+ ```
58
+
59
+ ### 方式二:修改預設值
60
+
61
+ 可直接修改 `model_service.py` 裡的:
62
+
63
+ ```python
64
+ DEFAULT_MODEL_NAME = os.getenv("SECUREML_MODEL", "ppnet_baseline")
65
+ ```
66
+
67
+ 以及 `MODEL_CONFIGS` 中對應模型的設定。
68
+
69
+ ## 本地推論
70
+
71
+ 專案提供 `inference.py`,可直接對單張圖片做推論:
72
+
73
+ ```bash
74
+ uv run python inference.py --image person.jpg
75
+ ```
76
+
77
+ 指定模型:
78
+
79
+ ```bash
80
+ uv run python inference.py --image person.jpg --model ppnet_baseline
81
+ ```
82
+
83
+ ```bash
84
+ uv run python inference.py --image person.jpg --model resnet18_presence
85
+ ```
86
+
87
+ ## API 使用方式
88
+
89
+ 使用 `curl` 呼叫 `/predict`:
90
+
91
+ ```bash
92
+ curl -X POST \
93
+ -F "file=@person.jpg" \
94
+ http://127.0.0.1:8000/predict
95
+ ```
96
+
97
+ 回傳格式範例:
98
+
99
+ ```json
100
+ {
101
+ "label": "person",
102
+ "prediction_index": 1,
103
+ "probabilities": {
104
+ "no_person": 0.0,
105
+ "person": 1.0
106
+ },
107
+ "model_name": "ppnet_baseline",
108
+ "model_backend": "ppnet",
109
+ "model_path": "baseline_40_model.pt.tar",
110
+ "filename": "person.jpg",
111
+ "content_type": "image/jpeg"
112
+ }
113
+ ```
114
+
115
+ ## 目前模型設定位置
116
+
117
+ 模型切換與設定集中在 `model_service.py`:
118
+
119
+ - `MODEL_CONFIGS`:定義可用模型
120
+ - `DEFAULT_MODEL_NAME`:定義預設模型
121
+ - `get_model_service()`:建立對應推論服務
122
+
123
+ 如果之後要新增模型,建議直接在 `MODEL_CONFIGS` 增加一筆設定,並在 `_load_model()` 補上對應後端載入方式。
124
+
125
+ ## 驗證
126
+
127
+ 可先做基本語法檢查:
128
+
129
+ ```bash
130
+ uv run python -m py_compile app.py main.py inference.py model_service.py protopnet.py
131
+ ```
app.py CHANGED
@@ -8,10 +8,11 @@ from fastapi.responses import HTMLResponse
8
  from fastapi.templating import Jinja2Templates
9
  from PIL import Image, UnidentifiedImageError
10
 
11
- from model_service import MODEL_PATH, get_model_service
12
 
13
  BASE_DIR = Path(__file__).resolve().parent
14
  templates = Jinja2Templates(directory=str(BASE_DIR / "templates"))
 
15
 
16
 
17
  @asynccontextmanager
@@ -67,7 +68,9 @@ def root():
67
  return {
68
  "message": "Presence Detection API",
69
  "docs": "/docs",
70
- "model_path": str(MODEL_PATH.name),
 
 
71
  }
72
 
73
 
 
8
  from fastapi.templating import Jinja2Templates
9
  from PIL import Image, UnidentifiedImageError
10
 
11
+ from model_service import get_model_config, get_model_service
12
 
13
  BASE_DIR = Path(__file__).resolve().parent
14
  templates = Jinja2Templates(directory=str(BASE_DIR / "templates"))
15
+ ACTIVE_MODEL_CONFIG = get_model_config()
16
 
17
 
18
  @asynccontextmanager
 
68
  return {
69
  "message": "Presence Detection API",
70
  "docs": "/docs",
71
+ "model_name": ACTIVE_MODEL_CONFIG.name,
72
+ "model_backend": ACTIVE_MODEL_CONFIG.backend,
73
+ "model_path": str(ACTIVE_MODEL_CONFIG.model_path.name),
74
  }
75
 
76
 
baseline_40_model.pt.tar ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:555c304d21f6db8d41b53ff06b7c9bd9a7fe78a104b3ae69150cc0061532d94a
3
+ size 80485030
inference.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ from pathlib import Path
5
+
6
+ from PIL import Image
7
+
8
+ from model_service import get_model_config, get_model_service
9
+
10
+
11
+ def build_parser() -> argparse.ArgumentParser:
12
+ parser = argparse.ArgumentParser(description="Run local inference with the configured model.")
13
+ parser.add_argument("--image", type=Path, default=Path("person.jpg"), help="Input image path.")
14
+ parser.add_argument(
15
+ "--model",
16
+ type=str,
17
+ default=None,
18
+ help="Optional model name override. Defaults to SECUREML_MODEL or the project default.",
19
+ )
20
+ return parser
21
+
22
+
23
+ def main() -> None:
24
+ args = build_parser().parse_args()
25
+ if not args.image.exists():
26
+ raise SystemExit(f"Image not found: {args.image}")
27
+
28
+ config = get_model_config(args.model)
29
+ service = get_model_service(args.model)
30
+ image = Image.open(args.image).convert("RGB")
31
+ result = service.predict_image(image)
32
+
33
+ print(f"[INFO] device={service.device}")
34
+ print(f"[INFO] model_name={config.name}")
35
+ print(f"[INFO] model_backend={config.backend}")
36
+ print(f"[INFO] model_path={config.model_path}")
37
+ print(f"[INFO] image={args.image}")
38
+ print("========== RESULT ==========")
39
+ print(f"prediction: {result['prediction_index']} ({result['label']})")
40
+ for label, prob in result["probabilities"].items():
41
+ print(f"P({label}) = {prob:.6f}")
42
+ print("============================")
43
+
44
+
45
+ if __name__ == "__main__":
46
+ main()
main.py CHANGED
@@ -1,7 +1,7 @@
1
  from pathlib import Path
2
  from PIL import Image
3
 
4
- from model_service import MODEL_PATH, get_model_service
5
 
6
 
7
  IMAGE_PATH = Path("person.jpg")
@@ -13,8 +13,11 @@ def main():
13
  raise SystemExit(f"Image not found: {IMAGE_PATH}")
14
 
15
  service = get_model_service()
 
16
  print(f"[INFO] device={service.device}")
17
- print(f"[INFO] model={MODEL_PATH}")
 
 
18
  print(f"[INFO] image={IMAGE_PATH}")
19
 
20
  img = Image.open(IMAGE_PATH).convert("RGB")
 
1
  from pathlib import Path
2
  from PIL import Image
3
 
4
+ from model_service import get_model_config, get_model_service
5
 
6
 
7
  IMAGE_PATH = Path("person.jpg")
 
13
  raise SystemExit(f"Image not found: {IMAGE_PATH}")
14
 
15
  service = get_model_service()
16
+ config = get_model_config()
17
  print(f"[INFO] device={service.device}")
18
+ print(f"[INFO] model_name={config.name}")
19
+ print(f"[INFO] model_backend={config.backend}")
20
+ print(f"[INFO] model_path={config.model_path}")
21
  print(f"[INFO] image={IMAGE_PATH}")
22
 
23
  img = Image.open(IMAGE_PATH).convert("RGB")
model_service.py CHANGED
@@ -1,5 +1,10 @@
 
 
 
 
1
  from functools import lru_cache
2
  from pathlib import Path
 
3
 
4
  import torch
5
  import torch.nn as nn
@@ -7,46 +12,123 @@ import torchvision.transforms as T
7
  from PIL import Image
8
  from torchvision import models
9
 
 
 
10
 
11
  BASE_DIR = Path(__file__).resolve().parent
12
- MODEL_PATH = BASE_DIR / "best_global_model_presence.pt"
13
  CLASS_NAMES = ["no_person", "person"]
14
 
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  def build_resnet18(num_classes: int = 2) -> nn.Module:
17
- # We load task-specific weights from `best_global_model_presence.pt`, so no
18
- # pretrained backbone download is needed at runtime.
19
  model = models.resnet18(weights=None)
20
  in_features = model.fc.in_features
21
  model.fc = nn.Linear(in_features, num_classes)
22
  return model
23
 
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  class PresenceModelService:
26
- def __init__(self, model_path: Path):
27
- if not model_path.exists():
28
- raise FileNotFoundError(f"Model not found: {model_path}")
29
 
 
30
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
- self.model = build_resnet18(num_classes=2).to(self.device)
32
-
33
- state = torch.load(model_path, map_location="cpu")
34
- self.model.load_state_dict(state, strict=True)
35
  self.model.eval()
36
-
37
  self.transform = T.Compose(
38
  [
39
- T.Resize((224, 224)),
40
  T.ToTensor(),
41
- T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
42
  ]
43
  )
44
 
45
- def predict_image(self, image: Image.Image) -> dict:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  x = self.transform(image).unsqueeze(0).to(self.device)
47
 
48
  with torch.no_grad():
49
- logits = self.model(x)
 
50
  probs = torch.softmax(logits, dim=-1)[0]
51
  pred_idx = int(torch.argmax(probs).item())
52
 
@@ -57,9 +139,12 @@ class PresenceModelService:
57
  "label": CLASS_NAMES[pred_idx],
58
  "prediction_index": pred_idx,
59
  "probabilities": probabilities,
 
 
 
60
  }
61
 
62
 
63
- @lru_cache(maxsize=1)
64
- def get_model_service() -> PresenceModelService:
65
- return PresenceModelService(MODEL_PATH)
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ from dataclasses import dataclass
5
  from functools import lru_cache
6
  from pathlib import Path
7
+ from typing import Any
8
 
9
  import torch
10
  import torch.nn as nn
 
12
  from PIL import Image
13
  from torchvision import models
14
 
15
+ from protopnet import build_ppnet
16
+
17
 
18
  BASE_DIR = Path(__file__).resolve().parent
 
19
  CLASS_NAMES = ["no_person", "person"]
20
 
21
 
22
+ @dataclass(frozen=True)
23
+ class ModelConfig:
24
+ name: str
25
+ backend: str
26
+ model_path: Path
27
+ image_size: int
28
+ normalize_mean: tuple[float, float, float]
29
+ normalize_std: tuple[float, float, float]
30
+
31
+
32
+ MODEL_CONFIGS: dict[str, ModelConfig] = {
33
+ "resnet18_presence": ModelConfig(
34
+ name="resnet18_presence",
35
+ backend="resnet18",
36
+ model_path=BASE_DIR / "best_global_model_presence.pt",
37
+ image_size=224,
38
+ normalize_mean=(0.485, 0.456, 0.406),
39
+ normalize_std=(0.229, 0.224, 0.225),
40
+ ),
41
+ "ppnet_baseline": ModelConfig(
42
+ name="ppnet_baseline",
43
+ backend="ppnet",
44
+ model_path=BASE_DIR / "baseline_40_model.pt.tar",
45
+ image_size=128,
46
+ normalize_mean=(0.4914, 0.4822, 0.4465),
47
+ normalize_std=(0.2023, 0.1994, 0.2010),
48
+ ),
49
+ }
50
+
51
+ DEFAULT_MODEL_NAME = os.getenv("SECUREML_MODEL", "ppnet_baseline")
52
+
53
+
54
  def build_resnet18(num_classes: int = 2) -> nn.Module:
 
 
55
  model = models.resnet18(weights=None)
56
  in_features = model.fc.in_features
57
  model.fc = nn.Linear(in_features, num_classes)
58
  return model
59
 
60
 
61
+ def _normalize_prototype_shape(raw_value: Any) -> tuple[int, int, int, int]:
62
+ if isinstance(raw_value, tuple):
63
+ return raw_value
64
+ if isinstance(raw_value, list):
65
+ return tuple(raw_value)
66
+ raise ValueError(f"Unsupported prototype_shape value: {raw_value!r}")
67
+
68
+
69
+ def get_model_config(name: str | None = None) -> ModelConfig:
70
+ model_name = name or DEFAULT_MODEL_NAME
71
+ try:
72
+ return MODEL_CONFIGS[model_name]
73
+ except KeyError as exc:
74
+ available = ", ".join(sorted(MODEL_CONFIGS))
75
+ raise ValueError(f"Unknown model '{model_name}'. Available: {available}") from exc
76
+
77
+
78
  class PresenceModelService:
79
+ def __init__(self, config: ModelConfig):
80
+ if not config.model_path.exists():
81
+ raise FileNotFoundError(f"Model not found: {config.model_path}")
82
 
83
+ self.config = config
84
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
85
+ self.model = self._load_model().to(self.device)
 
 
 
86
  self.model.eval()
 
87
  self.transform = T.Compose(
88
  [
89
+ T.Resize((config.image_size, config.image_size)),
90
  T.ToTensor(),
91
+ T.Normalize(config.normalize_mean, config.normalize_std),
92
  ]
93
  )
94
 
95
+ def _load_model(self) -> nn.Module:
96
+ if self.config.backend == "resnet18":
97
+ model = build_resnet18(num_classes=len(CLASS_NAMES))
98
+ state = torch.load(self.config.model_path, map_location="cpu")
99
+ model.load_state_dict(state, strict=True)
100
+ return model
101
+
102
+ if self.config.backend == "ppnet":
103
+ checkpoint = torch.load(self.config.model_path, map_location="cpu")
104
+ state_dict = checkpoint.get("state_dict")
105
+ if not isinstance(state_dict, dict):
106
+ raise ValueError("Invalid PPNet checkpoint: missing state_dict.")
107
+
108
+ params = checkpoint.get("params_dict", {})
109
+ model = build_ppnet(
110
+ base_architecture=str(params.get("base_architecture", "vgg19")),
111
+ img_size=int(params.get("img_size", self.config.image_size)),
112
+ prototype_shape=_normalize_prototype_shape(
113
+ params.get("prototype_shape", (40, 128, 1, 1))
114
+ ),
115
+ num_classes=int(params.get("num_classes", len(CLASS_NAMES))),
116
+ prototype_activation_function=str(
117
+ params.get("prototype_activation_function", "log")
118
+ ),
119
+ add_on_layers_type=str(params.get("add_on_layers_type", "regular")),
120
+ )
121
+ model.load_state_dict(state_dict, strict=True)
122
+ return model
123
+
124
+ raise ValueError(f"Unsupported backend: {self.config.backend}")
125
+
126
+ def predict_image(self, image: Image.Image) -> dict[str, Any]:
127
  x = self.transform(image).unsqueeze(0).to(self.device)
128
 
129
  with torch.no_grad():
130
+ outputs = self.model(x)
131
+ logits = outputs[0] if isinstance(outputs, (tuple, list)) else outputs
132
  probs = torch.softmax(logits, dim=-1)[0]
133
  pred_idx = int(torch.argmax(probs).item())
134
 
 
139
  "label": CLASS_NAMES[pred_idx],
140
  "prediction_index": pred_idx,
141
  "probabilities": probabilities,
142
+ "model_name": self.config.name,
143
+ "model_backend": self.config.backend,
144
+ "model_path": self.config.model_path.name,
145
  }
146
 
147
 
148
+ @lru_cache(maxsize=None)
149
+ def get_model_service(model_name: str | None = None) -> PresenceModelService:
150
+ return PresenceModelService(get_model_config(model_name))
protopnet.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import math
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+
10
+ _VGG_CFGS = {
11
+ "vgg11": [64, "M", 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
12
+ "vgg13": [64, 64, "M", 128, 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
13
+ "vgg16": [64, 64, "M", 128, 128, "M", 256, 256, 256, "M", 512, 512, 512, "M", 512, 512, 512, "M"],
14
+ "vgg19": [
15
+ 64,
16
+ 64,
17
+ "M",
18
+ 128,
19
+ 128,
20
+ "M",
21
+ 256,
22
+ 256,
23
+ 256,
24
+ 256,
25
+ "M",
26
+ 512,
27
+ 512,
28
+ 512,
29
+ 512,
30
+ "M",
31
+ 512,
32
+ 512,
33
+ 512,
34
+ 512,
35
+ "M",
36
+ ],
37
+ }
38
+
39
+
40
+ class VGGFeatures(nn.Module):
41
+ def __init__(self, cfg: list[int | str], batch_norm: bool = False, init_weights: bool = True):
42
+ super().__init__()
43
+ self.batch_norm = batch_norm
44
+ self.kernel_sizes: list[int] = []
45
+ self.strides: list[int] = []
46
+ self.paddings: list[int] = []
47
+ self.features = self._make_layers(cfg, batch_norm)
48
+
49
+ if init_weights:
50
+ self._initialize_weights()
51
+
52
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
53
+ return self.features(x)
54
+
55
+ def _make_layers(self, cfg: list[int | str], batch_norm: bool) -> nn.Sequential:
56
+ layers: list[nn.Module] = []
57
+ in_channels = 3
58
+ self.n_layers = 0
59
+
60
+ for item in cfg:
61
+ if item == "M":
62
+ layers.append(nn.MaxPool2d(kernel_size=2, stride=2))
63
+ self.kernel_sizes.append(2)
64
+ self.strides.append(2)
65
+ self.paddings.append(0)
66
+ continue
67
+
68
+ conv2d = nn.Conv2d(in_channels, item, kernel_size=3, padding=1)
69
+ if batch_norm:
70
+ layers.extend([conv2d, nn.BatchNorm2d(item), nn.ReLU(inplace=True)])
71
+ else:
72
+ layers.extend([conv2d, nn.ReLU(inplace=True)])
73
+
74
+ self.n_layers += 1
75
+ self.kernel_sizes.append(3)
76
+ self.strides.append(1)
77
+ self.paddings.append(1)
78
+ in_channels = item
79
+
80
+ return nn.Sequential(*layers)
81
+
82
+ def _initialize_weights(self) -> None:
83
+ for module in self.modules():
84
+ if isinstance(module, nn.Conv2d):
85
+ nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
86
+ if module.bias is not None:
87
+ nn.init.constant_(module.bias, 0)
88
+ elif isinstance(module, nn.BatchNorm2d):
89
+ nn.init.constant_(module.weight, 1)
90
+ nn.init.constant_(module.bias, 0)
91
+ elif isinstance(module, nn.Linear):
92
+ nn.init.normal_(module.weight, 0, 0.01)
93
+ nn.init.constant_(module.bias, 0)
94
+
95
+ def conv_info(self) -> tuple[list[int], list[int], list[int]]:
96
+ return self.kernel_sizes, self.strides, self.paddings
97
+
98
+ def __repr__(self) -> str:
99
+ return f"VGG{self.n_layers + 3}, batch_norm={self.batch_norm}"
100
+
101
+
102
+ def build_vgg_features(name: str) -> VGGFeatures:
103
+ if name not in _VGG_CFGS:
104
+ raise ValueError(f"Unsupported VGG architecture: {name}")
105
+ return VGGFeatures(_VGG_CFGS[name], batch_norm=name.endswith("_bn"))
106
+
107
+
108
+ def compute_layer_rf_info(
109
+ layer_filter_size: int,
110
+ layer_stride: int,
111
+ layer_padding: int | str,
112
+ previous_layer_rf_info: list[float],
113
+ ) -> list[float]:
114
+ n_in, j_in, r_in, start_in = previous_layer_rf_info
115
+
116
+ if layer_padding == "SAME":
117
+ n_out = math.ceil(float(n_in) / float(layer_stride))
118
+ if n_in % layer_stride == 0:
119
+ pad = max(layer_filter_size - layer_stride, 0)
120
+ else:
121
+ pad = max(layer_filter_size - (n_in % layer_stride), 0)
122
+ elif layer_padding == "VALID":
123
+ n_out = math.ceil(float(n_in - layer_filter_size + 1) / float(layer_stride))
124
+ pad = 0
125
+ else:
126
+ pad = layer_padding * 2
127
+ n_out = math.floor((n_in - layer_filter_size + pad) / layer_stride) + 1
128
+
129
+ pad_left = math.floor(pad / 2)
130
+ j_out = j_in * layer_stride
131
+ r_out = r_in + (layer_filter_size - 1) * j_in
132
+ start_out = start_in + ((layer_filter_size - 1) / 2 - pad_left) * j_in
133
+ return [n_out, j_out, r_out, start_out]
134
+
135
+
136
+ def compute_proto_layer_rf_info_v2(
137
+ img_size: int,
138
+ layer_filter_sizes: list[int],
139
+ layer_strides: list[int],
140
+ layer_paddings: list[int],
141
+ prototype_kernel_size: int,
142
+ ) -> list[float]:
143
+ if not (
144
+ len(layer_filter_sizes) == len(layer_strides) == len(layer_paddings)
145
+ ):
146
+ raise ValueError("Layer metadata length mismatch.")
147
+
148
+ rf_info: list[float] = [img_size, 1, 1, 0.5]
149
+ for filter_size, stride_size, padding_size in zip(
150
+ layer_filter_sizes, layer_strides, layer_paddings, strict=True
151
+ ):
152
+ rf_info = compute_layer_rf_info(
153
+ layer_filter_size=int(filter_size),
154
+ layer_stride=stride_size,
155
+ layer_padding=padding_size,
156
+ previous_layer_rf_info=rf_info,
157
+ )
158
+
159
+ return compute_layer_rf_info(
160
+ layer_filter_size=prototype_kernel_size,
161
+ layer_stride=1,
162
+ layer_padding="VALID",
163
+ previous_layer_rf_info=rf_info,
164
+ )
165
+
166
+
167
+ class PPNet(nn.Module):
168
+ def __init__(
169
+ self,
170
+ features: nn.Module,
171
+ img_size: int,
172
+ prototype_shape: tuple[int, int, int, int],
173
+ proto_layer_rf_info: list[float],
174
+ num_classes: int,
175
+ init_weights: bool = True,
176
+ prototype_activation_function: str = "log",
177
+ add_on_layers_type: str = "bottleneck",
178
+ ):
179
+ super().__init__()
180
+ self.img_size = img_size
181
+ self.prototype_shape = prototype_shape
182
+ self.num_prototypes = prototype_shape[0]
183
+ self.num_classes = num_classes
184
+ self.epsilon = 1e-4
185
+ self.prototype_activation_function = prototype_activation_function
186
+ self.proto_layer_rf_info = proto_layer_rf_info
187
+ self.features = features
188
+
189
+ if self.num_prototypes % self.num_classes != 0:
190
+ raise ValueError("Number of prototypes must be divisible by num_classes.")
191
+
192
+ self.prototype_class_identity = torch.zeros(self.num_prototypes, self.num_classes)
193
+ num_prototypes_per_class = self.num_prototypes // self.num_classes
194
+ for idx in range(self.num_prototypes):
195
+ self.prototype_class_identity[idx, idx // num_prototypes_per_class] = 1
196
+
197
+ features_name = str(self.features).upper()
198
+ if features_name.startswith("VGG") or features_name.startswith("RES"):
199
+ in_channels = [m for m in features.modules() if isinstance(m, nn.Conv2d)][-1].out_channels
200
+ elif features_name.startswith("DENSE"):
201
+ in_channels = [m for m in features.modules() if isinstance(m, nn.BatchNorm2d)][-1].num_features
202
+ else:
203
+ raise ValueError("Unsupported base architecture.")
204
+
205
+ if add_on_layers_type == "bottleneck":
206
+ add_on_layers: list[nn.Module] = []
207
+ current_in_channels = in_channels
208
+ while current_in_channels > self.prototype_shape[1] or not add_on_layers:
209
+ current_out_channels = max(self.prototype_shape[1], current_in_channels // 2)
210
+ add_on_layers.append(
211
+ nn.Conv2d(current_in_channels, current_out_channels, kernel_size=1)
212
+ )
213
+ add_on_layers.append(nn.ReLU())
214
+ add_on_layers.append(
215
+ nn.Conv2d(current_out_channels, current_out_channels, kernel_size=1)
216
+ )
217
+ if current_out_channels > self.prototype_shape[1]:
218
+ add_on_layers.append(nn.ReLU())
219
+ else:
220
+ add_on_layers.append(nn.Sigmoid())
221
+ current_in_channels //= 2
222
+ self.add_on_layers = nn.Sequential(*add_on_layers)
223
+ else:
224
+ self.add_on_layers = nn.Sequential(
225
+ nn.Conv2d(in_channels, self.prototype_shape[1], kernel_size=1),
226
+ nn.ReLU(),
227
+ nn.Conv2d(self.prototype_shape[1], self.prototype_shape[1], kernel_size=1),
228
+ nn.Sigmoid(),
229
+ )
230
+
231
+ self.prototype_vectors = nn.Parameter(torch.rand(self.prototype_shape), requires_grad=True)
232
+ self.ones = nn.Parameter(torch.ones(self.prototype_shape), requires_grad=False)
233
+ self.last_layer = nn.Linear(self.num_prototypes, self.num_classes, bias=False)
234
+
235
+ if init_weights:
236
+ self._initialize_weights()
237
+
238
+ def conv_features(self, x: torch.Tensor) -> torch.Tensor:
239
+ return self.add_on_layers(self.features(x))
240
+
241
+ def _l2_convolution(self, x: torch.Tensor) -> torch.Tensor:
242
+ x2_patch_sum = F.conv2d(input=x**2, weight=self.ones)
243
+ p2 = torch.sum(self.prototype_vectors**2, dim=(1, 2, 3)).view(-1, 1, 1)
244
+ xp = F.conv2d(input=x, weight=self.prototype_vectors)
245
+ distances = F.relu(x2_patch_sum - 2 * xp + p2)
246
+ return distances
247
+
248
+ def prototype_distances(self, x: torch.Tensor) -> torch.Tensor:
249
+ return self._l2_convolution(self.conv_features(x))
250
+
251
+ def distance_2_similarity(self, distances: torch.Tensor) -> torch.Tensor:
252
+ if self.prototype_activation_function == "log":
253
+ return torch.log((distances + 1) / (distances + self.epsilon))
254
+ if self.prototype_activation_function == "linear":
255
+ return -distances
256
+ raise ValueError(
257
+ f"Unsupported prototype activation function: {self.prototype_activation_function}"
258
+ )
259
+
260
+ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
261
+ distances = self.prototype_distances(x)
262
+ min_distances = -F.max_pool2d(
263
+ -distances, kernel_size=(distances.size(2), distances.size(3))
264
+ )
265
+ min_distances = min_distances.view(-1, self.num_prototypes)
266
+ prototype_activations = self.distance_2_similarity(min_distances)
267
+ logits = self.last_layer(prototype_activations)
268
+ return logits, min_distances
269
+
270
+ def set_last_layer_incorrect_connection(self, incorrect_strength: float) -> None:
271
+ positive_locs = torch.t(self.prototype_class_identity)
272
+ negative_locs = 1 - positive_locs
273
+ self.last_layer.weight.data.copy_(positive_locs + incorrect_strength * negative_locs)
274
+
275
+ def _initialize_weights(self) -> None:
276
+ for module in self.add_on_layers.modules():
277
+ if isinstance(module, nn.Conv2d):
278
+ nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
279
+ if module.bias is not None:
280
+ nn.init.constant_(module.bias, 0)
281
+ elif isinstance(module, nn.BatchNorm2d):
282
+ nn.init.constant_(module.weight, 1)
283
+ nn.init.constant_(module.bias, 0)
284
+
285
+ self.set_last_layer_incorrect_connection(incorrect_strength=-0.5)
286
+
287
+
288
+ def build_ppnet(
289
+ *,
290
+ base_architecture: str,
291
+ img_size: int,
292
+ prototype_shape: tuple[int, int, int, int],
293
+ num_classes: int,
294
+ prototype_activation_function: str,
295
+ add_on_layers_type: str,
296
+ ) -> PPNet:
297
+ features = build_vgg_features(base_architecture)
298
+ layer_filter_sizes, layer_strides, layer_paddings = features.conv_info()
299
+ proto_layer_rf_info = compute_proto_layer_rf_info_v2(
300
+ img_size=img_size,
301
+ layer_filter_sizes=layer_filter_sizes,
302
+ layer_strides=layer_strides,
303
+ layer_paddings=layer_paddings,
304
+ prototype_kernel_size=prototype_shape[2],
305
+ )
306
+ return PPNet(
307
+ features=features,
308
+ img_size=img_size,
309
+ prototype_shape=prototype_shape,
310
+ proto_layer_rf_info=proto_layer_rf_info,
311
+ num_classes=num_classes,
312
+ init_weights=True,
313
+ prototype_activation_function=prototype_activation_function,
314
+ add_on_layers_type=add_on_layers_type,
315
+ )