Spaces:
Running
Running
| import base64 | |
| import cv2 | |
| import os | |
| from datetime import datetime | |
| import numpy as np | |
| import torch | |
| from facenet_pytorch import MTCNN, InceptionResnetV1 | |
| from ultralytics import YOLO | |
| from helpers.db import get_chroma | |
| import uuid | |
| from helpers.Augmentions import FaceAugmentor | |
| class EmbeddingController: | |
| def __init__(self, DETECTION_MODEL: str, YOLOFACE_MODEL_PATH=None): | |
| self.client, self.collection = get_chroma() | |
| self.detection_model = DETECTION_MODEL | |
| if DETECTION_MODEL == "yoloface": | |
| self.detector = YOLO(model=YOLOFACE_MODEL_PATH) | |
| else: | |
| self.detector = MTCNN( | |
| image_size=160, | |
| margin=10, # tight crop, small context | |
| min_face_size=20, # allow smaller faces | |
| thresholds=[0.6, 0.7, 0.8], # higher recall, fewer misses | |
| factor=0.709, | |
| post_process=True, | |
| keep_all=True, | |
| device=torch.device('cpu') | |
| ) | |
| self.facenet = InceptionResnetV1(pretrained="vggface2").eval().to("cpu") | |
| self.augmentor = FaceAugmentor() | |
| def detect_faces(self, image): | |
| if isinstance(self.detector, YOLO): | |
| results = self.detector(image,verbose=False) | |
| boxes = results[0].boxes.xyxy.cpu().numpy() | |
| else: | |
| boxes, _ = self.detector.detect(image) | |
| if boxes is None: | |
| return [] | |
| faces = [] | |
| for box in boxes: | |
| x1, y1, x2, y2 = map(int, box) | |
| face = image[y1:y2, x1:x2] | |
| if face.size > 0: | |
| faces.append(face) | |
| return faces | |
| def get_embedding(self, face): | |
| try: | |
| face_rgb = cv2.cvtColor(face, cv2.COLOR_BGR2RGB) | |
| except Exception: | |
| face_rgb = face | |
| face_resized = cv2.resize(face_rgb, (160, 160)) | |
| face_tensor = torch.tensor(face_resized).permute(2, 0, 1).unsqueeze(0).float() / 255.0 | |
| with torch.no_grad(): | |
| embedding = self.facenet(face_tensor.to("cpu")).cpu().numpy() | |
| return embedding.flatten() | |
| def face_to_base64(self, face): | |
| _, buffer = cv2.imencode('.jpg', face) | |
| return base64.b64encode(buffer).decode("utf-8") | |
| def save_cropped_face(self, face, user_id: str = None, idx: int = 0): | |
| try: | |
| out_dir = os.path.join(os.getcwd(), 'static', 'crops') | |
| os.makedirs(out_dir, exist_ok=True) | |
| ts = datetime.now().strftime('%Y%m%d_%H%M%S') | |
| user_part = user_id if user_id else 'unknown' | |
| filename = f"{user_part}_{self.detection_model}_{idx}_{ts}.jpg" | |
| path = os.path.join(out_dir, filename) | |
| cv2.imwrite(path, face) | |
| return path | |
| except Exception: | |
| return None | |
| def add_embedding(self, face, embedding, metadata: dict): | |
| user_id = metadata["user_id"] | |
| record_id = f"{user_id}_{uuid.uuid4().hex}" | |
| face_b64 = self.face_to_base64(face) | |
| # try: | |
| # self.save_cropped_face(face, user_id=user_id, idx=0) | |
| # except Exception: | |
| # pass | |
| embedding = embedding / np.linalg.norm(embedding) | |
| self.collection.add( | |
| ids=[record_id], | |
| embeddings=[embedding.tolist()], | |
| documents=[face_b64], | |
| metadatas=[metadata] | |
| ) | |
| aug_faces = self.augmentor.generate(face) | |
| for i, aug_face in enumerate(aug_faces): | |
| aug_embedding = self.get_embedding(aug_face) | |
| aug_metadata = metadata.copy() | |
| aug_metadata["augmented"] = True | |
| aug_id = f"{user_id}_aug_{i}_{uuid.uuid4().hex}" | |
| # try: | |
| # self.save_cropped_face(aug_face, user_id=aug_id, idx=i) | |
| # except Exception: | |
| # pass | |
| aug_embedding = aug_embedding / np.linalg.norm(aug_embedding) | |
| self.collection.add( | |
| ids=[aug_id], | |
| embeddings=[aug_embedding.tolist()], | |
| documents=[self.face_to_base64(aug_face)], | |
| metadatas=[aug_metadata] | |
| ) | |
| def update_embeddings(self, user_id: str, faces: list, embeddings: list, metadata: dict = None): | |
| try: | |
| self.collection.delete(where={"user_id": user_id}) | |
| except Exception: | |
| pass | |
| for idx, (face, emb) in enumerate(zip(faces, embeddings)): | |
| meta = metadata.copy() if metadata else {} | |
| meta.update({"user_id": user_id}) | |
| # try: | |
| # self.save_cropped_face(face, user_id=user_id, idx=idx) | |
| # except Exception: | |
| # pass | |
| record_id = f"{user_id}_{idx}_{datetime.now().timestamp()}" | |
| emb = emb / np.linalg.norm(emb) | |
| self.collection.add( | |
| ids=[record_id], | |
| embeddings=[emb.tolist()], | |
| documents=[self.face_to_base64(face)], | |
| metadatas=[meta] | |
| ) | |
| aug_faces = self.augmentor.generate(face) | |
| for j, aug_face in enumerate(aug_faces): | |
| aug_embedding = self.get_embedding(aug_face) | |
| aug_meta = meta.copy() | |
| aug_meta["augmented"] = True | |
| aug_id = f"{user_id}_upd_aug_{j}_{uuid.uuid4().hex}" | |
| aug_embedding = aug_embedding / np.linalg.norm(aug_embedding) | |
| self.collection.add( | |
| ids=[aug_id], | |
| embeddings=[aug_embedding.tolist()], | |
| documents=[self.face_to_base64(aug_face)], | |
| metadatas=[aug_meta] | |
| ) | |
| def delete_embeddings_by_user(self, user_id: str): | |
| try: | |
| self.collection.delete(where={"user_id": user_id}) | |
| return True | |
| except Exception as e: | |
| print("Deletion error:", e) | |
| return False | |
| def query_embedding(self, embedding, n_results=5, threshold=0.6): | |
| embedding = embedding / np.linalg.norm(embedding) | |
| results = self.collection.query( | |
| query_embeddings=[embedding.tolist()], | |
| n_results=n_results | |
| ) | |
| if not results or not results.get("distances"): | |
| return { | |
| "match": False, | |
| "reason": "No results from database" | |
| } | |
| distances = results["distances"][0] | |
| metadatas = results["metadatas"][0] | |
| if not distances or not metadatas: | |
| return { | |
| "match": False, | |
| "reason": "Empty results from database" | |
| } | |
| best_distance = min(distances) | |
| best_index = distances.index(best_distance) | |
| best_metadata = metadatas[best_index] | |
| similarity = 1 - best_distance | |
| if similarity >= threshold: | |
| return { | |
| "match": True, | |
| "user_id": best_metadata.get("user_id"), | |
| "similarity": round(similarity, 5), | |
| "metadata": best_metadata | |
| } | |
| return { | |
| "match": False, | |
| "similarity": round(similarity, 5) | |
| } | |