File size: 11,544 Bytes
8ac50b6 31e8146 8ac50b6 2d788b3 8ac50b6 2d788b3 8ac50b6 31e8146 8ac50b6 31e8146 8ac50b6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 | """
src/models.py β Central Model Registry
=========================================
Downloads backbone weights **once** from the internet (PyTorch Hub / timm),
freezes every feature-extraction layer, and caches the result in RAM with
Streamlit's ``@st.cache_resource``.
Strategy
--------
1. **Freeze the Backbone** β ``requires_grad = False`` on every parameter.
The backbone is a pure feature extractor β no gradient updates, ever.
2. **Cache the Resource** β ``@st.cache_resource`` keeps the heavy model
in RAM even when you switch pages.
3. **Define the Head** β ``RecognitionHead``: a tiny sklearn
LogisticRegression that takes the backbone's feature vector and
produces a recognition score. Lives only in ``st.session_state``.
"""
import streamlit as st
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
import timm
import cv2
import numpy as np
# ---------------------------------------------------------------------------
# Device selection (MPS > CUDA > CPU)
# ---------------------------------------------------------------------------
DEVICE = (
"mps" if torch.backends.mps.is_available() else
"cuda" if torch.cuda.is_available() else
"cpu"
)
# ---------------------------------------------------------------------------
# Shared ImageNet preprocessing
# ---------------------------------------------------------------------------
_IMAGENET_TRANSFORM = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
# ===================================================================
# Base class
# ===================================================================
class _FrozenBackbone:
"""Shared helpers: freeze, normalise activation maps."""
DIM: int = 0 # overridden by subclasses
# --- freeze every parameter ---
def _freeze(self, model: nn.Module) -> nn.Module:
model.eval()
for p in model.parameters():
p.requires_grad = False
return model.to(DEVICE)
# --- public interface ---
def get_features(self, img_bgr: np.ndarray) -> np.ndarray:
"""Return a 1-D float32 feature vector for *img_bgr* (BGR uint8)."""
raise NotImplementedError
def get_activation_maps(self, img_bgr: np.ndarray,
n_maps: int = 6) -> list[np.ndarray]:
"""Return *n_maps* normalised float32 spatial activation maps."""
raise NotImplementedError
@staticmethod
def _norm(m: np.ndarray) -> np.ndarray:
lo, hi = m.min(), m.max()
return ((m - lo) / (hi - lo + 1e-5)).astype(np.float32)
# ===================================================================
# ResNet-18
# ===================================================================
class ResNet18Backbone(_FrozenBackbone):
"""ResNet-18 downloaded from PyTorch Hub, frozen, classifier removed."""
DIM = 512
def __init__(self):
full = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
self.backbone = self._freeze(full)
self.extractor = nn.Sequential(*list(full.children())[:-1]).to(DEVICE)
self.transform = _IMAGENET_TRANSFORM
def get_features(self, img_bgr):
t = self.transform(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB))
with torch.no_grad():
return self.extractor(t.unsqueeze(0).to(DEVICE)).cpu().numpy().flatten()
def get_activation_maps(self, img_bgr, n_maps=6):
cap = {}
hook = self.backbone.layer4.register_forward_hook(
lambda _m, _i, o: cap.update(feat=o))
t = self.transform(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB))
with torch.no_grad():
self.backbone(t.unsqueeze(0).to(DEVICE))
hook.remove()
acts = cap["feat"][0].cpu().numpy()
return [self._norm(acts[i]) for i in range(min(n_maps, acts.shape[0]))]
# ===================================================================
# MobileNetV3-Small
# ===================================================================
class MobileNetV3Backbone(_FrozenBackbone):
"""MobileNetV3-Small from PyTorch Hub, frozen, classifier = Identity."""
DIM = 576
def __init__(self):
self.backbone = models.mobilenet_v3_small(
weights=models.MobileNet_V3_Small_Weights.DEFAULT)
self.backbone.classifier = nn.Identity()
self._freeze(self.backbone)
self.transform = _IMAGENET_TRANSFORM
def get_features(self, img_bgr):
t = self.transform(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB))
with torch.no_grad():
return self.backbone(t.unsqueeze(0).to(DEVICE)).cpu().numpy().flatten()
def get_activation_maps(self, img_bgr, n_maps=6):
cap = {}
hook = self.backbone.features[-1].register_forward_hook(
lambda _m, _i, o: cap.update(feat=o))
t = self.transform(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB))
with torch.no_grad():
self.backbone(t.unsqueeze(0).to(DEVICE))
hook.remove()
acts = cap["feat"][0].cpu().numpy()
return [self._norm(acts[i]) for i in range(min(n_maps, acts.shape[0]))]
# ===================================================================
# MobileViT-XXS
# ===================================================================
class MobileViTBackbone(_FrozenBackbone):
"""MobileViT-XXS from timm (Apple Research), frozen."""
DIM = 320
def __init__(self):
self.backbone = timm.create_model(
"mobilevit_xxs.cvnets_in1k", pretrained=True, num_classes=0)
self._freeze(self.backbone)
cfg = timm.data.resolve_model_data_config(self.backbone)
self.transform = timm.data.create_transform(**cfg, is_training=False)
def _to_tensor(self, img_bgr):
from PIL import Image
pil = Image.fromarray(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB))
return self.transform(pil).unsqueeze(0).to(DEVICE)
def get_features(self, img_bgr):
with torch.no_grad():
return self.backbone(self._to_tensor(img_bgr)).cpu().numpy().flatten()
def get_activation_maps(self, img_bgr, n_maps=6):
cap = {}
hook = self.backbone.stages[-1].register_forward_hook(
lambda _m, _i, o: cap.update(feat=o))
with torch.no_grad():
self.backbone(self._to_tensor(img_bgr))
hook.remove()
acts = cap["feat"][0].cpu().numpy()
return [self._norm(acts[i]) for i in range(min(n_maps, acts.shape[0]))]
# ===================================================================
# YOLOv8n Backbone (feature extractor only β no detection head)
# ===================================================================
class YOLOv8Backbone(_FrozenBackbone):
"""YOLOv8-Nano backbone (layers 0β9), frozen, used as feature extractor.
Produces a 256-D embedding via global average pooling after SPPF."""
DIM = 256
def __init__(self):
from ultralytics import YOLO as _YOLO
full = _YOLO("models/yolov8n.pt")
# Layers 0-9 = Conv/C2f backbone + SPPF (before the detection neck)
self.backbone = nn.Sequential(*list(full.model.model[:10]))
self._freeze(self.backbone)
self.transform = _IMAGENET_TRANSFORM
self.pool = nn.AdaptiveAvgPool2d(1)
def get_features(self, img_bgr):
t = self.transform(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB))
with torch.no_grad():
feat = self.backbone(t.unsqueeze(0).to(DEVICE))
return self.pool(feat).cpu().numpy().flatten()
def get_activation_maps(self, img_bgr, n_maps=6):
t = self.transform(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB))
with torch.no_grad():
feat = self.backbone(t.unsqueeze(0).to(DEVICE))
acts = feat[0].cpu().numpy()
return [self._norm(acts[i]) for i in range(min(n_maps, acts.shape[0]))]
# ===================================================================
# Lightweight Head (lives in session state, never on disk)
# ===================================================================
class RecognitionHead:
"""
A tiny trainable layer on top of a frozen backbone.
Wraps sklearn ``LogisticRegression`` for binary classification.
Stored in ``st.session_state`` and optionally persisted to disk.
"""
def __init__(self, C: float = 1.0, max_iter: int = 1000):
from sklearn.linear_model import LogisticRegression
self.model = LogisticRegression(C=C, max_iter=max_iter)
self.is_trained = False
def fit(self, X, y):
self.model.fit(X, y)
self.is_trained = True
return self
def predict(self, features: np.ndarray):
"""Return *(label, confidence)* for a single feature vector."""
probs = self.model.predict_proba([features])[0]
idx = int(np.argmax(probs))
return self.model.classes_[idx], probs[idx]
def predict_proba(self, X):
return self.model.predict_proba(X)
@property
def classes_(self):
return self.model.classes_
def save(self, path: str):
"""Persist the trained head to *path* via joblib."""
import joblib
from pathlib import Path
Path(path).parent.mkdir(parents=True, exist_ok=True)
joblib.dump({"model": self.model, "is_trained": self.is_trained}, path)
@classmethod
def load(cls, path: str) -> "RecognitionHead":
"""Load a previously saved head from *path*."""
import joblib
data = joblib.load(path)
head = cls.__new__(cls)
head.model = data["model"]
head.is_trained = data["is_trained"]
return head
# ===================================================================
# Cached loaders β @st.cache_resource keeps models in RAM
# ===================================================================
@st.cache_resource
def get_resnet() -> ResNet18Backbone:
"""Download & freeze ResNet-18. Stays in RAM across page switches."""
return ResNet18Backbone()
@st.cache_resource
def get_mobilenet() -> MobileNetV3Backbone:
"""Download & freeze MobileNetV3-Small. Stays in RAM."""
return MobileNetV3Backbone()
@st.cache_resource
def get_mobilevit() -> MobileViTBackbone:
"""Download & freeze MobileViT-XXS. Stays in RAM."""
return MobileViTBackbone()
@st.cache_resource
def get_yolov8() -> YOLOv8Backbone:
"""Load & freeze YOLOv8n backbone. Stays in RAM."""
return YOLOv8Backbone()
# ===================================================================
# BACKBONES β The Registry Dict
# ===================================================================
BACKBONES = {
"ResNet-18": {
"loader": get_resnet,
"dim": ResNet18Backbone.DIM,
"hook_layer": "layer4 (last conv block)",
},
"MobileNetV3": {
"loader": get_mobilenet,
"dim": MobileNetV3Backbone.DIM,
"hook_layer": "features[-1] (last features block)",
},
"MobileViT-XXS": {
"loader": get_mobilevit,
"dim": MobileViTBackbone.DIM,
"hook_layer": "stages[-1] (last transformer stage)",
},
"YOLOv8n": {
"loader": get_yolov8,
"dim": YOLOv8Backbone.DIM,
"hook_layer": "SPPF (layer 9, end of backbone)",
},
}
|