| from neo4j import GraphDatabase | |
| import logging | |
| import time | |
| from langchain_community.graphs import Neo4jGraph | |
| import os | |
| from src.shared.common_fn import load_embedding_model | |
| DROP_INDEX_QUERY = "DROP INDEX entities IF EXISTS;" | |
| LABELS_QUERY = "CALL db.labels()" | |
| FULL_TEXT_QUERY = "CREATE FULLTEXT INDEX entities FOR (n{labels_str}) ON EACH [n.id, n.description];" | |
| FILTER_LABELS = ["Chunk", "Document"] | |
| HYBRID_SEARCH_INDEX_DROP_QUERY = "DROP INDEX keyword IF EXISTS;" | |
| HYBRID_SEARCH_FULL_TEXT_QUERY = "CREATE FULLTEXT INDEX keyword FOR (n:Chunk) ON EACH [n.text]" | |
| def create_fulltext(uri, username, password, database, type): | |
| start_time = time.time() | |
| logging.info("Starting the process of creating a full-text index.") | |
| try: | |
| driver = GraphDatabase.driver(uri, auth=(username, password), database=database) | |
| driver.verify_connectivity() | |
| logging.info("Database connectivity verified.") | |
| except Exception as e: | |
| logging.error(f"Failed to create a database driver or verify connectivity: {e}") | |
| return | |
| try: | |
| with driver.session() as session: | |
| try: | |
| start_step = time.time() | |
| if type == "entities": | |
| drop_query = DROP_INDEX_QUERY | |
| else: | |
| drop_query = HYBRID_SEARCH_INDEX_DROP_QUERY | |
| session.run(drop_query) | |
| logging.info(f"Dropped existing index (if any) in {time.time() - start_step:.2f} seconds.") | |
| except Exception as e: | |
| logging.error(f"Failed to drop index: {e}") | |
| return | |
| try: | |
| if type == "entities": | |
| start_step = time.time() | |
| result = session.run(LABELS_QUERY) | |
| labels = [record["label"] for record in result] | |
| for label in FILTER_LABELS: | |
| if label in labels: | |
| labels.remove(label) | |
| labels_str = ":" + "|".join([f"`{label}`" for label in labels]) | |
| logging.info(f"Fetched labels in {time.time() - start_step:.2f} seconds.") | |
| except Exception as e: | |
| logging.error(f"Failed to fetch labels: {e}") | |
| return | |
| try: | |
| start_step = time.time() | |
| if type == "entities": | |
| fulltext_query = FULL_TEXT_QUERY.format(labels_str=labels_str) | |
| else: | |
| fulltext_query = HYBRID_SEARCH_FULL_TEXT_QUERY | |
| session.run(fulltext_query) | |
| logging.info(f"Created full-text index in {time.time() - start_step:.2f} seconds.") | |
| except Exception as e: | |
| logging.error(f"Failed to create full-text index: {e}") | |
| return | |
| except Exception as e: | |
| logging.error(f"An error occurred during the session: {e}") | |
| finally: | |
| driver.close() | |
| logging.info("Driver closed.") | |
| logging.info(f"Process completed in {time.time() - start_time:.2f} seconds.") | |
| def create_entity_embedding(graph: Neo4jGraph): | |
| rows = fetch_entities_for_embedding(graph) | |
| for i in range(0, len(rows), 1000): | |
| update_embeddings(rows[i:i + 1000], graph) | |
| def fetch_entities_for_embedding(graph): | |
| query = """ | |
| MATCH (e) | |
| WHERE NOT (e:Chunk OR e:Document) AND e.embedding IS NULL AND e.id IS NOT NULL | |
| RETURN elementId(e) AS elementId, e.id + " " + coalesce(e.description, "") AS text | |
| """ | |
| result = graph.query(query) | |
| return [{"elementId": record["elementId"], "text": record["text"]} for record in result] | |
| def update_embeddings(rows, graph): | |
| embedding_model = os.getenv('EMBEDDING_MODEL') | |
| embeddings, dimension = load_embedding_model(embedding_model) | |
| logging.info(f"update embedding for entities") | |
| for row in rows: | |
| row['embedding'] = embeddings.embed_query(row['text']) | |
| query = """ | |
| UNWIND $rows AS row | |
| MATCH (e) WHERE elementId(e) = row.elementId | |
| CALL db.create.setNodeVectorProperty(e, "embedding", row.embedding) | |
| """ | |
| return graph.query(query, params={'rows': rows}) | |