File size: 5,457 Bytes
28035e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
"""Relationship discovery between database tables.

Detects relationships via:
1. Explicit foreign-key constraints
2. Matching column names across tables
3. ID-like suffix patterns (*_id, *_key)
4. Fuzzy name matching (cust_id β‰ˆ customer_id)
"""

from dataclasses import dataclass
from difflib import SequenceMatcher

from sqlalchemy import text

from db.connection import get_engine
from db.schema import get_schema


@dataclass
class Relationship:
    table_a: str
    column_a: str
    table_b: str
    column_b: str
    confidence: float        # 0.0 – 1.0
    source: str              # "fk", "exact_match", "id_pattern", "fuzzy"


def discover_relationships() -> list[Relationship]:
    """Return all discovered relationships across public tables."""
    rels: list[Relationship] = []
    rels.extend(_fk_relationships())
    rels.extend(_implicit_relationships())
    return _deduplicate(rels)


# ── Explicit FK relationships ───────────────────────────────────────────────

def _fk_relationships() -> list[Relationship]:
    query = text("""
        SELECT
            tc.table_name       AS source_table,
            kcu.column_name     AS source_column,
            ccu.table_name      AS target_table,
            ccu.column_name     AS target_column
        FROM information_schema.table_constraints tc
        JOIN information_schema.key_column_usage kcu
            ON tc.constraint_name = kcu.constraint_name
            AND tc.table_schema  = kcu.table_schema
        JOIN information_schema.constraint_column_usage ccu
            ON ccu.constraint_name = tc.constraint_name
            AND ccu.table_schema   = tc.table_schema
        WHERE tc.constraint_type = 'FOREIGN KEY'
          AND tc.table_schema    = 'public'
    """)

    rels: list[Relationship] = []
    with get_engine().connect() as conn:
        for row in conn.execute(query).fetchall():
            rels.append(Relationship(
                table_a=row[0], column_a=row[1],
                table_b=row[2], column_b=row[3],
                confidence=1.0, source="fk",
            ))
    return rels


# ── Implicit relationships ──────────────────────────────────────────────────

def _implicit_relationships() -> list[Relationship]:
    schema = get_schema()
    tables = list(schema.keys())
    rels: list[Relationship] = []

    for i, t1 in enumerate(tables):
        cols1 = {c["column_name"] for c in schema[t1]}
        for t2 in tables[i + 1:]:
            cols2 = {c["column_name"] for c in schema[t2]}

            # 1. Exact column-name matches
            common = cols1 & cols2
            for col in common:
                rels.append(Relationship(
                    table_a=t1, column_a=col,
                    table_b=t2, column_b=col,
                    confidence=0.85, source="exact_match",
                ))

            # 2. ID-pattern matching  (e.g. "id" in t1 ↔ "t1_id" in t2)
            for c1 in cols1:
                if not c1.endswith(("_id", "_key", "id")):
                    continue
                for c2 in cols2:
                    if not c2.endswith(("_id", "_key", "id")):
                        continue
                    if c1 == c2:
                        continue  # already caught above
                    base1 = c1.rsplit("_", 1)[0] if "_" in c1 else c1
                    base2 = c2.rsplit("_", 1)[0] if "_" in c2 else c2
                    if base1 == base2:
                        rels.append(Relationship(
                            table_a=t1, column_a=c1,
                            table_b=t2, column_b=c2,
                            confidence=0.75, source="id_pattern",
                        ))

            # 3. Fuzzy matching for remaining column pairs
            for c1 in cols1:
                for c2 in cols2:
                    if c1 == c2:
                        continue
                    ratio = SequenceMatcher(None, c1, c2).ratio()
                    if ratio >= 0.75:
                        rels.append(Relationship(
                            table_a=t1, column_a=c1,
                            table_b=t2, column_b=c2,
                            confidence=round(ratio * 0.8, 2),
                            source="fuzzy",
                        ))

    return rels


def _deduplicate(rels: list[Relationship]) -> list[Relationship]:
    """Keep the highest-confidence relationship for each column pair."""
    best: dict[tuple, Relationship] = {}
    for r in rels:
        key = tuple(sorted([(r.table_a, r.column_a), (r.table_b, r.column_b)]))
        if key not in best or r.confidence > best[key].confidence:
            best[key] = r
    return list(best.values())


def format_relationships(rels: list[Relationship] | None = None) -> str:
    """Format relationships as a readable string for prompt injection."""
    if rels is None:
        rels = discover_relationships()

    if not rels:
        return "No explicit or inferred relationships found between tables."

    lines: list[str] = []
    for r in sorted(rels, key=lambda x: -x.confidence):
        lines.append(
            f"{r.table_a}.{r.column_a} <-> {r.table_b}.{r.column_b}  "
            f"(confidence: {r.confidence:.0%}, source: {r.source})"
        )
    return "\n".join(lines)