DaoManhDuc2004
Deploy DATN face AI server
b5d3a91
# built-in dependencies
import os
import json
import hashlib
import struct
import base64
import uuid
import math
from typing import Any, Dict, Optional, List, Union
# project dependencies
from deepface.modules.database.types import Database
from deepface.commons.logger import Logger
logger = Logger()
_SCHEMA_CHECKED: Dict[str, bool] = {}
# pylint: disable=too-many-positional-arguments
class WeaviateClient(Database):
"""
Weaviate client for storing and retrieving face embeddings and indices.
"""
def __init__(
self,
connection_details: Optional[Union[str, Dict[str, Any]]] = None,
connection: Any = None,
):
try:
import weaviate
except (ModuleNotFoundError, ImportError) as e:
raise ValueError(
"weaviate-client is an optional dependency. "
"Install with 'pip install weaviate-client'"
) from e
self.weaviate = weaviate
if connection is not None:
self.client = connection
# URL key for _WEAVIATE_CHECKED; fallback if client has no URL
self.url = getattr(connection, "url", str(id(connection)))
else:
self.conn_details = connection_details or os.environ.get("DEEPFACE_WEAVIATE_URL")
if isinstance(self.conn_details, str):
self.url = self.conn_details
self.api_key = os.getenv("WEAVIATE_API_KEY")
elif isinstance(self.conn_details, dict):
self.url = self.conn_details.get("url")
self.api_key = self.conn_details.get("api_key") or os.getenv("WEAVIATE_API_KEY")
else:
raise ValueError("connection_details must be a string or dict with 'url'.")
if not self.url:
raise ValueError("Weaviate URL not provided in connection_details.")
client_config = {"url": self.url}
if getattr(self, "api_key", None):
client_config["auth_client_secret"] = self.weaviate.AuthApiKey(api_key=self.api_key)
self.client = self.weaviate.Client(**client_config)
def initialize_database(self, **kwargs: Any) -> None:
"""
Ensure Weaviate schemas exist for embeddings using both cosine and L2 (euclidean).
"""
model_name = kwargs.get("model_name", "VGG-Face")
detector_backend = kwargs.get("detector_backend", "opencv")
aligned = kwargs.get("aligned", True)
l2_normalized = kwargs.get("l2_normalized", False)
existing_schema = self.client.schema.get()
existing_classes = {c["class"] for c in existing_schema.get("classes", [])}
class_name = self.__generate_class_name(
model_name=model_name,
detector_backend=detector_backend,
aligned=aligned,
l2_normalized=l2_normalized,
)
if _SCHEMA_CHECKED.get(class_name):
logger.debug("Weaviate schema already checked, skipping.")
return
if class_name in existing_classes:
logger.debug(f"Weaviate class {class_name} already exists.")
return
self.client.schema.create_class(
{
"class": class_name,
"vectorIndexType": "hnsw",
"vectorizer": "none",
"vectorIndexConfig": {
"M": int(os.getenv("WEAVIATE_HNSW_M", "16")),
"distance": "cosine" if l2_normalized else "l2-squared",
},
"properties": [
{"name": "img_name", "dataType": ["text"]},
{"name": "face", "dataType": ["blob"]},
{"name": "face_shape", "dataType": ["int[]"]},
{"name": "model_name", "dataType": ["text"]},
{"name": "detector_backend", "dataType": ["text"]},
{"name": "aligned", "dataType": ["boolean"]},
{"name": "l2_normalized", "dataType": ["boolean"]},
{"name": "face_hash", "dataType": ["text"]},
{"name": "embedding_hash", "dataType": ["text"]},
# embedding property is optional since we pass it as vector
{"name": "embedding", "dataType": ["number[]"]},
],
}
)
logger.debug(f"Weaviate class {class_name} created successfully.")
_SCHEMA_CHECKED[class_name] = True
def insert_embeddings(self, embeddings: List[Dict[str, Any]], batch_size: int = 100) -> int:
"""
Insert multiple embeddings into Weaviate using batch API.
"""
if not embeddings:
raise ValueError("No embeddings to insert.")
self.initialize_database(
model_name=embeddings[0]["model_name"],
detector_backend=embeddings[0]["detector_backend"],
aligned=embeddings[0]["aligned"],
l2_normalized=embeddings[0]["l2_normalized"],
)
class_name = self.__generate_class_name(
model_name=embeddings[0]["model_name"],
detector_backend=embeddings[0]["detector_backend"],
aligned=embeddings[0]["aligned"],
l2_normalized=embeddings[0]["l2_normalized"],
)
with self.client.batch as batcher:
batcher.batch_size = batch_size
batcher.timeout_retries = 3
for e in embeddings:
face_json = json.dumps(e["face"].tolist())
face_hash = hashlib.sha256(face_json.encode()).hexdigest()
embedding_bytes = struct.pack(f'{len(e["embedding"])}d', *e["embedding"])
embedding_hash = hashlib.sha256(embedding_bytes).hexdigest()
# Check if embedding already exists
query = (
self.client.query.get(class_name, ["embedding_hash"])
.with_where(
{
"path": ["embedding_hash"],
"operator": "Equal",
"valueText": embedding_hash,
}
)
.with_limit(1)
.do()
)
existing = query.get("data", {}).get("Get", {}).get(class_name, [])
if existing:
logger.warn(
f"Embedding with hash {embedding_hash} already exists in {class_name}."
)
continue
uid = str(uuid.uuid4())
properties = {
"img_name": e["img_name"],
"face": base64.b64encode(e["face"].tobytes()).decode("utf-8"),
"face_shape": list(e["face"].shape),
"model_name": e["model_name"],
"detector_backend": e["detector_backend"],
"aligned": e["aligned"],
"l2_normalized": e["l2_normalized"],
"embedding": e["embedding"], # optional
"face_hash": face_hash,
"embedding_hash": embedding_hash,
}
batcher.add_data_object(properties, class_name, vector=e["embedding"], uuid=uid)
return len(embeddings)
def fetch_all_embeddings(
self,
model_name: str,
detector_backend: str,
aligned: bool,
l2_normalized: bool,
batch_size: int = 1000,
) -> List[Dict[str, Any]]:
"""
Fetch all embeddings with filters.
"""
class_name = self.__generate_class_name(
model_name=model_name,
detector_backend=detector_backend,
aligned=aligned,
l2_normalized=l2_normalized,
)
self.initialize_database(
model_name=model_name,
detector_backend=detector_backend,
aligned=aligned,
l2_normalized=l2_normalized,
)
results = (
self.client.query.get(class_name, ["img_name", "embedding"])
.with_additional(["id"])
.do()
)
data = results.get("data", {}).get("Get", {}).get(class_name, [])
embeddings = []
for r in data:
embeddings.append(
{
"id": r.get("_additional", {}).get("id"),
"img_name": r["img_name"],
"embedding": r["embedding"],
"model_name": model_name,
"detector_backend": detector_backend,
"aligned": aligned,
"l2_normalized": l2_normalized,
}
)
return embeddings
def search_by_vector(
self,
vector: List[float],
model_name: str = "VGG-Face",
detector_backend: str = "opencv",
aligned: bool = True,
l2_normalized: bool = False,
limit: int = 10,
) -> List[Dict[str, Any]]:
"""
ANN search using the main vector (embedding).
"""
class_name = self.__generate_class_name(
model_name=model_name,
detector_backend=detector_backend,
aligned=aligned,
l2_normalized=l2_normalized,
)
self.initialize_database(
model_name=model_name,
detector_backend=detector_backend,
aligned=aligned,
l2_normalized=l2_normalized,
)
query = self.client.query.get(class_name, ["img_name", "embedding"])
query = (
query.with_near_vector({"vector": vector})
.with_limit(limit)
.with_additional(["id", "distance"])
)
results = query.do()
data = results.get("data", {}).get("Get", {}).get(class_name, [])
return [
{
"id": r.get("_additional", {}).get("id"),
"img_name": r["img_name"],
"embedding": r["embedding"],
"distance": (
r.get("_additional", {}).get("distance")
if l2_normalized
else math.sqrt(r.get("_additional", {}).get("distance"))
),
}
for r in data
]
def close(self) -> None:
"""
Close the Weaviate client connection.
"""
self.client.close()
@staticmethod
def __generate_class_name(
model_name: str,
detector_backend: str,
aligned: bool,
l2_normalized: bool,
) -> str:
"""
Generate Weaviate class name based on parameters.
"""
class_name_attributes = [
model_name.replace("-", ""),
detector_backend,
"Aligned" if aligned else "Unaligned",
"Norm" if l2_normalized else "Raw",
]
return "Embeddings_" + "_".join(class_name_attributes).lower()