Spaces:
Build error
Build error
File size: 5,109 Bytes
6273393 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
import os
import cv2
import numpy as np
import pickle
import faiss
from insightface.app import FaceAnalysis
from fastapi import FastAPI, UploadFile, File, HTTPException
from io import BytesIO
# 1. FastAPI ์ ํ๋ฆฌ์ผ์ด์
์ธ์คํด์ค ์์ฑ
app = FastAPI(
title="InsightFace Face Recognition API",
description="InsightFace (buffalo_l)๋ฅผ ์ฌ์ฉํ ์ผ๊ตด ์ธ์ ๋ฐ FAISS ์ธ๋ฑ์ค ๊ฒ์ API"
)
# ๋ชจ๋ธ ๋ฐ ์ธ๋ฑ์ค ์ ์ญ ๋ณ์
model = None
index = None
labels = None
# ๐ ํ๊ฒฝ ์ค์ : Docker ํ๊ฒฝ์์๋ ๋ชจ๋ธ/์ธ๋ฑ์ค ํ์ผ์ ํ๋ก์ ํธ ๋๋ ํ ๋ฆฌ์ ๋ฐฐ์นํฉ๋๋ค.
# ์ธ๋ฑ์ค ํ์ผ์ /app/data ํด๋์ ์๋ค๊ณ ๊ฐ์ ํฉ๋๋ค.
LOAD_DIR = "data"
FAISS_INDEX_FILE = os.path.join(LOAD_DIR, "faiss_index_v2.index")
LABELS_FILE = os.path.join(LOAD_DIR, "faiss_labels_v2.pkl")
# โ
1. ๋ชจ๋ธ ๋ฐ ์ธ๋ฑ์ค ์ค๋น (์๋ฒ ์์ ์ ํ ๋ฒ๋ง ์คํ)
@app.on_event("startup")
async def startup_event():
global model, index, labels
print("๐ ์๋ฒ ์์: InsightFace ๋ชจ๋ธ ๋ฐ FAISS ์ธ๋ฑ์ค ๋ก๋ฉ ์ค...")
# InsightFace ArcFace ๋ชจ๋ธ ์ค๋น
try:
# CPUExecutionProvider ์ฌ์ฉ (GPU๊ฐ ์๋ ํ๊ฒฝ/Docker์ ์ ํฉ)
model = FaceAnalysis(name='buffalo_l', providers=['CPUExecutionProvider'])
model.prepare(ctx_id=0)
print("โ
InsightFace model (buffalo_l) ๋ก๋ฉ ์๋ฃ.")
except Exception as e:
print(f"โ InsightFace ๋ชจ๋ธ ๋ก๋ฉ ์คํจ: {e}")
raise HTTPException(status_code=500, detail=f"๋ชจ๋ธ ๋ก๋ฉ ์คํจ: {e}")
# ์ธ๋ฑ์ค & ๋ผ๋ฒจ ๋ก๋ฉ
try:
if not os.path.exists(FAISS_INDEX_FILE) or not os.path.exists(LABELS_FILE):
print(f"โ FAISS ํ์ผ์ด ์์ต๋๋ค. ๊ฒฝ๋ก ํ์ธ: {LOAD_DIR}")
raise FileNotFoundError(f"ํ์ํ FAISS ์ธ๋ฑ์ค ํ์ผ ํน์ ๋ผ๋ฒจ ํ์ผ์ด ์์ต๋๋ค.")
index = faiss.read_index(FAISS_INDEX_FILE)
with open(LABELS_FILE, "rb") as f:
labels = pickle.load(f)
print(f"โ
FAISS ์ธ๋ฑ์ค ๋ก๋ฉ ์๋ฃ. ์ด {index.ntotal}๊ฐ์ ์๋ฒ ๋ฉ ๋ก๋.")
except Exception as e:
print(f"โ FAISS ์ธ๋ฑ์ค ๋ก๋ฉ ์คํจ: {e}")
raise HTTPException(status_code=500, detail=f"FAISS ๋ก๋ฉ ์คํจ: {e}")
# ๐ ์ผ๊ตด ์๋ฒ ๋ฉ ์ถ์ถ ํจ์ (์๋ณธ ์๋ฒ ๋ฉ๋ง ์ถ์ถ)
def get_face_embedding(img_np):
"""
Numpy ๋ฐฐ์ด ํํ์ ์ด๋ฏธ์ง์์ ์ผ๊ตด ์๋ฒ ๋ฉ์ ์ถ์ถํฉ๋๋ค.
"""
global model
if model is None:
raise HTTPException(status_code=500, detail="๋ชจ๋ธ์ด ์ด๊ธฐํ๋์ง ์์์ต๋๋ค.")
# BGR์ RGB๋ก ๋ณํ (InsightFace ๋ชจ๋ธ์ RGB๋ฅผ ์ ํธ)
img = cv2.cvtColor(img_np, cv2.COLOR_BGR2RGB)
faces = model.get(img)
if faces:
return faces[0].embedding
else:
return None
# --- API ์๋ํฌ์ธํธ ์ ์ ---
# 3. ๋ฃจํธ ์๋ํฌ์ธํธ (GET /)
@app.get("/")
def read_root():
result={"success":True,"data":None,"msg":""}
try:
result["data"]="ok"
return result
except Exception as e:
result["success"] = False
result["msg"]=f"server error. {e!r}"
return result
# ๐ ์ผ๊ตด ์์ธก API ์๋ํฌ์ธํธ
@app.post("/predict_person/")
async def predict_person(
image: UploadFile = File(..., description="๋ถ์ํ ์ผ๊ตด ์ด๋ฏธ์ง ํ์ผ"),
top_k: int = 1
):
"""
์
๋ก๋๋ ์ด๋ฏธ์ง์์ ์ผ๊ตด์ ์ธ์ํ๊ณ , FAISS ์ธ๋ฑ์ค๋ฅผ ์ฌ์ฉํ์ฌ ๊ฐ์ฅ ์ ์ฌํ ์ธ๋ฌผ์ ์์ธกํฉ๋๋ค.
"""
global index, labels
# 1. ์ด๋ฏธ์ง ํ์ผ ์ฝ๊ธฐ
content = await image.read()
np_array = np.frombuffer(content, np.uint8)
img_np = cv2.imdecode(np_array, cv2.IMREAD_COLOR)
if img_np is None:
raise HTTPException(status_code=400, detail="์ด๋ฏธ์ง ํ์ผ์ ์ฝ์ ์ ์์ต๋๋ค.")
# 2. ์ผ๊ตด ์๋ฒ ๋ฉ ์ถ์ถ
embedding = get_face_embedding(img_np)
if embedding is None:
raise HTTPException(status_code=404, detail="์ด๋ฏธ์ง์์ ์ผ๊ตด์ ์ธ์ํ ์ ์์ต๋๋ค.")
# 3. ์๋ฒ ๋ฉ ์ ๊ทํ
embedding = embedding.astype('float32')
embedding /= np.linalg.norm(embedding)
# ์ฟผ๋ฆฌ ํ์์ ๋ง๊ฒ [1, D] ํํ๋ก ๋ณํ
query_vector = np.array([embedding])
# 4. ์ ์ฌ๋ ๊ฒ์
# top_k๋ ์ต๋ ์ธ๋ฑ์ค ํฌ๊ธฐ(index.ntotal)๋ฅผ ์ด๊ณผํ ์ ์์ต๋๋ค.
k = min(top_k, index.ntotal)
scores, indices = index.search(query_vector, k)
# 5. ๊ฒฐ๊ณผ ํฌ๋งทํ
results = []
for idx, score in zip(indices[0], scores[0]):
# labels[idx]๋ ์ธ๋ฑ์ค์ ์ ์ฅ๋ ์๋ฒ ๋ฉ ์ค ๊ฐ์ฅ ์ ์ฌํ ์๋ฒ ๋ฉ์ ๋ผ๋ฒจ
results.append({
"rank": len(results) + 1,
"person_id": labels[idx],
"similarity_score": float(f"{score:.4f}")
})
return {"filename": image.filename, "predictions": results}
if __name__ == "__main__":
# --reload ์ต์
์ ์ถ๊ฐํ์ฌ ์ฝ๋๊ฐ ๋ณ๊ฒฝ๋ ๋๋ง๋ค ์๋ ์ฌ์์๋๊ฒ ์ค์ ํฉ๋๋ค.
uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=True) |