Reunite / core.py
Shoraky's picture
Complete backend folder
7e710af
Raw
History Blame Contribute Delete
16.4 kB
"""
core.py β€” Deep Learning Engine for REFIND
Pipeline: InsightFace (buffalo_l / ArcFace) ➜ 512-d embedding ➜ FAISS IndexFlatIP
No database. All state lives in /storage (CSV + .npy + photo.jpg per person).
"""
from __future__ import annotations
import json
import io
import os
import shutil
import uuid
import warnings
from datetime import datetime
from pathlib import Path
import time
import logging
import contextlib
from typing import Optional
import cv2
import faiss
import numpy as np
import pandas as pd
# ─────────────────────────────────────────────────────────────────────────────
# Storage Layout
# Storage/
# persons/{ID}/photo.jpg β€” original upload
# persons/{ID}/embedding.npy β€” 512-d L2-norm ArcFace vector
# registry.csv β€” metadata for all persons
# embeddings_map.json β€” {id: abs_path_to_embedding}
# Weights/ β€” InsightFace / ArcFace model cache
# ─────────────────────────────────────────────────────────────────────────────
BASE_DIR = Path(__file__).parent
STORAGE_DIR = BASE_DIR / "Storage"
PERSONS_DIR = STORAGE_DIR / "persons"
WEIGHTS_DIR = BASE_DIR / "Weights"
REGISTRY_CSV = STORAGE_DIR / "registry.csv"
EMB_MAP_PATH = STORAGE_DIR / "embeddings_map.json"
EMBEDDING_DIM = 512 # ArcFace output dimensionality
# Cosine-similarity thresholds (inner product of L2-normalized vectors)
THRESH_VERY_HIGH = 0.68
THRESH_HIGH = 0.52
THRESH_MEDIUM = 0.38 # default search cut-off
CSV_COLUMNS = [
"id", "name", "age", "gender",
"last_seen_date", "last_seen_location",
"phone_contact", "address", "national_id",
"description", "registered_at", "status",
]
# ─────────────────────────────────────────────────────────────────────────────
# Directory bootstrap
# ─────────────────────────────────────────────────────────────────────────────
def ensure_dirs() -> None:
for d in [STORAGE_DIR, PERSONS_DIR, WEIGHTS_DIR]:
d.mkdir(parents=True, exist_ok=True)
# ─────────────────────────────────────────────────────────────────────────────
# InsightFace / ArcFace β€” lazy singleton
# ─────────────────────────────────────────────────────────────────────────────
_face_app = None
def get_face_app():
"""
Lazy-load InsightFace FaceAnalysis with buffalo_l (ArcFace R100).
Falls back to CPU if CUDA is unavailable.
Model weights are cached under Weights/ on first call.
"""
global _face_app
if _face_app is None:
os.environ.setdefault("ORT_LOG_SEVERITY_LEVEL", "3")
os.environ.setdefault("INSIGHTFACE_LOG_LEVEL", "ERROR")
logging.getLogger("onnxruntime").setLevel(logging.ERROR)
logging.getLogger("insightface").setLevel(logging.ERROR)
warnings.filterwarnings(
"ignore",
message=r"`rcond` parameter will change to the default.*",
category=FutureWarning,
)
sink = io.StringIO()
with contextlib.redirect_stdout(sink), contextlib.redirect_stderr(sink):
from insightface.app import FaceAnalysis
_face_app = FaceAnalysis(
name="buffalo_l",
root=str(WEIGHTS_DIR),
providers=["CPUExecutionProvider"],
)
# ctx_id=-1 forces CPU mode
_face_app.prepare(ctx_id=-1, det_size=(640, 640))
return _face_app
# ─────────────────────────────────────────────────────────────────────────────
# Embedding Extraction
# ─────────────────────────────────────────────────────────────────────────────
def extract_embedding(image_bytes: bytes) -> Optional[np.ndarray]:
"""
Decode image bytes β†’ detect all faces β†’ return the 512-d L2-normalized
ArcFace embedding of the largest face in the frame.
Returns None if no face is detected or image is corrupt.
"""
arr = np.frombuffer(image_bytes, np.uint8)
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
if img is None:
return None
faces = get_face_app().get(img)
if not faces:
return None
# Select the spatially largest face (highest confidence in most scenarios)
best = max(
faces,
key=lambda f: (f.bbox[2] - f.bbox[0]) * (f.bbox[3] - f.bbox[1]),
)
# normed_embedding is already L2-normalized by InsightFace
return best.normed_embedding.astype(np.float32)
# ─────────────────────────────────────────────────────────────────────────────
# CSV Registry helpers
# ─────────────────────────────────────────────────────────────────────────────
def load_registry() -> pd.DataFrame:
if not REGISTRY_CSV.exists():
df = pd.DataFrame(columns=CSV_COLUMNS)
df.to_csv(REGISTRY_CSV, index=False)
return df
return pd.read_csv(REGISTRY_CSV, dtype=str).fillna("")
def save_registry(df: pd.DataFrame) -> None:
df.to_csv(REGISTRY_CSV, index=False)
# ─────────────────────────────────────────────────────────────────────────────
# Embeddings map helpers {person_id: abs_path_to_npy}
# ─────────────────────────────────────────────────────────────────────────────
def load_emb_map() -> dict:
if not EMB_MAP_PATH.exists():
return {}
with open(EMB_MAP_PATH) as fh:
return json.load(fh)
def save_emb_map(mapping: dict) -> None:
with open(EMB_MAP_PATH, "w") as fh:
json.dump(mapping, fh)
# ─────────────────────────────────────────────────────────────────────────────
# FAISS index β€” rebuilt on each search call
# For datasets < ~50 k persons this is near-instant (<10 ms).
# For larger datasets swap IndexFlatIP for IndexIVFFlat + periodic training.
# ─────────────────────────────────────────────────────────────────────────────
def build_faiss_index() -> tuple[faiss.Index, list[str]]:
"""
Load all stored embeddings and add them to a fresh FAISS IndexFlatIP.
Inner-product on L2-normalized vectors == cosine similarity.
Returns (index, ordered_ids) where ordered_ids[i] maps to index row i.
"""
emb_map = load_emb_map()
ids: list[str] = []
vecs: list[np.ndarray] = []
for pid, emb_path in emb_map.items():
p = Path(emb_path)
if p.exists():
vecs.append(np.load(str(p)))
ids.append(pid)
index = faiss.IndexFlatIP(EMBEDDING_DIM)
if vecs:
matrix = np.stack(vecs).astype(np.float32)
index.add(matrix)
return index, ids
# ─────────────────────────────────────────────────────────────────────────────
# Confidence label helper
# ─────────────────────────────────────────────────────────────────────────────
def confidence_label(sim: float) -> str:
if sim >= THRESH_VERY_HIGH:
return "Very High"
if sim >= THRESH_HIGH:
return "High"
return "Medium"
# ─────────────────────────────────────────────────────────────────────────────
# Public API
# ─────────────────────────────────────────────────────────────────────────────
def register_missing_person(image_bytes: bytes, details: dict) -> dict:
"""
Full registration pipeline
──────────────────────────
1. Decode + detect face β†’ extract ArcFace embedding
2. Generate 8-char uppercase UUID
3. Write photo.jpg to storage/persons/{ID}/
4. Write embedding.npy to storage/persons/{ID}/
5. Append row to registry.csv
6. Update embeddings_map.json (used by build_faiss_index)
Returns {"success": True, "id": <ID>} or {"success": False, "error": ...}
"""
ensure_dirs()
t0 = time.perf_counter()
t_emb0 = time.perf_counter()
emb = extract_embedding(image_bytes)
t_emb1 = time.perf_counter()
if emb is None:
return {
"success": False,
"error": (
"No face detected. Please upload a clear, well-lit photo "
"showing the person's face without occlusion."
),
}
person_id = uuid.uuid4().hex[:8].upper()
person_dir = PERSONS_DIR / person_id
person_dir.mkdir(parents=True, exist_ok=True)
# ── Photo ──────────────────────────────────────────────────────────────
t_io0 = time.perf_counter()
arr = np.frombuffer(image_bytes, np.uint8)
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
cv2.imwrite(str(person_dir / "photo.jpg"), img, [cv2.IMWRITE_JPEG_QUALITY, 92])
# ── Embedding ──────────────────────────────────────────────────────────
emb_path = str(person_dir / "embedding.npy")
np.save(emb_path, emb)
# ── Registry CSV ───────────────────────────────────────────────────────
df = load_registry()
row = {
"id": person_id,
"name": details.get("name", "Unknown").strip(),
"age": details.get("age", ""),
"gender": details.get("gender", ""),
"last_seen_date": details.get("last_seen_date", ""),
"last_seen_location": details.get("last_seen_location", ""),
"phone_contact": details.get("phone_contact", ""),
"address": details.get("address", ""),
"national_id": details.get("national_id", ""),
"description": details.get("description", ""),
"registered_at": datetime.now().strftime("%Y-%m-%d %H:%M"),
"status": "missing",
}
df = pd.concat([df, pd.DataFrame([row])], ignore_index=True)
save_registry(df)
# ── Embeddings map ─────────────────────────────────────────────────────
emb_map = load_emb_map()
emb_map[person_id] = emb_path
save_emb_map(emb_map)
t_io1 = time.perf_counter()
t1 = time.perf_counter()
timing_ms = {
"embedding_ms": round((t_emb1 - t_emb0) * 1000, 1),
"io_ms": round((t_io1 - t_io0) * 1000, 1),
"total_ms": round((t1 - t0) * 1000, 1),
}
return {"success": True, "id": person_id, "timing_ms": timing_ms}
def search_person(
image_bytes: bytes,
top_k: int = 5,
threshold: float = THRESH_MEDIUM,
) -> dict:
"""
Search pipeline
───────────────
1. Extract ArcFace embedding from query image
2. Build FAISS IndexFlatIP from all stored embeddings
3. k-NN inner-product search (cosine similarity)
4. Filter results below threshold, rank descending, enrich with metadata
Returns {"success": True, "matches": [...]}
Each match: id, similarity (0-100), confidence label, full metadata.
"""
ensure_dirs()
q_emb = extract_embedding(image_bytes)
if q_emb is None:
return {"success": False, "error": "No face detected in the search image."}
index, ordered_ids = build_faiss_index()
if index.ntotal == 0:
return {"success": True, "matches": [], "message": "Registry is currently empty."}
k = min(top_k, index.ntotal)
scores, indices = index.search(q_emb.reshape(1, -1), k)
df = load_registry()
matches = []
for score, idx in zip(scores[0], indices[0]):
if idx < 0:
continue
sim = float(score)
if sim < threshold:
continue
pid = ordered_ids[idx]
rows = df[df["id"] == pid]
if rows.empty:
continue
p = rows.iloc[0]
matches.append({
"id": pid,
"similarity": round(sim * 100, 1), # as percentage
"confidence": confidence_label(sim),
"name": p["name"],
"age": p["age"],
"gender": p["gender"],
"last_seen_date": p["last_seen_date"],
"last_seen_location": p["last_seen_location"],
"phone_contact": p["phone_contact"],
"address": p["address"],
"description": p["description"],
"registered_at": p["registered_at"],
"status": p.get("status", "missing"),
})
matches.sort(key=lambda x: x["similarity"], reverse=True)
return {"success": True, "matches": matches}
def get_all_persons() -> list[dict]:
ensure_dirs()
return load_registry().to_dict(orient="records")
def delete_person(person_id: str) -> dict:
ensure_dirs()
df = load_registry()
if person_id not in df["id"].values:
return {"success": False, "error": "Person not found in registry."}
df = df[df["id"] != person_id]
save_registry(df)
emb_map = load_emb_map()
emb_map.pop(person_id, None)
save_emb_map(emb_map)
person_dir = PERSONS_DIR / person_id
if person_dir.exists():
shutil.rmtree(str(person_dir))
return {"success": True, "message": f"Person {person_id} removed from registry."}
def update_person_status(person_id: str, status: str) -> dict:
"""Update status field: 'missing' | 'found'"""
df = load_registry()
if person_id not in df["id"].values:
return {"success": False, "error": "Person not found."}
df.loc[df["id"] == person_id, "status"] = status
save_registry(df)
return {"success": True}