aefrs-space / app.py
midokhaled927's picture
Update app.py
333b323 verified
from fastapi import FastAPI, UploadFile, File
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
import gradio as gr
import numpy as np
import cv2
import os
import onnxruntime as ort
from PIL import Image
# ================= إعداد النظام =================
app = FastAPI(title="AEFRS Face Recognition System")
# السماح لأي دومين للوصول للـ API (مهم لو هتستخدمه خارجي)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
MODEL_PATH = "artifacts/models"
DETECTION_MODEL = os.path.join(MODEL_PATH, "retinaface.onnx")
EMBEDDING_MODEL = os.path.join(MODEL_PATH, "arcface_iresnet100.onnx")
detection_session = ort.InferenceSession(DETECTION_MODEL) if os.path.exists(DETECTION_MODEL) else None
embedding_session = ort.InferenceSession(EMBEDDING_MODEL) if os.path.exists(EMBEDDING_MODEL) else None
database = {} # key: identity_id, value: {"name": name, "embedding": embedding}
# ================= دوال مساعدة =================
def read_image(file):
"""قراءة صورة من Gradio أو FastAPI"""
image = Image.open(file)
return cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
def detect_face(image):
if detection_session is None:
return image, "Detection model not loaded"
return image, None
def get_embedding(face):
if embedding_session is None:
return np.random.rand(512).astype(np.float32), "Embedding model not loaded"
return np.random.rand(512).astype(np.float32), None
# ================= FastAPI Endpoints =================
@app.post("/enroll")
async def api_enroll(identity_id: str, name: str, file: UploadFile = File(...)):
img = read_image(file.file)
face, error = detect_face(img)
if error:
return JSONResponse({"error": error}, status_code=400)
embedding, error = get_embedding(face)
if error:
return JSONResponse({"error": error}, status_code=400)
database[identity_id] = {"name": name, "embedding": embedding}
return {"message": f"{name} تم تسجيله بنجاح"}
@app.post("/search")
async def api_search(file: UploadFile = File(...)):
img = read_image(file.file)
face, error = detect_face(img)
if error:
return JSONResponse({"error": error}, status_code=400)
embedding, error = get_embedding(face)
if error:
return JSONResponse({"error": error}, status_code=400)
matches = []
for id_, data in database.items():
score = float(np.dot(embedding, data["embedding"]) / (np.linalg.norm(embedding)*np.linalg.norm(data["embedding"])))
matches.append({"identity_id": id_, "name": data["name"], "score": round(score, 3)})
matches = sorted(matches, key=lambda x: x["score"], reverse=True)[:5]
return {"matches": matches}
# ================= واجهة Gradio =================
def gr_enroll(identity_id, name, image):
img_array = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
face, error = detect_face(img_array)
if error:
return f"⚠️ {error}"
embedding, error = get_embedding(face)
if error:
return f"⚠️ {error}"
database[identity_id] = {"name": name, "embedding": embedding}
return f"✅ {name} تم تسجيله بنجاح!"
def gr_search(image):
img_array = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
face, error = detect_face(img_array)
if error:
return f"⚠️ {error}", []
embedding, error = get_embedding(face)
if error:
return f"⚠️ {error}", []
matches = []
for id_, data in database.items():
score = float(np.dot(embedding, data["embedding"]) / (np.linalg.norm(embedding)*np.linalg.norm(data["embedding"])))
matches.append({"identity_id": id_, "name": data["name"], "score": round(score, 3)})
matches = sorted(matches, key=lambda x: x["score"], reverse=True)[:5]
return "✅ البحث اكتمل", matches
with gr.Blocks() as demo:
gr.Markdown("## AEFRS Face Recognition System")
with gr.Tab("تسجيل شخص جديد"):
identity_id = gr.Textbox(label="Identity ID")
name = gr.Textbox(label="Name")
img = gr.Image(type="pil")
enroll_btn = gr.Button("تسجيل")
enroll_output = gr.Textbox()
enroll_btn.click(gr_enroll, inputs=[identity_id, name, img], outputs=enroll_output)
with gr.Tab("بحث عن وجه"):
search_img = gr.Image(type="pil")
search_btn = gr.Button("بحث")
search_status = gr.Textbox()
search_results = gr.Dataframe(headers=["ID", "Name", "Score"])
search_btn.click(gr_search, inputs=[search_img], outputs=[search_status, search_results])
# ================= تشغيل Gradio =================
if __name__ == "__main__":
import threading
# تشغيل Gradio في thread منفصل
def launch_gradio():
demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))
threading.Thread(target=launch_gradio, daemon=True).start()