DaoManhDuc2004
Deploy DATN face AI server
b5d3a91
# built-in dependencies
import os
import json
import hashlib
import struct
from typing import Any, Dict, Optional, List, Union
from urllib.parse import urlparse
# project dependencies
from deepface.modules.database.types import Database
from deepface.modules.modeling import build_model
from deepface.modules.verification import find_cosine_distance, find_euclidean_distance
from deepface.commons.logger import Logger
logger = Logger()
_SCHEMA_CHECKED: Dict[str, bool] = {}
# pylint: disable=too-many-positional-arguments
class Neo4jClient(Database):
def __init__(
self,
connection_details: Optional[Union[Dict[str, Any], str]] = None,
connection: Any = None,
) -> None:
# Import here to avoid mandatory dependency
try:
from neo4j import GraphDatabase
except (ModuleNotFoundError, ImportError) as e:
raise ValueError(
"neo4j is an optional dependency, ensure the library is installed."
"Please install using 'pip install neo4j' "
) from e
self.GraphDatabase = GraphDatabase
if connection is not None:
self.conn = connection
else:
self.conn_details = connection_details or os.environ.get("DEEPFACE_NEO4J_URI")
if not self.conn_details:
raise ValueError(
"Neo4j connection information not found. "
"Please provide connection_details or set the DEEPFACE_NEO4J_URI"
" environment variable."
)
if isinstance(self.conn_details, str):
parsed = urlparse(self.conn_details)
uri = f"{parsed.scheme}://{parsed.hostname}:{parsed.port}"
self.conn = self.GraphDatabase.driver(uri, auth=(parsed.username, parsed.password))
else:
raise ValueError("connection_details must be a string.")
if not self.__is_gds_installed():
raise ValueError(
"Neo4j Graph Data Science (GDS) plugin is not installed. "
"Please install the GDS plugin to use Neo4j as a database backend."
)
def close(self) -> None:
"""
Close the Neo4j database connection.
"""
if self.conn:
self.conn.close()
logger.debug("Neo4j connection closed.")
def initialize_database(self, **kwargs: Any) -> None:
"""
Ensure Neo4j database has the necessary constraints and indexes for storing embeddings.
"""
model_name = kwargs.get("model_name", "VGG-Face")
detector_backend = kwargs.get("detector_backend", "opencv")
aligned = kwargs.get("aligned", True)
l2_normalized = kwargs.get("l2_normalized", False)
node_label = self.__generate_node_label(
model_name=model_name,
detector_backend=detector_backend,
aligned=aligned,
l2_normalized=l2_normalized,
)
model = build_model(task="facial_recognition", model_name=model_name)
dimensions = model.output_shape
similarity_function = "cosine" if l2_normalized else "euclidean"
if _SCHEMA_CHECKED.get(node_label):
logger.debug(f"Neo4j index {node_label} already exists, skipping creation.")
return
index_query = f"""
CREATE VECTOR INDEX {node_label}_embedding_idx IF NOT EXISTS
FOR (d:{node_label})
ON (d.embedding)
OPTIONS {{
indexConfig: {{
`vector.dimensions`: {dimensions},
`vector.similarity_function`: '{similarity_function}'
}}
}};
"""
uniq_query = f"""
CREATE CONSTRAINT {node_label}_unique IF NOT EXISTS
FOR (n:{node_label})
REQUIRE (n.face_hash, n.embedding_hash) IS UNIQUE;
"""
with self.conn.session() as session:
session.execute_write(lambda tx: tx.run(index_query))
session.execute_write(lambda tx: tx.run(uniq_query))
_SCHEMA_CHECKED[node_label] = True
logger.debug(f"Neo4j index {node_label} ensured.")
def insert_embeddings(self, embeddings: List[Dict[str, Any]], batch_size: int = 100) -> int:
"""
Insert embeddings into Neo4j database in batches.
"""
if not embeddings:
raise ValueError("No embeddings to insert.")
self.initialize_database(
model_name=embeddings[0]["model_name"],
detector_backend=embeddings[0]["detector_backend"],
aligned=embeddings[0]["aligned"],
l2_normalized=embeddings[0]["l2_normalized"],
)
node_label = self.__generate_node_label(
model_name=embeddings[0]["model_name"],
detector_backend=embeddings[0]["detector_backend"],
aligned=embeddings[0]["aligned"],
l2_normalized=embeddings[0]["l2_normalized"],
)
query = f"""
UNWIND $rows AS r
MERGE (n:{node_label} {{face_hash: r.face_hash, embedding_hash: r.embedding_hash}})
ON CREATE SET
n.img_name = r.img_name,
n.embedding = r.embedding,
n.face = r.face,
n.model_name = r.model_name,
n.detector_backend = r.detector_backend,
n.aligned = r.aligned,
n.l2_normalized = r.l2_normalized
RETURN count(*) AS processed
"""
total = 0
with self.conn.session() as session:
for i in range(0, len(embeddings), batch_size):
batch = embeddings[i : i + batch_size]
rows = []
for e in batch:
face_json = json.dumps(e["face"].tolist())
face_hash = hashlib.sha256(face_json.encode()).hexdigest()
embedding_bytes = struct.pack(f'{len(e["embedding"])}d', *e["embedding"])
embedding_hash = hashlib.sha256(embedding_bytes).hexdigest()
rows.append(
{
"face_hash": face_hash,
"embedding_hash": embedding_hash,
"img_name": e["img_name"],
"embedding": e["embedding"],
# "face": e["face"].tolist(),
# "face_shape": list(e["face"].shape),
"model_name": e.get("model_name"),
"detector_backend": e.get("detector_backend"),
"aligned": bool(e.get("aligned", True)),
"l2_normalized": bool(e.get("l2_normalized", False)),
}
)
processed = session.execute_write(
lambda tx, q=query, r=rows: int(tx.run(q, rows=r).single()["processed"])
)
total += processed
return total
def fetch_all_embeddings(
self,
model_name: str,
detector_backend: str,
aligned: bool,
l2_normalized: bool,
batch_size: int = 1000,
) -> List[Dict[str, Any]]:
"""
Fetch all embeddings from Neo4j database in batches.
"""
node_label = self.__generate_node_label(
model_name=model_name,
detector_backend=detector_backend,
aligned=aligned,
l2_normalized=l2_normalized,
)
query = f"""
MATCH (n:{node_label})
WHERE n.embedding IS NOT NULL
AND ($last_eid IS NULL OR elementId(n) > $last_eid)
RETURN
elementId(n) AS cursor,
coalesce(n.id, elementId(n)) AS id,
n.img_name AS img_name,
n.embedding AS embedding
ORDER BY cursor ASC
LIMIT $limit
"""
out: List[Dict[str, Any]] = []
last_eid: Optional[str] = None
with self.conn.session() as session:
while True:
result = session.run(query, last_eid=last_eid, limit=batch_size)
rows = list(result)
if not rows:
break
for r in rows:
out.append(
{
"id": r["id"],
"img_name": r["img_name"],
"embedding": r["embedding"],
"model_name": model_name,
"detector_backend": detector_backend,
"aligned": aligned,
"l2_normalized": l2_normalized,
}
)
# advance cursor using elementId
last_eid = rows[-1]["cursor"]
return out
def search_by_vector(
self,
vector: List[float],
model_name: str = "VGG-Face",
detector_backend: str = "opencv",
aligned: bool = True,
l2_normalized: bool = False,
limit: int = 10,
) -> List[Dict[str, Any]]:
"""
ANN search using the main vector (embedding).
"""
self.initialize_database(
model_name=model_name,
detector_backend=detector_backend,
aligned=aligned,
l2_normalized=l2_normalized,
)
node_label = self.__generate_node_label(
model_name=model_name,
detector_backend=detector_backend,
aligned=aligned,
l2_normalized=l2_normalized,
)
index_name = f"{node_label}_embedding_idx"
query = """
CALL db.index.vector.queryNodes($index_name, $limit, $vector)
YIELD node, score
RETURN
elementId(node) AS id,
node.img_name AS img_name,
node.face_hash AS face_hash,
node.embedding AS embedding,
node.embedding_hash AS embedding_hash,
score AS score
ORDER BY score DESC
"""
with self.conn.session() as session:
result = session.run(
query,
index_name=index_name,
limit=limit,
vector=vector,
)
out: List[Dict[str, Any]] = []
for r in result:
if l2_normalized:
distance = find_cosine_distance(vector, r.get("embedding"))
# distance = 2 * (1 - r.get("score"))
else:
distance = find_euclidean_distance(vector, r.get("embedding"))
# distance = math.sqrt(1.0 / r.get("score"))
out.append(
{
"id": r.get("id"),
"img_name": r.get("img_name"),
"face_hash": r.get("face_hash"),
"embedding_hash": r.get("embedding_hash"),
"model_name": model_name,
"detector_backend": detector_backend,
"aligned": aligned,
"l2_normalized": l2_normalized,
"distance": distance,
}
)
return out
def __is_gds_installed(self) -> bool:
"""
Check if the Graph Data Science (GDS) plugin is installed in the Neo4j database.
"""
query = "RETURN gds.version() AS version"
try:
with self.conn.session() as session:
result = session.run(query).single()
logger.debug(f"GDS version: {result['version']}")
return True
except Exception as e: # pylint: disable=broad-except
logger.error(f"GDS plugin not installed or error occurred: {e}")
return False
@staticmethod
def __generate_node_label(
model_name: str,
detector_backend: str,
aligned: bool,
l2_normalized: bool,
) -> str:
"""
Generate a Neo4j node label based on model and preprocessing parameters.
"""
label_parts = [
model_name.replace("-", "_").capitalize(),
detector_backend.capitalize(),
"Aligned" if aligned else "Unaligned",
"Norm" if l2_normalized else "Raw",
]
return "".join(label_parts)