| | from neo4j import GraphDatabase, Driver |
| | import logging |
| | from typing import List, Dict, Any |
| |
|
| | from . import config |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| | _driver: Driver = None |
| |
|
| | def get_graph_driver() -> Driver: |
| | """Initializes and returns the singleton Neo4j driver instance.""" |
| | global _driver |
| | if _driver is None: |
| | logger.info("Initializing Neo4j driver...") |
| | _driver = GraphDatabase.driver(config.NEO4J_URI, auth=(config.NEO4J_USER, config.NEO4J_PASSWORD)) |
| | _ensure_constraints(_driver) |
| | return _driver |
| |
|
| | def close_graph_driver(): |
| | """Closes the Neo4j driver connection.""" |
| | global _driver |
| | if _driver: |
| | logger.info("Closing Neo4j driver.") |
| | _driver.close() |
| | _driver = None |
| |
|
| | def _ensure_constraints(driver: Driver): |
| | """Ensure uniqueness constraints are set up in Neo4j.""" |
| | with driver.session() as session: |
| | session.run("CREATE CONSTRAINT IF NOT EXISTS FOR (d:Database) REQUIRE d.name IS UNIQUE") |
| | session.run("CREATE CONSTRAINT IF NOT EXISTS FOR (t:Table) REQUIRE t.unique_name IS UNIQUE") |
| | session.run("CREATE CONSTRAINT IF NOT EXISTS FOR (c:Column) REQUIRE c.unique_name IS UNIQUE") |
| | logger.info("Neo4j constraints ensured.") |
| |
|
| | def import_schema(schema_data: dict): |
| | """ |
| | Imports a discovered database schema into the Neo4j graph. |
| | """ |
| | driver = get_graph_driver() |
| | db_name = schema_data['database_name'] |
| | |
| | with driver.session() as session: |
| | |
| | session.run("MERGE (d:Database {name: $db_name})", db_name=db_name) |
| | |
| | for table in schema_data['tables']: |
| | table_unique_name = f"{db_name}.{table['name']}" |
| | table_properties = { |
| | "name": table['name'], |
| | "unique_name": table_unique_name, |
| | } |
| | |
| | |
| | session.run( |
| | """ |
| | MATCH (d:Database {name: $db_name}) |
| | MERGE (t:Table {unique_name: $unique_name}) |
| | ON CREATE SET t += $props |
| | ON MATCH SET t += $props |
| | MERGE (d)-[:HAS_TABLE]->(t) |
| | """, |
| | db_name=db_name, |
| | unique_name=table_unique_name, |
| | props=table_properties |
| | ) |
| |
|
| | for column in table['columns']: |
| | column_unique_name = f"{table_unique_name}.{column['name']}" |
| | column_properties = { |
| | "name": column['name'], |
| | "unique_name": column_unique_name, |
| | "type": column['type'], |
| | } |
| |
|
| | |
| | session.run( |
| | """ |
| | MATCH (t:Table {unique_name: $table_unique_name}) |
| | MERGE (c:Column {unique_name: $column_unique_name}) |
| | ON CREATE SET c += $props |
| | ON MATCH SET c += $props |
| | MERGE (t)-[:HAS_COLUMN]->(c) |
| | """, |
| | table_unique_name=table_unique_name, |
| | column_unique_name=column_unique_name, |
| | props=column_properties |
| | ) |
| | logger.info(f"Successfully imported schema for database: {db_name}") |
| |
|
| | def _keyword_search(keyword: str) -> List[Dict[str, Any]]: |
| | """Internal helper to search for table nodes by keyword.""" |
| | driver = get_graph_driver() |
| | query = """ |
| | MATCH (d:Database)-[:HAS_TABLE]->(t:Table) |
| | WHERE t.name CONTAINS $keyword |
| | RETURN d.name as database, t.name as table |
| | LIMIT 5 |
| | """ |
| | with driver.session() as session: |
| | result = session.run(query, keyword=keyword) |
| | return [record.data() for record in result] |
| |
|
| | def find_join_path(table1_name: str, table2_name: str) -> str: |
| | """ |
| | Finds a human-readable join path between two tables using the graph's schema. |
| | """ |
| | driver = get_graph_driver() |
| |
|
| | t1_nodes = _keyword_search(table1_name) |
| | t2_nodes = _keyword_search(table2_name) |
| |
|
| | if not t1_nodes: return f"Could not find a table matching '{table1_name}'." |
| | if not t2_nodes: return f"Could not find a table matching '{table2_name}'." |
| | |
| | t1_unique_name = f"{t1_nodes[0]['database']}.{t1_nodes[0]['table']}" |
| | t2_unique_name = f"{t2_nodes[0]['database']}.{t2_nodes[0]['table']}" |
| |
|
| | path_query = """ |
| | MATCH (start:Table {unique_name: $start_name}), (end:Table {unique_name: $end_name}) |
| | CALL apoc.path.shortestPath(start, end, 'HAS_COLUMN|REFERENCES|<HAS_COLUMN', {maxLevel: 5}) YIELD path |
| | WITH [n in nodes(path) | COALESCE(n.name, '')] as path_nodes |
| | RETURN FILTER(name in path_nodes WHERE name <> '') as path |
| | LIMIT 1 |
| | """ |
| | with driver.session() as session: |
| | result = session.run(path_query, start_name=t1_unique_name, end_name=t2_unique_name) |
| | record = result.single() |
| | |
| | if not record or not record["path"]: |
| | return f"No join path found between {table1_name} and {table2_name}." |
| |
|
| | path_str = " -> ".join(record["path"]) |
| | return f"Found path: {path_str}" |
| |
|