from neo4j import GraphDatabase import logging import time from langchain_community.graphs import Neo4jGraph import os from src.shared.common_fn import load_embedding_model DROP_INDEX_QUERY = "DROP INDEX entities IF EXISTS;" LABELS_QUERY = "CALL db.labels()" FULL_TEXT_QUERY = "CREATE FULLTEXT INDEX entities FOR (n{labels_str}) ON EACH [n.id, n.description];" FILTER_LABELS = ["Chunk", "Document"] HYBRID_SEARCH_INDEX_DROP_QUERY = "DROP INDEX keyword IF EXISTS;" HYBRID_SEARCH_FULL_TEXT_QUERY = "CREATE FULLTEXT INDEX keyword FOR (n:Chunk) ON EACH [n.text]" def create_fulltext(uri, username, password, database, type): start_time = time.time() logging.info("Starting the process of creating a full-text index.") try: driver = GraphDatabase.driver(uri, auth=(username, password), database=database) driver.verify_connectivity() logging.info("Database connectivity verified.") except Exception as e: logging.error(f"Failed to create a database driver or verify connectivity: {e}") return try: with driver.session() as session: try: start_step = time.time() if type == "entities": drop_query = DROP_INDEX_QUERY else: drop_query = HYBRID_SEARCH_INDEX_DROP_QUERY session.run(drop_query) logging.info(f"Dropped existing index (if any) in {time.time() - start_step:.2f} seconds.") except Exception as e: logging.error(f"Failed to drop index: {e}") return try: if type == "entities": start_step = time.time() result = session.run(LABELS_QUERY) labels = [record["label"] for record in result] for label in FILTER_LABELS: if label in labels: labels.remove(label) labels_str = ":" + "|".join([f"`{label}`" for label in labels]) logging.info(f"Fetched labels in {time.time() - start_step:.2f} seconds.") except Exception as e: logging.error(f"Failed to fetch labels: {e}") return try: start_step = time.time() if type == "entities": fulltext_query = FULL_TEXT_QUERY.format(labels_str=labels_str) else: fulltext_query = HYBRID_SEARCH_FULL_TEXT_QUERY session.run(fulltext_query) logging.info(f"Created full-text index in {time.time() - start_step:.2f} seconds.") except Exception as e: logging.error(f"Failed to create full-text index: {e}") return except Exception as e: logging.error(f"An error occurred during the session: {e}") finally: driver.close() logging.info("Driver closed.") logging.info(f"Process completed in {time.time() - start_time:.2f} seconds.") def create_entity_embedding(graph: Neo4jGraph): rows = fetch_entities_for_embedding(graph) for i in range(0, len(rows), 1000): update_embeddings(rows[i:i + 1000], graph) def fetch_entities_for_embedding(graph): query = """ MATCH (e) WHERE NOT (e:Chunk OR e:Document) AND e.embedding IS NULL AND e.id IS NOT NULL RETURN elementId(e) AS elementId, e.id + " " + coalesce(e.description, "") AS text """ result = graph.query(query) return [{"elementId": record["elementId"], "text": record["text"]} for record in result] def update_embeddings(rows, graph): embedding_model = os.getenv('EMBEDDING_MODEL') embeddings, dimension = load_embedding_model(embedding_model) logging.info(f"update embedding for entities") for row in rows: row['embedding'] = embeddings.embed_query(row['text']) query = """ UNWIND $rows AS row MATCH (e) WHERE elementId(e) = row.elementId CALL db.create.setNodeVectorProperty(e, "embedding", row.embedding) """ return graph.query(query, params={'rows': rows})