kuldeep0204's picture
Update app.py
979affc verified
import os
import cv2
import numpy as np
import onnxruntime as ort
from ultralytics import YOLO
from huggingface_hub import snapshot_download
from sklearn.metrics.pairwise import cosine_similarity
from PIL import Image
import gradio as gr
# ------------------------------
# MODEL REPOSITORIES
# ------------------------------
YOLO_REPO = "arnabdhar/YOLOv8-Face-Detection" # Face detection model
ARCFACE_ONNX_REPO = "garavv/arcface-onnx" # ArcFace ONNX model
# ------------------------------
# DOWNLOAD MODELS FROM HUGGING FACE
# ------------------------------
yolo_dir = snapshot_download(YOLO_REPO)
arcface_dir = snapshot_download(ARCFACE_ONNX_REPO)
def find_model(folder, ext):
for root, _, files in os.walk(folder):
for f in files:
if f.endswith(ext):
return os.path.join(root, f)
return None
yolo_model_file = find_model(yolo_dir, ".pt")
arcface_file = find_model(arcface_dir, ".onnx")
# ------------------------------
# LOAD MODELS
# ------------------------------
yolo = YOLO(yolo_model_file)
arcface_sess = ort.InferenceSession(arcface_file, providers=["CPUExecutionProvider"])
arcface_input = arcface_sess.get_inputs()[0].name
# ------------------------------
# HELPER FUNCTIONS
# ------------------------------
def get_embedding(face):
"""
Preprocess the face and compute embedding using ArcFace ONNX model.
Fixed to match NHWC format (1, 112, 112, 3).
"""
img = cv2.resize(face, (112, 112))
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = (img - 127.5) / 128.0
img = np.expand_dims(img.astype(np.float32), axis=0) # (1, 112, 112, 3)
emb = arcface_sess.run(None, {arcface_input: img})[0][0]
emb = emb / np.linalg.norm(emb)
return emb
# ------------------------------
# KNOWN FACES DATABASE (in memory)
# ------------------------------
known_faces = {}
# ------------------------------
# REGISTER FACE FUNCTION (with debug logging)
# ------------------------------
def register_face(name, image):
"""Register a face with a name — with debug logging"""
import traceback
if not name or name.strip() == "":
return "❌ Please enter a name before uploading an image."
if image is None:
return "❌ Please upload a valid face image."
try:
# Convert PIL image to OpenCV
img = np.array(image)
if img.ndim == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
else:
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
# Compute embedding
emb = get_embedding(img)
if emb is None or np.isnan(emb).any():
return f"❌ Failed to compute embedding for {name}. Try another image."
known_faces[name] = emb
return f"✅ Registered face for **{name}**. Total known faces: {len(known_faces)}"
except Exception as e:
tb = traceback.format_exc()
print("---- ERROR DURING REGISTER FACE ----")
print(tb)
return f"⚠️ Internal error: {str(e)}\n\n{tb}"
# ------------------------------
# DETECT + RECOGNIZE FUNCTION
# ------------------------------
def detect_and_recognize(image):
import traceback
if image is None:
return "❌ Please upload an image.", None
if not known_faces:
return "⚠️ No known faces registered yet!", None
try:
img = np.array(image)
if img.ndim == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
else:
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
results = yolo.predict(source=img, conf=0.35, verbose=False)
names_found = []
for r in results:
boxes = r.boxes
for box in boxes:
x1, y1, x2, y2 = map(int, box.xyxy[0].tolist())
crop = img[y1:y2, x1:x2]
if crop.size == 0:
continue
emb = get_embedding(crop)
best_name, best_score = "Unknown", 0
for name, ref_emb in known_faces.items():
score = cosine_similarity([emb], [ref_emb])[0][0]
if score > best_score:
best_name, best_score = name, score
if best_score < 0.45:
best_name = "Unknown"
cv2.rectangle(img, (x1, y1), (x2, y2), (0,255,0), 2)
cv2.putText(img, f"{best_name} ({best_score:.2f})", (x1, y1-10),
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,255,0), 2)
names_found.append(best_name)
result_img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
return f"Detected faces: {names_found}", result_img
except Exception as e:
tb = traceback.format_exc()
print("---- ERROR DURING DETECT & RECOGNIZE ----")
print(tb)
return f"⚠️ Internal error: {str(e)}\n\n{tb}", None
# ------------------------------
# GRADIO INTERFACE
# ------------------------------
with gr.Blocks() as demo:
gr.Markdown("# 👤 Face Detection & Recognition (YOLO + ArcFace)")
with gr.Tab("Register Face"):
name_input = gr.Textbox(label="Person Name")
face_input = gr.Image(label="Upload Face Image")
register_btn = gr.Button("Register Face")
register_output = gr.Textbox(label="Status")
register_btn.click(register_face, [name_input, face_input], register_output)
with gr.Tab("Detect & Recognize"):
img_input = gr.Image(label="Upload Test Image")
detect_btn = gr.Button("Detect Faces")
text_output = gr.Textbox(label="Results")
img_output = gr.Image(label="Output Image")
detect_btn.click(detect_and_recognize, img_input, [text_output, img_output])
demo.launch()