Spaces:
Running
Running
| from fastapi import FastAPI, UploadFile, File | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse | |
| import gradio as gr | |
| import numpy as np | |
| import cv2 | |
| import os | |
| import onnxruntime as ort | |
| from PIL import Image | |
| # ================= إعداد النظام ================= | |
| app = FastAPI(title="AEFRS Face Recognition System") | |
| # السماح لأي دومين للوصول للـ API (مهم لو هتستخدمه خارجي) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| MODEL_PATH = "artifacts/models" | |
| DETECTION_MODEL = os.path.join(MODEL_PATH, "retinaface.onnx") | |
| EMBEDDING_MODEL = os.path.join(MODEL_PATH, "arcface_iresnet100.onnx") | |
| detection_session = ort.InferenceSession(DETECTION_MODEL) if os.path.exists(DETECTION_MODEL) else None | |
| embedding_session = ort.InferenceSession(EMBEDDING_MODEL) if os.path.exists(EMBEDDING_MODEL) else None | |
| database = {} # key: identity_id, value: {"name": name, "embedding": embedding} | |
| # ================= دوال مساعدة ================= | |
| def read_image(file): | |
| """قراءة صورة من Gradio أو FastAPI""" | |
| image = Image.open(file) | |
| return cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) | |
| def detect_face(image): | |
| if detection_session is None: | |
| return image, "Detection model not loaded" | |
| return image, None | |
| def get_embedding(face): | |
| if embedding_session is None: | |
| return np.random.rand(512).astype(np.float32), "Embedding model not loaded" | |
| return np.random.rand(512).astype(np.float32), None | |
| # ================= FastAPI Endpoints ================= | |
| async def api_enroll(identity_id: str, name: str, file: UploadFile = File(...)): | |
| img = read_image(file.file) | |
| face, error = detect_face(img) | |
| if error: | |
| return JSONResponse({"error": error}, status_code=400) | |
| embedding, error = get_embedding(face) | |
| if error: | |
| return JSONResponse({"error": error}, status_code=400) | |
| database[identity_id] = {"name": name, "embedding": embedding} | |
| return {"message": f"{name} تم تسجيله بنجاح"} | |
| async def api_search(file: UploadFile = File(...)): | |
| img = read_image(file.file) | |
| face, error = detect_face(img) | |
| if error: | |
| return JSONResponse({"error": error}, status_code=400) | |
| embedding, error = get_embedding(face) | |
| if error: | |
| return JSONResponse({"error": error}, status_code=400) | |
| matches = [] | |
| for id_, data in database.items(): | |
| score = float(np.dot(embedding, data["embedding"]) / (np.linalg.norm(embedding)*np.linalg.norm(data["embedding"]))) | |
| matches.append({"identity_id": id_, "name": data["name"], "score": round(score, 3)}) | |
| matches = sorted(matches, key=lambda x: x["score"], reverse=True)[:5] | |
| return {"matches": matches} | |
| # ================= واجهة Gradio ================= | |
| def gr_enroll(identity_id, name, image): | |
| img_array = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) | |
| face, error = detect_face(img_array) | |
| if error: | |
| return f"⚠️ {error}" | |
| embedding, error = get_embedding(face) | |
| if error: | |
| return f"⚠️ {error}" | |
| database[identity_id] = {"name": name, "embedding": embedding} | |
| return f"✅ {name} تم تسجيله بنجاح!" | |
| def gr_search(image): | |
| img_array = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) | |
| face, error = detect_face(img_array) | |
| if error: | |
| return f"⚠️ {error}", [] | |
| embedding, error = get_embedding(face) | |
| if error: | |
| return f"⚠️ {error}", [] | |
| matches = [] | |
| for id_, data in database.items(): | |
| score = float(np.dot(embedding, data["embedding"]) / (np.linalg.norm(embedding)*np.linalg.norm(data["embedding"]))) | |
| matches.append({"identity_id": id_, "name": data["name"], "score": round(score, 3)}) | |
| matches = sorted(matches, key=lambda x: x["score"], reverse=True)[:5] | |
| return "✅ البحث اكتمل", matches | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## AEFRS Face Recognition System") | |
| with gr.Tab("تسجيل شخص جديد"): | |
| identity_id = gr.Textbox(label="Identity ID") | |
| name = gr.Textbox(label="Name") | |
| img = gr.Image(type="pil") | |
| enroll_btn = gr.Button("تسجيل") | |
| enroll_output = gr.Textbox() | |
| enroll_btn.click(gr_enroll, inputs=[identity_id, name, img], outputs=enroll_output) | |
| with gr.Tab("بحث عن وجه"): | |
| search_img = gr.Image(type="pil") | |
| search_btn = gr.Button("بحث") | |
| search_status = gr.Textbox() | |
| search_results = gr.Dataframe(headers=["ID", "Name", "Score"]) | |
| search_btn.click(gr_search, inputs=[search_img], outputs=[search_status, search_results]) | |
| # ================= تشغيل Gradio ================= | |
| if __name__ == "__main__": | |
| import threading | |
| # تشغيل Gradio في thread منفصل | |
| def launch_gradio(): | |
| demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860))) | |
| threading.Thread(target=launch_gradio, daemon=True).start() |