# built-in dependencies import os import json import hashlib import struct from datetime import datetime, timezone from typing import Any, Dict, List, Optional, Union # 3rd party dependencies import numpy as np # project dependencies from deepface.modules.database.types import Database from deepface.modules.exceptions import DuplicateEntryError from deepface.commons.logger import Logger logger = Logger() # pylint: disable=too-many-positional-arguments, too-many-instance-attributes class MongoDbClient(Database): """ MongoDB equivalent of PostgresClient for DeepFace embeddings storage. """ def __init__( self, connection_details: Optional[Union[str, Dict[str, Any]]] = None, connection: Any = None, db_name: str = "deepface", ) -> None: try: from pymongo import MongoClient, ASCENDING from pymongo.errors import DuplicateKeyError, BulkWriteError from bson import Binary except (ModuleNotFoundError, ImportError) as e: raise ValueError( "pymongo is an optional dependency. Please install it as `pip install pymongo`" ) from e self.MongoClient = MongoClient self.ASCENDING = ASCENDING self.DuplicateKeyError = DuplicateKeyError self.BulkWriteError = BulkWriteError self.Binary = Binary if connection is not None: self.client = connection else: self.conn_details = connection_details or os.environ.get("DEEPFACE_MONGO_URI") if not self.conn_details: raise ValueError( "MongoDB connection information not found. " "Please provide connection_details or set DEEPFACE_MONGO_URI" ) if isinstance(self.conn_details, str): self.client = MongoClient(self.conn_details) else: self.client = MongoClient(**self.conn_details) self.db = self.client[db_name] self.embeddings = self.db.embeddings self.embeddings_index = self.db.embeddings_index self.counters = self.db.counters self.initialize_database() def close(self) -> None: """Close MongoDB connection.""" self.client.close() def initialize_database(self, **kwargs: Any) -> None: """ Ensure required MongoDB indexes exist. """ # Unique constraint for embeddings self.embeddings.create_index( [("face_hash", self.ASCENDING), ("embedding_hash", self.ASCENDING)], unique=True, name="uniq_face_embedding", ) # Unique constraint for embeddings_index self.embeddings_index.create_index( [ ("model_name", self.ASCENDING), ("detector_backend", self.ASCENDING), ("align", self.ASCENDING), ("l2_normalized", self.ASCENDING), ], unique=True, name="uniq_index_config", ) # counters collection for auto-incrementing IDs if not self.counters.find_one({"_id": "embedding_id"}): self.counters.insert_one({"_id": "embedding_id", "seq": 0}) logger.debug("MongoDB indexes ensured.") def upsert_embeddings_index( self, model_name: str, detector_backend: str, aligned: bool, l2_normalized: bool, index_data: bytes, ) -> None: """ Upsert embeddings index into MongoDB. Args: model_name (str): Name of the model. detector_backend (str): Name of the detector backend. aligned (bool): Whether the embeddings are aligned. l2_normalized (bool): Whether the embeddings are L2 normalized. index_data (bytes): Serialized index data. """ self.embeddings_index.update_one( { "model_name": model_name, "detector_backend": detector_backend, "align": aligned, "l2_normalized": l2_normalized, }, { "$set": { "index_data": self.Binary(index_data), "updated_at": datetime.now(timezone.utc), }, "$setOnInsert": { "created_at": datetime.now(timezone.utc), }, }, upsert=True, ) def get_embeddings_index( self, model_name: str, detector_backend: str, aligned: bool, l2_normalized: bool, ) -> bytes: """ Retrieve embeddings index from MongoDB. Args: model_name (str): Name of the model. detector_backend (str): Name of the detector backend. aligned (bool): Whether the embeddings are aligned. l2_normalized (bool): Whether the embeddings are L2 normalized. Returns: bytes: Serialized index data. """ doc = self.embeddings_index.find_one( { "model_name": model_name, "detector_backend": detector_backend, "align": aligned, "l2_normalized": l2_normalized, }, {"index_data": 1}, ) if not doc: raise ValueError( "No Embeddings index found for the specified parameters " f"{model_name=}, {detector_backend=}, " f"{aligned=}, {l2_normalized=}. " "You must run build_index first." ) return bytes(doc["index_data"]) def insert_embeddings(self, embeddings: List[Dict[str, Any]], batch_size: int = 100) -> int: """ Insert embeddings into MongoDB. Args: embeddings (List[Dict[str, Any]]): List of embedding records to insert. batch_size (int): Number of records to insert in each batch. Returns: int: Number of embeddings successfully inserted. """ if not embeddings: raise ValueError("No embeddings to insert.") docs: List[Dict[str, Any]] = [] for e in embeddings: face = e["face"] face_shape = list(face.shape) binary_face_data = self.Binary(face.astype(np.float32).tobytes()) embedding_bytes = struct.pack(f'{len(e["embedding"])}d', *e["embedding"]) face_hash = hashlib.sha256(json.dumps(face.tolist()).encode()).hexdigest() embedding_hash = hashlib.sha256(embedding_bytes).hexdigest() int_id = self.counters.find_one_and_update( {"_id": "embedding_id"}, {"$inc": {"seq": 1}}, upsert=True, return_document=True )["seq"] docs.append( { "sequence": int_id, "img_name": e["img_name"], "face": binary_face_data, "face_shape": face_shape, "model_name": e["model_name"], "detector_backend": e["detector_backend"], "aligned": e["aligned"], "l2_normalized": e["l2_normalized"], "embedding": e["embedding"], "face_hash": face_hash, "embedding_hash": embedding_hash, "created_at": datetime.now(timezone.utc), } ) inserted = 0 try: for i in range(0, len(docs), batch_size): result = self.embeddings.insert_many(docs[i : i + batch_size], ordered=False) inserted += len(result.inserted_ids) except (self.DuplicateKeyError, self.BulkWriteError) as e: if len(docs) == 1: logger.warn("Duplicate detected for extracted face and embedding.") return inserted raise DuplicateEntryError( f"Duplicate detected for extracted face and embedding in {i}-th batch" ) from e return inserted def fetch_all_embeddings( self, model_name: str, detector_backend: str, aligned: bool, l2_normalized: bool, batch_size: int = 1000, ) -> List[Dict[str, Any]]: """ Fetch all embeddings from MongoDB based on specified parameters. Args: model_name (str): Name of the model. detector_backend (str): Name of the detector backend. aligned (bool): Whether the embeddings are aligned. l2_normalized (bool): Whether the embeddings are L2 normalized. batch_size (int): Number of records to fetch in each batch. Returns: List[Dict[str, Any]]: List of embedding records. """ cursor = self.embeddings.find( { "model_name": model_name, "detector_backend": detector_backend, "aligned": aligned, "l2_normalized": l2_normalized, }, { "_id": 1, "sequence": 1, "img_name": 1, "embedding": 1, }, batch_size=batch_size, ).sort("sequence", self.ASCENDING) results: List[Dict[str, Any]] = [] for doc in cursor: results.append( { "_id": str(doc["_id"]), "id": doc["sequence"], "img_name": doc["img_name"], "embedding": doc["embedding"], "model_name": model_name, "detector_backend": detector_backend, "aligned": aligned, "l2_normalized": l2_normalized, } ) return results def search_by_id( self, ids: Union[List[str], List[int]], ) -> List[Dict[str, Any]]: """ Search records by their IDs. """ cursor = self.embeddings.find( {"sequence": {"$in": ids}}, { "_id": 1, "sequence": 1, "img_name": 1, }, ) results: List[Dict[str, Any]] = [] for doc in cursor: results.append( { "_id": str(doc["_id"]), "id": doc["sequence"], "img_name": doc["img_name"], } ) return results