attendance-insightface / api /websocket.py
vrfefavr's picture
Update api/websocket.py
68e4117 verified
import json
import base64
import cv2
import numpy as np
import asyncio
from fastapi import WebSocket, WebSocketDisconnect
from services.vision import process_frame_synchronous, get_face_embedding
from services.attendance import mark_attendance
from services.faiss_manager import db
# Temporary memory for active registration sessions
# { websocket: { "name": "...", "embeddings": [] } }
registration_sessions = {}
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
try:
while True:
data = await websocket.receive_text()
payload = json.loads(data)
if payload.get("type") == "heartbeat":
await websocket.send_json({"type": "heartbeat_ack"})
continue
# --- REGISTRATION FLOW ---
if payload.get("type") == "start_registration":
name = payload.get("name")
registration_sessions[websocket] = {"name": name, "embeddings": []}
await websocket.send_json({"type": "registration_status", "message": f"Started scanning for {name}", "progress": 0})
continue
if payload.get("type") == "register_frame":
if websocket not in registration_sessions:
continue
encoded_data = payload["image"].split(',')[1]
nparr = np.frombuffer(base64.b64decode(encoded_data), np.uint8)
frame = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
# Extract embedding (will return None if blurry or no face)
emb = await asyncio.to_thread(get_face_embedding, frame)
if emb is not None:
registration_sessions[websocket]["embeddings"].append(emb)
count = len(registration_sessions[websocket]["embeddings"])
await websocket.send_json({
"type": "registration_status",
"message": f"Captured {count} high-quality frames",
"progress": count
})
else:
await websocket.send_json({
"type": "registration_status",
"message": "Face not found or too blurry. Please hold still.",
"progress": len(registration_sessions[websocket]["embeddings"])
})
continue
if payload.get("type") == "finish_registration":
if websocket in registration_sessions:
session = registration_sessions[websocket]
embs = session["embeddings"]
if len(embs) > 0:
# MULTI-ANGLE REGISTRATION: Average all captured embeddings
avg_emb = np.mean(embs, axis=0)
# Re-normalize to ensure Cosine Similarity works correctly
avg_emb = avg_emb / np.linalg.norm(avg_emb)
# Save to FAISS
db.add_identity(session["name"], avg_emb)
await websocket.send_json({"type": "registration_success", "message": f"Successfully registered {session['name']}!"})
else:
await websocket.send_json({"type": "registration_error", "message": "Failed. No valid frames captured."})
del registration_sessions[websocket]
continue
# --- STANDARD RECOGNITION FLOW ---
if payload.get("type") == "frame":
encoded_data = payload["image"].split(',')[1]
nparr = np.frombuffer(base64.b64decode(encoded_data), np.uint8)
frame = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
report, stats = await asyncio.to_thread(process_frame_synchronous, frame)
results_summary = []
client_faces = []
for face in report:
name = face["name"]
score = face["score"]
status = face["status"]
client_faces.append({
"name": name,
"score": score,
"status": status,
"box": face["box"],
"crop": face["crop_b64"]
})
if status == "match":
# --- 🧠 SELF-HEALING AI (Continuous Learning) ---
# If we are highly confident (>95%), we add this new angle/lighting to their profile!
if score >= 95.0 and "embedding" in face:
db.add_identity(name, face["embedding"])
status_db, time_str = mark_attendance(name)
results_summary.append(f"✅ {name} ({score}%)")
await websocket.send_json({
"type": "attendance",
"name": name,
"time": time_str or "Just Now",
"status": "success"
})
elif status == "verifying":
results_summary.append(f"⏳ {name}")
else:
results_summary.append(f"❌ {name} ({score}%)")
debug_msg = f"Faces: {stats['detected']} | "
debug_msg += " | ".join(results_summary) if results_summary else "No faces found."
await websocket.send_json({
"type": "ready",
"debug": debug_msg,
"faces": client_faces
})
except WebSocketDisconnect:
if websocket in registration_sessions:
del registration_sessions[websocket]
except Exception as e:
print(f"WebSocket Error: {e}")