Spaces:
Sleeping
Sleeping
File size: 5,714 Bytes
2758540 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 | """
vision/reid.py - Cross-Camera Person Re-Identification using ViT + FAISS
"""
import os
import time
import numpy as np
import faiss
import torch
import torch.nn.functional as F
from PIL import Image
from typing import List, Dict, Optional, Tuple
from transformers import ViTImageProcessor, ViTModel
from loguru import logger
from config import settings, DEVICE, FAISS_DIR
class PersonReID:
"""
Person Re-Identification using google/vit-base-patch16-224 embeddings.
Embeddings are stored in a FAISS IndexFlatIP (inner product = cosine after normalization).
"""
INDEX_FILE = str(FAISS_DIR / "reid_index.faiss")
META_FILE = str(FAISS_DIR / "reid_meta.npy")
def __init__(self):
logger.info(f"Loading ReID model: {settings.REID_MODEL}")
self.processor = ViTImageProcessor.from_pretrained(settings.REID_MODEL)
self.model = ViTModel.from_pretrained(settings.REID_MODEL)
self.model.to(DEVICE)
self.model.eval()
self.dim = settings.REID_EMBEDDING_DIM
self.index = self._load_or_create_index()
# meta list: maps faiss internal id (row index) β {"person_id": str, "camera_id": str}
self.meta: List[Dict] = self._load_meta()
logger.info(f"β
ReID ready. FAISS index size: {self.index.ntotal}")
# ββ Index Management ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def _load_or_create_index(self) -> faiss.IndexFlatIP:
if os.path.exists(self.INDEX_FILE):
logger.info("Loading existing FAISS ReID index.")
return faiss.read_index(self.INDEX_FILE)
logger.info("Creating new FAISS ReID index (IndexFlatIP).")
return faiss.IndexFlatIP(self.dim)
def _load_meta(self) -> List[Dict]:
if os.path.exists(self.META_FILE):
data = np.load(self.META_FILE, allow_pickle=True)
return list(data)
return []
def save(self):
faiss.write_index(self.index, self.INDEX_FILE)
np.save(self.META_FILE, np.array(self.meta, dtype=object))
logger.debug("FAISS ReID index saved.")
# ββ Embedding Extraction ββββββββββββββββββββββββββββββββββββββββββββββββββ
@torch.inference_mode()
def extract_embedding(self, image: Image.Image) -> np.ndarray:
"""Extract L2-normalized ViT CLS token embedding from a cropped person image."""
inputs = self.processor(images=image, return_tensors="pt")
inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
outputs = self.model(**inputs)
# CLS token β (1, hidden_size)
cls = outputs.last_hidden_state[:, 0, :]
# L2 normalize for cosine similarity via inner product
embedding = F.normalize(cls, p=2, dim=-1).cpu().numpy().astype(np.float32)
return embedding # shape: (1, 768)
# ββ Gallery Operations ββββββββββββββββββββββββββββββββββββββββββββββββββββ
def add_person(self, image: Image.Image, person_id: str, camera_id: str) -> int:
"""Add a new person embedding to the FAISS gallery. Returns faiss_id."""
embedding = self.extract_embedding(image)
faiss_id = self.index.ntotal
self.index.add(embedding)
self.meta.append({"person_id": person_id, "camera_id": camera_id, "faiss_id": faiss_id})
self.save()
return faiss_id
def search(
self,
image: Image.Image,
top_k: int = 5,
similarity_threshold: float = 0.85,
) -> List[Dict]:
"""
Search gallery for matching persons.
Returns:
list of {"person_id": str, "camera_id": str, "similarity": float, "faiss_id": int}
"""
if self.index.ntotal == 0:
return []
t0 = time.perf_counter()
query = self.extract_embedding(image)
k = min(top_k, self.index.ntotal)
distances, indices = self.index.search(query, k)
latency = (time.perf_counter() - t0) * 1000
results = []
for dist, idx in zip(distances[0], indices[0]):
if idx == -1:
continue
similarity = float(dist)
if similarity >= similarity_threshold:
meta = self.meta[idx]
results.append({
"person_id": meta["person_id"],
"camera_id": meta["camera_id"],
"similarity": round(similarity, 4),
"faiss_id": int(idx),
})
logger.debug(f"ReID search: {len(results)} matches in {latency:.1f}ms")
return results
def search_by_embedding(
self,
embedding: np.ndarray,
top_k: int = 5,
similarity_threshold: float = 0.85,
) -> List[Dict]:
"""Direct search with a precomputed embedding."""
if self.index.ntotal == 0:
return []
k = min(top_k, self.index.ntotal)
distances, indices = self.index.search(embedding, k)
results = []
for dist, idx in zip(distances[0], indices[0]):
if idx == -1 or float(dist) < similarity_threshold:
continue
meta = self.meta[idx]
results.append({
"person_id": meta["person_id"],
"camera_id": meta["camera_id"],
"similarity": round(float(dist), 4),
"faiss_id": int(idx),
})
return results
|