File size: 5,139 Bytes
9d411a7
86cbe3c
 
9d411a7
86cbe3c
 
 
 
9d411a7
86cbe3c
9d411a7
 
 
 
 
 
 
 
86cbe3c
9d411a7
 
 
 
 
 
 
86cbe3c
9d411a7
 
 
 
 
 
 
86cbe3c
a0eb181
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9d411a7
 
 
 
 
 
 
 
 
 
 
 
86cbe3c
9d411a7
 
 
 
 
86cbe3c
9d411a7
 
86cbe3c
9d411a7
 
 
 
 
86cbe3c
9d411a7
 
 
 
 
 
 
 
 
 
 
 
 
86cbe3c
9d411a7
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
from neo4j import GraphDatabase, Driver
import logging
from typing import List, Dict, Any

from . import config

logger = logging.getLogger(__name__)

_driver: Driver = None

def get_graph_driver() -> Driver:
    """Initializes and returns the singleton Neo4j driver instance."""
    global _driver
    if _driver is None:
        logger.info("Initializing Neo4j driver...")
        _driver = GraphDatabase.driver(config.NEO4J_URI, auth=(config.NEO4J_USER, config.NEO4J_PASSWORD))
        _ensure_constraints(_driver)
    return _driver

def close_graph_driver():
    """Closes the Neo4j driver connection."""
    global _driver
    if _driver:
        logger.info("Closing Neo4j driver.")
        _driver.close()
        _driver = None

def _ensure_constraints(driver: Driver):
    """Ensure uniqueness constraints are set up in Neo4j."""
    with driver.session() as session:
        session.run("CREATE CONSTRAINT IF NOT EXISTS FOR (d:Database) REQUIRE d.name IS UNIQUE")
        session.run("CREATE CONSTRAINT IF NOT EXISTS FOR (t:Table) REQUIRE t.unique_name IS UNIQUE")
        session.run("CREATE CONSTRAINT IF NOT EXISTS FOR (c:Column) REQUIRE c.unique_name IS UNIQUE")
    logger.info("Neo4j constraints ensured.")

def import_schema(schema_data: dict):
    """
    Imports a discovered database schema into the Neo4j graph.
    """
    driver = get_graph_driver()
    db_name = schema_data['database_name']
    
    with driver.session() as session:
        # Create Database node
        session.run("MERGE (d:Database {name: $db_name})", db_name=db_name)
        
        for table in schema_data['tables']:
            table_unique_name = f"{db_name}.{table['name']}"
            table_properties = {
                "name": table['name'],
                "unique_name": table_unique_name,
            }
            
            # Create Table node and HAS_TABLE relationship
            session.run(
                """
                MATCH (d:Database {name: $db_name})
                MERGE (t:Table {unique_name: $unique_name})
                ON CREATE SET t += $props
                ON MATCH SET t += $props
                MERGE (d)-[:HAS_TABLE]->(t)
                """,
                db_name=db_name,
                unique_name=table_unique_name,
                props=table_properties
            )

            for column in table['columns']:
                column_unique_name = f"{table_unique_name}.{column['name']}"
                column_properties = {
                    "name": column['name'],
                    "unique_name": column_unique_name,
                    "type": column['type'],
                }

                # Create Column node and HAS_COLUMN relationship
                session.run(
                    """
                    MATCH (t:Table {unique_name: $table_unique_name})
                    MERGE (c:Column {unique_name: $column_unique_name})
                    ON CREATE SET c += $props
                    ON MATCH SET c += $props
                    MERGE (t)-[:HAS_COLUMN]->(c)
                    """,
                    table_unique_name=table_unique_name,
                    column_unique_name=column_unique_name,
                    props=column_properties
                )
    logger.info(f"Successfully imported schema for database: {db_name}")

def _keyword_search(keyword: str) -> List[Dict[str, Any]]:
    """Internal helper to search for table nodes by keyword."""
    driver = get_graph_driver()
    query = """
    MATCH (d:Database)-[:HAS_TABLE]->(t:Table)
    WHERE t.name CONTAINS $keyword
    RETURN d.name as database, t.name as table
    LIMIT 5
    """
    with driver.session() as session:
        result = session.run(query, keyword=keyword)
        return [record.data() for record in result]

def find_join_path(table1_name: str, table2_name: str) -> str:
    """
    Finds a human-readable join path between two tables using the graph's schema.
    """
    driver = get_graph_driver()

    t1_nodes = _keyword_search(table1_name)
    t2_nodes = _keyword_search(table2_name)

    if not t1_nodes: return f"Could not find a table matching '{table1_name}'."
    if not t2_nodes: return f"Could not find a table matching '{table2_name}'."
    
    t1_unique_name = f"{t1_nodes[0]['database']}.{t1_nodes[0]['table']}"
    t2_unique_name = f"{t2_nodes[0]['database']}.{t2_nodes[0]['table']}"

    path_query = """
    MATCH (start:Table {unique_name: $start_name}), (end:Table {unique_name: $end_name})
    CALL apoc.path.shortestPath(start, end, 'HAS_COLUMN|REFERENCES|<HAS_COLUMN', {maxLevel: 5}) YIELD path
    WITH [n in nodes(path) | COALESCE(n.name, '')] as path_nodes
    RETURN FILTER(name in path_nodes WHERE name <> '') as path
    LIMIT 1
    """
    with driver.session() as session:
        result = session.run(path_query, start_name=t1_unique_name, end_name=t2_unique_name)
        record = result.single()
        
        if not record or not record["path"]:
            return f"No join path found between {table1_name} and {table2_name}."

        path_str = " -> ".join(record["path"])
        return f"Found path: {path_str}"