Oliver Nitsche Claude Sonnet 4.6 commited on
Commit
6d06d8a
·
1 Parent(s): ee7b214

Fix face recognition: BGR→RGB, robust zip extraction, schema versioning

Browse files

Two bugs made recognition impossible:

1. BGR→RGB: MobileFaceNet is trained on RGB images. _embed() was feeding
raw OpenCV BGR arrays, producing systematically wrong embeddings.
Fixed by adding cv2.cvtColor(BGR2RGB) before normalisation.

2. Zip entry path: w600k_mbf.onnx lives inside a buffalo_sc/ subdirectory
in the release zip, not at the root. The hardcoded entry name caused
extraction to fail (KeyError), so the model was never written to disk.
Fixed by searching the zip namelist for any entry ending in the filename.

Additional improvements:
- SCHEMA_VERSION (=3) in face_db.py: load() auto-clears the DB when the
embedding pipeline changes, so stale BGR embeddings from older code are
discarded automatically on the first run after upgrade.
- Threshold raised back to 0.35 (correct RGB embeddings are more consistent).
- Detection more lenient: minNeighbors 4→3, minSize 60→40.
- INFO/DEBUG logging of similarity scores for diagnosability.
- POST /clear_db endpoint + "Clear face database" button in the UI so the
user can force re-enrollment without SSH access.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

recognizer/face_db.py CHANGED
@@ -8,7 +8,8 @@ Alignment : eye-centre similarity transform to the InsightFace 112×112
8
  Matching : cosine similarity on L2-normalised 512-D embeddings.
9
  Storage : recognizer/face_db.json (gitignored).
10
 
11
- No compilation required onnxruntime ships pre-built ARM64 wheels.
 
12
  """
13
 
14
  import json
@@ -24,6 +25,10 @@ import onnxruntime as ort
24
 
25
  logger = logging.getLogger(__name__)
26
 
 
 
 
 
27
  DB_PATH = Path(__file__).parent / "face_db.json"
28
  MODEL_DIR = Path(__file__).parent / "models"
29
  MODEL_FILE = MODEL_DIR / "w600k_mbf.onnx"
@@ -31,7 +36,6 @@ MODEL_URL = (
31
  "https://github.com/deepinsight/insightface"
32
  "/releases/download/v0.7/buffalo_sc.zip"
33
  )
34
- _REC_ENTRY = "w600k_mbf.onnx" # path inside the zip (root-level since buffalo_sc v0.7)
35
 
36
  _CASCADE = cv2.CascadeClassifier(
37
  cv2.data.haarcascades + "haarcascade_frontalface_default.xml"
@@ -58,13 +62,20 @@ def _ensure_model() -> None:
58
  return
59
  MODEL_DIR.mkdir(exist_ok=True)
60
  zip_path = MODEL_DIR / "buffalo_sc.zip"
61
- logger.info("Downloading face recognition model (~17 MB) — one-time setup...")
62
  urllib.request.urlretrieve(MODEL_URL, zip_path)
63
  with zipfile.ZipFile(zip_path) as zf:
64
- with zf.open(_REC_ENTRY) as src, open(MODEL_FILE, "wb") as dst:
 
 
 
 
 
 
 
65
  dst.write(src.read())
66
  zip_path.unlink()
67
- logger.info("Model ready at %s", MODEL_FILE)
68
 
69
 
70
  def _get_session() -> ort.InferenceSession:
@@ -80,22 +91,16 @@ def _get_session() -> ort.InferenceSession:
80
  def _detect(frame_bgr: np.ndarray) -> list[tuple[int, int, int, int]]:
81
  gray = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2GRAY)
82
  boxes = _CASCADE.detectMultiScale(
83
- gray, scaleFactor=1.1, minNeighbors=4, minSize=(60, 60)
84
  )
85
  return [tuple(b) for b in boxes] if len(boxes) > 0 else []
86
 
87
 
88
  def _align(face_bgr: np.ndarray) -> np.ndarray:
89
- """Return a 112×112 crop aligned on eye centres; plain resize as fallback.
90
-
91
- MobileFaceNet is trained on faces warped to a canonical eye position.
92
- Without this step, embeddings from different frames of the same person
93
- can be too dissimilar for reliable matching.
94
- """
95
  gray = cv2.cvtColor(face_bgr, cv2.COLOR_BGR2GRAY)
96
  eyes = _EYE_CASCADE.detectMultiScale(gray, scaleFactor=1.1, minNeighbors=3)
97
  if len(eyes) >= 2:
98
- # Pick the two largest detections and sort left-to-right
99
  eyes = sorted(eyes, key=lambda e: e[2] * e[3], reverse=True)[:2]
100
  eyes = sorted(eyes, key=lambda e: e[0])
101
  src = np.float32([
@@ -109,12 +114,16 @@ def _align(face_bgr: np.ndarray) -> np.ndarray:
109
 
110
 
111
  def _embed(face_bgr: np.ndarray) -> np.ndarray:
112
- img = _align(face_bgr).astype(np.float32)
113
- img = (img - 127.5) / 127.5
114
- inp = np.transpose(img, (2, 0, 1))[np.newaxis] # NCHW
 
 
 
 
115
  sess = _get_session()
116
  emb = sess.run(None, {sess.get_inputs()[0].name: inp})[0][0]
117
- return emb / np.linalg.norm(emb) # L2-normalise
118
 
119
 
120
  # ---------------------------------------------------------------------------
@@ -122,25 +131,46 @@ def _embed(face_bgr: np.ndarray) -> np.ndarray:
122
  # ---------------------------------------------------------------------------
123
 
124
  def load() -> dict[str, list[list[float]]]:
125
- """Load face DB from disk and warm up the ONNX session."""
126
- _get_session() # triggers one-time model download
127
- if DB_PATH.exists():
128
- return json.loads(DB_PATH.read_text())
129
- return {}
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
 
132
  def save(db: dict[str, list[list[float]]]) -> None:
133
- DB_PATH.write_text(json.dumps(db, indent=2))
 
 
 
 
 
 
 
 
 
134
 
135
 
136
  def find_match(
137
  frame_bgr: np.ndarray,
138
  db: dict[str, list[list[float]]],
139
- threshold: float = 0.25,
140
  ) -> Optional[str]:
141
  """Return matched name if recognised, None if face present but unknown.
142
 
143
- Raises NoFaceDetected if no face appears in the image at all.
144
  """
145
  boxes = _detect(frame_bgr)
146
  if not boxes:
@@ -156,9 +186,13 @@ def find_match(
156
  if sim > best_sim:
157
  best_sim, best_name = sim, name
158
 
 
159
  if best_name is not None and best_sim >= threshold:
 
160
  return best_name
161
- return None # face present but not recognised (or DB is empty)
 
 
162
 
163
 
164
  def add_face(
 
8
  Matching : cosine similarity on L2-normalised 512-D embeddings.
9
  Storage : recognizer/face_db.json (gitignored).
10
 
11
+ Bump SCHEMA_VERSION whenever the embedding pipeline changes so that stale
12
+ DB entries from older code are automatically discarded on load.
13
  """
14
 
15
  import json
 
25
 
26
  logger = logging.getLogger(__name__)
27
 
28
+ # Bump this whenever the embedding pipeline changes (alignment, colour space,
29
+ # model weights, normalisation, …). Mismatched DBs are auto-cleared on load.
30
+ SCHEMA_VERSION = 3 # 1=plain-resize BGR 2=aligned BGR 3=aligned RGB
31
+
32
  DB_PATH = Path(__file__).parent / "face_db.json"
33
  MODEL_DIR = Path(__file__).parent / "models"
34
  MODEL_FILE = MODEL_DIR / "w600k_mbf.onnx"
 
36
  "https://github.com/deepinsight/insightface"
37
  "/releases/download/v0.7/buffalo_sc.zip"
38
  )
 
39
 
40
  _CASCADE = cv2.CascadeClassifier(
41
  cv2.data.haarcascades + "haarcascade_frontalface_default.xml"
 
62
  return
63
  MODEL_DIR.mkdir(exist_ok=True)
64
  zip_path = MODEL_DIR / "buffalo_sc.zip"
65
+ logger.info("Downloading face recognition model (~17 MB) — one-time setup")
66
  urllib.request.urlretrieve(MODEL_URL, zip_path)
67
  with zipfile.ZipFile(zip_path) as zf:
68
+ # The file may live at root or inside a named subdirectory (e.g. buffalo_sc/).
69
+ matches = [n for n in zf.namelist() if n.endswith("w600k_mbf.onnx")]
70
+ if not matches:
71
+ raise RuntimeError(
72
+ f"w600k_mbf.onnx not found in downloaded zip. "
73
+ f"Available entries: {zf.namelist()}"
74
+ )
75
+ with zf.open(matches[0]) as src, open(MODEL_FILE, "wb") as dst:
76
  dst.write(src.read())
77
  zip_path.unlink()
78
+ logger.info("Model saved to %s", MODEL_FILE)
79
 
80
 
81
  def _get_session() -> ort.InferenceSession:
 
91
  def _detect(frame_bgr: np.ndarray) -> list[tuple[int, int, int, int]]:
92
  gray = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2GRAY)
93
  boxes = _CASCADE.detectMultiScale(
94
+ gray, scaleFactor=1.1, minNeighbors=3, minSize=(40, 40)
95
  )
96
  return [tuple(b) for b in boxes] if len(boxes) > 0 else []
97
 
98
 
99
  def _align(face_bgr: np.ndarray) -> np.ndarray:
100
+ """Return a 112×112 crop aligned on eye centres; plain resize as fallback."""
 
 
 
 
 
101
  gray = cv2.cvtColor(face_bgr, cv2.COLOR_BGR2GRAY)
102
  eyes = _EYE_CASCADE.detectMultiScale(gray, scaleFactor=1.1, minNeighbors=3)
103
  if len(eyes) >= 2:
 
104
  eyes = sorted(eyes, key=lambda e: e[2] * e[3], reverse=True)[:2]
105
  eyes = sorted(eyes, key=lambda e: e[0])
106
  src = np.float32([
 
114
 
115
 
116
  def _embed(face_bgr: np.ndarray) -> np.ndarray:
117
+ """Return an L2-normalised 512-D embedding for face_bgr."""
118
+ face_112 = _align(face_bgr)
119
+ # MobileFaceNet (InsightFace) is trained on RGB — convert from OpenCV BGR.
120
+ face_rgb = cv2.cvtColor(face_112, cv2.COLOR_BGR2RGB)
121
+ img = face_rgb.astype(np.float32)
122
+ img = (img - 127.5) / 127.5 # normalise to [-1, 1]
123
+ inp = np.transpose(img, (2, 0, 1))[np.newaxis] # NCHW
124
  sess = _get_session()
125
  emb = sess.run(None, {sess.get_inputs()[0].name: inp})[0][0]
126
+ return emb / np.linalg.norm(emb) # L2-normalise
127
 
128
 
129
  # ---------------------------------------------------------------------------
 
131
  # ---------------------------------------------------------------------------
132
 
133
  def load() -> dict[str, list[list[float]]]:
134
+ """Load face DB from disk and warm up the ONNX session.
135
+
136
+ Returns an empty dict if the DB is missing or was produced by an older
137
+ embedding pipeline (schema mismatch → auto-clear).
138
+ """
139
+ _get_session()
140
+ if not DB_PATH.exists():
141
+ return {}
142
+ raw = json.loads(DB_PATH.read_text())
143
+ if raw.get("_schema") != SCHEMA_VERSION:
144
+ logger.warning(
145
+ "face_db schema mismatch (file=%s expected=%s) — clearing stale embeddings",
146
+ raw.get("_schema"), SCHEMA_VERSION,
147
+ )
148
+ DB_PATH.unlink()
149
+ return {}
150
+ return {k: v for k, v in raw.items() if not k.startswith("_")}
151
 
152
 
153
  def save(db: dict[str, list[list[float]]]) -> None:
154
+ out: dict = {"_schema": SCHEMA_VERSION}
155
+ out.update(db)
156
+ DB_PATH.write_text(json.dumps(out, indent=2))
157
+
158
+
159
+ def wipe() -> None:
160
+ """Delete all enrolled faces from disk."""
161
+ if DB_PATH.exists():
162
+ DB_PATH.unlink()
163
+ logger.info("Face database cleared")
164
 
165
 
166
  def find_match(
167
  frame_bgr: np.ndarray,
168
  db: dict[str, list[list[float]]],
169
+ threshold: float = 0.35,
170
  ) -> Optional[str]:
171
  """Return matched name if recognised, None if face present but unknown.
172
 
173
+ Raises NoFaceDetected if no face is detected in the frame at all.
174
  """
175
  boxes = _detect(frame_bgr)
176
  if not boxes:
 
186
  if sim > best_sim:
187
  best_sim, best_name = sim, name
188
 
189
+ logger.debug("Best match: %s sim=%.3f threshold=%.2f", best_name, best_sim, threshold)
190
  if best_name is not None and best_sim >= threshold:
191
+ logger.info("Recognised: %s (sim=%.3f)", best_name, best_sim)
192
  return best_name
193
+ if best_name is not None:
194
+ logger.info("Face detected but not recognised (best sim=%.3f < %.2f)", best_sim, threshold)
195
+ return None
196
 
197
 
198
  def add_face(
recognizer/main.py CHANGED
@@ -16,7 +16,7 @@ import numpy as np
16
  from pydantic import BaseModel
17
  from reachy_mini import ReachyMini, ReachyMiniApp
18
 
19
- from recognizer.face_db import NoFaceDetected, add_face, find_match
20
  from recognizer.face_db import load as load_face_db
21
  from recognizer.tts import speak
22
 
@@ -60,6 +60,12 @@ class Recognizer(ReachyMiniApp):
60
  _shared["pending_name"] = payload.name.strip()
61
  return {"ok": True}
62
 
 
 
 
 
 
 
63
  @self.settings_app.get("/status")
64
  def get_status():
65
  with _lock:
 
16
  from pydantic import BaseModel
17
  from reachy_mini import ReachyMini, ReachyMiniApp
18
 
19
+ from recognizer.face_db import NoFaceDetected, add_face, find_match, wipe as wipe_face_db
20
  from recognizer.face_db import load as load_face_db
21
  from recognizer.tts import speak
22
 
 
60
  _shared["pending_name"] = payload.name.strip()
61
  return {"ok": True}
62
 
63
+ @self.settings_app.post("/clear_db")
64
+ def clear_db():
65
+ wipe_face_db()
66
+ face_db.clear()
67
+ return {"ok": True}
68
+
69
  @self.settings_app.get("/status")
70
  def get_status():
71
  with _lock:
recognizer/static/index.html CHANGED
@@ -28,6 +28,11 @@
28
  <div id="enroll-status"></div>
29
  </div>
30
 
 
 
 
 
 
31
  <script src="/static/main.js"></script>
32
  </body>
33
 
 
28
  <div id="enroll-status"></div>
29
  </div>
30
 
31
+ <div id="admin-section">
32
+ <button id="clear-db-btn" class="danger">🗑 Clear face database</button>
33
+ <div id="clear-db-status"></div>
34
+ </div>
35
+
36
  <script src="/static/main.js"></script>
37
  </body>
38
 
recognizer/static/main.js CHANGED
@@ -72,6 +72,20 @@ document.getElementById("name-input").addEventListener("keydown", (e) => {
72
  if (e.key === "Enter") submitName();
73
  });
74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  // Poll every second
76
  setInterval(pollStatus, 1000);
77
  pollStatus();
 
72
  if (e.key === "Enter") submitName();
73
  });
74
 
75
+ async function clearDb() {
76
+ if (!confirm("Delete all enrolled faces? The robot will not recognise anyone until they enroll again.")) return;
77
+ try {
78
+ const resp = await fetch("/clear_db", { method: "POST" });
79
+ const data = await resp.json();
80
+ document.getElementById("clear-db-status").textContent =
81
+ data.ok ? "✓ Database cleared – please re-enroll." : "Error clearing database.";
82
+ } catch (e) {
83
+ document.getElementById("clear-db-status").textContent = "Error clearing database.";
84
+ }
85
+ }
86
+
87
+ document.getElementById("clear-db-btn").addEventListener("click", clearDb);
88
+
89
  // Poll every second
90
  setInterval(pollStatus, 1000);
91
  pollStatus();
recognizer/static/style.css CHANGED
@@ -81,6 +81,21 @@ button {
81
 
82
  button:hover { background: #1558b0; }
83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  #enroll-status {
85
  margin-top: 0.75rem;
86
  font-size: 0.9rem;
 
81
 
82
  button:hover { background: #1558b0; }
83
 
84
+ button.danger { background: #c62828; }
85
+ button.danger:hover { background: #8e0000; }
86
+
87
+ #admin-section {
88
+ margin-top: 2rem;
89
+ padding-top: 1rem;
90
+ border-top: 1px solid #ddd;
91
+ }
92
+
93
+ #clear-db-status {
94
+ margin-top: 0.6rem;
95
+ font-size: 0.9rem;
96
+ color: #c62828;
97
+ }
98
+
99
  #enroll-status {
100
  margin-top: 0.75rem;
101
  font-size: 0.9rem;