Spaces:
Running
Running
File size: 5,457 Bytes
28035e9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 | """Relationship discovery between database tables.
Detects relationships via:
1. Explicit foreign-key constraints
2. Matching column names across tables
3. ID-like suffix patterns (*_id, *_key)
4. Fuzzy name matching (cust_id β customer_id)
"""
from dataclasses import dataclass
from difflib import SequenceMatcher
from sqlalchemy import text
from db.connection import get_engine
from db.schema import get_schema
@dataclass
class Relationship:
table_a: str
column_a: str
table_b: str
column_b: str
confidence: float # 0.0 β 1.0
source: str # "fk", "exact_match", "id_pattern", "fuzzy"
def discover_relationships() -> list[Relationship]:
"""Return all discovered relationships across public tables."""
rels: list[Relationship] = []
rels.extend(_fk_relationships())
rels.extend(_implicit_relationships())
return _deduplicate(rels)
# ββ Explicit FK relationships βββββββββββββββββββββββββββββββββββββββββββββββ
def _fk_relationships() -> list[Relationship]:
query = text("""
SELECT
tc.table_name AS source_table,
kcu.column_name AS source_column,
ccu.table_name AS target_table,
ccu.column_name AS target_column
FROM information_schema.table_constraints tc
JOIN information_schema.key_column_usage kcu
ON tc.constraint_name = kcu.constraint_name
AND tc.table_schema = kcu.table_schema
JOIN information_schema.constraint_column_usage ccu
ON ccu.constraint_name = tc.constraint_name
AND ccu.table_schema = tc.table_schema
WHERE tc.constraint_type = 'FOREIGN KEY'
AND tc.table_schema = 'public'
""")
rels: list[Relationship] = []
with get_engine().connect() as conn:
for row in conn.execute(query).fetchall():
rels.append(Relationship(
table_a=row[0], column_a=row[1],
table_b=row[2], column_b=row[3],
confidence=1.0, source="fk",
))
return rels
# ββ Implicit relationships ββββββββββββββββββββββββββββββββββββββββββββββββββ
def _implicit_relationships() -> list[Relationship]:
schema = get_schema()
tables = list(schema.keys())
rels: list[Relationship] = []
for i, t1 in enumerate(tables):
cols1 = {c["column_name"] for c in schema[t1]}
for t2 in tables[i + 1:]:
cols2 = {c["column_name"] for c in schema[t2]}
# 1. Exact column-name matches
common = cols1 & cols2
for col in common:
rels.append(Relationship(
table_a=t1, column_a=col,
table_b=t2, column_b=col,
confidence=0.85, source="exact_match",
))
# 2. ID-pattern matching (e.g. "id" in t1 β "t1_id" in t2)
for c1 in cols1:
if not c1.endswith(("_id", "_key", "id")):
continue
for c2 in cols2:
if not c2.endswith(("_id", "_key", "id")):
continue
if c1 == c2:
continue # already caught above
base1 = c1.rsplit("_", 1)[0] if "_" in c1 else c1
base2 = c2.rsplit("_", 1)[0] if "_" in c2 else c2
if base1 == base2:
rels.append(Relationship(
table_a=t1, column_a=c1,
table_b=t2, column_b=c2,
confidence=0.75, source="id_pattern",
))
# 3. Fuzzy matching for remaining column pairs
for c1 in cols1:
for c2 in cols2:
if c1 == c2:
continue
ratio = SequenceMatcher(None, c1, c2).ratio()
if ratio >= 0.75:
rels.append(Relationship(
table_a=t1, column_a=c1,
table_b=t2, column_b=c2,
confidence=round(ratio * 0.8, 2),
source="fuzzy",
))
return rels
def _deduplicate(rels: list[Relationship]) -> list[Relationship]:
"""Keep the highest-confidence relationship for each column pair."""
best: dict[tuple, Relationship] = {}
for r in rels:
key = tuple(sorted([(r.table_a, r.column_a), (r.table_b, r.column_b)]))
if key not in best or r.confidence > best[key].confidence:
best[key] = r
return list(best.values())
def format_relationships(rels: list[Relationship] | None = None) -> str:
"""Format relationships as a readable string for prompt injection."""
if rels is None:
rels = discover_relationships()
if not rels:
return "No explicit or inferred relationships found between tables."
lines: list[str] = []
for r in sorted(rels, key=lambda x: -x.confidence):
lines.append(
f"{r.table_a}.{r.column_a} <-> {r.table_b}.{r.column_b} "
f"(confidence: {r.confidence:.0%}, source: {r.source})"
)
return "\n".join(lines)
|