DariusGiannoli
fix: batch of bug fixes and missing features
2d788b3
"""
src/models.py β€” Central Model Registry
=========================================
Downloads backbone weights **once** from the internet (PyTorch Hub / timm),
freezes every feature-extraction layer, and caches the result in RAM with
Streamlit's ``@st.cache_resource``.
Strategy
--------
1. **Freeze the Backbone** β†’ ``requires_grad = False`` on every parameter.
The backbone is a pure feature extractor β€” no gradient updates, ever.
2. **Cache the Resource** β†’ ``@st.cache_resource`` keeps the heavy model
in RAM even when you switch pages.
3. **Define the Head** β†’ ``RecognitionHead``: a tiny sklearn
LogisticRegression that takes the backbone's feature vector and
produces a recognition score. Lives only in ``st.session_state``.
"""
import streamlit as st
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
import timm
import cv2
import numpy as np
# ---------------------------------------------------------------------------
# Device selection (MPS > CUDA > CPU)
# ---------------------------------------------------------------------------
DEVICE = (
"mps" if torch.backends.mps.is_available() else
"cuda" if torch.cuda.is_available() else
"cpu"
)
# ---------------------------------------------------------------------------
# Shared ImageNet preprocessing
# ---------------------------------------------------------------------------
_IMAGENET_TRANSFORM = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
# ===================================================================
# Base class
# ===================================================================
class _FrozenBackbone:
"""Shared helpers: freeze, normalise activation maps."""
DIM: int = 0 # overridden by subclasses
# --- freeze every parameter ---
def _freeze(self, model: nn.Module) -> nn.Module:
model.eval()
for p in model.parameters():
p.requires_grad = False
return model.to(DEVICE)
# --- public interface ---
def get_features(self, img_bgr: np.ndarray) -> np.ndarray:
"""Return a 1-D float32 feature vector for *img_bgr* (BGR uint8)."""
raise NotImplementedError
def get_activation_maps(self, img_bgr: np.ndarray,
n_maps: int = 6) -> list[np.ndarray]:
"""Return *n_maps* normalised float32 spatial activation maps."""
raise NotImplementedError
@staticmethod
def _norm(m: np.ndarray) -> np.ndarray:
lo, hi = m.min(), m.max()
return ((m - lo) / (hi - lo + 1e-5)).astype(np.float32)
# ===================================================================
# ResNet-18
# ===================================================================
class ResNet18Backbone(_FrozenBackbone):
"""ResNet-18 downloaded from PyTorch Hub, frozen, classifier removed."""
DIM = 512
def __init__(self):
full = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
self.backbone = self._freeze(full)
self.extractor = nn.Sequential(*list(full.children())[:-1]).to(DEVICE)
self.transform = _IMAGENET_TRANSFORM
def get_features(self, img_bgr):
t = self.transform(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB))
with torch.no_grad():
return self.extractor(t.unsqueeze(0).to(DEVICE)).cpu().numpy().flatten()
def get_activation_maps(self, img_bgr, n_maps=6):
cap = {}
hook = self.backbone.layer4.register_forward_hook(
lambda _m, _i, o: cap.update(feat=o))
t = self.transform(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB))
with torch.no_grad():
self.backbone(t.unsqueeze(0).to(DEVICE))
hook.remove()
acts = cap["feat"][0].cpu().numpy()
return [self._norm(acts[i]) for i in range(min(n_maps, acts.shape[0]))]
# ===================================================================
# MobileNetV3-Small
# ===================================================================
class MobileNetV3Backbone(_FrozenBackbone):
"""MobileNetV3-Small from PyTorch Hub, frozen, classifier = Identity."""
DIM = 576
def __init__(self):
self.backbone = models.mobilenet_v3_small(
weights=models.MobileNet_V3_Small_Weights.DEFAULT)
self.backbone.classifier = nn.Identity()
self._freeze(self.backbone)
self.transform = _IMAGENET_TRANSFORM
def get_features(self, img_bgr):
t = self.transform(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB))
with torch.no_grad():
return self.backbone(t.unsqueeze(0).to(DEVICE)).cpu().numpy().flatten()
def get_activation_maps(self, img_bgr, n_maps=6):
cap = {}
hook = self.backbone.features[-1].register_forward_hook(
lambda _m, _i, o: cap.update(feat=o))
t = self.transform(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB))
with torch.no_grad():
self.backbone(t.unsqueeze(0).to(DEVICE))
hook.remove()
acts = cap["feat"][0].cpu().numpy()
return [self._norm(acts[i]) for i in range(min(n_maps, acts.shape[0]))]
# ===================================================================
# MobileViT-XXS
# ===================================================================
class MobileViTBackbone(_FrozenBackbone):
"""MobileViT-XXS from timm (Apple Research), frozen."""
DIM = 320
def __init__(self):
self.backbone = timm.create_model(
"mobilevit_xxs.cvnets_in1k", pretrained=True, num_classes=0)
self._freeze(self.backbone)
cfg = timm.data.resolve_model_data_config(self.backbone)
self.transform = timm.data.create_transform(**cfg, is_training=False)
def _to_tensor(self, img_bgr):
from PIL import Image
pil = Image.fromarray(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB))
return self.transform(pil).unsqueeze(0).to(DEVICE)
def get_features(self, img_bgr):
with torch.no_grad():
return self.backbone(self._to_tensor(img_bgr)).cpu().numpy().flatten()
def get_activation_maps(self, img_bgr, n_maps=6):
cap = {}
hook = self.backbone.stages[-1].register_forward_hook(
lambda _m, _i, o: cap.update(feat=o))
with torch.no_grad():
self.backbone(self._to_tensor(img_bgr))
hook.remove()
acts = cap["feat"][0].cpu().numpy()
return [self._norm(acts[i]) for i in range(min(n_maps, acts.shape[0]))]
# ===================================================================
# YOLOv8n Backbone (feature extractor only β€” no detection head)
# ===================================================================
class YOLOv8Backbone(_FrozenBackbone):
"""YOLOv8-Nano backbone (layers 0–9), frozen, used as feature extractor.
Produces a 256-D embedding via global average pooling after SPPF."""
DIM = 256
def __init__(self):
from ultralytics import YOLO as _YOLO
full = _YOLO("models/yolov8n.pt")
# Layers 0-9 = Conv/C2f backbone + SPPF (before the detection neck)
self.backbone = nn.Sequential(*list(full.model.model[:10]))
self._freeze(self.backbone)
self.transform = _IMAGENET_TRANSFORM
self.pool = nn.AdaptiveAvgPool2d(1)
def get_features(self, img_bgr):
t = self.transform(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB))
with torch.no_grad():
feat = self.backbone(t.unsqueeze(0).to(DEVICE))
return self.pool(feat).cpu().numpy().flatten()
def get_activation_maps(self, img_bgr, n_maps=6):
t = self.transform(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB))
with torch.no_grad():
feat = self.backbone(t.unsqueeze(0).to(DEVICE))
acts = feat[0].cpu().numpy()
return [self._norm(acts[i]) for i in range(min(n_maps, acts.shape[0]))]
# ===================================================================
# Lightweight Head (lives in session state, never on disk)
# ===================================================================
class RecognitionHead:
"""
A tiny trainable layer on top of a frozen backbone.
Wraps sklearn ``LogisticRegression`` for binary classification.
Stored in ``st.session_state`` and optionally persisted to disk.
"""
def __init__(self, C: float = 1.0, max_iter: int = 1000):
from sklearn.linear_model import LogisticRegression
self.model = LogisticRegression(C=C, max_iter=max_iter)
self.is_trained = False
def fit(self, X, y):
self.model.fit(X, y)
self.is_trained = True
return self
def predict(self, features: np.ndarray):
"""Return *(label, confidence)* for a single feature vector."""
probs = self.model.predict_proba([features])[0]
idx = int(np.argmax(probs))
return self.model.classes_[idx], probs[idx]
def predict_proba(self, X):
return self.model.predict_proba(X)
@property
def classes_(self):
return self.model.classes_
def save(self, path: str):
"""Persist the trained head to *path* via joblib."""
import joblib
from pathlib import Path
Path(path).parent.mkdir(parents=True, exist_ok=True)
joblib.dump({"model": self.model, "is_trained": self.is_trained}, path)
@classmethod
def load(cls, path: str) -> "RecognitionHead":
"""Load a previously saved head from *path*."""
import joblib
data = joblib.load(path)
head = cls.__new__(cls)
head.model = data["model"]
head.is_trained = data["is_trained"]
return head
# ===================================================================
# Cached loaders β€” @st.cache_resource keeps models in RAM
# ===================================================================
@st.cache_resource
def get_resnet() -> ResNet18Backbone:
"""Download & freeze ResNet-18. Stays in RAM across page switches."""
return ResNet18Backbone()
@st.cache_resource
def get_mobilenet() -> MobileNetV3Backbone:
"""Download & freeze MobileNetV3-Small. Stays in RAM."""
return MobileNetV3Backbone()
@st.cache_resource
def get_mobilevit() -> MobileViTBackbone:
"""Download & freeze MobileViT-XXS. Stays in RAM."""
return MobileViTBackbone()
@st.cache_resource
def get_yolov8() -> YOLOv8Backbone:
"""Load & freeze YOLOv8n backbone. Stays in RAM."""
return YOLOv8Backbone()
# ===================================================================
# BACKBONES β€” The Registry Dict
# ===================================================================
BACKBONES = {
"ResNet-18": {
"loader": get_resnet,
"dim": ResNet18Backbone.DIM,
"hook_layer": "layer4 (last conv block)",
},
"MobileNetV3": {
"loader": get_mobilenet,
"dim": MobileNetV3Backbone.DIM,
"hook_layer": "features[-1] (last features block)",
},
"MobileViT-XXS": {
"loader": get_mobilevit,
"dim": MobileViTBackbone.DIM,
"hook_layer": "stages[-1] (last transformer stage)",
},
"YOLOv8n": {
"loader": get_yolov8,
"dim": YOLOv8Backbone.DIM,
"hook_layer": "SPPF (layer 9, end of backbone)",
},
}