|
|
import os |
|
|
import cv2 |
|
|
import numpy as np |
|
|
import onnxruntime as ort |
|
|
from ultralytics import YOLO |
|
|
from huggingface_hub import snapshot_download |
|
|
from sklearn.metrics.pairwise import cosine_similarity |
|
|
from PIL import Image |
|
|
import gradio as gr |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
YOLO_REPO = "arnabdhar/YOLOv8-Face-Detection" |
|
|
ARCFACE_ONNX_REPO = "garavv/arcface-onnx" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
yolo_dir = snapshot_download(YOLO_REPO) |
|
|
arcface_dir = snapshot_download(ARCFACE_ONNX_REPO) |
|
|
|
|
|
def find_model(folder, ext): |
|
|
for root, _, files in os.walk(folder): |
|
|
for f in files: |
|
|
if f.endswith(ext): |
|
|
return os.path.join(root, f) |
|
|
return None |
|
|
|
|
|
yolo_model_file = find_model(yolo_dir, ".pt") |
|
|
arcface_file = find_model(arcface_dir, ".onnx") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
yolo = YOLO(yolo_model_file) |
|
|
arcface_sess = ort.InferenceSession(arcface_file, providers=["CPUExecutionProvider"]) |
|
|
arcface_input = arcface_sess.get_inputs()[0].name |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_embedding(face): |
|
|
""" |
|
|
Preprocess the face and compute embedding using ArcFace ONNX model. |
|
|
Fixed to match NHWC format (1, 112, 112, 3). |
|
|
""" |
|
|
img = cv2.resize(face, (112, 112)) |
|
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
|
|
img = (img - 127.5) / 128.0 |
|
|
img = np.expand_dims(img.astype(np.float32), axis=0) |
|
|
|
|
|
emb = arcface_sess.run(None, {arcface_input: img})[0][0] |
|
|
emb = emb / np.linalg.norm(emb) |
|
|
return emb |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
known_faces = {} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def register_face(name, image): |
|
|
"""Register a face with a name — with debug logging""" |
|
|
import traceback |
|
|
|
|
|
if not name or name.strip() == "": |
|
|
return "❌ Please enter a name before uploading an image." |
|
|
if image is None: |
|
|
return "❌ Please upload a valid face image." |
|
|
|
|
|
try: |
|
|
|
|
|
img = np.array(image) |
|
|
if img.ndim == 2: |
|
|
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) |
|
|
else: |
|
|
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) |
|
|
|
|
|
|
|
|
emb = get_embedding(img) |
|
|
if emb is None or np.isnan(emb).any(): |
|
|
return f"❌ Failed to compute embedding for {name}. Try another image." |
|
|
|
|
|
known_faces[name] = emb |
|
|
return f"✅ Registered face for **{name}**. Total known faces: {len(known_faces)}" |
|
|
|
|
|
except Exception as e: |
|
|
tb = traceback.format_exc() |
|
|
print("---- ERROR DURING REGISTER FACE ----") |
|
|
print(tb) |
|
|
return f"⚠️ Internal error: {str(e)}\n\n{tb}" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def detect_and_recognize(image): |
|
|
import traceback |
|
|
if image is None: |
|
|
return "❌ Please upload an image.", None |
|
|
if not known_faces: |
|
|
return "⚠️ No known faces registered yet!", None |
|
|
|
|
|
try: |
|
|
img = np.array(image) |
|
|
if img.ndim == 2: |
|
|
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) |
|
|
else: |
|
|
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) |
|
|
|
|
|
results = yolo.predict(source=img, conf=0.35, verbose=False) |
|
|
names_found = [] |
|
|
|
|
|
for r in results: |
|
|
boxes = r.boxes |
|
|
for box in boxes: |
|
|
x1, y1, x2, y2 = map(int, box.xyxy[0].tolist()) |
|
|
crop = img[y1:y2, x1:x2] |
|
|
if crop.size == 0: |
|
|
continue |
|
|
emb = get_embedding(crop) |
|
|
best_name, best_score = "Unknown", 0 |
|
|
for name, ref_emb in known_faces.items(): |
|
|
score = cosine_similarity([emb], [ref_emb])[0][0] |
|
|
if score > best_score: |
|
|
best_name, best_score = name, score |
|
|
if best_score < 0.45: |
|
|
best_name = "Unknown" |
|
|
cv2.rectangle(img, (x1, y1), (x2, y2), (0,255,0), 2) |
|
|
cv2.putText(img, f"{best_name} ({best_score:.2f})", (x1, y1-10), |
|
|
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,255,0), 2) |
|
|
names_found.append(best_name) |
|
|
|
|
|
result_img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) |
|
|
return f"Detected faces: {names_found}", result_img |
|
|
|
|
|
except Exception as e: |
|
|
tb = traceback.format_exc() |
|
|
print("---- ERROR DURING DETECT & RECOGNIZE ----") |
|
|
print(tb) |
|
|
return f"⚠️ Internal error: {str(e)}\n\n{tb}", None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
gr.Markdown("# 👤 Face Detection & Recognition (YOLO + ArcFace)") |
|
|
|
|
|
with gr.Tab("Register Face"): |
|
|
name_input = gr.Textbox(label="Person Name") |
|
|
face_input = gr.Image(label="Upload Face Image") |
|
|
register_btn = gr.Button("Register Face") |
|
|
register_output = gr.Textbox(label="Status") |
|
|
register_btn.click(register_face, [name_input, face_input], register_output) |
|
|
|
|
|
with gr.Tab("Detect & Recognize"): |
|
|
img_input = gr.Image(label="Upload Test Image") |
|
|
detect_btn = gr.Button("Detect Faces") |
|
|
text_output = gr.Textbox(label="Results") |
|
|
img_output = gr.Image(label="Output Image") |
|
|
detect_btn.click(detect_and_recognize, img_input, [text_output, img_output]) |
|
|
|
|
|
demo.launch() |
|
|
|