Spaces:
Running
Fix face recognition: BGR→RGB, robust zip extraction, schema versioning
Browse filesTwo bugs made recognition impossible:
1. BGR→RGB: MobileFaceNet is trained on RGB images. _embed() was feeding
raw OpenCV BGR arrays, producing systematically wrong embeddings.
Fixed by adding cv2.cvtColor(BGR2RGB) before normalisation.
2. Zip entry path: w600k_mbf.onnx lives inside a buffalo_sc/ subdirectory
in the release zip, not at the root. The hardcoded entry name caused
extraction to fail (KeyError), so the model was never written to disk.
Fixed by searching the zip namelist for any entry ending in the filename.
Additional improvements:
- SCHEMA_VERSION (=3) in face_db.py: load() auto-clears the DB when the
embedding pipeline changes, so stale BGR embeddings from older code are
discarded automatically on the first run after upgrade.
- Threshold raised back to 0.35 (correct RGB embeddings are more consistent).
- Detection more lenient: minNeighbors 4→3, minSize 60→40.
- INFO/DEBUG logging of similarity scores for diagnosability.
- POST /clear_db endpoint + "Clear face database" button in the UI so the
user can force re-enrollment without SSH access.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- recognizer/face_db.py +60 -26
- recognizer/main.py +7 -1
- recognizer/static/index.html +5 -0
- recognizer/static/main.js +14 -0
- recognizer/static/style.css +15 -0
|
@@ -8,7 +8,8 @@ Alignment : eye-centre similarity transform to the InsightFace 112×112
|
|
| 8 |
Matching : cosine similarity on L2-normalised 512-D embeddings.
|
| 9 |
Storage : recognizer/face_db.json (gitignored).
|
| 10 |
|
| 11 |
-
|
|
|
|
| 12 |
"""
|
| 13 |
|
| 14 |
import json
|
|
@@ -24,6 +25,10 @@ import onnxruntime as ort
|
|
| 24 |
|
| 25 |
logger = logging.getLogger(__name__)
|
| 26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
DB_PATH = Path(__file__).parent / "face_db.json"
|
| 28 |
MODEL_DIR = Path(__file__).parent / "models"
|
| 29 |
MODEL_FILE = MODEL_DIR / "w600k_mbf.onnx"
|
|
@@ -31,7 +36,6 @@ MODEL_URL = (
|
|
| 31 |
"https://github.com/deepinsight/insightface"
|
| 32 |
"/releases/download/v0.7/buffalo_sc.zip"
|
| 33 |
)
|
| 34 |
-
_REC_ENTRY = "w600k_mbf.onnx" # path inside the zip (root-level since buffalo_sc v0.7)
|
| 35 |
|
| 36 |
_CASCADE = cv2.CascadeClassifier(
|
| 37 |
cv2.data.haarcascades + "haarcascade_frontalface_default.xml"
|
|
@@ -58,13 +62,20 @@ def _ensure_model() -> None:
|
|
| 58 |
return
|
| 59 |
MODEL_DIR.mkdir(exist_ok=True)
|
| 60 |
zip_path = MODEL_DIR / "buffalo_sc.zip"
|
| 61 |
-
logger.info("Downloading face recognition model (~17 MB) — one-time setup
|
| 62 |
urllib.request.urlretrieve(MODEL_URL, zip_path)
|
| 63 |
with zipfile.ZipFile(zip_path) as zf:
|
| 64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
dst.write(src.read())
|
| 66 |
zip_path.unlink()
|
| 67 |
-
logger.info("Model
|
| 68 |
|
| 69 |
|
| 70 |
def _get_session() -> ort.InferenceSession:
|
|
@@ -80,22 +91,16 @@ def _get_session() -> ort.InferenceSession:
|
|
| 80 |
def _detect(frame_bgr: np.ndarray) -> list[tuple[int, int, int, int]]:
|
| 81 |
gray = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2GRAY)
|
| 82 |
boxes = _CASCADE.detectMultiScale(
|
| 83 |
-
gray, scaleFactor=1.1, minNeighbors=
|
| 84 |
)
|
| 85 |
return [tuple(b) for b in boxes] if len(boxes) > 0 else []
|
| 86 |
|
| 87 |
|
| 88 |
def _align(face_bgr: np.ndarray) -> np.ndarray:
|
| 89 |
-
"""Return a 112×112 crop aligned on eye centres; plain resize as fallback.
|
| 90 |
-
|
| 91 |
-
MobileFaceNet is trained on faces warped to a canonical eye position.
|
| 92 |
-
Without this step, embeddings from different frames of the same person
|
| 93 |
-
can be too dissimilar for reliable matching.
|
| 94 |
-
"""
|
| 95 |
gray = cv2.cvtColor(face_bgr, cv2.COLOR_BGR2GRAY)
|
| 96 |
eyes = _EYE_CASCADE.detectMultiScale(gray, scaleFactor=1.1, minNeighbors=3)
|
| 97 |
if len(eyes) >= 2:
|
| 98 |
-
# Pick the two largest detections and sort left-to-right
|
| 99 |
eyes = sorted(eyes, key=lambda e: e[2] * e[3], reverse=True)[:2]
|
| 100 |
eyes = sorted(eyes, key=lambda e: e[0])
|
| 101 |
src = np.float32([
|
|
@@ -109,12 +114,16 @@ def _align(face_bgr: np.ndarray) -> np.ndarray:
|
|
| 109 |
|
| 110 |
|
| 111 |
def _embed(face_bgr: np.ndarray) -> np.ndarray:
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
sess = _get_session()
|
| 116 |
emb = sess.run(None, {sess.get_inputs()[0].name: inp})[0][0]
|
| 117 |
-
return emb / np.linalg.norm(emb)
|
| 118 |
|
| 119 |
|
| 120 |
# ---------------------------------------------------------------------------
|
|
@@ -122,25 +131,46 @@ def _embed(face_bgr: np.ndarray) -> np.ndarray:
|
|
| 122 |
# ---------------------------------------------------------------------------
|
| 123 |
|
| 124 |
def load() -> dict[str, list[list[float]]]:
|
| 125 |
-
"""Load face DB from disk and warm up the ONNX session.
|
| 126 |
-
|
| 127 |
-
if
|
| 128 |
-
|
| 129 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
|
| 131 |
|
| 132 |
def save(db: dict[str, list[list[float]]]) -> None:
|
| 133 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
|
| 135 |
|
| 136 |
def find_match(
|
| 137 |
frame_bgr: np.ndarray,
|
| 138 |
db: dict[str, list[list[float]]],
|
| 139 |
-
threshold: float = 0.
|
| 140 |
) -> Optional[str]:
|
| 141 |
"""Return matched name if recognised, None if face present but unknown.
|
| 142 |
|
| 143 |
-
Raises NoFaceDetected if no face
|
| 144 |
"""
|
| 145 |
boxes = _detect(frame_bgr)
|
| 146 |
if not boxes:
|
|
@@ -156,9 +186,13 @@ def find_match(
|
|
| 156 |
if sim > best_sim:
|
| 157 |
best_sim, best_name = sim, name
|
| 158 |
|
|
|
|
| 159 |
if best_name is not None and best_sim >= threshold:
|
|
|
|
| 160 |
return best_name
|
| 161 |
-
|
|
|
|
|
|
|
| 162 |
|
| 163 |
|
| 164 |
def add_face(
|
|
|
|
| 8 |
Matching : cosine similarity on L2-normalised 512-D embeddings.
|
| 9 |
Storage : recognizer/face_db.json (gitignored).
|
| 10 |
|
| 11 |
+
Bump SCHEMA_VERSION whenever the embedding pipeline changes so that stale
|
| 12 |
+
DB entries from older code are automatically discarded on load.
|
| 13 |
"""
|
| 14 |
|
| 15 |
import json
|
|
|
|
| 25 |
|
| 26 |
logger = logging.getLogger(__name__)
|
| 27 |
|
| 28 |
+
# Bump this whenever the embedding pipeline changes (alignment, colour space,
|
| 29 |
+
# model weights, normalisation, …). Mismatched DBs are auto-cleared on load.
|
| 30 |
+
SCHEMA_VERSION = 3 # 1=plain-resize BGR 2=aligned BGR 3=aligned RGB
|
| 31 |
+
|
| 32 |
DB_PATH = Path(__file__).parent / "face_db.json"
|
| 33 |
MODEL_DIR = Path(__file__).parent / "models"
|
| 34 |
MODEL_FILE = MODEL_DIR / "w600k_mbf.onnx"
|
|
|
|
| 36 |
"https://github.com/deepinsight/insightface"
|
| 37 |
"/releases/download/v0.7/buffalo_sc.zip"
|
| 38 |
)
|
|
|
|
| 39 |
|
| 40 |
_CASCADE = cv2.CascadeClassifier(
|
| 41 |
cv2.data.haarcascades + "haarcascade_frontalface_default.xml"
|
|
|
|
| 62 |
return
|
| 63 |
MODEL_DIR.mkdir(exist_ok=True)
|
| 64 |
zip_path = MODEL_DIR / "buffalo_sc.zip"
|
| 65 |
+
logger.info("Downloading face recognition model (~17 MB) — one-time setup…")
|
| 66 |
urllib.request.urlretrieve(MODEL_URL, zip_path)
|
| 67 |
with zipfile.ZipFile(zip_path) as zf:
|
| 68 |
+
# The file may live at root or inside a named subdirectory (e.g. buffalo_sc/).
|
| 69 |
+
matches = [n for n in zf.namelist() if n.endswith("w600k_mbf.onnx")]
|
| 70 |
+
if not matches:
|
| 71 |
+
raise RuntimeError(
|
| 72 |
+
f"w600k_mbf.onnx not found in downloaded zip. "
|
| 73 |
+
f"Available entries: {zf.namelist()}"
|
| 74 |
+
)
|
| 75 |
+
with zf.open(matches[0]) as src, open(MODEL_FILE, "wb") as dst:
|
| 76 |
dst.write(src.read())
|
| 77 |
zip_path.unlink()
|
| 78 |
+
logger.info("Model saved to %s", MODEL_FILE)
|
| 79 |
|
| 80 |
|
| 81 |
def _get_session() -> ort.InferenceSession:
|
|
|
|
| 91 |
def _detect(frame_bgr: np.ndarray) -> list[tuple[int, int, int, int]]:
|
| 92 |
gray = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2GRAY)
|
| 93 |
boxes = _CASCADE.detectMultiScale(
|
| 94 |
+
gray, scaleFactor=1.1, minNeighbors=3, minSize=(40, 40)
|
| 95 |
)
|
| 96 |
return [tuple(b) for b in boxes] if len(boxes) > 0 else []
|
| 97 |
|
| 98 |
|
| 99 |
def _align(face_bgr: np.ndarray) -> np.ndarray:
|
| 100 |
+
"""Return a 112×112 crop aligned on eye centres; plain resize as fallback."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
gray = cv2.cvtColor(face_bgr, cv2.COLOR_BGR2GRAY)
|
| 102 |
eyes = _EYE_CASCADE.detectMultiScale(gray, scaleFactor=1.1, minNeighbors=3)
|
| 103 |
if len(eyes) >= 2:
|
|
|
|
| 104 |
eyes = sorted(eyes, key=lambda e: e[2] * e[3], reverse=True)[:2]
|
| 105 |
eyes = sorted(eyes, key=lambda e: e[0])
|
| 106 |
src = np.float32([
|
|
|
|
| 114 |
|
| 115 |
|
| 116 |
def _embed(face_bgr: np.ndarray) -> np.ndarray:
|
| 117 |
+
"""Return an L2-normalised 512-D embedding for face_bgr."""
|
| 118 |
+
face_112 = _align(face_bgr)
|
| 119 |
+
# MobileFaceNet (InsightFace) is trained on RGB — convert from OpenCV BGR.
|
| 120 |
+
face_rgb = cv2.cvtColor(face_112, cv2.COLOR_BGR2RGB)
|
| 121 |
+
img = face_rgb.astype(np.float32)
|
| 122 |
+
img = (img - 127.5) / 127.5 # normalise to [-1, 1]
|
| 123 |
+
inp = np.transpose(img, (2, 0, 1))[np.newaxis] # NCHW
|
| 124 |
sess = _get_session()
|
| 125 |
emb = sess.run(None, {sess.get_inputs()[0].name: inp})[0][0]
|
| 126 |
+
return emb / np.linalg.norm(emb) # L2-normalise
|
| 127 |
|
| 128 |
|
| 129 |
# ---------------------------------------------------------------------------
|
|
|
|
| 131 |
# ---------------------------------------------------------------------------
|
| 132 |
|
| 133 |
def load() -> dict[str, list[list[float]]]:
|
| 134 |
+
"""Load face DB from disk and warm up the ONNX session.
|
| 135 |
+
|
| 136 |
+
Returns an empty dict if the DB is missing or was produced by an older
|
| 137 |
+
embedding pipeline (schema mismatch → auto-clear).
|
| 138 |
+
"""
|
| 139 |
+
_get_session()
|
| 140 |
+
if not DB_PATH.exists():
|
| 141 |
+
return {}
|
| 142 |
+
raw = json.loads(DB_PATH.read_text())
|
| 143 |
+
if raw.get("_schema") != SCHEMA_VERSION:
|
| 144 |
+
logger.warning(
|
| 145 |
+
"face_db schema mismatch (file=%s expected=%s) — clearing stale embeddings",
|
| 146 |
+
raw.get("_schema"), SCHEMA_VERSION,
|
| 147 |
+
)
|
| 148 |
+
DB_PATH.unlink()
|
| 149 |
+
return {}
|
| 150 |
+
return {k: v for k, v in raw.items() if not k.startswith("_")}
|
| 151 |
|
| 152 |
|
| 153 |
def save(db: dict[str, list[list[float]]]) -> None:
|
| 154 |
+
out: dict = {"_schema": SCHEMA_VERSION}
|
| 155 |
+
out.update(db)
|
| 156 |
+
DB_PATH.write_text(json.dumps(out, indent=2))
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def wipe() -> None:
|
| 160 |
+
"""Delete all enrolled faces from disk."""
|
| 161 |
+
if DB_PATH.exists():
|
| 162 |
+
DB_PATH.unlink()
|
| 163 |
+
logger.info("Face database cleared")
|
| 164 |
|
| 165 |
|
| 166 |
def find_match(
|
| 167 |
frame_bgr: np.ndarray,
|
| 168 |
db: dict[str, list[list[float]]],
|
| 169 |
+
threshold: float = 0.35,
|
| 170 |
) -> Optional[str]:
|
| 171 |
"""Return matched name if recognised, None if face present but unknown.
|
| 172 |
|
| 173 |
+
Raises NoFaceDetected if no face is detected in the frame at all.
|
| 174 |
"""
|
| 175 |
boxes = _detect(frame_bgr)
|
| 176 |
if not boxes:
|
|
|
|
| 186 |
if sim > best_sim:
|
| 187 |
best_sim, best_name = sim, name
|
| 188 |
|
| 189 |
+
logger.debug("Best match: %s sim=%.3f threshold=%.2f", best_name, best_sim, threshold)
|
| 190 |
if best_name is not None and best_sim >= threshold:
|
| 191 |
+
logger.info("Recognised: %s (sim=%.3f)", best_name, best_sim)
|
| 192 |
return best_name
|
| 193 |
+
if best_name is not None:
|
| 194 |
+
logger.info("Face detected but not recognised (best sim=%.3f < %.2f)", best_sim, threshold)
|
| 195 |
+
return None
|
| 196 |
|
| 197 |
|
| 198 |
def add_face(
|
|
@@ -16,7 +16,7 @@ import numpy as np
|
|
| 16 |
from pydantic import BaseModel
|
| 17 |
from reachy_mini import ReachyMini, ReachyMiniApp
|
| 18 |
|
| 19 |
-
from recognizer.face_db import NoFaceDetected, add_face, find_match
|
| 20 |
from recognizer.face_db import load as load_face_db
|
| 21 |
from recognizer.tts import speak
|
| 22 |
|
|
@@ -60,6 +60,12 @@ class Recognizer(ReachyMiniApp):
|
|
| 60 |
_shared["pending_name"] = payload.name.strip()
|
| 61 |
return {"ok": True}
|
| 62 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
@self.settings_app.get("/status")
|
| 64 |
def get_status():
|
| 65 |
with _lock:
|
|
|
|
| 16 |
from pydantic import BaseModel
|
| 17 |
from reachy_mini import ReachyMini, ReachyMiniApp
|
| 18 |
|
| 19 |
+
from recognizer.face_db import NoFaceDetected, add_face, find_match, wipe as wipe_face_db
|
| 20 |
from recognizer.face_db import load as load_face_db
|
| 21 |
from recognizer.tts import speak
|
| 22 |
|
|
|
|
| 60 |
_shared["pending_name"] = payload.name.strip()
|
| 61 |
return {"ok": True}
|
| 62 |
|
| 63 |
+
@self.settings_app.post("/clear_db")
|
| 64 |
+
def clear_db():
|
| 65 |
+
wipe_face_db()
|
| 66 |
+
face_db.clear()
|
| 67 |
+
return {"ok": True}
|
| 68 |
+
|
| 69 |
@self.settings_app.get("/status")
|
| 70 |
def get_status():
|
| 71 |
with _lock:
|
|
@@ -28,6 +28,11 @@
|
|
| 28 |
<div id="enroll-status"></div>
|
| 29 |
</div>
|
| 30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
<script src="/static/main.js"></script>
|
| 32 |
</body>
|
| 33 |
|
|
|
|
| 28 |
<div id="enroll-status"></div>
|
| 29 |
</div>
|
| 30 |
|
| 31 |
+
<div id="admin-section">
|
| 32 |
+
<button id="clear-db-btn" class="danger">🗑 Clear face database</button>
|
| 33 |
+
<div id="clear-db-status"></div>
|
| 34 |
+
</div>
|
| 35 |
+
|
| 36 |
<script src="/static/main.js"></script>
|
| 37 |
</body>
|
| 38 |
|
|
@@ -72,6 +72,20 @@ document.getElementById("name-input").addEventListener("keydown", (e) => {
|
|
| 72 |
if (e.key === "Enter") submitName();
|
| 73 |
});
|
| 74 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
// Poll every second
|
| 76 |
setInterval(pollStatus, 1000);
|
| 77 |
pollStatus();
|
|
|
|
| 72 |
if (e.key === "Enter") submitName();
|
| 73 |
});
|
| 74 |
|
| 75 |
+
async function clearDb() {
|
| 76 |
+
if (!confirm("Delete all enrolled faces? The robot will not recognise anyone until they enroll again.")) return;
|
| 77 |
+
try {
|
| 78 |
+
const resp = await fetch("/clear_db", { method: "POST" });
|
| 79 |
+
const data = await resp.json();
|
| 80 |
+
document.getElementById("clear-db-status").textContent =
|
| 81 |
+
data.ok ? "✓ Database cleared – please re-enroll." : "Error clearing database.";
|
| 82 |
+
} catch (e) {
|
| 83 |
+
document.getElementById("clear-db-status").textContent = "Error clearing database.";
|
| 84 |
+
}
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
document.getElementById("clear-db-btn").addEventListener("click", clearDb);
|
| 88 |
+
|
| 89 |
// Poll every second
|
| 90 |
setInterval(pollStatus, 1000);
|
| 91 |
pollStatus();
|
|
@@ -81,6 +81,21 @@ button {
|
|
| 81 |
|
| 82 |
button:hover { background: #1558b0; }
|
| 83 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
#enroll-status {
|
| 85 |
margin-top: 0.75rem;
|
| 86 |
font-size: 0.9rem;
|
|
|
|
| 81 |
|
| 82 |
button:hover { background: #1558b0; }
|
| 83 |
|
| 84 |
+
button.danger { background: #c62828; }
|
| 85 |
+
button.danger:hover { background: #8e0000; }
|
| 86 |
+
|
| 87 |
+
#admin-section {
|
| 88 |
+
margin-top: 2rem;
|
| 89 |
+
padding-top: 1rem;
|
| 90 |
+
border-top: 1px solid #ddd;
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
#clear-db-status {
|
| 94 |
+
margin-top: 0.6rem;
|
| 95 |
+
font-size: 0.9rem;
|
| 96 |
+
color: #c62828;
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
#enroll-status {
|
| 100 |
margin-top: 0.75rem;
|
| 101 |
font-size: 0.9rem;
|