Spaces:
No application file
No application file
| from neo4j import GraphDatabase, Driver | |
| import logging | |
| from typing import List, Dict, Any | |
| from . import config | |
| logger = logging.getLogger(__name__) | |
| _driver: Driver = None | |
| def get_graph_driver() -> Driver: | |
| """Initializes and returns the singleton Neo4j driver instance.""" | |
| global _driver | |
| if _driver is None: | |
| logger.info("Initializing Neo4j driver...") | |
| _driver = GraphDatabase.driver(config.NEO4J_URI, auth=(config.NEO4J_USER, config.NEO4J_PASSWORD)) | |
| _ensure_constraints(_driver) | |
| return _driver | |
| def close_graph_driver(): | |
| """Closes the Neo4j driver connection.""" | |
| global _driver | |
| if _driver: | |
| logger.info("Closing Neo4j driver.") | |
| _driver.close() | |
| _driver = None | |
| def _ensure_constraints(driver: Driver): | |
| """Ensure uniqueness constraints are set up in Neo4j.""" | |
| with driver.session() as session: | |
| session.run("CREATE CONSTRAINT IF NOT EXISTS FOR (d:Database) REQUIRE d.name IS UNIQUE") | |
| session.run("CREATE CONSTRAINT IF NOT EXISTS FOR (t:Table) REQUIRE t.unique_name IS UNIQUE") | |
| session.run("CREATE CONSTRAINT IF NOT EXISTS FOR (c:Column) REQUIRE c.unique_name IS UNIQUE") | |
| logger.info("Neo4j constraints ensured.") | |
| def import_schema(schema_data: dict): | |
| """ | |
| Imports a discovered database schema into the Neo4j graph. | |
| """ | |
| driver = get_graph_driver() | |
| db_name = schema_data['database_name'] | |
| with driver.session() as session: | |
| # Create Database node | |
| session.run("MERGE (d:Database {name: $db_name})", db_name=db_name) | |
| for table in schema_data['tables']: | |
| table_unique_name = f"{db_name}.{table['name']}" | |
| table_properties = { | |
| "name": table['name'], | |
| "unique_name": table_unique_name, | |
| } | |
| # Create Table node and HAS_TABLE relationship | |
| session.run( | |
| """ | |
| MATCH (d:Database {name: $db_name}) | |
| MERGE (t:Table {unique_name: $unique_name}) | |
| ON CREATE SET t += $props | |
| ON MATCH SET t += $props | |
| MERGE (d)-[:HAS_TABLE]->(t) | |
| """, | |
| db_name=db_name, | |
| unique_name=table_unique_name, | |
| props=table_properties | |
| ) | |
| for column in table['columns']: | |
| column_unique_name = f"{table_unique_name}.{column['name']}" | |
| column_properties = { | |
| "name": column['name'], | |
| "unique_name": column_unique_name, | |
| "type": column['type'], | |
| } | |
| # Create Column node and HAS_COLUMN relationship | |
| session.run( | |
| """ | |
| MATCH (t:Table {unique_name: $table_unique_name}) | |
| MERGE (c:Column {unique_name: $column_unique_name}) | |
| ON CREATE SET c += $props | |
| ON MATCH SET c += $props | |
| MERGE (t)-[:HAS_COLUMN]->(c) | |
| """, | |
| table_unique_name=table_unique_name, | |
| column_unique_name=column_unique_name, | |
| props=column_properties | |
| ) | |
| logger.info(f"Successfully imported schema for database: {db_name}") | |
| def _keyword_search(keyword: str) -> List[Dict[str, Any]]: | |
| """Internal helper to search for table nodes by keyword.""" | |
| driver = get_graph_driver() | |
| query = """ | |
| MATCH (d:Database)-[:HAS_TABLE]->(t:Table) | |
| WHERE t.name CONTAINS $keyword | |
| RETURN d.name as database, t.name as table | |
| LIMIT 5 | |
| """ | |
| with driver.session() as session: | |
| result = session.run(query, keyword=keyword) | |
| return [record.data() for record in result] | |
| def find_join_path(table1_name: str, table2_name: str) -> str: | |
| """ | |
| Finds a human-readable join path between two tables using the graph's schema. | |
| """ | |
| driver = get_graph_driver() | |
| t1_nodes = _keyword_search(table1_name) | |
| t2_nodes = _keyword_search(table2_name) | |
| if not t1_nodes: return f"Could not find a table matching '{table1_name}'." | |
| if not t2_nodes: return f"Could not find a table matching '{table2_name}'." | |
| t1_unique_name = f"{t1_nodes[0]['database']}.{t1_nodes[0]['table']}" | |
| t2_unique_name = f"{t2_nodes[0]['database']}.{t2_nodes[0]['table']}" | |
| path_query = """ | |
| MATCH (start:Table {unique_name: $start_name}), (end:Table {unique_name: $end_name}) | |
| CALL apoc.path.shortestPath(start, end, 'HAS_COLUMN|REFERENCES|<HAS_COLUMN', {maxLevel: 5}) YIELD path | |
| WITH [n in nodes(path) | COALESCE(n.name, '')] as path_nodes | |
| RETURN FILTER(name in path_nodes WHERE name <> '') as path | |
| LIMIT 1 | |
| """ | |
| with driver.session() as session: | |
| result = session.run(path_query, start_name=t1_unique_name, end_name=t2_unique_name) | |
| record = result.single() | |
| if not record or not record["path"]: | |
| return f"No join path found between {table1_name} and {table2_name}." | |
| path_str = " -> ".join(record["path"]) | |
| return f"Found path: {path_str}" | |