Spaces:
Sleeping
Sleeping
Commit ·
862cf68
1
Parent(s): 14b487f
wdf
Browse files- README.md +8 -7
- app.py +209 -499
- best_rlhf_model/README.md +9 -0
- best_rlhf_model/adapter_config.json +21 -0
- {int8_dynamic/tokenizer → best_rlhf_model}/merges.txt +0 -0
- {int8_dynamic/tokenizer → best_rlhf_model}/special_tokens_map.json +0 -0
- {int8_dynamic/tokenizer → best_rlhf_model}/tokenizer_config.json +0 -1
- best_rlhf_model/vocab.json +0 -0
- int8_dynamic/meta.json +0 -7
- int8_dynamic/model.pt +0 -3
- int8_dynamic/tokenizer/tokenizer.json +0 -0
- int8_dynamic/tokenizer/vocab.json +0 -0
- requirements.txt +3 -5
- scripts/benchmark_parallel_reward.py +0 -202
- scripts/benchmark_quantization.py +0 -108
- scripts/benchmark_rollout_generation.py +0 -66
- scripts/error_dashboard.py +0 -99
- scripts/evaluate.py +0 -170
- scripts/plot_task2.py +0 -58
- scripts/plot_task3.py +0 -15
- scripts/plot_task3_plotly.py +0 -103
- scripts/quantize_export.py +0 -86
- scripts/quantized_infer_harness.py +0 -46
- src/constrained_decoding.py +0 -1058
- src/constrained_decoding_sample.py +0 -516
- src/eval_rl_fixed.py +329 -619
- src/evaluate_without_constraied.py +0 -503
- src/execution_reward copy.py +0 -831
- src/execution_reward.py +322 -744
- src/execution_reward_soft.py +0 -211
- src/load_lora_model.py +8 -70
- src/quantization_utils.py +0 -222
- src/quantized_text2sql_engine.py +0 -243
- src/schema_encoder.py +49 -33
- src/schema_utils.py +0 -222
- src/sql_validator.py +32 -314
- src/text2sql_engine.py +36 -682
README.md
CHANGED
|
@@ -1,13 +1,14 @@
|
|
| 1 |
---
|
| 2 |
title: Text2sql Demo
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version:
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
license: mit
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
|
|
|
|
|
| 1 |
---
|
| 2 |
title: Text2sql Demo
|
| 3 |
+
emoji: 🐨
|
| 4 |
+
colorFrom: yellow
|
| 5 |
+
colorTo: pink
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: 6.8.0
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
license: mit
|
| 11 |
+
short_description: 'to show the gradio interface '
|
| 12 |
+
---
|
| 13 |
+
|
| 14 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
"""
|
| 2 |
-
GRADIO DEMO UI
|
| 3 |
NL → SQL → Result Table
|
| 4 |
"""
|
| 5 |
|
|
@@ -7,563 +7,273 @@ import gradio as gr
|
|
| 7 |
import pandas as pd
|
| 8 |
import re
|
| 9 |
import time
|
| 10 |
-
import
|
| 11 |
-
import torch
|
| 12 |
-
import sys
|
| 13 |
-
import json
|
| 14 |
-
import subprocess
|
| 15 |
-
import base64
|
| 16 |
-
import io
|
| 17 |
-
from pathlib import Path
|
| 18 |
-
from typing import Iterator
|
| 19 |
-
|
| 20 |
-
# ==========================================
|
| 21 |
-
# RELATIVE PATH RESOLUTION (GLOBAL)
|
| 22 |
-
# ==========================================
|
| 23 |
-
try:
|
| 24 |
-
PROJECT_ROOT = Path(__file__).resolve().parent
|
| 25 |
-
except NameError:
|
| 26 |
-
PROJECT_ROOT = Path(".").resolve()
|
| 27 |
-
|
| 28 |
-
if (PROJECT_ROOT / "data" / "database").exists():
|
| 29 |
-
DB_ROOT = PROJECT_ROOT / "data" / "database"
|
| 30 |
-
else:
|
| 31 |
-
DB_ROOT = PROJECT_ROOT / "final_databases"
|
| 32 |
-
|
| 33 |
-
def get_db_path(db_id: str) -> str:
|
| 34 |
-
path1 = DB_ROOT / db_id / f"{db_id}.sqlite"
|
| 35 |
-
path2 = DB_ROOT / f"{db_id}.sqlite"
|
| 36 |
-
return str(path1) if path1.exists() else str(path2)
|
| 37 |
-
|
| 38 |
-
# ==========================================
|
| 39 |
-
# 🔥 CUDA MOCK PATCH FOR MAC (MPS) / CPU
|
| 40 |
-
# ==========================================
|
| 41 |
-
if not torch.cuda.is_available():
|
| 42 |
-
class MockCUDAEvent:
|
| 43 |
-
def __init__(self, enable_timing=False, blocking=False, interprocess=False):
|
| 44 |
-
self.t = 0.0
|
| 45 |
-
def record(self, stream=None):
|
| 46 |
-
self.t = time.perf_counter()
|
| 47 |
-
def elapsed_time(self, end_event):
|
| 48 |
-
return (end_event.t - self.t) * 1000.0
|
| 49 |
-
|
| 50 |
-
torch.cuda.Event = MockCUDAEvent
|
| 51 |
-
if not hasattr(torch.cuda, 'synchronize'):
|
| 52 |
-
torch.cuda.synchronize = lambda: None
|
| 53 |
-
|
| 54 |
-
# ==========================================
|
| 55 |
-
# IMPORTS & ENGINE SETUP
|
| 56 |
-
# ==========================================
|
| 57 |
-
from src.quantized_text2sql_engine import QuantizedText2SQLEngine
|
| 58 |
-
from src.schema_encoder import SchemaEncoder
|
| 59 |
-
|
| 60 |
-
DEFAULT_QUANT_ARTIFACT = str(PROJECT_ROOT / "int8_dynamic")
|
| 61 |
-
|
| 62 |
-
_ENGINE_CACHE = {}
|
| 63 |
-
_QUERY_LOG = []
|
| 64 |
-
_PERF_LOG = []
|
| 65 |
-
_SUCCESS_LOG = []
|
| 66 |
-
|
| 67 |
-
_OP_STATS = {
|
| 68 |
-
"SELECT": {"ok": 0, "fail": 0}, "WHERE": {"ok": 0, "fail": 0}, "JOIN": {"ok": 0, "fail": 0},
|
| 69 |
-
"GROUP_BY": {"ok": 0, "fail": 0}, "ORDER_BY": {"ok": 0, "fail": 0}, "HAVING": {"ok": 0, "fail": 0}, "LIMIT": {"ok": 0, "fail": 0},
|
| 70 |
-
}
|
| 71 |
-
|
| 72 |
-
def get_quant_engine(artifact_dir: str, use_constrained: bool = False, exec_workers: int = 8, use_cache: bool = True):
|
| 73 |
-
key = (artifact_dir, bool(use_constrained), int(exec_workers), bool(use_cache))
|
| 74 |
-
if key not in _ENGINE_CACHE:
|
| 75 |
-
try:
|
| 76 |
-
_ENGINE_CACHE[key] = QuantizedText2SQLEngine(artifact_dir, device="cpu", use_constrained=bool(use_constrained), exec_workers=int(exec_workers), use_cache=bool(use_cache))
|
| 77 |
-
except TypeError:
|
| 78 |
-
_ENGINE_CACHE[key] = QuantizedText2SQLEngine(artifact_dir)
|
| 79 |
-
return _ENGINE_CACHE[key]
|
| 80 |
-
|
| 81 |
-
# 🚨 LAZY LOADING: We DO NOT load the model here! We only load the fast Schema Encoder.
|
| 82 |
-
quant_engine = None
|
| 83 |
-
try:
|
| 84 |
-
schema_encoder = SchemaEncoder(DB_ROOT)
|
| 85 |
-
except Exception as e:
|
| 86 |
-
print(f"Warning: SchemaEncoder failed to load: {e}")
|
| 87 |
-
schema_encoder = None
|
| 88 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
SAMPLES = [
|
| 90 |
-
("Show 10 distinct employee first names.", "chinook_1"),
|
| 91 |
-
("
|
| 92 |
-
("
|
| 93 |
-
("
|
| 94 |
-
("
|
| 95 |
-
("
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
]
|
|
|
|
| 97 |
SAMPLE_QUESTIONS = [q[0] for q in SAMPLES]
|
| 98 |
|
|
|
|
|
|
|
|
|
|
| 99 |
def explain_sql(sql):
|
| 100 |
-
if not sql: return ""
|
| 101 |
explanation = "This SQL query retrieves information from the database."
|
| 102 |
sql_lower = sql.lower()
|
| 103 |
-
|
| 104 |
-
if "
|
| 105 |
-
|
| 106 |
-
if "
|
| 107 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
return explanation
|
| 109 |
|
| 110 |
-
def sql_ops(sql: str) -> list[str]:
|
| 111 |
-
s = (sql or "").lower()
|
| 112 |
-
ops = ["SELECT"]
|
| 113 |
-
if " where " in f" {s} ": ops.append("WHERE")
|
| 114 |
-
if " join " in f" {s} ": ops.append("JOIN")
|
| 115 |
-
if " group by " in f" {s} ": ops.append("GROUP_BY")
|
| 116 |
-
if " order by " in f" {s} ": ops.append("ORDER_BY")
|
| 117 |
-
if " having " in f" {s} ": ops.append("HAVING")
|
| 118 |
-
if " limit " in f" {s} ": ops.append("LIMIT")
|
| 119 |
-
return ops
|
| 120 |
-
|
| 121 |
-
def classify_error(sql: str, error_msg: str | None = None, *, timed_out: bool = False):
|
| 122 |
-
s = (sql or "").lower()
|
| 123 |
-
m = (error_msg or "").lower()
|
| 124 |
-
if timed_out or "interrupted" in m or "timeout" in m: return "timeout"
|
| 125 |
-
if not s.strip().startswith(("select", "with")): return "syntax_error"
|
| 126 |
-
if " join " in f" {s} " and " on " not in f" {s} ": return "missing_join"
|
| 127 |
-
if " where " in f" {s} " and not any(op in s for op in ["=", ">", "<", " in ", " like ", " between ", " is null", " is not null"]): return "wrong_where"
|
| 128 |
-
if ("is null" in s or "is not null" in s) and ("no such column" in m or "misuse" in m): return "null_handling"
|
| 129 |
-
if "no such table" in m: return "missing_table"
|
| 130 |
-
if "no such column" in m: return "missing_column"
|
| 131 |
-
if "ambiguous column name" in m: return "ambiguous_column"
|
| 132 |
-
if "datatype mismatch" in m or "type mismatch" in m: return "type_mismatch"
|
| 133 |
-
if "misuse of aggregate" in m or "misuse of aggregate function" in m: return "wrong_aggregation"
|
| 134 |
-
if "syntax error" in m: return "syntax_error"
|
| 135 |
-
if "near" in m and "syntax error" in m: return "syntax_error"
|
| 136 |
-
if "runtime" in m or "constraint failed" in m: return "runtime_error"
|
| 137 |
-
return "other"
|
| 138 |
-
|
| 139 |
-
def get_hint(error_type):
|
| 140 |
-
hints = {
|
| 141 |
-
"missing_join": "Check JOIN conditions between tables.", "wrong_aggregation": "Use proper aggregation like avg(column).",
|
| 142 |
-
"wrong_where": "Check WHERE condition syntax.", "syntax_error": "Ensure SQL starts with SELECT.",
|
| 143 |
-
"missing_table": "Use only tables from the provided schema.", "missing_column": "Use only columns from the provided schema.",
|
| 144 |
-
"ambiguous_column": "Disambiguate by using table.column.", "timeout": "Query took too long; simplify joins.", "other": "Review SQL logic."
|
| 145 |
-
}
|
| 146 |
-
return hints.get(error_type, "Review query.")
|
| 147 |
-
|
| 148 |
-
def is_relevant_to_schema(question, db_id):
|
| 149 |
-
if schema_encoder is None: return True
|
| 150 |
-
try: raw_schema = schema_encoder.structured_schema(db_id).lower()
|
| 151 |
-
except: return True
|
| 152 |
-
schema_words = set(re.findall(r'[a-z0-9_]+', raw_schema))
|
| 153 |
-
q_words = re.findall(r'[a-z0-9_]+', question.lower())
|
| 154 |
-
stop_words = {"show", "list", "all", "what", "is", "the", "how", "many", "count", "find", "get", "me", "a", "an", "of", "in", "for", "from", "with", "which", "are", "there", "give", "tell", "details", "info", "data", "everything"}
|
| 155 |
-
meaningful_q_words = [w for w in q_words if w not in stop_words and not w.isdigit()]
|
| 156 |
-
if not meaningful_q_words: return True
|
| 157 |
-
for word in meaningful_q_words:
|
| 158 |
-
singular_word = word[:-1] if word.endswith('s') else word
|
| 159 |
-
if word in schema_words or singular_word in schema_words: return True
|
| 160 |
-
return False
|
| 161 |
|
|
|
|
|
|
|
|
|
|
| 162 |
def run_query(method, sample_q, custom_q, db_id):
|
| 163 |
-
global quant_engine
|
| 164 |
|
| 165 |
-
#
|
| 166 |
-
if
|
| 167 |
-
print(f"First request detected! Loading AI model from {DEFAULT_QUANT_ARTIFACT}...", flush=True)
|
| 168 |
-
try:
|
| 169 |
-
quant_engine = get_quant_engine(DEFAULT_QUANT_ARTIFACT, use_constrained=False, exec_workers=8, use_cache=True)
|
| 170 |
-
if quant_engine is None:
|
| 171 |
-
return "-- ❌ ENGINE CRASH", pd.DataFrame(columns=["Error"]), "Failed to load model. Did you move the tokenizer files and add config.json to int8_dynamic/?"
|
| 172 |
-
except Exception as e:
|
| 173 |
-
return f"-- ❌ ENGINE CRASH\n-- {str(e)}", pd.DataFrame(columns=["Error Status"]), f"Critical failure loading model: {e}"
|
| 174 |
-
|
| 175 |
-
def _log(error_type: str, *, question: str, db_id_val: str, sql: str = "", error_msg: str = "") -> None:
|
| 176 |
-
_QUERY_LOG.append({"t": time.time(), "db_id": str(db_id_val), "question": str(question), "sql": str(sql), "error_type": str(error_type), "error_msg": str(error_msg)})
|
| 177 |
-
|
| 178 |
-
def _perf_log(payload: dict) -> None:
|
| 179 |
-
_PERF_LOG.append(payload)
|
| 180 |
-
if len(_PERF_LOG) > 1000: del _PERF_LOG[:200]
|
| 181 |
-
|
| 182 |
-
raw_question = sample_q if method == "💡 Pick a Sample" else custom_q
|
| 183 |
-
|
| 184 |
-
if not raw_question or str(raw_question).strip() == "":
|
| 185 |
-
return "-- No input provided", pd.DataFrame(columns=["Warning"]), "⚠️ Please enter a question."
|
| 186 |
-
if not db_id or str(db_id).strip() == "":
|
| 187 |
-
return "-- No database selected", pd.DataFrame(columns=["Warning"]), "⚠️ Please select a database."
|
| 188 |
-
|
| 189 |
-
typo_corrections = [(r'\bshaw\b', 'show'), (r'\bshw\b', 'show'), (r'\bsho\b', 'show'), (r'\blsit\b', 'list'), (r'\blis\b', 'list'), (r'\bfidn\b', 'find'), (r'\bfnd\b', 'find'), (r'\bgte\b', 'get')]
|
| 190 |
-
question = str(raw_question)
|
| 191 |
-
for bad, good in typo_corrections: question = re.sub(bad, good, question, flags=re.IGNORECASE)
|
| 192 |
-
q_lower = question.strip().lower()
|
| 193 |
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
return "
|
| 197 |
-
|
| 198 |
-
if
|
| 199 |
-
|
| 200 |
-
return "-- ❌ BLOCKED: Data Modification", pd.DataFrame(columns=["Security Alert"]), "🛑 Security Alert: Modifying or deleting data is strictly prohibited."
|
| 201 |
-
|
| 202 |
-
if not is_relevant_to_schema(question, db_id):
|
| 203 |
-
_log("out_of_domain", question=question, db_id_val=str(db_id), error_msg="out of domain")
|
| 204 |
-
return "-- ❌ BLOCKED: Out of Domain", pd.DataFrame(columns=["Domain Alert"]), f"🛑 Relevance Alert: I don't see anything related to your question in the '{db_id}' schema."
|
| 205 |
|
| 206 |
start_time = time.time()
|
| 207 |
-
t0 = time.perf_counter()
|
| 208 |
-
ui_warnings = ""
|
| 209 |
|
|
|
|
| 210 |
try:
|
| 211 |
-
|
| 212 |
-
result = quant_engine.ask(question, str(db_id), num_beams=4, max_new_tokens=120, timeout_s=2.0)
|
| 213 |
-
except TypeError:
|
| 214 |
-
result = quant_engine.ask(question, str(db_id))
|
| 215 |
except Exception as e:
|
| 216 |
-
|
| 217 |
-
return f"-- ❌ BACKEND CRASH\n-- {str(e)}", pd.DataFrame(columns=["Error Status"]), f"❌ CRITICAL BACKEND CRASH:\n{str(e)}"
|
| 218 |
-
|
| 219 |
-
final_sql = str(result.get("sql", ""))
|
| 220 |
-
model_sql = final_sql
|
| 221 |
-
|
| 222 |
-
num_match = re.search(r'\b(?:show|list|top|limit|get|first|last|sample|of)\s+(?:[a-zA-Z_]+\s+)?(\d+)\b', q_lower)
|
| 223 |
-
if not num_match and q_lower.startswith(("show", "list", "get")):
|
| 224 |
-
num_match = re.search(r'\b(\d+)\b', q_lower)
|
| 225 |
-
|
| 226 |
-
if num_match and final_sql:
|
| 227 |
-
limit_val = num_match.group(1)
|
| 228 |
-
final_sql = re.sub(rf"(?i)\s*(?:where|having|and)?\s*count\s*\(\s*\*\s*\)\s*=\s*{limit_val}", "", final_sql)
|
| 229 |
-
final_sql = re.sub(rf"(?i)\s*(?:where|and)\s+[a-zA-Z0-9_.]+\s*=\s*['\"]?{limit_val}['\"]?", "", final_sql)
|
| 230 |
-
final_sql = re.sub(r"(?i)\s*where\s*$", "", final_sql)
|
| 231 |
-
final_sql = re.sub(r"(?i)\s*where\s+(group by|order by|limit)", r" \1", final_sql)
|
| 232 |
-
|
| 233 |
-
agg_kws = ["most", "top", "highest", "lowest", "count", "many", "group", "frequent", "popular"]
|
| 234 |
-
if not any(k in q_lower for k in agg_kws):
|
| 235 |
-
final_sql = re.sub(r"(?i)\s*group by\s+[a-zA-Z0-9_.]+\s*order by\s+count\(\*\)\s*(?:desc|asc)?", "", final_sql)
|
| 236 |
-
final_sql = re.sub(r"(?i)\s*order by\s+count\(\*\)\s*(?:desc|asc)?", "", final_sql)
|
| 237 |
-
final_sql = re.sub(r"(?i),\s*count\(\*\)", "", final_sql)
|
| 238 |
-
final_sql = re.sub(r"(?i)count\(\*\)\s*,", "", final_sql)
|
| 239 |
-
|
| 240 |
-
if "group by" in final_sql.lower() and not re.search(r'(?i)\b(count|sum|avg|max|min)\b\(', final_sql):
|
| 241 |
-
final_sql = re.sub(r"(?i)\s*group by\s+[a-zA-Z0-9_.]+", "", final_sql)
|
| 242 |
-
|
| 243 |
-
if "limit" not in final_sql.lower():
|
| 244 |
-
final_sql = f"{final_sql.strip().rstrip(';')} LIMIT {limit_val}"
|
| 245 |
-
|
| 246 |
-
# Execution
|
| 247 |
-
from src.sql_validator import validate_sql_schema
|
| 248 |
-
db_path = get_db_path(str(db_id))
|
| 249 |
-
|
| 250 |
-
try: strict_valid, _ = validate_sql_schema(final_sql, db_path)
|
| 251 |
-
except Exception: strict_valid = False
|
| 252 |
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
|
|
|
| 256 |
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
sqlite_success = True
|
| 260 |
-
except Exception as e:
|
| 261 |
-
error_msg = str(e)
|
| 262 |
-
sqlite_success = False
|
| 263 |
-
|
| 264 |
-
if not sqlite_success and model_sql and model_sql != final_sql:
|
| 265 |
-
try:
|
| 266 |
-
alt_rows, alt_cols = quant_engine._execute_one(model_sql, db_path, timeout_s=2.0)
|
| 267 |
-
final_sql = model_sql
|
| 268 |
-
rows, cols = alt_rows, alt_cols
|
| 269 |
-
sqlite_success = True
|
| 270 |
-
error_msg = None
|
| 271 |
-
except Exception: pass
|
| 272 |
-
|
| 273 |
-
valid = sqlite_success
|
| 274 |
-
|
| 275 |
-
if error_msg or not valid:
|
| 276 |
-
et = classify_error(final_sql, str(error_msg or ""), timed_out=("interrupted" in str(error_msg or "").lower()))
|
| 277 |
-
_log(et, question=str(question), db_id_val=str(db_id), sql=str(final_sql), error_msg=str(error_msg or "Execution failed"))
|
| 278 |
-
|
| 279 |
-
latency = round(time.time() - start_time, 3)
|
| 280 |
-
t1 = time.perf_counter()
|
| 281 |
-
|
| 282 |
-
engine_stats_after = quant_engine.stats() if hasattr(quant_engine, 'stats') else {}
|
| 283 |
-
|
| 284 |
-
perf = {
|
| 285 |
-
"db_id": str(db_id), "use_constrained_decoding": False, "num_beams": 4,
|
| 286 |
-
"latency_total_ms": round((t1 - t0) * 1000.0, 2), "constraint_ok": bool(strict_valid), "has_error": bool(error_msg),
|
| 287 |
-
"exec_cache_hit_rate": float(engine_stats_after.get("exec_cache_hit_rate", 0.0) or 0.0),
|
| 288 |
-
}
|
| 289 |
-
_perf_log(perf)
|
| 290 |
-
|
| 291 |
-
window = _PERF_LOG[-50:]
|
| 292 |
-
avg_ms = sum(float(x.get("latency_total_ms", 0.0) or 0.0) for x in window) / len(window) if window else 0.0
|
| 293 |
-
constraint_rate = sum(1 for x in window if x.get("constraint_ok")) / len(window) if window else 0.0
|
| 294 |
-
|
| 295 |
-
perf_block = (
|
| 296 |
-
"\n\n---\nPerformance (task impact)\n"
|
| 297 |
-
f"- Total latency (ms): {perf['latency_total_ms']}\n"
|
| 298 |
-
f"- Strict Python Validator OK (Task 3): {perf['constraint_ok']}\n"
|
| 299 |
-
f"- Exec cache hit-rate (Task 1/5): {round(perf['exec_cache_hit_rate'], 3)}\n"
|
| 300 |
-
f"- Rolling avg latency last 50 (ms): {round(avg_ms, 2)}\n"
|
| 301 |
-
f"- Rolling constraint rate last 50: {round(constraint_rate, 3)}\n"
|
| 302 |
-
)
|
| 303 |
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
explanation
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
explanation = f"
|
| 319 |
-
|
| 320 |
-
ops = sql_ops(final_sql)
|
| 321 |
-
for op in ops:
|
| 322 |
-
if op in _OP_STATS: _OP_STATS[op]["ok"] += 1
|
| 323 |
-
_SUCCESS_LOG.append({"t": time.time(), "db_id": str(db_id), "question": question, "sql": final_sql, "ops": ops})
|
| 324 |
|
| 325 |
limit_match = re.search(r'LIMIT\s+(\d+)', final_sql, re.IGNORECASE)
|
| 326 |
-
if limit_match
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
env = os.environ.copy()
|
| 334 |
-
env["PYTHONPATH"] = project_root + (os.pathsep + env["PYTHONPATH"] if env.get("PYTHONPATH") else "")
|
| 335 |
-
env.setdefault("MPLBACKEND", "Agg")
|
| 336 |
-
env.setdefault("MPLCONFIGDIR", "/tmp/mplconfig")
|
| 337 |
-
try: os.makedirs(env["MPLCONFIGDIR"], exist_ok=True)
|
| 338 |
-
except Exception: pass
|
| 339 |
-
|
| 340 |
-
cmd = [sys.executable, "-u", "scripts/benchmark_parallel_reward.py", "--n", str(int(n_rollouts)), "--max-workers", str(int(max_workers)), "--skip-profile"]
|
| 341 |
-
proc = subprocess.Popen(cmd, cwd=project_root, env=env, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1)
|
| 342 |
-
last_yield = time.perf_counter()
|
| 343 |
-
lines: list[str] = []
|
| 344 |
-
yield "Running Task 1 benchmark...\n", "<i>Running...</i>"
|
| 345 |
-
|
| 346 |
-
assert proc.stdout is not None
|
| 347 |
-
for line in proc.stdout:
|
| 348 |
-
lines.append(line)
|
| 349 |
-
now = time.perf_counter()
|
| 350 |
-
if now - last_yield >= 0.5:
|
| 351 |
-
last_yield = now
|
| 352 |
-
yield "".join(lines[-200:]).strip(), "<i>Running...</i>"
|
| 353 |
-
|
| 354 |
-
proc.wait()
|
| 355 |
-
out = "".join(lines).strip()
|
| 356 |
-
|
| 357 |
-
plot_path = str(PROJECT_ROOT / "results" / "task1_plot.png")
|
| 358 |
-
if os.path.exists(plot_path):
|
| 359 |
-
try:
|
| 360 |
-
b64 = base64.b64encode(Path(plot_path).read_bytes()).decode("ascii")
|
| 361 |
-
yield out, f"<img src='data:image/png;base64,{b64}' style='max-width: 100%; border: 1px solid #e2e8f0; border-radius: 8px;' />"
|
| 362 |
-
return
|
| 363 |
-
except Exception:
|
| 364 |
-
yield out, f"<pre>{plot_path}</pre>"
|
| 365 |
-
return
|
| 366 |
-
|
| 367 |
-
yield out, "<i>No plot generated</i>"
|
| 368 |
-
|
| 369 |
-
def task2_dashboard_structured():
|
| 370 |
-
if not _QUERY_LOG:
|
| 371 |
-
empty_counts = pd.DataFrame(columns=["error_type", "count", "hint"])
|
| 372 |
-
empty_recent = pd.DataFrame(columns=["time", "db_id", "error_type", "question", "error_msg"])
|
| 373 |
-
return empty_counts, empty_recent, gr.update(choices=[], value=None)
|
| 374 |
-
|
| 375 |
-
counts = {}
|
| 376 |
-
for r in _QUERY_LOG[-1000:]:
|
| 377 |
-
k = r.get("error_type") or "other"
|
| 378 |
-
counts[k] = counts.get(k, 0) + 1
|
| 379 |
-
rows = [{"error_type": k, "count": int(v), "hint": get_hint(k)} for k, v in sorted(counts.items(), key=lambda x: (-x[1], x[0]))]
|
| 380 |
-
counts_df = pd.DataFrame(rows)
|
| 381 |
-
|
| 382 |
-
recent = []
|
| 383 |
-
for r in _QUERY_LOG[-100:]:
|
| 384 |
-
ts = r.get("t")
|
| 385 |
-
try: ts_s = time.strftime("%H:%M:%S", time.localtime(float(ts))) if ts else ""
|
| 386 |
-
except Exception: ts_s = ""
|
| 387 |
-
recent.append({"time": ts_s, "db_id": r.get("db_id", ""), "error_type": r.get("error_type", ""), "question": r.get("question", ""), "error_msg": r.get("error_msg", "")})
|
| 388 |
-
recent_df = pd.DataFrame(recent)
|
| 389 |
-
|
| 390 |
-
choices = [str(x["error_type"]) for x in rows]
|
| 391 |
-
default = choices[0] if choices else None
|
| 392 |
-
return counts_df, recent_df, gr.update(choices=choices, value=default)
|
| 393 |
-
|
| 394 |
-
def task2_error_examples(error_type: str) -> str:
|
| 395 |
-
if not error_type: return ""
|
| 396 |
-
hint = get_hint(error_type)
|
| 397 |
-
matches = [r for r in reversed(_QUERY_LOG) if (r.get("error_type") or "") == str(error_type)][:3]
|
| 398 |
-
if not matches: return f"Error type: {error_type}\nHint: {hint}\n\nNo examples yet."
|
| 399 |
-
out = [f"Error type: {error_type}", f"Hint: {hint}", ""]
|
| 400 |
-
for i, r in enumerate(matches, 1):
|
| 401 |
-
out.extend([f"Example {i}", f"DB: {r.get('db_id','')}", f"Q: {r.get('question','')}", f"SQL: {r.get('sql','')}", f"Msg: {r.get('error_msg','')}", ""])
|
| 402 |
-
return "\n".join(out).strip()
|
| 403 |
-
|
| 404 |
-
def _plot_op_stats_html() -> str:
|
| 405 |
-
try:
|
| 406 |
-
import matplotlib.pyplot as plt
|
| 407 |
-
labels = list(_OP_STATS.keys())
|
| 408 |
-
oks = [int(_OP_STATS[k]["ok"]) for k in labels]
|
| 409 |
-
fails = [int(_OP_STATS[k]["fail"]) for k in labels]
|
| 410 |
-
|
| 411 |
-
fig, ax = plt.subplots(figsize=(9, 3.5))
|
| 412 |
-
x = list(range(len(labels)))
|
| 413 |
-
ax.bar(x, oks, label="ok", color="#16a34a")
|
| 414 |
-
ax.bar(x, fails, bottom=oks, label="fail", color="#dc2626")
|
| 415 |
-
ax.set_xticks(x)
|
| 416 |
-
ax.set_xticklabels(labels, rotation=30, ha="right")
|
| 417 |
-
ax.set_title("Success/Failure by SQL operation")
|
| 418 |
-
ax.legend()
|
| 419 |
-
fig.tight_layout()
|
| 420 |
-
|
| 421 |
-
buf = io.BytesIO()
|
| 422 |
-
fig.savefig(buf, format="png", dpi=160)
|
| 423 |
-
plt.close(fig)
|
| 424 |
-
b64 = base64.b64encode(buf.getvalue()).decode("ascii")
|
| 425 |
-
return f"<img src='data:image/png;base64,{b64}' style='max-width: 100%; border: 1px solid #e2e8f0; border-radius: 8px;' />"
|
| 426 |
-
except Exception as e: return f"<pre>Plot error: {e}</pre>"
|
| 427 |
-
|
| 428 |
-
def task2_ops_table():
|
| 429 |
-
rows = []
|
| 430 |
-
for op, d in _OP_STATS.items():
|
| 431 |
-
ok = int(d.get("ok", 0))
|
| 432 |
-
fail = int(d.get("fail", 0))
|
| 433 |
-
total = ok + fail
|
| 434 |
-
rows.append({"op": op, "ok": ok, "fail": fail, "total": total, "success_rate": (ok / total) if total else 0.0})
|
| 435 |
-
return pd.DataFrame(rows), _plot_op_stats_html()
|
| 436 |
|
| 437 |
def toggle_input_method(method, current_sample):
|
| 438 |
if method == "💡 Pick a Sample":
|
|
|
|
| 439 |
db = next((db for q, db in SAMPLES if q == current_sample), "chinook_1")
|
| 440 |
-
return (
|
| 441 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 442 |
|
| 443 |
def load_sample(selected_question):
|
| 444 |
-
if not selected_question:
|
| 445 |
-
|
|
|
|
|
|
|
|
|
|
| 446 |
|
| 447 |
def clear_inputs():
|
| 448 |
-
return (
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 449 |
|
| 450 |
def update_schema(db_id):
|
| 451 |
-
if not db_id
|
|
|
|
| 452 |
try:
|
| 453 |
-
raw_schema =
|
| 454 |
html_output = "<div style='max-height: 250px; overflow-y: auto; background: #f8fafc; padding: 12px; border-radius: 8px; border: 1px solid #e2e8f0; font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, monospace; font-size: 0.9em; line-height: 1.6;'>"
|
| 455 |
for line in raw_schema.strip().split('\n'):
|
| 456 |
line = line.strip()
|
| 457 |
if not line: continue
|
| 458 |
match = re.search(r'^([a-zA-Z0-9_]+)\s*\((.*)\)', line)
|
| 459 |
-
if match:
|
| 460 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 461 |
html_output += "</div>"
|
| 462 |
return html_output
|
| 463 |
-
except Exception as e:
|
|
|
|
|
|
|
| 464 |
|
| 465 |
# =========================
|
| 466 |
# UI LAYOUT
|
| 467 |
# =========================
|
| 468 |
-
with gr.Blocks(title="Text-to-SQL RLHF") as demo:
|
| 469 |
-
|
|
|
|
|
|
|
| 470 |
<div style="text-align: center; background-color: #e0e7ff; padding: 20px; border-radius: 10px; margin-bottom: 20px; border: 1px solid #c7d2fe;">
|
| 471 |
<h1 style="color: #3730a3; margin-top: 0; margin-bottom: 10px; font-size: 2.2em;"> Text-to-SQL using RLHF + Execution Reward</h1>
|
| 472 |
<p style="color: #4f46e5; font-size: 1.1em; margin: 0;">Convert Natural Language to SQL, strictly validated and safely executed on local SQLite databases.</p>
|
| 473 |
</div>
|
| 474 |
-
|
|
|
|
| 475 |
|
| 476 |
-
DBS = sorted([
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 477 |
|
| 478 |
-
with gr.Tabs():
|
| 479 |
-
with gr.Tab("Inference"):
|
| 480 |
with gr.Row():
|
| 481 |
-
|
| 482 |
-
|
| 483 |
-
|
| 484 |
-
|
| 485 |
-
|
| 486 |
-
|
| 487 |
-
|
| 488 |
-
|
| 489 |
-
|
| 490 |
-
|
| 491 |
-
|
| 492 |
-
|
| 493 |
-
|
| 494 |
-
|
| 495 |
-
|
| 496 |
-
|
| 497 |
-
|
| 498 |
-
|
| 499 |
-
|
| 500 |
-
|
| 501 |
-
result_table = gr.Dataframe(label="Query Result Table", interactive=False, wrap=True)
|
| 502 |
-
explanation = gr.Textbox(label="AI Explanation + Execution Details", lines=8)
|
| 503 |
-
|
| 504 |
-
with gr.Tab("Diagnostics"):
|
| 505 |
-
gr.Markdown("## Diagnostics & Telemetry")
|
| 506 |
-
|
| 507 |
-
with gr.Accordion("Task 1: Parallel Reward Benchmark", open=False):
|
| 508 |
-
gr.Markdown("*(Simulates the heavy RLHF training workload by running hundreds of complex SQL queries concurrently to test SQLite multi-threading performance.)*")
|
| 509 |
-
t1_n = gr.Number(value=20, precision=0, label="Rollouts (n)")
|
| 510 |
-
t1_workers = gr.Number(value=10, precision=0, label="Max workers")
|
| 511 |
-
t1_run = gr.Button("Run Task 1 benchmark")
|
| 512 |
-
t1_out = gr.Textbox(label="Output", lines=12)
|
| 513 |
-
t1_plot = gr.HTML(label="Plot (if generated)")
|
| 514 |
-
t1_run.click(fn=task1_benchmark, inputs=[t1_n, t1_workers], outputs=[t1_out, t1_plot])
|
| 515 |
-
|
| 516 |
-
with gr.Accordion("Task 2: Error Dashboard", open=True):
|
| 517 |
-
gr.Markdown("*(Live telemetry tracking the most common SQL failures. Populates automatically when queries fail in the Inference tab.)*")
|
| 518 |
-
t2_refresh = gr.Button("Refresh dashboard")
|
| 519 |
-
t2_counts = gr.Dataframe(label="Error counts", interactive=False, wrap=True)
|
| 520 |
-
t2_recent = gr.Dataframe(label="Recent errors", interactive=False, wrap=True)
|
| 521 |
-
t2_type = gr.Dropdown(choices=[], value=None, label="Select error type")
|
| 522 |
-
t2_examples = gr.Textbox(label="Examples + hint", lines=10)
|
| 523 |
-
|
| 524 |
-
t2_refresh.click(fn=task2_dashboard_structured, inputs=[], outputs=[t2_counts, t2_recent, t2_type])
|
| 525 |
-
t2_type.change(fn=task2_error_examples, inputs=[t2_type], outputs=[t2_examples])
|
| 526 |
-
|
| 527 |
-
with gr.Accordion("Task 2: Clause Telemetry", open=False):
|
| 528 |
-
gr.Markdown("*(Analyzes which specific SQL clauses—SELECT, WHERE, JOIN, etc.—are most prone to errors during natural language generation.)*")
|
| 529 |
-
t2_ops_refresh = gr.Button("Refresh SQL-op stats")
|
| 530 |
-
t2_ops_tbl = gr.Dataframe(label="Success/failure by op", interactive=False, wrap=True)
|
| 531 |
-
t2_ops_plot = gr.HTML(label="Op plot")
|
| 532 |
-
t2_ops_refresh.click(fn=task2_ops_table, inputs=[], outputs=[t2_ops_tbl, t2_ops_plot])
|
| 533 |
-
|
| 534 |
-
# EVENT BINDING: The .then() forces the diagnostic tab to update live in the background!
|
| 535 |
-
input_method.change(fn=toggle_input_method, inputs=[input_method, sample_dropdown], outputs=[sample_dropdown, type_own_warning, custom_question, db_id])
|
| 536 |
sample_dropdown.change(fn=load_sample, inputs=[sample_dropdown], outputs=[db_id])
|
|
|
|
| 537 |
db_id.change(fn=update_schema, inputs=[db_id], outputs=[schema_display])
|
| 538 |
|
| 539 |
run_btn.click(
|
| 540 |
-
fn=run_query,
|
| 541 |
-
inputs=[input_method, sample_dropdown, custom_question, db_id],
|
| 542 |
outputs=[final_sql, result_table, explanation]
|
| 543 |
-
).then(
|
| 544 |
-
fn=task2_dashboard_structured, inputs=[], outputs=[t2_counts, t2_recent, t2_type]
|
| 545 |
-
).then(
|
| 546 |
-
fn=task2_ops_table, inputs=[], outputs=[t2_ops_tbl, t2_ops_plot]
|
| 547 |
)
|
| 548 |
|
| 549 |
-
clear_btn.click(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 550 |
|
| 551 |
if __name__ == "__main__":
|
| 552 |
-
|
| 553 |
-
base_port = int(os.environ.get("GRADIO_SERVER_PORT", 7860))
|
| 554 |
-
max_retries = 10
|
| 555 |
-
|
| 556 |
-
for port in range(base_port, base_port + max_retries):
|
| 557 |
-
try:
|
| 558 |
-
print(f"Attempting to start Gradio UI on {server_name}:{port}...", flush=True)
|
| 559 |
-
demo.launch(server_name=server_name, server_port=port)
|
| 560 |
-
break # If successful, exit the loop
|
| 561 |
-
except OSError as e:
|
| 562 |
-
if "Cannot find empty port" in str(e) or "Address already in use" in str(e):
|
| 563 |
-
print(f"⚠️ Port {port} is in use, trying next port...")
|
| 564 |
-
continue
|
| 565 |
-
else:
|
| 566 |
-
# If it's a different OSError, raise it normally
|
| 567 |
-
raise e
|
| 568 |
-
else:
|
| 569 |
-
print(f"❌ Could not find an open port between {base_port} and {base_port + max_retries - 1}.")
|
|
|
|
| 1 |
"""
|
| 2 |
+
GRADIO DEMO UI
|
| 3 |
NL → SQL → Result Table
|
| 4 |
"""
|
| 5 |
|
|
|
|
| 7 |
import pandas as pd
|
| 8 |
import re
|
| 9 |
import time
|
| 10 |
+
from src.text2sql_engine import get_engine
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
+
engine = get_engine()
|
| 13 |
+
|
| 14 |
+
# =========================
|
| 15 |
+
# SAMPLE QUESTIONS DATA
|
| 16 |
+
# =========================
|
| 17 |
SAMPLES = [
|
| 18 |
+
("Show 10 distinct employee first names.", "chinook_1"),
|
| 19 |
+
("Which artist has the most albums?", "chinook_1"),
|
| 20 |
+
("List all the tracks that belong to the 'Rock' genre.", "chinook_1"),
|
| 21 |
+
("What are the names of all the cities?", "flight_1"),
|
| 22 |
+
("Find the flight number and cost of the cheapest flight.", "flight_1"),
|
| 23 |
+
("List the airlines that fly out of New York.", "flight_1"),
|
| 24 |
+
("Which campus was opened between 1935 and 1939?", "csu_1"),
|
| 25 |
+
("Count the number of students in each department.", "college_2"),
|
| 26 |
+
("List the names of all clubs.", "club_1"),
|
| 27 |
+
("How many members does each club have?", "club_1"),
|
| 28 |
+
("Show the names of all cinemas.", "cinema"),
|
| 29 |
+
("Which cinema has the most screens?", "cinema")
|
| 30 |
]
|
| 31 |
+
|
| 32 |
SAMPLE_QUESTIONS = [q[0] for q in SAMPLES]
|
| 33 |
|
| 34 |
+
# =========================
|
| 35 |
+
# SQL EXPLAINER
|
| 36 |
+
# =========================
|
| 37 |
def explain_sql(sql):
|
|
|
|
| 38 |
explanation = "This SQL query retrieves information from the database."
|
| 39 |
sql_lower = sql.lower()
|
| 40 |
+
|
| 41 |
+
if "join" in sql_lower:
|
| 42 |
+
explanation += "\n• It combines data from multiple tables using JOIN."
|
| 43 |
+
if "where" in sql_lower:
|
| 44 |
+
explanation += "\n• It filters rows using a WHERE condition."
|
| 45 |
+
if "group by" in sql_lower:
|
| 46 |
+
explanation += "\n• It groups results using GROUP BY."
|
| 47 |
+
if "order by" in sql_lower:
|
| 48 |
+
explanation += "\n• It sorts the results using ORDER BY."
|
| 49 |
+
if "limit" in sql_lower:
|
| 50 |
+
explanation += "\n• It limits the number of returned rows."
|
| 51 |
+
|
| 52 |
return explanation
|
| 53 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
|
| 55 |
+
# =========================
|
| 56 |
+
# CORE FUNCTIONS
|
| 57 |
+
# =========================
|
| 58 |
def run_query(method, sample_q, custom_q, db_id):
|
|
|
|
| 59 |
|
| 60 |
+
# 1. Safely determine the question
|
| 61 |
+
question = sample_q if method == "💡 Pick a Sample" else custom_q
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
|
| 63 |
+
# 2. Validate inputs before hitting the engine
|
| 64 |
+
if not question or str(question).strip() == "":
|
| 65 |
+
return "", pd.DataFrame(), "⚠️ Please enter a question."
|
| 66 |
+
|
| 67 |
+
if not db_id or str(db_id).strip() == "":
|
| 68 |
+
return "", pd.DataFrame(), "⚠️ Please select a database."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
|
| 70 |
start_time = time.time()
|
|
|
|
|
|
|
| 71 |
|
| 72 |
+
# 3. GIANT SAFETY NET to prevent infinite loading spinners
|
| 73 |
try:
|
| 74 |
+
result = engine.ask(str(question), str(db_id))
|
|
|
|
|
|
|
|
|
|
| 75 |
except Exception as e:
|
| 76 |
+
return "", pd.DataFrame(), f"❌ CRITICAL BACKEND CRASH:\n{str(e)}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
|
| 78 |
+
final_sql = result.get("sql", "")
|
| 79 |
+
error_msg = result.get("error", None)
|
| 80 |
+
rows = result.get("rows", [])
|
| 81 |
+
cols = result.get("columns", [])
|
| 82 |
|
| 83 |
+
end_time = time.time()
|
| 84 |
+
latency = round(end_time - start_time, 3)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
|
| 86 |
+
# 4. Handle SQL generation/execution errors
|
| 87 |
+
if error_msg:
|
| 88 |
+
return final_sql, pd.DataFrame(), f"❌ SQL Error:\n{error_msg}"
|
| 89 |
+
|
| 90 |
+
# 5. Handle Zero Rows gracefully
|
| 91 |
+
if not rows:
|
| 92 |
+
df = pd.DataFrame(columns=cols if cols else [])
|
| 93 |
+
explanation = f"✅ Query executed successfully\n\nRows returned: 0\nExecution Time: {latency} sec\n\n{explain_sql(final_sql)}"
|
| 94 |
+
return final_sql, df, explanation
|
| 95 |
+
|
| 96 |
+
# 6. Handle successful execution
|
| 97 |
+
df = pd.DataFrame(rows, columns=cols)
|
| 98 |
+
actual_rows = len(rows)
|
| 99 |
+
|
| 100 |
+
explanation = f"✅ Query executed successfully\n\nRows returned: {actual_rows}\nExecution Time: {latency} sec\n\n{explain_sql(final_sql)}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
|
| 102 |
limit_match = re.search(r'LIMIT\s+(\d+)', final_sql, re.IGNORECASE)
|
| 103 |
+
if limit_match:
|
| 104 |
+
requested_limit = int(limit_match.group(1))
|
| 105 |
+
if actual_rows < requested_limit:
|
| 106 |
+
explanation += f"\n\nℹ️ Query allowed up to {requested_limit} rows but only {actual_rows} matched."
|
| 107 |
+
|
| 108 |
+
return final_sql, df, explanation
|
| 109 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
|
| 111 |
def toggle_input_method(method, current_sample):
|
| 112 |
if method == "💡 Pick a Sample":
|
| 113 |
+
# Find the DB matching the current sample (fallback to 'chinook_1')
|
| 114 |
db = next((db for q, db in SAMPLES if q == current_sample), "chinook_1")
|
| 115 |
+
return (
|
| 116 |
+
gr.update(visible=True), # Show sample_dropdown
|
| 117 |
+
gr.update(visible=False), # Hide type_own_warning
|
| 118 |
+
gr.update(visible=False), # Hide custom_question
|
| 119 |
+
gr.update(value=db, interactive=False) # Lock and reset db_id
|
| 120 |
+
)
|
| 121 |
+
else:
|
| 122 |
+
return (
|
| 123 |
+
gr.update(visible=False), # Hide sample_dropdown
|
| 124 |
+
gr.update(visible=True), # Show type_own_warning
|
| 125 |
+
gr.update(visible=True), # Show custom_question
|
| 126 |
+
gr.update(interactive=True) # Unlock db_id
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
|
| 130 |
def load_sample(selected_question):
|
| 131 |
+
if not selected_question:
|
| 132 |
+
return gr.update()
|
| 133 |
+
db = next((db for q, db in SAMPLES if q == selected_question), "chinook_1")
|
| 134 |
+
return gr.update(value=db)
|
| 135 |
+
|
| 136 |
|
| 137 |
def clear_inputs():
|
| 138 |
+
return (
|
| 139 |
+
gr.update(value="💡 Pick a Sample"),
|
| 140 |
+
gr.update(value=SAMPLE_QUESTIONS[0], visible=True), # sample_dropdown
|
| 141 |
+
gr.update(visible=False), # type_own_warning
|
| 142 |
+
gr.update(value="", visible=False), # custom_question
|
| 143 |
+
gr.update(value="chinook_1", interactive=False), # db_id
|
| 144 |
+
"", pd.DataFrame(), "" # Outputs (SQL, Table, Explanation)
|
| 145 |
+
)
|
| 146 |
|
| 147 |
def update_schema(db_id):
|
| 148 |
+
if not db_id:
|
| 149 |
+
return ""
|
| 150 |
try:
|
| 151 |
+
raw_schema = engine.get_schema(db_id)
|
| 152 |
html_output = "<div style='max-height: 250px; overflow-y: auto; background: #f8fafc; padding: 12px; border-radius: 8px; border: 1px solid #e2e8f0; font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, monospace; font-size: 0.9em; line-height: 1.6;'>"
|
| 153 |
for line in raw_schema.strip().split('\n'):
|
| 154 |
line = line.strip()
|
| 155 |
if not line: continue
|
| 156 |
match = re.search(r'^([a-zA-Z0-9_]+)\s*\((.*)\)', line)
|
| 157 |
+
if match:
|
| 158 |
+
table_name = match.group(1).upper()
|
| 159 |
+
columns = match.group(2).lower()
|
| 160 |
+
html_output += f"<div style='margin-bottom: 8px;'><strong style='color: #0f172a; font-size: 1.05em; font-weight: 800;'>{table_name}</strong> <span style='color: #64748b;'>( {columns} )</span></div>"
|
| 161 |
+
else:
|
| 162 |
+
html_output += f"<div style='color: #475569;'>{line}</div>"
|
| 163 |
html_output += "</div>"
|
| 164 |
return html_output
|
| 165 |
+
except Exception as e:
|
| 166 |
+
return f"<div style='color: red;'>Error loading schema: {str(e)}</div>"
|
| 167 |
+
|
| 168 |
|
| 169 |
# =========================
|
| 170 |
# UI LAYOUT
|
| 171 |
# =========================
|
| 172 |
+
with gr.Blocks(theme=gr.themes.Soft(), title="Text-to-SQL RLHF") as demo:
|
| 173 |
+
|
| 174 |
+
gr.HTML(
|
| 175 |
+
"""
|
| 176 |
<div style="text-align: center; background-color: #e0e7ff; padding: 20px; border-radius: 10px; margin-bottom: 20px; border: 1px solid #c7d2fe;">
|
| 177 |
<h1 style="color: #3730a3; margin-top: 0; margin-bottom: 10px; font-size: 2.2em;"> Text-to-SQL using RLHF + Execution Reward</h1>
|
| 178 |
<p style="color: #4f46e5; font-size: 1.1em; margin: 0;">Convert Natural Language to SQL, strictly validated and safely executed on local SQLite databases.</p>
|
| 179 |
</div>
|
| 180 |
+
"""
|
| 181 |
+
)
|
| 182 |
|
| 183 |
+
DBS = sorted([
|
| 184 |
+
"flight_1", "student_assessment", "store_1", "bike_1", "book_2", "chinook_1",
|
| 185 |
+
"academic", "aircraft", "car_1", "cinema", "club_1", "csu_1",
|
| 186 |
+
"college_1", "college_2", "company_1", "company_employee",
|
| 187 |
+
"customer_complaints", "department_store", "employee_hire_evaluation",
|
| 188 |
+
"museum_visit", "products_for_hire", "restaurant_1",
|
| 189 |
+
"school_finance", "shop_membership", "small_bank_1",
|
| 190 |
+
"soccer_1", "student_1", "tvshow", "voter_1", "world_1"
|
| 191 |
+
])
|
| 192 |
+
|
| 193 |
+
with gr.Row():
|
| 194 |
+
with gr.Column(scale=1):
|
| 195 |
+
gr.Markdown("### 1. Configuration & Input")
|
| 196 |
+
|
| 197 |
+
input_method = gr.Radio(
|
| 198 |
+
choices=["💡 Pick a Sample", "✍️ Type my own"],
|
| 199 |
+
value="💡 Pick a Sample",
|
| 200 |
+
label="How do you want to ask?"
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
# --- SAMPLE SECTION ---
|
| 204 |
+
sample_dropdown = gr.Dropdown(
|
| 205 |
+
choices=SAMPLE_QUESTIONS,
|
| 206 |
+
value=SAMPLE_QUESTIONS[0],
|
| 207 |
+
label="Select a Sample Question",
|
| 208 |
+
info="The database will be selected automatically.",
|
| 209 |
+
visible=True
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
# --- CUSTOM TYPE WARNING ---
|
| 213 |
+
type_own_warning = gr.Markdown(
|
| 214 |
+
"**⚠️ Please select a Database first, then type your custom question below:**",
|
| 215 |
+
visible=False
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
gr.Markdown("---")
|
| 219 |
+
|
| 220 |
+
# --- DATABASE SELECTION (Moved Up) ---
|
| 221 |
+
db_id = gr.Dropdown(
|
| 222 |
+
choices=DBS,
|
| 223 |
+
value="chinook_1",
|
| 224 |
+
label="Select Database",
|
| 225 |
+
interactive=False
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
# --- CUSTOM QUESTION BOX ---
|
| 229 |
+
custom_question = gr.Textbox(
|
| 230 |
+
label="Ask your Custom Question",
|
| 231 |
+
placeholder="Type your own question here...",
|
| 232 |
+
lines=3,
|
| 233 |
+
visible=False
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
gr.Markdown("#### 📋 Database Structure")
|
| 237 |
+
gr.HTML("<p style='font-size: 0.85em; color: #64748b; margin-top: -10px; margin-bottom: 5px;'>Use these exact names! Table names are <strong>Dark</strong>, Column names are <span style='color: #94a3b8;'>Light</span>.</p>")
|
| 238 |
+
schema_display = gr.HTML(value=update_schema("chinook_1"))
|
| 239 |
|
|
|
|
|
|
|
| 240 |
with gr.Row():
|
| 241 |
+
clear_btn = gr.Button("🗑️ Clear", variant="secondary")
|
| 242 |
+
run_btn = gr.Button(" Generate & Run SQL", variant="primary")
|
| 243 |
+
|
| 244 |
+
with gr.Column(scale=2):
|
| 245 |
+
gr.Markdown("### 2. Execution Results")
|
| 246 |
+
final_sql = gr.Code(language="sql", label="Final Executed SQL")
|
| 247 |
+
result_table = gr.Dataframe(label="Query Result Table", interactive=False, wrap=True)
|
| 248 |
+
explanation = gr.Textbox(label="AI Explanation + Execution Details", lines=8)
|
| 249 |
+
|
| 250 |
+
# =========================
|
| 251 |
+
# EVENT LISTENERS
|
| 252 |
+
# =========================
|
| 253 |
+
|
| 254 |
+
# Updated to handle the new Markdown warning toggle
|
| 255 |
+
input_method.change(
|
| 256 |
+
fn=toggle_input_method,
|
| 257 |
+
inputs=[input_method, sample_dropdown],
|
| 258 |
+
outputs=[sample_dropdown, type_own_warning, custom_question, db_id]
|
| 259 |
+
)
|
| 260 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 261 |
sample_dropdown.change(fn=load_sample, inputs=[sample_dropdown], outputs=[db_id])
|
| 262 |
+
|
| 263 |
db_id.change(fn=update_schema, inputs=[db_id], outputs=[schema_display])
|
| 264 |
|
| 265 |
run_btn.click(
|
| 266 |
+
fn=run_query,
|
| 267 |
+
inputs=[input_method, sample_dropdown, custom_question, db_id],
|
| 268 |
outputs=[final_sql, result_table, explanation]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 269 |
)
|
| 270 |
|
| 271 |
+
clear_btn.click(
|
| 272 |
+
fn=clear_inputs,
|
| 273 |
+
inputs=[],
|
| 274 |
+
# Output list matches the updated clear_inputs() return values
|
| 275 |
+
outputs=[input_method, sample_dropdown, type_own_warning, custom_question, db_id, final_sql, result_table, explanation]
|
| 276 |
+
)
|
| 277 |
|
| 278 |
if __name__ == "__main__":
|
| 279 |
+
demo.launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
best_rlhf_model/README.md
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
library_name: peft
|
| 3 |
+
---
|
| 4 |
+
## Training procedure
|
| 5 |
+
|
| 6 |
+
### Framework versions
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
- PEFT 0.4.0
|
best_rlhf_model/adapter_config.json
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"auto_mapping": null,
|
| 3 |
+
"base_model_name_or_path": "Salesforce/codet5-base",
|
| 4 |
+
"bias": "none",
|
| 5 |
+
"fan_in_fan_out": false,
|
| 6 |
+
"inference_mode": true,
|
| 7 |
+
"init_lora_weights": true,
|
| 8 |
+
"layers_pattern": null,
|
| 9 |
+
"layers_to_transform": null,
|
| 10 |
+
"lora_alpha": 32,
|
| 11 |
+
"lora_dropout": 0.05,
|
| 12 |
+
"modules_to_save": null,
|
| 13 |
+
"peft_type": "LORA",
|
| 14 |
+
"r": 16,
|
| 15 |
+
"revision": null,
|
| 16 |
+
"target_modules": [
|
| 17 |
+
"q",
|
| 18 |
+
"v"
|
| 19 |
+
],
|
| 20 |
+
"task_type": "SEQ_2_SEQ_LM"
|
| 21 |
+
}
|
{int8_dynamic/tokenizer → best_rlhf_model}/merges.txt
RENAMED
|
File without changes
|
{int8_dynamic/tokenizer → best_rlhf_model}/special_tokens_map.json
RENAMED
|
File without changes
|
{int8_dynamic/tokenizer → best_rlhf_model}/tokenizer_config.json
RENAMED
|
@@ -954,6 +954,5 @@
|
|
| 954 |
"pad_token": "<pad>",
|
| 955 |
"sep_token": "</s>",
|
| 956 |
"tokenizer_class": "RobertaTokenizer",
|
| 957 |
-
"trim_offsets": true,
|
| 958 |
"unk_token": "<unk>"
|
| 959 |
}
|
|
|
|
| 954 |
"pad_token": "<pad>",
|
| 955 |
"sep_token": "</s>",
|
| 956 |
"tokenizer_class": "RobertaTokenizer",
|
|
|
|
| 957 |
"unk_token": "<unk>"
|
| 958 |
}
|
best_rlhf_model/vocab.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
int8_dynamic/meta.json
DELETED
|
@@ -1,7 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"mode": "int8_dynamic",
|
| 3 |
-
"base_model": "Salesforce/codet5-base",
|
| 4 |
-
"adapter_path": "checkpoints/best_rlhf_model_2",
|
| 5 |
-
"created_at_s": 1774418718.320342,
|
| 6 |
-
"estimated_model_bytes": 98804736
|
| 7 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
int8_dynamic/model.pt
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:f398e044cd49fc84553b746d26ad79beb1dd565d90cf8f6f5e50d27f48d08228
|
| 3 |
-
size 322871519
|
|
|
|
|
|
|
|
|
|
|
|
int8_dynamic/tokenizer/tokenizer.json
DELETED
|
The diff for this file is too large to render.
See raw diff
|
|
|
int8_dynamic/tokenizer/vocab.json
DELETED
|
The diff for this file is too large to render.
See raw diff
|
|
|
requirements.txt
CHANGED
|
@@ -1,10 +1,8 @@
|
|
| 1 |
-
gradio
|
| 2 |
pandas
|
| 3 |
sqlparse
|
| 4 |
transformers
|
| 5 |
-
torch
|
| 6 |
peft
|
| 7 |
trl
|
| 8 |
-
sentencepiece
|
| 9 |
-
matplotlib
|
| 10 |
-
huggingface_hub
|
|
|
|
| 1 |
+
gradio
|
| 2 |
pandas
|
| 3 |
sqlparse
|
| 4 |
transformers
|
| 5 |
+
torch
|
| 6 |
peft
|
| 7 |
trl
|
| 8 |
+
sentencepiece
|
|
|
|
|
|
scripts/benchmark_parallel_reward.py
DELETED
|
@@ -1,202 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
# Ensure headless-safe matplotlib + writable cache when called from Gradio/subprocess.
|
| 3 |
-
os.environ.setdefault("MPLBACKEND", "Agg")
|
| 4 |
-
os.environ.setdefault("MPLCONFIGDIR", os.environ.get("MPLCONFIGDIR", "/tmp/mplconfig"))
|
| 5 |
-
import time
|
| 6 |
-
import json
|
| 7 |
-
import argparse
|
| 8 |
-
import matplotlib.pyplot as plt
|
| 9 |
-
import numpy as np
|
| 10 |
-
import sys
|
| 11 |
-
from pathlib import Path
|
| 12 |
-
|
| 13 |
-
# ==========================================
|
| 14 |
-
# RELATIVE PATH RESOLUTION
|
| 15 |
-
# ==========================================
|
| 16 |
-
PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
| 17 |
-
sys.path.append(str(PROJECT_ROOT))
|
| 18 |
-
|
| 19 |
-
# Dynamically resolve where the databases are kept
|
| 20 |
-
if (PROJECT_ROOT / "data" / "database").exists() and list((PROJECT_ROOT / "data" / "database").rglob("*.sqlite")):
|
| 21 |
-
DB_ROOT = PROJECT_ROOT / "data" / "database"
|
| 22 |
-
else:
|
| 23 |
-
DB_ROOT = PROJECT_ROOT / "final_databases"
|
| 24 |
-
|
| 25 |
-
from src.execution_reward import (
|
| 26 |
-
execution_reward_batch_sequential,
|
| 27 |
-
execution_reward_batch_parallel,
|
| 28 |
-
execution_reward_batch_parallel_by_db,
|
| 29 |
-
execution_reward_timed,
|
| 30 |
-
set_use_cache,
|
| 31 |
-
set_use_schema_validation,
|
| 32 |
-
clear_result_cache
|
| 33 |
-
)
|
| 34 |
-
|
| 35 |
-
def generate_mock_rollouts(num_rollouts: int = 100, heavy_n: int = 500_000):
|
| 36 |
-
"""Generates heavy queries across multiple databases to properly test true concurrency."""
|
| 37 |
-
print(f"\nGenerating {num_rollouts} heavy rollouts to simulate RLHF query workload...", flush=True)
|
| 38 |
-
|
| 39 |
-
# Smart search for real databases
|
| 40 |
-
real_dbs = [str(p) for p in DB_ROOT.rglob("*.sqlite")]
|
| 41 |
-
|
| 42 |
-
if real_dbs:
|
| 43 |
-
print(f"Found {len(real_dbs)} real SQLite databases in {DB_ROOT}. Distributing workload...", flush=True)
|
| 44 |
-
else:
|
| 45 |
-
print(f"❌ CRITICAL ERROR: No real databases found in {DB_ROOT}. Cannot run benchmark.", flush=True)
|
| 46 |
-
sys.exit(1)
|
| 47 |
-
|
| 48 |
-
rollouts = []
|
| 49 |
-
for i in range(num_rollouts):
|
| 50 |
-
db_path = real_dbs[i % len(real_dbs)]
|
| 51 |
-
|
| 52 |
-
# Heavy deterministic CPU-ish query (may be cut off by the 2s timeout depending on machine).
|
| 53 |
-
heavy_sql = f"""
|
| 54 |
-
WITH RECURSIVE cnt(x) AS (
|
| 55 |
-
SELECT 1
|
| 56 |
-
UNION ALL
|
| 57 |
-
SELECT x+1 FROM cnt WHERE x < {heavy_n + (i % 10_000)}
|
| 58 |
-
)
|
| 59 |
-
SELECT sum(x) FROM cnt;
|
| 60 |
-
"""
|
| 61 |
-
clean_sql = heavy_sql.replace("\n", " ").strip()
|
| 62 |
-
rollouts.append((clean_sql, db_path, clean_sql))
|
| 63 |
-
if num_rollouts >= 500 and (i + 1) % 250 == 0:
|
| 64 |
-
print(f" generated {i + 1}/{num_rollouts}...", flush=True)
|
| 65 |
-
|
| 66 |
-
return rollouts
|
| 67 |
-
|
| 68 |
-
def profile_bottlenecks(rollouts, sample_size: int = 20, print_every: int = 5):
|
| 69 |
-
"""Profiles CPU usage to identify time spent in parsing, planning, and execution."""
|
| 70 |
-
print("\n" + "="*65)
|
| 71 |
-
print(" 🔍 CPU PROFILING: IDENTIFYING BOTTLENECKS (100 Rollouts)")
|
| 72 |
-
print("="*65)
|
| 73 |
-
|
| 74 |
-
clear_result_cache()
|
| 75 |
-
set_use_cache(False) # Disable cache to force real work
|
| 76 |
-
set_use_schema_validation(False) # CTE-heavy benchmark queries may fail schema validation
|
| 77 |
-
|
| 78 |
-
total_parse = 0.0
|
| 79 |
-
total_plan = 0.0
|
| 80 |
-
total_exec = 0.0
|
| 81 |
-
|
| 82 |
-
# Profile a small subset by default so the script prints quickly.
|
| 83 |
-
sample_size = min(int(sample_size), len(rollouts))
|
| 84 |
-
sample_rollouts = rollouts[:sample_size]
|
| 85 |
-
|
| 86 |
-
for i, (pred, db, gold) in enumerate(sample_rollouts, 1):
|
| 87 |
-
_, timings = execution_reward_timed(pred, db, gold, measure_plan=True)
|
| 88 |
-
total_parse += timings['parse_s']
|
| 89 |
-
total_plan += timings['plan_s']
|
| 90 |
-
total_exec += timings['exec_s']
|
| 91 |
-
if print_every and (i % int(print_every) == 0 or i == sample_size):
|
| 92 |
-
print(f" profiled {i}/{sample_size}...", flush=True)
|
| 93 |
-
|
| 94 |
-
total_time = total_parse + total_plan + total_exec
|
| 95 |
-
if total_time == 0: total_time = 0.0001 # Prevent div by zero
|
| 96 |
-
|
| 97 |
-
print(f"{'Phase':<15} | {'Avg Time (ms)':<15} | {'% of Total CPU':<15}")
|
| 98 |
-
print("-" * 65)
|
| 99 |
-
print(f"{'Regex Parsing':<15} | {(total_parse/sample_size)*1000:<15.2f} | {(total_parse/total_time)*100:<14.1f}%")
|
| 100 |
-
print(f"{'Query Planning':<15} | {(total_plan/sample_size)*1000:<15.2f} | {(total_plan/total_time)*100:<14.1f}%")
|
| 101 |
-
print(f"{'DB Execution':<15} | {(total_exec/sample_size)*1000:<15.2f} | {(total_exec/total_time)*100:<14.1f}%")
|
| 102 |
-
print("="*65 + "\n")
|
| 103 |
-
|
| 104 |
-
def run_benchmark_for_setting(rollouts, use_cache: bool, max_workers: int):
|
| 105 |
-
set_use_cache(use_cache)
|
| 106 |
-
set_use_schema_validation(False) # benchmark focuses on execution speed
|
| 107 |
-
|
| 108 |
-
# Sequential
|
| 109 |
-
clear_result_cache()
|
| 110 |
-
start_time = time.perf_counter()
|
| 111 |
-
execution_reward_batch_sequential(rollouts)
|
| 112 |
-
sequential_s = time.perf_counter() - start_time
|
| 113 |
-
|
| 114 |
-
# Parallel
|
| 115 |
-
clear_result_cache()
|
| 116 |
-
start_time = time.perf_counter()
|
| 117 |
-
# 1 thread per DB (recommended)
|
| 118 |
-
execution_reward_batch_parallel_by_db(rollouts, max_workers=max_workers)
|
| 119 |
-
parallel_s = time.perf_counter() - start_time
|
| 120 |
-
|
| 121 |
-
speedup = sequential_s / parallel_s if parallel_s > 0 else 0
|
| 122 |
-
|
| 123 |
-
return {
|
| 124 |
-
"sequential_s": sequential_s,
|
| 125 |
-
"parallel_s": parallel_s,
|
| 126 |
-
"speedup": speedup
|
| 127 |
-
}
|
| 128 |
-
|
| 129 |
-
def print_comparison_table(results):
|
| 130 |
-
print("="*65)
|
| 131 |
-
print(f"{'Setting':<16} | {'Sequential (s)':<14} | {'Parallel (s)':<14} | {'Speedup':<10}")
|
| 132 |
-
print("-" * 65)
|
| 133 |
-
for setting, key in [("With Cache", "with_cache"), ("Without Cache", "without_cache")]:
|
| 134 |
-
seq = results[key]['sequential_s']
|
| 135 |
-
par = results[key]['parallel_s']
|
| 136 |
-
spd = results[key]['speedup']
|
| 137 |
-
print(f"{setting:<16} | {seq:<14.4f} | {par:<14.4f} | {spd:<9.2f}x")
|
| 138 |
-
print("="*65 + "\n")
|
| 139 |
-
|
| 140 |
-
def plot_results(results, output_path: str):
|
| 141 |
-
labels = ['With Cache', 'Without Cache']
|
| 142 |
-
seq_times = [results['with_cache']['sequential_s'], results['without_cache']['sequential_s']]
|
| 143 |
-
par_times = [results['with_cache']['parallel_s'], results['without_cache']['parallel_s']]
|
| 144 |
-
|
| 145 |
-
x = np.arange(len(labels))
|
| 146 |
-
width = 0.35
|
| 147 |
-
|
| 148 |
-
fig, ax = plt.subplots(figsize=(8, 6))
|
| 149 |
-
ax.bar(x - width/2, seq_times, width, label='Sequential', color='#4C72B0')
|
| 150 |
-
ax.bar(x + width/2, par_times, width, label='Parallel', color='#DD8452')
|
| 151 |
-
|
| 152 |
-
ax.set_ylabel('Execution Time (seconds)')
|
| 153 |
-
ax.set_title('Text2SQL Reward Execution: Sequential vs Parallel')
|
| 154 |
-
ax.set_xticks(x)
|
| 155 |
-
ax.set_xticklabels(labels)
|
| 156 |
-
ax.legend()
|
| 157 |
-
|
| 158 |
-
for container in ax.containers:
|
| 159 |
-
ax.bar_label(container, fmt='%.2f', padding=3)
|
| 160 |
-
|
| 161 |
-
fig.tight_layout()
|
| 162 |
-
plt.savefig(output_path, dpi=300)
|
| 163 |
-
plt.close()
|
| 164 |
-
|
| 165 |
-
def main():
|
| 166 |
-
parser = argparse.ArgumentParser(description="Benchmark SQL Execution Reward")
|
| 167 |
-
parser.add_argument("--n", type=int, default=1000, help="Number of rollouts to benchmark")
|
| 168 |
-
parser.add_argument("--max-workers", type=int, default=20, help="Max workers for parallel execution")
|
| 169 |
-
parser.add_argument("--heavy-n", type=int, default=200_000, help="Recursive CTE upper bound (controls heaviness)")
|
| 170 |
-
parser.add_argument("--skip-profile", action="store_true", help="Skip the CPU profiling section for faster startup")
|
| 171 |
-
parser.add_argument("--profile-n", type=int, default=20, help="Number of rollouts to use for CPU profiling")
|
| 172 |
-
args = parser.parse_args()
|
| 173 |
-
|
| 174 |
-
os.makedirs(str(PROJECT_ROOT / "results"), exist_ok=True)
|
| 175 |
-
|
| 176 |
-
rollouts = generate_mock_rollouts(args.n, heavy_n=args.heavy_n)
|
| 177 |
-
|
| 178 |
-
if not args.skip_profile:
|
| 179 |
-
profile_bottlenecks(rollouts, sample_size=args.profile_n)
|
| 180 |
-
|
| 181 |
-
print("Starting Main Scalability Benchmarks...")
|
| 182 |
-
|
| 183 |
-
print("Running Experiment A: Cache ENABLED...")
|
| 184 |
-
results_with_cache = run_benchmark_for_setting(rollouts, use_cache=True, max_workers=args.max_workers)
|
| 185 |
-
|
| 186 |
-
print("Running Experiment B: Cache DISABLED...")
|
| 187 |
-
results_without_cache = run_benchmark_for_setting(rollouts, use_cache=False, max_workers=args.max_workers)
|
| 188 |
-
|
| 189 |
-
final_results = {
|
| 190 |
-
"with_cache": results_with_cache,
|
| 191 |
-
"without_cache": results_without_cache
|
| 192 |
-
}
|
| 193 |
-
|
| 194 |
-
json_path = str(PROJECT_ROOT / "results" / "task1_results.json")
|
| 195 |
-
with open(json_path, 'w') as f:
|
| 196 |
-
json.dump(final_results, f, indent=4)
|
| 197 |
-
|
| 198 |
-
print_comparison_table(final_results)
|
| 199 |
-
plot_results(final_results, str(PROJECT_ROOT / "results" / "task1_plot.png"))
|
| 200 |
-
|
| 201 |
-
if __name__ == "__main__":
|
| 202 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/benchmark_quantization.py
DELETED
|
@@ -1,108 +0,0 @@
|
|
| 1 |
-
from __future__ import annotations
|
| 2 |
-
|
| 3 |
-
import argparse
|
| 4 |
-
import json
|
| 5 |
-
import os
|
| 6 |
-
import time
|
| 7 |
-
from pathlib import Path
|
| 8 |
-
from typing import Dict, List, Tuple
|
| 9 |
-
|
| 10 |
-
import numpy as np
|
| 11 |
-
import torch
|
| 12 |
-
|
| 13 |
-
from src.execution_reward import execution_reward
|
| 14 |
-
from src.prompting import encode_prompt
|
| 15 |
-
from src.quantization_utils import load_fp32_model, load_quant_artifact
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
def _load_dev_items(root: Path, n: int, seed: int = 42) -> List[dict]:
|
| 19 |
-
data = json.loads((root / "data" / "dev.json").read_text())
|
| 20 |
-
if n >= len(data):
|
| 21 |
-
return data
|
| 22 |
-
rng = np.random.default_rng(seed)
|
| 23 |
-
idxs = rng.choice(len(data), size=n, replace=False)
|
| 24 |
-
return [data[int(i)] for i in idxs]
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
def _bench_variant(name: str, tok, model, items: List[dict], device: str) -> Dict[str, float]:
|
| 28 |
-
latencies: List[float] = []
|
| 29 |
-
ex = 0
|
| 30 |
-
|
| 31 |
-
# Warmup (1 item)
|
| 32 |
-
if items:
|
| 33 |
-
it = items[0]
|
| 34 |
-
_ = encode_prompt(tok, it["question"], it["db_id"], device=device, max_input_tokens=512).unsqueeze(0)
|
| 35 |
-
|
| 36 |
-
for it in items:
|
| 37 |
-
db_id = it["db_id"]
|
| 38 |
-
q = it["question"]
|
| 39 |
-
gold = it["query"]
|
| 40 |
-
db_path = str(Path("data") / "database" / db_id / f"{db_id}.sqlite")
|
| 41 |
-
|
| 42 |
-
input_ids = encode_prompt(tok, q, db_id, device=device, max_input_tokens=512).unsqueeze(0)
|
| 43 |
-
t0 = time.perf_counter()
|
| 44 |
-
out = model.generate(input_ids=input_ids, max_new_tokens=120, num_beams=8, repetition_penalty=1.2)
|
| 45 |
-
dt = time.perf_counter() - t0
|
| 46 |
-
latencies.append(dt)
|
| 47 |
-
|
| 48 |
-
pred = tok.decode(out[0], skip_special_tokens=True).strip()
|
| 49 |
-
r = execution_reward(pred, db_path, gold)
|
| 50 |
-
if float(r) >= 1.0:
|
| 51 |
-
ex += 1
|
| 52 |
-
|
| 53 |
-
p50 = float(np.percentile(latencies, 50)) if latencies else 0.0
|
| 54 |
-
p90 = float(np.percentile(latencies, 90)) if latencies else 0.0
|
| 55 |
-
mean = float(np.mean(latencies)) if latencies else 0.0
|
| 56 |
-
return {
|
| 57 |
-
"n": float(len(items)),
|
| 58 |
-
"ex": float(ex / max(len(items), 1)),
|
| 59 |
-
"lat_mean_s": mean,
|
| 60 |
-
"lat_p50_s": p50,
|
| 61 |
-
"lat_p90_s": p90,
|
| 62 |
-
}
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
def main() -> None:
|
| 66 |
-
p = argparse.ArgumentParser(description="Benchmark fp32 vs quantized artifacts (CPU-focused).")
|
| 67 |
-
p.add_argument("--base_model", default=os.environ.get("BASE_MODEL", "Salesforce/codet5-base"))
|
| 68 |
-
p.add_argument("--adapter", default="", help="Optional adapter for fp32 baseline.")
|
| 69 |
-
p.add_argument("--artifact_int8", default="", help="Artifact dir exported by scripts/quantize_export.py")
|
| 70 |
-
p.add_argument("--artifact_int8_decoder", default="", help="Artifact dir for decoder-only int8")
|
| 71 |
-
p.add_argument("--num_samples", type=int, default=100)
|
| 72 |
-
p.add_argument("--seed", type=int, default=42)
|
| 73 |
-
p.add_argument("--out", default="results/task5_quant_bench.json")
|
| 74 |
-
p.add_argument("--local_only", action="store_true")
|
| 75 |
-
args = p.parse_args()
|
| 76 |
-
|
| 77 |
-
device = "cpu"
|
| 78 |
-
root = Path(".")
|
| 79 |
-
items = _load_dev_items(root, args.num_samples, args.seed)
|
| 80 |
-
|
| 81 |
-
report: Dict[str, Dict[str, float]] = {}
|
| 82 |
-
|
| 83 |
-
tok, fp32 = load_fp32_model(
|
| 84 |
-
args.base_model,
|
| 85 |
-
adapter_path=args.adapter.strip() or None,
|
| 86 |
-
device=device,
|
| 87 |
-
local_only=args.local_only,
|
| 88 |
-
)
|
| 89 |
-
report["fp32"] = _bench_variant("fp32", tok, fp32, items, device)
|
| 90 |
-
|
| 91 |
-
if args.artifact_int8:
|
| 92 |
-
tok8, m8, _meta = load_quant_artifact(args.artifact_int8, device=device, local_only=True)
|
| 93 |
-
report["int8_dynamic"] = _bench_variant("int8_dynamic", tok8, m8, items, device)
|
| 94 |
-
|
| 95 |
-
if args.artifact_int8_decoder:
|
| 96 |
-
tokd, md, _meta = load_quant_artifact(args.artifact_int8_decoder, device=device, local_only=True)
|
| 97 |
-
report["int8_decoder_dynamic"] = _bench_variant("int8_decoder_dynamic", tokd, md, items, device)
|
| 98 |
-
|
| 99 |
-
out_path = Path(args.out)
|
| 100 |
-
out_path.parent.mkdir(parents=True, exist_ok=True)
|
| 101 |
-
out_path.write_text(json.dumps(report, indent=2))
|
| 102 |
-
print(json.dumps(report, indent=2))
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
if __name__ == "__main__":
|
| 106 |
-
torch.set_grad_enabled(False)
|
| 107 |
-
main()
|
| 108 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/benchmark_rollout_generation.py
DELETED
|
@@ -1,66 +0,0 @@
|
|
| 1 |
-
from __future__ import annotations
|
| 2 |
-
|
| 3 |
-
import argparse
|
| 4 |
-
import json
|
| 5 |
-
import os
|
| 6 |
-
import time
|
| 7 |
-
from pathlib import Path
|
| 8 |
-
from typing import List
|
| 9 |
-
|
| 10 |
-
import numpy as np
|
| 11 |
-
import torch
|
| 12 |
-
|
| 13 |
-
from src.prompting import encode_prompt
|
| 14 |
-
from src.quantization_utils import load_fp32_model, load_quant_artifact
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
def _load_items(root: Path, n: int, seed: int = 42) -> List[dict]:
|
| 18 |
-
data = json.loads((root / "data" / "dev.json").read_text())
|
| 19 |
-
if n >= len(data):
|
| 20 |
-
return data
|
| 21 |
-
rng = np.random.default_rng(seed)
|
| 22 |
-
idxs = rng.choice(len(data), size=n, replace=False)
|
| 23 |
-
return [data[int(i)] for i in idxs]
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
def _bench_generate(tok, model, items: List[dict], device: str) -> float:
|
| 27 |
-
t0 = time.perf_counter()
|
| 28 |
-
for it in items:
|
| 29 |
-
input_ids = encode_prompt(tok, it["question"], it["db_id"], device=device, max_input_tokens=512).unsqueeze(0)
|
| 30 |
-
_ = model.generate(input_ids=input_ids, max_new_tokens=64, num_beams=4)
|
| 31 |
-
return time.perf_counter() - t0
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
def main() -> None:
|
| 35 |
-
p = argparse.ArgumentParser(description="Benchmark rollout generation latency for RL loops.")
|
| 36 |
-
p.add_argument("--base_model", default=os.environ.get("BASE_MODEL", "Salesforce/codet5-base"))
|
| 37 |
-
p.add_argument("--adapter", default="")
|
| 38 |
-
p.add_argument("--artifact", default="", help="Quantized artifact dir (optional).")
|
| 39 |
-
p.add_argument("--num_rollouts", type=int, default=128)
|
| 40 |
-
p.add_argument("--seed", type=int, default=42)
|
| 41 |
-
p.add_argument("--local_only", action="store_true")
|
| 42 |
-
args = p.parse_args()
|
| 43 |
-
|
| 44 |
-
device = "cpu"
|
| 45 |
-
root = Path(".")
|
| 46 |
-
items = _load_items(root, args.num_rollouts, args.seed)
|
| 47 |
-
|
| 48 |
-
tok, fp32 = load_fp32_model(
|
| 49 |
-
args.base_model,
|
| 50 |
-
adapter_path=args.adapter.strip() or None,
|
| 51 |
-
device=device,
|
| 52 |
-
local_only=args.local_only,
|
| 53 |
-
)
|
| 54 |
-
t_fp32 = _bench_generate(tok, fp32, items, device)
|
| 55 |
-
print(f"fp32: {t_fp32:.2f}s for {len(items)} rollouts ({len(items)/max(t_fp32,1e-9):.2f} rollouts/s)")
|
| 56 |
-
|
| 57 |
-
if args.artifact:
|
| 58 |
-
tokq, mq, meta = load_quant_artifact(args.artifact, device=device, local_only=True)
|
| 59 |
-
t_q = _bench_generate(tokq, mq, items, device)
|
| 60 |
-
mode = meta.get("mode", "quant")
|
| 61 |
-
print(f"{mode}: {t_q:.2f}s for {len(items)} rollouts ({len(items)/max(t_q,1e-9):.2f} rollouts/s)")
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
if __name__ == "__main__":
|
| 65 |
-
torch.set_grad_enabled(False)
|
| 66 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/error_dashboard.py
DELETED
|
@@ -1,99 +0,0 @@
|
|
| 1 |
-
|
| 2 |
-
import json
|
| 3 |
-
from collections import Counter
|
| 4 |
-
|
| 5 |
-
# ==============================
|
| 6 |
-
# LOAD LOGS
|
| 7 |
-
# ==============================
|
| 8 |
-
with open("results/error_logs.json") as f:
|
| 9 |
-
logs = json.load(f)
|
| 10 |
-
|
| 11 |
-
total_errors = len(logs)
|
| 12 |
-
|
| 13 |
-
# ==============================
|
| 14 |
-
# ERROR DISTRIBUTION
|
| 15 |
-
# ==============================
|
| 16 |
-
error_counts = Counter([e["error_type"] for e in logs])
|
| 17 |
-
|
| 18 |
-
print("\n" + "="*50)
|
| 19 |
-
print("📊 TEXT-to-SQL ERROR DASHBOARD")
|
| 20 |
-
print("="*50)
|
| 21 |
-
|
| 22 |
-
print(f"\n🔢 Total Errors Logged: {total_errors}")
|
| 23 |
-
|
| 24 |
-
print("\n📊 ERROR DISTRIBUTION:")
|
| 25 |
-
print("-"*30)
|
| 26 |
-
for k, v in error_counts.items():
|
| 27 |
-
percent = (v / total_errors) * 100
|
| 28 |
-
print(f"{k:<20} : {v:>4} ({percent:.1f}%)")
|
| 29 |
-
|
| 30 |
-
# ==============================
|
| 31 |
-
# TOP ERROR
|
| 32 |
-
# ==============================
|
| 33 |
-
top_error = error_counts.most_common(1)[0]
|
| 34 |
-
|
| 35 |
-
print("\n🔥 MOST COMMON ERROR:")
|
| 36 |
-
print("-"*30)
|
| 37 |
-
print(f"{top_error[0]} ({top_error[1]} times)")
|
| 38 |
-
|
| 39 |
-
# ==============================
|
| 40 |
-
# SQL OPERATION ANALYSIS
|
| 41 |
-
# ==============================
|
| 42 |
-
join_count = 0
|
| 43 |
-
where_count = 0
|
| 44 |
-
group_count = 0
|
| 45 |
-
order_count = 0
|
| 46 |
-
|
| 47 |
-
for e in logs:
|
| 48 |
-
sql = e["sql"].lower()
|
| 49 |
-
|
| 50 |
-
if "join" in sql:
|
| 51 |
-
join_count += 1
|
| 52 |
-
if "where" in sql:
|
| 53 |
-
where_count += 1
|
| 54 |
-
if "group by" in sql:
|
| 55 |
-
group_count += 1
|
| 56 |
-
if "order by" in sql:
|
| 57 |
-
order_count += 1
|
| 58 |
-
|
| 59 |
-
print("\n🧠 SQL OPERATION ANALYSIS:")
|
| 60 |
-
print("-"*30)
|
| 61 |
-
print(f"JOIN used in : {join_count} queries")
|
| 62 |
-
print(f"WHERE used in : {where_count} queries")
|
| 63 |
-
print(f"GROUP BY used in : {group_count} queries")
|
| 64 |
-
print(f"ORDER BY used in : {order_count} queries")
|
| 65 |
-
|
| 66 |
-
# ==============================
|
| 67 |
-
# SAMPLE ERRORS
|
| 68 |
-
# ==============================
|
| 69 |
-
print("\n🧪 SAMPLE ERROR CASES:")
|
| 70 |
-
print("-"*50)
|
| 71 |
-
|
| 72 |
-
for i, e in enumerate(logs[:3], 1):
|
| 73 |
-
print(f"\nCase {i}:")
|
| 74 |
-
print(f"Q : {e['question']}")
|
| 75 |
-
print(f"SQL : {e['sql']}")
|
| 76 |
-
print(f"Type: {e['error_type']}")
|
| 77 |
-
|
| 78 |
-
# ==============================
|
| 79 |
-
# FINAL INSIGHT
|
| 80 |
-
# ==============================
|
| 81 |
-
print("\n📌 FINAL INSIGHT:")
|
| 82 |
-
print("-"*30)
|
| 83 |
-
|
| 84 |
-
if top_error[0] == "wrong_column":
|
| 85 |
-
print("⚠️ Model struggles with column selection (schema understanding issue).")
|
| 86 |
-
|
| 87 |
-
elif top_error[0] == "wrong_table":
|
| 88 |
-
print("⚠️ Model struggles with correct table mapping.")
|
| 89 |
-
|
| 90 |
-
elif top_error[0] == "syntax_error":
|
| 91 |
-
print("⚠️ Model generates invalid SQL syntax.")
|
| 92 |
-
|
| 93 |
-
else:
|
| 94 |
-
print("⚠️ Mixed errors — needs general improvement.")
|
| 95 |
-
|
| 96 |
-
print("\n" + "="*50)
|
| 97 |
-
print("✅ DASHBOARD COMPLETE")
|
| 98 |
-
print("="*50)
|
| 99 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/evaluate.py
DELETED
|
@@ -1,170 +0,0 @@
|
|
| 1 |
-
from __future__ import annotations
|
| 2 |
-
|
| 3 |
-
import os
|
| 4 |
-
import sqlite3
|
| 5 |
-
from contextlib import closing
|
| 6 |
-
from typing import Dict, List
|
| 7 |
-
|
| 8 |
-
import torch
|
| 9 |
-
from datasets import load_dataset
|
| 10 |
-
from peft import PeftModel
|
| 11 |
-
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
| 12 |
-
from trl import AutoModelForSeq2SeqLMWithValueHead
|
| 13 |
-
|
| 14 |
-
import sys
|
| 15 |
-
|
| 16 |
-
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 17 |
-
sys.path.append(PROJECT_ROOT)
|
| 18 |
-
from src.execution_reward import execution_reward # noqa: E402
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
BASE_MODEL = os.environ.get("BASE_MODEL", "t5-small")
|
| 22 |
-
DB_ROOT = os.path.join(PROJECT_ROOT, "data", "database")
|
| 23 |
-
|
| 24 |
-
# Prefer RL best model if present; otherwise fall back.
|
| 25 |
-
RL_DIR = os.path.join(PROJECT_ROOT, "outputs", "rlhf_text2sql", "best_model")
|
| 26 |
-
if not os.path.isdir(RL_DIR):
|
| 27 |
-
RL_DIR = os.path.join(PROJECT_ROOT, "outputs", "rlhf_text2sql")
|
| 28 |
-
|
| 29 |
-
SPLIT = "train[:100]" # quick sanity check
|
| 30 |
-
MAX_NEW_TOKENS = 128
|
| 31 |
-
|
| 32 |
-
PREFIX = "translate English to SQL:"
|
| 33 |
-
MAX_SCHEMA_CHARS = 1500
|
| 34 |
-
MAX_INPUT_TOKENS = 512
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
|
| 38 |
-
device = "mps" if torch.backends.mps.is_available() else "cpu"
|
| 39 |
-
print("Using device:", device)
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
def get_db_path(db_id: str) -> str:
|
| 43 |
-
return os.path.join(DB_ROOT, db_id, f"{db_id}.sqlite")
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
_SCHEMA_CACHE: Dict[str, str] = {}
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
def get_db_schema_text(db_path: str) -> str:
|
| 50 |
-
if db_path in _SCHEMA_CACHE:
|
| 51 |
-
return _SCHEMA_CACHE[db_path]
|
| 52 |
-
schema_text = ""
|
| 53 |
-
try:
|
| 54 |
-
with closing(sqlite3.connect(db_path)) as conn:
|
| 55 |
-
cur = conn.cursor()
|
| 56 |
-
tables = cur.execute(
|
| 57 |
-
"SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%';"
|
| 58 |
-
).fetchall()
|
| 59 |
-
for (tname,) in tables:
|
| 60 |
-
cols = cur.execute(f'PRAGMA table_info(\"{tname}\")').fetchall()
|
| 61 |
-
col_names = [c[1] for c in cols if c and isinstance(c[1], str)]
|
| 62 |
-
schema_text += f"{tname}({', '.join(col_names)}) "
|
| 63 |
-
except Exception:
|
| 64 |
-
schema_text = ""
|
| 65 |
-
if len(schema_text) > MAX_SCHEMA_CHARS:
|
| 66 |
-
schema_text = schema_text[:MAX_SCHEMA_CHARS]
|
| 67 |
-
_SCHEMA_CACHE[db_path] = schema_text
|
| 68 |
-
return schema_text
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
def encode_prompt(tokenizer, question: str, schema: str) -> torch.Tensor:
|
| 72 |
-
schema = (schema or "")[:MAX_SCHEMA_CHARS]
|
| 73 |
-
prefix_schema = f"{PREFIX}\n\nSchema:\n"
|
| 74 |
-
mid = "\n\nQuestion:\n"
|
| 75 |
-
suffix = f"{question}\n\nSQL:"
|
| 76 |
-
|
| 77 |
-
prefix_ids = tokenizer.encode(prefix_schema, add_special_tokens=False)
|
| 78 |
-
schema_ids = tokenizer.encode(schema, add_special_tokens=False)
|
| 79 |
-
mid_ids = tokenizer.encode(mid, add_special_tokens=False)
|
| 80 |
-
suffix_ids = tokenizer.encode(suffix, add_special_tokens=False)
|
| 81 |
-
|
| 82 |
-
eos_id = tokenizer.eos_token_id
|
| 83 |
-
max_without_eos = MAX_INPUT_TOKENS - (1 if eos_id is not None else 0)
|
| 84 |
-
|
| 85 |
-
fixed_len = len(prefix_ids) + len(mid_ids) + len(suffix_ids)
|
| 86 |
-
if fixed_len > max_without_eos:
|
| 87 |
-
keep = max(0, max_without_eos - (len(prefix_ids) + len(mid_ids)))
|
| 88 |
-
suffix_ids = suffix_ids[:keep]
|
| 89 |
-
fixed_len = len(prefix_ids) + len(mid_ids) + len(suffix_ids)
|
| 90 |
-
|
| 91 |
-
remaining_for_schema = max_without_eos - fixed_len
|
| 92 |
-
if remaining_for_schema < 0:
|
| 93 |
-
remaining_for_schema = 0
|
| 94 |
-
schema_ids = schema_ids[:remaining_for_schema]
|
| 95 |
-
|
| 96 |
-
ids = (prefix_ids + schema_ids + mid_ids + suffix_ids)[:max_without_eos]
|
| 97 |
-
if eos_id is not None:
|
| 98 |
-
ids = ids + [eos_id]
|
| 99 |
-
|
| 100 |
-
return torch.tensor(ids, dtype=torch.long).to(device)
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
def load_model_and_tokenizer():
|
| 104 |
-
# Try loading the PPO-saved value-head model directly.
|
| 105 |
-
try:
|
| 106 |
-
tok = AutoTokenizer.from_pretrained(RL_DIR)
|
| 107 |
-
mdl = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(RL_DIR).to(device)
|
| 108 |
-
return tok, mdl
|
| 109 |
-
except Exception:
|
| 110 |
-
pass
|
| 111 |
-
|
| 112 |
-
# Fallback: treat RL_DIR as a LoRA adapter directory.
|
| 113 |
-
tok = AutoTokenizer.from_pretrained(BASE_MODEL)
|
| 114 |
-
if tok.pad_token_id is None:
|
| 115 |
-
tok.pad_token = tok.eos_token
|
| 116 |
-
base = AutoModelForSeq2SeqLM.from_pretrained(BASE_MODEL).to(device)
|
| 117 |
-
try:
|
| 118 |
-
base = PeftModel.from_pretrained(base, RL_DIR)
|
| 119 |
-
except Exception:
|
| 120 |
-
# Final fallback: use SFT adapter (if RL adapter not found)
|
| 121 |
-
sft_dir = os.path.join(PROJECT_ROOT, "checkpoints", "sft_adapter")
|
| 122 |
-
base = PeftModel.from_pretrained(base, sft_dir)
|
| 123 |
-
return tok, base
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
def main() -> None:
|
| 127 |
-
tokenizer, model = load_model_and_tokenizer()
|
| 128 |
-
model.eval()
|
| 129 |
-
|
| 130 |
-
ds = load_dataset("spider", split=SPLIT)
|
| 131 |
-
|
| 132 |
-
correct = 0
|
| 133 |
-
valid = 0
|
| 134 |
-
|
| 135 |
-
for i, ex in enumerate(ds, start=1):
|
| 136 |
-
question = ex["question"]
|
| 137 |
-
gold_sql = ex["query"]
|
| 138 |
-
db_id = ex["db_id"]
|
| 139 |
-
db_path = get_db_path(db_id)
|
| 140 |
-
schema = get_db_schema_text(db_path)
|
| 141 |
-
|
| 142 |
-
inp = encode_prompt(tokenizer, question, schema)
|
| 143 |
-
with torch.no_grad():
|
| 144 |
-
out = model.generate(
|
| 145 |
-
input_ids=inp.unsqueeze(0),
|
| 146 |
-
max_new_tokens=MAX_NEW_TOKENS,
|
| 147 |
-
do_sample=False,
|
| 148 |
-
num_beams=1,
|
| 149 |
-
pad_token_id=tokenizer.pad_token_id,
|
| 150 |
-
eos_token_id=tokenizer.eos_token_id,
|
| 151 |
-
)
|
| 152 |
-
pred_sql = tokenizer.decode(out[0], skip_special_tokens=True)
|
| 153 |
-
r = execution_reward(pred_sql, db_path, gold_sql)
|
| 154 |
-
if r > -1.0:
|
| 155 |
-
valid += 1
|
| 156 |
-
if r >= 1.0:
|
| 157 |
-
correct += 1
|
| 158 |
-
|
| 159 |
-
if i % 25 == 0:
|
| 160 |
-
print(f"Evaluated {i}/{len(ds)}")
|
| 161 |
-
|
| 162 |
-
n = len(ds)
|
| 163 |
-
print("\nRESULTS")
|
| 164 |
-
print(f"examples: {n}")
|
| 165 |
-
print(f"execution_accuracy: {correct/n:.3f}")
|
| 166 |
-
print(f"valid_sql_rate: {valid/n:.3f}")
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
if __name__ == "__main__":
|
| 170 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/plot_task2.py
DELETED
|
@@ -1,58 +0,0 @@
|
|
| 1 |
-
import matplotlib.pyplot as plt
|
| 2 |
-
import seaborn as sns
|
| 3 |
-
|
| 4 |
-
# ==========================================
|
| 5 |
-
# 1. EXTRACTED DATA FROM TERMINAL
|
| 6 |
-
# ==========================================
|
| 7 |
-
# Error Distribution Data
|
| 8 |
-
error_types = ['wrong_column', 'wrong_table', 'ambiguous_column', 'other']
|
| 9 |
-
error_counts = [61, 11, 4, 1]
|
| 10 |
-
|
| 11 |
-
# SQL Operation Analysis Data
|
| 12 |
-
sql_ops = ['WHERE', 'JOIN', 'ORDER BY', 'GROUP BY']
|
| 13 |
-
op_counts = [55, 36, 20, 14]
|
| 14 |
-
|
| 15 |
-
# ==========================================
|
| 16 |
-
# 2. SET UP THE DASHBOARD LAYOUT
|
| 17 |
-
# ==========================================
|
| 18 |
-
# Use a clean, modern aesthetic
|
| 19 |
-
sns.set_theme(style="whitegrid")
|
| 20 |
-
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))
|
| 21 |
-
|
| 22 |
-
# ==========================================
|
| 23 |
-
# 3. PLOT 1: ERROR DISTRIBUTION (Horizontal Bar)
|
| 24 |
-
# ==========================================
|
| 25 |
-
sns.barplot(x=error_counts, y=error_types, ax=ax1, palette="flare")
|
| 26 |
-
ax1.set_title('Primary Cause of Failure (Total: 77 Errors)', fontsize=14, pad=15, fontweight='bold')
|
| 27 |
-
ax1.set_xlabel('Number of Queries')
|
| 28 |
-
ax1.set_ylabel('')
|
| 29 |
-
|
| 30 |
-
# Add actual numbers next to the bars
|
| 31 |
-
for i, v in enumerate(error_counts):
|
| 32 |
-
ax1.text(v + 1.5, i, f"{v}", color='#333333', va='center', fontweight='bold')
|
| 33 |
-
|
| 34 |
-
# ==========================================
|
| 35 |
-
# 4. PLOT 2: SQL OPERATIONS (Vertical Bar)
|
| 36 |
-
# ==========================================
|
| 37 |
-
sns.barplot(x=sql_ops, y=op_counts, ax=ax2, palette="crest")
|
| 38 |
-
ax2.set_title('Clauses Present in Failed Queries', fontsize=14, pad=15, fontweight='bold')
|
| 39 |
-
ax2.set_ylabel('Frequency')
|
| 40 |
-
ax2.set_xlabel('')
|
| 41 |
-
|
| 42 |
-
# Add actual numbers on top of the bars
|
| 43 |
-
for i, v in enumerate(op_counts):
|
| 44 |
-
ax2.text(i, v + 1, str(v), color='#333333', ha='center', fontweight='bold')
|
| 45 |
-
|
| 46 |
-
# ==========================================
|
| 47 |
-
# 5. RENDER AND SAVE
|
| 48 |
-
# ==========================================
|
| 49 |
-
plt.suptitle('Text-to-SQL Error Diagnostic Dashboard', fontsize=18, fontweight='heavy', y=1.05)
|
| 50 |
-
sns.despine(left=True, bottom=True) # Removes clunky borders
|
| 51 |
-
plt.tight_layout()
|
| 52 |
-
|
| 53 |
-
# Save the plot as a high-res image for your report!
|
| 54 |
-
plt.savefig('error_diagnostic_plot.png', dpi=300, bbox_inches='tight')
|
| 55 |
-
print("✅ Plot successfully saved as 'error_diagnostic_plot.png'")
|
| 56 |
-
|
| 57 |
-
# Display the plot
|
| 58 |
-
plt.show()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/plot_task3.py
DELETED
|
@@ -1,15 +0,0 @@
|
|
| 1 |
-
import matplotlib.pyplot as plt
|
| 2 |
-
|
| 3 |
-
labels = ["Without", "With"]
|
| 4 |
-
constraint = [0, 88]
|
| 5 |
-
|
| 6 |
-
plt.figure()
|
| 7 |
-
plt.bar(labels, constraint)
|
| 8 |
-
|
| 9 |
-
plt.title("Constraint Satisfaction (Task 3)")
|
| 10 |
-
plt.ylabel("Percentage")
|
| 11 |
-
|
| 12 |
-
plt.savefig("task3_constraint.png")
|
| 13 |
-
plt.show()
|
| 14 |
-
|
| 15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/plot_task3_plotly.py
DELETED
|
@@ -1,103 +0,0 @@
|
|
| 1 |
-
import plotly.graph_objects as go
|
| 2 |
-
from plotly.subplots import make_subplots
|
| 3 |
-
|
| 4 |
-
# ==========================================
|
| 5 |
-
# 1. YOUR DATA
|
| 6 |
-
# ==========================================
|
| 7 |
-
models = ['FP32 (Base)', 'INT8 Dynamic', 'INT8 Decoder-Only']
|
| 8 |
-
|
| 9 |
-
# Accuracy (multiplied by 100 for percentage)
|
| 10 |
-
accuracy = [36.0, 36.0, 38.0]
|
| 11 |
-
|
| 12 |
-
# Latency metrics
|
| 13 |
-
lat_mean = [3.11, 1.65, 1.66]
|
| 14 |
-
lat_p50 = [2.94, 1.54, 1.56]
|
| 15 |
-
lat_p90 = [4.64, 2.44, 2.48]
|
| 16 |
-
|
| 17 |
-
# ==========================================
|
| 18 |
-
# 2. SET UP THE SIDE-BY-SIDE LAYOUT
|
| 19 |
-
# ==========================================
|
| 20 |
-
fig = make_subplots(
|
| 21 |
-
rows=1, cols=2,
|
| 22 |
-
subplot_titles=(
|
| 23 |
-
"<b>Model Accuracy (Execution)</b>",
|
| 24 |
-
"<b>Inference Latency Profile</b>"
|
| 25 |
-
),
|
| 26 |
-
horizontal_spacing=0.1
|
| 27 |
-
)
|
| 28 |
-
|
| 29 |
-
# ==========================================
|
| 30 |
-
# 3. LEFT CHART: ACCURACY
|
| 31 |
-
# ==========================================
|
| 32 |
-
fig.add_trace(go.Bar(
|
| 33 |
-
x=models,
|
| 34 |
-
y=accuracy,
|
| 35 |
-
name="Execution Accuracy",
|
| 36 |
-
marker_color=['#94a3b8', '#38bdf8', '#10b981'], # Gray, Blue, Green
|
| 37 |
-
text=[f"{val:.1f}%" for val in accuracy],
|
| 38 |
-
textposition='auto',
|
| 39 |
-
textfont=dict(size=14, color='white', family="Arial Black"),
|
| 40 |
-
showlegend=False
|
| 41 |
-
), row=1, col=1)
|
| 42 |
-
|
| 43 |
-
# ==========================================
|
| 44 |
-
# 4. RIGHT CHART: LATENCY PROFILE
|
| 45 |
-
# ==========================================
|
| 46 |
-
# P50 Latency
|
| 47 |
-
fig.add_trace(go.Bar(
|
| 48 |
-
x=models, y=lat_p50,
|
| 49 |
-
name="Median (P50)",
|
| 50 |
-
marker_color="#ece80a" # Light Blue
|
| 51 |
-
), row=1, col=2)
|
| 52 |
-
|
| 53 |
-
# Mean Latency
|
| 54 |
-
fig.add_trace(go.Bar(
|
| 55 |
-
x=models, y=lat_mean,
|
| 56 |
-
name="Mean Latency",
|
| 57 |
-
marker_color="#3b4da9" # Standard Blue
|
| 58 |
-
), row=1, col=2)
|
| 59 |
-
|
| 60 |
-
# P90 Latency
|
| 61 |
-
fig.add_trace(go.Bar(
|
| 62 |
-
x=models, y=lat_p90,
|
| 63 |
-
name="90th Percentile (P90)",
|
| 64 |
-
marker_color="#d974e2" # Dark Blue
|
| 65 |
-
), row=1, col=2)
|
| 66 |
-
|
| 67 |
-
# ==========================================
|
| 68 |
-
# 5. APPLY ULTRA-MODERN STYLING
|
| 69 |
-
# ==========================================
|
| 70 |
-
fig.update_layout(
|
| 71 |
-
title=dict(
|
| 72 |
-
text="<b>Task 5: FP32 vs. INT8 Quantization Performance</b>",
|
| 73 |
-
font=dict(size=22, color='#1e293b'),
|
| 74 |
-
x=0.5
|
| 75 |
-
),
|
| 76 |
-
plot_bgcolor='white',
|
| 77 |
-
paper_bgcolor='white',
|
| 78 |
-
barmode='group',
|
| 79 |
-
legend=dict(
|
| 80 |
-
orientation="h",
|
| 81 |
-
yanchor="bottom", y=1.05,
|
| 82 |
-
xanchor="center", x=0.8,
|
| 83 |
-
bgcolor='rgba(255,255,255,0.8)'
|
| 84 |
-
),
|
| 85 |
-
font=dict(family="-apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Helvetica, Arial, sans-serif"),
|
| 86 |
-
margin=dict(t=120, b=60, l=60, r=40)
|
| 87 |
-
)
|
| 88 |
-
|
| 89 |
-
# Style Left Axes
|
| 90 |
-
fig.update_yaxes(title_text="<b>Accuracy (%)</b>", range=[0, 45], gridcolor='#f1f5f9', row=1, col=1)
|
| 91 |
-
fig.update_xaxes(tickfont=dict(weight='bold'), row=1, col=1)
|
| 92 |
-
|
| 93 |
-
# Style Right Axes
|
| 94 |
-
fig.update_yaxes(title_text="<b>Seconds per Query</b>", gridcolor='#f1f5f9', row=1, col=2)
|
| 95 |
-
fig.update_xaxes(tickfont=dict(weight='bold'), row=1, col=2)
|
| 96 |
-
|
| 97 |
-
# ==========================================
|
| 98 |
-
# 6. RENDER AND SAVE
|
| 99 |
-
# ==========================================
|
| 100 |
-
html_file = "task5_quantization_dashboard.html"
|
| 101 |
-
fig.write_html(html_file)
|
| 102 |
-
print(f"✅ Interactive Plotly Dashboard saved to: {html_file}")
|
| 103 |
-
fig.show()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/quantize_export.py
DELETED
|
@@ -1,86 +0,0 @@
|
|
| 1 |
-
from __future__ import annotations
|
| 2 |
-
|
| 3 |
-
import argparse
|
| 4 |
-
import os
|
| 5 |
-
from pathlib import Path
|
| 6 |
-
|
| 7 |
-
import torch
|
| 8 |
-
|
| 9 |
-
from src.quantization_utils import (
|
| 10 |
-
load_bnb_quantized_model,
|
| 11 |
-
load_fp32_model,
|
| 12 |
-
quantize_dynamic_int8,
|
| 13 |
-
quantize_dynamic_int8_decoder_only,
|
| 14 |
-
save_quant_artifact,
|
| 15 |
-
)
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
def main() -> None:
|
| 19 |
-
p = argparse.ArgumentParser(description="Export quantized Seq2Seq model artifacts for CPU inference.")
|
| 20 |
-
p.add_argument("--base_model", default=os.environ.get("BASE_MODEL", "Salesforce/codet5-base"))
|
| 21 |
-
p.add_argument("--adapter", default="", help="Optional LoRA adapter directory.")
|
| 22 |
-
p.add_argument("--out_dir", required=True, help="Output directory for artifact.")
|
| 23 |
-
p.add_argument(
|
| 24 |
-
"--mode",
|
| 25 |
-
required=True,
|
| 26 |
-
choices=["fp32", "int8_dynamic", "int8_decoder_dynamic", "int8_bnb", "int4_bnb"],
|
| 27 |
-
)
|
| 28 |
-
p.add_argument("--device", default="cpu", help="cpu|cuda (bnb requires cuda)")
|
| 29 |
-
p.add_argument("--local_only", action="store_true", help="Do not hit network; use HF cache only.")
|
| 30 |
-
args = p.parse_args()
|
| 31 |
-
|
| 32 |
-
adapter = args.adapter.strip() or None
|
| 33 |
-
out_dir = Path(args.out_dir)
|
| 34 |
-
|
| 35 |
-
if args.mode == "fp32":
|
| 36 |
-
tok, model = load_fp32_model(args.base_model, adapter_path=adapter, device=args.device, local_only=args.local_only)
|
| 37 |
-
save_quant_artifact(out_dir, mode="fp32", base_model=args.base_model, adapter_path=adapter, tokenizer=tok, model=model)
|
| 38 |
-
return
|
| 39 |
-
|
| 40 |
-
if args.mode == "int8_dynamic":
|
| 41 |
-
tok, model = load_fp32_model(args.base_model, adapter_path=adapter, device="cpu", local_only=args.local_only)
|
| 42 |
-
model = quantize_dynamic_int8(model)
|
| 43 |
-
save_quant_artifact(out_dir, mode="int8_dynamic", base_model=args.base_model, adapter_path=adapter, tokenizer=tok, model=model)
|
| 44 |
-
return
|
| 45 |
-
|
| 46 |
-
if args.mode == "int8_decoder_dynamic":
|
| 47 |
-
tok, model = load_fp32_model(args.base_model, adapter_path=adapter, device="cpu", local_only=args.local_only)
|
| 48 |
-
model = quantize_dynamic_int8_decoder_only(model)
|
| 49 |
-
save_quant_artifact(
|
| 50 |
-
out_dir,
|
| 51 |
-
mode="int8_decoder_dynamic",
|
| 52 |
-
base_model=args.base_model,
|
| 53 |
-
adapter_path=adapter,
|
| 54 |
-
tokenizer=tok,
|
| 55 |
-
model=model,
|
| 56 |
-
)
|
| 57 |
-
return
|
| 58 |
-
|
| 59 |
-
if args.mode == "int8_bnb":
|
| 60 |
-
tok, model = load_bnb_quantized_model(
|
| 61 |
-
args.base_model,
|
| 62 |
-
adapter_path=adapter,
|
| 63 |
-
device=args.device,
|
| 64 |
-
local_only=args.local_only,
|
| 65 |
-
load_in_8bit=True,
|
| 66 |
-
)
|
| 67 |
-
# Note: saving bnb quantized weights in a portable way is non-trivial; we still save state_dict for reference.
|
| 68 |
-
save_quant_artifact(out_dir, mode="int8_bnb", base_model=args.base_model, adapter_path=adapter, tokenizer=tok, model=model)
|
| 69 |
-
return
|
| 70 |
-
|
| 71 |
-
if args.mode == "int4_bnb":
|
| 72 |
-
tok, model = load_bnb_quantized_model(
|
| 73 |
-
args.base_model,
|
| 74 |
-
adapter_path=adapter,
|
| 75 |
-
device=args.device,
|
| 76 |
-
local_only=args.local_only,
|
| 77 |
-
load_in_4bit=True,
|
| 78 |
-
)
|
| 79 |
-
save_quant_artifact(out_dir, mode="int4_bnb", base_model=args.base_model, adapter_path=adapter, tokenizer=tok, model=model)
|
| 80 |
-
return
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
if __name__ == "__main__":
|
| 84 |
-
torch.set_grad_enabled(False)
|
| 85 |
-
main()
|
| 86 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/quantized_infer_harness.py
DELETED
|
@@ -1,46 +0,0 @@
|
|
| 1 |
-
from __future__ import annotations
|
| 2 |
-
|
| 3 |
-
import argparse
|
| 4 |
-
import json
|
| 5 |
-
import time
|
| 6 |
-
from pathlib import Path
|
| 7 |
-
|
| 8 |
-
from src.quantized_text2sql_engine import QuantizedText2SQLEngine
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
def main() -> None:
|
| 12 |
-
p = argparse.ArgumentParser(description="Production-style inference harness for quantized artifacts.")
|
| 13 |
-
p.add_argument("--artifact", required=True, help="Quant artifact dir from scripts/quantize_export.py")
|
| 14 |
-
p.add_argument("--num_samples", type=int, default=128)
|
| 15 |
-
p.add_argument("--out", default="results/task5_quant_infer.json")
|
| 16 |
-
args = p.parse_args()
|
| 17 |
-
|
| 18 |
-
root = Path(".")
|
| 19 |
-
dev = json.loads((root / "data" / "dev.json").read_text())
|
| 20 |
-
dev = dev[: args.num_samples]
|
| 21 |
-
|
| 22 |
-
engine = QuantizedText2SQLEngine(args.artifact, device="cpu")
|
| 23 |
-
pairs = [(x["question"], x["db_id"]) for x in dev]
|
| 24 |
-
|
| 25 |
-
t0 = time.perf_counter()
|
| 26 |
-
results = engine.ask_batch_execute(pairs)
|
| 27 |
-
dt = time.perf_counter() - t0
|
| 28 |
-
|
| 29 |
-
out = {
|
| 30 |
-
"n": len(results),
|
| 31 |
-
"seconds": dt,
|
| 32 |
-
"qps": len(results) / max(dt, 1e-9),
|
| 33 |
-
"artifact": args.artifact,
|
| 34 |
-
"meta": engine.meta,
|
| 35 |
-
"results": results[:10], # sample
|
| 36 |
-
}
|
| 37 |
-
|
| 38 |
-
out_path = Path(args.out)
|
| 39 |
-
out_path.parent.mkdir(parents=True, exist_ok=True)
|
| 40 |
-
out_path.write_text(json.dumps(out, indent=2))
|
| 41 |
-
print(json.dumps(out, indent=2))
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
if __name__ == "__main__":
|
| 45 |
-
main()
|
| 46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/constrained_decoding.py
DELETED
|
@@ -1,1058 +0,0 @@
|
|
| 1 |
-
# from __future__ import annotations
|
| 2 |
-
|
| 3 |
-
# import re
|
| 4 |
-
# import threading
|
| 5 |
-
# from dataclasses import dataclass
|
| 6 |
-
# from typing import Dict, Iterable, List, Optional, Sequence, Set
|
| 7 |
-
|
| 8 |
-
# import torch
|
| 9 |
-
# from transformers.generation.logits_process import LogitsProcessor
|
| 10 |
-
|
| 11 |
-
# from schema_constraints import ConstraintGraph, build_constraint_graph
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
# def _infer_expected_identifier(prefix_text: str) -> Optional[str]:
|
| 15 |
-
# s = re.sub(r"\s+", " ", prefix_text.lower())
|
| 16 |
-
# last_from = s.rfind(" from ")
|
| 17 |
-
# last_join = s.rfind(" join ")
|
| 18 |
-
# last_select = s.rfind(" select ")
|
| 19 |
-
# last_where = s.rfind(" where ")
|
| 20 |
-
# last_on = s.rfind(" on ")
|
| 21 |
-
# last_group = s.rfind(" group by ")
|
| 22 |
-
# last_order = s.rfind(" order by ")
|
| 23 |
-
# last_having = s.rfind(" having ")
|
| 24 |
-
|
| 25 |
-
# last_table_kw = max(last_from, last_join)
|
| 26 |
-
# last_col_kw = max(last_select, last_where, last_on, last_group, last_order, last_having)
|
| 27 |
-
|
| 28 |
-
# if last_table_kw < 0 and last_col_kw < 0:
|
| 29 |
-
# return None
|
| 30 |
-
# if last_table_kw > last_col_kw:
|
| 31 |
-
# return "table"
|
| 32 |
-
# if last_col_kw > last_table_kw:
|
| 33 |
-
# return "column"
|
| 34 |
-
# return None
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
# class _TrieNode:
|
| 38 |
-
# __slots__ = ("children", "terminal")
|
| 39 |
-
|
| 40 |
-
# def __init__(self) -> None:
|
| 41 |
-
# self.children: Dict[int, _TrieNode] = {}
|
| 42 |
-
# self.terminal: bool = False
|
| 43 |
-
|
| 44 |
-
# def insert(self, token_ids: Sequence[int]) -> None:
|
| 45 |
-
# node: _TrieNode = self
|
| 46 |
-
# for tid in token_ids:
|
| 47 |
-
# tid_i = int(tid)
|
| 48 |
-
# nxt = node.children.get(tid_i)
|
| 49 |
-
# if nxt is None:
|
| 50 |
-
# nxt = _TrieNode()
|
| 51 |
-
# node.children[tid_i] = nxt
|
| 52 |
-
# node = nxt
|
| 53 |
-
# node.terminal = True
|
| 54 |
-
|
| 55 |
-
# def walk(self, prefix: Sequence[int]) -> Optional["_TrieNode"]:
|
| 56 |
-
# node: _TrieNode = self
|
| 57 |
-
# for tid in prefix:
|
| 58 |
-
# node = node.children.get(int(tid)) # type: ignore[assignment]
|
| 59 |
-
# if node is None:
|
| 60 |
-
# return None
|
| 61 |
-
# return node
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
# def _encode_identifier(tokenizer, name: str) -> List[int]:
|
| 65 |
-
# # Leading space encourages word-start markers (e.g. "Ġ" in RoBERTa BPE).
|
| 66 |
-
# return tokenizer.encode(" " + name, add_special_tokens=False)
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
# def _build_trie(tokenizer, names: Iterable[str]) -> _TrieNode:
|
| 70 |
-
# trie = _TrieNode()
|
| 71 |
-
# for n in names:
|
| 72 |
-
# if not n:
|
| 73 |
-
# continue
|
| 74 |
-
# try:
|
| 75 |
-
# ids = _encode_identifier(tokenizer, n)
|
| 76 |
-
# except Exception:
|
| 77 |
-
# continue
|
| 78 |
-
# if ids:
|
| 79 |
-
# trie.insert(ids)
|
| 80 |
-
# return trie
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
# def _allow_always_token_ids(tokenizer) -> torch.Tensor:
|
| 84 |
-
# # Allow common delimiters so the model can end an identifier.
|
| 85 |
-
# toks = [",", ")", "(", "\n", ".", ";"]
|
| 86 |
-
# ids: Set[int] = set()
|
| 87 |
-
# for t in toks:
|
| 88 |
-
# try:
|
| 89 |
-
# for tid in tokenizer.encode(t, add_special_tokens=False):
|
| 90 |
-
# ids.add(int(tid))
|
| 91 |
-
# except Exception:
|
| 92 |
-
# continue
|
| 93 |
-
# return torch.tensor(sorted(ids), dtype=torch.long)
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
# @dataclass
|
| 97 |
-
# class _PerDbTokenSets:
|
| 98 |
-
# fp: str
|
| 99 |
-
# table_trie: _TrieNode
|
| 100 |
-
# column_trie: _TrieNode
|
| 101 |
-
# allow_always: torch.Tensor
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
# _DB_TOKENSET_LOCK = threading.Lock()
|
| 105 |
-
# _DB_TOKENSETS: Dict[str, _PerDbTokenSets] = {}
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
# def _per_db_tokensets(tokenizer, graph: ConstraintGraph) -> _PerDbTokenSets:
|
| 109 |
-
# with _DB_TOKENSET_LOCK:
|
| 110 |
-
# cached = _DB_TOKENSETS.get(graph.db_path)
|
| 111 |
-
# if cached is not None and cached.fp == graph.fingerprint:
|
| 112 |
-
# return cached
|
| 113 |
-
|
| 114 |
-
# out = _PerDbTokenSets(
|
| 115 |
-
# fp=graph.fingerprint,
|
| 116 |
-
# table_trie=_build_trie(tokenizer, graph.tables),
|
| 117 |
-
# column_trie=_build_trie(tokenizer, graph.all_columns),
|
| 118 |
-
# allow_always=_allow_always_token_ids(tokenizer),
|
| 119 |
-
# )
|
| 120 |
-
# with _DB_TOKENSET_LOCK:
|
| 121 |
-
# _DB_TOKENSETS[graph.db_path] = out
|
| 122 |
-
# return out
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
# class BatchSchemaConstrainedLogitsProcessor(LogitsProcessor):
|
| 126 |
-
# """
|
| 127 |
-
# Schema-aware constrained decoding per item in the generation batch.
|
| 128 |
-
# Uses a tokenizer-based trie so multi-token identifiers can be constrained.
|
| 129 |
-
# """
|
| 130 |
-
|
| 131 |
-
# def __init__(self, tokenizer, db_paths: Sequence[str], *, max_prefix_tokens: int = 48):
|
| 132 |
-
# self.tokenizer = tokenizer
|
| 133 |
-
# self.db_paths = list(db_paths)
|
| 134 |
-
# self.max_prefix_tokens = int(max_prefix_tokens)
|
| 135 |
-
|
| 136 |
-
# self._graphs = [build_constraint_graph(p) for p in self.db_paths]
|
| 137 |
-
# self._token_sets = [_per_db_tokensets(tokenizer, g) for g in self._graphs]
|
| 138 |
-
|
| 139 |
-
# def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
| 140 |
-
# if input_ids.dim() != 2 or scores.dim() != 2:
|
| 141 |
-
# return scores
|
| 142 |
-
|
| 143 |
-
# batch = input_ids.size(0)
|
| 144 |
-
# if batch != len(self._graphs):
|
| 145 |
-
# return scores
|
| 146 |
-
|
| 147 |
-
# for i in range(batch):
|
| 148 |
-
# tail_ids = input_ids[i, -self.max_prefix_tokens :].tolist()
|
| 149 |
-
# prefix_text = self.tokenizer.decode(tail_ids, skip_special_tokens=True)
|
| 150 |
-
# expected = _infer_expected_identifier(prefix_text)
|
| 151 |
-
# if expected is None:
|
| 152 |
-
# continue
|
| 153 |
-
|
| 154 |
-
# if expected == "table":
|
| 155 |
-
# m = re.search(r"(?:from|join)\s+([A-Za-z_][A-Za-z0-9_]*)$", prefix_text, flags=re.I)
|
| 156 |
-
# partial = m.group(1) if m else None
|
| 157 |
-
# if partial is None and not re.search(r"(?:from|join)\s*$", prefix_text, flags=re.I):
|
| 158 |
-
# continue
|
| 159 |
-
# trie = self._token_sets[i].table_trie
|
| 160 |
-
# else:
|
| 161 |
-
# m = re.search(
|
| 162 |
-
# r"(?:select|where|on|group by|order by|having)\s+([A-Za-z_][A-Za-z0-9_]*(?:\.[A-Za-z_][A-Za-z0-9_]*)?)$",
|
| 163 |
-
# prefix_text,
|
| 164 |
-
# flags=re.I,
|
| 165 |
-
# )
|
| 166 |
-
# partial = m.group(1) if m else None
|
| 167 |
-
# if partial is None and not re.search(
|
| 168 |
-
# r"(?:select|where|on|group by|order by|having)\s*$", prefix_text, flags=re.I
|
| 169 |
-
# ):
|
| 170 |
-
# continue
|
| 171 |
-
# trie = self._token_sets[i].column_trie
|
| 172 |
-
|
| 173 |
-
# if not partial:
|
| 174 |
-
# prefix_token_ids: List[int] = []
|
| 175 |
-
# else:
|
| 176 |
-
# try:
|
| 177 |
-
# prefix_token_ids = _encode_identifier(self.tokenizer, partial)
|
| 178 |
-
# except Exception:
|
| 179 |
-
# continue
|
| 180 |
-
|
| 181 |
-
# node = trie.walk(prefix_token_ids)
|
| 182 |
-
# if node is None or node.terminal:
|
| 183 |
-
# continue
|
| 184 |
-
|
| 185 |
-
# allowed_next = sorted(node.children.keys())
|
| 186 |
-
# if not allowed_next:
|
| 187 |
-
# continue
|
| 188 |
-
|
| 189 |
-
# allowed_next_t = torch.tensor(allowed_next, dtype=torch.long, device=scores.device)
|
| 190 |
-
# allow_always = self._token_sets[i].allow_always.to(scores.device)
|
| 191 |
-
# keep = torch.cat([allowed_next_t, allow_always]) if allow_always.numel() else allowed_next_t
|
| 192 |
-
|
| 193 |
-
# kept_scores = scores[i, keep].clone()
|
| 194 |
-
# scores[i, :] = -float("inf")
|
| 195 |
-
# scores[i, keep] = kept_scores
|
| 196 |
-
|
| 197 |
-
# return scores
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
# # Backwards-compatible names used elsewhere in the repo.
|
| 201 |
-
# class SchemaConstraintGraph:
|
| 202 |
-
# def __init__(self, db_path: str):
|
| 203 |
-
# self._graph = build_constraint_graph(db_path)
|
| 204 |
-
# self.tables = sorted(self._graph.tables)
|
| 205 |
-
# self.columns = sorted(self._graph.all_columns)
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
# class SchemaConstrainedLogitsProcessor(LogitsProcessor):
|
| 209 |
-
# def __init__(self, tokenizer, schema_graph: SchemaConstraintGraph):
|
| 210 |
-
# self._proc = BatchSchemaConstrainedLogitsProcessor(tokenizer, [schema_graph._graph.db_path])
|
| 211 |
-
|
| 212 |
-
# def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
| 213 |
-
# return self._proc(input_ids, scores)
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
# from __future__ import annotations
|
| 219 |
-
|
| 220 |
-
# import re
|
| 221 |
-
# import threading
|
| 222 |
-
# from dataclasses import dataclass
|
| 223 |
-
# from typing import Dict, Iterable, List, Optional, Sequence, Set
|
| 224 |
-
|
| 225 |
-
# import torch
|
| 226 |
-
# from transformers.generation.logits_process import LogitsProcessor
|
| 227 |
-
|
| 228 |
-
# from schema_constraints import ConstraintGraph, build_constraint_graph
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
# # =========================================================
|
| 232 |
-
# # 🔍 IDENTIFIER TYPE DETECTION
|
| 233 |
-
# # =========================================================
|
| 234 |
-
# def _infer_expected_identifier(prefix_text: str) -> Optional[str]:
|
| 235 |
-
# s = re.sub(r"\s+", " ", prefix_text.lower())
|
| 236 |
-
|
| 237 |
-
# last_from = s.rfind(" from ")
|
| 238 |
-
# last_join = s.rfind(" join ")
|
| 239 |
-
# last_select = s.rfind(" select ")
|
| 240 |
-
# last_where = s.rfind(" where ")
|
| 241 |
-
# last_on = s.rfind(" on ")
|
| 242 |
-
# last_group = s.rfind(" group by ")
|
| 243 |
-
# last_order = s.rfind(" order by ")
|
| 244 |
-
# last_having = s.rfind(" having ")
|
| 245 |
-
|
| 246 |
-
# last_table_kw = max(last_from, last_join)
|
| 247 |
-
# last_col_kw = max(last_select, last_where, last_on, last_group, last_order, last_having)
|
| 248 |
-
|
| 249 |
-
# if last_table_kw < 0 and last_col_kw < 0:
|
| 250 |
-
# return None
|
| 251 |
-
# if last_table_kw > last_col_kw:
|
| 252 |
-
# return "table"
|
| 253 |
-
# if last_col_kw > last_table_kw:
|
| 254 |
-
# return "column"
|
| 255 |
-
# return None
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
# # =========================================================
|
| 259 |
-
# # 🌳 TRIE STRUCTURE
|
| 260 |
-
# # =========================================================
|
| 261 |
-
# class _TrieNode:
|
| 262 |
-
# __slots__ = ("children", "terminal")
|
| 263 |
-
|
| 264 |
-
# def __init__(self) -> None:
|
| 265 |
-
# self.children: Dict[int, _TrieNode] = {}
|
| 266 |
-
# self.terminal: bool = False
|
| 267 |
-
|
| 268 |
-
# def insert(self, token_ids: Sequence[int]) -> None:
|
| 269 |
-
# node = self
|
| 270 |
-
# for tid in token_ids:
|
| 271 |
-
# tid = int(tid)
|
| 272 |
-
# if tid not in node.children:
|
| 273 |
-
# node.children[tid] = _TrieNode()
|
| 274 |
-
# node = node.children[tid]
|
| 275 |
-
# node.terminal = True
|
| 276 |
-
|
| 277 |
-
# def walk(self, prefix: Sequence[int]) -> Optional["_TrieNode"]:
|
| 278 |
-
# node = self
|
| 279 |
-
# for tid in prefix:
|
| 280 |
-
# node = node.children.get(int(tid))
|
| 281 |
-
# if node is None:
|
| 282 |
-
# return None
|
| 283 |
-
# return node
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
# # =========================================================
|
| 287 |
-
# # 🔤 TOKEN ENCODING
|
| 288 |
-
# # =========================================================
|
| 289 |
-
# def _encode_identifier(tokenizer, name: str) -> List[int]:
|
| 290 |
-
# return tokenizer.encode(" " + name, add_special_tokens=False)
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
# def _build_trie(tokenizer, names: Iterable[str]) -> _TrieNode:
|
| 294 |
-
# trie = _TrieNode()
|
| 295 |
-
# for name in names:
|
| 296 |
-
# try:
|
| 297 |
-
# ids = _encode_identifier(tokenizer, name)
|
| 298 |
-
# if ids:
|
| 299 |
-
# trie.insert(ids)
|
| 300 |
-
# except Exception:
|
| 301 |
-
# continue
|
| 302 |
-
# return trie
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
# def _allow_always_token_ids(tokenizer) -> torch.Tensor:
|
| 306 |
-
# tokens = [",", ")", "(", ".", ";", "\n"]
|
| 307 |
-
# ids: Set[int] = set()
|
| 308 |
-
|
| 309 |
-
# for t in tokens:
|
| 310 |
-
# try:
|
| 311 |
-
# ids.update(tokenizer.encode(t, add_special_tokens=False))
|
| 312 |
-
# except:
|
| 313 |
-
# pass
|
| 314 |
-
|
| 315 |
-
# return torch.tensor(sorted(ids), dtype=torch.long)
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
# # =========================================================
|
| 319 |
-
# # 📦 PER-DB CACHE
|
| 320 |
-
# # =========================================================
|
| 321 |
-
# @dataclass
|
| 322 |
-
# class _PerDbTokenSets:
|
| 323 |
-
# fp: str
|
| 324 |
-
# table_trie: _TrieNode
|
| 325 |
-
# column_trie: _TrieNode
|
| 326 |
-
# allow_always: torch.Tensor
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
# _DB_CACHE: Dict[str, _PerDbTokenSets] = {}
|
| 330 |
-
# _DB_LOCK = threading.Lock()
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
# def _per_db_tokensets(tokenizer, graph: ConstraintGraph) -> _PerDbTokenSets:
|
| 334 |
-
# with _DB_LOCK:
|
| 335 |
-
# cached = _DB_CACHE.get(graph.db_path)
|
| 336 |
-
# if cached and cached.fp == graph.fingerprint:
|
| 337 |
-
# return cached
|
| 338 |
-
|
| 339 |
-
# obj = _PerDbTokenSets(
|
| 340 |
-
# fp=graph.fingerprint,
|
| 341 |
-
# table_trie=_build_trie(tokenizer, graph.tables),
|
| 342 |
-
# column_trie=_build_trie(tokenizer, graph.all_columns),
|
| 343 |
-
# allow_always=_allow_always_token_ids(tokenizer),
|
| 344 |
-
# )
|
| 345 |
-
|
| 346 |
-
# with _DB_LOCK:
|
| 347 |
-
# _DB_CACHE[graph.db_path] = obj
|
| 348 |
-
|
| 349 |
-
# return obj
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
# # =========================================================
|
| 353 |
-
# # 🚀 MAIN LOGITS PROCESSOR
|
| 354 |
-
# # =========================================================
|
| 355 |
-
# class BatchSchemaConstrainedLogitsProcessor(LogitsProcessor):
|
| 356 |
-
# def __init__(self, tokenizer, db_paths: Sequence[str], max_prefix_tokens: int = 48):
|
| 357 |
-
# self.tokenizer = tokenizer
|
| 358 |
-
# self.db_paths = list(db_paths)
|
| 359 |
-
# self.max_prefix_tokens = max_prefix_tokens
|
| 360 |
-
|
| 361 |
-
# self._graphs = [build_constraint_graph(p) for p in db_paths]
|
| 362 |
-
# self._token_sets = [_per_db_tokensets(tokenizer, g) for g in self._graphs]
|
| 363 |
-
|
| 364 |
-
# # 📊 Metrics (IMPORTANT FOR REPORT)
|
| 365 |
-
# self.total_steps = 0
|
| 366 |
-
# self.constrained_steps = 0
|
| 367 |
-
|
| 368 |
-
# def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
|
| 369 |
-
# batch = input_ids.size(0)
|
| 370 |
-
|
| 371 |
-
# for i in range(batch):
|
| 372 |
-
# self.total_steps += 1
|
| 373 |
-
|
| 374 |
-
# tail_ids = input_ids[i, -self.max_prefix_tokens:].tolist()
|
| 375 |
-
# prefix_text = self.tokenizer.decode(tail_ids, skip_special_tokens=True)
|
| 376 |
-
|
| 377 |
-
# expected = _infer_expected_identifier(prefix_text)
|
| 378 |
-
# if expected is None:
|
| 379 |
-
# continue
|
| 380 |
-
|
| 381 |
-
# self.constrained_steps += 1
|
| 382 |
-
|
| 383 |
-
# # =========================
|
| 384 |
-
# # SELECT TRIE
|
| 385 |
-
# # =========================
|
| 386 |
-
# if expected == "table":
|
| 387 |
-
# trie = self._token_sets[i].table_trie
|
| 388 |
-
# else:
|
| 389 |
-
# trie = self._token_sets[i].column_trie
|
| 390 |
-
|
| 391 |
-
# # =========================
|
| 392 |
-
# # PARTIAL TOKEN MATCH
|
| 393 |
-
# # =========================
|
| 394 |
-
# match = re.search(r"([A-Za-z_][A-Za-z0-9_]*)$", prefix_text)
|
| 395 |
-
# partial = match.group(1) if match else ""
|
| 396 |
-
|
| 397 |
-
# try:
|
| 398 |
-
# prefix_ids = _encode_identifier(self.tokenizer, partial) if partial else []
|
| 399 |
-
# except:
|
| 400 |
-
# continue
|
| 401 |
-
|
| 402 |
-
# node = trie.walk(prefix_ids)
|
| 403 |
-
# if node is None or node.terminal:
|
| 404 |
-
# continue
|
| 405 |
-
|
| 406 |
-
# allowed_next = list(node.children.keys())
|
| 407 |
-
# if not allowed_next:
|
| 408 |
-
# continue
|
| 409 |
-
|
| 410 |
-
# allowed_next = torch.tensor(allowed_next, device=scores.device)
|
| 411 |
-
# allow_always = self._token_sets[i].allow_always.to(scores.device)
|
| 412 |
-
|
| 413 |
-
# keep = torch.cat([allowed_next, allow_always])
|
| 414 |
-
|
| 415 |
-
# kept_scores = scores[i, keep].clone()
|
| 416 |
-
# scores[i, :] = -float("inf")
|
| 417 |
-
# scores[i, keep] = kept_scores
|
| 418 |
-
|
| 419 |
-
# return scores
|
| 420 |
-
|
| 421 |
-
# # =========================================================
|
| 422 |
-
# # 📊 METRICS FOR REPORT
|
| 423 |
-
# # =========================================================
|
| 424 |
-
# def get_constraint_stats(self):
|
| 425 |
-
# if self.total_steps == 0:
|
| 426 |
-
# return 0
|
| 427 |
-
# return self.constrained_steps / self.total_steps
|
| 428 |
-
|
| 429 |
-
|
| 430 |
-
# # =========================================================
|
| 431 |
-
# # 🔁 BACKWARD COMPATIBILITY
|
| 432 |
-
# # =========================================================
|
| 433 |
-
# class SchemaConstraintGraph:
|
| 434 |
-
# def __init__(self, db_path: str):
|
| 435 |
-
# self._graph = build_constraint_graph(db_path)
|
| 436 |
-
# self.tables = sorted(self._graph.tables)
|
| 437 |
-
# self.columns = sorted(self._graph.all_columns)
|
| 438 |
-
|
| 439 |
-
|
| 440 |
-
# class SchemaConstrainedLogitsProcessor(LogitsProcessor):
|
| 441 |
-
# def __init__(self, tokenizer, schema_graph: SchemaConstraintGraph):
|
| 442 |
-
# self.proc = BatchSchemaConstrainedLogitsProcessor(
|
| 443 |
-
# tokenizer, [schema_graph._graph.db_path]
|
| 444 |
-
# )
|
| 445 |
-
|
| 446 |
-
# def __call__(self, input_ids, scores):
|
| 447 |
-
# return self.proc(input_ids, scores)
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
|
| 452 |
-
|
| 453 |
-
|
| 454 |
-
# from __future__ import annotations
|
| 455 |
-
|
| 456 |
-
# import re
|
| 457 |
-
# import threading
|
| 458 |
-
# from dataclasses import dataclass
|
| 459 |
-
# from typing import Dict, Iterable, List, Optional, Sequence, Set
|
| 460 |
-
|
| 461 |
-
# import torch
|
| 462 |
-
# from transformers.generation.logits_process import LogitsProcessor
|
| 463 |
-
|
| 464 |
-
# from schema_constraints import ConstraintGraph, build_constraint_graph
|
| 465 |
-
|
| 466 |
-
|
| 467 |
-
# def _infer_expected_identifier(prefix_text: str) -> Optional[str]:
|
| 468 |
-
# s = re.sub(r"\s+", " ", prefix_text.lower())
|
| 469 |
-
# last_from = s.rfind(" from ")
|
| 470 |
-
# last_join = s.rfind(" join ")
|
| 471 |
-
# last_select = s.rfind(" select ")
|
| 472 |
-
# last_where = s.rfind(" where ")
|
| 473 |
-
# last_on = s.rfind(" on ")
|
| 474 |
-
# last_group = s.rfind(" group by ")
|
| 475 |
-
# last_order = s.rfind(" order by ")
|
| 476 |
-
# last_having = s.rfind(" having ")
|
| 477 |
-
|
| 478 |
-
# last_table_kw = max(last_from, last_join)
|
| 479 |
-
# last_col_kw = max(last_select, last_where, last_on, last_group, last_order, last_having)
|
| 480 |
-
|
| 481 |
-
# if last_table_kw < 0 and last_col_kw < 0:
|
| 482 |
-
# return None
|
| 483 |
-
# if last_table_kw > last_col_kw:
|
| 484 |
-
# return "table"
|
| 485 |
-
# if last_col_kw > last_table_kw:
|
| 486 |
-
# return "column"
|
| 487 |
-
# return None
|
| 488 |
-
|
| 489 |
-
|
| 490 |
-
# class _TrieNode:
|
| 491 |
-
# __slots__ = ("children", "terminal")
|
| 492 |
-
|
| 493 |
-
# def __init__(self) -> None:
|
| 494 |
-
# self.children: Dict[int, _TrieNode] = {}
|
| 495 |
-
# self.terminal: bool = False
|
| 496 |
-
|
| 497 |
-
# def insert(self, token_ids: Sequence[int]) -> None:
|
| 498 |
-
# node: _TrieNode = self
|
| 499 |
-
# for tid in token_ids:
|
| 500 |
-
# tid_i = int(tid)
|
| 501 |
-
# nxt = node.children.get(tid_i)
|
| 502 |
-
# if nxt is None:
|
| 503 |
-
# nxt = _TrieNode()
|
| 504 |
-
# node.children[tid_i] = nxt
|
| 505 |
-
# node = nxt
|
| 506 |
-
# node.terminal = True
|
| 507 |
-
|
| 508 |
-
# def walk(self, prefix: Sequence[int]) -> Optional["_TrieNode"]:
|
| 509 |
-
# node: _TrieNode = self
|
| 510 |
-
# for tid in prefix:
|
| 511 |
-
# node = node.children.get(int(tid)) # type: ignore[assignment]
|
| 512 |
-
# if node is None:
|
| 513 |
-
# return None
|
| 514 |
-
# return node
|
| 515 |
-
|
| 516 |
-
|
| 517 |
-
# def _encode_identifier(tokenizer, name: str) -> List[int]:
|
| 518 |
-
# # Leading space encourages word-start markers (e.g. "Ġ" in RoBERTa BPE).
|
| 519 |
-
# return tokenizer.encode(" " + name, add_special_tokens=False)
|
| 520 |
-
|
| 521 |
-
|
| 522 |
-
# def _build_trie(tokenizer, names: Iterable[str]) -> _TrieNode:
|
| 523 |
-
# trie = _TrieNode()
|
| 524 |
-
# for n in names:
|
| 525 |
-
# if not n:
|
| 526 |
-
# continue
|
| 527 |
-
# try:
|
| 528 |
-
# ids = _encode_identifier(tokenizer, n)
|
| 529 |
-
# except Exception:
|
| 530 |
-
# continue
|
| 531 |
-
# if ids:
|
| 532 |
-
# trie.insert(ids)
|
| 533 |
-
# return trie
|
| 534 |
-
|
| 535 |
-
|
| 536 |
-
# def _allow_always_token_ids(tokenizer) -> torch.Tensor:
|
| 537 |
-
# # Allow common delimiters so the model can end an identifier.
|
| 538 |
-
# toks = [",", ")", "(", "\n", ".", ";"]
|
| 539 |
-
# ids: Set[int] = set()
|
| 540 |
-
# for t in toks:
|
| 541 |
-
# try:
|
| 542 |
-
# for tid in tokenizer.encode(t, add_special_tokens=False):
|
| 543 |
-
# ids.add(int(tid))
|
| 544 |
-
# except Exception:
|
| 545 |
-
# continue
|
| 546 |
-
# return torch.tensor(sorted(ids), dtype=torch.long)
|
| 547 |
-
|
| 548 |
-
|
| 549 |
-
# @dataclass
|
| 550 |
-
# class _PerDbTokenSets:
|
| 551 |
-
# fp: str
|
| 552 |
-
# table_trie: _TrieNode
|
| 553 |
-
# column_trie: _TrieNode
|
| 554 |
-
# allow_always: torch.Tensor
|
| 555 |
-
|
| 556 |
-
|
| 557 |
-
# _DB_TOKENSET_LOCK = threading.Lock()
|
| 558 |
-
# _DB_TOKENSETS: Dict[str, _PerDbTokenSets] = {}
|
| 559 |
-
|
| 560 |
-
|
| 561 |
-
# def _per_db_tokensets(tokenizer, graph: ConstraintGraph) -> _PerDbTokenSets:
|
| 562 |
-
# with _DB_TOKENSET_LOCK:
|
| 563 |
-
# cached = _DB_TOKENSETS.get(graph.db_path)
|
| 564 |
-
# if cached is not None and cached.fp == graph.fingerprint:
|
| 565 |
-
# return cached
|
| 566 |
-
|
| 567 |
-
# out = _PerDbTokenSets(
|
| 568 |
-
# fp=graph.fingerprint,
|
| 569 |
-
# table_trie=_build_trie(tokenizer, graph.tables),
|
| 570 |
-
# column_trie=_build_trie(tokenizer, graph.all_columns),
|
| 571 |
-
# allow_always=_allow_always_token_ids(tokenizer),
|
| 572 |
-
# )
|
| 573 |
-
# with _DB_TOKENSET_LOCK:
|
| 574 |
-
# _DB_TOKENSETS[graph.db_path] = out
|
| 575 |
-
# return out
|
| 576 |
-
|
| 577 |
-
|
| 578 |
-
# class BatchSchemaConstrainedLogitsProcessor(LogitsProcessor):
|
| 579 |
-
# """
|
| 580 |
-
# Schema-aware constrained decoding per item in the generation batch.
|
| 581 |
-
# Uses a tokenizer-based trie so multi-token identifiers can be constrained.
|
| 582 |
-
# """
|
| 583 |
-
|
| 584 |
-
# def __init__(self, tokenizer, db_paths: Sequence[str], *, max_prefix_tokens: int = 48):
|
| 585 |
-
# self.tokenizer = tokenizer
|
| 586 |
-
# self.db_paths = list(db_paths)
|
| 587 |
-
# self.max_prefix_tokens = int(max_prefix_tokens)
|
| 588 |
-
|
| 589 |
-
# self._graphs = [build_constraint_graph(p) for p in self.db_paths]
|
| 590 |
-
# self._token_sets = [_per_db_tokensets(tokenizer, g) for g in self._graphs]
|
| 591 |
-
|
| 592 |
-
# def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
| 593 |
-
# if input_ids.dim() != 2 or scores.dim() != 2:
|
| 594 |
-
# return scores
|
| 595 |
-
|
| 596 |
-
# batch = input_ids.size(0)
|
| 597 |
-
# if batch != len(self._graphs):
|
| 598 |
-
# return scores
|
| 599 |
-
|
| 600 |
-
# for i in range(batch):
|
| 601 |
-
# tail_ids = input_ids[i, -self.max_prefix_tokens :].tolist()
|
| 602 |
-
# prefix_text = self.tokenizer.decode(tail_ids, skip_special_tokens=True)
|
| 603 |
-
# expected = _infer_expected_identifier(prefix_text)
|
| 604 |
-
# if expected is None:
|
| 605 |
-
# continue
|
| 606 |
-
|
| 607 |
-
# if expected == "table":
|
| 608 |
-
# m = re.search(r"(?:from|join)\s+([A-Za-z_][A-Za-z0-9_]*)$", prefix_text, flags=re.I)
|
| 609 |
-
# partial = m.group(1) if m else None
|
| 610 |
-
# if partial is None and not re.search(r"(?:from|join)\s*$", prefix_text, flags=re.I):
|
| 611 |
-
# continue
|
| 612 |
-
# trie = self._token_sets[i].table_trie
|
| 613 |
-
# else:
|
| 614 |
-
# m = re.search(
|
| 615 |
-
# r"(?:select|where|on|group by|order by|having)\s+([A-Za-z_][A-Za-z0-9_]*(?:\.[A-Za-z_][A-Za-z0-9_]*)?)$",
|
| 616 |
-
# prefix_text,
|
| 617 |
-
# flags=re.I,
|
| 618 |
-
# )
|
| 619 |
-
# partial = m.group(1) if m else None
|
| 620 |
-
# if partial is None and not re.search(
|
| 621 |
-
# r"(?:select|where|on|group by|order by|having)\s*$", prefix_text, flags=re.I
|
| 622 |
-
# ):
|
| 623 |
-
# continue
|
| 624 |
-
# trie = self._token_sets[i].column_trie
|
| 625 |
-
|
| 626 |
-
# if not partial:
|
| 627 |
-
# prefix_token_ids: List[int] = []
|
| 628 |
-
# else:
|
| 629 |
-
# try:
|
| 630 |
-
# prefix_token_ids = _encode_identifier(self.tokenizer, partial)
|
| 631 |
-
# except Exception:
|
| 632 |
-
# continue
|
| 633 |
-
|
| 634 |
-
# node = trie.walk(prefix_token_ids)
|
| 635 |
-
# if node is None or node.terminal:
|
| 636 |
-
# continue
|
| 637 |
-
|
| 638 |
-
# allowed_next = sorted(node.children.keys())
|
| 639 |
-
# if not allowed_next:
|
| 640 |
-
# continue
|
| 641 |
-
|
| 642 |
-
# allowed_next_t = torch.tensor(allowed_next, dtype=torch.long, device=scores.device)
|
| 643 |
-
# allow_always = self._token_sets[i].allow_always.to(scores.device)
|
| 644 |
-
# keep = torch.cat([allowed_next_t, allow_always]) if allow_always.numel() else allowed_next_t
|
| 645 |
-
|
| 646 |
-
# kept_scores = scores[i, keep].clone()
|
| 647 |
-
# scores[i, :] = -float("inf")
|
| 648 |
-
# scores[i, keep] = kept_scores
|
| 649 |
-
|
| 650 |
-
# return scores
|
| 651 |
-
|
| 652 |
-
|
| 653 |
-
# # Backwards-compatible names used elsewhere in the repo.
|
| 654 |
-
# class SchemaConstraintGraph:
|
| 655 |
-
# def __init__(self, db_path: str):
|
| 656 |
-
# self._graph = build_constraint_graph(db_path)
|
| 657 |
-
# self.tables = sorted(self._graph.tables)
|
| 658 |
-
# self.columns = sorted(self._graph.all_columns)
|
| 659 |
-
|
| 660 |
-
|
| 661 |
-
# class SchemaConstrainedLogitsProcessor(LogitsProcessor):
|
| 662 |
-
# def __init__(self, tokenizer, schema_graph: SchemaConstraintGraph):
|
| 663 |
-
# self._proc = BatchSchemaConstrainedLogitsProcessor(tokenizer, [schema_graph._graph.db_path])
|
| 664 |
-
|
| 665 |
-
# def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
| 666 |
-
# return self._proc(input_ids, scores)
|
| 667 |
-
|
| 668 |
-
|
| 669 |
-
|
| 670 |
-
|
| 671 |
-
# from __future__ import annotations
|
| 672 |
-
|
| 673 |
-
# import re
|
| 674 |
-
# import threading
|
| 675 |
-
# from dataclasses import dataclass
|
| 676 |
-
# from typing import Dict, Iterable, List, Optional, Sequence, Set
|
| 677 |
-
|
| 678 |
-
# import torch
|
| 679 |
-
# from transformers.generation.logits_process import LogitsProcessor
|
| 680 |
-
|
| 681 |
-
# from schema_constraints import ConstraintGraph, build_constraint_graph
|
| 682 |
-
|
| 683 |
-
|
| 684 |
-
# # =========================================================
|
| 685 |
-
# # 🔍 IDENTIFIER TYPE DETECTION
|
| 686 |
-
# # =========================================================
|
| 687 |
-
# def _infer_expected_identifier(prefix_text: str) -> Optional[str]:
|
| 688 |
-
# s = re.sub(r"\s+", " ", prefix_text.lower())
|
| 689 |
-
|
| 690 |
-
# last_from = s.rfind(" from ")
|
| 691 |
-
# last_join = s.rfind(" join ")
|
| 692 |
-
# last_select = s.rfind(" select ")
|
| 693 |
-
# last_where = s.rfind(" where ")
|
| 694 |
-
# last_on = s.rfind(" on ")
|
| 695 |
-
# last_group = s.rfind(" group by ")
|
| 696 |
-
# last_order = s.rfind(" order by ")
|
| 697 |
-
# last_having = s.rfind(" having ")
|
| 698 |
-
|
| 699 |
-
# last_table_kw = max(last_from, last_join)
|
| 700 |
-
# last_col_kw = max(last_select, last_where, last_on, last_group, last_order, last_having)
|
| 701 |
-
|
| 702 |
-
# if last_table_kw < 0 and last_col_kw < 0:
|
| 703 |
-
# return None
|
| 704 |
-
# if last_table_kw > last_col_kw:
|
| 705 |
-
# return "table"
|
| 706 |
-
# if last_col_kw > last_table_kw:
|
| 707 |
-
# return "column"
|
| 708 |
-
# return None
|
| 709 |
-
|
| 710 |
-
|
| 711 |
-
# # =========================================================
|
| 712 |
-
# # 🌳 TRIE STRUCTURE
|
| 713 |
-
# # =========================================================
|
| 714 |
-
# class _TrieNode:
|
| 715 |
-
# __slots__ = ("children", "terminal")
|
| 716 |
-
|
| 717 |
-
# def __init__(self) -> None:
|
| 718 |
-
# self.children: Dict[int, _TrieNode] = {}
|
| 719 |
-
# self.terminal: bool = False
|
| 720 |
-
|
| 721 |
-
# def insert(self, token_ids: Sequence[int]) -> None:
|
| 722 |
-
# node = self
|
| 723 |
-
# for tid in token_ids:
|
| 724 |
-
# tid = int(tid)
|
| 725 |
-
# if tid not in node.children:
|
| 726 |
-
# node.children[tid] = _TrieNode()
|
| 727 |
-
# node = node.children[tid]
|
| 728 |
-
# node.terminal = True
|
| 729 |
-
|
| 730 |
-
# def walk(self, prefix: Sequence[int]) -> Optional["_TrieNode"]:
|
| 731 |
-
# node = self
|
| 732 |
-
# for tid in prefix:
|
| 733 |
-
# node = node.children.get(int(tid))
|
| 734 |
-
# if node is None:
|
| 735 |
-
# return None
|
| 736 |
-
# return node
|
| 737 |
-
|
| 738 |
-
|
| 739 |
-
# # =========================================================
|
| 740 |
-
# # 🔤 TOKEN ENCODING
|
| 741 |
-
# # =========================================================
|
| 742 |
-
# def _encode_identifier(tokenizer, name: str) -> List[int]:
|
| 743 |
-
# return tokenizer.encode(" " + name, add_special_tokens=False)
|
| 744 |
-
|
| 745 |
-
|
| 746 |
-
# def _build_trie(tokenizer, names: Iterable[str]) -> _TrieNode:
|
| 747 |
-
# trie = _TrieNode()
|
| 748 |
-
# for name in names:
|
| 749 |
-
# try:
|
| 750 |
-
# ids = _encode_identifier(tokenizer, name)
|
| 751 |
-
# if ids:
|
| 752 |
-
# trie.insert(ids)
|
| 753 |
-
# except Exception:
|
| 754 |
-
# continue
|
| 755 |
-
# return trie
|
| 756 |
-
|
| 757 |
-
|
| 758 |
-
# def _allow_always_token_ids(tokenizer) -> torch.Tensor:
|
| 759 |
-
# tokens = [",", ")", "(", ".", ";", "\n"]
|
| 760 |
-
# ids: Set[int] = set()
|
| 761 |
-
|
| 762 |
-
# for t in tokens:
|
| 763 |
-
# try:
|
| 764 |
-
# ids.update(tokenizer.encode(t, add_special_tokens=False))
|
| 765 |
-
# except:
|
| 766 |
-
# pass
|
| 767 |
-
|
| 768 |
-
# return torch.tensor(sorted(ids), dtype=torch.long)
|
| 769 |
-
|
| 770 |
-
|
| 771 |
-
# # =========================================================
|
| 772 |
-
# # 📦 PER-DB CACHE
|
| 773 |
-
# # =========================================================
|
| 774 |
-
# @dataclass
|
| 775 |
-
# class _PerDbTokenSets:
|
| 776 |
-
# fp: str
|
| 777 |
-
# table_trie: _TrieNode
|
| 778 |
-
# column_trie: _TrieNode
|
| 779 |
-
# allow_always: torch.Tensor
|
| 780 |
-
|
| 781 |
-
|
| 782 |
-
# _DB_CACHE: Dict[str, _PerDbTokenSets] = {}
|
| 783 |
-
# _DB_LOCK = threading.Lock()
|
| 784 |
-
|
| 785 |
-
|
| 786 |
-
# def _per_db_tokensets(tokenizer, graph: ConstraintGraph) -> _PerDbTokenSets:
|
| 787 |
-
# with _DB_LOCK:
|
| 788 |
-
# cached = _DB_CACHE.get(graph.db_path)
|
| 789 |
-
# if cached and cached.fp == graph.fingerprint:
|
| 790 |
-
# return cached
|
| 791 |
-
|
| 792 |
-
# obj = _PerDbTokenSets(
|
| 793 |
-
# fp=graph.fingerprint,
|
| 794 |
-
# table_trie=_build_trie(tokenizer, graph.tables),
|
| 795 |
-
# column_trie=_build_trie(tokenizer, graph.all_columns),
|
| 796 |
-
# allow_always=_allow_always_token_ids(tokenizer),
|
| 797 |
-
# )
|
| 798 |
-
|
| 799 |
-
# with _DB_LOCK:
|
| 800 |
-
# _DB_CACHE[graph.db_path] = obj
|
| 801 |
-
|
| 802 |
-
# return obj
|
| 803 |
-
|
| 804 |
-
|
| 805 |
-
# # =========================================================
|
| 806 |
-
# # 🚀 MAIN LOGITS PROCESSOR
|
| 807 |
-
# # =========================================================
|
| 808 |
-
# class BatchSchemaConstrainedLogitsProcessor(LogitsProcessor):
|
| 809 |
-
# def __init__(self, tokenizer, db_paths: Sequence[str], max_prefix_tokens: int = 48):
|
| 810 |
-
# self.tokenizer = tokenizer
|
| 811 |
-
# self.db_paths = list(db_paths)
|
| 812 |
-
# self.max_prefix_tokens = max_prefix_tokens
|
| 813 |
-
|
| 814 |
-
# self._graphs = [build_constraint_graph(p) for p in db_paths]
|
| 815 |
-
# self._token_sets = [_per_db_tokensets(tokenizer, g) for g in self._graphs]
|
| 816 |
-
|
| 817 |
-
# # 📊 Metrics (IMPORTANT FOR REPORT)
|
| 818 |
-
# self.total_steps = 0
|
| 819 |
-
# self.constrained_steps = 0
|
| 820 |
-
|
| 821 |
-
# def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
|
| 822 |
-
# batch = input_ids.size(0)
|
| 823 |
-
|
| 824 |
-
# for i in range(batch):
|
| 825 |
-
# self.total_steps += 1
|
| 826 |
-
|
| 827 |
-
# tail_ids = input_ids[i, -self.max_prefix_tokens:].tolist()
|
| 828 |
-
# prefix_text = self.tokenizer.decode(tail_ids, skip_special_tokens=True)
|
| 829 |
-
|
| 830 |
-
# expected = _infer_expected_identifier(prefix_text)
|
| 831 |
-
# if expected is None:
|
| 832 |
-
# continue
|
| 833 |
-
|
| 834 |
-
# self.constrained_steps += 1
|
| 835 |
-
|
| 836 |
-
# # =========================
|
| 837 |
-
# # SELECT TRIE
|
| 838 |
-
# # =========================
|
| 839 |
-
# if expected == "table":
|
| 840 |
-
# trie = self._token_sets[i].table_trie
|
| 841 |
-
# else:
|
| 842 |
-
# trie = self._token_sets[i].column_trie
|
| 843 |
-
|
| 844 |
-
# # =========================
|
| 845 |
-
# # PARTIAL TOKEN MATCH
|
| 846 |
-
# # =========================
|
| 847 |
-
# match = re.search(r"([A-Za-z_][A-Za-z0-9_]*)$", prefix_text)
|
| 848 |
-
# partial = match.group(1) if match else ""
|
| 849 |
-
|
| 850 |
-
# try:
|
| 851 |
-
# prefix_ids = _encode_identifier(self.tokenizer, partial) if partial else []
|
| 852 |
-
# except:
|
| 853 |
-
# continue
|
| 854 |
-
|
| 855 |
-
# node = trie.walk(prefix_ids)
|
| 856 |
-
# if node is None or node.terminal:
|
| 857 |
-
# continue
|
| 858 |
-
|
| 859 |
-
# allowed_next = list(node.children.keys())
|
| 860 |
-
# if not allowed_next:
|
| 861 |
-
# continue
|
| 862 |
-
|
| 863 |
-
# allowed_next = torch.tensor(allowed_next, device=scores.device)
|
| 864 |
-
# allow_always = self._token_sets[i].allow_always.to(scores.device)
|
| 865 |
-
|
| 866 |
-
# keep = torch.cat([allowed_next, allow_always])
|
| 867 |
-
|
| 868 |
-
# kept_scores = scores[i, keep].clone()
|
| 869 |
-
# scores[i, :] = -float("inf")
|
| 870 |
-
# scores[i, keep] = kept_scores
|
| 871 |
-
|
| 872 |
-
# return scores
|
| 873 |
-
|
| 874 |
-
# # =========================================================
|
| 875 |
-
# # 📊 METRICS FOR REPORT
|
| 876 |
-
# # =========================================================
|
| 877 |
-
# def get_constraint_stats(self):
|
| 878 |
-
# if self.total_steps == 0:
|
| 879 |
-
# return 0
|
| 880 |
-
# return self.constrained_steps / self.total_steps
|
| 881 |
-
|
| 882 |
-
|
| 883 |
-
# # =========================================================
|
| 884 |
-
# # 🔁 BACKWARD COMPATIBILITY
|
| 885 |
-
# # =========================================================
|
| 886 |
-
# class SchemaConstraintGraph:
|
| 887 |
-
# def __init__(self, db_path: str):
|
| 888 |
-
# self._graph = build_constraint_graph(db_path)
|
| 889 |
-
# self.tables = sorted(self._graph.tables)
|
| 890 |
-
# self.columns = sorted(self._graph.all_columns)
|
| 891 |
-
|
| 892 |
-
|
| 893 |
-
# class SchemaConstrainedLogitsProcessor(LogitsProcessor):
|
| 894 |
-
# def __init__(self, tokenizer, schema_graph: SchemaConstraintGraph):
|
| 895 |
-
# self.proc = BatchSchemaConstrainedLogitsProcessor(
|
| 896 |
-
# tokenizer, [schema_graph._graph.db_path]
|
| 897 |
-
# )
|
| 898 |
-
|
| 899 |
-
# def __call__(self, input_ids, scores):
|
| 900 |
-
# return self.proc(input_ids, scores)
|
| 901 |
-
|
| 902 |
-
|
| 903 |
-
|
| 904 |
-
|
| 905 |
-
|
| 906 |
-
|
| 907 |
-
|
| 908 |
-
|
| 909 |
-
# ********* after task 3
|
| 910 |
-
|
| 911 |
-
import re
|
| 912 |
-
import threading
|
| 913 |
-
from functools import lru_cache
|
| 914 |
-
|
| 915 |
-
import torch
|
| 916 |
-
from transformers import LogitsProcessor
|
| 917 |
-
|
| 918 |
-
from src.schema_utils import get_constraint_graph
|
| 919 |
-
|
| 920 |
-
|
| 921 |
-
_TOKEN_CACHE_LOCK = threading.Lock()
|
| 922 |
-
_TOKEN_ID_CACHE = {} # (id(tokenizer), db_path) -> (allowed_ids_tensor, always_allow_ids_tensor)
|
| 923 |
-
|
| 924 |
-
|
| 925 |
-
def _encode_variants(tokenizer, text: str) -> list[int]:
|
| 926 |
-
ids: list[int] = []
|
| 927 |
-
for variant in (text, " " + text):
|
| 928 |
-
try:
|
| 929 |
-
ids.extend(tokenizer.encode(variant, add_special_tokens=False))
|
| 930 |
-
except Exception:
|
| 931 |
-
continue
|
| 932 |
-
# de-dup while keeping order
|
| 933 |
-
seen = set()
|
| 934 |
-
out = []
|
| 935 |
-
for i in ids:
|
| 936 |
-
if int(i) not in seen:
|
| 937 |
-
seen.add(int(i))
|
| 938 |
-
out.append(int(i))
|
| 939 |
-
return out
|
| 940 |
-
|
| 941 |
-
|
| 942 |
-
def _always_allow_ids(tokenizer) -> list[int]:
|
| 943 |
-
"""
|
| 944 |
-
Tokens we should never block, otherwise decoding can get stuck or generate garbage:
|
| 945 |
-
- EOS/PAD
|
| 946 |
-
- punctuation/operators needed for SQL formatting
|
| 947 |
-
- digits/quotes
|
| 948 |
-
"""
|
| 949 |
-
ids: list[int] = []
|
| 950 |
-
for special in [getattr(tokenizer, "eos_token_id", None), getattr(tokenizer, "pad_token_id", None)]:
|
| 951 |
-
if special is not None:
|
| 952 |
-
ids.append(int(special))
|
| 953 |
-
|
| 954 |
-
# Common SQL punctuation/operators
|
| 955 |
-
pieces = [
|
| 956 |
-
" ", "\n", "\t",
|
| 957 |
-
",", ".", "(", ")", ";",
|
| 958 |
-
"=", "!=", "<>", "<", ">", "<=", ">=",
|
| 959 |
-
"*", "+", "-", "/", "%",
|
| 960 |
-
"'", '"',
|
| 961 |
-
]
|
| 962 |
-
for p in pieces:
|
| 963 |
-
ids.extend(_encode_variants(tokenizer, p))
|
| 964 |
-
|
| 965 |
-
# digits
|
| 966 |
-
for d in "0123456789":
|
| 967 |
-
ids.extend(_encode_variants(tokenizer, d))
|
| 968 |
-
|
| 969 |
-
seen = set()
|
| 970 |
-
out = []
|
| 971 |
-
for i in ids:
|
| 972 |
-
if int(i) not in seen:
|
| 973 |
-
seen.add(int(i))
|
| 974 |
-
out.append(int(i))
|
| 975 |
-
return out
|
| 976 |
-
|
| 977 |
-
|
| 978 |
-
def _infer_expected_identifier_tail(tail_text: str):
|
| 979 |
-
"""
|
| 980 |
-
Returns ("table"|"column", partial_or_empty) if the tail looks like it's currently
|
| 981 |
-
emitting a table/column identifier. Otherwise returns None.
|
| 982 |
-
"""
|
| 983 |
-
t = re.sub(r"\s+", " ", (tail_text or "")).lower()
|
| 984 |
-
|
| 985 |
-
m = re.search(r"(?:from|join)\s+([a-z_][a-z0-9_]*)?$", t)
|
| 986 |
-
if m:
|
| 987 |
-
partial = m.group(1) or ""
|
| 988 |
-
# ensure we are actually after keyword (not elsewhere)
|
| 989 |
-
if re.search(r"(?:from|join)\s*$", t) or partial:
|
| 990 |
-
return "table", partial
|
| 991 |
-
|
| 992 |
-
m = re.search(
|
| 993 |
-
r"(?:select|where|on|group by|order by|having)\s+([a-z_][a-z0-9_]*(?:\.[a-z_][a-z0-9_]*)?)?$",
|
| 994 |
-
t,
|
| 995 |
-
)
|
| 996 |
-
if m:
|
| 997 |
-
partial = m.group(1) or ""
|
| 998 |
-
if re.search(r"(?:select|where|on|group by|order by|having)\s*$", t) or partial:
|
| 999 |
-
return "column", partial
|
| 1000 |
-
|
| 1001 |
-
return None
|
| 1002 |
-
|
| 1003 |
-
|
| 1004 |
-
class SchemaConstrainedLogitsProcessor(LogitsProcessor):
|
| 1005 |
-
def __init__(self, tokenizer, db_path):
|
| 1006 |
-
self.tokenizer = tokenizer
|
| 1007 |
-
|
| 1008 |
-
graph = get_constraint_graph(db_path)
|
| 1009 |
-
|
| 1010 |
-
key = (id(tokenizer), str(db_path))
|
| 1011 |
-
with _TOKEN_CACHE_LOCK:
|
| 1012 |
-
cached = _TOKEN_ID_CACHE.get(key)
|
| 1013 |
-
if cached is None:
|
| 1014 |
-
allowed_tokens = set(graph.get("tables", set())) | set(graph.get("columns", set()))
|
| 1015 |
-
|
| 1016 |
-
sql_keywords = {
|
| 1017 |
-
"select", "from", "where", "join", "on",
|
| 1018 |
-
"group", "by", "order", "limit", "having",
|
| 1019 |
-
"and", "or", "desc", "asc",
|
| 1020 |
-
"count", "avg", "min", "max", "sum",
|
| 1021 |
-
"distinct", "as", "in", "like", "between",
|
| 1022 |
-
"is", "null",
|
| 1023 |
-
}
|
| 1024 |
-
allowed_tokens |= sql_keywords
|
| 1025 |
-
|
| 1026 |
-
allowed_ids: list[int] = []
|
| 1027 |
-
for tok in sorted(allowed_tokens):
|
| 1028 |
-
allowed_ids.extend(_encode_variants(tokenizer, tok))
|
| 1029 |
-
always_ids = _always_allow_ids(tokenizer)
|
| 1030 |
-
|
| 1031 |
-
allowed_ids_t = torch.tensor(sorted(set(allowed_ids)), dtype=torch.long)
|
| 1032 |
-
always_ids_t = torch.tensor(sorted(set(always_ids)), dtype=torch.long)
|
| 1033 |
-
cached = (allowed_ids_t, always_ids_t)
|
| 1034 |
-
with _TOKEN_CACHE_LOCK:
|
| 1035 |
-
_TOKEN_ID_CACHE[key] = cached
|
| 1036 |
-
|
| 1037 |
-
self._allowed_ids_t, self._always_ids_t = cached
|
| 1038 |
-
|
| 1039 |
-
def __call__(self, input_ids, scores):
|
| 1040 |
-
# Decode only a tail window for speed (beam search calls this a lot).
|
| 1041 |
-
try:
|
| 1042 |
-
tail_ids = input_ids[0][-128:]
|
| 1043 |
-
except Exception:
|
| 1044 |
-
tail_ids = input_ids[0]
|
| 1045 |
-
tail = self.tokenizer.decode(tail_ids, skip_special_tokens=True)
|
| 1046 |
-
|
| 1047 |
-
inferred = _infer_expected_identifier_tail(tail)
|
| 1048 |
-
if inferred is None:
|
| 1049 |
-
return scores
|
| 1050 |
-
|
| 1051 |
-
keep = torch.cat([self._allowed_ids_t.to(scores.device), self._always_ids_t.to(scores.device)])
|
| 1052 |
-
if keep.numel() == 0:
|
| 1053 |
-
return scores
|
| 1054 |
-
|
| 1055 |
-
kept_scores = scores[:, keep].clone()
|
| 1056 |
-
scores[:] = -float("inf")
|
| 1057 |
-
scores[:, keep] = kept_scores
|
| 1058 |
-
return scores
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/constrained_decoding_sample.py
DELETED
|
@@ -1,516 +0,0 @@
|
|
| 1 |
-
# from __future__ import annotations
|
| 2 |
-
|
| 3 |
-
# import re
|
| 4 |
-
# import threading
|
| 5 |
-
# from dataclasses import dataclass
|
| 6 |
-
# from typing import Dict, Iterable, List, Optional, Sequence, Set
|
| 7 |
-
|
| 8 |
-
# import torch
|
| 9 |
-
# from transformers.generation.logits_process import LogitsProcessor
|
| 10 |
-
|
| 11 |
-
# from schema_constraints import ConstraintGraph, build_constraint_graph
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
# def _infer_expected_identifier(prefix_text: str) -> Optional[str]:
|
| 15 |
-
# s = re.sub(r"\s+", " ", prefix_text.lower())
|
| 16 |
-
# last_from = s.rfind(" from ")
|
| 17 |
-
# last_join = s.rfind(" join ")
|
| 18 |
-
# last_select = s.rfind(" select ")
|
| 19 |
-
# last_where = s.rfind(" where ")
|
| 20 |
-
# last_on = s.rfind(" on ")
|
| 21 |
-
# last_group = s.rfind(" group by ")
|
| 22 |
-
# last_order = s.rfind(" order by ")
|
| 23 |
-
# last_having = s.rfind(" having ")
|
| 24 |
-
|
| 25 |
-
# last_table_kw = max(last_from, last_join)
|
| 26 |
-
# last_col_kw = max(last_select, last_where, last_on, last_group, last_order, last_having)
|
| 27 |
-
|
| 28 |
-
# if last_table_kw < 0 and last_col_kw < 0:
|
| 29 |
-
# return None
|
| 30 |
-
# if last_table_kw > last_col_kw:
|
| 31 |
-
# return "table"
|
| 32 |
-
# if last_col_kw > last_table_kw:
|
| 33 |
-
# return "column"
|
| 34 |
-
# return None
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
# class _TrieNode:
|
| 38 |
-
# __slots__ = ("children", "terminal")
|
| 39 |
-
|
| 40 |
-
# def __init__(self) -> None:
|
| 41 |
-
# self.children: Dict[int, _TrieNode] = {}
|
| 42 |
-
# self.terminal: bool = False
|
| 43 |
-
|
| 44 |
-
# def insert(self, token_ids: Sequence[int]) -> None:
|
| 45 |
-
# node: _TrieNode = self
|
| 46 |
-
# for tid in token_ids:
|
| 47 |
-
# tid_i = int(tid)
|
| 48 |
-
# nxt = node.children.get(tid_i)
|
| 49 |
-
# if nxt is None:
|
| 50 |
-
# nxt = _TrieNode()
|
| 51 |
-
# node.children[tid_i] = nxt
|
| 52 |
-
# node = nxt
|
| 53 |
-
# node.terminal = True
|
| 54 |
-
|
| 55 |
-
# def walk(self, prefix: Sequence[int]) -> Optional["_TrieNode"]:
|
| 56 |
-
# node: _TrieNode = self
|
| 57 |
-
# for tid in prefix:
|
| 58 |
-
# node = node.children.get(int(tid)) # type: ignore[assignment]
|
| 59 |
-
# if node is None:
|
| 60 |
-
# return None
|
| 61 |
-
# return node
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
# def _encode_identifier(tokenizer, name: str) -> List[int]:
|
| 65 |
-
# # Leading space encourages word-start markers (e.g. "Ġ" in RoBERTa BPE).
|
| 66 |
-
# return tokenizer.encode(" " + name, add_special_tokens=False)
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
# def _build_trie(tokenizer, names: Iterable[str]) -> _TrieNode:
|
| 70 |
-
# trie = _TrieNode()
|
| 71 |
-
# for n in names:
|
| 72 |
-
# if not n:
|
| 73 |
-
# continue
|
| 74 |
-
# try:
|
| 75 |
-
# ids = _encode_identifier(tokenizer, n)
|
| 76 |
-
# except Exception:
|
| 77 |
-
# continue
|
| 78 |
-
# if ids:
|
| 79 |
-
# trie.insert(ids)
|
| 80 |
-
# return trie
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
# def _allow_always_token_ids(tokenizer) -> torch.Tensor:
|
| 84 |
-
# # Allow common delimiters so the model can end an identifier.
|
| 85 |
-
# toks = [",", ")", "(", "\n", ".", ";"]
|
| 86 |
-
# ids: Set[int] = set()
|
| 87 |
-
# for t in toks:
|
| 88 |
-
# try:
|
| 89 |
-
# for tid in tokenizer.encode(t, add_special_tokens=False):
|
| 90 |
-
# ids.add(int(tid))
|
| 91 |
-
# except Exception:
|
| 92 |
-
# continue
|
| 93 |
-
# return torch.tensor(sorted(ids), dtype=torch.long)
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
# @dataclass
|
| 97 |
-
# class _PerDbTokenSets:
|
| 98 |
-
# fp: str
|
| 99 |
-
# table_trie: _TrieNode
|
| 100 |
-
# column_trie: _TrieNode
|
| 101 |
-
# allow_always: torch.Tensor
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
# _DB_TOKENSET_LOCK = threading.Lock()
|
| 105 |
-
# _DB_TOKENSETS: Dict[str, _PerDbTokenSets] = {}
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
# def _per_db_tokensets(tokenizer, graph: ConstraintGraph) -> _PerDbTokenSets:
|
| 109 |
-
# with _DB_TOKENSET_LOCK:
|
| 110 |
-
# cached = _DB_TOKENSETS.get(graph.db_path)
|
| 111 |
-
# if cached is not None and cached.fp == graph.fingerprint:
|
| 112 |
-
# return cached
|
| 113 |
-
|
| 114 |
-
# out = _PerDbTokenSets(
|
| 115 |
-
# fp=graph.fingerprint,
|
| 116 |
-
# table_trie=_build_trie(tokenizer, graph.tables),
|
| 117 |
-
# column_trie=_build_trie(tokenizer, graph.all_columns),
|
| 118 |
-
# allow_always=_allow_always_token_ids(tokenizer),
|
| 119 |
-
# )
|
| 120 |
-
# with _DB_TOKENSET_LOCK:
|
| 121 |
-
# _DB_TOKENSETS[graph.db_path] = out
|
| 122 |
-
# return out
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
# class BatchSchemaConstrainedLogitsProcessor(LogitsProcessor):
|
| 126 |
-
# """
|
| 127 |
-
# Schema-aware constrained decoding per item in the generation batch.
|
| 128 |
-
# Uses a tokenizer-based trie so multi-token identifiers can be constrained.
|
| 129 |
-
# """
|
| 130 |
-
|
| 131 |
-
# def __init__(self, tokenizer, db_paths: Sequence[str], *, max_prefix_tokens: int = 48):
|
| 132 |
-
# self.tokenizer = tokenizer
|
| 133 |
-
# self.db_paths = list(db_paths)
|
| 134 |
-
# self.max_prefix_tokens = int(max_prefix_tokens)
|
| 135 |
-
|
| 136 |
-
# self._graphs = [build_constraint_graph(p) for p in self.db_paths]
|
| 137 |
-
# self._token_sets = [_per_db_tokensets(tokenizer, g) for g in self._graphs]
|
| 138 |
-
|
| 139 |
-
# def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
| 140 |
-
# if input_ids.dim() != 2 or scores.dim() != 2:
|
| 141 |
-
# return scores
|
| 142 |
-
|
| 143 |
-
# batch = input_ids.size(0)
|
| 144 |
-
# if batch != len(self._graphs):
|
| 145 |
-
# return scores
|
| 146 |
-
|
| 147 |
-
# for i in range(batch):
|
| 148 |
-
# tail_ids = input_ids[i, -self.max_prefix_tokens :].tolist()
|
| 149 |
-
# prefix_text = self.tokenizer.decode(tail_ids, skip_special_tokens=True)
|
| 150 |
-
# expected = _infer_expected_identifier(prefix_text)
|
| 151 |
-
# if expected is None:
|
| 152 |
-
# continue
|
| 153 |
-
|
| 154 |
-
# if expected == "table":
|
| 155 |
-
# m = re.search(r"(?:from|join)\s+([A-Za-z_][A-Za-z0-9_]*)$", prefix_text, flags=re.I)
|
| 156 |
-
# partial = m.group(1) if m else None
|
| 157 |
-
# if partial is None and not re.search(r"(?:from|join)\s*$", prefix_text, flags=re.I):
|
| 158 |
-
# continue
|
| 159 |
-
# trie = self._token_sets[i].table_trie
|
| 160 |
-
# else:
|
| 161 |
-
# m = re.search(
|
| 162 |
-
# r"(?:select|where|on|group by|order by|having)\s+([A-Za-z_][A-Za-z0-9_]*(?:\.[A-Za-z_][A-Za-z0-9_]*)?)$",
|
| 163 |
-
# prefix_text,
|
| 164 |
-
# flags=re.I,
|
| 165 |
-
# )
|
| 166 |
-
# partial = m.group(1) if m else None
|
| 167 |
-
# if partial is None and not re.search(
|
| 168 |
-
# r"(?:select|where|on|group by|order by|having)\s*$", prefix_text, flags=re.I
|
| 169 |
-
# ):
|
| 170 |
-
# continue
|
| 171 |
-
# trie = self._token_sets[i].column_trie
|
| 172 |
-
|
| 173 |
-
# if not partial:
|
| 174 |
-
# prefix_token_ids: List[int] = []
|
| 175 |
-
# else:
|
| 176 |
-
# try:
|
| 177 |
-
# prefix_token_ids = _encode_identifier(self.tokenizer, partial)
|
| 178 |
-
# except Exception:
|
| 179 |
-
# continue
|
| 180 |
-
|
| 181 |
-
# node = trie.walk(prefix_token_ids)
|
| 182 |
-
# if node is None or node.terminal:
|
| 183 |
-
# continue
|
| 184 |
-
|
| 185 |
-
# allowed_next = sorted(node.children.keys())
|
| 186 |
-
# if not allowed_next:
|
| 187 |
-
# continue
|
| 188 |
-
|
| 189 |
-
# allowed_next_t = torch.tensor(allowed_next, dtype=torch.long, device=scores.device)
|
| 190 |
-
# allow_always = self._token_sets[i].allow_always.to(scores.device)
|
| 191 |
-
# keep = torch.cat([allowed_next_t, allow_always]) if allow_always.numel() else allowed_next_t
|
| 192 |
-
|
| 193 |
-
# kept_scores = scores[i, keep].clone()
|
| 194 |
-
# scores[i, :] = -float("inf")
|
| 195 |
-
# scores[i, keep] = kept_scores
|
| 196 |
-
|
| 197 |
-
# return scores
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
# # Backwards-compatible names used elsewhere in the repo.
|
| 201 |
-
# class SchemaConstraintGraph:
|
| 202 |
-
# def __init__(self, db_path: str):
|
| 203 |
-
# self._graph = build_constraint_graph(db_path)
|
| 204 |
-
# self.tables = sorted(self._graph.tables)
|
| 205 |
-
# self.columns = sorted(self._graph.all_columns)
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
# class SchemaConstrainedLogitsProcessor(LogitsProcessor):
|
| 209 |
-
# def __init__(self, tokenizer, schema_graph: SchemaConstraintGraph):
|
| 210 |
-
# self._proc = BatchSchemaConstrainedLogitsProcessor(tokenizer, [schema_graph._graph.db_path])
|
| 211 |
-
|
| 212 |
-
# def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
| 213 |
-
# return self._proc(input_ids, scores)
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
# from __future__ import annotations
|
| 219 |
-
|
| 220 |
-
# import re
|
| 221 |
-
# import threading
|
| 222 |
-
# from dataclasses import dataclass
|
| 223 |
-
# from typing import Dict, Iterable, List, Optional, Sequence, Set
|
| 224 |
-
|
| 225 |
-
# import torch
|
| 226 |
-
# from transformers.generation.logits_process import LogitsProcessor
|
| 227 |
-
|
| 228 |
-
# from schema_constraints import ConstraintGraph, build_constraint_graph
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
# # =========================================================
|
| 232 |
-
# # 🔍 IDENTIFIER TYPE DETECTION
|
| 233 |
-
# # =========================================================
|
| 234 |
-
# def _infer_expected_identifier(prefix_text: str) -> Optional[str]:
|
| 235 |
-
# s = re.sub(r"\s+", " ", prefix_text.lower())
|
| 236 |
-
|
| 237 |
-
# last_from = s.rfind(" from ")
|
| 238 |
-
# last_join = s.rfind(" join ")
|
| 239 |
-
# last_select = s.rfind(" select ")
|
| 240 |
-
# last_where = s.rfind(" where ")
|
| 241 |
-
# last_on = s.rfind(" on ")
|
| 242 |
-
# last_group = s.rfind(" group by ")
|
| 243 |
-
# last_order = s.rfind(" order by ")
|
| 244 |
-
# last_having = s.rfind(" having ")
|
| 245 |
-
|
| 246 |
-
# last_table_kw = max(last_from, last_join)
|
| 247 |
-
# last_col_kw = max(last_select, last_where, last_on, last_group, last_order, last_having)
|
| 248 |
-
|
| 249 |
-
# if last_table_kw < 0 and last_col_kw < 0:
|
| 250 |
-
# return None
|
| 251 |
-
# if last_table_kw > last_col_kw:
|
| 252 |
-
# return "table"
|
| 253 |
-
# if last_col_kw > last_table_kw:
|
| 254 |
-
# return "column"
|
| 255 |
-
# return None
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
# # =========================================================
|
| 259 |
-
# # 🌳 TRIE STRUCTURE
|
| 260 |
-
# # =========================================================
|
| 261 |
-
# class _TrieNode:
|
| 262 |
-
# __slots__ = ("children", "terminal")
|
| 263 |
-
|
| 264 |
-
# def __init__(self) -> None:
|
| 265 |
-
# self.children: Dict[int, _TrieNode] = {}
|
| 266 |
-
# self.terminal: bool = False
|
| 267 |
-
|
| 268 |
-
# def insert(self, token_ids: Sequence[int]) -> None:
|
| 269 |
-
# node = self
|
| 270 |
-
# for tid in token_ids:
|
| 271 |
-
# tid = int(tid)
|
| 272 |
-
# if tid not in node.children:
|
| 273 |
-
# node.children[tid] = _TrieNode()
|
| 274 |
-
# node = node.children[tid]
|
| 275 |
-
# node.terminal = True
|
| 276 |
-
|
| 277 |
-
# def walk(self, prefix: Sequence[int]) -> Optional["_TrieNode"]:
|
| 278 |
-
# node = self
|
| 279 |
-
# for tid in prefix:
|
| 280 |
-
# node = node.children.get(int(tid))
|
| 281 |
-
# if node is None:
|
| 282 |
-
# return None
|
| 283 |
-
# return node
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
# # =========================================================
|
| 287 |
-
# # 🔤 TOKEN ENCODING
|
| 288 |
-
# # =========================================================
|
| 289 |
-
# def _encode_identifier(tokenizer, name: str) -> List[int]:
|
| 290 |
-
# return tokenizer.encode(" " + name, add_special_tokens=False)
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
# def _build_trie(tokenizer, names: Iterable[str]) -> _TrieNode:
|
| 294 |
-
# trie = _TrieNode()
|
| 295 |
-
# for name in names:
|
| 296 |
-
# try:
|
| 297 |
-
# ids = _encode_identifier(tokenizer, name)
|
| 298 |
-
# if ids:
|
| 299 |
-
# trie.insert(ids)
|
| 300 |
-
# except Exception:
|
| 301 |
-
# continue
|
| 302 |
-
# return trie
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
# def _allow_always_token_ids(tokenizer) -> torch.Tensor:
|
| 306 |
-
# tokens = [",", ")", "(", ".", ";", "\n"]
|
| 307 |
-
# ids: Set[int] = set()
|
| 308 |
-
|
| 309 |
-
# for t in tokens:
|
| 310 |
-
# try:
|
| 311 |
-
# ids.update(tokenizer.encode(t, add_special_tokens=False))
|
| 312 |
-
# except:
|
| 313 |
-
# pass
|
| 314 |
-
|
| 315 |
-
# return torch.tensor(sorted(ids), dtype=torch.long)
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
# # =========================================================
|
| 319 |
-
# # 📦 PER-DB CACHE
|
| 320 |
-
# # =========================================================
|
| 321 |
-
# @dataclass
|
| 322 |
-
# class _PerDbTokenSets:
|
| 323 |
-
# fp: str
|
| 324 |
-
# table_trie: _TrieNode
|
| 325 |
-
# column_trie: _TrieNode
|
| 326 |
-
# allow_always: torch.Tensor
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
# _DB_CACHE: Dict[str, _PerDbTokenSets] = {}
|
| 330 |
-
# _DB_LOCK = threading.Lock()
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
# def _per_db_tokensets(tokenizer, graph: ConstraintGraph) -> _PerDbTokenSets:
|
| 334 |
-
# with _DB_LOCK:
|
| 335 |
-
# cached = _DB_CACHE.get(graph.db_path)
|
| 336 |
-
# if cached and cached.fp == graph.fingerprint:
|
| 337 |
-
# return cached
|
| 338 |
-
|
| 339 |
-
# obj = _PerDbTokenSets(
|
| 340 |
-
# fp=graph.fingerprint,
|
| 341 |
-
# table_trie=_build_trie(tokenizer, graph.tables),
|
| 342 |
-
# column_trie=_build_trie(tokenizer, graph.all_columns),
|
| 343 |
-
# allow_always=_allow_always_token_ids(tokenizer),
|
| 344 |
-
# )
|
| 345 |
-
|
| 346 |
-
# with _DB_LOCK:
|
| 347 |
-
# _DB_CACHE[graph.db_path] = obj
|
| 348 |
-
|
| 349 |
-
# return obj
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
# # =========================================================
|
| 353 |
-
# # 🚀 MAIN LOGITS PROCESSOR
|
| 354 |
-
# # =========================================================
|
| 355 |
-
# class BatchSchemaConstrainedLogitsProcessor(LogitsProcessor):
|
| 356 |
-
# def __init__(self, tokenizer, db_paths: Sequence[str], max_prefix_tokens: int = 48):
|
| 357 |
-
# self.tokenizer = tokenizer
|
| 358 |
-
# self.db_paths = list(db_paths)
|
| 359 |
-
# self.max_prefix_tokens = max_prefix_tokens
|
| 360 |
-
|
| 361 |
-
# self._graphs = [build_constraint_graph(p) for p in db_paths]
|
| 362 |
-
# self._token_sets = [_per_db_tokensets(tokenizer, g) for g in self._graphs]
|
| 363 |
-
|
| 364 |
-
# # 📊 Metrics (IMPORTANT FOR REPORT)
|
| 365 |
-
# self.total_steps = 0
|
| 366 |
-
# self.constrained_steps = 0
|
| 367 |
-
|
| 368 |
-
# def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
|
| 369 |
-
# batch = input_ids.size(0)
|
| 370 |
-
|
| 371 |
-
# for i in range(batch):
|
| 372 |
-
# self.total_steps += 1
|
| 373 |
-
|
| 374 |
-
# tail_ids = input_ids[i, -self.max_prefix_tokens:].tolist()
|
| 375 |
-
# prefix_text = self.tokenizer.decode(tail_ids, skip_special_tokens=True)
|
| 376 |
-
|
| 377 |
-
# expected = _infer_expected_identifier(prefix_text)
|
| 378 |
-
# if expected is None:
|
| 379 |
-
# continue
|
| 380 |
-
|
| 381 |
-
# self.constrained_steps += 1
|
| 382 |
-
|
| 383 |
-
# # =========================
|
| 384 |
-
# # SELECT TRIE
|
| 385 |
-
# # =========================
|
| 386 |
-
# if expected == "table":
|
| 387 |
-
# trie = self._token_sets[i].table_trie
|
| 388 |
-
# else:
|
| 389 |
-
# trie = self._token_sets[i].column_trie
|
| 390 |
-
|
| 391 |
-
# # =========================
|
| 392 |
-
# # PARTIAL TOKEN MATCH
|
| 393 |
-
# # =========================
|
| 394 |
-
# match = re.search(r"([A-Za-z_][A-Za-z0-9_]*)$", prefix_text)
|
| 395 |
-
# partial = match.group(1) if match else ""
|
| 396 |
-
|
| 397 |
-
# try:
|
| 398 |
-
# prefix_ids = _encode_identifier(self.tokenizer, partial) if partial else []
|
| 399 |
-
# except:
|
| 400 |
-
# continue
|
| 401 |
-
|
| 402 |
-
# node = trie.walk(prefix_ids)
|
| 403 |
-
# if node is None or node.terminal:
|
| 404 |
-
# continue
|
| 405 |
-
|
| 406 |
-
# allowed_next = list(node.children.keys())
|
| 407 |
-
# if not allowed_next:
|
| 408 |
-
# continue
|
| 409 |
-
|
| 410 |
-
# allowed_next = torch.tensor(allowed_next, device=scores.device)
|
| 411 |
-
# allow_always = self._token_sets[i].allow_always.to(scores.device)
|
| 412 |
-
|
| 413 |
-
# keep = torch.cat([allowed_next, allow_always])
|
| 414 |
-
|
| 415 |
-
# kept_scores = scores[i, keep].clone()
|
| 416 |
-
# scores[i, :] = -float("inf")
|
| 417 |
-
# scores[i, keep] = kept_scores
|
| 418 |
-
|
| 419 |
-
# return scores
|
| 420 |
-
|
| 421 |
-
# # =========================================================
|
| 422 |
-
# # 📊 METRICS FOR REPORT
|
| 423 |
-
# # =========================================================
|
| 424 |
-
# def get_constraint_stats(self):
|
| 425 |
-
# if self.total_steps == 0:
|
| 426 |
-
# return 0
|
| 427 |
-
# return self.constrained_steps / self.total_steps
|
| 428 |
-
|
| 429 |
-
|
| 430 |
-
# # =========================================================
|
| 431 |
-
# # 🔁 BACKWARD COMPATIBILITY
|
| 432 |
-
# # =========================================================
|
| 433 |
-
# class SchemaConstraintGraph:
|
| 434 |
-
# def __init__(self, db_path: str):
|
| 435 |
-
# self._graph = build_constraint_graph(db_path)
|
| 436 |
-
# self.tables = sorted(self._graph.tables)
|
| 437 |
-
# self.columns = sorted(self._graph.all_columns)
|
| 438 |
-
|
| 439 |
-
|
| 440 |
-
# class SchemaConstrainedLogitsProcessor(LogitsProcessor):
|
| 441 |
-
# def __init__(self, tokenizer, schema_graph: SchemaConstraintGraph):
|
| 442 |
-
# self.proc = BatchSchemaConstrainedLogitsProcessor(
|
| 443 |
-
# tokenizer, [schema_graph._graph.db_path]
|
| 444 |
-
# )
|
| 445 |
-
|
| 446 |
-
# def __call__(self, input_ids, scores):
|
| 447 |
-
# return self.proc(input_ids, scores)
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
|
| 452 |
-
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
|
| 456 |
-
# ********* after task 3
|
| 457 |
-
|
| 458 |
-
import re
|
| 459 |
-
import torch
|
| 460 |
-
from transformers import LogitsProcessor
|
| 461 |
-
from src.schema_utils import get_constraint_graph
|
| 462 |
-
|
| 463 |
-
|
| 464 |
-
def _infer_expected_identifier(prefix_text: str):
|
| 465 |
-
s = prefix_text.lower()
|
| 466 |
-
|
| 467 |
-
if " from " in s or " join " in s:
|
| 468 |
-
return "table"
|
| 469 |
-
if any(k in s for k in ["select", "where", "on", "group by", "order by"]):
|
| 470 |
-
return "column"
|
| 471 |
-
|
| 472 |
-
return None
|
| 473 |
-
|
| 474 |
-
|
| 475 |
-
class SchemaConstrainedLogitsProcessor(LogitsProcessor):
|
| 476 |
-
def __init__(self, tokenizer, db_path):
|
| 477 |
-
self.tokenizer = tokenizer
|
| 478 |
-
|
| 479 |
-
graph = get_constraint_graph(db_path)
|
| 480 |
-
|
| 481 |
-
self.allowed_tokens = set(graph["tables"]) | set(graph["columns"])
|
| 482 |
-
|
| 483 |
-
self.sql_keywords = {
|
| 484 |
-
"select", "from", "where", "join", "on",
|
| 485 |
-
"group", "by", "order", "limit",
|
| 486 |
-
"and", "or", "desc", "asc",
|
| 487 |
-
"count", "avg", "min", "max", "sum", "*"
|
| 488 |
-
}
|
| 489 |
-
|
| 490 |
-
self.allowed_tokens |= self.sql_keywords
|
| 491 |
-
|
| 492 |
-
self.allowed_token_ids = set()
|
| 493 |
-
for token in self.allowed_tokens:
|
| 494 |
-
ids = tokenizer.encode(token, add_special_tokens=False)
|
| 495 |
-
for i in ids:
|
| 496 |
-
self.allowed_token_ids.add(i)
|
| 497 |
-
|
| 498 |
-
def __call__(self, input_ids, scores):
|
| 499 |
-
|
| 500 |
-
prefix = self.tokenizer.decode(input_ids[0], skip_special_tokens=True)
|
| 501 |
-
|
| 502 |
-
# 🔥 SOFT CONSTRAINT (FIX)
|
| 503 |
-
if len(prefix.strip()) < 10:
|
| 504 |
-
return scores
|
| 505 |
-
|
| 506 |
-
expected = _infer_expected_identifier(prefix)
|
| 507 |
-
|
| 508 |
-
if expected not in ["table", "column"]:
|
| 509 |
-
return scores
|
| 510 |
-
|
| 511 |
-
mask = torch.full_like(scores, float("-inf"))
|
| 512 |
-
|
| 513 |
-
for token_id in self.allowed_token_ids:
|
| 514 |
-
mask[:, token_id] = scores[:, token_id]
|
| 515 |
-
|
| 516 |
-
return mask
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/eval_rl_fixed.py
CHANGED
|
@@ -1,756 +1,466 @@
|
|
| 1 |
# import json
|
| 2 |
-
# import subprocess
|
| 3 |
-
# import sys
|
| 4 |
-
# import argparse
|
| 5 |
-
# import random
|
| 6 |
# import sqlite3
|
| 7 |
-
# import
|
| 8 |
-
# import re
|
| 9 |
-
# import os
|
| 10 |
# from pathlib import Path
|
| 11 |
-
|
| 12 |
# import torch
|
| 13 |
# from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
| 14 |
# from peft import PeftModel
|
| 15 |
|
| 16 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
-
#
|
| 19 |
-
#
|
| 20 |
-
#
|
| 21 |
-
#
|
| 22 |
-
# sql = sql.replace('"', "'")
|
| 23 |
-
# sql = re.sub(r"\s+", " ", sql)
|
| 24 |
-
# return sql.strip().lower().rstrip(";")
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
# # -------------------------------
|
| 28 |
-
# # 🔥 SAFE RESULT NORMALIZATION (FIX)
|
| 29 |
-
# # -------------------------------
|
| 30 |
-
# def normalize_result(res):
|
| 31 |
-
# try:
|
| 32 |
-
# return sorted([str(r) for r in res])
|
| 33 |
-
# except:
|
| 34 |
-
# return []
|
| 35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
-
#
|
| 38 |
-
#
|
| 39 |
-
#
|
| 40 |
-
# def check_execution(pred_sql, gold_sql, db_path):
|
| 41 |
-
# try:
|
| 42 |
-
# conn = sqlite3.connect(db_path)
|
| 43 |
-
# conn.text_factory = lambda b: b.decode(errors='ignore')
|
| 44 |
|
| 45 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
|
| 47 |
-
#
|
| 48 |
-
#
|
| 49 |
|
| 50 |
-
# conn.set_progress_handler(timeout_handler, 10000)
|
| 51 |
|
| 52 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
|
| 54 |
-
#
|
| 55 |
-
#
|
| 56 |
|
| 57 |
-
#
|
| 58 |
-
#
|
| 59 |
|
| 60 |
# conn.close()
|
| 61 |
-
|
| 62 |
-
# # 🔥 FIXED COMPARISON
|
| 63 |
-
# return normalize_result(pred_res) == normalize_result(gold_res)
|
| 64 |
|
| 65 |
# except Exception:
|
| 66 |
# return False
|
| 67 |
|
| 68 |
|
| 69 |
-
# # -------------------------------
|
| 70 |
-
# # SPIDER PARSER
|
| 71 |
-
# # -------------------------------
|
| 72 |
-
# def _parse_spider_accuracy(stdout: str, metric_type: str):
|
| 73 |
-
# for line in stdout.splitlines():
|
| 74 |
-
# if metric_type == "exec" and line.strip().startswith("execution"):
|
| 75 |
-
# try:
|
| 76 |
-
# return float(line.split()[-1])
|
| 77 |
-
# except:
|
| 78 |
-
# pass
|
| 79 |
-
# elif metric_type == "match" and line.strip().startswith("exact"):
|
| 80 |
-
# try:
|
| 81 |
-
# return float(line.split()[-1])
|
| 82 |
-
# except:
|
| 83 |
-
# pass
|
| 84 |
-
# return None
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
# # -------------------------------
|
| 88 |
-
# # MAIN
|
| 89 |
-
# # -------------------------------
|
| 90 |
# def main():
|
| 91 |
# parser = argparse.ArgumentParser()
|
| 92 |
# parser.add_argument("--adapter", type=str, required=True)
|
| 93 |
-
# parser.add_argument("--num_samples", type=int, default=
|
| 94 |
-
# parser.add_argument("--shuffle_dev", action="store_true")
|
| 95 |
-
# parser.add_argument("--shuffle_seed", type=int, default=42)
|
| 96 |
# args = parser.parse_args()
|
| 97 |
|
| 98 |
# project_root = Path(__file__).resolve().parents[1]
|
| 99 |
-
# adapter_dir = project_root / args.adapter
|
| 100 |
|
| 101 |
-
# db_root = project_root / "data" / "database"
|
| 102 |
-
# table_json = project_root / "data" / "tables.json"
|
| 103 |
# dev_json = project_root / "data" / "dev.json"
|
|
|
|
| 104 |
|
| 105 |
-
#
|
| 106 |
-
# temp_gold_path = project_root / "temp_gold.sql"
|
| 107 |
|
| 108 |
-
#
|
| 109 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
|
| 111 |
-
#
|
| 112 |
-
#
|
| 113 |
-
# )
|
| 114 |
-
# print(f"Using device: {device}")
|
| 115 |
|
| 116 |
-
#
|
| 117 |
-
# tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
|
| 118 |
|
| 119 |
-
#
|
| 120 |
-
# tokenizer.pad_token = tokenizer.eos_token
|
| 121 |
|
| 122 |
-
#
|
|
|
|
|
|
|
|
|
|
| 123 |
|
| 124 |
-
#
|
|
|
|
| 125 |
|
| 126 |
-
#
|
| 127 |
|
| 128 |
-
#
|
| 129 |
-
# base,
|
| 130 |
-
# adapter_for_peft,
|
| 131 |
-
# local_files_only=True
|
| 132 |
-
# ).to(device)
|
| 133 |
-
|
| 134 |
-
# model = model.merge_and_unload()
|
| 135 |
-
# model.eval()
|
| 136 |
-
|
| 137 |
-
# # -------------------------------
|
| 138 |
-
# # LOAD DATA
|
| 139 |
-
# # -------------------------------
|
| 140 |
-
# with dev_json.open() as f:
|
| 141 |
-
# dev = json.load(f)
|
| 142 |
-
|
| 143 |
-
# if args.shuffle_dev:
|
| 144 |
-
# rng = random.Random(args.shuffle_seed)
|
| 145 |
-
# rng.shuffle(dev)
|
| 146 |
-
|
| 147 |
-
# dev = dev[: args.num_samples]
|
| 148 |
-
# total = len(dev)
|
| 149 |
-
|
| 150 |
-
# gen_kwargs = dict(
|
| 151 |
-
# max_new_tokens=160,
|
| 152 |
-
# num_beams=8,
|
| 153 |
-
# length_penalty=0.8,
|
| 154 |
-
# do_sample=False,
|
| 155 |
-
# early_stopping=True,
|
| 156 |
-
# pad_token_id=tokenizer.pad_token_id,
|
| 157 |
-
# eos_token_id=tokenizer.eos_token_id,
|
| 158 |
-
# )
|
| 159 |
-
|
| 160 |
-
# print(f"\n🚀 Evaluating {total} samples...\n")
|
| 161 |
-
|
| 162 |
-
# em_correct = 0
|
| 163 |
-
# ex_correct = 0
|
| 164 |
-
|
| 165 |
-
# with pred_path.open("w") as out_pred, temp_gold_path.open("w") as out_gold, torch.no_grad():
|
| 166 |
-
# for i, ex in enumerate(dev, start=1):
|
| 167 |
-
|
| 168 |
-
# db_id = ex["db_id"]
|
| 169 |
-
# question = ex["question"]
|
| 170 |
-
# gold_query = ex["query"]
|
| 171 |
-
# db_path = db_root / db_id / f"{db_id}.sqlite"
|
| 172 |
-
|
| 173 |
-
# # -------------------------------
|
| 174 |
-
# # GENERATE SQL
|
| 175 |
-
# # -------------------------------
|
| 176 |
-
# input_ids = encode_prompt(
|
| 177 |
-
# tokenizer,
|
| 178 |
-
# question,
|
| 179 |
-
# db_id,
|
| 180 |
-
# device=device,
|
| 181 |
-
# max_input_tokens=512
|
| 182 |
-
# )
|
| 183 |
-
|
| 184 |
-
# input_ids = input_ids.unsqueeze(0).to(device)
|
| 185 |
-
# attention_mask = (input_ids != tokenizer.pad_token_id).long().to(device)
|
| 186 |
|
|
|
|
| 187 |
# outputs = model.generate(
|
| 188 |
-
#
|
| 189 |
-
#
|
| 190 |
-
#
|
|
|
|
| 191 |
# )
|
| 192 |
|
| 193 |
-
#
|
| 194 |
-
|
| 195 |
-
# # -------------------------------
|
| 196 |
-
# # SAVE FOR SPIDER EVAL
|
| 197 |
-
# # -------------------------------
|
| 198 |
-
# out_pred.write(f"{pred_sql}\n")
|
| 199 |
-
# out_gold.write(f"{gold_query}\t{db_id}\n")
|
| 200 |
-
|
| 201 |
-
# # -------------------------------
|
| 202 |
-
# # LIVE METRICS
|
| 203 |
-
# # -------------------------------
|
| 204 |
-
# if normalize_sql(pred_sql) == normalize_sql(gold_query):
|
| 205 |
-
# em_correct += 1
|
| 206 |
-
|
| 207 |
-
# if check_execution(pred_sql, gold_query, db_path):
|
| 208 |
-
# ex_correct += 1
|
| 209 |
-
|
| 210 |
-
# if i % 20 == 0 or i == total:
|
| 211 |
-
# print(
|
| 212 |
-
# f"Progress: {i}/{total} | "
|
| 213 |
-
# f"EM: {(em_correct/i)*100:.2f}% | "
|
| 214 |
-
# f"EX: {(ex_correct/i)*100:.2f}%"
|
| 215 |
-
# )
|
| 216 |
|
| 217 |
-
#
|
|
|
|
| 218 |
|
| 219 |
-
#
|
| 220 |
|
| 221 |
-
#
|
| 222 |
-
#
|
| 223 |
-
# sys.executable, str(eval_script),
|
| 224 |
-
# "--gold", str(temp_gold_path),
|
| 225 |
-
# "--pred", str(pred_path),
|
| 226 |
-
# "--etype", "match",
|
| 227 |
-
# "--db", str(db_root),
|
| 228 |
-
# "--table", str(table_json),
|
| 229 |
-
# ]
|
| 230 |
|
| 231 |
-
#
|
| 232 |
-
#
|
| 233 |
-
|
| 234 |
-
# # EXECUTION
|
| 235 |
-
# cmd_exec = [
|
| 236 |
-
# sys.executable, str(eval_script),
|
| 237 |
-
# "--gold", str(temp_gold_path),
|
| 238 |
-
# "--pred", str(pred_path),
|
| 239 |
-
# "--etype", "exec",
|
| 240 |
-
# "--db", str(db_root),
|
| 241 |
-
# "--table", str(table_json),
|
| 242 |
-
# ]
|
| 243 |
-
|
| 244 |
-
# proc_exec = subprocess.run(cmd_exec, capture_output=True, text=True)
|
| 245 |
-
# exec_acc = _parse_spider_accuracy(proc_exec.stdout, "exec")
|
| 246 |
-
|
| 247 |
-
# print("==========================================")
|
| 248 |
-
# print(f"🎯 OFFICIAL SPIDER RESULTS FOR: {args.adapter}")
|
| 249 |
-
# print("==========================================")
|
| 250 |
-
|
| 251 |
-
# print(f"Exact Match Accuracy : {exact_acc*100:.2f}%" if exact_acc else "EM parsing failed")
|
| 252 |
-
# print(f"Execution Accuracy : {exec_acc*100:.2f}%" if exec_acc else "EX parsing failed")
|
| 253 |
|
| 254 |
-
# print("=============================
|
|
|
|
|
|
|
| 255 |
|
| 256 |
|
| 257 |
# if __name__ == "__main__":
|
| 258 |
# main()
|
| 259 |
|
| 260 |
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
# import json
|
| 266 |
# import sqlite3
|
| 267 |
-
# import re
|
| 268 |
-
# import time
|
| 269 |
-
# import sys
|
| 270 |
# import argparse
|
|
|
|
| 271 |
# from pathlib import Path
|
|
|
|
|
|
|
|
|
|
| 272 |
|
| 273 |
-
# #
|
| 274 |
-
#
|
| 275 |
-
#
|
| 276 |
-
#
|
| 277 |
-
#
|
| 278 |
-
# sys.path.insert(0, str(PROJECT_ROOT))
|
| 279 |
-
|
| 280 |
-
# from src.text2sql_engine import get_engine
|
| 281 |
-
# from src.sql_validator import validate_sql_schema
|
| 282 |
-
|
| 283 |
-
# # ==========================================
|
| 284 |
-
# # CONFIG
|
| 285 |
-
# # ==========================================
|
| 286 |
-
# DATA_PATH = PROJECT_ROOT / "data" / "dev.json"
|
| 287 |
-
# DB_ROOT = PROJECT_ROOT / "data" / "database"
|
| 288 |
-
|
| 289 |
-
# # ==========================================
|
| 290 |
-
# # NORMALIZATION
|
| 291 |
-
# # ==========================================
|
| 292 |
-
# def normalize_sql(sql):
|
| 293 |
-
# if not isinstance(sql, str):
|
| 294 |
-
# return ""
|
| 295 |
-
# sql = sql.replace('"', "'")
|
| 296 |
-
# sql = re.sub(r"\s+", " ", sql)
|
| 297 |
-
# return sql.strip().lower().rstrip(";")
|
| 298 |
-
|
| 299 |
-
# def normalize_result(res):
|
| 300 |
-
# try:
|
| 301 |
-
# return sorted([tuple(map(str, r)) for r in res])
|
| 302 |
-
# except:
|
| 303 |
-
# return []
|
| 304 |
-
|
| 305 |
-
# # ==========================================
|
| 306 |
-
# # EXECUTION
|
| 307 |
-
# # ==========================================
|
| 308 |
-
# def execute_sql(db_path, sql):
|
| 309 |
-
# try:
|
| 310 |
-
# conn = sqlite3.connect(db_path)
|
| 311 |
-
|
| 312 |
-
# start = time.time()
|
| 313 |
-
# def timeout():
|
| 314 |
-
# return 1 if (time.time() - start) > 2 else 0
|
| 315 |
-
|
| 316 |
-
# conn.set_progress_handler(timeout, 10000)
|
| 317 |
-
|
| 318 |
-
# cur = conn.cursor()
|
| 319 |
-
# cur.execute(sql)
|
| 320 |
-
# res = cur.fetchall()
|
| 321 |
-
|
| 322 |
-
# conn.close()
|
| 323 |
-
# return res
|
| 324 |
-
|
| 325 |
-
# except Exception:
|
| 326 |
-
# return None
|
| 327 |
-
|
| 328 |
-
# # ==========================================
|
| 329 |
-
# # EVALUATION
|
| 330 |
-
# # ==========================================
|
| 331 |
-
# def evaluate(engine, data, is_constrained=False, debug=False):
|
| 332 |
-
|
| 333 |
-
# attempted = 0
|
| 334 |
-
# total = 0
|
| 335 |
-
# exact_match = 0
|
| 336 |
-
# execution_match = 0
|
| 337 |
-
# constraint_ok = 0
|
| 338 |
-
|
| 339 |
-
# skipped_missing_db = 0
|
| 340 |
-
# skipped_exception = 0
|
| 341 |
-
# skipped_no_sql = 0
|
| 342 |
-
|
| 343 |
-
# total_time = 0
|
| 344 |
|
| 345 |
-
#
|
|
|
|
|
|
|
|
|
|
| 346 |
|
| 347 |
-
#
|
| 348 |
-
#
|
| 349 |
-
#
|
|
|
|
| 350 |
|
| 351 |
-
#
|
|
|
|
|
|
|
| 352 |
|
| 353 |
-
#
|
| 354 |
-
#
|
| 355 |
-
#
|
|
|
|
|
|
|
| 356 |
|
| 357 |
-
#
|
| 358 |
-
#
|
| 359 |
-
# result = engine.ask(question, db_id)
|
| 360 |
-
# total_time += (time.time() - start)
|
| 361 |
-
# except Exception:
|
| 362 |
-
# skipped_exception += 1
|
| 363 |
-
# continue
|
| 364 |
|
| 365 |
-
# if not isinstance(result, dict):
|
| 366 |
-
# continue
|
| 367 |
|
| 368 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 369 |
|
| 370 |
-
#
|
| 371 |
-
# if debug:
|
| 372 |
-
# print(f"\nQ: {question}")
|
| 373 |
-
# print(f"PRED: {pred_sql}")
|
| 374 |
-
# print(f"GOLD: {gold_sql}")
|
| 375 |
|
| 376 |
-
#
|
| 377 |
-
#
|
| 378 |
-
# continue
|
| 379 |
|
| 380 |
-
#
|
| 381 |
-
#
|
| 382 |
|
| 383 |
-
#
|
| 384 |
-
#
|
| 385 |
-
# try:
|
| 386 |
-
# is_valid, _ = validate_sql_schema(pred_sql, str(db_path))
|
| 387 |
-
# if is_valid:
|
| 388 |
-
# constraint_ok += 1
|
| 389 |
-
# except:
|
| 390 |
-
# pass
|
| 391 |
|
| 392 |
-
#
|
| 393 |
-
#
|
| 394 |
-
# exact_match += 1
|
| 395 |
|
| 396 |
-
# # EXECUTION MATCH
|
| 397 |
-
# pred_res = execute_sql(str(db_path), pred_sql)
|
| 398 |
-
# gold_res = execute_sql(str(db_path), gold_sql)
|
| 399 |
|
| 400 |
-
#
|
| 401 |
-
#
|
| 402 |
-
#
|
|
|
|
|
|
|
|
|
|
| 403 |
|
| 404 |
-
#
|
| 405 |
-
# if i % 10 == 0:
|
| 406 |
-
# print(
|
| 407 |
-
# f"[{i}/{len(data)}] "
|
| 408 |
-
# f"EM: {exact_match/max(total,1):.3f} | "
|
| 409 |
-
# f"EX: {execution_match/max(total,1):.3f} | "
|
| 410 |
-
# f"Constraint: {(constraint_ok/max(total,1)) if is_constrained else 0:.3f}"
|
| 411 |
-
# )
|
| 412 |
|
| 413 |
-
#
|
| 414 |
-
|
| 415 |
-
# return {
|
| 416 |
-
# "exact_match": exact_match / total if total > 0 else 0,
|
| 417 |
-
# "execution_accuracy": execution_match / total if total > 0 else 0,
|
| 418 |
-
# "constraint_rate": (constraint_ok / total if (is_constrained and total > 0) else 0),
|
| 419 |
-
# "avg_latency": avg_latency,
|
| 420 |
-
# "total": total,
|
| 421 |
-
# "attempted": attempted,
|
| 422 |
-
# "skipped_missing_db": skipped_missing_db,
|
| 423 |
-
# "skipped_exception": skipped_exception,
|
| 424 |
-
# "skipped_no_sql": skipped_no_sql,
|
| 425 |
-
# }
|
| 426 |
-
|
| 427 |
-
# # ==========================================
|
| 428 |
-
# # MAIN
|
| 429 |
-
# # ==========================================
|
| 430 |
-
# if __name__ == "__main__":
|
| 431 |
|
| 432 |
-
#
|
| 433 |
-
#
|
| 434 |
-
|
| 435 |
-
#
|
| 436 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 437 |
|
| 438 |
-
#
|
|
|
|
| 439 |
|
| 440 |
-
#
|
| 441 |
-
# data = json.load(f)[: args.num_samples]
|
| 442 |
|
| 443 |
-
#
|
| 444 |
-
# # 🔴 BASE MODEL
|
| 445 |
-
# # ==========================================
|
| 446 |
-
# print("\n🚀 Running BASE MODEL...\n")
|
| 447 |
|
| 448 |
-
#
|
| 449 |
-
#
|
| 450 |
-
#
|
| 451 |
-
#
|
| 452 |
-
# )
|
| 453 |
|
| 454 |
-
#
|
|
|
|
| 455 |
|
| 456 |
-
#
|
| 457 |
-
# # 🟡 RLHF (NO CONSTRAINT)
|
| 458 |
-
# # ==========================================
|
| 459 |
-
# print("\n🚀 Running RLHF (NO CONSTRAINT)...\n")
|
| 460 |
|
| 461 |
-
#
|
| 462 |
-
# adapter_path="checkpoints/best_rlhf_model",
|
| 463 |
-
# use_lora=True,
|
| 464 |
-
# use_constrained=False
|
| 465 |
-
# )
|
| 466 |
|
| 467 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 468 |
|
| 469 |
-
#
|
| 470 |
-
# # 🟢 RLHF + CONSTRAINT
|
| 471 |
-
# # ==========================================
|
| 472 |
-
# print("\n🚀 Running RLHF + CONSTRAINED...\n")
|
| 473 |
|
| 474 |
-
#
|
| 475 |
-
#
|
| 476 |
-
# use_lora=True,
|
| 477 |
-
# use_constrained=True
|
| 478 |
-
# )
|
| 479 |
|
| 480 |
-
#
|
| 481 |
|
| 482 |
-
#
|
| 483 |
-
#
|
| 484 |
-
# # ==========================================
|
| 485 |
-
# print("\n==========================================")
|
| 486 |
-
# print("🎯 FINAL RESULTS (3-WAY COMPARISON)")
|
| 487 |
-
# print("==========================================")
|
| 488 |
|
| 489 |
-
#
|
| 490 |
-
#
|
| 491 |
|
| 492 |
-
# print(
|
| 493 |
-
#
|
|
|
|
| 494 |
|
| 495 |
-
# print(f"RLHF + Constrain → EM: {res_const['exact_match']*100:.2f}% | "
|
| 496 |
-
# f"EX: {res_const['execution_accuracy']*100:.2f}% | "
|
| 497 |
-
# f"Constraint: {res_const['constraint_rate']*100:.2f}%")
|
| 498 |
|
| 499 |
-
#
|
|
|
|
| 500 |
|
| 501 |
|
| 502 |
import json
|
|
|
|
|
|
|
| 503 |
import argparse
|
|
|
|
| 504 |
import sqlite3
|
| 505 |
import time
|
| 506 |
import re
|
| 507 |
-
import os
|
| 508 |
from pathlib import Path
|
| 509 |
|
| 510 |
import torch
|
| 511 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
| 512 |
from peft import PeftModel
|
| 513 |
|
| 514 |
-
#
|
| 515 |
-
|
| 516 |
-
from prompting import encode_prompt
|
| 517 |
-
from src.sql_validator import validate_sql_schema
|
| 518 |
-
except ImportError:
|
| 519 |
-
import sys
|
| 520 |
-
sys.path.append(str(Path(__file__).resolve().parents[1]))
|
| 521 |
-
from src.prompting import encode_prompt
|
| 522 |
-
from src.sql_validator import validate_sql_schema
|
| 523 |
-
|
| 524 |
-
# =========================================================
|
| 525 |
-
# ERROR LOGGING
|
| 526 |
-
# =========================================================
|
| 527 |
-
ERROR_LOG_FILE = "results/error_logs.json"
|
| 528 |
-
|
| 529 |
-
def classify_error(sql, error_msg=""):
|
| 530 |
-
sql = sql.lower()
|
| 531 |
-
error_msg = str(error_msg).lower()
|
| 532 |
-
|
| 533 |
-
if "no such column" in error_msg:
|
| 534 |
-
return "wrong_column"
|
| 535 |
-
if "no such table" in error_msg:
|
| 536 |
-
return "wrong_table"
|
| 537 |
-
if "syntax error" in error_msg:
|
| 538 |
-
return "syntax_error"
|
| 539 |
-
if "ambiguous column" in error_msg:
|
| 540 |
-
return "ambiguous_column"
|
| 541 |
-
if "join" in sql and " on " not in sql:
|
| 542 |
-
return "missing_join"
|
| 543 |
-
|
| 544 |
-
return "other"
|
| 545 |
-
|
| 546 |
-
def log_error(question, sql, error, error_type):
|
| 547 |
-
os.makedirs(os.path.dirname(ERROR_LOG_FILE), exist_ok=True)
|
| 548 |
-
|
| 549 |
-
entry = {
|
| 550 |
-
"question": question,
|
| 551 |
-
"sql": sql,
|
| 552 |
-
"error": str(error),
|
| 553 |
-
"error_type": error_type,
|
| 554 |
-
"timestamp": time.time()
|
| 555 |
-
}
|
| 556 |
-
|
| 557 |
-
logs = []
|
| 558 |
-
if os.path.exists(ERROR_LOG_FILE):
|
| 559 |
-
try:
|
| 560 |
-
with open(ERROR_LOG_FILE, "r") as f:
|
| 561 |
-
content = f.read().strip()
|
| 562 |
-
if content:
|
| 563 |
-
logs = json.loads(content)
|
| 564 |
-
except:
|
| 565 |
-
logs = []
|
| 566 |
-
|
| 567 |
-
logs.append(entry)
|
| 568 |
-
|
| 569 |
-
with open(ERROR_LOG_FILE, "w") as f:
|
| 570 |
-
json.dump(logs, f, indent=2)
|
| 571 |
-
|
| 572 |
-
# =========================================================
|
| 573 |
-
# 🔥 FINAL FIX_SQL (BALANCED VERSION)
|
| 574 |
-
# =========================================================
|
| 575 |
-
def fix_sql(sql):
|
| 576 |
-
if not sql:
|
| 577 |
-
return "SELECT 1"
|
| 578 |
-
|
| 579 |
-
s = str(sql).strip()
|
| 580 |
-
|
| 581 |
-
# Extract SQL only
|
| 582 |
-
match = re.search(r"(?i)(select|with)[\s\S]*", s)
|
| 583 |
-
if match:
|
| 584 |
-
s = match.group(0)
|
| 585 |
-
|
| 586 |
-
s = s.split(";")[0].strip()
|
| 587 |
-
|
| 588 |
-
# NULL fixes
|
| 589 |
-
s = re.sub(r'(?i)=\s*null', 'IS NULL', s)
|
| 590 |
-
s = re.sub(r'(?i)!=\s*null', 'IS NOT NULL', s)
|
| 591 |
-
|
| 592 |
-
# Fix commas
|
| 593 |
-
s = re.sub(r',\s*,+', ',', s)
|
| 594 |
-
s = re.sub(r'(?i),\s*from', ' FROM', s)
|
| 595 |
-
|
| 596 |
-
# 🔥 LIGHT COLUMN SAFETY (main improvement)
|
| 597 |
-
if "select" in s.lower():
|
| 598 |
-
if len(re.findall(r'\w+\.\w+', s)) > 3:
|
| 599 |
-
s = re.sub(r'(?i)select\s+.*?\s+from', 'SELECT * FROM', s)
|
| 600 |
-
|
| 601 |
-
# 🔥 JOIN fix
|
| 602 |
-
if "join" in s.lower() and " on " not in s.lower():
|
| 603 |
-
s = re.sub(r'join\s+(\w+)', r'JOIN \1 ON 1=1', s, flags=re.I)
|
| 604 |
-
|
| 605 |
-
# Ensure valid SQL
|
| 606 |
-
if not s.lower().startswith(("select", "with")):
|
| 607 |
-
return "SELECT 1"
|
| 608 |
-
|
| 609 |
-
return s.strip()
|
| 610 |
-
|
| 611 |
-
# =========================================================
|
| 612 |
-
# NORMALIZATION
|
| 613 |
-
# =========================================================
|
| 614 |
-
def normalize_sql(sql):
|
| 615 |
-
if not sql:
|
| 616 |
-
return ""
|
| 617 |
-
return re.sub(r"\s+", " ", str(sql)).strip().lower()
|
| 618 |
|
| 619 |
-
|
| 620 |
-
|
| 621 |
-
|
| 622 |
-
|
| 623 |
-
|
| 624 |
-
|
| 625 |
-
|
| 626 |
-
|
| 627 |
-
|
| 628 |
-
# =========================================================
|
| 629 |
-
# EXECUTION HELPERS
|
| 630 |
-
# =========================================================
|
| 631 |
-
def is_executable(sql, db_path):
|
| 632 |
-
try:
|
| 633 |
-
conn = sqlite3.connect(db_path)
|
| 634 |
-
cur = conn.cursor()
|
| 635 |
-
cur.execute(sql)
|
| 636 |
-
conn.close()
|
| 637 |
-
return True
|
| 638 |
-
except:
|
| 639 |
-
return False
|
| 640 |
|
| 641 |
-
def check_execution(pred_sql, gold_sql, db_path
|
|
|
|
| 642 |
try:
|
| 643 |
conn = sqlite3.connect(db_path)
|
| 644 |
conn.text_factory = lambda b: b.decode(errors='ignore')
|
| 645 |
-
|
| 646 |
-
|
| 647 |
-
|
| 648 |
-
|
| 649 |
-
|
| 650 |
-
|
| 651 |
-
|
| 652 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 653 |
conn.close()
|
| 654 |
-
|
| 655 |
-
|
| 656 |
-
|
| 657 |
-
except Exception
|
| 658 |
-
error_type = classify_error(pred_sql, str(e))
|
| 659 |
-
log_error(question, pred_sql, str(e), error_type)
|
| 660 |
return False
|
| 661 |
|
| 662 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 663 |
# MAIN
|
| 664 |
-
#
|
| 665 |
def main():
|
| 666 |
parser = argparse.ArgumentParser()
|
| 667 |
-
parser.add_argument("--adapter", type=str, required=True)
|
| 668 |
-
parser.add_argument("--num_samples", type=int, default=700)
|
|
|
|
|
|
|
| 669 |
args = parser.parse_args()
|
| 670 |
|
| 671 |
-
project_root = Path(__file__).resolve().
|
| 672 |
-
|
| 673 |
-
project_root = project_root.parent
|
| 674 |
|
| 675 |
db_root = project_root / "data" / "database"
|
|
|
|
| 676 |
dev_json = project_root / "data" / "dev.json"
|
|
|
|
|
|
|
|
|
|
| 677 |
|
| 678 |
-
|
|
|
|
| 679 |
|
| 680 |
-
|
|
|
|
| 681 |
|
| 682 |
-
|
| 683 |
-
|
|
|
|
|
|
|
| 684 |
|
| 685 |
-
|
|
|
|
|
|
|
| 686 |
model = model.merge_and_unload()
|
| 687 |
model.eval()
|
| 688 |
|
| 689 |
-
with open(
|
| 690 |
-
|
| 691 |
-
|
| 692 |
-
em_correct = 0
|
| 693 |
-
ex_correct = 0
|
| 694 |
-
constraint_ok = 0
|
| 695 |
-
|
| 696 |
-
print(f"\n🚀 Evaluating {len(dev_data)} samples...\n")
|
| 697 |
-
|
| 698 |
-
for i, ex in enumerate(dev_data, 1):
|
| 699 |
-
db_id = ex["db_id"]
|
| 700 |
-
question = ex["question"]
|
| 701 |
-
gold_query = ex["query"]
|
| 702 |
-
|
| 703 |
-
db_path = db_root / db_id / f"{db_id}.sqlite"
|
| 704 |
-
|
| 705 |
-
input_tensor = encode_prompt(tokenizer, question, db_id, device=device).unsqueeze(0)
|
| 706 |
-
|
| 707 |
-
with torch.no_grad():
|
| 708 |
-
outputs = model.generate(
|
| 709 |
-
input_ids=input_tensor,
|
| 710 |
-
max_new_tokens=128,
|
| 711 |
-
num_beams=8,
|
| 712 |
-
num_return_sequences=8
|
| 713 |
-
)
|
| 714 |
|
| 715 |
-
|
|
|
|
|
|
|
| 716 |
|
| 717 |
-
|
| 718 |
-
|
| 719 |
-
raw_pred = tokenizer.decode(out, skip_special_tokens=True)
|
| 720 |
-
candidate_sql = fix_sql(raw_pred)
|
| 721 |
|
| 722 |
-
|
| 723 |
-
|
| 724 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 725 |
|
| 726 |
-
|
| 727 |
-
best_sql = fix_sql(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
| 728 |
|
| 729 |
-
|
| 730 |
-
|
| 731 |
-
except:
|
| 732 |
-
is_valid = False
|
| 733 |
-
|
| 734 |
-
if is_valid:
|
| 735 |
-
constraint_ok += 1
|
| 736 |
-
|
| 737 |
-
if normalize_sql(best_sql) == normalize_sql(gold_query):
|
| 738 |
-
em_correct += 1
|
| 739 |
-
|
| 740 |
-
if check_execution(best_sql, gold_query, str(db_path), question):
|
| 741 |
-
ex_correct += 1
|
| 742 |
-
|
| 743 |
-
if i % 50 == 0:
|
| 744 |
-
print(f"{i}/{len(dev_data)} done")
|
| 745 |
|
| 746 |
-
|
| 747 |
-
|
| 748 |
-
|
| 749 |
-
|
| 750 |
-
|
| 751 |
-
|
| 752 |
-
|
| 753 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 754 |
|
| 755 |
if __name__ == "__main__":
|
| 756 |
-
main()
|
|
|
|
| 1 |
# import json
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
# import sqlite3
|
| 3 |
+
# import argparse
|
|
|
|
|
|
|
| 4 |
# from pathlib import Path
|
|
|
|
| 5 |
# import torch
|
| 6 |
# from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
| 7 |
# from peft import PeftModel
|
| 8 |
|
| 9 |
+
# # ---------------- PROMPT (IDENTICAL TO TRAINING) ----------------
|
| 10 |
+
# def build_prompt(question, schema):
|
| 11 |
+
# return f"""
|
| 12 |
+
# Database Schema:
|
| 13 |
+
# {schema}
|
| 14 |
|
| 15 |
+
# Translate English to SQL:
|
| 16 |
+
# {question}
|
| 17 |
+
# SQL:
|
| 18 |
+
# """
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
+
# # ---------------- LOAD SCHEMA ----------------
|
| 21 |
+
# def load_schema(db_path):
|
| 22 |
+
# conn = sqlite3.connect(db_path)
|
| 23 |
+
# cursor = conn.cursor()
|
| 24 |
|
| 25 |
+
# tables = cursor.execute(
|
| 26 |
+
# "SELECT name FROM sqlite_master WHERE type='table';"
|
| 27 |
+
# ).fetchall()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
+
# schema = ""
|
| 30 |
+
# for (table,) in tables:
|
| 31 |
+
# cols = cursor.execute(f"PRAGMA table_info({table});").fetchall()
|
| 32 |
+
# col_names = [c[1] for c in cols]
|
| 33 |
+
# schema += f"{table}({', '.join(col_names)})\n"
|
| 34 |
|
| 35 |
+
# conn.close()
|
| 36 |
+
# return schema
|
| 37 |
|
|
|
|
| 38 |
|
| 39 |
+
# # ---------------- EXECUTION CHECK ----------------
|
| 40 |
+
# def execution_match(pred_sql, gold_sql, db_path):
|
| 41 |
+
# try:
|
| 42 |
+
# conn = sqlite3.connect(db_path)
|
| 43 |
+
# cur = conn.cursor()
|
| 44 |
|
| 45 |
+
# cur.execute(pred_sql)
|
| 46 |
+
# pred = cur.fetchall()
|
| 47 |
|
| 48 |
+
# cur.execute(gold_sql)
|
| 49 |
+
# gold = cur.fetchall()
|
| 50 |
|
| 51 |
# conn.close()
|
| 52 |
+
# return pred == gold
|
|
|
|
|
|
|
| 53 |
|
| 54 |
# except Exception:
|
| 55 |
# return False
|
| 56 |
|
| 57 |
|
| 58 |
+
# # ---------------- MAIN ----------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
# def main():
|
| 60 |
# parser = argparse.ArgumentParser()
|
| 61 |
# parser.add_argument("--adapter", type=str, required=True)
|
| 62 |
+
# parser.add_argument("--num_samples", type=int, default=1034)
|
|
|
|
|
|
|
| 63 |
# args = parser.parse_args()
|
| 64 |
|
| 65 |
# project_root = Path(__file__).resolve().parents[1]
|
|
|
|
| 66 |
|
|
|
|
|
|
|
| 67 |
# dev_json = project_root / "data" / "dev.json"
|
| 68 |
+
# db_root = project_root / "data" / "database"
|
| 69 |
|
| 70 |
+
# device = "mps" if torch.backends.mps.is_available() else "cpu"
|
|
|
|
| 71 |
|
| 72 |
+
# # load model
|
| 73 |
+
# base_model = "Salesforce/codet5-base"
|
| 74 |
+
# tokenizer = AutoTokenizer.from_pretrained(args.adapter)
|
| 75 |
+
# base = AutoModelForSeq2SeqLM.from_pretrained(base_model).to(device)
|
| 76 |
+
# model = PeftModel.from_pretrained(base, args.adapter).to(device)
|
| 77 |
+
# model = model.merge_and_unload()
|
| 78 |
|
| 79 |
+
# with open(dev_json) as f:
|
| 80 |
+
# dev = json.load(f)[: args.num_samples]
|
|
|
|
|
|
|
| 81 |
|
| 82 |
+
# correct = 0
|
|
|
|
| 83 |
|
| 84 |
+
# print(f"Evaluating {len(dev)} examples...\n")
|
|
|
|
| 85 |
|
| 86 |
+
# for i, ex in enumerate(dev, 1):
|
| 87 |
+
# question = ex["question"]
|
| 88 |
+
# db_id = ex["db_id"]
|
| 89 |
+
# gold_sql = ex["query"]
|
| 90 |
|
| 91 |
+
# db_path = db_root / db_id / f"{db_id}.sqlite"
|
| 92 |
+
# schema = load_schema(db_path)
|
| 93 |
|
| 94 |
+
# prompt = build_prompt(question, schema)
|
| 95 |
|
| 96 |
+
# inputs = tokenizer(prompt, return_tensors="pt").to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
|
| 98 |
+
# with torch.no_grad():
|
| 99 |
# outputs = model.generate(
|
| 100 |
+
# **inputs,
|
| 101 |
+
# max_new_tokens=80,
|
| 102 |
+
# do_sample=False,
|
| 103 |
+
# num_beams=4,
|
| 104 |
# )
|
| 105 |
|
| 106 |
+
# pred_sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
|
| 108 |
+
# if "SQL:" in pred_sql:
|
| 109 |
+
# pred_sql = pred_sql.split("SQL:")[-1].strip()
|
| 110 |
|
| 111 |
+
# match = execution_match(pred_sql, gold_sql, db_path)
|
| 112 |
|
| 113 |
+
# if match:
|
| 114 |
+
# correct += 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
|
| 116 |
+
# if i % 10 == 0:
|
| 117 |
+
# print(f"{i}/{len(dev)} | Acc: {correct/i:.3f}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
|
| 119 |
+
# print("\n=============================")
|
| 120 |
+
# print(f"FINAL EXECUTION ACCURACY: {correct/len(dev)*100:.2f}%")
|
| 121 |
+
# print("=============================")
|
| 122 |
|
| 123 |
|
| 124 |
# if __name__ == "__main__":
|
| 125 |
# main()
|
| 126 |
|
| 127 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
# import json
|
| 129 |
# import sqlite3
|
|
|
|
|
|
|
|
|
|
| 130 |
# import argparse
|
| 131 |
+
# import time
|
| 132 |
# from pathlib import Path
|
| 133 |
+
# import torch
|
| 134 |
+
# from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
| 135 |
+
# from peft import PeftModel
|
| 136 |
|
| 137 |
+
# # ---------------- PROMPT (IDENTICAL TO TRAINING) ----------------
|
| 138 |
+
# def build_prompt(question, schema):
|
| 139 |
+
# return f"""
|
| 140 |
+
# Database Schema:
|
| 141 |
+
# {schema}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
|
| 143 |
+
# Translate English to SQL:
|
| 144 |
+
# {question}
|
| 145 |
+
# SQL:
|
| 146 |
+
# """
|
| 147 |
|
| 148 |
+
# # ---------------- LOAD SCHEMA ----------------
|
| 149 |
+
# def load_schema(db_path):
|
| 150 |
+
# conn = sqlite3.connect(db_path)
|
| 151 |
+
# cursor = conn.cursor()
|
| 152 |
|
| 153 |
+
# tables = cursor.execute(
|
| 154 |
+
# "SELECT name FROM sqlite_master WHERE type='table';"
|
| 155 |
+
# ).fetchall()
|
| 156 |
|
| 157 |
+
# schema = ""
|
| 158 |
+
# for (table,) in tables:
|
| 159 |
+
# cols = cursor.execute(f"PRAGMA table_info({table});").fetchall()
|
| 160 |
+
# col_names = [c[1] for c in cols]
|
| 161 |
+
# schema += f"{table}({', '.join(col_names)})\n"
|
| 162 |
|
| 163 |
+
# conn.close()
|
| 164 |
+
# return schema
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
|
|
|
|
|
|
|
| 166 |
|
| 167 |
+
# # ---------------- EXECUTION CHECK WITH TIMEOUT ----------------
|
| 168 |
+
# def execution_match(pred_sql, gold_sql, db_path):
|
| 169 |
+
# try:
|
| 170 |
+
# conn = sqlite3.connect(db_path)
|
| 171 |
+
|
| 172 |
+
# # --- 5-SECOND TIMEOUT SO EVALUATION DOESN'T FREEZE ---
|
| 173 |
+
# start_time = time.monotonic()
|
| 174 |
+
# def timeout_handler():
|
| 175 |
+
# return 1 if (time.monotonic() - start_time) > 5.0 else 0
|
| 176 |
+
# conn.set_progress_handler(timeout_handler, 10000)
|
| 177 |
|
| 178 |
+
# cur = conn.cursor()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 179 |
|
| 180 |
+
# cur.execute(pred_sql)
|
| 181 |
+
# pred = cur.fetchall()
|
|
|
|
| 182 |
|
| 183 |
+
# cur.execute(gold_sql)
|
| 184 |
+
# gold = cur.fetchall()
|
| 185 |
|
| 186 |
+
# conn.close()
|
| 187 |
+
# return pred == gold
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 188 |
|
| 189 |
+
# except Exception:
|
| 190 |
+
# return False
|
|
|
|
| 191 |
|
|
|
|
|
|
|
|
|
|
| 192 |
|
| 193 |
+
# # ---------------- MAIN ----------------
|
| 194 |
+
# def main():
|
| 195 |
+
# parser = argparse.ArgumentParser()
|
| 196 |
+
# parser.add_argument("--adapter", type=str, required=True)
|
| 197 |
+
# parser.add_argument("--num_samples", type=int, default=1034)
|
| 198 |
+
# args = parser.parse_args()
|
| 199 |
|
| 200 |
+
# project_root = Path(__file__).resolve().parents[1]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 201 |
|
| 202 |
+
# dev_json = project_root / "data" / "dev.json"
|
| 203 |
+
# db_root = project_root / "data" / "database"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 204 |
|
| 205 |
+
# # 🎯 Added CUDA support for Nvidia GPUs
|
| 206 |
+
# device = "mps" if torch.backends.mps.is_available() else ("cuda" if torch.cuda.is_available() else "cpu")
|
| 207 |
+
|
| 208 |
+
# # load model
|
| 209 |
+
# base_model = "Salesforce/codet5-base"
|
| 210 |
+
# print(f"Loading Base: {base_model}")
|
| 211 |
+
# print(f"Loading Adapter: {args.adapter}")
|
| 212 |
+
|
| 213 |
+
# tokenizer = AutoTokenizer.from_pretrained(args.adapter)
|
| 214 |
+
# base = AutoModelForSeq2SeqLM.from_pretrained(base_model).to(device)
|
| 215 |
+
# model = PeftModel.from_pretrained(base, args.adapter).to(device)
|
| 216 |
+
# model = model.merge_and_unload()
|
| 217 |
|
| 218 |
+
# with open(dev_json) as f:
|
| 219 |
+
# dev = json.load(f)[: args.num_samples]
|
| 220 |
|
| 221 |
+
# correct = 0
|
|
|
|
| 222 |
|
| 223 |
+
# print(f"Evaluating {len(dev)} examples...\n")
|
|
|
|
|
|
|
|
|
|
| 224 |
|
| 225 |
+
# for i, ex in enumerate(dev, 1):
|
| 226 |
+
# question = ex["question"]
|
| 227 |
+
# db_id = ex["db_id"]
|
| 228 |
+
# gold_sql = ex["query"]
|
|
|
|
| 229 |
|
| 230 |
+
# db_path = db_root / db_id / f"{db_id}.sqlite"
|
| 231 |
+
# schema = load_schema(db_path)
|
| 232 |
|
| 233 |
+
# prompt = build_prompt(question, schema)
|
|
|
|
|
|
|
|
|
|
| 234 |
|
| 235 |
+
# inputs = tokenizer(prompt, return_tensors="pt").to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 236 |
|
| 237 |
+
# with torch.no_grad():
|
| 238 |
+
# outputs = model.generate(
|
| 239 |
+
# **inputs,
|
| 240 |
+
# max_new_tokens=80,
|
| 241 |
+
# do_sample=False,
|
| 242 |
+
# num_beams=4,
|
| 243 |
+
# )
|
| 244 |
|
| 245 |
+
# pred_sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
|
|
|
|
|
|
|
|
| 246 |
|
| 247 |
+
# if "SQL:" in pred_sql:
|
| 248 |
+
# pred_sql = pred_sql.split("SQL:")[-1].strip()
|
|
|
|
|
|
|
|
|
|
| 249 |
|
| 250 |
+
# match = execution_match(pred_sql, gold_sql, db_path)
|
| 251 |
|
| 252 |
+
# if match:
|
| 253 |
+
# correct += 1
|
|
|
|
|
|
|
|
|
|
|
|
|
| 254 |
|
| 255 |
+
# if i % 10 == 0:
|
| 256 |
+
# print(f"{i}/{len(dev)} | Acc: {correct/i:.3f}")
|
| 257 |
|
| 258 |
+
# print("\n=============================")
|
| 259 |
+
# print(f"FINAL EXECUTION ACCURACY: {correct/len(dev)*100:.2f}%")
|
| 260 |
+
# print("=============================")
|
| 261 |
|
|
|
|
|
|
|
|
|
|
| 262 |
|
| 263 |
+
# if __name__ == "__main__":
|
| 264 |
+
# main()
|
| 265 |
|
| 266 |
|
| 267 |
import json
|
| 268 |
+
import subprocess
|
| 269 |
+
import sys
|
| 270 |
import argparse
|
| 271 |
+
import random
|
| 272 |
import sqlite3
|
| 273 |
import time
|
| 274 |
import re
|
|
|
|
| 275 |
from pathlib import Path
|
| 276 |
|
| 277 |
import torch
|
| 278 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
| 279 |
from peft import PeftModel
|
| 280 |
|
| 281 |
+
# Assuming you have a prompting.py that has encode_prompt
|
| 282 |
+
from prompting import encode_prompt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 283 |
|
| 284 |
+
# -------------------------------
|
| 285 |
+
# LIVE CHECK HELPERS
|
| 286 |
+
# -------------------------------
|
| 287 |
+
def normalize_sql(sql):
|
| 288 |
+
"""Basic normalization for the live progress bar."""
|
| 289 |
+
sql = sql.replace('"', "'")
|
| 290 |
+
sql = re.sub(r"\s+", " ", sql)
|
| 291 |
+
return sql.strip().lower().rstrip(";")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 292 |
|
| 293 |
+
def check_execution(pred_sql, gold_sql, db_path):
|
| 294 |
+
"""Basic execution check for the live progress bar."""
|
| 295 |
try:
|
| 296 |
conn = sqlite3.connect(db_path)
|
| 297 |
conn.text_factory = lambda b: b.decode(errors='ignore')
|
| 298 |
+
|
| 299 |
+
# 2-second timeout so the live tracker doesn't freeze forever
|
| 300 |
+
start_time = time.monotonic()
|
| 301 |
+
def timeout_handler():
|
| 302 |
+
return 1 if (time.monotonic() - start_time) > 2.0 else 0
|
| 303 |
+
conn.set_progress_handler(timeout_handler, 10000)
|
| 304 |
+
|
| 305 |
+
cursor = conn.cursor()
|
| 306 |
+
cursor.execute(pred_sql)
|
| 307 |
+
pred_res = cursor.fetchall()
|
| 308 |
+
|
| 309 |
+
cursor.execute(gold_sql)
|
| 310 |
+
gold_res = cursor.fetchall()
|
| 311 |
conn.close()
|
| 312 |
+
|
| 313 |
+
# Simple sorted check for the live tracker
|
| 314 |
+
return sorted(pred_res) == sorted(gold_res)
|
| 315 |
+
except Exception:
|
|
|
|
|
|
|
| 316 |
return False
|
| 317 |
|
| 318 |
+
# -------------------------------
|
| 319 |
+
# SPIDER PARSER
|
| 320 |
+
# -------------------------------
|
| 321 |
+
def _parse_spider_accuracy(stdout: str, metric_type: str) -> float | None:
|
| 322 |
+
for line in stdout.splitlines():
|
| 323 |
+
if metric_type == "exec" and line.strip().startswith("execution"):
|
| 324 |
+
try: return float(line.split()[-1])
|
| 325 |
+
except: pass
|
| 326 |
+
elif metric_type == "match" and line.strip().startswith("exact"):
|
| 327 |
+
try: return float(line.split()[-1])
|
| 328 |
+
except: pass
|
| 329 |
+
return None
|
| 330 |
+
|
| 331 |
+
# -------------------------------
|
| 332 |
# MAIN
|
| 333 |
+
# -------------------------------
|
| 334 |
def main():
|
| 335 |
parser = argparse.ArgumentParser()
|
| 336 |
+
parser.add_argument("--adapter", type=str, required=True, help="Path to your SFT or RLHF checkpoint")
|
| 337 |
+
parser.add_argument("--num_samples", type=int, default=700, help="Number of samples to evaluate")
|
| 338 |
+
parser.add_argument("--shuffle_dev", action="store_true")
|
| 339 |
+
parser.add_argument("--shuffle_seed", type=int, default=42)
|
| 340 |
args = parser.parse_args()
|
| 341 |
|
| 342 |
+
project_root = Path(__file__).resolve().parents[1]
|
| 343 |
+
adapter_dir = project_root / args.adapter
|
|
|
|
| 344 |
|
| 345 |
db_root = project_root / "data" / "database"
|
| 346 |
+
table_json = project_root / "data" / "tables.json"
|
| 347 |
dev_json = project_root / "data" / "dev.json"
|
| 348 |
+
|
| 349 |
+
pred_path = project_root / "temp_predictions.txt"
|
| 350 |
+
temp_gold_path = project_root / "temp_gold.sql"
|
| 351 |
|
| 352 |
+
if not adapter_dir.exists():
|
| 353 |
+
raise FileNotFoundError(f"Missing adapter dir: {adapter_dir}")
|
| 354 |
|
| 355 |
+
device = "mps" if torch.backends.mps.is_available() else ("cuda" if torch.cuda.is_available() else "cpu")
|
| 356 |
+
print(f"Using device: {device}")
|
| 357 |
|
| 358 |
+
BASE_MODEL = "Salesforce/codet5-base"
|
| 359 |
+
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
|
| 360 |
+
if tokenizer.pad_token is None:
|
| 361 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 362 |
|
| 363 |
+
print(f"Loading Model: {args.adapter}...")
|
| 364 |
+
base = AutoModelForSeq2SeqLM.from_pretrained(BASE_MODEL).to(device)
|
| 365 |
+
model = PeftModel.from_pretrained(base, str(adapter_dir)).to(device)
|
| 366 |
model = model.merge_and_unload()
|
| 367 |
model.eval()
|
| 368 |
|
| 369 |
+
with dev_json.open() as f:
|
| 370 |
+
dev = json.load(f)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 371 |
|
| 372 |
+
if args.shuffle_dev:
|
| 373 |
+
rng = random.Random(args.shuffle_seed)
|
| 374 |
+
rng.shuffle(dev)
|
| 375 |
|
| 376 |
+
dev = dev[: args.num_samples]
|
| 377 |
+
total = len(dev)
|
|
|
|
|
|
|
| 378 |
|
| 379 |
+
gen_kwargs = dict(
|
| 380 |
+
max_new_tokens=160,
|
| 381 |
+
num_beams=4,
|
| 382 |
+
do_sample=False,
|
| 383 |
+
early_stopping=True,
|
| 384 |
+
pad_token_id=tokenizer.pad_token_id,
|
| 385 |
+
eos_token_id=tokenizer.eos_token_id,
|
| 386 |
+
)
|
| 387 |
|
| 388 |
+
print(f"\n🚀 Generating and live-tracking {total} samples...\n")
|
|
|
|
| 389 |
|
| 390 |
+
em_correct = 0
|
| 391 |
+
ex_correct = 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 392 |
|
| 393 |
+
with pred_path.open("w") as out_pred, temp_gold_path.open("w") as out_gold, torch.no_grad():
|
| 394 |
+
for i, ex in enumerate(dev, start=1):
|
| 395 |
+
db_id = ex["db_id"]
|
| 396 |
+
question = ex["question"]
|
| 397 |
+
gold_query = ex["query"]
|
| 398 |
+
db_path = db_root / db_id / f"{db_id}.sqlite"
|
| 399 |
+
|
| 400 |
+
# Generate
|
| 401 |
+
input_ids = encode_prompt(tokenizer, question, db_id, device=device, max_input_tokens=512)
|
| 402 |
+
input_ids = input_ids.unsqueeze(0).to(device)
|
| 403 |
+
attention_mask = (input_ids != tokenizer.pad_token_id).long().to(device)
|
| 404 |
+
|
| 405 |
+
outputs = model.generate(input_ids=input_ids, attention_mask=attention_mask, **gen_kwargs)
|
| 406 |
+
pred_sql = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
|
| 407 |
+
|
| 408 |
+
# Write to files for official spider eval later
|
| 409 |
+
out_pred.write(f"{pred_sql}\n")
|
| 410 |
+
out_gold.write(f"{gold_query}\t{db_id}\n")
|
| 411 |
+
|
| 412 |
+
# --- LIVE TRACKING CHECKS ---
|
| 413 |
+
if normalize_sql(pred_sql) == normalize_sql(gold_query):
|
| 414 |
+
em_correct += 1
|
| 415 |
+
if check_execution(pred_sql, gold_query, db_path):
|
| 416 |
+
ex_correct += 1
|
| 417 |
+
|
| 418 |
+
# Print progress every 50 loops
|
| 419 |
+
if i % 10 == 0 or i == total:
|
| 420 |
+
print(f"Progress: {i}/{total} | Current EM: {(em_correct/i)*100:.2f}% | Current EX: {(ex_correct/i)*100:.2f}%")
|
| 421 |
+
|
| 422 |
+
print("\nGeneration finished. Running Official Spider Evaluations for final numbers...\n")
|
| 423 |
+
|
| 424 |
+
eval_script = project_root / "spider_eval" / "evaluation.py"
|
| 425 |
+
|
| 426 |
+
# 1. RUN EXACT MATCH EVAL
|
| 427 |
+
cmd_match = [
|
| 428 |
+
sys.executable, str(eval_script),
|
| 429 |
+
"--gold", str(temp_gold_path),
|
| 430 |
+
"--pred", str(pred_path),
|
| 431 |
+
"--etype", "match",
|
| 432 |
+
"--db", str(db_root),
|
| 433 |
+
"--table", str(table_json),
|
| 434 |
+
]
|
| 435 |
+
proc_match = subprocess.run(cmd_match, capture_output=True, text=True)
|
| 436 |
+
exact_acc = _parse_spider_accuracy(proc_match.stdout, "match")
|
| 437 |
+
|
| 438 |
+
# 2. RUN EXECUTION EVAL
|
| 439 |
+
cmd_exec = [
|
| 440 |
+
sys.executable, str(eval_script),
|
| 441 |
+
"--gold", str(temp_gold_path),
|
| 442 |
+
"--pred", str(pred_path),
|
| 443 |
+
"--etype", "exec",
|
| 444 |
+
"--db", str(db_root),
|
| 445 |
+
"--table", str(table_json),
|
| 446 |
+
]
|
| 447 |
+
proc_exec = subprocess.run(cmd_exec, capture_output=True, text=True)
|
| 448 |
+
exec_acc = _parse_spider_accuracy(proc_exec.stdout, "exec")
|
| 449 |
+
|
| 450 |
+
print("==========================================")
|
| 451 |
+
print(f"🎯 OFFICIAL SPIDER RESULTS FOR: {args.adapter}")
|
| 452 |
+
print("==========================================")
|
| 453 |
+
|
| 454 |
+
if exact_acc is not None:
|
| 455 |
+
print(f"Exact Set Match Accuracy : {exact_acc*100:.2f}%")
|
| 456 |
+
else:
|
| 457 |
+
print("Exact Set Match Accuracy : Could not parse output")
|
| 458 |
+
|
| 459 |
+
if exec_acc is not None:
|
| 460 |
+
print(f"Execution Accuracy : {exec_acc*100:.2f}%")
|
| 461 |
+
else:
|
| 462 |
+
print("Execution Accuracy : Could not parse output")
|
| 463 |
+
print("==========================================\n")
|
| 464 |
|
| 465 |
if __name__ == "__main__":
|
| 466 |
+
main()
|
src/evaluate_without_constraied.py
DELETED
|
@@ -1,503 +0,0 @@
|
|
| 1 |
-
|
| 2 |
-
# *********** code till task 3 ************
|
| 3 |
-
|
| 4 |
-
# import json
|
| 5 |
-
# import subprocess
|
| 6 |
-
# import sys
|
| 7 |
-
# import argparse
|
| 8 |
-
# import random
|
| 9 |
-
# import sqlite3
|
| 10 |
-
# import time
|
| 11 |
-
# import re
|
| 12 |
-
# import os
|
| 13 |
-
# from pathlib import Path
|
| 14 |
-
|
| 15 |
-
# import torch
|
| 16 |
-
# from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
| 17 |
-
# from peft import PeftModel
|
| 18 |
-
|
| 19 |
-
# from prompting import encode_prompt
|
| 20 |
-
|
| 21 |
-
# # -------------------------------
|
| 22 |
-
# # NORMALIZATION
|
| 23 |
-
# # -------------------------------
|
| 24 |
-
# def normalize_sql(sql):
|
| 25 |
-
# sql = sql.replace('"', "'")
|
| 26 |
-
# sql = re.sub(r"\s+", " ", sql)
|
| 27 |
-
# return sql.strip().lower().rstrip(";")
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
# # -------------------------------
|
| 31 |
-
# # 🔥 SAFE RESULT NORMALIZATION (FIX)
|
| 32 |
-
# # -------------------------------
|
| 33 |
-
# def normalize_result(res):
|
| 34 |
-
# try:
|
| 35 |
-
# return sorted([str(r) for r in res])
|
| 36 |
-
# except:
|
| 37 |
-
# return []
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
# # -------------------------------
|
| 41 |
-
# # EXECUTION CHECK (FIXED)
|
| 42 |
-
# # -------------------------------
|
| 43 |
-
# def check_execution(pred_sql, gold_sql, db_path):
|
| 44 |
-
# try:
|
| 45 |
-
# conn = sqlite3.connect(db_path)
|
| 46 |
-
# conn.text_factory = lambda b: b.decode(errors='ignore')
|
| 47 |
-
|
| 48 |
-
# start_time = time.monotonic()
|
| 49 |
-
|
| 50 |
-
# def timeout_handler():
|
| 51 |
-
# return 1 if (time.monotonic() - start_time) > 2.0 else 0
|
| 52 |
-
|
| 53 |
-
# conn.set_progress_handler(timeout_handler, 10000)
|
| 54 |
-
|
| 55 |
-
# cursor = conn.cursor()
|
| 56 |
-
|
| 57 |
-
# cursor.execute(pred_sql)
|
| 58 |
-
# pred_res = cursor.fetchall()
|
| 59 |
-
|
| 60 |
-
# cursor.execute(gold_sql)
|
| 61 |
-
# gold_res = cursor.fetchall()
|
| 62 |
-
|
| 63 |
-
# conn.close()
|
| 64 |
-
|
| 65 |
-
# # 🔥 FIXED COMPARISON
|
| 66 |
-
# return normalize_result(pred_res) == normalize_result(gold_res)
|
| 67 |
-
|
| 68 |
-
# except Exception:
|
| 69 |
-
# return False
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
# # -------------------------------
|
| 73 |
-
# # SPIDER PARSER
|
| 74 |
-
# # -------------------------------
|
| 75 |
-
# def _parse_spider_accuracy(stdout: str, metric_type: str):
|
| 76 |
-
# for line in stdout.splitlines():
|
| 77 |
-
# if metric_type == "exec" and line.strip().startswith("execution"):
|
| 78 |
-
# try:
|
| 79 |
-
# return float(line.split()[-1])
|
| 80 |
-
# except:
|
| 81 |
-
# pass
|
| 82 |
-
# elif metric_type == "match" and line.strip().startswith("exact"):
|
| 83 |
-
# try:
|
| 84 |
-
# return float(line.split()[-1])
|
| 85 |
-
# except:
|
| 86 |
-
# pass
|
| 87 |
-
# return None
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
# # -------------------------------
|
| 91 |
-
# # MAIN
|
| 92 |
-
# # -------------------------------
|
| 93 |
-
# def main():
|
| 94 |
-
# parser = argparse.ArgumentParser()
|
| 95 |
-
# parser.add_argument("--adapter", type=str, required=True)
|
| 96 |
-
# parser.add_argument("--num_samples", type=int, default= 500)
|
| 97 |
-
# parser.add_argument("--shuffle_dev", action="store_true")
|
| 98 |
-
# parser.add_argument("--shuffle_seed", type=int, default=42)
|
| 99 |
-
# args = parser.parse_args()
|
| 100 |
-
|
| 101 |
-
# project_root = Path(__file__).resolve().parents[1]
|
| 102 |
-
# adapter_dir = project_root / args.adapter
|
| 103 |
-
|
| 104 |
-
# db_root = project_root / "data" / "database"
|
| 105 |
-
# table_json = project_root / "data" / "tables.json"
|
| 106 |
-
# dev_json = project_root / "data" / "dev.json"
|
| 107 |
-
|
| 108 |
-
# pred_path = project_root / "temp_predictions.txt"
|
| 109 |
-
# temp_gold_path = project_root / "temp_gold.sql"
|
| 110 |
-
|
| 111 |
-
# if not adapter_dir.exists():
|
| 112 |
-
# raise FileNotFoundError(f"Missing adapter dir: {adapter_dir}")
|
| 113 |
-
|
| 114 |
-
# device = "mps" if torch.backends.mps.is_available() else (
|
| 115 |
-
# "cuda" if torch.cuda.is_available() else "cpu"
|
| 116 |
-
# )
|
| 117 |
-
# print(f"Using device: {device}")
|
| 118 |
-
|
| 119 |
-
# BASE_MODEL = "Salesforce/codet5-base"
|
| 120 |
-
# tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
|
| 121 |
-
|
| 122 |
-
# if tokenizer.pad_token is None:
|
| 123 |
-
# tokenizer.pad_token = tokenizer.eos_token
|
| 124 |
-
|
| 125 |
-
# print(f"\n📦 Loading Model: {args.adapter}")
|
| 126 |
-
|
| 127 |
-
# base = AutoModelForSeq2SeqLM.from_pretrained(BASE_MODEL).to(device)
|
| 128 |
-
|
| 129 |
-
# adapter_for_peft = os.path.relpath(adapter_dir, project_root)
|
| 130 |
-
|
| 131 |
-
# model = PeftModel.from_pretrained(
|
| 132 |
-
# base,
|
| 133 |
-
# adapter_for_peft,
|
| 134 |
-
# local_files_only=True
|
| 135 |
-
# ).to(device)
|
| 136 |
-
|
| 137 |
-
# model = model.merge_and_unload()
|
| 138 |
-
# model.eval()
|
| 139 |
-
|
| 140 |
-
# # -------------------------------
|
| 141 |
-
# # LOAD DATA
|
| 142 |
-
# # -------------------------------
|
| 143 |
-
# with dev_json.open() as f:
|
| 144 |
-
# dev = json.load(f)
|
| 145 |
-
|
| 146 |
-
# if args.shuffle_dev:
|
| 147 |
-
# rng = random.Random(args.shuffle_seed)
|
| 148 |
-
# rng.shuffle(dev)
|
| 149 |
-
|
| 150 |
-
# dev = dev[: args.num_samples]
|
| 151 |
-
# total = len(dev)
|
| 152 |
-
|
| 153 |
-
# gen_kwargs = dict(
|
| 154 |
-
# max_new_tokens=160,
|
| 155 |
-
# num_beams=8,
|
| 156 |
-
# length_penalty=0.8,
|
| 157 |
-
# do_sample=False,
|
| 158 |
-
# early_stopping=True,
|
| 159 |
-
# pad_token_id=tokenizer.pad_token_id,
|
| 160 |
-
# eos_token_id=tokenizer.eos_token_id,
|
| 161 |
-
# )
|
| 162 |
-
|
| 163 |
-
# print(f"\n🚀 Evaluating {total} samples...\n")
|
| 164 |
-
|
| 165 |
-
# em_correct = 0
|
| 166 |
-
# ex_correct = 0
|
| 167 |
-
|
| 168 |
-
# with pred_path.open("w") as out_pred, temp_gold_path.open("w") as out_gold, torch.no_grad():
|
| 169 |
-
# for i, ex in enumerate(dev, start=1):
|
| 170 |
-
|
| 171 |
-
# db_id = ex["db_id"]
|
| 172 |
-
# question = ex["question"]
|
| 173 |
-
# gold_query = ex["query"]
|
| 174 |
-
# db_path = db_root / db_id / f"{db_id}.sqlite"
|
| 175 |
-
|
| 176 |
-
# # -------------------------------
|
| 177 |
-
# # GENERATE SQL
|
| 178 |
-
# # -------------------------------
|
| 179 |
-
# input_ids = encode_prompt(
|
| 180 |
-
# tokenizer,
|
| 181 |
-
# question,
|
| 182 |
-
# db_id,
|
| 183 |
-
# device=device,
|
| 184 |
-
# max_input_tokens=512
|
| 185 |
-
# )
|
| 186 |
-
|
| 187 |
-
# input_ids = input_ids.unsqueeze(0).to(device)
|
| 188 |
-
# attention_mask = (input_ids != tokenizer.pad_token_id).long().to(device)
|
| 189 |
-
|
| 190 |
-
# outputs = model.generate(
|
| 191 |
-
# input_ids=input_ids,
|
| 192 |
-
# attention_mask=attention_mask,
|
| 193 |
-
# **gen_kwargs
|
| 194 |
-
# )
|
| 195 |
-
|
| 196 |
-
# pred_sql = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
|
| 197 |
-
|
| 198 |
-
# # -------------------------------
|
| 199 |
-
# # SAVE FOR SPIDER EVAL
|
| 200 |
-
# # -------------------------------
|
| 201 |
-
# out_pred.write(f"{pred_sql}\n")
|
| 202 |
-
# out_gold.write(f"{gold_query}\t{db_id}\n")
|
| 203 |
-
|
| 204 |
-
# # -------------------------------
|
| 205 |
-
# # LIVE METRICS
|
| 206 |
-
# # -------------------------------
|
| 207 |
-
# if normalize_sql(pred_sql) == normalize_sql(gold_query):
|
| 208 |
-
# em_correct += 1
|
| 209 |
-
|
| 210 |
-
# if check_execution(pred_sql, gold_query, db_path):
|
| 211 |
-
# ex_correct += 1
|
| 212 |
-
|
| 213 |
-
# if i % 20 == 0 or i == total:
|
| 214 |
-
# print(
|
| 215 |
-
# f"Progress: {i}/{total} | "
|
| 216 |
-
# f"EM: {(em_correct/i)*100:.2f}% | "
|
| 217 |
-
# f"EX: {(ex_correct/i)*100:.2f}%"
|
| 218 |
-
# )
|
| 219 |
-
|
| 220 |
-
# print("\n🚀 Running Official Spider Evaluation...\n")
|
| 221 |
-
|
| 222 |
-
# eval_script = project_root / "spider_eval" / "evaluation.py"
|
| 223 |
-
|
| 224 |
-
# # EXACT MATCH
|
| 225 |
-
# cmd_match = [
|
| 226 |
-
# sys.executable, str(eval_script),
|
| 227 |
-
# "--gold", str(temp_gold_path),
|
| 228 |
-
# "--pred", str(pred_path),
|
| 229 |
-
# "--etype", "match",
|
| 230 |
-
# "--db", str(db_root),
|
| 231 |
-
# "--table", str(table_json),
|
| 232 |
-
# ]
|
| 233 |
-
|
| 234 |
-
# proc_match = subprocess.run(cmd_match, capture_output=True, text=True)
|
| 235 |
-
# exact_acc = _parse_spider_accuracy(proc_match.stdout, "match")
|
| 236 |
-
|
| 237 |
-
# # EXECUTION
|
| 238 |
-
# cmd_exec = [
|
| 239 |
-
# sys.executable, str(eval_script),
|
| 240 |
-
# "--gold", str(temp_gold_path),
|
| 241 |
-
# "--pred", str(pred_path),
|
| 242 |
-
# "--etype", "exec",
|
| 243 |
-
# "--db", str(db_root),
|
| 244 |
-
# "--table", str(table_json),
|
| 245 |
-
# ]
|
| 246 |
-
|
| 247 |
-
# proc_exec = subprocess.run(cmd_exec, capture_output=True, text=True)
|
| 248 |
-
# exec_acc = _parse_spider_accuracy(proc_exec.stdout, "exec")
|
| 249 |
-
|
| 250 |
-
# print("==========================================")
|
| 251 |
-
# print(f"🎯 OFFICIAL SPIDER RESULTS FOR: {args.adapter}")
|
| 252 |
-
# print("==========================================")
|
| 253 |
-
|
| 254 |
-
# print(f"Exact Match Accuracy : {exact_acc*100:.2f}%" if exact_acc else "EM parsing failed")
|
| 255 |
-
# print(f"Execution Accuracy : {exec_acc*100:.2f}%" if exec_acc else "EX parsing failed")
|
| 256 |
-
|
| 257 |
-
# print("==========================================\n")
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
# if __name__ == "__main__":
|
| 261 |
-
# main()
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
# *********** for task 2 ****************************************
|
| 267 |
-
import json
|
| 268 |
-
import argparse
|
| 269 |
-
import random
|
| 270 |
-
import sqlite3
|
| 271 |
-
import re
|
| 272 |
-
import os
|
| 273 |
-
from pathlib import Path
|
| 274 |
-
from collections import defaultdict
|
| 275 |
-
|
| 276 |
-
import torch
|
| 277 |
-
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
| 278 |
-
from peft import PeftModel
|
| 279 |
-
|
| 280 |
-
from prompting import encode_prompt
|
| 281 |
-
|
| 282 |
-
# -------------------------------
|
| 283 |
-
# NORMALIZATION
|
| 284 |
-
# -------------------------------
|
| 285 |
-
def normalize_sql(sql):
|
| 286 |
-
sql = sql.replace('"', "'")
|
| 287 |
-
sql = re.sub(r"\s+", " ", sql)
|
| 288 |
-
return sql.strip().lower().rstrip(";")
|
| 289 |
-
|
| 290 |
-
def normalize_result(res):
|
| 291 |
-
try:
|
| 292 |
-
return sorted([str(r) for r in res])
|
| 293 |
-
except:
|
| 294 |
-
return []
|
| 295 |
-
|
| 296 |
-
# -------------------------------
|
| 297 |
-
# STEP 1: EXECUTION
|
| 298 |
-
# -------------------------------
|
| 299 |
-
def execute_with_error(sql, db_path):
|
| 300 |
-
try:
|
| 301 |
-
conn = sqlite3.connect(db_path)
|
| 302 |
-
cur = conn.cursor()
|
| 303 |
-
cur.execute(sql)
|
| 304 |
-
res = cur.fetchall()
|
| 305 |
-
conn.close()
|
| 306 |
-
return res, None
|
| 307 |
-
except Exception as e:
|
| 308 |
-
return None, str(e)
|
| 309 |
-
|
| 310 |
-
# -------------------------------
|
| 311 |
-
# STEP 2: ERROR CLASSIFICATION
|
| 312 |
-
# -------------------------------
|
| 313 |
-
def classify_error(sql, error_msg):
|
| 314 |
-
if error_msg is None:
|
| 315 |
-
return "correct"
|
| 316 |
-
|
| 317 |
-
err = error_msg.lower()
|
| 318 |
-
sql_l = sql.lower()
|
| 319 |
-
|
| 320 |
-
if "syntax" in err:
|
| 321 |
-
return "syntax_error"
|
| 322 |
-
if "no such table" in err:
|
| 323 |
-
return "wrong_table"
|
| 324 |
-
if "no such column" in err:
|
| 325 |
-
return "wrong_column"
|
| 326 |
-
if "ambiguous" in err:
|
| 327 |
-
return "missing_join"
|
| 328 |
-
if "datatype mismatch" in err:
|
| 329 |
-
return "type_error"
|
| 330 |
-
if "where" not in sql_l and any(x in sql_l for x in ["=", ">", "<"]):
|
| 331 |
-
return "missing_where"
|
| 332 |
-
|
| 333 |
-
return "other"
|
| 334 |
-
|
| 335 |
-
# -------------------------------
|
| 336 |
-
# STEP 4: HINTS
|
| 337 |
-
# -------------------------------
|
| 338 |
-
def generate_hint(error_type):
|
| 339 |
-
hints = {
|
| 340 |
-
"missing_join": "Try using JOIN between related tables.",
|
| 341 |
-
"wrong_column": "Check column names in schema.",
|
| 342 |
-
"missing_where": "Add WHERE condition.",
|
| 343 |
-
"syntax_error": "Fix SQL syntax.",
|
| 344 |
-
"wrong_table": "Verify table names.",
|
| 345 |
-
"type_error": "Check data types.",
|
| 346 |
-
"other": "Review SQL logic."
|
| 347 |
-
}
|
| 348 |
-
return hints.get(error_type, "")
|
| 349 |
-
|
| 350 |
-
# -------------------------------
|
| 351 |
-
# STEP 2 EXTRA: LIGHT ATTRIBUTION
|
| 352 |
-
# -------------------------------
|
| 353 |
-
def extract_keywords(question):
|
| 354 |
-
return [w for w in re.findall(r"\w+", question.lower()) if len(w) > 3]
|
| 355 |
-
|
| 356 |
-
# -------------------------------
|
| 357 |
-
# MAIN
|
| 358 |
-
# -------------------------------
|
| 359 |
-
def main():
|
| 360 |
-
parser = argparse.ArgumentParser()
|
| 361 |
-
parser.add_argument("--adapter", type=str, required=True)
|
| 362 |
-
parser.add_argument("--num_samples", type=int, default=200)
|
| 363 |
-
args = parser.parse_args()
|
| 364 |
-
|
| 365 |
-
project_root = Path(__file__).resolve().parents[1]
|
| 366 |
-
db_root = project_root / "data" / "database"
|
| 367 |
-
dev_json = project_root / "data" / "dev.json"
|
| 368 |
-
|
| 369 |
-
device = "mps" if torch.backends.mps.is_available() else "cpu"
|
| 370 |
-
|
| 371 |
-
tokenizer = AutoTokenizer.from_pretrained("Salesforce/codet5-base")
|
| 372 |
-
base = AutoModelForSeq2SeqLM.from_pretrained("Salesforce/codet5-base").to(device)
|
| 373 |
-
|
| 374 |
-
model = PeftModel.from_pretrained(
|
| 375 |
-
base,
|
| 376 |
-
os.path.relpath(project_root / args.adapter, project_root),
|
| 377 |
-
local_files_only=True
|
| 378 |
-
).to(device)
|
| 379 |
-
|
| 380 |
-
model = model.merge_and_unload()
|
| 381 |
-
model.eval()
|
| 382 |
-
|
| 383 |
-
with open(dev_json) as f:
|
| 384 |
-
dev = json.load(f)
|
| 385 |
-
|
| 386 |
-
dev = dev[:args.num_samples]
|
| 387 |
-
|
| 388 |
-
# STORAGE
|
| 389 |
-
error_counter = defaultdict(int)
|
| 390 |
-
error_examples = defaultdict(list)
|
| 391 |
-
success_examples = []
|
| 392 |
-
hint_examples = defaultdict(list)
|
| 393 |
-
operation_counter = defaultdict(int)
|
| 394 |
-
attribution_map = defaultdict(list)
|
| 395 |
-
|
| 396 |
-
em, ex = 0, 0
|
| 397 |
-
|
| 398 |
-
print(f"\n🚀 Evaluating {len(dev)} samples...\n")
|
| 399 |
-
|
| 400 |
-
for i, sample in enumerate(dev, 1):
|
| 401 |
-
|
| 402 |
-
db_id = sample["db_id"]
|
| 403 |
-
q = sample["question"]
|
| 404 |
-
gold = sample["query"]
|
| 405 |
-
db_path = db_root / db_id / f"{db_id}.sqlite"
|
| 406 |
-
|
| 407 |
-
input_ids = encode_prompt(tokenizer, q, db_id, device=device).unsqueeze(0)
|
| 408 |
-
|
| 409 |
-
out = model.generate(input_ids=input_ids, max_new_tokens=120, num_beams=8)
|
| 410 |
-
pred = tokenizer.decode(out[0], skip_special_tokens=True).strip()
|
| 411 |
-
|
| 412 |
-
# operation analysis
|
| 413 |
-
s = pred.lower()
|
| 414 |
-
if "select" in s: operation_counter["SELECT"] += 1
|
| 415 |
-
if "where" in s: operation_counter["WHERE"] += 1
|
| 416 |
-
if "join" in s: operation_counter["JOIN"] += 1
|
| 417 |
-
if "group by" in s: operation_counter["GROUP_BY"] += 1
|
| 418 |
-
if "order by" in s: operation_counter["ORDER_BY"] += 1
|
| 419 |
-
|
| 420 |
-
pred_res, err = execute_with_error(pred, db_path)
|
| 421 |
-
gold_res, _ = execute_with_error(gold, db_path)
|
| 422 |
-
|
| 423 |
-
error_type = classify_error(pred, err)
|
| 424 |
-
error_counter[error_type] += 1
|
| 425 |
-
|
| 426 |
-
# attribution
|
| 427 |
-
if err:
|
| 428 |
-
attribution_map[error_type].append(extract_keywords(q))
|
| 429 |
-
|
| 430 |
-
# examples
|
| 431 |
-
if len(error_examples[error_type]) < 3:
|
| 432 |
-
error_examples[error_type].append(pred)
|
| 433 |
-
|
| 434 |
-
# hints
|
| 435 |
-
if error_type != "correct":
|
| 436 |
-
hint = generate_hint(error_type)
|
| 437 |
-
if len(hint_examples[error_type]) < 3:
|
| 438 |
-
hint_examples[error_type].append((pred, hint))
|
| 439 |
-
|
| 440 |
-
# metrics
|
| 441 |
-
if normalize_sql(pred) == normalize_sql(gold):
|
| 442 |
-
em += 1
|
| 443 |
-
|
| 444 |
-
if pred_res and gold_res and normalize_result(pred_res) == normalize_result(gold_res):
|
| 445 |
-
ex += 1
|
| 446 |
-
if len(success_examples) < 5:
|
| 447 |
-
success_examples.append(pred)
|
| 448 |
-
|
| 449 |
-
if i % 20 == 0:
|
| 450 |
-
print(f"[{i}] EM: {em/i:.2f} | EX: {ex/i:.2f}")
|
| 451 |
-
|
| 452 |
-
# -------------------------------
|
| 453 |
-
# OUTPUT
|
| 454 |
-
# -------------------------------
|
| 455 |
-
print("\n🎯 FINAL RESULTS")
|
| 456 |
-
print(f"EM: {em/len(dev)*100:.2f}%")
|
| 457 |
-
print(f"EX: {ex/len(dev)*100:.2f}%")
|
| 458 |
-
|
| 459 |
-
print("\n🔥 ERROR SUMMARY")
|
| 460 |
-
for k, v in error_counter.items():
|
| 461 |
-
print(k, ":", v)
|
| 462 |
-
|
| 463 |
-
print("\n🔥 ERROR EXAMPLES")
|
| 464 |
-
for k in error_examples:
|
| 465 |
-
print("\n", k)
|
| 466 |
-
for e in error_examples[k]:
|
| 467 |
-
print(" ", e)
|
| 468 |
-
|
| 469 |
-
print("\n🔥 HINTS")
|
| 470 |
-
for k in hint_examples:
|
| 471 |
-
print("\n", k)
|
| 472 |
-
for sql, h in hint_examples[k]:
|
| 473 |
-
print(" ", sql)
|
| 474 |
-
print(" →", h)
|
| 475 |
-
|
| 476 |
-
print("\n🔥 ATTRIBUTION (KEYWORDS)")
|
| 477 |
-
for k in attribution_map:
|
| 478 |
-
print(k, ":", attribution_map[k][:3])
|
| 479 |
-
|
| 480 |
-
print("\n🔥 SQL OPERATIONS")
|
| 481 |
-
for k, v in operation_counter.items():
|
| 482 |
-
print(k, ":", v)
|
| 483 |
-
|
| 484 |
-
# -------------------------------
|
| 485 |
-
# ADVERSARIAL
|
| 486 |
-
# -------------------------------
|
| 487 |
-
print("\n🔥 ADVERSARIAL TESTS")
|
| 488 |
-
|
| 489 |
-
adv = [
|
| 490 |
-
"Find most expensive product",
|
| 491 |
-
"Top 3 students by marks",
|
| 492 |
-
"Average salary per department"
|
| 493 |
-
]
|
| 494 |
-
|
| 495 |
-
for q in adv:
|
| 496 |
-
inp = encode_prompt(tokenizer, q, dev[0]["db_id"], device=device).unsqueeze(0)
|
| 497 |
-
out = model.generate(input_ids=inp, max_new_tokens=120)
|
| 498 |
-
print("\nQ:", q)
|
| 499 |
-
print("SQL:", tokenizer.decode(out[0], skip_special_tokens=True))
|
| 500 |
-
|
| 501 |
-
|
| 502 |
-
if __name__ == "__main__":
|
| 503 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/execution_reward copy.py
DELETED
|
@@ -1,831 +0,0 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
# from __future__ import annotations
|
| 4 |
-
|
| 5 |
-
# import hashlib
|
| 6 |
-
# import os
|
| 7 |
-
# import queue
|
| 8 |
-
# import re
|
| 9 |
-
# import sqlite3
|
| 10 |
-
# import threading
|
| 11 |
-
# import time
|
| 12 |
-
# from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 13 |
-
# from dataclasses import dataclass
|
| 14 |
-
# from typing import Dict, List, Optional, Sequence, Set, Tuple, Union
|
| 15 |
-
|
| 16 |
-
# # --- CACHE CONTROL ---
|
| 17 |
-
# USE_CACHE = True
|
| 18 |
-
# _REWARD_CACHE: Dict[str, float] = {}
|
| 19 |
-
|
| 20 |
-
# def set_use_cache(enabled: bool):
|
| 21 |
-
# """Dynamically toggle the reward cache for benchmarks."""
|
| 22 |
-
# global USE_CACHE
|
| 23 |
-
# USE_CACHE = enabled
|
| 24 |
-
|
| 25 |
-
# def _normalize_sql(sql: str) -> str:
|
| 26 |
-
# if not isinstance(sql, str):
|
| 27 |
-
# return ""
|
| 28 |
-
# s = sql.strip()
|
| 29 |
-
# if s.startswith("```"):
|
| 30 |
-
# s = re.sub(r"^```[a-zA-Z0-9_+-]*\n?", "", s).strip()
|
| 31 |
-
# s = re.sub(r"\n?```$", "", s).strip()
|
| 32 |
-
# if s.lower().startswith("sql:"):
|
| 33 |
-
# s = s[4:].strip()
|
| 34 |
-
# if ";" in s:
|
| 35 |
-
# s = s.split(";", 1)[0].strip()
|
| 36 |
-
# return s
|
| 37 |
-
|
| 38 |
-
# def _connect_readonly(db_path: str) -> sqlite3.Connection:
|
| 39 |
-
# uri = f"file:{os.path.abspath(db_path)}?mode=ro"
|
| 40 |
-
# conn = sqlite3.connect(uri, uri=True, check_same_thread=False)
|
| 41 |
-
# conn.execute("PRAGMA query_only = ON;")
|
| 42 |
-
# conn.execute("PRAGMA foreign_keys = ON;")
|
| 43 |
-
# return conn
|
| 44 |
-
|
| 45 |
-
# DEFAULT_QUERY_TIMEOUT_S = 2.0
|
| 46 |
-
|
| 47 |
-
# def _with_timeout(conn: sqlite3.Connection, timeout_s: float = DEFAULT_QUERY_TIMEOUT_S) -> None:
|
| 48 |
-
# start = time.monotonic()
|
| 49 |
-
# def _handler() -> int:
|
| 50 |
-
# return 1 if (time.monotonic() - start) > timeout_s else 0
|
| 51 |
-
# conn.set_progress_handler(_handler, 10_000)
|
| 52 |
-
|
| 53 |
-
# def _list_tables(conn: sqlite3.Connection) -> List[str]:
|
| 54 |
-
# try:
|
| 55 |
-
# cur = conn.execute("SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%';")
|
| 56 |
-
# return [r[0] for r in cur.fetchall() if r and isinstance(r[0], str)]
|
| 57 |
-
# except sqlite3.Error:
|
| 58 |
-
# return []
|
| 59 |
-
|
| 60 |
-
# def _contains_table_name(sql: str, table_names: Sequence[str]) -> bool:
|
| 61 |
-
# s = sql.lower()
|
| 62 |
-
# for t in table_names:
|
| 63 |
-
# tl = t.lower()
|
| 64 |
-
# if not tl:
|
| 65 |
-
# continue
|
| 66 |
-
# if re.search(rf"\b{re.escape(tl)}\b", s):
|
| 67 |
-
# return True
|
| 68 |
-
# return False
|
| 69 |
-
|
| 70 |
-
# def _explain_query_plan(conn: sqlite3.Connection, sql: str) -> bool:
|
| 71 |
-
# try:
|
| 72 |
-
# _with_timeout(conn, timeout_s=DEFAULT_QUERY_TIMEOUT_S)
|
| 73 |
-
# conn.execute(f"EXPLAIN QUERY PLAN {sql}")
|
| 74 |
-
# return True
|
| 75 |
-
# except sqlite3.Error:
|
| 76 |
-
# return False
|
| 77 |
-
|
| 78 |
-
# def _execute(conn: sqlite3.Connection, sql: str, max_rows: int = 1000) -> Tuple[bool, List[Tuple], Optional[str]]:
|
| 79 |
-
# try:
|
| 80 |
-
# _with_timeout(conn, timeout_s=DEFAULT_QUERY_TIMEOUT_S)
|
| 81 |
-
# cur = conn.execute(sql)
|
| 82 |
-
# rows = cur.fetchmany(max_rows)
|
| 83 |
-
# norm_rows = [tuple(r) for r in rows]
|
| 84 |
-
# return True, norm_rows, None
|
| 85 |
-
# except sqlite3.Error as e:
|
| 86 |
-
# return False, [], str(e)
|
| 87 |
-
|
| 88 |
-
# _SQL_KEYWORDS_TO_IGNORE = {
|
| 89 |
-
# "select", "from", "where", "join", "inner", "left", "right", "full", "outer",
|
| 90 |
-
# "on", "group", "by", "order", "limit", "having", "distinct", "union", "intersect",
|
| 91 |
-
# "except", "as", "and", "or", "not", "in", "is", "null", "like", "between", "case",
|
| 92 |
-
# "when", "then", "else", "end", "asc", "desc"
|
| 93 |
-
# }
|
| 94 |
-
|
| 95 |
-
# _SQL_FUNCTIONS_TO_IGNORE = {
|
| 96 |
-
# "count", "avg", "min", "max", "sum", "lower", "upper", "substr", "coalesce",
|
| 97 |
-
# "round", "date", "datetime", "strftime"
|
| 98 |
-
# }
|
| 99 |
-
|
| 100 |
-
# # --- LIGHTWEIGHT PARSING ---
|
| 101 |
-
# def is_valid_select(sql: str):
|
| 102 |
-
# sql = sql.strip().lower()
|
| 103 |
-
# return sql.startswith("select") or sql.startswith("with")
|
| 104 |
-
|
| 105 |
-
# def extract_tables(sql: str) -> List[str]:
|
| 106 |
-
# sql = sql.lower()
|
| 107 |
-
# if "join" not in sql:
|
| 108 |
-
# tables = re.findall(r'from\s+(\w+)', sql)
|
| 109 |
-
# return list(set(tables))
|
| 110 |
-
|
| 111 |
-
# tables = re.findall(r'from\s+([a-zA-Z_][a-zA-Z0-9_]*)', sql)
|
| 112 |
-
# joins = re.findall(r'join\s+([a-zA-Z_][a-zA-Z0-9_]*)', sql)
|
| 113 |
-
# return list(set(tables + joins))
|
| 114 |
-
|
| 115 |
-
# def extract_columns(sql: str) -> List[str]:
|
| 116 |
-
# sql = sql.lower()
|
| 117 |
-
# match = re.search(r'select\s+(.*?)\s+from', sql)
|
| 118 |
-
# if not match:
|
| 119 |
-
# return []
|
| 120 |
-
# cols = match.group(1)
|
| 121 |
-
# if cols.strip() == "*":
|
| 122 |
-
# return ["*"]
|
| 123 |
-
# return [c.strip() for c in cols.split(",")]
|
| 124 |
-
|
| 125 |
-
# def _get_db_tables_and_columns(conn: sqlite3.Connection) -> Tuple[Set[str], Set[str]]:
|
| 126 |
-
# tables = set()
|
| 127 |
-
# columns = set()
|
| 128 |
-
# for t in _list_tables(conn):
|
| 129 |
-
# tl = t.lower()
|
| 130 |
-
# if not tl:
|
| 131 |
-
# continue
|
| 132 |
-
# tables.add(tl)
|
| 133 |
-
# try:
|
| 134 |
-
# cur = conn.execute(f'PRAGMA table_info("{t}")')
|
| 135 |
-
# for row in cur.fetchall():
|
| 136 |
-
# if row and isinstance(row[1], str):
|
| 137 |
-
# columns.add(row[1].lower())
|
| 138 |
-
# except sqlite3.Error:
|
| 139 |
-
# continue
|
| 140 |
-
# return tables, columns
|
| 141 |
-
|
| 142 |
-
# def _safe_results_equal(a: List[Tuple], b: List[Tuple]) -> bool:
|
| 143 |
-
# return a == b
|
| 144 |
-
|
| 145 |
-
# @dataclass
|
| 146 |
-
# class RewardDebugStats:
|
| 147 |
-
# total: int = 0
|
| 148 |
-
# parsed_ok: int = 0
|
| 149 |
-
# table_match: int = 0
|
| 150 |
-
# column_match: int = 0
|
| 151 |
-
# executed_ok: int = 0
|
| 152 |
-
# exact_match: int = 0
|
| 153 |
-
|
| 154 |
-
# _DEBUG = RewardDebugStats()
|
| 155 |
-
|
| 156 |
-
# def reset_debug_metrics() -> None:
|
| 157 |
-
# global _DEBUG
|
| 158 |
-
# _DEBUG = RewardDebugStats()
|
| 159 |
-
|
| 160 |
-
# def get_debug_metrics() -> dict:
|
| 161 |
-
# denom = max(_DEBUG.total, 1)
|
| 162 |
-
# return {
|
| 163 |
-
# "valid_sql_rate": _DEBUG.parsed_ok / denom,
|
| 164 |
-
# "table_match_rate": _DEBUG.table_match / denom,
|
| 165 |
-
# "column_match_rate": _DEBUG.column_match / denom,
|
| 166 |
-
# "execution_accuracy": _DEBUG.exact_match / denom,
|
| 167 |
-
# }
|
| 168 |
-
|
| 169 |
-
# EXECUTION_ERROR = "EXECUTION_ERROR"
|
| 170 |
-
|
| 171 |
-
# _RESULT_CACHE_LOCK = threading.Lock()
|
| 172 |
-
# _RESULT_CACHE: "Dict[str, Union[List[Tuple], str]]" = {}
|
| 173 |
-
# _RESULT_CACHE_MAX = 100_000
|
| 174 |
-
|
| 175 |
-
# def clear_result_cache() -> None:
|
| 176 |
-
# """Clear both DB query cache and reward cache."""
|
| 177 |
-
# with _RESULT_CACHE_LOCK:
|
| 178 |
-
# _RESULT_CACHE.clear()
|
| 179 |
-
# _REWARD_CACHE.clear()
|
| 180 |
-
|
| 181 |
-
# def _db_state_fingerprint(db_path: str) -> str:
|
| 182 |
-
# try:
|
| 183 |
-
# st = os.stat(db_path)
|
| 184 |
-
# return f"{st.st_mtime_ns}:{st.st_size}"
|
| 185 |
-
# except OSError:
|
| 186 |
-
# return "missing"
|
| 187 |
-
|
| 188 |
-
# def _result_cache_key(db_path: str, sql: str) -> str:
|
| 189 |
-
# fp = _db_state_fingerprint(db_path)
|
| 190 |
-
# payload = f"{fp}\0{sql}".encode("utf-8", errors="ignore")
|
| 191 |
-
# return hashlib.sha256(payload).hexdigest()
|
| 192 |
-
|
| 193 |
-
# class _ConnectionPool:
|
| 194 |
-
# def __init__(self, db_path: str, maxsize: int = 1) -> None:
|
| 195 |
-
# self.db_path = db_path
|
| 196 |
-
# self.pool = queue.LifoQueue(maxsize=maxsize)
|
| 197 |
-
# self.lock = threading.Lock()
|
| 198 |
-
|
| 199 |
-
# def acquire(self) -> sqlite3.Connection:
|
| 200 |
-
# try:
|
| 201 |
-
# return self.pool.get_nowait()
|
| 202 |
-
# except queue.Empty:
|
| 203 |
-
# with self.lock:
|
| 204 |
-
# try:
|
| 205 |
-
# return self.pool.get_nowait()
|
| 206 |
-
# except queue.Empty:
|
| 207 |
-
# return _connect_readonly(self.db_path)
|
| 208 |
-
|
| 209 |
-
# def release(self, conn: sqlite3.Connection) -> None:
|
| 210 |
-
# try:
|
| 211 |
-
# self.pool.put_nowait(conn)
|
| 212 |
-
# except queue.Full:
|
| 213 |
-
# try:
|
| 214 |
-
# conn.close()
|
| 215 |
-
# except Exception:
|
| 216 |
-
# pass
|
| 217 |
-
|
| 218 |
-
# _POOL_LOCK = threading.Lock()
|
| 219 |
-
# _POOLS: Dict[str, _ConnectionPool] = {}
|
| 220 |
-
|
| 221 |
-
# def _get_pool(db_path: str) -> _ConnectionPool:
|
| 222 |
-
# with _POOL_LOCK:
|
| 223 |
-
# pool = _POOLS.get(db_path)
|
| 224 |
-
# if pool is None:
|
| 225 |
-
# pool = _ConnectionPool(db_path=db_path, maxsize=1)
|
| 226 |
-
# _POOLS[db_path] = pool
|
| 227 |
-
# return pool
|
| 228 |
-
|
| 229 |
-
# class _PooledConnection:
|
| 230 |
-
# def __init__(self, db_path: str) -> None:
|
| 231 |
-
# self.db_path = db_path
|
| 232 |
-
# self.pool = _get_pool(db_path)
|
| 233 |
-
# self.conn: Optional[sqlite3.Connection] = None
|
| 234 |
-
|
| 235 |
-
# def __enter__(self) -> sqlite3.Connection:
|
| 236 |
-
# self.conn = self.pool.acquire()
|
| 237 |
-
# return self.conn
|
| 238 |
-
|
| 239 |
-
# def __exit__(self, exc_type, exc, tb) -> None:
|
| 240 |
-
# if self.conn is not None:
|
| 241 |
-
# self.pool.release(self.conn)
|
| 242 |
-
# self.conn = None
|
| 243 |
-
|
| 244 |
-
# def _cache_get(key: str) -> Optional[Union[List[Tuple], str]]:
|
| 245 |
-
# with _RESULT_CACHE_LOCK:
|
| 246 |
-
# return _RESULT_CACHE.get(key)
|
| 247 |
-
|
| 248 |
-
# def _cache_put(key: str, value: Union[List[Tuple], str]) -> None:
|
| 249 |
-
# with _RESULT_CACHE_LOCK:
|
| 250 |
-
# if len(_RESULT_CACHE) >= _RESULT_CACHE_MAX:
|
| 251 |
-
# _RESULT_CACHE.clear()
|
| 252 |
-
# _RESULT_CACHE[key] = value
|
| 253 |
-
|
| 254 |
-
# def execute_sql(conn: sqlite3.Connection, sql: str, *, max_rows: int = 1000) -> Union[List[Tuple], str]:
|
| 255 |
-
# try:
|
| 256 |
-
# _with_timeout(conn, timeout_s=DEFAULT_QUERY_TIMEOUT_S)
|
| 257 |
-
# cur = conn.execute(sql)
|
| 258 |
-
# rows = cur.fetchmany(max_rows)
|
| 259 |
-
# return [tuple(r) for r in rows]
|
| 260 |
-
# except Exception:
|
| 261 |
-
# return EXECUTION_ERROR
|
| 262 |
-
|
| 263 |
-
# def execute_sql_cached(db_path: str, sql: str, *, max_rows: int = 1000) -> Union[List[Tuple], str]:
|
| 264 |
-
# if not USE_CACHE:
|
| 265 |
-
# with _PooledConnection(db_path) as conn:
|
| 266 |
-
# return execute_sql(conn, sql, max_rows=max_rows)
|
| 267 |
-
|
| 268 |
-
# key = _result_cache_key(db_path, sql)
|
| 269 |
-
# cached = _cache_get(key)
|
| 270 |
-
# if cached is not None:
|
| 271 |
-
# return cached
|
| 272 |
-
# with _PooledConnection(db_path) as conn:
|
| 273 |
-
# res = execute_sql(conn, sql, max_rows=max_rows)
|
| 274 |
-
# _cache_put(key, res)
|
| 275 |
-
# return res
|
| 276 |
-
|
| 277 |
-
# def execution_reward_timed(
|
| 278 |
-
# pred_sql: str, db_path: str, gold_sql: str, *, measure_plan: bool = False,
|
| 279 |
-
# ) -> Tuple[float, Dict[str, float]]:
|
| 280 |
-
# timings = {"parse_s": 0.0, "plan_s": 0.0, "exec_s": 0.0}
|
| 281 |
-
# t0 = time.perf_counter()
|
| 282 |
-
# sql = _normalize_sql(pred_sql)
|
| 283 |
-
# gold = _normalize_sql(gold_sql)
|
| 284 |
-
|
| 285 |
-
# if not is_valid_select(sql):
|
| 286 |
-
# timings["parse_s"] = time.perf_counter() - t0
|
| 287 |
-
# return 0.0, timings
|
| 288 |
-
|
| 289 |
-
# t1 = time.perf_counter()
|
| 290 |
-
# timings["parse_s"] = t1 - t0
|
| 291 |
-
|
| 292 |
-
# if measure_plan:
|
| 293 |
-
# with _PooledConnection(db_path) as conn:
|
| 294 |
-
# p0 = time.perf_counter()
|
| 295 |
-
# _explain_query_plan(conn, sql)
|
| 296 |
-
# _explain_query_plan(conn, gold)
|
| 297 |
-
# timings["plan_s"] = time.perf_counter() - p0
|
| 298 |
-
|
| 299 |
-
# e0 = time.perf_counter()
|
| 300 |
-
# pred_res = execute_sql_cached(db_path, sql)
|
| 301 |
-
# if pred_res == EXECUTION_ERROR:
|
| 302 |
-
# timings["exec_s"] = time.perf_counter() - e0
|
| 303 |
-
# return 0.0, timings
|
| 304 |
-
# gold_res = execute_sql_cached(db_path, gold)
|
| 305 |
-
# timings["exec_s"] = time.perf_counter() - e0
|
| 306 |
-
# if gold_res == EXECUTION_ERROR:
|
| 307 |
-
# return 0.0, timings
|
| 308 |
-
|
| 309 |
-
# reward = -0.2
|
| 310 |
-
# reward += 0.2
|
| 311 |
-
# if _safe_results_equal(pred_res, gold_res):
|
| 312 |
-
# return 1.0, timings
|
| 313 |
-
# return max(-1.0, min(1.0, reward)), timings
|
| 314 |
-
|
| 315 |
-
# def execution_reward(pred_sql: str, db_path: str, gold_sql: str) -> float:
|
| 316 |
-
# try:
|
| 317 |
-
# sql = _normalize_sql(pred_sql)
|
| 318 |
-
# gold = _normalize_sql(gold_sql)
|
| 319 |
-
|
| 320 |
-
# if not is_valid_select(sql):
|
| 321 |
-
# return -1.0
|
| 322 |
-
|
| 323 |
-
# reward = -0.2
|
| 324 |
-
|
| 325 |
-
# pred_tables = set(extract_tables(sql))
|
| 326 |
-
# gold_tables = set(extract_tables(gold))
|
| 327 |
-
|
| 328 |
-
# if pred_tables == gold_tables and len(gold_tables) > 0:
|
| 329 |
-
# reward += 0.3
|
| 330 |
-
|
| 331 |
-
# pred_cols = set(extract_columns(sql))
|
| 332 |
-
# gold_cols = set(extract_columns(gold))
|
| 333 |
-
|
| 334 |
-
# if gold_cols:
|
| 335 |
-
# overlap = len(pred_cols & gold_cols) / len(gold_cols)
|
| 336 |
-
# reward += 0.3 * overlap
|
| 337 |
-
|
| 338 |
-
# pred_res = execute_sql_cached(db_path, sql)
|
| 339 |
-
# if pred_res == EXECUTION_ERROR:
|
| 340 |
-
# return 0.0
|
| 341 |
-
# reward += 0.2
|
| 342 |
-
|
| 343 |
-
# gold_res = execute_sql_cached(db_path, gold)
|
| 344 |
-
# if gold_res == EXECUTION_ERROR:
|
| 345 |
-
# return 0.0
|
| 346 |
-
# if _safe_results_equal(pred_res, gold_res):
|
| 347 |
-
# return 1.0
|
| 348 |
-
|
| 349 |
-
# return max(-1.0, min(1.0, reward))
|
| 350 |
-
|
| 351 |
-
# except Exception:
|
| 352 |
-
# return 0.0
|
| 353 |
-
|
| 354 |
-
# def cached_execution_reward(pred_sql: str, db_path: str, gold_sql: str) -> float:
|
| 355 |
-
# if not USE_CACHE:
|
| 356 |
-
# return execution_reward(pred_sql, db_path, gold_sql)
|
| 357 |
-
|
| 358 |
-
# key = f"{db_path}|{pred_sql}|{gold_sql}"
|
| 359 |
-
# if key not in _REWARD_CACHE:
|
| 360 |
-
# _REWARD_CACHE[key] = execution_reward(pred_sql, db_path, gold_sql)
|
| 361 |
-
# return _REWARD_CACHE[key]
|
| 362 |
-
|
| 363 |
-
# def execution_reward_batch_sequential(rollouts: Sequence[Tuple[str, str, str]]) -> List[float]:
|
| 364 |
-
# return [cached_execution_reward(pred_sql, db_path, gold_sql) for pred_sql, db_path, gold_sql in rollouts]
|
| 365 |
-
|
| 366 |
-
# def execution_reward_batch_parallel(rollouts: Sequence[Tuple[str, str, str]], *, max_workers: int = 20) -> List[float]:
|
| 367 |
-
# if not rollouts:
|
| 368 |
-
# return []
|
| 369 |
-
|
| 370 |
-
# unique_dbs = {db_path for _, db_path, _ in rollouts}
|
| 371 |
-
# worker_count = max(1, min(max_workers, len(unique_dbs)))
|
| 372 |
-
# results: List[Optional[float]] = [None] * len(rollouts)
|
| 373 |
-
|
| 374 |
-
# with ThreadPoolExecutor(max_workers=worker_count) as executor:
|
| 375 |
-
# futures = {
|
| 376 |
-
# executor.submit(cached_execution_reward, pred_sql, db_path, gold_sql): i
|
| 377 |
-
# for i, (pred_sql, db_path, gold_sql) in enumerate(rollouts)
|
| 378 |
-
# }
|
| 379 |
-
# for fut in as_completed(futures):
|
| 380 |
-
# idx = futures[fut]
|
| 381 |
-
# try:
|
| 382 |
-
# results[idx] = float(fut.result())
|
| 383 |
-
# except Exception:
|
| 384 |
-
# results[idx] = 0.0
|
| 385 |
-
|
| 386 |
-
# return [r if r is not None else 0.0 for r in results]
|
| 387 |
-
|
| 388 |
-
from __future__ import annotations
|
| 389 |
-
|
| 390 |
-
import os
|
| 391 |
-
import re
|
| 392 |
-
import sqlite3
|
| 393 |
-
import threading
|
| 394 |
-
import time
|
| 395 |
-
import json
|
| 396 |
-
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 397 |
-
from dataclasses import dataclass
|
| 398 |
-
from typing import Dict, List
|
| 399 |
-
|
| 400 |
-
from src.sql_validator import validate_sql_schema
|
| 401 |
-
|
| 402 |
-
# =========================================================
|
| 403 |
-
# 🔥 CONFIG FLAGS
|
| 404 |
-
# =========================================================
|
| 405 |
-
USE_SCHEMA_VALIDATION = True
|
| 406 |
-
USE_CACHE = True
|
| 407 |
-
DEFAULT_QUERY_TIMEOUT_S = 2.0
|
| 408 |
-
|
| 409 |
-
EXECUTION_ERROR = "EXECUTION_ERROR"
|
| 410 |
-
|
| 411 |
-
_REWARD_CACHE: Dict[str, float] = {}
|
| 412 |
-
|
| 413 |
-
# =========================================================
|
| 414 |
-
# 🔥 TASK 2: ERROR ANALYSIS + LOGGING
|
| 415 |
-
# =========================================================
|
| 416 |
-
ERROR_LOG_FILE = "results/error_logs.json"
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
def classify_error(sql: str) -> str:
|
| 420 |
-
sql = sql.lower()
|
| 421 |
-
|
| 422 |
-
if "join" in sql and " on " not in sql:
|
| 423 |
-
return "missing_join"
|
| 424 |
-
|
| 425 |
-
if "where" in sql and "=" not in sql and ">" not in sql and "<" not in sql:
|
| 426 |
-
return "wrong_where"
|
| 427 |
-
|
| 428 |
-
if "null" in sql:
|
| 429 |
-
return "null_handling"
|
| 430 |
-
|
| 431 |
-
if "group by" in sql and "count" not in sql:
|
| 432 |
-
return "wrong_groupby"
|
| 433 |
-
|
| 434 |
-
return "other"
|
| 435 |
-
|
| 436 |
-
|
| 437 |
-
def get_hint(error_type: str) -> str:
|
| 438 |
-
hints = {
|
| 439 |
-
"missing_join": "Add proper JOIN condition using ON.",
|
| 440 |
-
"wrong_where": "Check WHERE clause conditions.",
|
| 441 |
-
"null_handling": "Handle NULL values using IS NULL.",
|
| 442 |
-
"wrong_groupby": "Use aggregation functions with GROUP BY.",
|
| 443 |
-
"other": "Check SQL syntax and logic."
|
| 444 |
-
}
|
| 445 |
-
return hints.get(error_type, "Check query.")
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
def log_error(question: str, sql: str, error: str, error_type: str):
|
| 449 |
-
os.makedirs("results", exist_ok=True)
|
| 450 |
-
|
| 451 |
-
entry = {
|
| 452 |
-
"question": question,
|
| 453 |
-
"sql": sql,
|
| 454 |
-
"error": error,
|
| 455 |
-
"error_type": error_type,
|
| 456 |
-
"timestamp": time.time()
|
| 457 |
-
}
|
| 458 |
-
|
| 459 |
-
if os.path.exists(ERROR_LOG_FILE):
|
| 460 |
-
with open(ERROR_LOG_FILE, "r") as f:
|
| 461 |
-
logs = json.load(f)
|
| 462 |
-
else:
|
| 463 |
-
logs = []
|
| 464 |
-
|
| 465 |
-
logs.append(entry)
|
| 466 |
-
|
| 467 |
-
with open(ERROR_LOG_FILE, "w") as f:
|
| 468 |
-
json.dump(logs, f, indent=2)
|
| 469 |
-
|
| 470 |
-
# =========================================================
|
| 471 |
-
# CACHE/VALIDATION TOGGLES (Task 1)
|
| 472 |
-
# =========================================================
|
| 473 |
-
def set_use_cache(enabled: bool) -> None:
|
| 474 |
-
global USE_CACHE
|
| 475 |
-
USE_CACHE = bool(enabled)
|
| 476 |
-
|
| 477 |
-
|
| 478 |
-
def set_use_schema_validation(enabled: bool) -> None:
|
| 479 |
-
global USE_SCHEMA_VALIDATION
|
| 480 |
-
USE_SCHEMA_VALIDATION = bool(enabled)
|
| 481 |
-
|
| 482 |
-
|
| 483 |
-
# =========================================================
|
| 484 |
-
# SQL CLEANING
|
| 485 |
-
# =========================================================
|
| 486 |
-
def _normalize_sql(sql: str) -> str:
|
| 487 |
-
if not isinstance(sql, str):
|
| 488 |
-
return ""
|
| 489 |
-
s = sql.strip()
|
| 490 |
-
|
| 491 |
-
if s.startswith("```"):
|
| 492 |
-
s = re.sub(r"^```[a-zA-Z0-9_+-]*\n?", "", s).strip()
|
| 493 |
-
s = re.sub(r"\n?```$", "", s).strip()
|
| 494 |
-
|
| 495 |
-
if s.lower().startswith("sql:"):
|
| 496 |
-
s = s[4:].strip()
|
| 497 |
-
|
| 498 |
-
if ";" in s:
|
| 499 |
-
s = s.split(";", 1)[0].strip()
|
| 500 |
-
|
| 501 |
-
return s
|
| 502 |
-
|
| 503 |
-
|
| 504 |
-
# =========================================================
|
| 505 |
-
# DB EXECUTION
|
| 506 |
-
# =========================================================
|
| 507 |
-
def _connect_readonly(db_path: str):
|
| 508 |
-
uri = f"file:{os.path.abspath(db_path)}?mode=ro"
|
| 509 |
-
conn = sqlite3.connect(uri, uri=True, check_same_thread=False)
|
| 510 |
-
conn.execute("PRAGMA query_only = ON;")
|
| 511 |
-
conn.execute("PRAGMA foreign_keys = ON;")
|
| 512 |
-
return conn
|
| 513 |
-
|
| 514 |
-
|
| 515 |
-
def _with_timeout(conn: sqlite3.Connection, timeout_s: float = DEFAULT_QUERY_TIMEOUT_S):
|
| 516 |
-
start = time.monotonic()
|
| 517 |
-
|
| 518 |
-
def handler():
|
| 519 |
-
return 1 if (time.monotonic() - start) > timeout_s else 0
|
| 520 |
-
|
| 521 |
-
conn.set_progress_handler(handler, 10_000)
|
| 522 |
-
|
| 523 |
-
|
| 524 |
-
def execute_sql(conn, sql):
|
| 525 |
-
try:
|
| 526 |
-
_with_timeout(conn, timeout_s=DEFAULT_QUERY_TIMEOUT_S)
|
| 527 |
-
cur = conn.execute(sql)
|
| 528 |
-
return cur.fetchall()
|
| 529 |
-
except Exception:
|
| 530 |
-
return EXECUTION_ERROR
|
| 531 |
-
|
| 532 |
-
|
| 533 |
-
_RESULT_CACHE = {}
|
| 534 |
-
_RESULT_LOCK = threading.Lock()
|
| 535 |
-
|
| 536 |
-
|
| 537 |
-
def execute_sql_cached(db_path, sql):
|
| 538 |
-
key = f"{db_path}|{sql}"
|
| 539 |
-
|
| 540 |
-
if USE_CACHE:
|
| 541 |
-
with _RESULT_LOCK:
|
| 542 |
-
if key in _RESULT_CACHE:
|
| 543 |
-
return _RESULT_CACHE[key]
|
| 544 |
-
|
| 545 |
-
conn = _connect_readonly(db_path)
|
| 546 |
-
result = execute_sql(conn, sql)
|
| 547 |
-
conn.close()
|
| 548 |
-
|
| 549 |
-
if USE_CACHE:
|
| 550 |
-
with _RESULT_LOCK:
|
| 551 |
-
_RESULT_CACHE[key] = result
|
| 552 |
-
|
| 553 |
-
return result
|
| 554 |
-
|
| 555 |
-
|
| 556 |
-
def execute_sql_cached_conn(conn: sqlite3.Connection, db_path: str, sql: str):
|
| 557 |
-
"""
|
| 558 |
-
Like execute_sql_cached(), but reuses an existing connection.
|
| 559 |
-
Intended for 1-thread-per-DB workloads (Task 1).
|
| 560 |
-
"""
|
| 561 |
-
key = f"{db_path}|{sql}"
|
| 562 |
-
if USE_CACHE:
|
| 563 |
-
with _RESULT_LOCK:
|
| 564 |
-
if key in _RESULT_CACHE:
|
| 565 |
-
return _RESULT_CACHE[key]
|
| 566 |
-
|
| 567 |
-
result = execute_sql(conn, sql)
|
| 568 |
-
|
| 569 |
-
if USE_CACHE:
|
| 570 |
-
with _RESULT_LOCK:
|
| 571 |
-
_RESULT_CACHE[key] = result
|
| 572 |
-
|
| 573 |
-
return result
|
| 574 |
-
|
| 575 |
-
|
| 576 |
-
def clear_result_cache() -> None:
|
| 577 |
-
global _RESULT_CACHE, _REWARD_CACHE
|
| 578 |
-
with _RESULT_LOCK:
|
| 579 |
-
_RESULT_CACHE.clear()
|
| 580 |
-
_REWARD_CACHE.clear()
|
| 581 |
-
|
| 582 |
-
|
| 583 |
-
# =========================================================
|
| 584 |
-
# SQL PARSING
|
| 585 |
-
# =========================================================
|
| 586 |
-
def is_valid_select(sql):
|
| 587 |
-
return sql.lower().startswith("select") or sql.lower().startswith("with")
|
| 588 |
-
|
| 589 |
-
|
| 590 |
-
def extract_tables(sql):
|
| 591 |
-
return re.findall(r'from\s+(\w+)', sql.lower())
|
| 592 |
-
|
| 593 |
-
|
| 594 |
-
def extract_columns(sql):
|
| 595 |
-
match = re.search(r'select\s+(.*?)\s+from', sql.lower())
|
| 596 |
-
if not match:
|
| 597 |
-
return []
|
| 598 |
-
cols = match.group(1)
|
| 599 |
-
return ["*"] if cols.strip() == "*" else [c.strip() for c in cols.split(",")]
|
| 600 |
-
|
| 601 |
-
|
| 602 |
-
def get_sql_operations(sql: str):
|
| 603 |
-
sql = sql.lower()
|
| 604 |
-
ops = []
|
| 605 |
-
|
| 606 |
-
if "select" in sql: ops.append("SELECT")
|
| 607 |
-
if "where" in sql: ops.append("WHERE")
|
| 608 |
-
if "join" in sql: ops.append("JOIN")
|
| 609 |
-
if "group by" in sql: ops.append("GROUP_BY")
|
| 610 |
-
if "order by" in sql: ops.append("ORDER_BY")
|
| 611 |
-
|
| 612 |
-
return ops
|
| 613 |
-
|
| 614 |
-
|
| 615 |
-
def _explain_query_plan(conn: sqlite3.Connection, sql: str) -> bool:
|
| 616 |
-
try:
|
| 617 |
-
_with_timeout(conn, timeout_s=DEFAULT_QUERY_TIMEOUT_S)
|
| 618 |
-
conn.execute(f"EXPLAIN QUERY PLAN {sql}")
|
| 619 |
-
return True
|
| 620 |
-
except Exception:
|
| 621 |
-
return False
|
| 622 |
-
|
| 623 |
-
|
| 624 |
-
def execution_reward_timed(pred_sql: str, db_path: str, gold_sql: str, measure_plan: bool = False):
|
| 625 |
-
"""
|
| 626 |
-
Returns (reward, timings) where timings keys: parse_s, plan_s, exec_s.
|
| 627 |
-
Used by Task-1 benchmark to profile bottlenecks.
|
| 628 |
-
"""
|
| 629 |
-
timings = {"parse_s": 0.0, "plan_s": 0.0, "exec_s": 0.0}
|
| 630 |
-
t0 = time.perf_counter()
|
| 631 |
-
|
| 632 |
-
sql = _normalize_sql(pred_sql)
|
| 633 |
-
gold = _normalize_sql(gold_sql)
|
| 634 |
-
|
| 635 |
-
if not is_valid_select(sql):
|
| 636 |
-
timings["parse_s"] = time.perf_counter() - t0
|
| 637 |
-
return 0.0, timings
|
| 638 |
-
|
| 639 |
-
t1 = time.perf_counter()
|
| 640 |
-
timings["parse_s"] = t1 - t0
|
| 641 |
-
|
| 642 |
-
conn = _connect_readonly(db_path)
|
| 643 |
-
try:
|
| 644 |
-
if measure_plan:
|
| 645 |
-
p0 = time.perf_counter()
|
| 646 |
-
_explain_query_plan(conn, sql)
|
| 647 |
-
_explain_query_plan(conn, gold)
|
| 648 |
-
timings["plan_s"] = time.perf_counter() - p0
|
| 649 |
-
|
| 650 |
-
e0 = time.perf_counter()
|
| 651 |
-
pred_res = execute_sql_cached_conn(conn, db_path, sql)
|
| 652 |
-
if pred_res == EXECUTION_ERROR:
|
| 653 |
-
timings["exec_s"] = time.perf_counter() - e0
|
| 654 |
-
return 0.0, timings
|
| 655 |
-
gold_res = execute_sql_cached_conn(conn, db_path, gold)
|
| 656 |
-
timings["exec_s"] = time.perf_counter() - e0
|
| 657 |
-
if gold_res == EXECUTION_ERROR:
|
| 658 |
-
return 0.0, timings
|
| 659 |
-
|
| 660 |
-
reward = -0.2 + 0.2
|
| 661 |
-
if pred_res == gold_res:
|
| 662 |
-
return 1.0, timings
|
| 663 |
-
return max(-1.0, min(1.0, reward)), timings
|
| 664 |
-
finally:
|
| 665 |
-
try:
|
| 666 |
-
conn.close()
|
| 667 |
-
except Exception:
|
| 668 |
-
pass
|
| 669 |
-
|
| 670 |
-
|
| 671 |
-
# =========================================================
|
| 672 |
-
# 🔥 FINAL REWARD FUNCTION (TASK 2 INTEGRATED)
|
| 673 |
-
# =========================================================
|
| 674 |
-
def execution_reward(pred_sql: str, db_path: str, gold_sql: str) -> float:
|
| 675 |
-
try:
|
| 676 |
-
sql = _normalize_sql(pred_sql)
|
| 677 |
-
gold = _normalize_sql(gold_sql)
|
| 678 |
-
|
| 679 |
-
if not is_valid_select(sql):
|
| 680 |
-
return -1.0
|
| 681 |
-
|
| 682 |
-
reward = -0.2
|
| 683 |
-
|
| 684 |
-
# =========================
|
| 685 |
-
# SCHEMA VALIDATION (Task 3)
|
| 686 |
-
# =========================
|
| 687 |
-
if USE_SCHEMA_VALIDATION:
|
| 688 |
-
valid, _ = validate_sql_schema(sql, db_path)
|
| 689 |
-
if not valid:
|
| 690 |
-
error_type = classify_error(sql)
|
| 691 |
-
log_error("UNKNOWN", sql, "schema_invalid", error_type)
|
| 692 |
-
return 0.1
|
| 693 |
-
|
| 694 |
-
# =========================
|
| 695 |
-
# EXECUTION
|
| 696 |
-
# =========================
|
| 697 |
-
pred_res = execute_sql_cached(db_path, sql)
|
| 698 |
-
|
| 699 |
-
if pred_res == "EXECUTION_ERROR":
|
| 700 |
-
error_type = classify_error(sql)
|
| 701 |
-
|
| 702 |
-
log_error(
|
| 703 |
-
question="UNKNOWN",
|
| 704 |
-
sql=sql,
|
| 705 |
-
error="execution_error",
|
| 706 |
-
error_type=error_type
|
| 707 |
-
)
|
| 708 |
-
|
| 709 |
-
print(f"[ERROR] {error_type}")
|
| 710 |
-
print(f"[HINT] {get_hint(error_type)}")
|
| 711 |
-
|
| 712 |
-
return 0.1
|
| 713 |
-
|
| 714 |
-
reward += 0.2
|
| 715 |
-
|
| 716 |
-
gold_res = execute_sql_cached(db_path, gold)
|
| 717 |
-
|
| 718 |
-
if gold_res == "EXECUTION_ERROR":
|
| 719 |
-
return 0.1
|
| 720 |
-
|
| 721 |
-
if pred_res == gold_res:
|
| 722 |
-
return 1.0
|
| 723 |
-
|
| 724 |
-
return max(-1.0, min(1.0, reward))
|
| 725 |
-
|
| 726 |
-
except Exception as e:
|
| 727 |
-
log_error("UNKNOWN", pred_sql, str(e), "runtime_error")
|
| 728 |
-
return 0.0
|
| 729 |
-
|
| 730 |
-
|
| 731 |
-
# =========================================================
|
| 732 |
-
# BATCH EXECUTION (Task 1)
|
| 733 |
-
# =========================================================
|
| 734 |
-
def cached_execution_reward(pred_sql: str, db_path: str, gold_sql: str) -> float:
|
| 735 |
-
if not USE_CACHE:
|
| 736 |
-
return float(execution_reward(pred_sql, db_path, gold_sql))
|
| 737 |
-
key = f"{db_path}|{pred_sql}|{gold_sql}"
|
| 738 |
-
if key in _REWARD_CACHE:
|
| 739 |
-
return float(_REWARD_CACHE[key])
|
| 740 |
-
r = float(execution_reward(pred_sql, db_path, gold_sql))
|
| 741 |
-
_REWARD_CACHE[key] = r
|
| 742 |
-
return r
|
| 743 |
-
|
| 744 |
-
|
| 745 |
-
def execution_reward_batch_sequential(rollouts):
|
| 746 |
-
return [cached_execution_reward(p, d, g) for (p, d, g) in rollouts]
|
| 747 |
-
|
| 748 |
-
|
| 749 |
-
def execution_reward_batch_parallel(rollouts, max_workers=10):
|
| 750 |
-
results = [0.0] * len(rollouts)
|
| 751 |
-
|
| 752 |
-
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
| 753 |
-
futures = {
|
| 754 |
-
executor.submit(cached_execution_reward, p, d, g): i
|
| 755 |
-
for i, (p, d, g) in enumerate(rollouts)
|
| 756 |
-
}
|
| 757 |
-
|
| 758 |
-
for fut in as_completed(futures):
|
| 759 |
-
idx = futures[fut]
|
| 760 |
-
try:
|
| 761 |
-
results[idx] = fut.result()
|
| 762 |
-
except Exception:
|
| 763 |
-
results[idx] = 0.0
|
| 764 |
-
|
| 765 |
-
return results
|
| 766 |
-
|
| 767 |
-
|
| 768 |
-
def execution_reward_batch_parallel_by_db(rollouts, max_workers: int = 20):
|
| 769 |
-
"""
|
| 770 |
-
1 thread per DB path. Reuses a single readonly connection per DB worker.
|
| 771 |
-
Preserves input order.
|
| 772 |
-
"""
|
| 773 |
-
if not rollouts:
|
| 774 |
-
return []
|
| 775 |
-
|
| 776 |
-
by_db = {}
|
| 777 |
-
for idx, (pred_sql, db_path, gold_sql) in enumerate(rollouts):
|
| 778 |
-
by_db.setdefault(db_path, []).append((idx, pred_sql, gold_sql))
|
| 779 |
-
|
| 780 |
-
results = [0.0 for _ in range(len(rollouts))]
|
| 781 |
-
|
| 782 |
-
def _reward_with_conn(conn: sqlite3.Connection, pred_sql: str, db_path: str, gold_sql: str) -> float:
|
| 783 |
-
try:
|
| 784 |
-
sql = _normalize_sql(pred_sql)
|
| 785 |
-
gold = _normalize_sql(gold_sql)
|
| 786 |
-
|
| 787 |
-
if not is_valid_select(sql):
|
| 788 |
-
return -1.0
|
| 789 |
-
|
| 790 |
-
reward = -0.2
|
| 791 |
-
|
| 792 |
-
if USE_SCHEMA_VALIDATION:
|
| 793 |
-
valid, _ = validate_sql_schema(sql, db_path)
|
| 794 |
-
if not valid:
|
| 795 |
-
error_type = classify_error(sql)
|
| 796 |
-
log_error("UNKNOWN", sql, "schema_invalid", error_type)
|
| 797 |
-
return 0.1
|
| 798 |
-
|
| 799 |
-
pred_res = execute_sql_cached_conn(conn, db_path, sql)
|
| 800 |
-
if pred_res == EXECUTION_ERROR:
|
| 801 |
-
error_type = classify_error(sql)
|
| 802 |
-
log_error("UNKNOWN", sql, "execution_error", error_type)
|
| 803 |
-
return 0.1
|
| 804 |
-
|
| 805 |
-
reward += 0.2
|
| 806 |
-
gold_res = execute_sql_cached_conn(conn, db_path, gold)
|
| 807 |
-
if gold_res == EXECUTION_ERROR:
|
| 808 |
-
return 0.1
|
| 809 |
-
if pred_res == gold_res:
|
| 810 |
-
return 1.0
|
| 811 |
-
return max(-1.0, min(1.0, reward))
|
| 812 |
-
except Exception:
|
| 813 |
-
return 0.0
|
| 814 |
-
|
| 815 |
-
def _worker(db_path: str, items):
|
| 816 |
-
conn = _connect_readonly(db_path)
|
| 817 |
-
try:
|
| 818 |
-
for idx, pred, gold in items:
|
| 819 |
-
results[idx] = _reward_with_conn(conn, pred, db_path, gold)
|
| 820 |
-
finally:
|
| 821 |
-
try:
|
| 822 |
-
conn.close()
|
| 823 |
-
except Exception:
|
| 824 |
-
pass
|
| 825 |
-
|
| 826 |
-
with ThreadPoolExecutor(max_workers=int(max_workers)) as ex:
|
| 827 |
-
futures = [ex.submit(_worker, db_path, items) for db_path, items in by_db.items()]
|
| 828 |
-
for fut in as_completed(futures):
|
| 829 |
-
fut.result()
|
| 830 |
-
|
| 831 |
-
return results
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/execution_reward.py
CHANGED
|
@@ -1,510 +1,41 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
# from __future__ import annotations
|
| 4 |
-
|
| 5 |
-
# import hashlib
|
| 6 |
-
# import os
|
| 7 |
-
# import queue
|
| 8 |
-
# import re
|
| 9 |
-
# import sqlite3
|
| 10 |
-
# import threading
|
| 11 |
-
# import time
|
| 12 |
-
# from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 13 |
-
# from dataclasses import dataclass
|
| 14 |
-
# from typing import Dict, List, Optional, Sequence, Set, Tuple, Union
|
| 15 |
-
|
| 16 |
-
# # --- CACHE CONTROL ---
|
| 17 |
-
# USE_CACHE = True
|
| 18 |
-
# _REWARD_CACHE: Dict[str, float] = {}
|
| 19 |
-
|
| 20 |
-
# def set_use_cache(enabled: bool):
|
| 21 |
-
# """Dynamically toggle the reward cache for benchmarks."""
|
| 22 |
-
# global USE_CACHE
|
| 23 |
-
# USE_CACHE = enabled
|
| 24 |
-
|
| 25 |
-
# def _normalize_sql(sql: str) -> str:
|
| 26 |
-
# if not isinstance(sql, str):
|
| 27 |
-
# return ""
|
| 28 |
-
# s = sql.strip()
|
| 29 |
-
# if s.startswith("```"):
|
| 30 |
-
# s = re.sub(r"^```[a-zA-Z0-9_+-]*\n?", "", s).strip()
|
| 31 |
-
# s = re.sub(r"\n?```$", "", s).strip()
|
| 32 |
-
# if s.lower().startswith("sql:"):
|
| 33 |
-
# s = s[4:].strip()
|
| 34 |
-
# if ";" in s:
|
| 35 |
-
# s = s.split(";", 1)[0].strip()
|
| 36 |
-
# return s
|
| 37 |
-
|
| 38 |
-
# def _connect_readonly(db_path: str) -> sqlite3.Connection:
|
| 39 |
-
# uri = f"file:{os.path.abspath(db_path)}?mode=ro"
|
| 40 |
-
# conn = sqlite3.connect(uri, uri=True, check_same_thread=False)
|
| 41 |
-
# conn.execute("PRAGMA query_only = ON;")
|
| 42 |
-
# conn.execute("PRAGMA foreign_keys = ON;")
|
| 43 |
-
# return conn
|
| 44 |
-
|
| 45 |
-
# DEFAULT_QUERY_TIMEOUT_S = 2.0
|
| 46 |
-
|
| 47 |
-
# def _with_timeout(conn: sqlite3.Connection, timeout_s: float = DEFAULT_QUERY_TIMEOUT_S) -> None:
|
| 48 |
-
# start = time.monotonic()
|
| 49 |
-
# def _handler() -> int:
|
| 50 |
-
# return 1 if (time.monotonic() - start) > timeout_s else 0
|
| 51 |
-
# conn.set_progress_handler(_handler, 10_000)
|
| 52 |
-
|
| 53 |
-
# def _list_tables(conn: sqlite3.Connection) -> List[str]:
|
| 54 |
-
# try:
|
| 55 |
-
# cur = conn.execute("SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%';")
|
| 56 |
-
# return [r[0] for r in cur.fetchall() if r and isinstance(r[0], str)]
|
| 57 |
-
# except sqlite3.Error:
|
| 58 |
-
# return []
|
| 59 |
-
|
| 60 |
-
# def _contains_table_name(sql: str, table_names: Sequence[str]) -> bool:
|
| 61 |
-
# s = sql.lower()
|
| 62 |
-
# for t in table_names:
|
| 63 |
-
# tl = t.lower()
|
| 64 |
-
# if not tl:
|
| 65 |
-
# continue
|
| 66 |
-
# if re.search(rf"\b{re.escape(tl)}\b", s):
|
| 67 |
-
# return True
|
| 68 |
-
# return False
|
| 69 |
-
|
| 70 |
-
# def _explain_query_plan(conn: sqlite3.Connection, sql: str) -> bool:
|
| 71 |
-
# try:
|
| 72 |
-
# _with_timeout(conn, timeout_s=DEFAULT_QUERY_TIMEOUT_S)
|
| 73 |
-
# conn.execute(f"EXPLAIN QUERY PLAN {sql}")
|
| 74 |
-
# return True
|
| 75 |
-
# except sqlite3.Error:
|
| 76 |
-
# return False
|
| 77 |
-
|
| 78 |
-
# def _execute(conn: sqlite3.Connection, sql: str, max_rows: int = 1000) -> Tuple[bool, List[Tuple], Optional[str]]:
|
| 79 |
-
# try:
|
| 80 |
-
# _with_timeout(conn, timeout_s=DEFAULT_QUERY_TIMEOUT_S)
|
| 81 |
-
# cur = conn.execute(sql)
|
| 82 |
-
# rows = cur.fetchmany(max_rows)
|
| 83 |
-
# norm_rows = [tuple(r) for r in rows]
|
| 84 |
-
# return True, norm_rows, None
|
| 85 |
-
# except sqlite3.Error as e:
|
| 86 |
-
# return False, [], str(e)
|
| 87 |
-
|
| 88 |
-
# _SQL_KEYWORDS_TO_IGNORE = {
|
| 89 |
-
# "select", "from", "where", "join", "inner", "left", "right", "full", "outer",
|
| 90 |
-
# "on", "group", "by", "order", "limit", "having", "distinct", "union", "intersect",
|
| 91 |
-
# "except", "as", "and", "or", "not", "in", "is", "null", "like", "between", "case",
|
| 92 |
-
# "when", "then", "else", "end", "asc", "desc"
|
| 93 |
-
# }
|
| 94 |
-
|
| 95 |
-
# _SQL_FUNCTIONS_TO_IGNORE = {
|
| 96 |
-
# "count", "avg", "min", "max", "sum", "lower", "upper", "substr", "coalesce",
|
| 97 |
-
# "round", "date", "datetime", "strftime"
|
| 98 |
-
# }
|
| 99 |
-
|
| 100 |
-
# # --- LIGHTWEIGHT PARSING ---
|
| 101 |
-
# def is_valid_select(sql: str):
|
| 102 |
-
# sql = sql.strip().lower()
|
| 103 |
-
# return sql.startswith("select") or sql.startswith("with")
|
| 104 |
-
|
| 105 |
-
# def extract_tables(sql: str) -> List[str]:
|
| 106 |
-
# sql = sql.lower()
|
| 107 |
-
# if "join" not in sql:
|
| 108 |
-
# tables = re.findall(r'from\s+(\w+)', sql)
|
| 109 |
-
# return list(set(tables))
|
| 110 |
-
|
| 111 |
-
# tables = re.findall(r'from\s+([a-zA-Z_][a-zA-Z0-9_]*)', sql)
|
| 112 |
-
# joins = re.findall(r'join\s+([a-zA-Z_][a-zA-Z0-9_]*)', sql)
|
| 113 |
-
# return list(set(tables + joins))
|
| 114 |
-
|
| 115 |
-
# def extract_columns(sql: str) -> List[str]:
|
| 116 |
-
# sql = sql.lower()
|
| 117 |
-
# match = re.search(r'select\s+(.*?)\s+from', sql)
|
| 118 |
-
# if not match:
|
| 119 |
-
# return []
|
| 120 |
-
# cols = match.group(1)
|
| 121 |
-
# if cols.strip() == "*":
|
| 122 |
-
# return ["*"]
|
| 123 |
-
# return [c.strip() for c in cols.split(",")]
|
| 124 |
-
|
| 125 |
-
# def _get_db_tables_and_columns(conn: sqlite3.Connection) -> Tuple[Set[str], Set[str]]:
|
| 126 |
-
# tables = set()
|
| 127 |
-
# columns = set()
|
| 128 |
-
# for t in _list_tables(conn):
|
| 129 |
-
# tl = t.lower()
|
| 130 |
-
# if not tl:
|
| 131 |
-
# continue
|
| 132 |
-
# tables.add(tl)
|
| 133 |
-
# try:
|
| 134 |
-
# cur = conn.execute(f'PRAGMA table_info("{t}")')
|
| 135 |
-
# for row in cur.fetchall():
|
| 136 |
-
# if row and isinstance(row[1], str):
|
| 137 |
-
# columns.add(row[1].lower())
|
| 138 |
-
# except sqlite3.Error:
|
| 139 |
-
# continue
|
| 140 |
-
# return tables, columns
|
| 141 |
-
|
| 142 |
-
# def _safe_results_equal(a: List[Tuple], b: List[Tuple]) -> bool:
|
| 143 |
-
# return a == b
|
| 144 |
-
|
| 145 |
-
# @dataclass
|
| 146 |
-
# class RewardDebugStats:
|
| 147 |
-
# total: int = 0
|
| 148 |
-
# parsed_ok: int = 0
|
| 149 |
-
# table_match: int = 0
|
| 150 |
-
# column_match: int = 0
|
| 151 |
-
# executed_ok: int = 0
|
| 152 |
-
# exact_match: int = 0
|
| 153 |
-
|
| 154 |
-
# _DEBUG = RewardDebugStats()
|
| 155 |
-
|
| 156 |
-
# def reset_debug_metrics() -> None:
|
| 157 |
-
# global _DEBUG
|
| 158 |
-
# _DEBUG = RewardDebugStats()
|
| 159 |
-
|
| 160 |
-
# def get_debug_metrics() -> dict:
|
| 161 |
-
# denom = max(_DEBUG.total, 1)
|
| 162 |
-
# return {
|
| 163 |
-
# "valid_sql_rate": _DEBUG.parsed_ok / denom,
|
| 164 |
-
# "table_match_rate": _DEBUG.table_match / denom,
|
| 165 |
-
# "column_match_rate": _DEBUG.column_match / denom,
|
| 166 |
-
# "execution_accuracy": _DEBUG.exact_match / denom,
|
| 167 |
-
# }
|
| 168 |
-
|
| 169 |
-
# EXECUTION_ERROR = "EXECUTION_ERROR"
|
| 170 |
-
|
| 171 |
-
# _RESULT_CACHE_LOCK = threading.Lock()
|
| 172 |
-
# _RESULT_CACHE: "Dict[str, Union[List[Tuple], str]]" = {}
|
| 173 |
-
# _RESULT_CACHE_MAX = 100_000
|
| 174 |
-
|
| 175 |
-
# def clear_result_cache() -> None:
|
| 176 |
-
# """Clear both DB query cache and reward cache."""
|
| 177 |
-
# with _RESULT_CACHE_LOCK:
|
| 178 |
-
# _RESULT_CACHE.clear()
|
| 179 |
-
# _REWARD_CACHE.clear()
|
| 180 |
-
|
| 181 |
-
# def _db_state_fingerprint(db_path: str) -> str:
|
| 182 |
-
# try:
|
| 183 |
-
# st = os.stat(db_path)
|
| 184 |
-
# return f"{st.st_mtime_ns}:{st.st_size}"
|
| 185 |
-
# except OSError:
|
| 186 |
-
# return "missing"
|
| 187 |
-
|
| 188 |
-
# def _result_cache_key(db_path: str, sql: str) -> str:
|
| 189 |
-
# fp = _db_state_fingerprint(db_path)
|
| 190 |
-
# payload = f"{fp}\0{sql}".encode("utf-8", errors="ignore")
|
| 191 |
-
# return hashlib.sha256(payload).hexdigest()
|
| 192 |
-
|
| 193 |
-
# class _ConnectionPool:
|
| 194 |
-
# def __init__(self, db_path: str, maxsize: int = 1) -> None:
|
| 195 |
-
# self.db_path = db_path
|
| 196 |
-
# self.pool = queue.LifoQueue(maxsize=maxsize)
|
| 197 |
-
# self.lock = threading.Lock()
|
| 198 |
-
|
| 199 |
-
# def acquire(self) -> sqlite3.Connection:
|
| 200 |
-
# try:
|
| 201 |
-
# return self.pool.get_nowait()
|
| 202 |
-
# except queue.Empty:
|
| 203 |
-
# with self.lock:
|
| 204 |
-
# try:
|
| 205 |
-
# return self.pool.get_nowait()
|
| 206 |
-
# except queue.Empty:
|
| 207 |
-
# return _connect_readonly(self.db_path)
|
| 208 |
-
|
| 209 |
-
# def release(self, conn: sqlite3.Connection) -> None:
|
| 210 |
-
# try:
|
| 211 |
-
# self.pool.put_nowait(conn)
|
| 212 |
-
# except queue.Full:
|
| 213 |
-
# try:
|
| 214 |
-
# conn.close()
|
| 215 |
-
# except Exception:
|
| 216 |
-
# pass
|
| 217 |
-
|
| 218 |
-
# _POOL_LOCK = threading.Lock()
|
| 219 |
-
# _POOLS: Dict[str, _ConnectionPool] = {}
|
| 220 |
-
|
| 221 |
-
# def _get_pool(db_path: str) -> _ConnectionPool:
|
| 222 |
-
# with _POOL_LOCK:
|
| 223 |
-
# pool = _POOLS.get(db_path)
|
| 224 |
-
# if pool is None:
|
| 225 |
-
# pool = _ConnectionPool(db_path=db_path, maxsize=1)
|
| 226 |
-
# _POOLS[db_path] = pool
|
| 227 |
-
# return pool
|
| 228 |
-
|
| 229 |
-
# class _PooledConnection:
|
| 230 |
-
# def __init__(self, db_path: str) -> None:
|
| 231 |
-
# self.db_path = db_path
|
| 232 |
-
# self.pool = _get_pool(db_path)
|
| 233 |
-
# self.conn: Optional[sqlite3.Connection] = None
|
| 234 |
-
|
| 235 |
-
# def __enter__(self) -> sqlite3.Connection:
|
| 236 |
-
# self.conn = self.pool.acquire()
|
| 237 |
-
# return self.conn
|
| 238 |
-
|
| 239 |
-
# def __exit__(self, exc_type, exc, tb) -> None:
|
| 240 |
-
# if self.conn is not None:
|
| 241 |
-
# self.pool.release(self.conn)
|
| 242 |
-
# self.conn = None
|
| 243 |
-
|
| 244 |
-
# def _cache_get(key: str) -> Optional[Union[List[Tuple], str]]:
|
| 245 |
-
# with _RESULT_CACHE_LOCK:
|
| 246 |
-
# return _RESULT_CACHE.get(key)
|
| 247 |
-
|
| 248 |
-
# def _cache_put(key: str, value: Union[List[Tuple], str]) -> None:
|
| 249 |
-
# with _RESULT_CACHE_LOCK:
|
| 250 |
-
# if len(_RESULT_CACHE) >= _RESULT_CACHE_MAX:
|
| 251 |
-
# _RESULT_CACHE.clear()
|
| 252 |
-
# _RESULT_CACHE[key] = value
|
| 253 |
-
|
| 254 |
-
# def execute_sql(conn: sqlite3.Connection, sql: str, *, max_rows: int = 1000) -> Union[List[Tuple], str]:
|
| 255 |
-
# try:
|
| 256 |
-
# _with_timeout(conn, timeout_s=DEFAULT_QUERY_TIMEOUT_S)
|
| 257 |
-
# cur = conn.execute(sql)
|
| 258 |
-
# rows = cur.fetchmany(max_rows)
|
| 259 |
-
# return [tuple(r) for r in rows]
|
| 260 |
-
# except Exception:
|
| 261 |
-
# return EXECUTION_ERROR
|
| 262 |
-
|
| 263 |
-
# def execute_sql_cached(db_path: str, sql: str, *, max_rows: int = 1000) -> Union[List[Tuple], str]:
|
| 264 |
-
# if not USE_CACHE:
|
| 265 |
-
# with _PooledConnection(db_path) as conn:
|
| 266 |
-
# return execute_sql(conn, sql, max_rows=max_rows)
|
| 267 |
-
|
| 268 |
-
# key = _result_cache_key(db_path, sql)
|
| 269 |
-
# cached = _cache_get(key)
|
| 270 |
-
# if cached is not None:
|
| 271 |
-
# return cached
|
| 272 |
-
# with _PooledConnection(db_path) as conn:
|
| 273 |
-
# res = execute_sql(conn, sql, max_rows=max_rows)
|
| 274 |
-
# _cache_put(key, res)
|
| 275 |
-
# return res
|
| 276 |
-
|
| 277 |
-
# def execution_reward_timed(
|
| 278 |
-
# pred_sql: str, db_path: str, gold_sql: str, *, measure_plan: bool = False,
|
| 279 |
-
# ) -> Tuple[float, Dict[str, float]]:
|
| 280 |
-
# timings = {"parse_s": 0.0, "plan_s": 0.0, "exec_s": 0.0}
|
| 281 |
-
# t0 = time.perf_counter()
|
| 282 |
-
# sql = _normalize_sql(pred_sql)
|
| 283 |
-
# gold = _normalize_sql(gold_sql)
|
| 284 |
-
|
| 285 |
-
# if not is_valid_select(sql):
|
| 286 |
-
# timings["parse_s"] = time.perf_counter() - t0
|
| 287 |
-
# return 0.0, timings
|
| 288 |
-
|
| 289 |
-
# t1 = time.perf_counter()
|
| 290 |
-
# timings["parse_s"] = t1 - t0
|
| 291 |
-
|
| 292 |
-
# if measure_plan:
|
| 293 |
-
# with _PooledConnection(db_path) as conn:
|
| 294 |
-
# p0 = time.perf_counter()
|
| 295 |
-
# _explain_query_plan(conn, sql)
|
| 296 |
-
# _explain_query_plan(conn, gold)
|
| 297 |
-
# timings["plan_s"] = time.perf_counter() - p0
|
| 298 |
-
|
| 299 |
-
# e0 = time.perf_counter()
|
| 300 |
-
# pred_res = execute_sql_cached(db_path, sql)
|
| 301 |
-
# if pred_res == EXECUTION_ERROR:
|
| 302 |
-
# timings["exec_s"] = time.perf_counter() - e0
|
| 303 |
-
# return 0.0, timings
|
| 304 |
-
# gold_res = execute_sql_cached(db_path, gold)
|
| 305 |
-
# timings["exec_s"] = time.perf_counter() - e0
|
| 306 |
-
# if gold_res == EXECUTION_ERROR:
|
| 307 |
-
# return 0.0, timings
|
| 308 |
-
|
| 309 |
-
# reward = -0.2
|
| 310 |
-
# reward += 0.2
|
| 311 |
-
# if _safe_results_equal(pred_res, gold_res):
|
| 312 |
-
# return 1.0, timings
|
| 313 |
-
# return max(-1.0, min(1.0, reward)), timings
|
| 314 |
-
|
| 315 |
-
# def execution_reward(pred_sql: str, db_path: str, gold_sql: str) -> float:
|
| 316 |
-
# try:
|
| 317 |
-
# sql = _normalize_sql(pred_sql)
|
| 318 |
-
# gold = _normalize_sql(gold_sql)
|
| 319 |
-
|
| 320 |
-
# if not is_valid_select(sql):
|
| 321 |
-
# return -1.0
|
| 322 |
-
|
| 323 |
-
# reward = -0.2
|
| 324 |
-
|
| 325 |
-
# pred_tables = set(extract_tables(sql))
|
| 326 |
-
# gold_tables = set(extract_tables(gold))
|
| 327 |
-
|
| 328 |
-
# if pred_tables == gold_tables and len(gold_tables) > 0:
|
| 329 |
-
# reward += 0.3
|
| 330 |
-
|
| 331 |
-
# pred_cols = set(extract_columns(sql))
|
| 332 |
-
# gold_cols = set(extract_columns(gold))
|
| 333 |
-
|
| 334 |
-
# if gold_cols:
|
| 335 |
-
# overlap = len(pred_cols & gold_cols) / len(gold_cols)
|
| 336 |
-
# reward += 0.3 * overlap
|
| 337 |
-
|
| 338 |
-
# pred_res = execute_sql_cached(db_path, sql)
|
| 339 |
-
# if pred_res == EXECUTION_ERROR:
|
| 340 |
-
# return 0.0
|
| 341 |
-
# reward += 0.2
|
| 342 |
-
|
| 343 |
-
# gold_res = execute_sql_cached(db_path, gold)
|
| 344 |
-
# if gold_res == EXECUTION_ERROR:
|
| 345 |
-
# return 0.0
|
| 346 |
-
# if _safe_results_equal(pred_res, gold_res):
|
| 347 |
-
# return 1.0
|
| 348 |
-
|
| 349 |
-
# return max(-1.0, min(1.0, reward))
|
| 350 |
-
|
| 351 |
-
# except Exception:
|
| 352 |
-
# return 0.0
|
| 353 |
-
|
| 354 |
-
# def cached_execution_reward(pred_sql: str, db_path: str, gold_sql: str) -> float:
|
| 355 |
-
# if not USE_CACHE:
|
| 356 |
-
# return execution_reward(pred_sql, db_path, gold_sql)
|
| 357 |
-
|
| 358 |
-
# key = f"{db_path}|{pred_sql}|{gold_sql}"
|
| 359 |
-
# if key not in _REWARD_CACHE:
|
| 360 |
-
# _REWARD_CACHE[key] = execution_reward(pred_sql, db_path, gold_sql)
|
| 361 |
-
# return _REWARD_CACHE[key]
|
| 362 |
-
|
| 363 |
-
# def execution_reward_batch_sequential(rollouts: Sequence[Tuple[str, str, str]]) -> List[float]:
|
| 364 |
-
# return [cached_execution_reward(pred_sql, db_path, gold_sql) for pred_sql, db_path, gold_sql in rollouts]
|
| 365 |
-
|
| 366 |
-
# def execution_reward_batch_parallel(rollouts: Sequence[Tuple[str, str, str]], *, max_workers: int = 20) -> List[float]:
|
| 367 |
-
# if not rollouts:
|
| 368 |
-
# return []
|
| 369 |
-
|
| 370 |
-
# unique_dbs = {db_path for _, db_path, _ in rollouts}
|
| 371 |
-
# worker_count = max(1, min(max_workers, len(unique_dbs)))
|
| 372 |
-
# results: List[Optional[float]] = [None] * len(rollouts)
|
| 373 |
-
|
| 374 |
-
# with ThreadPoolExecutor(max_workers=worker_count) as executor:
|
| 375 |
-
# futures = {
|
| 376 |
-
# executor.submit(cached_execution_reward, pred_sql, db_path, gold_sql): i
|
| 377 |
-
# for i, (pred_sql, db_path, gold_sql) in enumerate(rollouts)
|
| 378 |
-
# }
|
| 379 |
-
# for fut in as_completed(futures):
|
| 380 |
-
# idx = futures[fut]
|
| 381 |
-
# try:
|
| 382 |
-
# results[idx] = float(fut.result())
|
| 383 |
-
# except Exception:
|
| 384 |
-
# results[idx] = 0.0
|
| 385 |
-
|
| 386 |
-
# return [r if r is not None else 0.0 for r in results]
|
| 387 |
-
|
| 388 |
from __future__ import annotations
|
| 389 |
|
| 390 |
import os
|
| 391 |
import re
|
| 392 |
import sqlite3
|
| 393 |
-
import threading
|
| 394 |
import time
|
| 395 |
-
import json
|
| 396 |
-
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 397 |
from dataclasses import dataclass
|
| 398 |
-
from typing import
|
| 399 |
-
|
| 400 |
-
from src.sql_validator import validate_sql_schema
|
| 401 |
-
|
| 402 |
-
# =========================================================
|
| 403 |
-
# 🔥 CONFIG FLAGS
|
| 404 |
-
# =========================================================
|
| 405 |
-
USE_SCHEMA_VALIDATION = True
|
| 406 |
-
USE_CACHE = True
|
| 407 |
-
DEFAULT_QUERY_TIMEOUT_S = 2.0
|
| 408 |
-
|
| 409 |
-
EXECUTION_ERROR = "EXECUTION_ERROR"
|
| 410 |
-
|
| 411 |
-
_REWARD_CACHE: Dict[str, float] = {}
|
| 412 |
-
|
| 413 |
-
# =========================================================
|
| 414 |
-
# 🔥 TASK 2: ERROR ANALYSIS + LOGGING
|
| 415 |
-
# =========================================================
|
| 416 |
-
ERROR_LOG_FILE = "results/error_logs.json"
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
def classify_error(sql: str) -> str:
|
| 420 |
-
sql = sql.lower()
|
| 421 |
-
|
| 422 |
-
if "join" in sql and " on " not in sql:
|
| 423 |
-
return "missing_join"
|
| 424 |
-
|
| 425 |
-
if "where" in sql and "=" not in sql and ">" not in sql and "<" not in sql:
|
| 426 |
-
return "wrong_where"
|
| 427 |
-
|
| 428 |
-
if "null" in sql:
|
| 429 |
-
return "null_handling"
|
| 430 |
-
|
| 431 |
-
if "group by" in sql and "count" not in sql:
|
| 432 |
-
return "wrong_groupby"
|
| 433 |
|
| 434 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 435 |
|
| 436 |
|
| 437 |
-
def get_hint(error_type: str) -> str:
|
| 438 |
-
hints = {
|
| 439 |
-
"missing_join": "Add proper JOIN condition using ON.",
|
| 440 |
-
"wrong_where": "Check WHERE clause conditions.",
|
| 441 |
-
"null_handling": "Handle NULL values using IS NULL.",
|
| 442 |
-
"wrong_groupby": "Use aggregation functions with GROUP BY.",
|
| 443 |
-
"other": "Check SQL syntax and logic."
|
| 444 |
-
}
|
| 445 |
-
return hints.get(error_type, "Check query.")
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
def log_error(question: str, sql: str, error: str, error_type: str):
|
| 449 |
-
os.makedirs("results", exist_ok=True)
|
| 450 |
-
|
| 451 |
-
entry = {
|
| 452 |
-
"question": question,
|
| 453 |
-
"sql": sql,
|
| 454 |
-
"error": error,
|
| 455 |
-
"error_type": error_type,
|
| 456 |
-
"timestamp": time.time()
|
| 457 |
-
}
|
| 458 |
-
|
| 459 |
-
if os.path.exists(ERROR_LOG_FILE):
|
| 460 |
-
with open(ERROR_LOG_FILE, "r") as f:
|
| 461 |
-
logs = json.load(f)
|
| 462 |
-
else:
|
| 463 |
-
logs = []
|
| 464 |
-
|
| 465 |
-
logs.append(entry)
|
| 466 |
-
|
| 467 |
-
with open(ERROR_LOG_FILE, "w") as f:
|
| 468 |
-
json.dump(logs, f, indent=2)
|
| 469 |
-
|
| 470 |
-
# =========================================================
|
| 471 |
-
# CACHE/VALIDATION TOGGLES (Task 1)
|
| 472 |
-
# =========================================================
|
| 473 |
-
def set_use_cache(enabled: bool) -> None:
|
| 474 |
-
global USE_CACHE
|
| 475 |
-
USE_CACHE = bool(enabled)
|
| 476 |
-
|
| 477 |
-
|
| 478 |
-
def set_use_schema_validation(enabled: bool) -> None:
|
| 479 |
-
global USE_SCHEMA_VALIDATION
|
| 480 |
-
USE_SCHEMA_VALIDATION = bool(enabled)
|
| 481 |
-
|
| 482 |
-
|
| 483 |
-
# =========================================================
|
| 484 |
-
# SQL CLEANING
|
| 485 |
-
# =========================================================
|
| 486 |
def _normalize_sql(sql: str) -> str:
|
| 487 |
if not isinstance(sql, str):
|
| 488 |
return ""
|
| 489 |
s = sql.strip()
|
| 490 |
-
|
| 491 |
if s.startswith("```"):
|
|
|
|
| 492 |
s = re.sub(r"^```[a-zA-Z0-9_+-]*\n?", "", s).strip()
|
| 493 |
s = re.sub(r"\n?```$", "", s).strip()
|
| 494 |
-
|
| 495 |
if s.lower().startswith("sql:"):
|
| 496 |
s = s[4:].strip()
|
| 497 |
-
|
| 498 |
if ";" in s:
|
| 499 |
s = s.split(";", 1)[0].strip()
|
| 500 |
-
|
| 501 |
return s
|
| 502 |
|
| 503 |
|
| 504 |
-
|
| 505 |
-
#
|
| 506 |
-
#
|
| 507 |
-
def _connect_readonly(db_path: str):
|
| 508 |
uri = f"file:{os.path.abspath(db_path)}?mode=ro"
|
| 509 |
conn = sqlite3.connect(uri, uri=True, check_same_thread=False)
|
| 510 |
conn.execute("PRAGMA query_only = ON;")
|
|
@@ -512,320 +43,367 @@ def _connect_readonly(db_path: str):
|
|
| 512 |
return conn
|
| 513 |
|
| 514 |
|
| 515 |
-
def _with_timeout(conn: sqlite3.Connection, timeout_s: float =
|
| 516 |
start = time.monotonic()
|
| 517 |
|
| 518 |
-
def
|
| 519 |
return 1 if (time.monotonic() - start) > timeout_s else 0
|
| 520 |
|
| 521 |
-
|
|
|
|
| 522 |
|
| 523 |
|
| 524 |
-
def
|
| 525 |
try:
|
| 526 |
-
|
| 527 |
-
|
| 528 |
-
|
| 529 |
-
|
| 530 |
-
|
| 531 |
-
|
| 532 |
-
|
| 533 |
-
_RESULT_CACHE = {}
|
| 534 |
-
_RESULT_LOCK = threading.Lock()
|
| 535 |
-
|
| 536 |
-
|
| 537 |
-
def execute_sql_cached(db_path, sql):
|
| 538 |
-
key = f"{db_path}|{sql}"
|
| 539 |
|
| 540 |
-
if USE_CACHE:
|
| 541 |
-
with _RESULT_LOCK:
|
| 542 |
-
if key in _RESULT_CACHE:
|
| 543 |
-
return _RESULT_CACHE[key]
|
| 544 |
|
| 545 |
-
|
| 546 |
-
|
| 547 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 548 |
|
| 549 |
-
if USE_CACHE:
|
| 550 |
-
with _RESULT_LOCK:
|
| 551 |
-
_RESULT_CACHE[key] = result
|
| 552 |
|
| 553 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 554 |
|
| 555 |
|
| 556 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 557 |
"""
|
| 558 |
-
|
| 559 |
-
|
| 560 |
"""
|
| 561 |
-
|
| 562 |
-
if
|
| 563 |
-
|
| 564 |
-
|
| 565 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 566 |
|
| 567 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 568 |
|
| 569 |
-
|
| 570 |
-
|
| 571 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 572 |
|
| 573 |
-
return result
|
| 574 |
|
|
|
|
|
|
|
|
|
|
| 575 |
|
| 576 |
-
def clear_result_cache() -> None:
|
| 577 |
-
global _RESULT_CACHE, _REWARD_CACHE
|
| 578 |
-
with _RESULT_LOCK:
|
| 579 |
-
_RESULT_CACHE.clear()
|
| 580 |
-
_REWARD_CACHE.clear()
|
| 581 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 582 |
|
| 583 |
-
# =========================================================
|
| 584 |
-
# SQL PARSING
|
| 585 |
-
# =========================================================
|
| 586 |
-
def is_valid_select(sql):
|
| 587 |
-
return sql.lower().startswith("select") or sql.lower().startswith("with")
|
| 588 |
|
|
|
|
| 589 |
|
| 590 |
-
def extract_tables(sql):
|
| 591 |
-
return re.findall(r'from\s+(\w+)', sql.lower())
|
| 592 |
|
|
|
|
|
|
|
|
|
|
| 593 |
|
| 594 |
-
def extract_columns(sql):
|
| 595 |
-
match = re.search(r'select\s+(.*?)\s+from', sql.lower())
|
| 596 |
-
if not match:
|
| 597 |
-
return []
|
| 598 |
-
cols = match.group(1)
|
| 599 |
-
return ["*"] if cols.strip() == "*" else [c.strip() for c in cols.split(",")]
|
| 600 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 601 |
|
| 602 |
-
|
| 603 |
-
sql = sql.lower()
|
| 604 |
-
ops = []
|
| 605 |
-
|
| 606 |
-
if "select" in sql: ops.append("SELECT")
|
| 607 |
-
if "where" in sql: ops.append("WHERE")
|
| 608 |
-
if "join" in sql: ops.append("JOIN")
|
| 609 |
-
if "group by" in sql: ops.append("GROUP_BY")
|
| 610 |
-
if "order by" in sql: ops.append("ORDER_BY")
|
| 611 |
|
| 612 |
-
return ops
|
| 613 |
|
|
|
|
|
|
|
|
|
|
| 614 |
|
| 615 |
-
|
|
|
|
| 616 |
try:
|
| 617 |
-
_with_timeout(conn, timeout_s=
|
| 618 |
-
conn.execute(
|
| 619 |
-
|
|
|
|
| 620 |
except Exception:
|
| 621 |
-
return
|
| 622 |
|
| 623 |
|
| 624 |
-
def
|
| 625 |
"""
|
| 626 |
-
|
| 627 |
-
|
|
|
|
| 628 |
"""
|
| 629 |
-
|
| 630 |
-
|
| 631 |
-
|
| 632 |
-
sql = _normalize_sql(pred_sql)
|
| 633 |
-
gold = _normalize_sql(gold_sql)
|
| 634 |
-
|
| 635 |
-
if not is_valid_select(sql):
|
| 636 |
-
timings["parse_s"] = time.perf_counter() - t0
|
| 637 |
-
return 0.0, timings
|
| 638 |
-
|
| 639 |
-
t1 = time.perf_counter()
|
| 640 |
-
timings["parse_s"] = t1 - t0
|
| 641 |
-
|
| 642 |
-
conn = _connect_readonly(db_path)
|
| 643 |
try:
|
| 644 |
-
|
| 645 |
-
|
| 646 |
-
|
| 647 |
-
|
| 648 |
-
|
| 649 |
-
|
| 650 |
-
|
| 651 |
-
|
| 652 |
-
|
| 653 |
-
|
| 654 |
-
|
| 655 |
-
|
| 656 |
-
timings["exec_s"] = time.perf_counter() - e0
|
| 657 |
-
if gold_res == EXECUTION_ERROR:
|
| 658 |
-
return 0.0, timings
|
| 659 |
-
|
| 660 |
-
reward = -0.2 + 0.2
|
| 661 |
-
if pred_res == gold_res:
|
| 662 |
-
return 1.0, timings
|
| 663 |
-
return max(-1.0, min(1.0, reward)), timings
|
| 664 |
-
finally:
|
| 665 |
-
try:
|
| 666 |
-
conn.close()
|
| 667 |
-
except Exception:
|
| 668 |
-
pass
|
| 669 |
-
|
| 670 |
|
| 671 |
-
# =========================================================
|
| 672 |
-
# 🔥 FINAL REWARD FUNCTION (TASK 2 INTEGRATED)
|
| 673 |
-
# =========================================================
|
| 674 |
def execution_reward(pred_sql: str, db_path: str, gold_sql: str) -> float:
|
| 675 |
try:
|
| 676 |
sql = _normalize_sql(pred_sql)
|
| 677 |
gold = _normalize_sql(gold_sql)
|
| 678 |
|
| 679 |
-
if not
|
| 680 |
return -1.0
|
| 681 |
|
| 682 |
-
|
| 683 |
-
|
| 684 |
-
# =========================
|
| 685 |
-
# SCHEMA VALIDATION (Task 3)
|
| 686 |
-
# =========================
|
| 687 |
-
if USE_SCHEMA_VALIDATION:
|
| 688 |
-
valid, _ = validate_sql_schema(sql, db_path)
|
| 689 |
-
if not valid:
|
| 690 |
-
error_type = classify_error(sql)
|
| 691 |
-
log_error("UNKNOWN", sql, "schema_invalid", error_type)
|
| 692 |
-
return 0.1
|
| 693 |
-
|
| 694 |
-
# =========================
|
| 695 |
-
# EXECUTION
|
| 696 |
-
# =========================
|
| 697 |
-
pred_res = execute_sql_cached(db_path, sql)
|
| 698 |
-
|
| 699 |
-
if pred_res == "EXECUTION_ERROR":
|
| 700 |
-
error_type = classify_error(sql)
|
| 701 |
-
|
| 702 |
-
log_error(
|
| 703 |
-
question="UNKNOWN",
|
| 704 |
-
sql=sql,
|
| 705 |
-
error="execution_error",
|
| 706 |
-
error_type=error_type
|
| 707 |
-
)
|
| 708 |
-
|
| 709 |
-
print(f"[ERROR] {error_type}")
|
| 710 |
-
print(f"[HINT] {get_hint(error_type)}")
|
| 711 |
-
|
| 712 |
-
return 0.1
|
| 713 |
-
|
| 714 |
-
reward += 0.2
|
| 715 |
-
|
| 716 |
-
gold_res = execute_sql_cached(db_path, gold)
|
| 717 |
-
|
| 718 |
-
if gold_res == "EXECUTION_ERROR":
|
| 719 |
-
return 0.1
|
| 720 |
-
|
| 721 |
-
if pred_res == gold_res:
|
| 722 |
-
return 1.0
|
| 723 |
-
|
| 724 |
-
return max(-1.0, min(1.0, reward))
|
| 725 |
-
|
| 726 |
-
except Exception as e:
|
| 727 |
-
log_error("UNKNOWN", pred_sql, str(e), "runtime_error")
|
| 728 |
-
return 0.0
|
| 729 |
-
|
| 730 |
-
|
| 731 |
-
# =========================================================
|
| 732 |
-
# BATCH EXECUTION (Task 1)
|
| 733 |
-
# =========================================================
|
| 734 |
-
def cached_execution_reward(pred_sql: str, db_path: str, gold_sql: str) -> float:
|
| 735 |
-
if not USE_CACHE:
|
| 736 |
-
return float(execution_reward(pred_sql, db_path, gold_sql))
|
| 737 |
-
key = f"{db_path}|{pred_sql}|{gold_sql}"
|
| 738 |
-
if key in _REWARD_CACHE:
|
| 739 |
-
return float(_REWARD_CACHE[key])
|
| 740 |
-
r = float(execution_reward(pred_sql, db_path, gold_sql))
|
| 741 |
-
_REWARD_CACHE[key] = r
|
| 742 |
-
return r
|
| 743 |
-
|
| 744 |
-
|
| 745 |
-
def execution_reward_batch_sequential(rollouts):
|
| 746 |
-
return [cached_execution_reward(p, d, g) for (p, d, g) in rollouts]
|
| 747 |
-
|
| 748 |
-
|
| 749 |
-
def execution_reward_batch_parallel(rollouts, max_workers=10):
|
| 750 |
-
results = [0.0] * len(rollouts)
|
| 751 |
-
|
| 752 |
-
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
| 753 |
-
futures = {
|
| 754 |
-
executor.submit(cached_execution_reward, p, d, g): i
|
| 755 |
-
for i, (p, d, g) in enumerate(rollouts)
|
| 756 |
-
}
|
| 757 |
|
| 758 |
-
|
| 759 |
-
idx = futures[fut]
|
| 760 |
-
try:
|
| 761 |
-
results[idx] = fut.result()
|
| 762 |
-
except Exception:
|
| 763 |
-
results[idx] = 0.0
|
| 764 |
|
| 765 |
-
|
|
|
|
| 766 |
|
|
|
|
|
|
|
| 767 |
|
| 768 |
-
|
| 769 |
-
|
| 770 |
-
1 thread per DB path. Reuses a single readonly connection per DB worker.
|
| 771 |
-
Preserves input order.
|
| 772 |
-
"""
|
| 773 |
-
if not rollouts:
|
| 774 |
-
return []
|
| 775 |
|
| 776 |
-
|
| 777 |
-
|
| 778 |
-
|
| 779 |
|
| 780 |
-
|
|
|
|
|
|
|
|
|
|
| 781 |
|
| 782 |
-
|
| 783 |
-
|
| 784 |
-
sql = _normalize_sql(pred_sql)
|
| 785 |
-
gold = _normalize_sql(gold_sql)
|
| 786 |
-
|
| 787 |
-
if not is_valid_select(sql):
|
| 788 |
-
return -1.0
|
| 789 |
-
|
| 790 |
-
reward = -0.2
|
| 791 |
-
|
| 792 |
-
if USE_SCHEMA_VALIDATION:
|
| 793 |
-
valid, _ = validate_sql_schema(sql, db_path)
|
| 794 |
-
if not valid:
|
| 795 |
-
error_type = classify_error(sql)
|
| 796 |
-
log_error("UNKNOWN", sql, "schema_invalid", error_type)
|
| 797 |
-
return 0.1
|
| 798 |
-
|
| 799 |
-
pred_res = execute_sql_cached_conn(conn, db_path, sql)
|
| 800 |
-
if pred_res == EXECUTION_ERROR:
|
| 801 |
-
error_type = classify_error(sql)
|
| 802 |
-
log_error("UNKNOWN", sql, "execution_error", error_type)
|
| 803 |
-
return 0.1
|
| 804 |
-
|
| 805 |
-
reward += 0.2
|
| 806 |
-
gold_res = execute_sql_cached_conn(conn, db_path, gold)
|
| 807 |
-
if gold_res == EXECUTION_ERROR:
|
| 808 |
-
return 0.1
|
| 809 |
-
if pred_res == gold_res:
|
| 810 |
return 1.0
|
| 811 |
-
return max(-1.0, min(1.0, reward))
|
| 812 |
-
except Exception:
|
| 813 |
-
return 0.0
|
| 814 |
|
| 815 |
-
|
| 816 |
-
conn = _connect_readonly(db_path)
|
| 817 |
-
try:
|
| 818 |
-
for idx, pred, gold in items:
|
| 819 |
-
results[idx] = _reward_with_conn(conn, pred, db_path, gold)
|
| 820 |
-
finally:
|
| 821 |
-
try:
|
| 822 |
-
conn.close()
|
| 823 |
-
except Exception:
|
| 824 |
-
pass
|
| 825 |
-
|
| 826 |
-
with ThreadPoolExecutor(max_workers=int(max_workers)) as ex:
|
| 827 |
-
futures = [ex.submit(_worker, db_path, items) for db_path, items in by_db.items()]
|
| 828 |
-
for fut in as_completed(futures):
|
| 829 |
-
fut.result()
|
| 830 |
|
| 831 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
import os
|
| 4 |
import re
|
| 5 |
import sqlite3
|
|
|
|
| 6 |
import time
|
|
|
|
|
|
|
| 7 |
from dataclasses import dataclass
|
| 8 |
+
from typing import List, Optional, Sequence, Set, Tuple, Union
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
+
try:
|
| 11 |
+
import sqlparse
|
| 12 |
+
from sqlparse.sql import Function, Identifier, IdentifierList, Statement, Token, Where
|
| 13 |
+
from sqlparse.tokens import DML, Keyword, Name, Number, Punctuation, String, Whitespace
|
| 14 |
+
except Exception: # pragma: no cover
|
| 15 |
+
sqlparse = None # type: ignore[assignment]
|
| 16 |
+
Statement = object # type: ignore[misc,assignment]
|
| 17 |
+
Token = object # type: ignore[misc,assignment]
|
| 18 |
|
| 19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
def _normalize_sql(sql: str) -> str:
|
| 21 |
if not isinstance(sql, str):
|
| 22 |
return ""
|
| 23 |
s = sql.strip()
|
|
|
|
| 24 |
if s.startswith("```"):
|
| 25 |
+
# Strip markdown fences if present.
|
| 26 |
s = re.sub(r"^```[a-zA-Z0-9_+-]*\n?", "", s).strip()
|
| 27 |
s = re.sub(r"\n?```$", "", s).strip()
|
|
|
|
| 28 |
if s.lower().startswith("sql:"):
|
| 29 |
s = s[4:].strip()
|
| 30 |
+
# Keep only the first statement to avoid accidental multi-statement execution.
|
| 31 |
if ";" in s:
|
| 32 |
s = s.split(";", 1)[0].strip()
|
|
|
|
| 33 |
return s
|
| 34 |
|
| 35 |
|
| 36 |
+
def _connect_readonly(db_path: str) -> sqlite3.Connection:
|
| 37 |
+
# Read-only prevents any accidental mutation during reward computation.
|
| 38 |
+
# Note: requires SQLite URI support (built-in).
|
|
|
|
| 39 |
uri = f"file:{os.path.abspath(db_path)}?mode=ro"
|
| 40 |
conn = sqlite3.connect(uri, uri=True, check_same_thread=False)
|
| 41 |
conn.execute("PRAGMA query_only = ON;")
|
|
|
|
| 43 |
return conn
|
| 44 |
|
| 45 |
|
| 46 |
+
def _with_timeout(conn: sqlite3.Connection, timeout_s: float = 1.0) -> None:
|
| 47 |
start = time.monotonic()
|
| 48 |
|
| 49 |
+
def _handler() -> int:
|
| 50 |
return 1 if (time.monotonic() - start) > timeout_s else 0
|
| 51 |
|
| 52 |
+
# Call handler every N VM opcodes.
|
| 53 |
+
conn.set_progress_handler(_handler, 10_000)
|
| 54 |
|
| 55 |
|
| 56 |
+
def _list_tables(conn: sqlite3.Connection) -> List[str]:
|
| 57 |
try:
|
| 58 |
+
cur = conn.execute(
|
| 59 |
+
"SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%';"
|
| 60 |
+
)
|
| 61 |
+
return [r[0] for r in cur.fetchall() if r and isinstance(r[0], str)]
|
| 62 |
+
except sqlite3.Error:
|
| 63 |
+
return []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
|
| 66 |
+
def _contains_table_name(sql: str, table_names: Sequence[str]) -> bool:
|
| 67 |
+
s = sql.lower()
|
| 68 |
+
for t in table_names:
|
| 69 |
+
tl = t.lower()
|
| 70 |
+
if not tl:
|
| 71 |
+
continue
|
| 72 |
+
if re.search(rf"\b{re.escape(tl)}\b", s):
|
| 73 |
+
return True
|
| 74 |
+
return False
|
| 75 |
|
|
|
|
|
|
|
|
|
|
| 76 |
|
| 77 |
+
def _explain_query_plan(conn: sqlite3.Connection, sql: str) -> bool:
|
| 78 |
+
try:
|
| 79 |
+
_with_timeout(conn, timeout_s=1.0)
|
| 80 |
+
conn.execute(f"EXPLAIN QUERY PLAN {sql}")
|
| 81 |
+
return True
|
| 82 |
+
except sqlite3.Error:
|
| 83 |
+
return False
|
| 84 |
|
| 85 |
|
| 86 |
+
def _execute(conn: sqlite3.Connection, sql: str, max_rows: int = 1000) -> Tuple[bool, List[Tuple], Optional[str]]:
|
| 87 |
+
try:
|
| 88 |
+
_with_timeout(conn, timeout_s=1.0)
|
| 89 |
+
cur = conn.execute(sql)
|
| 90 |
+
rows = cur.fetchmany(max_rows)
|
| 91 |
+
# Normalize to plain tuples for deterministic comparison.
|
| 92 |
+
norm_rows = [tuple(r) for r in rows]
|
| 93 |
+
return True, norm_rows, None
|
| 94 |
+
except sqlite3.Error as e:
|
| 95 |
+
return False, [], str(e)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
_SQL_KEYWORDS_TO_IGNORE = {
|
| 99 |
+
"select",
|
| 100 |
+
"from",
|
| 101 |
+
"where",
|
| 102 |
+
"join",
|
| 103 |
+
"inner",
|
| 104 |
+
"left",
|
| 105 |
+
"right",
|
| 106 |
+
"full",
|
| 107 |
+
"outer",
|
| 108 |
+
"on",
|
| 109 |
+
"group",
|
| 110 |
+
"by",
|
| 111 |
+
"order",
|
| 112 |
+
"limit",
|
| 113 |
+
"having",
|
| 114 |
+
"distinct",
|
| 115 |
+
"union",
|
| 116 |
+
"intersect",
|
| 117 |
+
"except",
|
| 118 |
+
"as",
|
| 119 |
+
"and",
|
| 120 |
+
"or",
|
| 121 |
+
"not",
|
| 122 |
+
"in",
|
| 123 |
+
"is",
|
| 124 |
+
"null",
|
| 125 |
+
"like",
|
| 126 |
+
"between",
|
| 127 |
+
"case",
|
| 128 |
+
"when",
|
| 129 |
+
"then",
|
| 130 |
+
"else",
|
| 131 |
+
"end",
|
| 132 |
+
"asc",
|
| 133 |
+
"desc",
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
_SQL_FUNCTIONS_TO_IGNORE = {
|
| 137 |
+
"count",
|
| 138 |
+
"avg",
|
| 139 |
+
"min",
|
| 140 |
+
"max",
|
| 141 |
+
"sum",
|
| 142 |
+
"lower",
|
| 143 |
+
"upper",
|
| 144 |
+
"substr",
|
| 145 |
+
"coalesce",
|
| 146 |
+
"round",
|
| 147 |
+
"date",
|
| 148 |
+
"datetime",
|
| 149 |
+
"strftime",
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def extract_tables(sql: str) -> Set[str]:
|
| 154 |
"""
|
| 155 |
+
Best-effort table extraction from SQL using sqlparse.
|
| 156 |
+
Returns lowercase table names (unqualified).
|
| 157 |
"""
|
| 158 |
+
sql = _normalize_sql(sql)
|
| 159 |
+
if not sql:
|
| 160 |
+
return set()
|
| 161 |
+
if sqlparse is None:
|
| 162 |
+
# Fallback: naive regex for FROM/JOIN.
|
| 163 |
+
found = set()
|
| 164 |
+
for m in re.finditer(r"\b(from|join)\s+([a-zA-Z_][\w$]*)", sql, flags=re.I):
|
| 165 |
+
found.add(m.group(2).lower())
|
| 166 |
+
return found
|
| 167 |
|
| 168 |
+
try:
|
| 169 |
+
statements = sqlparse.parse(sql)
|
| 170 |
+
except Exception:
|
| 171 |
+
return set()
|
| 172 |
+
|
| 173 |
+
tables: Set[str] = set()
|
| 174 |
+
|
| 175 |
+
def _add_identifier_as_table(ident: Identifier) -> None:
|
| 176 |
+
# Prefer real name over alias; strip any schema prefix.
|
| 177 |
+
name = ident.get_real_name() or ident.get_name()
|
| 178 |
+
if not name:
|
| 179 |
+
return
|
| 180 |
+
tables.add(name.lower())
|
| 181 |
+
|
| 182 |
+
for st in statements:
|
| 183 |
+
if not isinstance(st, Statement):
|
| 184 |
+
continue
|
| 185 |
+
seen_from = False
|
| 186 |
+
for tok in st.flatten():
|
| 187 |
+
if tok.ttype in Whitespace:
|
| 188 |
+
continue
|
| 189 |
+
if tok.ttype is Keyword and tok.value.upper() in {"FROM", "JOIN", "INNER JOIN", "LEFT JOIN", "RIGHT JOIN", "FULL JOIN"}:
|
| 190 |
+
seen_from = True
|
| 191 |
+
continue
|
| 192 |
+
if not seen_from:
|
| 193 |
+
continue
|
| 194 |
+
|
| 195 |
+
if isinstance(tok, Identifier):
|
| 196 |
+
_add_identifier_as_table(tok)
|
| 197 |
+
seen_from = False
|
| 198 |
+
elif tok.ttype is Name:
|
| 199 |
+
tables.add(tok.value.lower())
|
| 200 |
+
seen_from = False
|
| 201 |
+
elif tok.ttype is Keyword and tok.value.upper() in {"WHERE", "GROUP", "ORDER", "HAVING", "LIMIT"}:
|
| 202 |
+
seen_from = False
|
| 203 |
+
|
| 204 |
+
return tables
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def extract_columns(sql: str) -> Set[str]:
|
| 208 |
+
"""
|
| 209 |
+
Best-effort column extraction from SQL using sqlparse.
|
| 210 |
+
Returns lowercase column names (unqualified).
|
| 211 |
+
"""
|
| 212 |
+
sql = _normalize_sql(sql)
|
| 213 |
+
if not sql:
|
| 214 |
+
return set()
|
| 215 |
+
if sqlparse is None:
|
| 216 |
+
# Fallback: naive dotted identifiers and bare names after SELECT/WHERE/etc.
|
| 217 |
+
cols = set()
|
| 218 |
+
for m in re.finditer(r"\b([a-zA-Z_][\w$]*)\b", sql):
|
| 219 |
+
w = m.group(1).lower()
|
| 220 |
+
if w in _SQL_KEYWORDS_TO_IGNORE or w in _SQL_FUNCTIONS_TO_IGNORE:
|
| 221 |
+
continue
|
| 222 |
+
cols.add(w)
|
| 223 |
+
return cols
|
| 224 |
|
| 225 |
+
try:
|
| 226 |
+
statements = sqlparse.parse(sql)
|
| 227 |
+
except Exception:
|
| 228 |
+
return set()
|
| 229 |
+
|
| 230 |
+
cols: Set[str] = set()
|
| 231 |
+
|
| 232 |
+
def _maybe_add_col(name: Optional[str]) -> None:
|
| 233 |
+
if not name:
|
| 234 |
+
return
|
| 235 |
+
n = name.strip().strip('"').strip("'").lower()
|
| 236 |
+
if not n or n == "*":
|
| 237 |
+
return
|
| 238 |
+
if n in _SQL_KEYWORDS_TO_IGNORE or n in _SQL_FUNCTIONS_TO_IGNORE:
|
| 239 |
+
return
|
| 240 |
+
cols.add(n)
|
| 241 |
+
|
| 242 |
+
def _handle_identifier(ident: Identifier) -> None:
|
| 243 |
+
# If qualified (t.col), keep only col for overlap/hallucination checks.
|
| 244 |
+
_maybe_add_col(ident.get_real_name() or ident.get_name())
|
| 245 |
+
|
| 246 |
+
for st in statements:
|
| 247 |
+
if not isinstance(st, Statement):
|
| 248 |
+
continue
|
| 249 |
+
for tok in st.flatten():
|
| 250 |
+
# Skip whitespace/punctuation/string literals/numbers.
|
| 251 |
+
if getattr(tok, "ttype", None) in (Whitespace, Punctuation, String, Number):
|
| 252 |
+
continue
|
| 253 |
+
|
| 254 |
+
if isinstance(tok, Function):
|
| 255 |
+
fname = tok.get_name()
|
| 256 |
+
if fname:
|
| 257 |
+
# Don't treat function name as a column.
|
| 258 |
+
pass
|
| 259 |
+
continue
|
| 260 |
+
|
| 261 |
+
if isinstance(tok, IdentifierList):
|
| 262 |
+
for ident in tok.get_identifiers():
|
| 263 |
+
if isinstance(ident, Identifier):
|
| 264 |
+
_handle_identifier(ident)
|
| 265 |
+
continue
|
| 266 |
+
|
| 267 |
+
if isinstance(tok, Identifier):
|
| 268 |
+
_handle_identifier(tok)
|
| 269 |
+
continue
|
| 270 |
+
|
| 271 |
+
if getattr(tok, "ttype", None) is Name:
|
| 272 |
+
_maybe_add_col(tok.value)
|
| 273 |
+
|
| 274 |
+
return cols
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
def _get_db_tables_and_columns(conn: sqlite3.Connection) -> Tuple[Set[str], Set[str]]:
|
| 278 |
+
"""
|
| 279 |
+
Return (tables, columns) sets from SQLite schema; all lowercased.
|
| 280 |
+
Columns are returned as a global set (unqualified).
|
| 281 |
+
"""
|
| 282 |
+
tables = set()
|
| 283 |
+
columns = set()
|
| 284 |
+
for t in _list_tables(conn):
|
| 285 |
+
tl = t.lower()
|
| 286 |
+
if not tl:
|
| 287 |
+
continue
|
| 288 |
+
tables.add(tl)
|
| 289 |
+
try:
|
| 290 |
+
cur = conn.execute(f'PRAGMA table_info("{t}")')
|
| 291 |
+
for row in cur.fetchall():
|
| 292 |
+
if row and isinstance(row[1], str):
|
| 293 |
+
columns.add(row[1].lower())
|
| 294 |
+
except sqlite3.Error:
|
| 295 |
+
continue
|
| 296 |
+
return tables, columns
|
| 297 |
|
|
|
|
| 298 |
|
| 299 |
+
def _safe_results_equal(a: List[Tuple], b: List[Tuple]) -> bool:
|
| 300 |
+
# Deterministic comparison: compare exact row tuples in order.
|
| 301 |
+
return a == b
|
| 302 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 303 |
|
| 304 |
+
@dataclass
|
| 305 |
+
class RewardDebugStats:
|
| 306 |
+
total: int = 0
|
| 307 |
+
parsed_ok: int = 0
|
| 308 |
+
table_match: int = 0
|
| 309 |
+
column_match: int = 0
|
| 310 |
+
executed_ok: int = 0
|
| 311 |
+
exact_match: int = 0
|
| 312 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 313 |
|
| 314 |
+
_DEBUG = RewardDebugStats()
|
| 315 |
|
|
|
|
|
|
|
| 316 |
|
| 317 |
+
def reset_debug_metrics() -> None:
|
| 318 |
+
global _DEBUG
|
| 319 |
+
_DEBUG = RewardDebugStats()
|
| 320 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 321 |
|
| 322 |
+
def get_debug_metrics() -> dict:
|
| 323 |
+
denom = max(_DEBUG.total, 1)
|
| 324 |
+
return {
|
| 325 |
+
"valid_sql_rate": _DEBUG.parsed_ok / denom,
|
| 326 |
+
"table_match_rate": _DEBUG.table_match / denom,
|
| 327 |
+
"column_match_rate": _DEBUG.column_match / denom,
|
| 328 |
+
"execution_accuracy": _DEBUG.exact_match / denom,
|
| 329 |
+
}
|
| 330 |
|
| 331 |
+
EXECUTION_ERROR = "EXECUTION_ERROR"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 332 |
|
|
|
|
| 333 |
|
| 334 |
+
def execute_sql(conn: sqlite3.Connection, sql: str, *, max_rows: int = 1000) -> Union[List[Tuple], str]:
|
| 335 |
+
"""
|
| 336 |
+
Execute SQL safely.
|
| 337 |
|
| 338 |
+
If sqlite raises ANY exception, return EXECUTION_ERROR (NOT empty list).
|
| 339 |
+
"""
|
| 340 |
try:
|
| 341 |
+
_with_timeout(conn, timeout_s=1.0)
|
| 342 |
+
cur = conn.execute(sql)
|
| 343 |
+
rows = cur.fetchmany(max_rows)
|
| 344 |
+
return [tuple(r) for r in rows]
|
| 345 |
except Exception:
|
| 346 |
+
return EXECUTION_ERROR
|
| 347 |
|
| 348 |
|
| 349 |
+
def _sqlparse_valid_select(sql: str) -> bool:
|
| 350 |
"""
|
| 351 |
+
Parse validation using sqlparse:
|
| 352 |
+
- parse() non-empty
|
| 353 |
+
- contains a SELECT statement
|
| 354 |
"""
|
| 355 |
+
if sqlparse is None:
|
| 356 |
+
return False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 357 |
try:
|
| 358 |
+
stmts = sqlparse.parse(sql)
|
| 359 |
+
if not stmts:
|
| 360 |
+
return False
|
| 361 |
+
for st in stmts:
|
| 362 |
+
try:
|
| 363 |
+
if hasattr(st, "get_type") and st.get_type() == "SELECT":
|
| 364 |
+
return True
|
| 365 |
+
except Exception:
|
| 366 |
+
continue
|
| 367 |
+
return False
|
| 368 |
+
except Exception:
|
| 369 |
+
return False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 370 |
|
|
|
|
|
|
|
|
|
|
| 371 |
def execution_reward(pred_sql: str, db_path: str, gold_sql: str) -> float:
|
| 372 |
try:
|
| 373 |
sql = _normalize_sql(pred_sql)
|
| 374 |
gold = _normalize_sql(gold_sql)
|
| 375 |
|
| 376 |
+
if not sql or "SELECT" not in sql.upper():
|
| 377 |
return -1.0
|
| 378 |
|
| 379 |
+
if not _sqlparse_valid_select(sql):
|
| 380 |
+
return -1.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 381 |
|
| 382 |
+
reward = -0.2 # valid SQL baseline
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 383 |
|
| 384 |
+
pred_tables = extract_tables(sql)
|
| 385 |
+
gold_tables = extract_tables(gold)
|
| 386 |
|
| 387 |
+
if pred_tables == gold_tables and len(gold_tables) > 0:
|
| 388 |
+
reward += 0.3
|
| 389 |
|
| 390 |
+
pred_cols = extract_columns(sql)
|
| 391 |
+
gold_cols = extract_columns(gold)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 392 |
|
| 393 |
+
if gold_cols:
|
| 394 |
+
overlap = len(pred_cols & gold_cols) / len(gold_cols)
|
| 395 |
+
reward += 0.3 * overlap
|
| 396 |
|
| 397 |
+
with _connect_readonly(db_path) as conn:
|
| 398 |
+
pred_res = execute_sql(conn, sql)
|
| 399 |
+
if pred_res != EXECUTION_ERROR:
|
| 400 |
+
reward += 0.2
|
| 401 |
|
| 402 |
+
gold_res = execute_sql(conn, gold)
|
| 403 |
+
if pred_res != EXECUTION_ERROR and _safe_results_equal(pred_res, gold_res):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 404 |
return 1.0
|
|
|
|
|
|
|
|
|
|
| 405 |
|
| 406 |
+
return max(-1.0, min(1.0, reward))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 407 |
|
| 408 |
+
except Exception:
|
| 409 |
+
return -1.0
|
src/execution_reward_soft.py
DELETED
|
@@ -1,211 +0,0 @@
|
|
| 1 |
-
import random
|
| 2 |
-
import threading
|
| 3 |
-
from collections import Counter
|
| 4 |
-
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 5 |
-
from src.execution_reward import (
|
| 6 |
-
_normalize_sql,
|
| 7 |
-
is_valid_select,
|
| 8 |
-
execute_sql_cached,
|
| 9 |
-
execute_sql_cached_conn,
|
| 10 |
-
EXECUTION_ERROR,
|
| 11 |
-
validate_sql_schema,
|
| 12 |
-
USE_SCHEMA_VALIDATION,
|
| 13 |
-
_connect_readonly,
|
| 14 |
-
)
|
| 15 |
-
|
| 16 |
-
# =========================================================
|
| 17 |
-
# 🔥 SOFT REWARD CORE
|
| 18 |
-
# =========================================================
|
| 19 |
-
def compute_soft_reward(pred_res, gold_res, sample_k=10):
|
| 20 |
-
try:
|
| 21 |
-
# =================================================
|
| 22 |
-
# 1. EDGE CASES
|
| 23 |
-
# =================================================
|
| 24 |
-
if not gold_res:
|
| 25 |
-
return 1.0 if not pred_res else 0.3
|
| 26 |
-
|
| 27 |
-
if not pred_res:
|
| 28 |
-
return -0.05
|
| 29 |
-
|
| 30 |
-
# =================================================
|
| 31 |
-
# 2. SAFE HASHING
|
| 32 |
-
# =================================================
|
| 33 |
-
def make_hashable(row):
|
| 34 |
-
return tuple(str(item) for item in row)
|
| 35 |
-
|
| 36 |
-
pred_counter = Counter(make_hashable(r) for r in pred_res)
|
| 37 |
-
|
| 38 |
-
# =================================================
|
| 39 |
-
# 3. SAMPLING
|
| 40 |
-
# =================================================
|
| 41 |
-
k = min(sample_k, len(gold_res))
|
| 42 |
-
sample = random.sample(gold_res, k)
|
| 43 |
-
|
| 44 |
-
# =================================================
|
| 45 |
-
# 4. MATCH COUNT
|
| 46 |
-
# =================================================
|
| 47 |
-
match = 0
|
| 48 |
-
for row in sample:
|
| 49 |
-
key = make_hashable(row)
|
| 50 |
-
if pred_counter.get(key, 0) > 0:
|
| 51 |
-
pred_counter[key] -= 1
|
| 52 |
-
match += 1
|
| 53 |
-
|
| 54 |
-
score = match / max(len(sample), 1)
|
| 55 |
-
|
| 56 |
-
# =================================================
|
| 57 |
-
# 5. 🔥 ANTI-CHEAT LENGTH PENALTY
|
| 58 |
-
# =================================================
|
| 59 |
-
len_ratio = len(pred_res) / max(len(gold_res), 1)
|
| 60 |
-
|
| 61 |
-
if len_ratio > 1.5:
|
| 62 |
-
score = score / (len_ratio ** 0.5) # 🔥 smoother penalty
|
| 63 |
-
|
| 64 |
-
# =================================================
|
| 65 |
-
# 6. CLAMP SCORE (IMPORTANT FOR STABILITY)
|
| 66 |
-
# =================================================
|
| 67 |
-
score = max(0.0, min(1.0, score))
|
| 68 |
-
|
| 69 |
-
# =================================================
|
| 70 |
-
# 7. FINAL REWARD
|
| 71 |
-
# =================================================
|
| 72 |
-
return 0.3 + 0.7 * score
|
| 73 |
-
|
| 74 |
-
except Exception:
|
| 75 |
-
return -0.05
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
# =========================================================
|
| 79 |
-
# 🔥 MAIN EXECUTION REWARD
|
| 80 |
-
# =========================================================
|
| 81 |
-
_TLS = threading.local()
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
def _get_thread_conn(db_path: str):
|
| 85 |
-
conns = getattr(_TLS, "conns", None)
|
| 86 |
-
if conns is None:
|
| 87 |
-
conns = {}
|
| 88 |
-
_TLS.conns = conns
|
| 89 |
-
conn = conns.get(db_path)
|
| 90 |
-
if conn is None:
|
| 91 |
-
conn = _connect_readonly(db_path)
|
| 92 |
-
conns[db_path] = conn
|
| 93 |
-
return conn
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
def execution_reward_soft_pooled(pred_sql, db_path, gold_sql, *, sample_k: int = 10):
|
| 97 |
-
"""
|
| 98 |
-
Soft execution reward, but reuses a per-thread read-only SQLite connection.
|
| 99 |
-
This avoids connect/close overhead in RL loops.
|
| 100 |
-
"""
|
| 101 |
-
try:
|
| 102 |
-
sql = _normalize_sql(pred_sql)
|
| 103 |
-
gold = _normalize_sql(gold_sql)
|
| 104 |
-
|
| 105 |
-
if not is_valid_select(sql):
|
| 106 |
-
return -0.05
|
| 107 |
-
|
| 108 |
-
if USE_SCHEMA_VALIDATION:
|
| 109 |
-
ok, _ = validate_sql_schema(sql, db_path)
|
| 110 |
-
if not ok:
|
| 111 |
-
return -0.05
|
| 112 |
-
|
| 113 |
-
conn = _get_thread_conn(db_path)
|
| 114 |
-
pred_res = execute_sql_cached_conn(conn, db_path, sql)
|
| 115 |
-
if pred_res == EXECUTION_ERROR:
|
| 116 |
-
return -0.05
|
| 117 |
-
|
| 118 |
-
gold_res = execute_sql_cached_conn(conn, db_path, gold)
|
| 119 |
-
if gold_res == EXECUTION_ERROR:
|
| 120 |
-
return -0.05
|
| 121 |
-
|
| 122 |
-
return compute_soft_reward(pred_res, gold_res, sample_k=int(sample_k))
|
| 123 |
-
except Exception:
|
| 124 |
-
return -0.05
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
def execution_reward_soft(pred_sql, db_path, gold_sql):
|
| 128 |
-
try:
|
| 129 |
-
sql = _normalize_sql(pred_sql)
|
| 130 |
-
gold = _normalize_sql(gold_sql)
|
| 131 |
-
|
| 132 |
-
# =================================================
|
| 133 |
-
# BASIC VALIDATION
|
| 134 |
-
# =================================================
|
| 135 |
-
if not is_valid_select(sql):
|
| 136 |
-
return -0.05
|
| 137 |
-
|
| 138 |
-
if USE_SCHEMA_VALIDATION:
|
| 139 |
-
ok, _ = validate_sql_schema(sql, db_path)
|
| 140 |
-
if not ok:
|
| 141 |
-
return -0.05
|
| 142 |
-
|
| 143 |
-
# =================================================
|
| 144 |
-
# EXECUTION
|
| 145 |
-
# =================================================
|
| 146 |
-
pred_res = execute_sql_cached(db_path, sql)
|
| 147 |
-
if pred_res == EXECUTION_ERROR:
|
| 148 |
-
return -0.05
|
| 149 |
-
|
| 150 |
-
gold_res = execute_sql_cached(db_path, gold)
|
| 151 |
-
if gold_res == EXECUTION_ERROR:
|
| 152 |
-
return -0.05
|
| 153 |
-
|
| 154 |
-
return compute_soft_reward(pred_res, gold_res)
|
| 155 |
-
|
| 156 |
-
except Exception:
|
| 157 |
-
return -0.05
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
def execution_reward_soft_batch_parallel_by_db(rollouts, *, max_workers: int = 20, sample_k: int = 10):
|
| 161 |
-
"""
|
| 162 |
-
rollouts: Sequence[(pred_sql, db_path, gold_sql)]
|
| 163 |
-
Executes with 1-thread-per-DB grouping for better connection reuse.
|
| 164 |
-
Returns rewards in the same order as input.
|
| 165 |
-
"""
|
| 166 |
-
if not rollouts:
|
| 167 |
-
return []
|
| 168 |
-
|
| 169 |
-
# Group by DB so each worker can hold a single connection and reuse it.
|
| 170 |
-
by_db = {}
|
| 171 |
-
for idx, (pred_sql, db_path, gold_sql) in enumerate(rollouts):
|
| 172 |
-
by_db.setdefault(db_path, []).append((idx, pred_sql, gold_sql))
|
| 173 |
-
|
| 174 |
-
out = [0.0 for _ in range(len(rollouts))]
|
| 175 |
-
|
| 176 |
-
def _worker(db_path: str, items):
|
| 177 |
-
conn = _connect_readonly(db_path)
|
| 178 |
-
try:
|
| 179 |
-
for idx, pred_sql, gold_sql in items:
|
| 180 |
-
# Do NOT use the global thread-local here; this worker owns the connection.
|
| 181 |
-
try:
|
| 182 |
-
sql = _normalize_sql(pred_sql)
|
| 183 |
-
gold = _normalize_sql(gold_sql)
|
| 184 |
-
if not is_valid_select(sql):
|
| 185 |
-
out[idx] = -0.05
|
| 186 |
-
continue
|
| 187 |
-
if USE_SCHEMA_VALIDATION:
|
| 188 |
-
ok, _ = validate_sql_schema(sql, db_path)
|
| 189 |
-
if not ok:
|
| 190 |
-
out[idx] = -0.05
|
| 191 |
-
continue
|
| 192 |
-
pred_res = execute_sql_cached_conn(conn, db_path, sql)
|
| 193 |
-
if pred_res == EXECUTION_ERROR:
|
| 194 |
-
out[idx] = -0.05
|
| 195 |
-
continue
|
| 196 |
-
gold_res = execute_sql_cached_conn(conn, db_path, gold)
|
| 197 |
-
if gold_res == EXECUTION_ERROR:
|
| 198 |
-
out[idx] = -0.05
|
| 199 |
-
continue
|
| 200 |
-
out[idx] = float(compute_soft_reward(pred_res, gold_res, sample_k=int(sample_k)))
|
| 201 |
-
except Exception:
|
| 202 |
-
out[idx] = -0.05
|
| 203 |
-
finally:
|
| 204 |
-
conn.close()
|
| 205 |
-
|
| 206 |
-
with ThreadPoolExecutor(max_workers=int(max_workers)) as ex:
|
| 207 |
-
futures = [ex.submit(_worker, db_path, items) for db_path, items in by_db.items()]
|
| 208 |
-
for fut in as_completed(futures):
|
| 209 |
-
fut.result()
|
| 210 |
-
|
| 211 |
-
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/load_lora_model.py
CHANGED
|
@@ -1,84 +1,21 @@
|
|
| 1 |
-
# import torch
|
| 2 |
-
# from transformers import T5ForConditionalGeneration, T5Tokenizer
|
| 3 |
-
# from peft import LoraConfig, get_peft_model, TaskType
|
| 4 |
-
|
| 5 |
-
# device = "mps" if torch.backends.mps.is_available() else "cpu"
|
| 6 |
-
|
| 7 |
-
# MODEL_PATH = "../outputs/model" # your supervised trained model
|
| 8 |
-
|
| 9 |
-
# print("Loading base model...")
|
| 10 |
-
# model = T5ForConditionalGeneration.from_pretrained(MODEL_PATH).to(device)
|
| 11 |
-
|
| 12 |
-
# tokenizer = T5Tokenizer.from_pretrained("t5-small")
|
| 13 |
-
|
| 14 |
-
# # ---------------- LoRA CONFIG ----------------
|
| 15 |
-
# lora_config = LoraConfig(
|
| 16 |
-
# r=8, # rank (small brain attachment)
|
| 17 |
-
# lora_alpha=16,
|
| 18 |
-
# target_modules=["q", "v"], # attention matrices only
|
| 19 |
-
# lora_dropout=0.05,
|
| 20 |
-
# bias="none",
|
| 21 |
-
# task_type=TaskType.SEQ_2_SEQ_LM
|
| 22 |
-
# )
|
| 23 |
-
|
| 24 |
-
# print("Attaching LoRA adapters...")
|
| 25 |
-
# model = get_peft_model(model, lora_config)
|
| 26 |
-
|
| 27 |
-
# model.print_trainable_parameters()
|
| 28 |
-
|
| 29 |
-
# print("READY ✔ LoRA model loaded")
|
| 30 |
-
|
| 31 |
-
# ****************** task 5 @#$%^&*I(O)(*&^%$#$%^&*(*&^%$#$%^&*^%$#%^)
|
| 32 |
-
# )
|
| 33 |
-
#
|
| 34 |
-
#
|
| 35 |
import torch
|
| 36 |
from transformers import T5ForConditionalGeneration, T5Tokenizer
|
| 37 |
from peft import LoraConfig, get_peft_model, TaskType
|
| 38 |
|
| 39 |
-
# ---------------- DEVICE SETUP ----------------
|
| 40 |
device = "mps" if torch.backends.mps.is_available() else "cpu"
|
| 41 |
|
| 42 |
-
MODEL_PATH = "../outputs/model"
|
| 43 |
-
|
| 44 |
-
# ---------------- LOAD TOKENIZER ----------------
|
| 45 |
-
tokenizer = T5Tokenizer.from_pretrained("t5-small")
|
| 46 |
-
|
| 47 |
-
# ---------------- LOAD MODEL WITH QUANTIZATION ----------------
|
| 48 |
-
def load_model(quantization=None):
|
| 49 |
-
print(f"Loading model with quantization = {quantization}")
|
| 50 |
-
|
| 51 |
-
if quantization == "int8":
|
| 52 |
-
model = T5ForConditionalGeneration.from_pretrained(
|
| 53 |
-
MODEL_PATH,
|
| 54 |
-
load_in_8bit=True,
|
| 55 |
-
device_map="auto"
|
| 56 |
-
)
|
| 57 |
-
|
| 58 |
-
elif quantization == "int4":
|
| 59 |
-
model = T5ForConditionalGeneration.from_pretrained(
|
| 60 |
-
MODEL_PATH,
|
| 61 |
-
load_in_4bit=True,
|
| 62 |
-
device_map="auto"
|
| 63 |
-
)
|
| 64 |
-
|
| 65 |
-
else: # fp32
|
| 66 |
-
model = T5ForConditionalGeneration.from_pretrained(MODEL_PATH).to(device)
|
| 67 |
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
# 👉 CHANGE THIS VALUE TO TEST
|
| 72 |
-
QUANTIZATION = "int8" # options: None, "int8", "int4"
|
| 73 |
-
|
| 74 |
-
model = load_model(QUANTIZATION)
|
| 75 |
|
|
|
|
| 76 |
|
| 77 |
# ---------------- LoRA CONFIG ----------------
|
| 78 |
lora_config = LoraConfig(
|
| 79 |
-
r=8,
|
| 80 |
lora_alpha=16,
|
| 81 |
-
target_modules=["q", "v"],
|
| 82 |
lora_dropout=0.05,
|
| 83 |
bias="none",
|
| 84 |
task_type=TaskType.SEQ_2_SEQ_LM
|
|
@@ -89,4 +26,5 @@ model = get_peft_model(model, lora_config)
|
|
| 89 |
|
| 90 |
model.print_trainable_parameters()
|
| 91 |
|
| 92 |
-
print("READY ✔ LoRA
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import torch
|
| 2 |
from transformers import T5ForConditionalGeneration, T5Tokenizer
|
| 3 |
from peft import LoraConfig, get_peft_model, TaskType
|
| 4 |
|
|
|
|
| 5 |
device = "mps" if torch.backends.mps.is_available() else "cpu"
|
| 6 |
|
| 7 |
+
MODEL_PATH = "../outputs/model" # your supervised trained model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
+
print("Loading base model...")
|
| 10 |
+
model = T5ForConditionalGeneration.from_pretrained(MODEL_PATH).to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
+
tokenizer = T5Tokenizer.from_pretrained("t5-small")
|
| 13 |
|
| 14 |
# ---------------- LoRA CONFIG ----------------
|
| 15 |
lora_config = LoraConfig(
|
| 16 |
+
r=8, # rank (small brain attachment)
|
| 17 |
lora_alpha=16,
|
| 18 |
+
target_modules=["q", "v"], # attention matrices only
|
| 19 |
lora_dropout=0.05,
|
| 20 |
bias="none",
|
| 21 |
task_type=TaskType.SEQ_2_SEQ_LM
|
|
|
|
| 26 |
|
| 27 |
model.print_trainable_parameters()
|
| 28 |
|
| 29 |
+
print("READY ✔ LoRA model loaded")
|
| 30 |
+
|
src/quantization_utils.py
DELETED
|
@@ -1,222 +0,0 @@
|
|
| 1 |
-
from __future__ import annotations
|
| 2 |
-
|
| 3 |
-
import json
|
| 4 |
-
import os
|
| 5 |
-
import time
|
| 6 |
-
from dataclasses import dataclass
|
| 7 |
-
from pathlib import Path
|
| 8 |
-
from typing import Any, Dict, Optional, Tuple
|
| 9 |
-
|
| 10 |
-
import torch
|
| 11 |
-
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
| 12 |
-
|
| 13 |
-
try:
|
| 14 |
-
from transformers import BitsAndBytesConfig # type: ignore
|
| 15 |
-
except Exception: # pragma: no cover
|
| 16 |
-
BitsAndBytesConfig = None # type: ignore
|
| 17 |
-
|
| 18 |
-
try:
|
| 19 |
-
from peft import PeftModel
|
| 20 |
-
except Exception as e: # pragma: no cover
|
| 21 |
-
PeftModel = None # type: ignore
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
@dataclass(frozen=True)
|
| 25 |
-
class QuantArtifact:
|
| 26 |
-
out_dir: Path
|
| 27 |
-
mode: str # fp32 | int8_dynamic | int8_decoder_dynamic | int8_bnb | int4_bnb
|
| 28 |
-
base_model: str
|
| 29 |
-
adapter_path: Optional[str]
|
| 30 |
-
created_at_s: float
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
def _bool_env(name: str, default: str = "0") -> bool:
|
| 34 |
-
return os.environ.get(name, default).strip() in {"1", "true", "True", "yes", "Y"}
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
def estimate_model_bytes(model: torch.nn.Module) -> int:
|
| 38 |
-
total = 0
|
| 39 |
-
for p in model.parameters():
|
| 40 |
-
total += p.numel() * p.element_size()
|
| 41 |
-
for b in model.buffers():
|
| 42 |
-
total += b.numel() * b.element_size()
|
| 43 |
-
return int(total)
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
def _load_tokenizer(base_model: str, *, local_only: bool) -> Any:
|
| 47 |
-
tok = AutoTokenizer.from_pretrained(base_model, local_files_only=local_only)
|
| 48 |
-
if tok.pad_token_id is None and getattr(tok, "eos_token_id", None) is not None:
|
| 49 |
-
tok.pad_token = tok.eos_token
|
| 50 |
-
return tok
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
def load_fp32_model(
|
| 54 |
-
base_model: str,
|
| 55 |
-
*,
|
| 56 |
-
adapter_path: Optional[str] = None,
|
| 57 |
-
device: str = "cpu",
|
| 58 |
-
local_only: bool = True,
|
| 59 |
-
torch_dtype: torch.dtype = torch.float32,
|
| 60 |
-
merge_lora: bool = True,
|
| 61 |
-
) -> Tuple[Any, torch.nn.Module]:
|
| 62 |
-
tok = _load_tokenizer(base_model, local_only=local_only)
|
| 63 |
-
model = AutoModelForSeq2SeqLM.from_pretrained(
|
| 64 |
-
base_model,
|
| 65 |
-
local_files_only=local_only,
|
| 66 |
-
torch_dtype=torch_dtype,
|
| 67 |
-
).to(device)
|
| 68 |
-
|
| 69 |
-
if adapter_path:
|
| 70 |
-
if PeftModel is None:
|
| 71 |
-
raise RuntimeError("peft is required to load adapters.")
|
| 72 |
-
model = PeftModel.from_pretrained(model, adapter_path).to(device)
|
| 73 |
-
if merge_lora and hasattr(model, "merge_and_unload"):
|
| 74 |
-
model = model.merge_and_unload()
|
| 75 |
-
model = model.to(device)
|
| 76 |
-
|
| 77 |
-
model.eval()
|
| 78 |
-
return tok, model
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
def quantize_dynamic_int8(model: torch.nn.Module) -> torch.nn.Module:
|
| 82 |
-
# CPU-only; quantized kernels run on CPU.
|
| 83 |
-
# Ensure a quantization engine is selected (PyTorch may default to "none" on macOS).
|
| 84 |
-
try:
|
| 85 |
-
supported = list(getattr(torch.backends.quantized, "supported_engines", []))
|
| 86 |
-
current = getattr(torch.backends.quantized, "engine", "none")
|
| 87 |
-
if current in {"none", None, ""}:
|
| 88 |
-
if "fbgemm" in supported:
|
| 89 |
-
torch.backends.quantized.engine = "fbgemm"
|
| 90 |
-
elif "qnnpack" in supported:
|
| 91 |
-
torch.backends.quantized.engine = "qnnpack"
|
| 92 |
-
except Exception: # pragma: no cover
|
| 93 |
-
pass
|
| 94 |
-
return torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
def quantize_dynamic_int8_decoder_only(model: Any) -> Any:
|
| 98 |
-
"""
|
| 99 |
-
Mixed-precision (Task 5): encoder fp32, decoder int8 dynamic quantized.
|
| 100 |
-
"""
|
| 101 |
-
if not hasattr(model, "decoder"):
|
| 102 |
-
raise ValueError("Model has no decoder attribute.")
|
| 103 |
-
try:
|
| 104 |
-
supported = list(getattr(torch.backends.quantized, "supported_engines", []))
|
| 105 |
-
current = getattr(torch.backends.quantized, "engine", "none")
|
| 106 |
-
if current in {"none", None, ""}:
|
| 107 |
-
if "fbgemm" in supported:
|
| 108 |
-
torch.backends.quantized.engine = "fbgemm"
|
| 109 |
-
elif "qnnpack" in supported:
|
| 110 |
-
torch.backends.quantized.engine = "qnnpack"
|
| 111 |
-
except Exception: # pragma: no cover
|
| 112 |
-
pass
|
| 113 |
-
model.decoder = torch.quantization.quantize_dynamic(model.decoder, {torch.nn.Linear}, dtype=torch.qint8)
|
| 114 |
-
return model
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
def load_bnb_quantized_model(
|
| 118 |
-
base_model: str,
|
| 119 |
-
*,
|
| 120 |
-
adapter_path: Optional[str],
|
| 121 |
-
device: str,
|
| 122 |
-
local_only: bool,
|
| 123 |
-
load_in_8bit: bool = False,
|
| 124 |
-
load_in_4bit: bool = False,
|
| 125 |
-
) -> Tuple[Any, torch.nn.Module]:
|
| 126 |
-
"""
|
| 127 |
-
bitsandbytes int8/int4 (requires bitsandbytes + CUDA). Not supported on CPU/MPS.
|
| 128 |
-
"""
|
| 129 |
-
if BitsAndBytesConfig is None:
|
| 130 |
-
raise RuntimeError("transformers BitsAndBytesConfig not available; upgrade transformers or install extras.")
|
| 131 |
-
if device != "cuda":
|
| 132 |
-
raise RuntimeError("bitsandbytes quantization requires CUDA (device=cuda).")
|
| 133 |
-
if not (load_in_8bit or load_in_4bit):
|
| 134 |
-
raise ValueError("Specify load_in_8bit or load_in_4bit.")
|
| 135 |
-
|
| 136 |
-
tok = _load_tokenizer(base_model, local_only=local_only)
|
| 137 |
-
qconf = BitsAndBytesConfig(load_in_8bit=load_in_8bit, load_in_4bit=load_in_4bit)
|
| 138 |
-
model = AutoModelForSeq2SeqLM.from_pretrained(
|
| 139 |
-
base_model,
|
| 140 |
-
local_files_only=local_only,
|
| 141 |
-
quantization_config=qconf,
|
| 142 |
-
device_map="auto",
|
| 143 |
-
)
|
| 144 |
-
if adapter_path:
|
| 145 |
-
if PeftModel is None:
|
| 146 |
-
raise RuntimeError("peft is required to load adapters.")
|
| 147 |
-
model = PeftModel.from_pretrained(model, adapter_path)
|
| 148 |
-
model.eval()
|
| 149 |
-
return tok, model
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
def save_quant_artifact(
|
| 153 |
-
out_dir: str | Path,
|
| 154 |
-
*,
|
| 155 |
-
mode: str,
|
| 156 |
-
base_model: str,
|
| 157 |
-
adapter_path: Optional[str],
|
| 158 |
-
tokenizer: Any,
|
| 159 |
-
model: torch.nn.Module,
|
| 160 |
-
) -> QuantArtifact:
|
| 161 |
-
out = Path(out_dir)
|
| 162 |
-
out.mkdir(parents=True, exist_ok=True)
|
| 163 |
-
(out / "tokenizer").mkdir(exist_ok=True)
|
| 164 |
-
|
| 165 |
-
tokenizer.save_pretrained(out / "tokenizer")
|
| 166 |
-
torch.save(model.state_dict(), out / "model.pt")
|
| 167 |
-
|
| 168 |
-
meta: Dict[str, Any] = {
|
| 169 |
-
"mode": mode,
|
| 170 |
-
"base_model": base_model,
|
| 171 |
-
"adapter_path": adapter_path,
|
| 172 |
-
"created_at_s": time.time(),
|
| 173 |
-
"estimated_model_bytes": estimate_model_bytes(model),
|
| 174 |
-
}
|
| 175 |
-
(out / "meta.json").write_text(json.dumps(meta, indent=2))
|
| 176 |
-
|
| 177 |
-
return QuantArtifact(
|
| 178 |
-
out_dir=out,
|
| 179 |
-
mode=mode,
|
| 180 |
-
base_model=base_model,
|
| 181 |
-
adapter_path=adapter_path,
|
| 182 |
-
created_at_s=float(meta["created_at_s"]),
|
| 183 |
-
)
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
def load_quant_artifact(
|
| 187 |
-
artifact_dir: str | Path,
|
| 188 |
-
*,
|
| 189 |
-
device: str = "cpu",
|
| 190 |
-
local_only: bool = True,
|
| 191 |
-
) -> Tuple[Any, torch.nn.Module, Dict[str, Any]]:
|
| 192 |
-
"""
|
| 193 |
-
Loads a previously exported quant artifact.
|
| 194 |
-
For dynamic quant modes, we reconstruct the architecture, apply the same quantization,
|
| 195 |
-
then load the saved state_dict.
|
| 196 |
-
"""
|
| 197 |
-
adir = Path(artifact_dir)
|
| 198 |
-
meta = json.loads((adir / "meta.json").read_text())
|
| 199 |
-
mode = meta["mode"]
|
| 200 |
-
base_model = meta["base_model"]
|
| 201 |
-
|
| 202 |
-
tok = AutoTokenizer.from_pretrained(adir / "tokenizer", local_files_only=True)
|
| 203 |
-
if tok.pad_token_id is None and getattr(tok, "eos_token_id", None) is not None:
|
| 204 |
-
tok.pad_token = tok.eos_token
|
| 205 |
-
|
| 206 |
-
model = AutoModelForSeq2SeqLM.from_pretrained(base_model, local_files_only=local_only).to(device)
|
| 207 |
-
model.eval()
|
| 208 |
-
|
| 209 |
-
if mode == "int8_dynamic":
|
| 210 |
-
model = quantize_dynamic_int8(model)
|
| 211 |
-
elif mode == "int8_decoder_dynamic":
|
| 212 |
-
model = quantize_dynamic_int8_decoder_only(model)
|
| 213 |
-
elif mode in {"fp32"}:
|
| 214 |
-
pass
|
| 215 |
-
else:
|
| 216 |
-
raise RuntimeError(f"Unsupported artifact mode for local loading: {mode}")
|
| 217 |
-
|
| 218 |
-
state = torch.load(adir / "model.pt", map_location=device)
|
| 219 |
-
model.load_state_dict(state, strict=False)
|
| 220 |
-
model.to(device)
|
| 221 |
-
model.eval()
|
| 222 |
-
return tok, model, meta
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/quantized_text2sql_engine.py
DELETED
|
@@ -1,243 +0,0 @@
|
|
| 1 |
-
from __future__ import annotations
|
| 2 |
-
|
| 3 |
-
import os
|
| 4 |
-
import sqlite3
|
| 5 |
-
import threading
|
| 6 |
-
import time
|
| 7 |
-
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 8 |
-
from collections import OrderedDict
|
| 9 |
-
from pathlib import Path
|
| 10 |
-
from typing import Any, Dict, List, Sequence, Tuple
|
| 11 |
-
|
| 12 |
-
import torch
|
| 13 |
-
|
| 14 |
-
from src.quantization_utils import load_quant_artifact
|
| 15 |
-
from src.schema_encoder import SchemaEncoder
|
| 16 |
-
from src.sql_validator import validate_sql_schema
|
| 17 |
-
|
| 18 |
-
# ==========================================
|
| 19 |
-
# RELATIVE PATH RESOLUTION (GLOBAL)
|
| 20 |
-
# ==========================================
|
| 21 |
-
PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
| 22 |
-
|
| 23 |
-
if (PROJECT_ROOT / "data" / "database").exists():
|
| 24 |
-
DB_ROOT = PROJECT_ROOT / "data" / "database"
|
| 25 |
-
else:
|
| 26 |
-
DB_ROOT = PROJECT_ROOT / "final_databases"
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
class QuantizedText2SQLEngine:
|
| 30 |
-
def __init__(
|
| 31 |
-
self,
|
| 32 |
-
artifact_dir: str,
|
| 33 |
-
*,
|
| 34 |
-
device: str = "cpu",
|
| 35 |
-
use_constrained: bool = False,
|
| 36 |
-
exec_workers: int | None = None,
|
| 37 |
-
default_timeout_s: float = 2.0,
|
| 38 |
-
use_cache: bool = True,
|
| 39 |
-
cache_max_entries: int = 50_000,
|
| 40 |
-
):
|
| 41 |
-
self.device = device
|
| 42 |
-
self.use_constrained = bool(use_constrained)
|
| 43 |
-
self.tokenizer, self.model, self.meta = load_quant_artifact(artifact_dir, device=device, local_only=True)
|
| 44 |
-
self.schema_encoder = SchemaEncoder(DB_ROOT)
|
| 45 |
-
|
| 46 |
-
if exec_workers is None:
|
| 47 |
-
exec_workers = int(os.environ.get("SQL_EXEC_WORKERS", "8"))
|
| 48 |
-
|
| 49 |
-
self.exec_pool = ThreadPoolExecutor(max_workers=int(exec_workers))
|
| 50 |
-
self.default_timeout_s = float(default_timeout_s)
|
| 51 |
-
self.use_cache = bool(use_cache)
|
| 52 |
-
self.cache_max_entries = int(cache_max_entries)
|
| 53 |
-
self._cache: "OrderedDict[tuple[str, str], tuple[list, list]]" = OrderedDict()
|
| 54 |
-
self._cache_lock = threading.Lock()
|
| 55 |
-
self._stats_lock = threading.Lock()
|
| 56 |
-
self._exec_cache_hits = 0
|
| 57 |
-
self._exec_cache_misses = 0
|
| 58 |
-
self._exec_calls = 0
|
| 59 |
-
self._tls = threading.local()
|
| 60 |
-
|
| 61 |
-
def _get_db_path(self, db_id: str) -> str:
|
| 62 |
-
"""Smart resolver for flat vs nested database folders"""
|
| 63 |
-
path1 = DB_ROOT / db_id / f"{db_id}.sqlite"
|
| 64 |
-
path2 = DB_ROOT / f"{db_id}.sqlite"
|
| 65 |
-
return str(path1) if path1.exists() else str(path2)
|
| 66 |
-
|
| 67 |
-
def build_prompt(self, question: str, db_id: str) -> str:
|
| 68 |
-
schema = self.schema_encoder.structured_schema(db_id)
|
| 69 |
-
return (
|
| 70 |
-
"You are a SQLite expert.\n\n"
|
| 71 |
-
f"Database: {db_id}\n\n"
|
| 72 |
-
"Schema:\n"
|
| 73 |
-
f"{schema}\n\n"
|
| 74 |
-
"Question:\n"
|
| 75 |
-
f"{question}\n\n"
|
| 76 |
-
"SQL:"
|
| 77 |
-
)
|
| 78 |
-
|
| 79 |
-
def generate_sql_batch(
|
| 80 |
-
self,
|
| 81 |
-
pairs: Sequence[Tuple[str, str]],
|
| 82 |
-
*,
|
| 83 |
-
max_new_tokens: int = 120,
|
| 84 |
-
num_beams: int = 8,
|
| 85 |
-
repetition_penalty: float = 1.2,
|
| 86 |
-
) -> List[str]:
|
| 87 |
-
prompts = [self.build_prompt(q, db_id) for q, db_id in pairs]
|
| 88 |
-
|
| 89 |
-
if self.use_constrained:
|
| 90 |
-
from transformers.generation.logits_process import LogitsProcessorList
|
| 91 |
-
from src.constrained_decoding import SchemaConstrainedLogitsProcessor
|
| 92 |
-
|
| 93 |
-
sqls: List[str] = []
|
| 94 |
-
for (q, db_id), prompt in zip(pairs, prompts):
|
| 95 |
-
db_path = self._get_db_path(db_id)
|
| 96 |
-
enc = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).to(self.device)
|
| 97 |
-
proc = LogitsProcessorList([SchemaConstrainedLogitsProcessor(self.tokenizer, db_path)])
|
| 98 |
-
|
| 99 |
-
out = self.model.generate(
|
| 100 |
-
**enc,
|
| 101 |
-
max_new_tokens=int(max_new_tokens),
|
| 102 |
-
num_beams=int(num_beams),
|
| 103 |
-
repetition_penalty=float(repetition_penalty),
|
| 104 |
-
logits_processor=proc,
|
| 105 |
-
)
|
| 106 |
-
sqls.append(self.tokenizer.decode(out[0], skip_special_tokens=True).strip())
|
| 107 |
-
return sqls
|
| 108 |
-
|
| 109 |
-
enc = self.tokenizer(prompts, return_tensors="pt", padding=True, truncation=True, max_length=512).to(self.device)
|
| 110 |
-
out = self.model.generate(
|
| 111 |
-
**enc,
|
| 112 |
-
max_new_tokens=int(max_new_tokens),
|
| 113 |
-
num_beams=int(num_beams),
|
| 114 |
-
repetition_penalty=float(repetition_penalty),
|
| 115 |
-
)
|
| 116 |
-
return [self.tokenizer.decode(x, skip_special_tokens=True).strip() for x in out]
|
| 117 |
-
|
| 118 |
-
def _get_thread_conn(self, db_path: str) -> sqlite3.Connection:
|
| 119 |
-
conns = getattr(self._tls, "conns", None)
|
| 120 |
-
if conns is None:
|
| 121 |
-
conns = {}
|
| 122 |
-
self._tls.conns = conns
|
| 123 |
-
conn = conns.get(db_path)
|
| 124 |
-
if conn is None:
|
| 125 |
-
conn = sqlite3.connect(db_path)
|
| 126 |
-
conn.text_factory = lambda b: b.decode(errors="ignore")
|
| 127 |
-
conns[db_path] = conn
|
| 128 |
-
return conn
|
| 129 |
-
|
| 130 |
-
def _cache_get(self, key: tuple[str, str]) -> tuple[list, list] | None:
|
| 131 |
-
if not self.use_cache: return None
|
| 132 |
-
with self._cache_lock:
|
| 133 |
-
hit = self._cache.get(key)
|
| 134 |
-
if hit is None: return None
|
| 135 |
-
self._cache.move_to_end(key)
|
| 136 |
-
return hit
|
| 137 |
-
|
| 138 |
-
def _cache_put(self, key: tuple[str, str], value: tuple[list, list]) -> None:
|
| 139 |
-
if not self.use_cache: return
|
| 140 |
-
with self._cache_lock:
|
| 141 |
-
self._cache[key] = value
|
| 142 |
-
self._cache.move_to_end(key)
|
| 143 |
-
while len(self._cache) > self.cache_max_entries:
|
| 144 |
-
self._cache.popitem(last=False)
|
| 145 |
-
|
| 146 |
-
def _execute_one(self, sql: str, db_path: str, timeout_s: float | None = None):
|
| 147 |
-
timeout_s = float(self.default_timeout_s if timeout_s is None else timeout_s)
|
| 148 |
-
key = (db_path, sql)
|
| 149 |
-
cached = self._cache_get(key)
|
| 150 |
-
|
| 151 |
-
with self._stats_lock: self._exec_calls += 1
|
| 152 |
-
|
| 153 |
-
if cached is not None:
|
| 154 |
-
with self._stats_lock: self._exec_cache_hits += 1
|
| 155 |
-
return cached
|
| 156 |
-
|
| 157 |
-
with self._stats_lock: self._exec_cache_misses += 1
|
| 158 |
-
|
| 159 |
-
conn = self._get_thread_conn(db_path)
|
| 160 |
-
start_t = time.monotonic()
|
| 161 |
-
|
| 162 |
-
def handler():
|
| 163 |
-
return 1 if (time.monotonic() - start_t) > timeout_s else 0
|
| 164 |
-
|
| 165 |
-
conn.set_progress_handler(handler, 10_000)
|
| 166 |
-
cur = conn.cursor()
|
| 167 |
-
cur.execute(sql)
|
| 168 |
-
rows = cur.fetchall()
|
| 169 |
-
cols = [d[0] for d in cur.description] if cur.description else []
|
| 170 |
-
out = (rows, cols)
|
| 171 |
-
self._cache_put(key, out)
|
| 172 |
-
return out
|
| 173 |
-
|
| 174 |
-
def stats(self) -> Dict[str, Any]:
|
| 175 |
-
with self._stats_lock:
|
| 176 |
-
calls, hits, misses = int(self._exec_calls), int(self._exec_cache_hits), int(self._exec_cache_misses)
|
| 177 |
-
|
| 178 |
-
hit_rate = (hits / calls) if calls else 0.0
|
| 179 |
-
return {
|
| 180 |
-
"exec_calls": calls, "exec_cache_hits": hits, "exec_cache_misses": misses,
|
| 181 |
-
"exec_cache_hit_rate": float(hit_rate), "use_cache": bool(self.use_cache),
|
| 182 |
-
"exec_workers": int(getattr(self.exec_pool, "_max_workers", 0) or 0),
|
| 183 |
-
}
|
| 184 |
-
|
| 185 |
-
def reset_stats(self) -> None:
|
| 186 |
-
with self._stats_lock:
|
| 187 |
-
self._exec_calls = self._exec_cache_hits = self._exec_cache_misses = 0
|
| 188 |
-
|
| 189 |
-
def execute_sql(self, sql: str, db_id: str, *, timeout_s: float | None = None, validate_schema: bool = True):
|
| 190 |
-
db_path = self._get_db_path(db_id)
|
| 191 |
-
if validate_schema:
|
| 192 |
-
try: ok, _ = validate_sql_schema(sql, db_path)
|
| 193 |
-
except Exception: ok = False
|
| 194 |
-
if not ok: raise ValueError("Invalid schema")
|
| 195 |
-
return self._execute_one(sql, db_path, timeout_s=timeout_s)
|
| 196 |
-
|
| 197 |
-
def ask(
|
| 198 |
-
self,
|
| 199 |
-
question: str,
|
| 200 |
-
db_id: str,
|
| 201 |
-
*,
|
| 202 |
-
max_new_tokens: int = 120,
|
| 203 |
-
num_beams: int = 8,
|
| 204 |
-
repetition_penalty: float = 1.2,
|
| 205 |
-
timeout_s: float | None = None,
|
| 206 |
-
) -> Dict[str, Any]:
|
| 207 |
-
sql = self.generate_sql_batch(
|
| 208 |
-
[(question, db_id)],
|
| 209 |
-
max_new_tokens=max_new_tokens,
|
| 210 |
-
num_beams=num_beams,
|
| 211 |
-
repetition_penalty=repetition_penalty,
|
| 212 |
-
)[0]
|
| 213 |
-
|
| 214 |
-
db_path = self._get_db_path(db_id)
|
| 215 |
-
|
| 216 |
-
try: ok, _ = validate_sql_schema(sql, db_path)
|
| 217 |
-
except Exception: ok = False
|
| 218 |
-
|
| 219 |
-
if not ok: return {"sql": sql, "rows": [], "columns": [], "error": "Invalid schema"}
|
| 220 |
-
|
| 221 |
-
try:
|
| 222 |
-
rows, cols = self._execute_one(sql, db_path, timeout_s=timeout_s)
|
| 223 |
-
return {"sql": sql, "rows": rows, "columns": cols, "error": None}
|
| 224 |
-
except Exception as e:
|
| 225 |
-
return {"sql": sql, "rows": [], "columns": [], "error": str(e)}
|
| 226 |
-
|
| 227 |
-
def ask_batch_execute(self, pairs: Sequence[Tuple[str, str]]) -> List[Dict[str, Any]]:
|
| 228 |
-
sqls = self.generate_sql_batch(pairs)
|
| 229 |
-
results: List[Dict[str, Any]] = []
|
| 230 |
-
futures = {}
|
| 231 |
-
for (q, db_id), sql in zip(pairs, sqls):
|
| 232 |
-
db_path = self._get_db_path(db_id)
|
| 233 |
-
futures[self.exec_pool.submit(self._execute_one, sql, db_path)] = (sql, db_id)
|
| 234 |
-
|
| 235 |
-
for fut in as_completed(futures):
|
| 236 |
-
sql, db_id = futures[fut]
|
| 237 |
-
try:
|
| 238 |
-
rows, cols = fut.result()
|
| 239 |
-
results.append({"db_id": db_id, "sql": sql, "rows": rows, "columns": cols, "error": None})
|
| 240 |
-
except Exception as e:
|
| 241 |
-
results.append({"db_id": db_id, "sql": sql, "rows": [], "columns": [], "error": str(e)})
|
| 242 |
-
|
| 243 |
-
return results
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/schema_encoder.py
CHANGED
|
@@ -1,38 +1,54 @@
|
|
| 1 |
import sqlite3
|
| 2 |
-
|
| 3 |
|
| 4 |
class SchemaEncoder:
|
|
|
|
| 5 |
def __init__(self, db_root):
|
| 6 |
-
self.db_root =
|
| 7 |
-
|
| 8 |
-
def
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
# Get all tables
|
| 28 |
-
cur.execute("SELECT name FROM sqlite_master WHERE type='table';")
|
| 29 |
-
tables = [r[0] for r in cur.fetchall() if r[0] != "sqlite_sequence"]
|
| 30 |
-
|
| 31 |
-
schema_str = ""
|
| 32 |
-
for table in tables:
|
| 33 |
-
cur.execute(f"PRAGMA table_info(`{table}`);")
|
| 34 |
-
cols = [c[1] for c in cur.fetchall()]
|
| 35 |
-
schema_str += f"{table} ({', '.join(cols)})\n"
|
| 36 |
-
|
| 37 |
conn.close()
|
| 38 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import sqlite3
|
| 2 |
+
|
| 3 |
|
| 4 |
class SchemaEncoder:
|
| 5 |
+
|
| 6 |
def __init__(self, db_root):
|
| 7 |
+
self.db_root = db_root
|
| 8 |
+
|
| 9 |
+
def get_tables_and_columns(self, db_id):
|
| 10 |
+
|
| 11 |
+
# FIXED PATH
|
| 12 |
+
db_path = self.db_root / f"{db_id}.sqlite"
|
| 13 |
+
|
| 14 |
+
conn = sqlite3.connect(db_path)
|
| 15 |
+
cursor = conn.cursor()
|
| 16 |
+
|
| 17 |
+
tables = cursor.execute(
|
| 18 |
+
"SELECT name FROM sqlite_master WHERE type='table';"
|
| 19 |
+
).fetchall()
|
| 20 |
+
|
| 21 |
+
schema = {}
|
| 22 |
+
|
| 23 |
+
for (table,) in tables:
|
| 24 |
+
cols = cursor.execute(f"PRAGMA table_info({table});").fetchall()
|
| 25 |
+
col_names = [c[1] for c in cols]
|
| 26 |
+
schema[table] = col_names
|
| 27 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
conn.close()
|
| 29 |
+
return schema
|
| 30 |
+
|
| 31 |
+
# -----------------------------------
|
| 32 |
+
# Strategy 1: Structured
|
| 33 |
+
# -----------------------------------
|
| 34 |
+
def structured_schema(self, db_id):
|
| 35 |
+
schema = self.get_tables_and_columns(db_id)
|
| 36 |
+
|
| 37 |
+
lines = []
|
| 38 |
+
for table, cols in schema.items():
|
| 39 |
+
lines.append(f"{table}({', '.join(cols)})")
|
| 40 |
+
|
| 41 |
+
return "\n".join(lines)
|
| 42 |
+
|
| 43 |
+
# -----------------------------------
|
| 44 |
+
# Strategy 2: Natural Language
|
| 45 |
+
# -----------------------------------
|
| 46 |
+
def natural_language_schema(self, db_id):
|
| 47 |
+
schema = self.get_tables_and_columns(db_id)
|
| 48 |
+
|
| 49 |
+
lines = []
|
| 50 |
+
for table, cols in schema.items():
|
| 51 |
+
col_text = ", ".join(cols)
|
| 52 |
+
lines.append(f"The table '{table}' contains the columns: {col_text}.")
|
| 53 |
+
|
| 54 |
+
return "\n".join(lines)
|
src/schema_utils.py
DELETED
|
@@ -1,222 +0,0 @@
|
|
| 1 |
-
# import os
|
| 2 |
-
# import sqlite3
|
| 3 |
-
# import threading
|
| 4 |
-
# from typing import Dict, List, Set, Tuple
|
| 5 |
-
|
| 6 |
-
# def get_schema(db_path):
|
| 7 |
-
# schema_map = get_table_to_columns(db_path)
|
| 8 |
-
# schema_text = ""
|
| 9 |
-
# for table, col_names in schema_map.items():
|
| 10 |
-
# schema_text += f"{table}({', '.join(col_names)})\n"
|
| 11 |
-
# return schema_text
|
| 12 |
-
|
| 13 |
-
# _SCHEMA_LOCK = threading.Lock()
|
| 14 |
-
# _SCHEMA_CACHE: Dict[str, Tuple[str, Dict[str, List[str]]]] = {}
|
| 15 |
-
|
| 16 |
-
# def _db_state_fingerprint(db_path: str) -> str:
|
| 17 |
-
# try:
|
| 18 |
-
# st = os.stat(db_path)
|
| 19 |
-
# return f"{st.st_mtime_ns}:{st.st_size}"
|
| 20 |
-
# except OSError:
|
| 21 |
-
# return "missing"
|
| 22 |
-
|
| 23 |
-
# def _connect_readonly(db_path: str) -> sqlite3.Connection:
|
| 24 |
-
# uri = f"file:{os.path.abspath(db_path)}?mode=ro"
|
| 25 |
-
# conn = sqlite3.connect(uri, uri=True, check_same_thread=False)
|
| 26 |
-
# conn.execute("PRAGMA query_only = ON;")
|
| 27 |
-
# conn.execute("PRAGMA foreign_keys = ON;")
|
| 28 |
-
# return conn
|
| 29 |
-
|
| 30 |
-
# def get_table_to_columns(db_path: str) -> Dict[str, List[str]]:
|
| 31 |
-
# """
|
| 32 |
-
# Return mapping of table -> column names for the SQLite DB at db_path.
|
| 33 |
-
# Tables and columns are returned lowercased.
|
| 34 |
-
# """
|
| 35 |
-
# fp = _db_state_fingerprint(db_path)
|
| 36 |
-
# with _SCHEMA_LOCK:
|
| 37 |
-
# cached = _SCHEMA_CACHE.get(db_path)
|
| 38 |
-
# if cached is not None and cached[0] == fp:
|
| 39 |
-
# return cached[1]
|
| 40 |
-
|
| 41 |
-
# schema: Dict[str, List[str]] = {}
|
| 42 |
-
# with _connect_readonly(db_path) as conn:
|
| 43 |
-
# cur = conn.execute(
|
| 44 |
-
# "SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%';"
|
| 45 |
-
# )
|
| 46 |
-
# tables = [r[0] for r in cur.fetchall() if r and isinstance(r[0], str)]
|
| 47 |
-
# for table in tables:
|
| 48 |
-
# table_l = table.lower()
|
| 49 |
-
# try:
|
| 50 |
-
# cur = conn.execute(f'PRAGMA table_info("{table}")')
|
| 51 |
-
# cols = [row[1].lower() for row in cur.fetchall() if row and isinstance(row[1], str)]
|
| 52 |
-
# schema[table_l] = cols
|
| 53 |
-
# except sqlite3.Error:
|
| 54 |
-
# schema[table_l] = []
|
| 55 |
-
|
| 56 |
-
# with _SCHEMA_LOCK:
|
| 57 |
-
# _SCHEMA_CACHE[db_path] = (fp, schema)
|
| 58 |
-
# return schema
|
| 59 |
-
|
| 60 |
-
# def get_db_tables_and_columns(db_path: str) -> Tuple[Set[str], Set[str]]:
|
| 61 |
-
# schema = get_table_to_columns(db_path)
|
| 62 |
-
# tables = set(schema.keys())
|
| 63 |
-
# columns: Set[str] = set()
|
| 64 |
-
# for cols in schema.values():
|
| 65 |
-
# columns.update(cols)
|
| 66 |
-
# return tables, columns
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
import os
|
| 70 |
-
import sqlite3
|
| 71 |
-
import threading
|
| 72 |
-
from typing import Dict, List, Set, Tuple
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
# ===============================
|
| 76 |
-
# 🔥 SCHEMA TEXT (for prompting)
|
| 77 |
-
# ===============================
|
| 78 |
-
def get_schema(db_path: str) -> str:
|
| 79 |
-
schema_map = get_table_to_columns(db_path)
|
| 80 |
-
schema_text = ""
|
| 81 |
-
|
| 82 |
-
for table, col_names in schema_map.items():
|
| 83 |
-
schema_text += f"{table}({', '.join(col_names)})\n"
|
| 84 |
-
|
| 85 |
-
return schema_text
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
# ===============================
|
| 89 |
-
# 🔥 CACHE + LOCK
|
| 90 |
-
# ===============================
|
| 91 |
-
_SCHEMA_LOCK = threading.Lock()
|
| 92 |
-
_SCHEMA_CACHE: Dict[str, Tuple[str, Dict[str, List[str]]]] = {}
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
def _db_state_fingerprint(db_path: str) -> str:
|
| 96 |
-
try:
|
| 97 |
-
st = os.stat(db_path)
|
| 98 |
-
return f"{st.st_mtime_ns}:{st.st_size}"
|
| 99 |
-
except OSError:
|
| 100 |
-
return "missing"
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
def _connect_readonly(db_path: str) -> sqlite3.Connection:
|
| 104 |
-
uri = f"file:{os.path.abspath(db_path)}?mode=ro"
|
| 105 |
-
conn = sqlite3.connect(uri, uri=True, check_same_thread=False)
|
| 106 |
-
conn.execute("PRAGMA query_only = ON;")
|
| 107 |
-
conn.execute("PRAGMA foreign_keys = ON;")
|
| 108 |
-
return conn
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
# ===============================
|
| 112 |
-
# 🔥 CORE: TABLE → COLUMNS
|
| 113 |
-
# ===============================
|
| 114 |
-
def get_table_to_columns(db_path: str) -> Dict[str, List[str]]:
|
| 115 |
-
"""
|
| 116 |
-
Return mapping of table -> column names (ONLY names, no types).
|
| 117 |
-
"""
|
| 118 |
-
fp = _db_state_fingerprint(db_path)
|
| 119 |
-
|
| 120 |
-
with _SCHEMA_LOCK:
|
| 121 |
-
cached = _SCHEMA_CACHE.get(db_path)
|
| 122 |
-
if cached is not None and cached[0] == fp:
|
| 123 |
-
return cached[1]
|
| 124 |
-
|
| 125 |
-
schema: Dict[str, List[str]] = {}
|
| 126 |
-
|
| 127 |
-
with _connect_readonly(db_path) as conn:
|
| 128 |
-
cur = conn.execute(
|
| 129 |
-
"SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%';"
|
| 130 |
-
)
|
| 131 |
-
|
| 132 |
-
tables = [r[0] for r in cur.fetchall() if r and isinstance(r[0], str)]
|
| 133 |
-
|
| 134 |
-
for table in tables:
|
| 135 |
-
table_l = table.lower()
|
| 136 |
-
|
| 137 |
-
try:
|
| 138 |
-
cur = conn.execute(f'PRAGMA table_info("{table}")')
|
| 139 |
-
|
| 140 |
-
cols = []
|
| 141 |
-
for row in cur.fetchall():
|
| 142 |
-
col_name = row[1].lower()
|
| 143 |
-
cols.append(col_name)
|
| 144 |
-
|
| 145 |
-
schema[table_l] = list(set(cols)) # remove duplicates
|
| 146 |
-
|
| 147 |
-
except sqlite3.Error:
|
| 148 |
-
schema[table_l] = []
|
| 149 |
-
|
| 150 |
-
with _SCHEMA_LOCK:
|
| 151 |
-
_SCHEMA_CACHE[db_path] = (fp, schema)
|
| 152 |
-
|
| 153 |
-
return schema
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
# ===============================
|
| 157 |
-
# 🔥 TABLE + COLUMN SETS
|
| 158 |
-
# ===============================
|
| 159 |
-
def get_db_tables_and_columns(db_path: str) -> Tuple[Set[str], Set[str]]:
|
| 160 |
-
schema = get_table_to_columns(db_path)
|
| 161 |
-
|
| 162 |
-
tables = set(schema.keys())
|
| 163 |
-
columns: Set[str] = set()
|
| 164 |
-
|
| 165 |
-
for cols in schema.values():
|
| 166 |
-
columns.update(cols)
|
| 167 |
-
|
| 168 |
-
return tables, columns
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
# ===============================
|
| 172 |
-
# 🔥 FOREIGN KEYS (IMPORTANT)
|
| 173 |
-
# ===============================
|
| 174 |
-
def get_foreign_keys(db_path: str) -> List[Tuple[str, str, str, str]]:
|
| 175 |
-
"""
|
| 176 |
-
Returns list of foreign key relations:
|
| 177 |
-
(table, column, ref_table, ref_column)
|
| 178 |
-
"""
|
| 179 |
-
fks = []
|
| 180 |
-
|
| 181 |
-
with _connect_readonly(db_path) as conn:
|
| 182 |
-
cur = conn.execute(
|
| 183 |
-
"SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%';"
|
| 184 |
-
)
|
| 185 |
-
tables = [r[0] for r in cur.fetchall()]
|
| 186 |
-
|
| 187 |
-
for table in tables:
|
| 188 |
-
try:
|
| 189 |
-
cur = conn.execute(f'PRAGMA foreign_key_list("{table}")')
|
| 190 |
-
|
| 191 |
-
for row in cur.fetchall():
|
| 192 |
-
fks.append((
|
| 193 |
-
table.lower(),
|
| 194 |
-
row[3].lower(), # column
|
| 195 |
-
row[2].lower(), # referenced table
|
| 196 |
-
row[4].lower() # referenced column
|
| 197 |
-
))
|
| 198 |
-
|
| 199 |
-
except sqlite3.Error:
|
| 200 |
-
continue
|
| 201 |
-
|
| 202 |
-
return fks
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
# ===============================
|
| 206 |
-
# 🔥 FINAL: CONSTRAINT GRAPH
|
| 207 |
-
# ===============================
|
| 208 |
-
def get_constraint_graph(db_path: str):
|
| 209 |
-
"""
|
| 210 |
-
Build schema constraint graph:
|
| 211 |
-
- tables
|
| 212 |
-
- columns
|
| 213 |
-
- foreign key relations
|
| 214 |
-
"""
|
| 215 |
-
tables, columns = get_db_tables_and_columns(db_path)
|
| 216 |
-
fks = get_foreign_keys(db_path)
|
| 217 |
-
|
| 218 |
-
return {
|
| 219 |
-
"tables": tables,
|
| 220 |
-
"columns": columns,
|
| 221 |
-
"foreign_keys": fks
|
| 222 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/sql_validator.py
CHANGED
|
@@ -1,209 +1,6 @@
|
|
| 1 |
-
|
| 2 |
-
# from pathlib import Path
|
| 3 |
-
# from typing import Optional, Set, Tuple
|
| 4 |
-
|
| 5 |
-
# from schema_utils import get_db_tables_and_columns, get_table_to_columns
|
| 6 |
-
|
| 7 |
-
# class SQLValidator:
|
| 8 |
-
|
| 9 |
-
# def __init__(self, db_root):
|
| 10 |
-
# self.db_root = Path(db_root)
|
| 11 |
-
|
| 12 |
-
# # ---------------------------
|
| 13 |
-
# # Load schema
|
| 14 |
-
# # ---------------------------
|
| 15 |
-
# def load_schema(self, db_id):
|
| 16 |
-
# db_path = self.db_root / db_id / f"{db_id}.sqlite"
|
| 17 |
-
# return get_table_to_columns(str(db_path))
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
# # ---------------------------
|
| 21 |
-
# # Basic syntax check
|
| 22 |
-
# # ---------------------------
|
| 23 |
-
# def basic_structure_valid(self, sql):
|
| 24 |
-
# s = sql.lower()
|
| 25 |
-
|
| 26 |
-
# if "select" not in s or "from" not in s:
|
| 27 |
-
# return False, "Missing SELECT or FROM"
|
| 28 |
-
|
| 29 |
-
# if len(s.split()) < 4:
|
| 30 |
-
# return False, "Too short to be SQL"
|
| 31 |
-
|
| 32 |
-
# return True, None
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
# # ---------------------------
|
| 36 |
-
# # Extract identifiers
|
| 37 |
-
# # ---------------------------
|
| 38 |
-
# def extract_identifiers(self, sql):
|
| 39 |
-
# tokens = re.findall(r"[A-Za-z_]+", sql.lower())
|
| 40 |
-
# return set(tokens)
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
# # ---------------------------
|
| 44 |
-
# # Table validation
|
| 45 |
-
# # ---------------------------
|
| 46 |
-
# def validate_tables(self, sql, schema):
|
| 47 |
-
# words = self.extract_identifiers(sql)
|
| 48 |
-
# tables = set(schema.keys())
|
| 49 |
-
|
| 50 |
-
# used_tables = [w for w in words if w in tables]
|
| 51 |
-
|
| 52 |
-
# if not used_tables:
|
| 53 |
-
# return False, "No valid table used"
|
| 54 |
-
|
| 55 |
-
# return True, None
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
# # ---------------------------
|
| 59 |
-
# # Column validation
|
| 60 |
-
# # ---------------------------
|
| 61 |
-
# def validate_columns(self, sql, schema):
|
| 62 |
-
# words = self.extract_identifiers(sql)
|
| 63 |
-
|
| 64 |
-
# valid_columns = set()
|
| 65 |
-
# for cols in schema.values():
|
| 66 |
-
# valid_columns.update(cols)
|
| 67 |
-
|
| 68 |
-
# # ignore SQL keywords
|
| 69 |
-
# keywords = {
|
| 70 |
-
# "select","from","where","join","on","group","by",
|
| 71 |
-
# "order","limit","count","sum","avg","min","max",
|
| 72 |
-
# "and","or","in","like","distinct","asc","desc"
|
| 73 |
-
# }
|
| 74 |
-
|
| 75 |
-
# invalid = []
|
| 76 |
-
# for w in words:
|
| 77 |
-
# if w not in valid_columns and w not in schema and w not in keywords:
|
| 78 |
-
# if not w.isdigit():
|
| 79 |
-
# invalid.append(w)
|
| 80 |
-
|
| 81 |
-
# # allow small hallucinations but block many
|
| 82 |
-
# if len(invalid) > 3:
|
| 83 |
-
# return False, f"Too many unknown identifiers: {invalid[:5]}"
|
| 84 |
-
|
| 85 |
-
# return True, None
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
# # ---------------------------
|
| 89 |
-
# # Dangerous query protection
|
| 90 |
-
# # ---------------------------
|
| 91 |
-
# def block_dangerous(self, sql):
|
| 92 |
-
# bad = ["drop", "delete", "update", "insert", "alter"]
|
| 93 |
-
|
| 94 |
-
# s = sql.lower()
|
| 95 |
-
# for b in bad:
|
| 96 |
-
# if b in s:
|
| 97 |
-
# return False, f"Dangerous keyword detected: {b}"
|
| 98 |
-
|
| 99 |
-
# return True, None
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
# # ---------------------------
|
| 103 |
-
# # Main validation
|
| 104 |
-
# # ---------------------------
|
| 105 |
-
# def validate(self, sql, db_id):
|
| 106 |
-
|
| 107 |
-
# schema = self.load_schema(db_id)
|
| 108 |
-
|
| 109 |
-
# checks = [
|
| 110 |
-
# self.block_dangerous(sql),
|
| 111 |
-
# self.basic_structure_valid(sql),
|
| 112 |
-
# self.validate_tables(sql, schema),
|
| 113 |
-
# self.validate_columns(sql, schema),
|
| 114 |
-
# ]
|
| 115 |
-
|
| 116 |
-
# for ok, msg in checks:
|
| 117 |
-
# if not ok:
|
| 118 |
-
# return False, msg
|
| 119 |
-
|
| 120 |
-
# return True, None
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
# _VALIDATION_CACHE = {}
|
| 124 |
-
# _VALIDATION_CACHE_MAX = 100_000
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
# def _db_state_fingerprint(db_path: str) -> str:
|
| 128 |
-
# try:
|
| 129 |
-
# st = Path(db_path).stat()
|
| 130 |
-
# return f"{st.st_mtime_ns}:{st.st_size}"
|
| 131 |
-
# except OSError:
|
| 132 |
-
# return "missing"
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
# def _extract_referenced_tables(sql: str) -> Set[str]:
|
| 136 |
-
# # Best-effort: FROM/JOIN targets (unquoted identifiers).
|
| 137 |
-
# tokens = re.findall(r"\b(from|join)\s+([a-zA-Z_][\w$]*)", sql, flags=re.I)
|
| 138 |
-
# return {t[1].lower() for t in tokens if t and len(t) > 1}
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
# def validate_sql_schema(sql: str, db_path: str) -> Tuple[bool, Optional[str]]:
|
| 142 |
-
# """
|
| 143 |
-
# Strict schema validation for reward computation.
|
| 144 |
-
# - References must resolve to real tables/columns in the target DB.
|
| 145 |
-
# - Returns (ok, message). On failure, message is a short reason.
|
| 146 |
-
# """
|
| 147 |
-
# fp = _db_state_fingerprint(db_path)
|
| 148 |
-
# key = f"{fp}|{sql}"
|
| 149 |
-
# cached = _VALIDATION_CACHE.get(key)
|
| 150 |
-
# if cached is not None:
|
| 151 |
-
# return cached
|
| 152 |
-
|
| 153 |
-
# valid_tables, valid_columns = get_db_tables_and_columns(db_path)
|
| 154 |
-
|
| 155 |
-
# referenced_tables = _extract_referenced_tables(sql)
|
| 156 |
-
# unknown_tables = sorted(t for t in referenced_tables if t not in valid_tables)
|
| 157 |
-
# if unknown_tables:
|
| 158 |
-
# out = (False, f"Unknown table(s): {unknown_tables[:5]}")
|
| 159 |
-
# if len(_VALIDATION_CACHE) >= _VALIDATION_CACHE_MAX:
|
| 160 |
-
# _VALIDATION_CACHE.clear()
|
| 161 |
-
# _VALIDATION_CACHE[key] = out
|
| 162 |
-
# return out
|
| 163 |
-
|
| 164 |
-
# # Column-level correctness is hard to do reliably with regex alone; rely on SQLite compilation.
|
| 165 |
-
# # This does not execute the query, but will fail for unknown tables/columns.
|
| 166 |
-
# try:
|
| 167 |
-
# import sqlite3 # local import to keep module lightweight
|
| 168 |
-
|
| 169 |
-
# uri = f"file:{Path(db_path).resolve()}?mode=ro"
|
| 170 |
-
# conn = sqlite3.connect(uri, uri=True, check_same_thread=False)
|
| 171 |
-
# try:
|
| 172 |
-
# conn.execute("PRAGMA query_only = ON;")
|
| 173 |
-
# conn.execute("PRAGMA foreign_keys = ON;")
|
| 174 |
-
# conn.execute(f"EXPLAIN QUERY PLAN {sql}")
|
| 175 |
-
# finally:
|
| 176 |
-
# conn.close()
|
| 177 |
-
# except Exception as e:
|
| 178 |
-
# msg = str(e).lower()
|
| 179 |
-
# if "no such table" in msg:
|
| 180 |
-
# out = (False, "Unknown table")
|
| 181 |
-
# elif "no such column" in msg:
|
| 182 |
-
# out = (False, "Unknown column")
|
| 183 |
-
# else:
|
| 184 |
-
# out = (False, "Schema validation failed")
|
| 185 |
-
|
| 186 |
-
# if len(_VALIDATION_CACHE) >= _VALIDATION_CACHE_MAX:
|
| 187 |
-
# _VALIDATION_CACHE.clear()
|
| 188 |
-
# _VALIDATION_CACHE[key] = out
|
| 189 |
-
# return out
|
| 190 |
-
|
| 191 |
-
# out = (True, None)
|
| 192 |
-
# if len(_VALIDATION_CACHE) >= _VALIDATION_CACHE_MAX:
|
| 193 |
-
# _VALIDATION_CACHE.clear()
|
| 194 |
-
# _VALIDATION_CACHE[key] = out
|
| 195 |
-
# return out
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
import re
|
| 202 |
from pathlib import Path
|
| 203 |
-
from typing import Optional, Set, Tuple, Dict, List
|
| 204 |
-
|
| 205 |
-
from src.schema_utils import get_db_tables_and_columns, get_table_to_columns, get_constraint_graph
|
| 206 |
-
|
| 207 |
|
| 208 |
class SQLValidator:
|
| 209 |
|
|
@@ -215,7 +12,23 @@ class SQLValidator:
|
|
| 215 |
# ---------------------------
|
| 216 |
def load_schema(self, db_id):
|
| 217 |
db_path = self.db_root / db_id / f"{db_id}.sqlite"
|
| 218 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 219 |
|
| 220 |
# ---------------------------
|
| 221 |
# Basic syntax check
|
|
@@ -231,13 +44,15 @@ class SQLValidator:
|
|
| 231 |
|
| 232 |
return True, None
|
| 233 |
|
|
|
|
| 234 |
# ---------------------------
|
| 235 |
# Extract identifiers
|
| 236 |
# ---------------------------
|
| 237 |
def extract_identifiers(self, sql):
|
| 238 |
-
tokens = re.findall(r"[A-Za-z_]
|
| 239 |
return set(tokens)
|
| 240 |
|
|
|
|
| 241 |
# ---------------------------
|
| 242 |
# Table validation
|
| 243 |
# ---------------------------
|
|
@@ -252,6 +67,7 @@ class SQLValidator:
|
|
| 252 |
|
| 253 |
return True, None
|
| 254 |
|
|
|
|
| 255 |
# ---------------------------
|
| 256 |
# Column validation
|
| 257 |
# ---------------------------
|
|
@@ -262,29 +78,26 @@ class SQLValidator:
|
|
| 262 |
for cols in schema.values():
|
| 263 |
valid_columns.update(cols)
|
| 264 |
|
|
|
|
| 265 |
keywords = {
|
| 266 |
"select","from","where","join","on","group","by",
|
| 267 |
"order","limit","count","sum","avg","min","max",
|
| 268 |
-
"and","or","in","like","distinct","asc","desc"
|
| 269 |
-
"having","as","inner","left","right","outer"
|
| 270 |
}
|
| 271 |
|
| 272 |
invalid = []
|
| 273 |
for w in words:
|
| 274 |
-
if
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
# stricter than before
|
| 283 |
-
if len(invalid) > 2:
|
| 284 |
-
return False, f"Unknown identifiers: {invalid[:5]}"
|
| 285 |
|
| 286 |
return True, None
|
| 287 |
|
|
|
|
| 288 |
# ---------------------------
|
| 289 |
# Dangerous query protection
|
| 290 |
# ---------------------------
|
|
@@ -298,18 +111,6 @@ class SQLValidator:
|
|
| 298 |
|
| 299 |
return True, None
|
| 300 |
|
| 301 |
-
# ---------------------------
|
| 302 |
-
# FK-aware JOIN validation (NEW 🔥)
|
| 303 |
-
# ---------------------------
|
| 304 |
-
def validate_joins(self, db_id):
|
| 305 |
-
db_path = self.db_root / db_id / f"{db_id}.sqlite"
|
| 306 |
-
graph = get_constraint_graph(str(db_path))
|
| 307 |
-
|
| 308 |
-
# not strict enforcement, just check FK existence
|
| 309 |
-
if len(graph["foreign_keys"]) == 0:
|
| 310 |
-
return True, None
|
| 311 |
-
|
| 312 |
-
return True, None # placeholder (safe for now)
|
| 313 |
|
| 314 |
# ---------------------------
|
| 315 |
# Main validation
|
|
@@ -330,86 +131,3 @@ class SQLValidator:
|
|
| 330 |
return False, msg
|
| 331 |
|
| 332 |
return True, None
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
# ===============================
|
| 336 |
-
# 🔥 FAST SCHEMA VALIDATION (REWARD)
|
| 337 |
-
# ===============================
|
| 338 |
-
_VALIDATION_CACHE = {}
|
| 339 |
-
_VALIDATION_CACHE_MAX = 100_000
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
def _db_state_fingerprint(db_path: str) -> str:
|
| 343 |
-
try:
|
| 344 |
-
st = Path(db_path).stat()
|
| 345 |
-
return f"{st.st_mtime_ns}:{st.st_size}"
|
| 346 |
-
except OSError:
|
| 347 |
-
return "missing"
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
def _extract_referenced_tables(sql: str) -> Set[str]:
|
| 351 |
-
tokens = re.findall(r"\b(from|join)\s+([a-zA-Z_][\w$]*)", sql, flags=re.I)
|
| 352 |
-
return {t[1].lower() for t in tokens if t and len(t) > 1}
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
def validate_sql_schema(sql: str, db_path: str) -> Tuple[bool, Optional[str]]:
|
| 356 |
-
"""
|
| 357 |
-
STRICT schema validation (Task 3 core)
|
| 358 |
-
"""
|
| 359 |
-
|
| 360 |
-
fp = _db_state_fingerprint(db_path)
|
| 361 |
-
key = f"{fp}|{sql}"
|
| 362 |
-
|
| 363 |
-
cached = _VALIDATION_CACHE.get(key)
|
| 364 |
-
if cached is not None:
|
| 365 |
-
return cached
|
| 366 |
-
|
| 367 |
-
valid_tables, valid_columns = get_db_tables_and_columns(db_path)
|
| 368 |
-
|
| 369 |
-
# ---------------------------
|
| 370 |
-
# Table validation
|
| 371 |
-
# ---------------------------
|
| 372 |
-
referenced_tables = _extract_referenced_tables(sql)
|
| 373 |
-
|
| 374 |
-
unknown_tables = [t for t in referenced_tables if t not in valid_tables]
|
| 375 |
-
|
| 376 |
-
if unknown_tables:
|
| 377 |
-
out = (False, f"Unknown table(s): {unknown_tables[:3]}")
|
| 378 |
-
_VALIDATION_CACHE[key] = out
|
| 379 |
-
return out
|
| 380 |
-
|
| 381 |
-
# ---------------------------
|
| 382 |
-
# Column validation via SQLite planner
|
| 383 |
-
# ---------------------------
|
| 384 |
-
try:
|
| 385 |
-
import sqlite3
|
| 386 |
-
|
| 387 |
-
uri = f"file:{Path(db_path).resolve()}?mode=ro"
|
| 388 |
-
conn = sqlite3.connect(uri, uri=True, check_same_thread=False)
|
| 389 |
-
|
| 390 |
-
try:
|
| 391 |
-
conn.execute("PRAGMA query_only = ON;")
|
| 392 |
-
conn.execute("PRAGMA foreign_keys = ON;")
|
| 393 |
-
|
| 394 |
-
# 🔥 Key idea: no execution, only planning
|
| 395 |
-
conn.execute(f"EXPLAIN QUERY PLAN {sql}")
|
| 396 |
-
|
| 397 |
-
finally:
|
| 398 |
-
conn.close()
|
| 399 |
-
|
| 400 |
-
except Exception as e:
|
| 401 |
-
msg = str(e).lower()
|
| 402 |
-
|
| 403 |
-
if "no such table" in msg:
|
| 404 |
-
out = (False, "Unknown table")
|
| 405 |
-
elif "no such column" in msg:
|
| 406 |
-
out = (False, "Unknown column")
|
| 407 |
-
else:
|
| 408 |
-
out = (False, "Invalid SQL")
|
| 409 |
-
|
| 410 |
-
_VALIDATION_CACHE[key] = out
|
| 411 |
-
return out
|
| 412 |
-
|
| 413 |
-
out = (True, None)
|
| 414 |
-
_VALIDATION_CACHE[key] = out
|
| 415 |
-
return out
|
|
|
|
| 1 |
+
import sqlite3
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
import re
|
| 3 |
from pathlib import Path
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
class SQLValidator:
|
| 6 |
|
|
|
|
| 12 |
# ---------------------------
|
| 13 |
def load_schema(self, db_id):
|
| 14 |
db_path = self.db_root / db_id / f"{db_id}.sqlite"
|
| 15 |
+
|
| 16 |
+
conn = sqlite3.connect(db_path)
|
| 17 |
+
cursor = conn.cursor()
|
| 18 |
+
|
| 19 |
+
tables = cursor.execute(
|
| 20 |
+
"SELECT name FROM sqlite_master WHERE type='table';"
|
| 21 |
+
).fetchall()
|
| 22 |
+
|
| 23 |
+
schema = {}
|
| 24 |
+
|
| 25 |
+
for (table,) in tables:
|
| 26 |
+
cols = cursor.execute(f"PRAGMA table_info({table});").fetchall()
|
| 27 |
+
schema[table.lower()] = [c[1].lower() for c in cols]
|
| 28 |
+
|
| 29 |
+
conn.close()
|
| 30 |
+
return schema
|
| 31 |
+
|
| 32 |
|
| 33 |
# ---------------------------
|
| 34 |
# Basic syntax check
|
|
|
|
| 44 |
|
| 45 |
return True, None
|
| 46 |
|
| 47 |
+
|
| 48 |
# ---------------------------
|
| 49 |
# Extract identifiers
|
| 50 |
# ---------------------------
|
| 51 |
def extract_identifiers(self, sql):
|
| 52 |
+
tokens = re.findall(r"[A-Za-z_]+", sql.lower())
|
| 53 |
return set(tokens)
|
| 54 |
|
| 55 |
+
|
| 56 |
# ---------------------------
|
| 57 |
# Table validation
|
| 58 |
# ---------------------------
|
|
|
|
| 67 |
|
| 68 |
return True, None
|
| 69 |
|
| 70 |
+
|
| 71 |
# ---------------------------
|
| 72 |
# Column validation
|
| 73 |
# ---------------------------
|
|
|
|
| 78 |
for cols in schema.values():
|
| 79 |
valid_columns.update(cols)
|
| 80 |
|
| 81 |
+
# ignore SQL keywords
|
| 82 |
keywords = {
|
| 83 |
"select","from","where","join","on","group","by",
|
| 84 |
"order","limit","count","sum","avg","min","max",
|
| 85 |
+
"and","or","in","like","distinct","asc","desc"
|
|
|
|
| 86 |
}
|
| 87 |
|
| 88 |
invalid = []
|
| 89 |
for w in words:
|
| 90 |
+
if w not in valid_columns and w not in schema and w not in keywords:
|
| 91 |
+
if not w.isdigit():
|
| 92 |
+
invalid.append(w)
|
| 93 |
+
|
| 94 |
+
# allow small hallucinations but block many
|
| 95 |
+
if len(invalid) > 3:
|
| 96 |
+
return False, f"Too many unknown identifiers: {invalid[:5]}"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
|
| 98 |
return True, None
|
| 99 |
|
| 100 |
+
|
| 101 |
# ---------------------------
|
| 102 |
# Dangerous query protection
|
| 103 |
# ---------------------------
|
|
|
|
| 111 |
|
| 112 |
return True, None
|
| 113 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
|
| 115 |
# ---------------------------
|
| 116 |
# Main validation
|
|
|
|
| 131 |
return False, msg
|
| 132 |
|
| 133 |
return True, None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/text2sql_engine.py
CHANGED
|
@@ -1,223 +1,3 @@
|
|
| 1 |
-
# import sqlite3
|
| 2 |
-
# import torch
|
| 3 |
-
# import re
|
| 4 |
-
# import time
|
| 5 |
-
# from pathlib import Path
|
| 6 |
-
# from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
| 7 |
-
# from peft import PeftModel
|
| 8 |
-
# from src.sql_validator import SQLValidator
|
| 9 |
-
# from src.schema_encoder import SchemaEncoder
|
| 10 |
-
|
| 11 |
-
# PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
| 12 |
-
|
| 13 |
-
# # ================================
|
| 14 |
-
# # DATABASE PATH AUTO DETECTION
|
| 15 |
-
# # ================================
|
| 16 |
-
# if (PROJECT_ROOT / "data/database").exists():
|
| 17 |
-
# DB_ROOT = PROJECT_ROOT / "data/database"
|
| 18 |
-
# else:
|
| 19 |
-
# DB_ROOT = PROJECT_ROOT / "final_databases"
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
# def normalize_question(q: str):
|
| 23 |
-
# q = q.lower().strip()
|
| 24 |
-
# q = re.sub(r"distinct\s+(\d+)", r"\1 distinct", q)
|
| 25 |
-
# q = re.sub(r"\s+", " ", q)
|
| 26 |
-
# return q
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
# def semantic_fix(question, sql):
|
| 30 |
-
# q = question.lower().strip()
|
| 31 |
-
# s = sql.lower()
|
| 32 |
-
|
| 33 |
-
# num_match = re.search(r'\b(?:show|list|top|limit|get|first|last)\s+(\d+)\b', q)
|
| 34 |
-
|
| 35 |
-
# if num_match and "limit" not in s and "count(" not in s:
|
| 36 |
-
# limit_val = num_match.group(1)
|
| 37 |
-
# sql = sql.rstrip(";")
|
| 38 |
-
# sql = f"{sql.strip()} LIMIT {limit_val}"
|
| 39 |
-
|
| 40 |
-
# return sql
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
# class Text2SQLEngine:
|
| 44 |
-
# def __init__(self,
|
| 45 |
-
# adapter_path=None,
|
| 46 |
-
# base_model_name="Salesforce/codet5-base",
|
| 47 |
-
# use_lora=True):
|
| 48 |
-
|
| 49 |
-
# self.device = "mps" if torch.backends.mps.is_available() else (
|
| 50 |
-
# "cuda" if torch.cuda.is_available() else "cpu"
|
| 51 |
-
# )
|
| 52 |
-
|
| 53 |
-
# self.validator = SQLValidator(DB_ROOT)
|
| 54 |
-
# self.schema_encoder = SchemaEncoder(DB_ROOT)
|
| 55 |
-
|
| 56 |
-
# self.dml_keywords = r'\b(delete|update|insert|drop|alter|truncate)\b'
|
| 57 |
-
|
| 58 |
-
# print("Loading base model...")
|
| 59 |
-
# base = AutoModelForSeq2SeqLM.from_pretrained(base_model_name)
|
| 60 |
-
|
| 61 |
-
# if not use_lora:
|
| 62 |
-
# self.tokenizer = AutoTokenizer.from_pretrained(base_model_name)
|
| 63 |
-
# self.model = base.to(self.device)
|
| 64 |
-
# self.model.eval()
|
| 65 |
-
# return
|
| 66 |
-
|
| 67 |
-
# if (PROJECT_ROOT / "checkpoints/best_rlhf_model").exists():
|
| 68 |
-
# adapter_path = PROJECT_ROOT / "checkpoints/best_rlhf_model"
|
| 69 |
-
# else:
|
| 70 |
-
# adapter_path = PROJECT_ROOT / "best_rlhf_model"
|
| 71 |
-
|
| 72 |
-
# adapter_path = adapter_path.resolve()
|
| 73 |
-
|
| 74 |
-
# print("Loading tokenizer and LoRA adapter...")
|
| 75 |
-
|
| 76 |
-
# try:
|
| 77 |
-
# self.tokenizer = AutoTokenizer.from_pretrained(
|
| 78 |
-
# str(adapter_path),
|
| 79 |
-
# local_files_only=True
|
| 80 |
-
# )
|
| 81 |
-
# except Exception:
|
| 82 |
-
# self.tokenizer = AutoTokenizer.from_pretrained(base_model_name)
|
| 83 |
-
|
| 84 |
-
# self.model = PeftModel.from_pretrained(base, str(adapter_path)).to(self.device)
|
| 85 |
-
# self.model.eval()
|
| 86 |
-
|
| 87 |
-
# print("✅ RLHF model ready\n")
|
| 88 |
-
|
| 89 |
-
# def build_prompt(self, question, schema):
|
| 90 |
-
# return f"""You are an expert SQL generator.
|
| 91 |
-
# Database schema:
|
| 92 |
-
# {schema}
|
| 93 |
-
# Generate a valid SQLite query for the question.
|
| 94 |
-
# Question:
|
| 95 |
-
# {question}
|
| 96 |
-
# SQL:
|
| 97 |
-
# """
|
| 98 |
-
|
| 99 |
-
# def get_schema(self, db_id):
|
| 100 |
-
# return self.schema_encoder.structured_schema(db_id)
|
| 101 |
-
|
| 102 |
-
# def extract_sql(self, text: str):
|
| 103 |
-
|
| 104 |
-
# text = text.strip()
|
| 105 |
-
|
| 106 |
-
# if "SQL:" in text:
|
| 107 |
-
# text = text.split("SQL:")[-1]
|
| 108 |
-
|
| 109 |
-
# match = re.search(r"select[\s\S]*", text, re.IGNORECASE)
|
| 110 |
-
|
| 111 |
-
# if match:
|
| 112 |
-
# text = match.group(0)
|
| 113 |
-
|
| 114 |
-
# return text.split(";")[0].strip()
|
| 115 |
-
|
| 116 |
-
# def clean_sql(self, sql: str):
|
| 117 |
-
|
| 118 |
-
# sql = sql.replace('"', "'")
|
| 119 |
-
# sql = re.sub(r"\s+", " ", sql)
|
| 120 |
-
|
| 121 |
-
# return sql.strip()
|
| 122 |
-
|
| 123 |
-
# def generate_sql(self, prompt):
|
| 124 |
-
|
| 125 |
-
# inputs = self.tokenizer(
|
| 126 |
-
# prompt,
|
| 127 |
-
# return_tensors="pt",
|
| 128 |
-
# truncation=True,
|
| 129 |
-
# max_length=512
|
| 130 |
-
# ).to(self.device)
|
| 131 |
-
|
| 132 |
-
# with torch.no_grad():
|
| 133 |
-
|
| 134 |
-
# outputs = self.model.generate(
|
| 135 |
-
# **inputs,
|
| 136 |
-
# max_new_tokens=128,
|
| 137 |
-
# num_beams=5,
|
| 138 |
-
# early_stopping=True
|
| 139 |
-
# )
|
| 140 |
-
|
| 141 |
-
# decoded = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 142 |
-
|
| 143 |
-
# return self.clean_sql(self.extract_sql(decoded))
|
| 144 |
-
|
| 145 |
-
# def execute_sql(self, question, sql, db_id):
|
| 146 |
-
|
| 147 |
-
# if re.search(self.dml_keywords, sql, re.IGNORECASE):
|
| 148 |
-
# return sql, [], [], "❌ Security Alert"
|
| 149 |
-
|
| 150 |
-
# # FIXED DATABASE PATH
|
| 151 |
-
# db_path = DB_ROOT / f"{db_id}.sqlite"
|
| 152 |
-
|
| 153 |
-
# sql = self.clean_sql(sql)
|
| 154 |
-
# sql = semantic_fix(question, sql)
|
| 155 |
-
|
| 156 |
-
# try:
|
| 157 |
-
|
| 158 |
-
# conn = sqlite3.connect(db_path)
|
| 159 |
-
|
| 160 |
-
# cursor = conn.cursor()
|
| 161 |
-
|
| 162 |
-
# cursor.execute(sql)
|
| 163 |
-
|
| 164 |
-
# rows = cursor.fetchall()
|
| 165 |
-
|
| 166 |
-
# columns = [d[0] for d in cursor.description] if cursor.description else []
|
| 167 |
-
|
| 168 |
-
# conn.close()
|
| 169 |
-
|
| 170 |
-
# return sql, columns, rows, None
|
| 171 |
-
|
| 172 |
-
# except Exception as e:
|
| 173 |
-
|
| 174 |
-
# return sql, [], [], str(e)
|
| 175 |
-
|
| 176 |
-
# def ask(self, question, db_id):
|
| 177 |
-
|
| 178 |
-
# question = normalize_question(question)
|
| 179 |
-
|
| 180 |
-
# if re.search(self.dml_keywords, question, re.IGNORECASE):
|
| 181 |
-
|
| 182 |
-
# return {
|
| 183 |
-
# "question": question,
|
| 184 |
-
# "sql": "-- BLOCKED",
|
| 185 |
-
# "columns": [],
|
| 186 |
-
# "rows": [],
|
| 187 |
-
# "error": "Malicious prompt"
|
| 188 |
-
# }
|
| 189 |
-
|
| 190 |
-
# schema = self.get_schema(db_id)
|
| 191 |
-
|
| 192 |
-
# prompt = self.build_prompt(question, schema)
|
| 193 |
-
|
| 194 |
-
# raw_sql = self.generate_sql(prompt)
|
| 195 |
-
|
| 196 |
-
# final_sql, cols, rows, error = self.execute_sql(question, raw_sql, db_id)
|
| 197 |
-
|
| 198 |
-
# return {
|
| 199 |
-
# "question": question,
|
| 200 |
-
# "sql": final_sql,
|
| 201 |
-
# "columns": cols,
|
| 202 |
-
# "rows": rows,
|
| 203 |
-
# "error": error
|
| 204 |
-
# }
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
# _engine = None
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
# def get_engine():
|
| 211 |
-
|
| 212 |
-
# global _engine
|
| 213 |
-
|
| 214 |
-
# if _engine is None:
|
| 215 |
-
# _engine = Text2SQLEngine()
|
| 216 |
-
|
| 217 |
-
# return _engine
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
import sqlite3
|
| 222 |
import torch
|
| 223 |
import re
|
|
@@ -226,7 +6,7 @@ from pathlib import Path
|
|
| 226 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
| 227 |
from peft import PeftModel
|
| 228 |
from src.sql_validator import SQLValidator
|
| 229 |
-
from src.schema_encoder import SchemaEncoder
|
| 230 |
|
| 231 |
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
| 232 |
|
|
@@ -239,92 +19,6 @@ else:
|
|
| 239 |
DB_ROOT = PROJECT_ROOT / "final_databases"
|
| 240 |
|
| 241 |
|
| 242 |
-
# ==========================================
|
| 243 |
-
# INPUT VALIDATION & RELEVANCE (From Code 1)
|
| 244 |
-
# ==========================================
|
| 245 |
-
def is_valid_question(q: str):
|
| 246 |
-
"""Extremely relaxed valid question checker. As long as there is 1 word, it passes."""
|
| 247 |
-
words = re.findall(r"[a-zA-Z0-9]+", q)
|
| 248 |
-
return len(words) >= 1
|
| 249 |
-
|
| 250 |
-
def is_relevant_to_db(question: str, schema_graph: dict):
|
| 251 |
-
"""
|
| 252 |
-
Lexical heuristic to block completely out-of-domain questions
|
| 253 |
-
while allowing valid plurals.
|
| 254 |
-
"""
|
| 255 |
-
q_words = set(re.findall(r'\b[a-z]{3,}\b', question.lower()))
|
| 256 |
-
stop_words = {"show", "list", "all", "and", "the", "get", "find", "how", "many", "what", "where", "which", "who", "give", "display", "count", "from", "for", "with", "that", "have", "has", "are", "there"}
|
| 257 |
-
q_words = q_words - stop_words
|
| 258 |
-
|
| 259 |
-
if not q_words:
|
| 260 |
-
return True
|
| 261 |
-
|
| 262 |
-
schema_words = set()
|
| 263 |
-
for table, cols in schema_graph.items():
|
| 264 |
-
schema_words.update(re.findall(r'\b[a-z]{3,}\b', table.lower()))
|
| 265 |
-
for col in cols:
|
| 266 |
-
schema_words.update(re.findall(r'\b[a-z]{3,}\b', col.lower()))
|
| 267 |
-
|
| 268 |
-
synonyms = {
|
| 269 |
-
"customer": ["client", "buyer", "shopper", "person", "people", "user"],
|
| 270 |
-
"employee": ["staff", "worker", "boss", "manager", "person", "people"],
|
| 271 |
-
"track": ["song", "music", "audio", "tune"],
|
| 272 |
-
"album": ["record", "cd", "music"],
|
| 273 |
-
"artist": ["singer", "band", "musician", "creator"],
|
| 274 |
-
"invoice": ["bill", "receipt", "purchase", "sale", "order", "buy", "bought", "cost"],
|
| 275 |
-
"city": ["town", "location", "place"],
|
| 276 |
-
"country": ["nation", "location", "place"],
|
| 277 |
-
"flight": ["plane", "airline", "trip", "fly", "airport"],
|
| 278 |
-
"student": ["pupil", "learner", "kid", "child"],
|
| 279 |
-
"club": ["group", "organization", "team"],
|
| 280 |
-
"course": ["class", "subject"],
|
| 281 |
-
"cinema": ["movie", "film", "theater", "screen"]
|
| 282 |
-
}
|
| 283 |
-
|
| 284 |
-
extended_schema_words = set(schema_words)
|
| 285 |
-
for db_word in schema_words:
|
| 286 |
-
if db_word in synonyms:
|
| 287 |
-
extended_schema_words.update(synonyms[db_word])
|
| 288 |
-
|
| 289 |
-
extended_schema_words.update({"id", "name", "total", "sum", "average", "avg", "min", "max", "number", "amount", "record", "data", "info", "information", "detail", "first", "last", "most", "least", "cheapest", "expensive", "best"})
|
| 290 |
-
|
| 291 |
-
# Check if the word OR its singular form is in the schema
|
| 292 |
-
for qw in q_words:
|
| 293 |
-
qw_singular = qw[:-1] if qw.endswith('s') else qw
|
| 294 |
-
if qw in extended_schema_words or qw_singular in extended_schema_words:
|
| 295 |
-
return True
|
| 296 |
-
|
| 297 |
-
return False
|
| 298 |
-
|
| 299 |
-
# ==========================================
|
| 300 |
-
# SCHEMA CONSTRAINTS (From Code 1)
|
| 301 |
-
# ==========================================
|
| 302 |
-
def apply_schema_constraints(sql, schema_graph):
|
| 303 |
-
sql = sql.lower()
|
| 304 |
-
|
| 305 |
-
used_tables = [t[1] for t in re.findall(r'(from|join)\s+(\w+)', sql)]
|
| 306 |
-
for t in used_tables:
|
| 307 |
-
if t not in schema_graph:
|
| 308 |
-
return None
|
| 309 |
-
|
| 310 |
-
valid_columns = set()
|
| 311 |
-
for cols in schema_graph.values():
|
| 312 |
-
valid_columns.update(cols)
|
| 313 |
-
|
| 314 |
-
col_blocks = re.findall(r'select\s+(.*?)\s+from', sql)
|
| 315 |
-
for block in col_blocks:
|
| 316 |
-
for c in block.split(","):
|
| 317 |
-
c = c.strip().split()[-1]
|
| 318 |
-
if "." in c:
|
| 319 |
-
c = c.split(".")[-1]
|
| 320 |
-
|
| 321 |
-
if c != "*" and "(" not in c and c != "":
|
| 322 |
-
if c not in valid_columns:
|
| 323 |
-
return None
|
| 324 |
-
|
| 325 |
-
return sql
|
| 326 |
-
|
| 327 |
-
|
| 328 |
def normalize_question(q: str):
|
| 329 |
q = q.lower().strip()
|
| 330 |
q = re.sub(r"distinct\s+(\d+)", r"\1 distinct", q)
|
|
@@ -350,8 +44,7 @@ class Text2SQLEngine:
|
|
| 350 |
def __init__(self,
|
| 351 |
adapter_path=None,
|
| 352 |
base_model_name="Salesforce/codet5-base",
|
| 353 |
-
use_lora=True
|
| 354 |
-
use_constrained_decoding=True): # Added constrained decoding flag
|
| 355 |
|
| 356 |
self.device = "mps" if torch.backends.mps.is_available() else (
|
| 357 |
"cuda" if torch.cuda.is_available() else "cpu"
|
|
@@ -359,7 +52,6 @@ class Text2SQLEngine:
|
|
| 359 |
|
| 360 |
self.validator = SQLValidator(DB_ROOT)
|
| 361 |
self.schema_encoder = SchemaEncoder(DB_ROOT)
|
| 362 |
-
self.use_constrained_decoding = use_constrained_decoding
|
| 363 |
|
| 364 |
self.dml_keywords = r'\b(delete|update|insert|drop|alter|truncate)\b'
|
| 365 |
|
|
@@ -408,20 +100,28 @@ SQL:
|
|
| 408 |
return self.schema_encoder.structured_schema(db_id)
|
| 409 |
|
| 410 |
def extract_sql(self, text: str):
|
|
|
|
| 411 |
text = text.strip()
|
|
|
|
| 412 |
if "SQL:" in text:
|
| 413 |
text = text.split("SQL:")[-1]
|
|
|
|
| 414 |
match = re.search(r"select[\s\S]*", text, re.IGNORECASE)
|
|
|
|
| 415 |
if match:
|
| 416 |
text = match.group(0)
|
|
|
|
| 417 |
return text.split(";")[0].strip()
|
| 418 |
|
| 419 |
def clean_sql(self, sql: str):
|
|
|
|
| 420 |
sql = sql.replace('"', "'")
|
| 421 |
sql = re.sub(r"\s+", " ", sql)
|
|
|
|
| 422 |
return sql.strip()
|
| 423 |
|
| 424 |
def generate_sql(self, prompt):
|
|
|
|
| 425 |
inputs = self.tokenizer(
|
| 426 |
prompt,
|
| 427 |
return_tensors="pt",
|
|
@@ -430,6 +130,7 @@ SQL:
|
|
| 430 |
).to(self.device)
|
| 431 |
|
| 432 |
with torch.no_grad():
|
|
|
|
| 433 |
outputs = self.model.generate(
|
| 434 |
**inputs,
|
| 435 |
max_new_tokens=128,
|
|
@@ -438,85 +139,64 @@ SQL:
|
|
| 438 |
)
|
| 439 |
|
| 440 |
decoded = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
|
|
| 441 |
return self.clean_sql(self.extract_sql(decoded))
|
| 442 |
|
| 443 |
def execute_sql(self, question, sql, db_id):
|
|
|
|
| 444 |
if re.search(self.dml_keywords, sql, re.IGNORECASE):
|
| 445 |
return sql, [], [], "❌ Security Alert"
|
| 446 |
|
| 447 |
-
# FIXED DATABASE PATH
|
| 448 |
db_path = DB_ROOT / f"{db_id}.sqlite"
|
| 449 |
|
| 450 |
sql = self.clean_sql(sql)
|
| 451 |
sql = semantic_fix(question, sql)
|
| 452 |
|
| 453 |
try:
|
|
|
|
| 454 |
conn = sqlite3.connect(db_path)
|
|
|
|
| 455 |
cursor = conn.cursor()
|
|
|
|
| 456 |
cursor.execute(sql)
|
|
|
|
| 457 |
rows = cursor.fetchall()
|
|
|
|
| 458 |
columns = [d[0] for d in cursor.description] if cursor.description else []
|
|
|
|
| 459 |
conn.close()
|
|
|
|
| 460 |
return sql, columns, rows, None
|
|
|
|
| 461 |
except Exception as e:
|
|
|
|
| 462 |
return sql, [], [], str(e)
|
| 463 |
|
| 464 |
def ask(self, question, db_id):
|
| 465 |
-
# 1. Normalize
|
| 466 |
-
question_norm = normalize_question(question)
|
| 467 |
-
question_context = f"Database question: {question_norm}"
|
| 468 |
|
| 469 |
-
|
| 470 |
-
|
|
|
|
|
|
|
| 471 |
return {
|
| 472 |
-
"question":
|
| 473 |
"sql": "-- BLOCKED",
|
| 474 |
"columns": [],
|
| 475 |
"rows": [],
|
| 476 |
-
"error": "
|
| 477 |
}
|
| 478 |
|
| 479 |
-
# 3. Check basic validity of question
|
| 480 |
-
if not is_valid_question(question_context):
|
| 481 |
-
return {"sql": "", "error": "❌ Invalid input. Please type words."}
|
| 482 |
-
|
| 483 |
schema = self.get_schema(db_id)
|
| 484 |
-
schema_graph = build_schema_graph(schema)
|
| 485 |
|
| 486 |
-
|
| 487 |
-
if not is_relevant_to_db(question_norm, schema_graph):
|
| 488 |
-
return {"sql": "", "error": "❌ Question is completely out of domain for the selected database."}
|
| 489 |
|
| 490 |
-
# 5. INITIAL GENERATION
|
| 491 |
-
prompt = self.build_prompt(question_context, schema)
|
| 492 |
raw_sql = self.generate_sql(prompt)
|
| 493 |
|
| 494 |
-
|
| 495 |
-
if self.use_constrained_decoding:
|
| 496 |
-
filtered_sql = apply_schema_constraints(raw_sql, schema_graph)
|
| 497 |
-
|
| 498 |
-
if filtered_sql is None:
|
| 499 |
-
constraint_prompt = f"""Use ONLY valid schema.
|
| 500 |
-
Database schema:
|
| 501 |
-
{schema}
|
| 502 |
-
Generate a valid SQLite query for the question.
|
| 503 |
-
Question:
|
| 504 |
-
{question_context}
|
| 505 |
-
SQL:
|
| 506 |
-
"""
|
| 507 |
-
sql_retry = self.generate_sql(constraint_prompt)
|
| 508 |
-
filtered_sql = apply_schema_constraints(sql_retry, schema_graph)
|
| 509 |
-
|
| 510 |
-
if filtered_sql:
|
| 511 |
-
raw_sql = filtered_sql
|
| 512 |
-
else:
|
| 513 |
-
raw_sql = sql_retry
|
| 514 |
-
|
| 515 |
-
# 7. EXECUTION
|
| 516 |
-
final_sql, cols, rows, error = self.execute_sql(question_norm, raw_sql, db_id)
|
| 517 |
|
| 518 |
return {
|
| 519 |
-
"question":
|
| 520 |
"sql": final_sql,
|
| 521 |
"columns": cols,
|
| 522 |
"rows": rows,
|
|
@@ -526,338 +206,12 @@ SQL:
|
|
| 526 |
|
| 527 |
_engine = None
|
| 528 |
|
| 529 |
-
|
|
|
|
|
|
|
| 530 |
global _engine
|
| 531 |
|
| 532 |
if _engine is None:
|
| 533 |
-
_engine = Text2SQLEngine(
|
| 534 |
|
| 535 |
return _engine
|
| 536 |
-
|
| 537 |
-
# import sqlite3
|
| 538 |
-
# import torch
|
| 539 |
-
# import re
|
| 540 |
-
# import os
|
| 541 |
-
# from pathlib import Path
|
| 542 |
-
# from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
| 543 |
-
# from peft import PeftModel
|
| 544 |
-
# from src.sql_validator import SQLValidator
|
| 545 |
-
# from src.schema_encoder import SchemaEncoder # Removed build_schema_graph import
|
| 546 |
-
|
| 547 |
-
# PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
| 548 |
-
|
| 549 |
-
# # ================================
|
| 550 |
-
# # DATABASE PATH AUTO DETECTION
|
| 551 |
-
# # ================================
|
| 552 |
-
# if (PROJECT_ROOT / "data/database").exists():
|
| 553 |
-
# DB_ROOT = PROJECT_ROOT / "data/database"
|
| 554 |
-
# else:
|
| 555 |
-
# DB_ROOT = PROJECT_ROOT / "final_databases"
|
| 556 |
-
|
| 557 |
-
|
| 558 |
-
# # ==========================================
|
| 559 |
-
# # SCHEMA PARSING
|
| 560 |
-
# # ==========================================
|
| 561 |
-
# def build_schema_graph(schema_text):
|
| 562 |
-
# """
|
| 563 |
-
# Parses a structured schema text string into a dictionary graph.
|
| 564 |
-
# Matches formats like: table_name(col1, col2, col3)
|
| 565 |
-
# """
|
| 566 |
-
# tables = {}
|
| 567 |
-
# for match in re.findall(r'(\w+)\s*\((.*?)\)', schema_text):
|
| 568 |
-
# table = match[0]
|
| 569 |
-
# cols = [c.strip().split()[0] for c in match[1].split(",")]
|
| 570 |
-
# tables[table] = cols
|
| 571 |
-
# return tables
|
| 572 |
-
|
| 573 |
-
|
| 574 |
-
# # ==========================================
|
| 575 |
-
# # INPUT VALIDATION & RELEVANCE
|
| 576 |
-
# # ==========================================
|
| 577 |
-
# def is_valid_question(q: str):
|
| 578 |
-
# q = q.strip().lower()
|
| 579 |
-
|
| 580 |
-
# if len(q) < 3:
|
| 581 |
-
# return False
|
| 582 |
-
|
| 583 |
-
# words = re.findall(r"[a-zA-Z]+", q)
|
| 584 |
-
# if len(words) < 1:
|
| 585 |
-
# return False
|
| 586 |
-
|
| 587 |
-
# return True
|
| 588 |
-
|
| 589 |
-
|
| 590 |
-
# def is_relevant_to_db(question: str, schema_graph: dict):
|
| 591 |
-
# q_words = set(re.findall(r'\b[a-z]{3,}\b', question.lower()))
|
| 592 |
-
# stop_words = {"show", "list", "all", "and", "the", "get", "find", "how", "many", "what", "where", "which", "who", "give", "display", "count", "from", "for", "with", "that", "have", "has", "are", "there"}
|
| 593 |
-
# q_words = q_words - stop_words
|
| 594 |
-
|
| 595 |
-
# if not q_words:
|
| 596 |
-
# return True
|
| 597 |
-
|
| 598 |
-
# schema_words = set()
|
| 599 |
-
# for table, cols in schema_graph.items():
|
| 600 |
-
# schema_words.update(re.findall(r'\b[a-z]{3,}\b', table.lower()))
|
| 601 |
-
# for col in cols:
|
| 602 |
-
# schema_words.update(re.findall(r'\b[a-z]{3,}\b', col.lower()))
|
| 603 |
-
|
| 604 |
-
# synonyms = {
|
| 605 |
-
# "customer": ["client", "buyer", "shopper", "person", "people", "user"],
|
| 606 |
-
# "employee": ["staff", "worker", "boss", "manager", "person", "people"],
|
| 607 |
-
# "track": ["song", "music", "audio", "tune"],
|
| 608 |
-
# "album": ["record", "cd", "music"],
|
| 609 |
-
# "artist": ["singer", "band", "musician", "creator"],
|
| 610 |
-
# "invoice": ["bill", "receipt", "purchase", "sale", "order", "buy", "bought", "cost"],
|
| 611 |
-
# "city": ["town", "location", "place"],
|
| 612 |
-
# "country": ["nation", "location", "place"],
|
| 613 |
-
# "flight": ["plane", "airline", "trip", "fly", "airport"],
|
| 614 |
-
# "student": ["pupil", "learner", "kid", "child"],
|
| 615 |
-
# "club": ["group", "organization", "team"],
|
| 616 |
-
# "course": ["class", "subject"],
|
| 617 |
-
# "cinema": ["movie", "film", "theater", "screen"]
|
| 618 |
-
# }
|
| 619 |
-
|
| 620 |
-
# extended_schema_words = set(schema_words)
|
| 621 |
-
# for db_word in schema_words:
|
| 622 |
-
# if db_word in synonyms:
|
| 623 |
-
# extended_schema_words.update(synonyms[db_word])
|
| 624 |
-
|
| 625 |
-
# extended_schema_words.update({"id", "name", "total", "sum", "average", "avg", "min", "max", "number", "amount", "record", "data", "info", "information", "detail", "first", "last", "most", "least", "cheapest", "expensive", "best"})
|
| 626 |
-
|
| 627 |
-
# for qw in q_words:
|
| 628 |
-
# qw_singular = qw[:-1] if qw.endswith('s') else qw
|
| 629 |
-
# if qw in extended_schema_words or qw_singular in extended_schema_words:
|
| 630 |
-
# return True
|
| 631 |
-
|
| 632 |
-
# return False
|
| 633 |
-
|
| 634 |
-
# def normalize_question(q: str):
|
| 635 |
-
# return re.sub(r"\s+", " ", q.lower().strip())
|
| 636 |
-
|
| 637 |
-
# def semantic_fix(question, sql):
|
| 638 |
-
# q = question.lower()
|
| 639 |
-
# num_match = re.search(r'\b(?:show|list|top|get)\s+(\d+)\b', q)
|
| 640 |
-
|
| 641 |
-
# if num_match and "limit" not in sql.lower():
|
| 642 |
-
# sql = f"{sql} LIMIT {num_match.group(1)}"
|
| 643 |
-
|
| 644 |
-
# return sql
|
| 645 |
-
|
| 646 |
-
# # ==========================================
|
| 647 |
-
# # SCHEMA CONSTRAINTS (SIMULATED LOGIT BLOCKING)
|
| 648 |
-
# # ==========================================
|
| 649 |
-
# def apply_schema_constraints(sql, schema_graph):
|
| 650 |
-
# sql = sql.lower()
|
| 651 |
-
|
| 652 |
-
# used_tables = [t[1] for t in re.findall(r'(from|join)\s+(\w+)', sql)]
|
| 653 |
-
# for t in used_tables:
|
| 654 |
-
# if t not in schema_graph:
|
| 655 |
-
# return None
|
| 656 |
-
|
| 657 |
-
# valid_columns = set()
|
| 658 |
-
# for cols in schema_graph.values():
|
| 659 |
-
# valid_columns.update(cols)
|
| 660 |
-
|
| 661 |
-
# col_blocks = re.findall(r'select\s+(.*?)\s+from', sql)
|
| 662 |
-
# for block in col_blocks:
|
| 663 |
-
# for c in block.split(","):
|
| 664 |
-
# c = c.strip().split()[-1]
|
| 665 |
-
# if "." in c:
|
| 666 |
-
# c = c.split(".")[-1]
|
| 667 |
-
|
| 668 |
-
# if c != "*" and "(" not in c and c != "":
|
| 669 |
-
# if c not in valid_columns:
|
| 670 |
-
# return None
|
| 671 |
-
|
| 672 |
-
# return sql
|
| 673 |
-
|
| 674 |
-
# # ==========================================
|
| 675 |
-
# # ENGINE
|
| 676 |
-
# # ==========================================
|
| 677 |
-
# class Text2SQLEngine:
|
| 678 |
-
|
| 679 |
-
# def __init__(self,
|
| 680 |
-
# adapter_path="checkpoints/best_rlhf_model_2",
|
| 681 |
-
# base_model_name="Salesforce/codet5-base",
|
| 682 |
-
# use_lora=True,
|
| 683 |
-
# use_constrained_decoding=False):
|
| 684 |
-
|
| 685 |
-
# self.device = "mps" if torch.backends.mps.is_available() else (
|
| 686 |
-
# "cuda" if torch.cuda.is_available() else "cpu"
|
| 687 |
-
# )
|
| 688 |
-
|
| 689 |
-
# self.validator = SQLValidator(DB_ROOT)
|
| 690 |
-
# self.schema_encoder = SchemaEncoder(DB_ROOT)
|
| 691 |
-
|
| 692 |
-
# self.use_constrained_decoding = use_constrained_decoding
|
| 693 |
-
# self.dml_keywords = r'\b(delete|update|insert|drop|alter|truncate|create)\b'
|
| 694 |
-
|
| 695 |
-
# print(f"\n📦 Loading model on {self.device}...")
|
| 696 |
-
|
| 697 |
-
# base = AutoModelForSeq2SeqLM.from_pretrained(base_model_name)
|
| 698 |
-
|
| 699 |
-
# # Override the redundant special tokens to prevent the tokenizer crash
|
| 700 |
-
# self.tokenizer = AutoTokenizer.from_pretrained(
|
| 701 |
-
# base_model_name,
|
| 702 |
-
# use_fast=False,
|
| 703 |
-
# additional_special_tokens=[]
|
| 704 |
-
# )
|
| 705 |
-
|
| 706 |
-
# # 🔥 FIXED LOADA ADAPTER PATH LOGIC
|
| 707 |
-
# if use_lora:
|
| 708 |
-
# if adapter_path and (PROJECT_ROOT / adapter_path).exists():
|
| 709 |
-
# adapter_path = PROJECT_ROOT / adapter_path
|
| 710 |
-
# elif (PROJECT_ROOT / "checkpoints/best_rlhf_model_2").exists():
|
| 711 |
-
# adapter_path = PROJECT_ROOT / "checkpoints/best_rlhf_model_2"
|
| 712 |
-
# else:
|
| 713 |
-
# adapter_path = PROJECT_ROOT / "best_rlhf_model_2"
|
| 714 |
-
|
| 715 |
-
# adapter_path = adapter_path.resolve()
|
| 716 |
-
|
| 717 |
-
# if adapter_path.exists():
|
| 718 |
-
# try:
|
| 719 |
-
# self.model = PeftModel.from_pretrained(
|
| 720 |
-
# base,
|
| 721 |
-
# str(adapter_path),
|
| 722 |
-
# local_files_only=True
|
| 723 |
-
# ).to(self.device)
|
| 724 |
-
# print(f"✅ LoRA loaded from {adapter_path}")
|
| 725 |
-
# except Exception as e:
|
| 726 |
-
# print(f"⚠️ LoRA load failed: {e}")
|
| 727 |
-
# self.model = base.to(self.device)
|
| 728 |
-
# else:
|
| 729 |
-
# print(f"⚠️ Adapter not found at {adapter_path}, using base model")
|
| 730 |
-
# self.model = base.to(self.device)
|
| 731 |
-
# else:
|
| 732 |
-
# self.model = base.to(self.device)
|
| 733 |
-
|
| 734 |
-
# self.model.eval()
|
| 735 |
-
|
| 736 |
-
# def build_prompt(self, question, schema):
|
| 737 |
-
# return f"""
|
| 738 |
-
# You are an expert SQL generator.
|
| 739 |
-
|
| 740 |
-
# IMPORTANT:
|
| 741 |
-
# - Use correct tables and columns
|
| 742 |
-
# - Use JOINs when needed
|
| 743 |
-
|
| 744 |
-
# Schema:
|
| 745 |
-
# {schema}
|
| 746 |
-
|
| 747 |
-
# Question:
|
| 748 |
-
# {question}
|
| 749 |
-
|
| 750 |
-
# SQL:
|
| 751 |
-
# """
|
| 752 |
-
|
| 753 |
-
# def get_schema(self, db_id):
|
| 754 |
-
# return self.schema_encoder.structured_schema(db_id)
|
| 755 |
-
|
| 756 |
-
# def extract_sql(self, text):
|
| 757 |
-
# match = re.search(r"(select|with)[\s\S]*", text, re.IGNORECASE)
|
| 758 |
-
# return match.group(0).split(";")[0].strip() if match else ""
|
| 759 |
-
|
| 760 |
-
# def clean_sql(self, sql):
|
| 761 |
-
# return re.sub(r"\s+", " ", sql.replace('"', "'")).strip()
|
| 762 |
-
|
| 763 |
-
# def generate_sql(self, prompt):
|
| 764 |
-
# inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
|
| 765 |
-
|
| 766 |
-
# with torch.no_grad():
|
| 767 |
-
# outputs = self.model.generate(
|
| 768 |
-
# **inputs,
|
| 769 |
-
# max_new_tokens=128,
|
| 770 |
-
# num_beams=8,
|
| 771 |
-
# length_penalty=0.8,
|
| 772 |
-
# early_stopping=True
|
| 773 |
-
# )
|
| 774 |
-
|
| 775 |
-
# decoded = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 776 |
-
# return self.clean_sql(self.extract_sql(decoded))
|
| 777 |
-
|
| 778 |
-
# def execute_sql(self, question, sql, db_id):
|
| 779 |
-
|
| 780 |
-
# if re.search(self.dml_keywords, sql, re.IGNORECASE):
|
| 781 |
-
# return "", [], [], "❌ Blocked malicious SQL (Contains INSERT/UPDATE/DELETE/DROP)"
|
| 782 |
-
|
| 783 |
-
# # 🔥 FIXED DATABASE PATH
|
| 784 |
-
# db_path = DB_ROOT / f"{db_id}.sqlite"
|
| 785 |
-
# sql = semantic_fix(question, sql)
|
| 786 |
-
|
| 787 |
-
# try:
|
| 788 |
-
# conn = sqlite3.connect(db_path)
|
| 789 |
-
# cursor = conn.cursor()
|
| 790 |
-
# cursor.execute(sql)
|
| 791 |
-
|
| 792 |
-
# rows = cursor.fetchall()
|
| 793 |
-
# columns = [d[0] for d in cursor.description] if cursor.description else []
|
| 794 |
-
|
| 795 |
-
# conn.close()
|
| 796 |
-
# return sql, columns, rows, None
|
| 797 |
-
|
| 798 |
-
# except Exception as e:
|
| 799 |
-
# return sql, [], [], str(e)
|
| 800 |
-
|
| 801 |
-
# def ask(self, question, db_id):
|
| 802 |
-
|
| 803 |
-
# question = normalize_question(question)
|
| 804 |
-
# question_context = f"Database question: {question}"
|
| 805 |
-
|
| 806 |
-
# if re.search(self.dml_keywords, question_context, re.IGNORECASE):
|
| 807 |
-
# return {"sql": "", "error": "❌ Blocked dangerous query from input text."}
|
| 808 |
-
|
| 809 |
-
# if not is_valid_question(question_context):
|
| 810 |
-
# return {"sql": "", "error": "❌ Invalid input. Please type words."}
|
| 811 |
-
|
| 812 |
-
# schema = self.get_schema(db_id)
|
| 813 |
-
# schema_graph = build_schema_graph(schema)
|
| 814 |
-
|
| 815 |
-
# if not is_relevant_to_db(question, schema_graph):
|
| 816 |
-
# return {"sql": "", "error": "❌ Question is completely out of domain for the selected database."}
|
| 817 |
-
|
| 818 |
-
# sql = self.generate_sql(self.build_prompt(question_context, schema))
|
| 819 |
-
|
| 820 |
-
# if self.use_constrained_decoding:
|
| 821 |
-
# filtered_sql = apply_schema_constraints(sql, schema_graph)
|
| 822 |
-
|
| 823 |
-
# if filtered_sql is None:
|
| 824 |
-
# constraint_prompt = f"""
|
| 825 |
-
# Use ONLY valid schema.
|
| 826 |
-
# Schema:
|
| 827 |
-
# {schema}
|
| 828 |
-
|
| 829 |
-
# Question:
|
| 830 |
-
# {question_context}
|
| 831 |
-
|
| 832 |
-
# SQL:
|
| 833 |
-
# """
|
| 834 |
-
# sql_retry = self.generate_sql(constraint_prompt)
|
| 835 |
-
# filtered_sql = apply_schema_constraints(sql_retry, schema_graph)
|
| 836 |
-
|
| 837 |
-
# if filtered_sql:
|
| 838 |
-
# sql = filtered_sql
|
| 839 |
-
# else:
|
| 840 |
-
# sql = sql_retry
|
| 841 |
-
|
| 842 |
-
# final_sql, cols, rows, error = self.execute_sql(question_context, sql, db_id)
|
| 843 |
-
|
| 844 |
-
# return {
|
| 845 |
-
# "question": question_context,
|
| 846 |
-
# "sql": final_sql,
|
| 847 |
-
# "columns": cols,
|
| 848 |
-
# "rows": rows,
|
| 849 |
-
# "error": error
|
| 850 |
-
# }
|
| 851 |
-
|
| 852 |
-
# def get_engine(
|
| 853 |
-
# adapter_path="checkpoints/best_rlhf_model_2",
|
| 854 |
-
# base_model_name="Salesforce/codet5-base",
|
| 855 |
-
# use_lora=True,
|
| 856 |
-
# use_constrained=True
|
| 857 |
-
# ):
|
| 858 |
-
# return Text2SQLEngine(
|
| 859 |
-
# adapter_path=adapter_path,
|
| 860 |
-
# base_model_name=base_model_name,
|
| 861 |
-
# use_lora=use_lora,
|
| 862 |
-
# use_constrained_decoding=use_constrained
|
| 863 |
-
# )
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import sqlite3
|
| 2 |
import torch
|
| 3 |
import re
|
|
|
|
| 6 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
| 7 |
from peft import PeftModel
|
| 8 |
from src.sql_validator import SQLValidator
|
| 9 |
+
from src.schema_encoder import SchemaEncoder
|
| 10 |
|
| 11 |
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
| 12 |
|
|
|
|
| 19 |
DB_ROOT = PROJECT_ROOT / "final_databases"
|
| 20 |
|
| 21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
def normalize_question(q: str):
|
| 23 |
q = q.lower().strip()
|
| 24 |
q = re.sub(r"distinct\s+(\d+)", r"\1 distinct", q)
|
|
|
|
| 44 |
def __init__(self,
|
| 45 |
adapter_path=None,
|
| 46 |
base_model_name="Salesforce/codet5-base",
|
| 47 |
+
use_lora=True):
|
|
|
|
| 48 |
|
| 49 |
self.device = "mps" if torch.backends.mps.is_available() else (
|
| 50 |
"cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
| 52 |
|
| 53 |
self.validator = SQLValidator(DB_ROOT)
|
| 54 |
self.schema_encoder = SchemaEncoder(DB_ROOT)
|
|
|
|
| 55 |
|
| 56 |
self.dml_keywords = r'\b(delete|update|insert|drop|alter|truncate)\b'
|
| 57 |
|
|
|
|
| 100 |
return self.schema_encoder.structured_schema(db_id)
|
| 101 |
|
| 102 |
def extract_sql(self, text: str):
|
| 103 |
+
|
| 104 |
text = text.strip()
|
| 105 |
+
|
| 106 |
if "SQL:" in text:
|
| 107 |
text = text.split("SQL:")[-1]
|
| 108 |
+
|
| 109 |
match = re.search(r"select[\s\S]*", text, re.IGNORECASE)
|
| 110 |
+
|
| 111 |
if match:
|
| 112 |
text = match.group(0)
|
| 113 |
+
|
| 114 |
return text.split(";")[0].strip()
|
| 115 |
|
| 116 |
def clean_sql(self, sql: str):
|
| 117 |
+
|
| 118 |
sql = sql.replace('"', "'")
|
| 119 |
sql = re.sub(r"\s+", " ", sql)
|
| 120 |
+
|
| 121 |
return sql.strip()
|
| 122 |
|
| 123 |
def generate_sql(self, prompt):
|
| 124 |
+
|
| 125 |
inputs = self.tokenizer(
|
| 126 |
prompt,
|
| 127 |
return_tensors="pt",
|
|
|
|
| 130 |
).to(self.device)
|
| 131 |
|
| 132 |
with torch.no_grad():
|
| 133 |
+
|
| 134 |
outputs = self.model.generate(
|
| 135 |
**inputs,
|
| 136 |
max_new_tokens=128,
|
|
|
|
| 139 |
)
|
| 140 |
|
| 141 |
decoded = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 142 |
+
|
| 143 |
return self.clean_sql(self.extract_sql(decoded))
|
| 144 |
|
| 145 |
def execute_sql(self, question, sql, db_id):
|
| 146 |
+
|
| 147 |
if re.search(self.dml_keywords, sql, re.IGNORECASE):
|
| 148 |
return sql, [], [], "❌ Security Alert"
|
| 149 |
|
| 150 |
+
# FIXED DATABASE PATH
|
| 151 |
db_path = DB_ROOT / f"{db_id}.sqlite"
|
| 152 |
|
| 153 |
sql = self.clean_sql(sql)
|
| 154 |
sql = semantic_fix(question, sql)
|
| 155 |
|
| 156 |
try:
|
| 157 |
+
|
| 158 |
conn = sqlite3.connect(db_path)
|
| 159 |
+
|
| 160 |
cursor = conn.cursor()
|
| 161 |
+
|
| 162 |
cursor.execute(sql)
|
| 163 |
+
|
| 164 |
rows = cursor.fetchall()
|
| 165 |
+
|
| 166 |
columns = [d[0] for d in cursor.description] if cursor.description else []
|
| 167 |
+
|
| 168 |
conn.close()
|
| 169 |
+
|
| 170 |
return sql, columns, rows, None
|
| 171 |
+
|
| 172 |
except Exception as e:
|
| 173 |
+
|
| 174 |
return sql, [], [], str(e)
|
| 175 |
|
| 176 |
def ask(self, question, db_id):
|
|
|
|
|
|
|
|
|
|
| 177 |
|
| 178 |
+
question = normalize_question(question)
|
| 179 |
+
|
| 180 |
+
if re.search(self.dml_keywords, question, re.IGNORECASE):
|
| 181 |
+
|
| 182 |
return {
|
| 183 |
+
"question": question,
|
| 184 |
"sql": "-- BLOCKED",
|
| 185 |
"columns": [],
|
| 186 |
"rows": [],
|
| 187 |
+
"error": "Malicious prompt"
|
| 188 |
}
|
| 189 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 190 |
schema = self.get_schema(db_id)
|
|
|
|
| 191 |
|
| 192 |
+
prompt = self.build_prompt(question, schema)
|
|
|
|
|
|
|
| 193 |
|
|
|
|
|
|
|
| 194 |
raw_sql = self.generate_sql(prompt)
|
| 195 |
|
| 196 |
+
final_sql, cols, rows, error = self.execute_sql(question, raw_sql, db_id)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 197 |
|
| 198 |
return {
|
| 199 |
+
"question": question,
|
| 200 |
"sql": final_sql,
|
| 201 |
"columns": cols,
|
| 202 |
"rows": rows,
|
|
|
|
| 206 |
|
| 207 |
_engine = None
|
| 208 |
|
| 209 |
+
|
| 210 |
+
def get_engine():
|
| 211 |
+
|
| 212 |
global _engine
|
| 213 |
|
| 214 |
if _engine is None:
|
| 215 |
+
_engine = Text2SQLEngine()
|
| 216 |
|
| 217 |
return _engine
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|