File size: 7,659 Bytes
dc59b01
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
"""
Simple schema linking for Spider-style Text-to-SQL.

Goal:
- Given (question, db_id), select a small set of relevant tables/columns
  to include in the prompt (RAG-style schema retrieval).

Design constraints:
- Pure Python (no heavy external deps).
- Robust to missing/odd schemas: never crash.
"""

from __future__ import annotations

import json
import os
import re
import sqlite3
from contextlib import closing
from dataclasses import dataclass
from typing import Dict, Iterable, List, Optional, Sequence, Tuple


_ALNUM_RE = re.compile(r"[A-Za-z0-9]+")
_CAMEL_RE = re.compile(r"([a-z])([A-Z])")


def _normalize_identifier(text: str) -> str:
    """
    Normalize a schema identifier:
    - split underscores
    - split camelCase / PascalCase boundaries
    - lowercase
    """
    text = str(text or "")
    text = text.replace("_", " ")
    text = _CAMEL_RE.sub(r"\1 \2", text)
    return text.lower()


def _tokenize(text: str) -> List[str]:
    text = _normalize_identifier(text)
    return _ALNUM_RE.findall(text)


@dataclass(frozen=True)
class TableSchema:
    table_name: str
    columns: Tuple[str, ...]


class SchemaLinker:
    """
    Loads Spider `tables.json` and (optionally) SQLite schemas from disk.
    Provides a lightweight table scoring function based on token overlap.
    """

    def __init__(self, tables_json_path: str, db_root: Optional[str] = None):
        self.tables_json_path = tables_json_path
        self.db_root = db_root
        self._tables_by_db: Dict[str, List[TableSchema]] = {}
        self._sqlite_schema_cache: Dict[str, Dict[str, List[str]]] = {}
        self._load_tables_json()

    def _load_tables_json(self) -> None:
        with open(self.tables_json_path) as f:
            entries = json.load(f)

        tables_by_db: Dict[str, List[TableSchema]] = {}
        for entry in entries:
            db_id = entry["db_id"]
            table_names: List[str] = entry.get("table_names_original") or entry.get("table_names") or []
            col_names: List[Sequence] = entry.get("column_names_original") or entry.get("column_names") or []

            columns_by_table_idx: Dict[int, List[str]] = {i: [] for i in range(len(table_names))}
            for col in col_names:
                # Spider format: [table_idx, col_name]
                if not col or len(col) < 2:
                    continue
                table_idx, col_name = col[0], col[1]
                if table_idx is None or table_idx < 0:
                    continue  # skip "*"
                if table_idx not in columns_by_table_idx:
                    continue
                columns_by_table_idx[table_idx].append(str(col_name))

            tables: List[TableSchema] = []
            for i, tname in enumerate(table_names):
                cols = tuple(columns_by_table_idx.get(i, []))
                tables.append(TableSchema(table_name=str(tname), columns=cols))

            tables_by_db[db_id] = tables

        self._tables_by_db = tables_by_db

    def _db_path(self, db_id: str) -> Optional[str]:
        if not self.db_root:
            return None
        path = os.path.join(self.db_root, db_id, f"{db_id}.sqlite")
        return path if os.path.exists(path) else None

    def _load_sqlite_schema(self, db_id: str) -> Dict[str, List[str]]:
        """
        Load actual SQLite schema (table -> columns). Cached per db_id.
        """
        if db_id in self._sqlite_schema_cache:
            return self._sqlite_schema_cache[db_id]

        schema: Dict[str, List[str]] = {}
        db_path = self._db_path(db_id)
        if not db_path:
            self._sqlite_schema_cache[db_id] = schema
            return schema

        try:
            with closing(sqlite3.connect(db_path)) as conn:
                cursor = conn.cursor()
                tables = cursor.execute("SELECT name FROM sqlite_master WHERE type='table';").fetchall()
                for (table_name,) in tables:
                    columns = cursor.execute(f"PRAGMA table_info({table_name});").fetchall()
                    schema[str(table_name)] = [str(col[1]) for col in columns]
        except Exception:
            schema = {}

        self._sqlite_schema_cache[db_id] = schema
        return schema

    def get_schema(self, db_id: str) -> List[TableSchema]:
        """
        Returns a list of table schemas for this db.
        Prefers `tables.json` (Spider canonical), but can fallback to SQLite if needed.
        """
        tables = self._tables_by_db.get(db_id, [])
        if tables:
            return tables

        sqlite_schema = self._load_sqlite_schema(db_id)
        return [TableSchema(table_name=t, columns=tuple(cols)) for t, cols in sqlite_schema.items()]

    def score_tables(self, question: str, db_id: str) -> List[Tuple[float, TableSchema]]:
        """
        Score each table using token overlap:
        - table token overlap (higher weight)
        - column token overlap (lower weight)
        """
        q_tokens = set(_tokenize(question))
        tables = self.get_schema(db_id)

        scored: List[Tuple[float, TableSchema]] = []
        for t in tables:
            table_tokens = set(_tokenize(t.table_name))
            col_tokens: set[str] = set()
            for c in t.columns:
                col_tokens.update(_tokenize(c))

            table_overlap = len(q_tokens & table_tokens)
            col_overlap = len(q_tokens & col_tokens)

            # Simple weighted overlap (tuned to bias table matches).
            score = 3.0 * table_overlap + 1.0 * col_overlap

            # Small boost for substring mentions (helps e.g. "album" vs "albums").
            q_text = _normalize_identifier(question)
            if t.table_name and _normalize_identifier(t.table_name) in q_text:
                score += 0.5

            scored.append((score, t))

        scored.sort(key=lambda x: (x[0], x[1].table_name), reverse=True)
        return scored

    def select_top_tables(self, question: str, db_id: str, top_k: int = 4) -> List[TableSchema]:
        scored = self.score_tables(question, db_id)
        if not scored:
            return []
        top_k = max(1, int(top_k))
        selected = [t for _, t in scored[:top_k]]

        # If everything scores 0, still return a stable selection.
        if scored[0][0] <= 0:
            tables = self.get_schema(db_id)
            return tables[:top_k]

        return selected

    def columns_for_selected_tables(self, db_id: str, selected_tables: Iterable[TableSchema]) -> Dict[str, List[str]]:
        """
        Returns only columns belonging to selected tables.
        Prefer SQLite columns (actual DB) if available; fallback to tables.json.
        """
        sqlite_schema = self._load_sqlite_schema(db_id)
        out: Dict[str, List[str]] = {}
        for t in selected_tables:
            if t.table_name in sqlite_schema and sqlite_schema[t.table_name]:
                out[t.table_name] = sqlite_schema[t.table_name]
            else:
                out[t.table_name] = list(t.columns)
        return out

    def format_relevant_schema(self, question: str, db_id: str, top_k: int = 4) -> Tuple[List[str], Dict[str, List[str]]]:
        """
        Returns:
        - lines: ["table(col1, col2)", ...]
        - selected: {table: [cols...], ...}
        """
        selected_tables = self.select_top_tables(question, db_id, top_k=top_k)
        selected = self.columns_for_selected_tables(db_id, selected_tables)

        lines: List[str] = []
        for table_name, cols in selected.items():
            cols_str = ", ".join(cols)
            lines.append(f"{table_name}({cols_str})")

        return lines, selected