File size: 4,474 Bytes
30e149a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
from __future__ import annotations

import os
import re
import sqlite3
from contextlib import closing
from typing import Dict, Optional

import torch

# Keep for compatibility with existing imports. Schema linking is disabled for
# SFT/RL alignment in this project version (full schema, deterministic order).
USE_SCHEMA_LINKING = False

PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
DB_ROOT = os.path.join(PROJECT_ROOT, "data", "database")

SCHEMA_CACHE: Dict[str, str] = {}


def get_schema_text(db_id: str) -> str:
    """
    Deterministic schema string:
      table(col1, col2, ...)
    Tables ordered alphabetically. Columns kept in PRAGMA order.
    """
    if db_id in SCHEMA_CACHE:
        return SCHEMA_CACHE[db_id]

    db_path = os.path.join(DB_ROOT, db_id, f"{db_id}.sqlite")
    schema_lines = []
    try:
        with closing(sqlite3.connect(db_path)) as conn:
            cur = conn.cursor()
            tables = cur.execute(
                "SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%';"
            ).fetchall()
            table_names = sorted([t[0] for t in tables if t and isinstance(t[0], str)])
            for tname in table_names:
                cols = cur.execute(f'PRAGMA table_info("{tname}")').fetchall()
                col_names = [c[1] for c in cols if c and isinstance(c[1], str)]
                schema_lines.append(f"{tname}({', '.join(col_names)})")
    except Exception:
        schema_lines = []

    schema_text = "\n".join(schema_lines).strip()
    SCHEMA_CACHE[db_id] = schema_text
    return schema_text


def clean_gold_sql(sql: str) -> str:
    """
    Lowercase SQL + strip common Spider aliases safely.
    If alias removal is ambiguous (same table used multiple times), keep SQL as-is.
    """
    if not isinstance(sql, str):
        return ""
    s = sql.strip().rstrip(";").strip()
    if not s:
        return ""

    # Attempt to resolve T1/T2 aliases to table names for simple cases.
    # Build alias -> table map from FROM/JOIN clauses.
    alias_map: Dict[str, str] = {}
    table_counts: Dict[str, int] = {}

    for m in re.finditer(r"\b(from|join)\s+([a-zA-Z_][\w$]*)\s+(?:as\s+)?(t\d+)\b", s, flags=re.I):
        table = m.group(2)
        alias = m.group(3)
        table_counts[table.lower()] = table_counts.get(table.lower(), 0) + 1
        alias_map[alias.lower()] = table

    # If any table appears multiple times, alias removal can be ambiguous → skip.
    if any(c > 1 for c in table_counts.values()):
        return s.lower()

    # Replace alias-qualified refs alias.col -> table.col
    out = s
    for alias, table in alias_map.items():
        out = re.sub(rf"\b{re.escape(alias)}\.", f"{table}.", out, flags=re.I)

    # Remove alias declarations: "table AS t1" or "table t1"
    for alias, table in alias_map.items():
        out = re.sub(rf"\b{re.escape(table)}\s+as\s+{re.escape(alias)}\b", table, out, flags=re.I)
        out = re.sub(rf"\b{re.escape(table)}\s+{re.escape(alias)}\b", table, out, flags=re.I)

    return out.lower().strip()


def build_prompt(
    question: str,
    db_id: str,
    *,
    schema_text: str,
    training_sql: Optional[str] = None,
) -> str:
    """
    Required prompt format:

    You are a SQLite expert.

    Database: <db_id>

    Schema:
    <table>(col1, col2, ...)
    ...

    Question:
    <question>

    SQL:
    <gold sql>   (training only)
    """
    base = (
        "You are a SQLite expert.\n\n"
        f"Database: {db_id}\n\n"
        "Schema:\n"
        f"{schema_text}\n\n"
        "Question:\n"
        f"{question}\n\n"
        "SQL:"
    )
    if training_sql is None:
        return base
    return base + "\n" + training_sql


def encode_prompt(
    tokenizer,
    question: str,
    db_id: str,
    *,
    device: str,
    max_input_tokens: int = 512,
    training_sql: Optional[str] = None,
) -> torch.Tensor:
    """
    Inference mode: stops at "SQL:"
    Training mode: can include SQL target (optional; we still recommend decoder labels).
    Truncation happens only on schema portion by character trimming (deterministic).
    """
    schema_text = get_schema_text(db_id)
    prompt = build_prompt(question, db_id, schema_text=schema_text, training_sql=training_sql)
    enc = tokenizer(
        prompt,
        truncation=True,
        max_length=max_input_tokens,
        padding=False,
        return_tensors="pt",
    )
    return enc.input_ids[0].to(device)