File size: 1,466 Bytes
dc59b01
 
 
 
 
 
 
 
 
29662cd
 
 
 
dc59b01
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29662cd
dc59b01
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import sqlite3


class SchemaEncoder:

    def __init__(self, db_root):
        self.db_root = db_root

    def get_tables_and_columns(self, db_id):

        # FIXED PATH
        db_path = self.db_root / f"{db_id}.sqlite"

        conn = sqlite3.connect(db_path)
        cursor = conn.cursor()

        tables = cursor.execute(
            "SELECT name FROM sqlite_master WHERE type='table';"
        ).fetchall()

        schema = {}

        for (table,) in tables:
            cols = cursor.execute(f"PRAGMA table_info({table});").fetchall()
            col_names = [c[1] for c in cols]
            schema[table] = col_names

        conn.close()
        return schema

    # -----------------------------------
    # Strategy 1: Structured
    # -----------------------------------
    def structured_schema(self, db_id):
        schema = self.get_tables_and_columns(db_id)

        lines = []
        for table, cols in schema.items():
            lines.append(f"{table}({', '.join(cols)})")

        return "\n".join(lines)

    # -----------------------------------
    # Strategy 2: Natural Language
    # -----------------------------------
    def natural_language_schema(self, db_id):
        schema = self.get_tables_and_columns(db_id)

        lines = []
        for table, cols in schema.items():
            col_text = ", ".join(cols)
            lines.append(f"The table '{table}' contains the columns: {col_text}.")

        return "\n".join(lines)