Spaces:
Running
Running
| # built-in dependencies | |
| import os | |
| import json | |
| import hashlib | |
| import struct | |
| import base64 | |
| import uuid | |
| import math | |
| from typing import Any, Dict, Optional, List, Union | |
| # project dependencies | |
| from deepface.modules.database.types import Database | |
| from deepface.commons.logger import Logger | |
| logger = Logger() | |
| _SCHEMA_CHECKED: Dict[str, bool] = {} | |
| # pylint: disable=too-many-positional-arguments | |
| class WeaviateClient(Database): | |
| """ | |
| Weaviate client for storing and retrieving face embeddings and indices. | |
| """ | |
| def __init__( | |
| self, | |
| connection_details: Optional[Union[str, Dict[str, Any]]] = None, | |
| connection: Any = None, | |
| ): | |
| try: | |
| import weaviate | |
| except (ModuleNotFoundError, ImportError) as e: | |
| raise ValueError( | |
| "weaviate-client is an optional dependency. " | |
| "Install with 'pip install weaviate-client'" | |
| ) from e | |
| self.weaviate = weaviate | |
| if connection is not None: | |
| self.client = connection | |
| # URL key for _WEAVIATE_CHECKED; fallback if client has no URL | |
| self.url = getattr(connection, "url", str(id(connection))) | |
| else: | |
| self.conn_details = connection_details or os.environ.get("DEEPFACE_WEAVIATE_URL") | |
| if isinstance(self.conn_details, str): | |
| self.url = self.conn_details | |
| self.api_key = os.getenv("WEAVIATE_API_KEY") | |
| elif isinstance(self.conn_details, dict): | |
| self.url = self.conn_details.get("url") | |
| self.api_key = self.conn_details.get("api_key") or os.getenv("WEAVIATE_API_KEY") | |
| else: | |
| raise ValueError("connection_details must be a string or dict with 'url'.") | |
| if not self.url: | |
| raise ValueError("Weaviate URL not provided in connection_details.") | |
| client_config = {"url": self.url} | |
| if getattr(self, "api_key", None): | |
| client_config["auth_client_secret"] = self.weaviate.AuthApiKey(api_key=self.api_key) | |
| self.client = self.weaviate.Client(**client_config) | |
| def initialize_database(self, **kwargs: Any) -> None: | |
| """ | |
| Ensure Weaviate schemas exist for embeddings using both cosine and L2 (euclidean). | |
| """ | |
| model_name = kwargs.get("model_name", "VGG-Face") | |
| detector_backend = kwargs.get("detector_backend", "opencv") | |
| aligned = kwargs.get("aligned", True) | |
| l2_normalized = kwargs.get("l2_normalized", False) | |
| existing_schema = self.client.schema.get() | |
| existing_classes = {c["class"] for c in existing_schema.get("classes", [])} | |
| class_name = self.__generate_class_name( | |
| model_name=model_name, | |
| detector_backend=detector_backend, | |
| aligned=aligned, | |
| l2_normalized=l2_normalized, | |
| ) | |
| if _SCHEMA_CHECKED.get(class_name): | |
| logger.debug("Weaviate schema already checked, skipping.") | |
| return | |
| if class_name in existing_classes: | |
| logger.debug(f"Weaviate class {class_name} already exists.") | |
| return | |
| self.client.schema.create_class( | |
| { | |
| "class": class_name, | |
| "vectorIndexType": "hnsw", | |
| "vectorizer": "none", | |
| "vectorIndexConfig": { | |
| "M": int(os.getenv("WEAVIATE_HNSW_M", "16")), | |
| "distance": "cosine" if l2_normalized else "l2-squared", | |
| }, | |
| "properties": [ | |
| {"name": "img_name", "dataType": ["text"]}, | |
| {"name": "face", "dataType": ["blob"]}, | |
| {"name": "face_shape", "dataType": ["int[]"]}, | |
| {"name": "model_name", "dataType": ["text"]}, | |
| {"name": "detector_backend", "dataType": ["text"]}, | |
| {"name": "aligned", "dataType": ["boolean"]}, | |
| {"name": "l2_normalized", "dataType": ["boolean"]}, | |
| {"name": "face_hash", "dataType": ["text"]}, | |
| {"name": "embedding_hash", "dataType": ["text"]}, | |
| # embedding property is optional since we pass it as vector | |
| {"name": "embedding", "dataType": ["number[]"]}, | |
| ], | |
| } | |
| ) | |
| logger.debug(f"Weaviate class {class_name} created successfully.") | |
| _SCHEMA_CHECKED[class_name] = True | |
| def insert_embeddings(self, embeddings: List[Dict[str, Any]], batch_size: int = 100) -> int: | |
| """ | |
| Insert multiple embeddings into Weaviate using batch API. | |
| """ | |
| if not embeddings: | |
| raise ValueError("No embeddings to insert.") | |
| self.initialize_database( | |
| model_name=embeddings[0]["model_name"], | |
| detector_backend=embeddings[0]["detector_backend"], | |
| aligned=embeddings[0]["aligned"], | |
| l2_normalized=embeddings[0]["l2_normalized"], | |
| ) | |
| class_name = self.__generate_class_name( | |
| model_name=embeddings[0]["model_name"], | |
| detector_backend=embeddings[0]["detector_backend"], | |
| aligned=embeddings[0]["aligned"], | |
| l2_normalized=embeddings[0]["l2_normalized"], | |
| ) | |
| with self.client.batch as batcher: | |
| batcher.batch_size = batch_size | |
| batcher.timeout_retries = 3 | |
| for e in embeddings: | |
| face_json = json.dumps(e["face"].tolist()) | |
| face_hash = hashlib.sha256(face_json.encode()).hexdigest() | |
| embedding_bytes = struct.pack(f'{len(e["embedding"])}d', *e["embedding"]) | |
| embedding_hash = hashlib.sha256(embedding_bytes).hexdigest() | |
| # Check if embedding already exists | |
| query = ( | |
| self.client.query.get(class_name, ["embedding_hash"]) | |
| .with_where( | |
| { | |
| "path": ["embedding_hash"], | |
| "operator": "Equal", | |
| "valueText": embedding_hash, | |
| } | |
| ) | |
| .with_limit(1) | |
| .do() | |
| ) | |
| existing = query.get("data", {}).get("Get", {}).get(class_name, []) | |
| if existing: | |
| logger.warn( | |
| f"Embedding with hash {embedding_hash} already exists in {class_name}." | |
| ) | |
| continue | |
| uid = str(uuid.uuid4()) | |
| properties = { | |
| "img_name": e["img_name"], | |
| "face": base64.b64encode(e["face"].tobytes()).decode("utf-8"), | |
| "face_shape": list(e["face"].shape), | |
| "model_name": e["model_name"], | |
| "detector_backend": e["detector_backend"], | |
| "aligned": e["aligned"], | |
| "l2_normalized": e["l2_normalized"], | |
| "embedding": e["embedding"], # optional | |
| "face_hash": face_hash, | |
| "embedding_hash": embedding_hash, | |
| } | |
| batcher.add_data_object(properties, class_name, vector=e["embedding"], uuid=uid) | |
| return len(embeddings) | |
| def fetch_all_embeddings( | |
| self, | |
| model_name: str, | |
| detector_backend: str, | |
| aligned: bool, | |
| l2_normalized: bool, | |
| batch_size: int = 1000, | |
| ) -> List[Dict[str, Any]]: | |
| """ | |
| Fetch all embeddings with filters. | |
| """ | |
| class_name = self.__generate_class_name( | |
| model_name=model_name, | |
| detector_backend=detector_backend, | |
| aligned=aligned, | |
| l2_normalized=l2_normalized, | |
| ) | |
| self.initialize_database( | |
| model_name=model_name, | |
| detector_backend=detector_backend, | |
| aligned=aligned, | |
| l2_normalized=l2_normalized, | |
| ) | |
| results = ( | |
| self.client.query.get(class_name, ["img_name", "embedding"]) | |
| .with_additional(["id"]) | |
| .do() | |
| ) | |
| data = results.get("data", {}).get("Get", {}).get(class_name, []) | |
| embeddings = [] | |
| for r in data: | |
| embeddings.append( | |
| { | |
| "id": r.get("_additional", {}).get("id"), | |
| "img_name": r["img_name"], | |
| "embedding": r["embedding"], | |
| "model_name": model_name, | |
| "detector_backend": detector_backend, | |
| "aligned": aligned, | |
| "l2_normalized": l2_normalized, | |
| } | |
| ) | |
| return embeddings | |
| def search_by_vector( | |
| self, | |
| vector: List[float], | |
| model_name: str = "VGG-Face", | |
| detector_backend: str = "opencv", | |
| aligned: bool = True, | |
| l2_normalized: bool = False, | |
| limit: int = 10, | |
| ) -> List[Dict[str, Any]]: | |
| """ | |
| ANN search using the main vector (embedding). | |
| """ | |
| class_name = self.__generate_class_name( | |
| model_name=model_name, | |
| detector_backend=detector_backend, | |
| aligned=aligned, | |
| l2_normalized=l2_normalized, | |
| ) | |
| self.initialize_database( | |
| model_name=model_name, | |
| detector_backend=detector_backend, | |
| aligned=aligned, | |
| l2_normalized=l2_normalized, | |
| ) | |
| query = self.client.query.get(class_name, ["img_name", "embedding"]) | |
| query = ( | |
| query.with_near_vector({"vector": vector}) | |
| .with_limit(limit) | |
| .with_additional(["id", "distance"]) | |
| ) | |
| results = query.do() | |
| data = results.get("data", {}).get("Get", {}).get(class_name, []) | |
| return [ | |
| { | |
| "id": r.get("_additional", {}).get("id"), | |
| "img_name": r["img_name"], | |
| "embedding": r["embedding"], | |
| "distance": ( | |
| r.get("_additional", {}).get("distance") | |
| if l2_normalized | |
| else math.sqrt(r.get("_additional", {}).get("distance")) | |
| ), | |
| } | |
| for r in data | |
| ] | |
| def close(self) -> None: | |
| """ | |
| Close the Weaviate client connection. | |
| """ | |
| self.client.close() | |
| def __generate_class_name( | |
| model_name: str, | |
| detector_backend: str, | |
| aligned: bool, | |
| l2_normalized: bool, | |
| ) -> str: | |
| """ | |
| Generate Weaviate class name based on parameters. | |
| """ | |
| class_name_attributes = [ | |
| model_name.replace("-", ""), | |
| detector_backend, | |
| "Aligned" if aligned else "Unaligned", | |
| "Norm" if l2_normalized else "Raw", | |
| ] | |
| return "Embeddings_" + "_".join(class_name_attributes).lower() | |