Spaces:
Build error
Build error
| import os | |
| import cv2 | |
| import numpy as np | |
| import pickle | |
| import faiss | |
| from insightface.app import FaceAnalysis | |
| from fastapi import FastAPI, UploadFile, File, HTTPException | |
| from io import BytesIO | |
| # 1. FastAPI ์ ํ๋ฆฌ์ผ์ด์ ์ธ์คํด์ค ์์ฑ | |
| app = FastAPI( | |
| title="InsightFace Face Recognition API", | |
| description="InsightFace (buffalo_l)๋ฅผ ์ฌ์ฉํ ์ผ๊ตด ์ธ์ ๋ฐ FAISS ์ธ๋ฑ์ค ๊ฒ์ API" | |
| ) | |
| # ๋ชจ๋ธ ๋ฐ ์ธ๋ฑ์ค ์ ์ญ ๋ณ์ | |
| model = None | |
| index = None | |
| labels = None | |
| # ๐ ํ๊ฒฝ ์ค์ : Docker ํ๊ฒฝ์์๋ ๋ชจ๋ธ/์ธ๋ฑ์ค ํ์ผ์ ํ๋ก์ ํธ ๋๋ ํ ๋ฆฌ์ ๋ฐฐ์นํฉ๋๋ค. | |
| # ์ธ๋ฑ์ค ํ์ผ์ /app/data ํด๋์ ์๋ค๊ณ ๊ฐ์ ํฉ๋๋ค. | |
| LOAD_DIR = "data" | |
| FAISS_INDEX_FILE = os.path.join(LOAD_DIR, "faiss_index_v2.index") | |
| LABELS_FILE = os.path.join(LOAD_DIR, "faiss_labels_v2.pkl") | |
| # โ 1. ๋ชจ๋ธ ๋ฐ ์ธ๋ฑ์ค ์ค๋น (์๋ฒ ์์ ์ ํ ๋ฒ๋ง ์คํ) | |
| async def startup_event(): | |
| global model, index, labels | |
| print("๐ ์๋ฒ ์์: InsightFace ๋ชจ๋ธ ๋ฐ FAISS ์ธ๋ฑ์ค ๋ก๋ฉ ์ค...") | |
| # InsightFace ArcFace ๋ชจ๋ธ ์ค๋น | |
| try: | |
| # CPUExecutionProvider ์ฌ์ฉ (GPU๊ฐ ์๋ ํ๊ฒฝ/Docker์ ์ ํฉ) | |
| model = FaceAnalysis(name='buffalo_l', providers=['CPUExecutionProvider']) | |
| model.prepare(ctx_id=0) | |
| print("โ InsightFace model (buffalo_l) ๋ก๋ฉ ์๋ฃ.") | |
| except Exception as e: | |
| print(f"โ InsightFace ๋ชจ๋ธ ๋ก๋ฉ ์คํจ: {e}") | |
| raise HTTPException(status_code=500, detail=f"๋ชจ๋ธ ๋ก๋ฉ ์คํจ: {e}") | |
| # ์ธ๋ฑ์ค & ๋ผ๋ฒจ ๋ก๋ฉ | |
| try: | |
| if not os.path.exists(FAISS_INDEX_FILE) or not os.path.exists(LABELS_FILE): | |
| print(f"โ FAISS ํ์ผ์ด ์์ต๋๋ค. ๊ฒฝ๋ก ํ์ธ: {LOAD_DIR}") | |
| raise FileNotFoundError(f"ํ์ํ FAISS ์ธ๋ฑ์ค ํ์ผ ํน์ ๋ผ๋ฒจ ํ์ผ์ด ์์ต๋๋ค.") | |
| index = faiss.read_index(FAISS_INDEX_FILE) | |
| with open(LABELS_FILE, "rb") as f: | |
| labels = pickle.load(f) | |
| print(f"โ FAISS ์ธ๋ฑ์ค ๋ก๋ฉ ์๋ฃ. ์ด {index.ntotal}๊ฐ์ ์๋ฒ ๋ฉ ๋ก๋.") | |
| except Exception as e: | |
| print(f"โ FAISS ์ธ๋ฑ์ค ๋ก๋ฉ ์คํจ: {e}") | |
| raise HTTPException(status_code=500, detail=f"FAISS ๋ก๋ฉ ์คํจ: {e}") | |
| # ๐ ์ผ๊ตด ์๋ฒ ๋ฉ ์ถ์ถ ํจ์ (์๋ณธ ์๋ฒ ๋ฉ๋ง ์ถ์ถ) | |
| def get_face_embedding(img_np): | |
| """ | |
| Numpy ๋ฐฐ์ด ํํ์ ์ด๋ฏธ์ง์์ ์ผ๊ตด ์๋ฒ ๋ฉ์ ์ถ์ถํฉ๋๋ค. | |
| """ | |
| global model | |
| if model is None: | |
| raise HTTPException(status_code=500, detail="๋ชจ๋ธ์ด ์ด๊ธฐํ๋์ง ์์์ต๋๋ค.") | |
| # BGR์ RGB๋ก ๋ณํ (InsightFace ๋ชจ๋ธ์ RGB๋ฅผ ์ ํธ) | |
| img = cv2.cvtColor(img_np, cv2.COLOR_BGR2RGB) | |
| faces = model.get(img) | |
| if faces: | |
| return faces[0].embedding | |
| else: | |
| return None | |
| # --- API ์๋ํฌ์ธํธ ์ ์ --- | |
| # 3. ๋ฃจํธ ์๋ํฌ์ธํธ (GET /) | |
| def read_root(): | |
| result={"success":True,"data":None,"msg":""} | |
| try: | |
| result["data"]="ok" | |
| return result | |
| except Exception as e: | |
| result["success"] = False | |
| result["msg"]=f"server error. {e!r}" | |
| return result | |
| # ๐ ์ผ๊ตด ์์ธก API ์๋ํฌ์ธํธ | |
| async def predict_person( | |
| image: UploadFile = File(..., description="๋ถ์ํ ์ผ๊ตด ์ด๋ฏธ์ง ํ์ผ"), | |
| top_k: int = 1 | |
| ): | |
| """ | |
| ์ ๋ก๋๋ ์ด๋ฏธ์ง์์ ์ผ๊ตด์ ์ธ์ํ๊ณ , FAISS ์ธ๋ฑ์ค๋ฅผ ์ฌ์ฉํ์ฌ ๊ฐ์ฅ ์ ์ฌํ ์ธ๋ฌผ์ ์์ธกํฉ๋๋ค. | |
| """ | |
| global index, labels | |
| # 1. ์ด๋ฏธ์ง ํ์ผ ์ฝ๊ธฐ | |
| content = await image.read() | |
| np_array = np.frombuffer(content, np.uint8) | |
| img_np = cv2.imdecode(np_array, cv2.IMREAD_COLOR) | |
| if img_np is None: | |
| raise HTTPException(status_code=400, detail="์ด๋ฏธ์ง ํ์ผ์ ์ฝ์ ์ ์์ต๋๋ค.") | |
| # 2. ์ผ๊ตด ์๋ฒ ๋ฉ ์ถ์ถ | |
| embedding = get_face_embedding(img_np) | |
| if embedding is None: | |
| raise HTTPException(status_code=404, detail="์ด๋ฏธ์ง์์ ์ผ๊ตด์ ์ธ์ํ ์ ์์ต๋๋ค.") | |
| # 3. ์๋ฒ ๋ฉ ์ ๊ทํ | |
| embedding = embedding.astype('float32') | |
| embedding /= np.linalg.norm(embedding) | |
| # ์ฟผ๋ฆฌ ํ์์ ๋ง๊ฒ [1, D] ํํ๋ก ๋ณํ | |
| query_vector = np.array([embedding]) | |
| # 4. ์ ์ฌ๋ ๊ฒ์ | |
| # top_k๋ ์ต๋ ์ธ๋ฑ์ค ํฌ๊ธฐ(index.ntotal)๋ฅผ ์ด๊ณผํ ์ ์์ต๋๋ค. | |
| k = min(top_k, index.ntotal) | |
| scores, indices = index.search(query_vector, k) | |
| # 5. ๊ฒฐ๊ณผ ํฌ๋งทํ | |
| results = [] | |
| for idx, score in zip(indices[0], scores[0]): | |
| # labels[idx]๋ ์ธ๋ฑ์ค์ ์ ์ฅ๋ ์๋ฒ ๋ฉ ์ค ๊ฐ์ฅ ์ ์ฌํ ์๋ฒ ๋ฉ์ ๋ผ๋ฒจ | |
| results.append({ | |
| "rank": len(results) + 1, | |
| "person_id": labels[idx], | |
| "similarity_score": float(f"{score:.4f}") | |
| }) | |
| return {"filename": image.filename, "predictions": results} | |
| if __name__ == "__main__": | |
| # --reload ์ต์ ์ ์ถ๊ฐํ์ฌ ์ฝ๋๊ฐ ๋ณ๊ฒฝ๋ ๋๋ง๋ค ์๋ ์ฌ์์๋๊ฒ ์ค์ ํฉ๋๋ค. | |
| uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=True) |