from neo4j import GraphDatabase, Driver import logging from typing import List, Dict, Any from . import config logger = logging.getLogger(__name__) _driver: Driver = None def get_graph_driver() -> Driver: """Initializes and returns the singleton Neo4j driver instance.""" global _driver if _driver is None: logger.info("Initializing Neo4j driver...") _driver = GraphDatabase.driver(config.NEO4J_URI, auth=(config.NEO4J_USER, config.NEO4J_PASSWORD)) _ensure_constraints(_driver) return _driver def close_graph_driver(): """Closes the Neo4j driver connection.""" global _driver if _driver: logger.info("Closing Neo4j driver.") _driver.close() _driver = None def _ensure_constraints(driver: Driver): """Ensure uniqueness constraints are set up in Neo4j.""" with driver.session() as session: session.run("CREATE CONSTRAINT IF NOT EXISTS FOR (d:Database) REQUIRE d.name IS UNIQUE") session.run("CREATE CONSTRAINT IF NOT EXISTS FOR (t:Table) REQUIRE t.unique_name IS UNIQUE") session.run("CREATE CONSTRAINT IF NOT EXISTS FOR (c:Column) REQUIRE c.unique_name IS UNIQUE") logger.info("Neo4j constraints ensured.") def import_schema(schema_data: dict): """ Imports a discovered database schema into the Neo4j graph. """ driver = get_graph_driver() db_name = schema_data['database_name'] with driver.session() as session: # Create Database node session.run("MERGE (d:Database {name: $db_name})", db_name=db_name) for table in schema_data['tables']: table_unique_name = f"{db_name}.{table['name']}" table_properties = { "name": table['name'], "unique_name": table_unique_name, } # Create Table node and HAS_TABLE relationship session.run( """ MATCH (d:Database {name: $db_name}) MERGE (t:Table {unique_name: $unique_name}) ON CREATE SET t += $props ON MATCH SET t += $props MERGE (d)-[:HAS_TABLE]->(t) """, db_name=db_name, unique_name=table_unique_name, props=table_properties ) for column in table['columns']: column_unique_name = f"{table_unique_name}.{column['name']}" column_properties = { "name": column['name'], "unique_name": column_unique_name, "type": column['type'], } # Create Column node and HAS_COLUMN relationship session.run( """ MATCH (t:Table {unique_name: $table_unique_name}) MERGE (c:Column {unique_name: $column_unique_name}) ON CREATE SET c += $props ON MATCH SET c += $props MERGE (t)-[:HAS_COLUMN]->(c) """, table_unique_name=table_unique_name, column_unique_name=column_unique_name, props=column_properties ) logger.info(f"Successfully imported schema for database: {db_name}") def _keyword_search(keyword: str) -> List[Dict[str, Any]]: """Internal helper to search for table nodes by keyword.""" driver = get_graph_driver() query = """ MATCH (d:Database)-[:HAS_TABLE]->(t:Table) WHERE t.name CONTAINS $keyword RETURN d.name as database, t.name as table LIMIT 5 """ with driver.session() as session: result = session.run(query, keyword=keyword) return [record.data() for record in result] def find_join_path(table1_name: str, table2_name: str) -> str: """ Finds a human-readable join path between two tables using the graph's schema. """ driver = get_graph_driver() t1_nodes = _keyword_search(table1_name) t2_nodes = _keyword_search(table2_name) if not t1_nodes: return f"Could not find a table matching '{table1_name}'." if not t2_nodes: return f"Could not find a table matching '{table2_name}'." t1_unique_name = f"{t1_nodes[0]['database']}.{t1_nodes[0]['table']}" t2_unique_name = f"{t2_nodes[0]['database']}.{t2_nodes[0]['table']}" path_query = """ MATCH (start:Table {unique_name: $start_name}), (end:Table {unique_name: $end_name}) CALL apoc.path.shortestPath(start, end, 'HAS_COLUMN|REFERENCES| '') as path LIMIT 1 """ with driver.session() as session: result = session.run(path_query, start_name=t1_unique_name, end_name=t2_unique_name) record = result.single() if not record or not record["path"]: return f"No join path found between {table1_name} and {table2_name}." path_str = " -> ".join(record["path"]) return f"Found path: {path_str}"