Spaces:
No application file
No application file
File size: 5,139 Bytes
9d411a7 86cbe3c 9d411a7 86cbe3c 9d411a7 86cbe3c 9d411a7 86cbe3c 9d411a7 86cbe3c 9d411a7 86cbe3c a0eb181 9d411a7 86cbe3c 9d411a7 86cbe3c 9d411a7 86cbe3c 9d411a7 86cbe3c 9d411a7 86cbe3c 9d411a7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 | from neo4j import GraphDatabase, Driver
import logging
from typing import List, Dict, Any
from . import config
logger = logging.getLogger(__name__)
_driver: Driver = None
def get_graph_driver() -> Driver:
"""Initializes and returns the singleton Neo4j driver instance."""
global _driver
if _driver is None:
logger.info("Initializing Neo4j driver...")
_driver = GraphDatabase.driver(config.NEO4J_URI, auth=(config.NEO4J_USER, config.NEO4J_PASSWORD))
_ensure_constraints(_driver)
return _driver
def close_graph_driver():
"""Closes the Neo4j driver connection."""
global _driver
if _driver:
logger.info("Closing Neo4j driver.")
_driver.close()
_driver = None
def _ensure_constraints(driver: Driver):
"""Ensure uniqueness constraints are set up in Neo4j."""
with driver.session() as session:
session.run("CREATE CONSTRAINT IF NOT EXISTS FOR (d:Database) REQUIRE d.name IS UNIQUE")
session.run("CREATE CONSTRAINT IF NOT EXISTS FOR (t:Table) REQUIRE t.unique_name IS UNIQUE")
session.run("CREATE CONSTRAINT IF NOT EXISTS FOR (c:Column) REQUIRE c.unique_name IS UNIQUE")
logger.info("Neo4j constraints ensured.")
def import_schema(schema_data: dict):
"""
Imports a discovered database schema into the Neo4j graph.
"""
driver = get_graph_driver()
db_name = schema_data['database_name']
with driver.session() as session:
# Create Database node
session.run("MERGE (d:Database {name: $db_name})", db_name=db_name)
for table in schema_data['tables']:
table_unique_name = f"{db_name}.{table['name']}"
table_properties = {
"name": table['name'],
"unique_name": table_unique_name,
}
# Create Table node and HAS_TABLE relationship
session.run(
"""
MATCH (d:Database {name: $db_name})
MERGE (t:Table {unique_name: $unique_name})
ON CREATE SET t += $props
ON MATCH SET t += $props
MERGE (d)-[:HAS_TABLE]->(t)
""",
db_name=db_name,
unique_name=table_unique_name,
props=table_properties
)
for column in table['columns']:
column_unique_name = f"{table_unique_name}.{column['name']}"
column_properties = {
"name": column['name'],
"unique_name": column_unique_name,
"type": column['type'],
}
# Create Column node and HAS_COLUMN relationship
session.run(
"""
MATCH (t:Table {unique_name: $table_unique_name})
MERGE (c:Column {unique_name: $column_unique_name})
ON CREATE SET c += $props
ON MATCH SET c += $props
MERGE (t)-[:HAS_COLUMN]->(c)
""",
table_unique_name=table_unique_name,
column_unique_name=column_unique_name,
props=column_properties
)
logger.info(f"Successfully imported schema for database: {db_name}")
def _keyword_search(keyword: str) -> List[Dict[str, Any]]:
"""Internal helper to search for table nodes by keyword."""
driver = get_graph_driver()
query = """
MATCH (d:Database)-[:HAS_TABLE]->(t:Table)
WHERE t.name CONTAINS $keyword
RETURN d.name as database, t.name as table
LIMIT 5
"""
with driver.session() as session:
result = session.run(query, keyword=keyword)
return [record.data() for record in result]
def find_join_path(table1_name: str, table2_name: str) -> str:
"""
Finds a human-readable join path between two tables using the graph's schema.
"""
driver = get_graph_driver()
t1_nodes = _keyword_search(table1_name)
t2_nodes = _keyword_search(table2_name)
if not t1_nodes: return f"Could not find a table matching '{table1_name}'."
if not t2_nodes: return f"Could not find a table matching '{table2_name}'."
t1_unique_name = f"{t1_nodes[0]['database']}.{t1_nodes[0]['table']}"
t2_unique_name = f"{t2_nodes[0]['database']}.{t2_nodes[0]['table']}"
path_query = """
MATCH (start:Table {unique_name: $start_name}), (end:Table {unique_name: $end_name})
CALL apoc.path.shortestPath(start, end, 'HAS_COLUMN|REFERENCES|<HAS_COLUMN', {maxLevel: 5}) YIELD path
WITH [n in nodes(path) | COALESCE(n.name, '')] as path_nodes
RETURN FILTER(name in path_nodes WHERE name <> '') as path
LIMIT 1
"""
with driver.session() as session:
result = session.run(path_query, start_name=t1_unique_name, end_name=t2_unique_name)
record = result.single()
if not record or not record["path"]:
return f"No join path found between {table1_name} and {table2_name}."
path_str = " -> ".join(record["path"])
return f"Found path: {path_str}"
|