ohmygaugh's picture
demo working
a0eb181
from neo4j import GraphDatabase, Driver
import logging
from typing import List, Dict, Any
from . import config
logger = logging.getLogger(__name__)
_driver: Driver = None
def get_graph_driver() -> Driver:
"""Initializes and returns the singleton Neo4j driver instance."""
global _driver
if _driver is None:
logger.info("Initializing Neo4j driver...")
_driver = GraphDatabase.driver(config.NEO4J_URI, auth=(config.NEO4J_USER, config.NEO4J_PASSWORD))
_ensure_constraints(_driver)
return _driver
def close_graph_driver():
"""Closes the Neo4j driver connection."""
global _driver
if _driver:
logger.info("Closing Neo4j driver.")
_driver.close()
_driver = None
def _ensure_constraints(driver: Driver):
"""Ensure uniqueness constraints are set up in Neo4j."""
with driver.session() as session:
session.run("CREATE CONSTRAINT IF NOT EXISTS FOR (d:Database) REQUIRE d.name IS UNIQUE")
session.run("CREATE CONSTRAINT IF NOT EXISTS FOR (t:Table) REQUIRE t.unique_name IS UNIQUE")
session.run("CREATE CONSTRAINT IF NOT EXISTS FOR (c:Column) REQUIRE c.unique_name IS UNIQUE")
logger.info("Neo4j constraints ensured.")
def import_schema(schema_data: dict):
"""
Imports a discovered database schema into the Neo4j graph.
"""
driver = get_graph_driver()
db_name = schema_data['database_name']
with driver.session() as session:
# Create Database node
session.run("MERGE (d:Database {name: $db_name})", db_name=db_name)
for table in schema_data['tables']:
table_unique_name = f"{db_name}.{table['name']}"
table_properties = {
"name": table['name'],
"unique_name": table_unique_name,
}
# Create Table node and HAS_TABLE relationship
session.run(
"""
MATCH (d:Database {name: $db_name})
MERGE (t:Table {unique_name: $unique_name})
ON CREATE SET t += $props
ON MATCH SET t += $props
MERGE (d)-[:HAS_TABLE]->(t)
""",
db_name=db_name,
unique_name=table_unique_name,
props=table_properties
)
for column in table['columns']:
column_unique_name = f"{table_unique_name}.{column['name']}"
column_properties = {
"name": column['name'],
"unique_name": column_unique_name,
"type": column['type'],
}
# Create Column node and HAS_COLUMN relationship
session.run(
"""
MATCH (t:Table {unique_name: $table_unique_name})
MERGE (c:Column {unique_name: $column_unique_name})
ON CREATE SET c += $props
ON MATCH SET c += $props
MERGE (t)-[:HAS_COLUMN]->(c)
""",
table_unique_name=table_unique_name,
column_unique_name=column_unique_name,
props=column_properties
)
logger.info(f"Successfully imported schema for database: {db_name}")
def _keyword_search(keyword: str) -> List[Dict[str, Any]]:
"""Internal helper to search for table nodes by keyword."""
driver = get_graph_driver()
query = """
MATCH (d:Database)-[:HAS_TABLE]->(t:Table)
WHERE t.name CONTAINS $keyword
RETURN d.name as database, t.name as table
LIMIT 5
"""
with driver.session() as session:
result = session.run(query, keyword=keyword)
return [record.data() for record in result]
def find_join_path(table1_name: str, table2_name: str) -> str:
"""
Finds a human-readable join path between two tables using the graph's schema.
"""
driver = get_graph_driver()
t1_nodes = _keyword_search(table1_name)
t2_nodes = _keyword_search(table2_name)
if not t1_nodes: return f"Could not find a table matching '{table1_name}'."
if not t2_nodes: return f"Could not find a table matching '{table2_name}'."
t1_unique_name = f"{t1_nodes[0]['database']}.{t1_nodes[0]['table']}"
t2_unique_name = f"{t2_nodes[0]['database']}.{t2_nodes[0]['table']}"
path_query = """
MATCH (start:Table {unique_name: $start_name}), (end:Table {unique_name: $end_name})
CALL apoc.path.shortestPath(start, end, 'HAS_COLUMN|REFERENCES|<HAS_COLUMN', {maxLevel: 5}) YIELD path
WITH [n in nodes(path) | COALESCE(n.name, '')] as path_nodes
RETURN FILTER(name in path_nodes WHERE name <> '') as path
LIMIT 1
"""
with driver.session() as session:
result = session.run(path_query, start_name=t1_unique_name, end_name=t2_unique_name)
record = result.single()
if not record or not record["path"]:
return f"No join path found between {table1_name} and {table2_name}."
path_str = " -> ".join(record["path"])
return f"Found path: {path_str}"