DaoManhDuc2004
Deploy DATN face AI server
b5d3a91
# built-in dependencies
import os
import json
import hashlib
import struct
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional, Union
# 3rd party dependencies
import numpy as np
# project dependencies
from deepface.modules.database.types import Database
from deepface.modules.exceptions import DuplicateEntryError
from deepface.commons.logger import Logger
logger = Logger()
# pylint: disable=too-many-positional-arguments, too-many-instance-attributes
class MongoDbClient(Database):
"""
MongoDB equivalent of PostgresClient for DeepFace embeddings storage.
"""
def __init__(
self,
connection_details: Optional[Union[str, Dict[str, Any]]] = None,
connection: Any = None,
db_name: str = "deepface",
) -> None:
try:
from pymongo import MongoClient, ASCENDING
from pymongo.errors import DuplicateKeyError, BulkWriteError
from bson import Binary
except (ModuleNotFoundError, ImportError) as e:
raise ValueError(
"pymongo is an optional dependency. Please install it as `pip install pymongo`"
) from e
self.MongoClient = MongoClient
self.ASCENDING = ASCENDING
self.DuplicateKeyError = DuplicateKeyError
self.BulkWriteError = BulkWriteError
self.Binary = Binary
if connection is not None:
self.client = connection
else:
self.conn_details = connection_details or os.environ.get("DEEPFACE_MONGO_URI")
if not self.conn_details:
raise ValueError(
"MongoDB connection information not found. "
"Please provide connection_details or set DEEPFACE_MONGO_URI"
)
if isinstance(self.conn_details, str):
self.client = MongoClient(self.conn_details)
else:
self.client = MongoClient(**self.conn_details)
self.db = self.client[db_name]
self.embeddings = self.db.embeddings
self.embeddings_index = self.db.embeddings_index
self.counters = self.db.counters
self.initialize_database()
def close(self) -> None:
"""Close MongoDB connection."""
self.client.close()
def initialize_database(self, **kwargs: Any) -> None:
"""
Ensure required MongoDB indexes exist.
"""
# Unique constraint for embeddings
self.embeddings.create_index(
[("face_hash", self.ASCENDING), ("embedding_hash", self.ASCENDING)],
unique=True,
name="uniq_face_embedding",
)
# Unique constraint for embeddings_index
self.embeddings_index.create_index(
[
("model_name", self.ASCENDING),
("detector_backend", self.ASCENDING),
("align", self.ASCENDING),
("l2_normalized", self.ASCENDING),
],
unique=True,
name="uniq_index_config",
)
# counters collection for auto-incrementing IDs
if not self.counters.find_one({"_id": "embedding_id"}):
self.counters.insert_one({"_id": "embedding_id", "seq": 0})
logger.debug("MongoDB indexes ensured.")
def upsert_embeddings_index(
self,
model_name: str,
detector_backend: str,
aligned: bool,
l2_normalized: bool,
index_data: bytes,
) -> None:
"""
Upsert embeddings index into MongoDB.
Args:
model_name (str): Name of the model.
detector_backend (str): Name of the detector backend.
aligned (bool): Whether the embeddings are aligned.
l2_normalized (bool): Whether the embeddings are L2 normalized.
index_data (bytes): Serialized index data.
"""
self.embeddings_index.update_one(
{
"model_name": model_name,
"detector_backend": detector_backend,
"align": aligned,
"l2_normalized": l2_normalized,
},
{
"$set": {
"index_data": self.Binary(index_data),
"updated_at": datetime.now(timezone.utc),
},
"$setOnInsert": {
"created_at": datetime.now(timezone.utc),
},
},
upsert=True,
)
def get_embeddings_index(
self,
model_name: str,
detector_backend: str,
aligned: bool,
l2_normalized: bool,
) -> bytes:
"""
Retrieve embeddings index from MongoDB.
Args:
model_name (str): Name of the model.
detector_backend (str): Name of the detector backend.
aligned (bool): Whether the embeddings are aligned.
l2_normalized (bool): Whether the embeddings are L2 normalized.
Returns:
bytes: Serialized index data.
"""
doc = self.embeddings_index.find_one(
{
"model_name": model_name,
"detector_backend": detector_backend,
"align": aligned,
"l2_normalized": l2_normalized,
},
{"index_data": 1},
)
if not doc:
raise ValueError(
"No Embeddings index found for the specified parameters "
f"{model_name=}, {detector_backend=}, "
f"{aligned=}, {l2_normalized=}. "
"You must run build_index first."
)
return bytes(doc["index_data"])
def insert_embeddings(self, embeddings: List[Dict[str, Any]], batch_size: int = 100) -> int:
"""
Insert embeddings into MongoDB.
Args:
embeddings (List[Dict[str, Any]]): List of embedding records to insert.
batch_size (int): Number of records to insert in each batch.
Returns:
int: Number of embeddings successfully inserted.
"""
if not embeddings:
raise ValueError("No embeddings to insert.")
docs: List[Dict[str, Any]] = []
for e in embeddings:
face = e["face"]
face_shape = list(face.shape)
binary_face_data = self.Binary(face.astype(np.float32).tobytes())
embedding_bytes = struct.pack(f'{len(e["embedding"])}d', *e["embedding"])
face_hash = hashlib.sha256(json.dumps(face.tolist()).encode()).hexdigest()
embedding_hash = hashlib.sha256(embedding_bytes).hexdigest()
int_id = self.counters.find_one_and_update(
{"_id": "embedding_id"}, {"$inc": {"seq": 1}}, upsert=True, return_document=True
)["seq"]
docs.append(
{
"sequence": int_id,
"img_name": e["img_name"],
"face": binary_face_data,
"face_shape": face_shape,
"model_name": e["model_name"],
"detector_backend": e["detector_backend"],
"aligned": e["aligned"],
"l2_normalized": e["l2_normalized"],
"embedding": e["embedding"],
"face_hash": face_hash,
"embedding_hash": embedding_hash,
"created_at": datetime.now(timezone.utc),
}
)
inserted = 0
try:
for i in range(0, len(docs), batch_size):
result = self.embeddings.insert_many(docs[i : i + batch_size], ordered=False)
inserted += len(result.inserted_ids)
except (self.DuplicateKeyError, self.BulkWriteError) as e:
if len(docs) == 1:
logger.warn("Duplicate detected for extracted face and embedding.")
return inserted
raise DuplicateEntryError(
f"Duplicate detected for extracted face and embedding in {i}-th batch"
) from e
return inserted
def fetch_all_embeddings(
self,
model_name: str,
detector_backend: str,
aligned: bool,
l2_normalized: bool,
batch_size: int = 1000,
) -> List[Dict[str, Any]]:
"""
Fetch all embeddings from MongoDB based on specified parameters.
Args:
model_name (str): Name of the model.
detector_backend (str): Name of the detector backend.
aligned (bool): Whether the embeddings are aligned.
l2_normalized (bool): Whether the embeddings are L2 normalized.
batch_size (int): Number of records to fetch in each batch.
Returns:
List[Dict[str, Any]]: List of embedding records.
"""
cursor = self.embeddings.find(
{
"model_name": model_name,
"detector_backend": detector_backend,
"aligned": aligned,
"l2_normalized": l2_normalized,
},
{
"_id": 1,
"sequence": 1,
"img_name": 1,
"embedding": 1,
},
batch_size=batch_size,
).sort("sequence", self.ASCENDING)
results: List[Dict[str, Any]] = []
for doc in cursor:
results.append(
{
"_id": str(doc["_id"]),
"id": doc["sequence"],
"img_name": doc["img_name"],
"embedding": doc["embedding"],
"model_name": model_name,
"detector_backend": detector_backend,
"aligned": aligned,
"l2_normalized": l2_normalized,
}
)
return results
def search_by_id(
self,
ids: Union[List[str], List[int]],
) -> List[Dict[str, Any]]:
"""
Search records by their IDs.
"""
cursor = self.embeddings.find(
{"sequence": {"$in": ids}},
{
"_id": 1,
"sequence": 1,
"img_name": 1,
},
)
results: List[Dict[str, Any]] = []
for doc in cursor:
results.append(
{
"_id": str(doc["_id"]),
"id": doc["sequence"],
"img_name": doc["img_name"],
}
)
return results