metalmind / kg_sys /post_processing.py
IELTS8's picture
Upload folder using huggingface_hub
ada3f28 verified
from neo4j import GraphDatabase
import logging
import time
from langchain_community.graphs import Neo4jGraph
import os
from src.shared.common_fn import load_embedding_model
DROP_INDEX_QUERY = "DROP INDEX entities IF EXISTS;"
LABELS_QUERY = "CALL db.labels()"
FULL_TEXT_QUERY = "CREATE FULLTEXT INDEX entities FOR (n{labels_str}) ON EACH [n.id, n.description];"
FILTER_LABELS = ["Chunk", "Document"]
HYBRID_SEARCH_INDEX_DROP_QUERY = "DROP INDEX keyword IF EXISTS;"
HYBRID_SEARCH_FULL_TEXT_QUERY = "CREATE FULLTEXT INDEX keyword FOR (n:Chunk) ON EACH [n.text]"
def create_fulltext(uri, username, password, database, type):
start_time = time.time()
logging.info("Starting the process of creating a full-text index.")
try:
driver = GraphDatabase.driver(uri, auth=(username, password), database=database)
driver.verify_connectivity()
logging.info("Database connectivity verified.")
except Exception as e:
logging.error(f"Failed to create a database driver or verify connectivity: {e}")
return
try:
with driver.session() as session:
try:
start_step = time.time()
if type == "entities":
drop_query = DROP_INDEX_QUERY
else:
drop_query = HYBRID_SEARCH_INDEX_DROP_QUERY
session.run(drop_query)
logging.info(f"Dropped existing index (if any) in {time.time() - start_step:.2f} seconds.")
except Exception as e:
logging.error(f"Failed to drop index: {e}")
return
try:
if type == "entities":
start_step = time.time()
result = session.run(LABELS_QUERY)
labels = [record["label"] for record in result]
for label in FILTER_LABELS:
if label in labels:
labels.remove(label)
labels_str = ":" + "|".join([f"`{label}`" for label in labels])
logging.info(f"Fetched labels in {time.time() - start_step:.2f} seconds.")
except Exception as e:
logging.error(f"Failed to fetch labels: {e}")
return
try:
start_step = time.time()
if type == "entities":
fulltext_query = FULL_TEXT_QUERY.format(labels_str=labels_str)
else:
fulltext_query = HYBRID_SEARCH_FULL_TEXT_QUERY
session.run(fulltext_query)
logging.info(f"Created full-text index in {time.time() - start_step:.2f} seconds.")
except Exception as e:
logging.error(f"Failed to create full-text index: {e}")
return
except Exception as e:
logging.error(f"An error occurred during the session: {e}")
finally:
driver.close()
logging.info("Driver closed.")
logging.info(f"Process completed in {time.time() - start_time:.2f} seconds.")
def create_entity_embedding(graph: Neo4jGraph):
rows = fetch_entities_for_embedding(graph)
for i in range(0, len(rows), 1000):
update_embeddings(rows[i:i + 1000], graph)
def fetch_entities_for_embedding(graph):
query = """
MATCH (e)
WHERE NOT (e:Chunk OR e:Document) AND e.embedding IS NULL AND e.id IS NOT NULL
RETURN elementId(e) AS elementId, e.id + " " + coalesce(e.description, "") AS text
"""
result = graph.query(query)
return [{"elementId": record["elementId"], "text": record["text"]} for record in result]
def update_embeddings(rows, graph):
embedding_model = os.getenv('EMBEDDING_MODEL')
embeddings, dimension = load_embedding_model(embedding_model)
logging.info(f"update embedding for entities")
for row in rows:
row['embedding'] = embeddings.embed_query(row['text'])
query = """
UNWIND $rows AS row
MATCH (e) WHERE elementId(e) = row.elementId
CALL db.create.setNodeVectorProperty(e, "embedding", row.embedding)
"""
return graph.query(query, params={'rows': rows})