Spaces:
Sleeping
Sleeping
File size: 4,474 Bytes
30e149a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 | from __future__ import annotations
import os
import re
import sqlite3
from contextlib import closing
from typing import Dict, Optional
import torch
# Keep for compatibility with existing imports. Schema linking is disabled for
# SFT/RL alignment in this project version (full schema, deterministic order).
USE_SCHEMA_LINKING = False
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
DB_ROOT = os.path.join(PROJECT_ROOT, "data", "database")
SCHEMA_CACHE: Dict[str, str] = {}
def get_schema_text(db_id: str) -> str:
"""
Deterministic schema string:
table(col1, col2, ...)
Tables ordered alphabetically. Columns kept in PRAGMA order.
"""
if db_id in SCHEMA_CACHE:
return SCHEMA_CACHE[db_id]
db_path = os.path.join(DB_ROOT, db_id, f"{db_id}.sqlite")
schema_lines = []
try:
with closing(sqlite3.connect(db_path)) as conn:
cur = conn.cursor()
tables = cur.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%';"
).fetchall()
table_names = sorted([t[0] for t in tables if t and isinstance(t[0], str)])
for tname in table_names:
cols = cur.execute(f'PRAGMA table_info("{tname}")').fetchall()
col_names = [c[1] for c in cols if c and isinstance(c[1], str)]
schema_lines.append(f"{tname}({', '.join(col_names)})")
except Exception:
schema_lines = []
schema_text = "\n".join(schema_lines).strip()
SCHEMA_CACHE[db_id] = schema_text
return schema_text
def clean_gold_sql(sql: str) -> str:
"""
Lowercase SQL + strip common Spider aliases safely.
If alias removal is ambiguous (same table used multiple times), keep SQL as-is.
"""
if not isinstance(sql, str):
return ""
s = sql.strip().rstrip(";").strip()
if not s:
return ""
# Attempt to resolve T1/T2 aliases to table names for simple cases.
# Build alias -> table map from FROM/JOIN clauses.
alias_map: Dict[str, str] = {}
table_counts: Dict[str, int] = {}
for m in re.finditer(r"\b(from|join)\s+([a-zA-Z_][\w$]*)\s+(?:as\s+)?(t\d+)\b", s, flags=re.I):
table = m.group(2)
alias = m.group(3)
table_counts[table.lower()] = table_counts.get(table.lower(), 0) + 1
alias_map[alias.lower()] = table
# If any table appears multiple times, alias removal can be ambiguous → skip.
if any(c > 1 for c in table_counts.values()):
return s.lower()
# Replace alias-qualified refs alias.col -> table.col
out = s
for alias, table in alias_map.items():
out = re.sub(rf"\b{re.escape(alias)}\.", f"{table}.", out, flags=re.I)
# Remove alias declarations: "table AS t1" or "table t1"
for alias, table in alias_map.items():
out = re.sub(rf"\b{re.escape(table)}\s+as\s+{re.escape(alias)}\b", table, out, flags=re.I)
out = re.sub(rf"\b{re.escape(table)}\s+{re.escape(alias)}\b", table, out, flags=re.I)
return out.lower().strip()
def build_prompt(
question: str,
db_id: str,
*,
schema_text: str,
training_sql: Optional[str] = None,
) -> str:
"""
Required prompt format:
You are a SQLite expert.
Database: <db_id>
Schema:
<table>(col1, col2, ...)
...
Question:
<question>
SQL:
<gold sql> (training only)
"""
base = (
"You are a SQLite expert.\n\n"
f"Database: {db_id}\n\n"
"Schema:\n"
f"{schema_text}\n\n"
"Question:\n"
f"{question}\n\n"
"SQL:"
)
if training_sql is None:
return base
return base + "\n" + training_sql
def encode_prompt(
tokenizer,
question: str,
db_id: str,
*,
device: str,
max_input_tokens: int = 512,
training_sql: Optional[str] = None,
) -> torch.Tensor:
"""
Inference mode: stops at "SQL:"
Training mode: can include SQL target (optional; we still recommend decoder labels).
Truncation happens only on schema portion by character trimming (deterministic).
"""
schema_text = get_schema_text(db_id)
prompt = build_prompt(question, db_id, schema_text=schema_text, training_sql=training_sql)
enc = tokenizer(
prompt,
truncation=True,
max_length=max_input_tokens,
padding=False,
return_tensors="pt",
)
return enc.input_ids[0].to(device)
|