FaceRecog_hf / routers /predict.py
tjrlgns09's picture
.
02a7bf9
from fastapi import APIRouter, File, UploadFile
import numpy as np
import cv2
import faiss
import pickle
import os
import torch
import insightface
from insightface.utils import face_align
import sys
# โœ… Dockerfile์—์„œ ํด๋ก ํ•œ AdaFace ๊ฒฝ๋กœ ์ถ”๊ฐ€
sys.path.append('/app/AdaFace')
import net
router = APIRouter()
# --- ์„ค์ • ๋ฐ ๊ฒฝ๋กœ ---
faiss_index_name = "face_faiss_index_v2.index"
faiss_label_name = "face_faiss_labels_v2.pkl"
load_path = os.path.abspath("embedding/person") # ์‹ค์ œ FAISS ํŒŒ์ผ ์œ„์น˜๋กœ ๋ณ€๊ฒฝ ํ•„์š”
threshold = 45.0 # Unknown ํŒ๋ณ„ ์ž„๊ณ„๊ฐ’
# โœ… Hugging Face ๋ฌด๋ฃŒ CPU ๊ฐ•์ œ ์„ค์ •
device = torch.device('cpu')
# --- 1. InsightFace (ํƒ์ง€๊ธฐ) ๋กœ๋“œ ---
detector = insightface.app.FaceAnalysis(name='buffalo_l', providers=['CPUExecutionProvider'], allowed_modules=['detection'])
detector.prepare(ctx_id=0, det_size=(640, 640))
# --- 2. AdaFace (์ธ์‹๊ธฐ) ๋กœ๋“œ ---
model_path = "/app/adaface_ir101_webface12m.ckpt"
adaface_model = net.build_model('ir_101')
statedict = torch.load(model_path, map_location=device)["state_dict"]
model_statedict = {key[6:]: val for key, val in statedict.items() if key.startswith("model.")}
adaface_model.load_state_dict(model_statedict)
adaface_model.to(device)
adaface_model.eval()
# --- 3. FAISS ๋กœ๋“œ ---
index = faiss.read_index(os.path.join(load_path, faiss_index_name))
with open(os.path.join(load_path, faiss_label_name), "rb") as f:
labels = pickle.load(f)
# โœ… AdaFace ์ž„๋ฒ ๋”ฉ ์ถ”์ถœ ํ•จ์ˆ˜
def extract_adaface_embedding(img_bgr, face_kps):
aligned_face = face_align.norm_crop(img_bgr, landmark=face_kps, image_size=112)
img_norm = (aligned_face / 255.0 - 0.5) / 0.5
img_tensor = torch.tensor(img_norm.transpose(2, 0, 1)).float().unsqueeze(0).to(device)
with torch.no_grad():
embedding, _ = adaface_model(img_tensor)
return embedding.cpu().numpy()[0]
@router.post("/predict")
async def predict(file: UploadFile = File(...)):
contents = await file.read()
nparr = np.frombuffer(contents, np.uint8)
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
if img is None:
return {"success": False, "message": "โŒ ์ด๋ฏธ์ง€๋ฅผ ์ฝ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค."}
faces = detector.get(img)
if not faces:
return {"success": False, "message": "โŒ ์–ผ๊ตด์„ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค."}
results = []
for face in faces:
# AdaFace ์ž„๋ฒ ๋”ฉ ์ถ”์ถœ ๋ฐ ์ •๊ทœํ™”
emb = extract_adaface_embedding(img, face.kps)
emb = np.array([emb], dtype='float32')
emb /= np.linalg.norm(emb, axis=1, keepdims=True)
# FAISS ๊ฒ€์ƒ‰
distances, indices = index.search(emb, k=1)
best_match_idx = indices[0][0]
similarity_score = distances[0][0]
# ์ผ์น˜์œจ ๊ณ„์‚ฐ ๋ฐ Threshold ์ ์šฉ
score_percent = max(0, similarity_score) * 100
if score_percent >= threshold:
predicted_name = labels[best_match_idx]
else:
predicted_name = "Unknown"
box = face.bbox.astype(int).tolist()
results.append({
"label": predicted_name,
"score": float(score_percent),
"bbox": box # ํ”„๋ก ํŠธ์—”๋“œ์—์„œ ๋ฐ•์Šค๋ฅผ ๊ทธ๋ฆด ์ˆ˜ ์žˆ๋„๋ก ์ขŒํ‘œ ๋ฐ˜ํ™˜
})
return {"success": True, "results": results, "message": f"โœ… ์ด {len(faces)}๋ช…์˜ ์–ผ๊ตด์„ ์ฒ˜๋ฆฌํ–ˆ์Šต๋‹ˆ๋‹ค."}