text2sql-demo / src /text2sql_engine.py
tjhalanigrid's picture
Fix database path and add SQL explanation + execution time
aa2c432
import sqlite3
import torch
import re
import time
from pathlib import Path
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from peft import PeftModel
from src.sql_validator import SQLValidator
from src.schema_encoder import SchemaEncoder
PROJECT_ROOT = Path(__file__).resolve().parents[1]
# ================================
# DATABASE PATH AUTO DETECTION
# ================================
if (PROJECT_ROOT / "data/database").exists():
DB_ROOT = PROJECT_ROOT / "data/database"
else:
DB_ROOT = PROJECT_ROOT / "final_databases"
def normalize_question(q: str):
q = q.lower().strip()
q = re.sub(r"distinct\s+(\d+)", r"\1 distinct", q)
q = re.sub(r"\s+", " ", q)
return q
def semantic_fix(question, sql):
q = question.lower().strip()
s = sql.lower()
num_match = re.search(r'\b(?:show|list|top|limit|get|first|last)\s+(\d+)\b', q)
if num_match and "limit" not in s and "count(" not in s:
limit_val = num_match.group(1)
sql = sql.rstrip(";")
sql = f"{sql.strip()} LIMIT {limit_val}"
return sql
class Text2SQLEngine:
def __init__(self,
adapter_path=None,
base_model_name="Salesforce/codet5-base",
use_lora=True):
self.device = "mps" if torch.backends.mps.is_available() else (
"cuda" if torch.cuda.is_available() else "cpu"
)
self.validator = SQLValidator(DB_ROOT)
self.schema_encoder = SchemaEncoder(DB_ROOT)
self.dml_keywords = r'\b(delete|update|insert|drop|alter|truncate)\b'
print("Loading base model...")
base = AutoModelForSeq2SeqLM.from_pretrained(base_model_name)
if not use_lora:
self.tokenizer = AutoTokenizer.from_pretrained(base_model_name)
self.model = base.to(self.device)
self.model.eval()
return
if (PROJECT_ROOT / "checkpoints/best_rlhf_model").exists():
adapter_path = PROJECT_ROOT / "checkpoints/best_rlhf_model"
else:
adapter_path = PROJECT_ROOT / "best_rlhf_model"
adapter_path = adapter_path.resolve()
print("Loading tokenizer and LoRA adapter...")
try:
self.tokenizer = AutoTokenizer.from_pretrained(
str(adapter_path),
local_files_only=True
)
except Exception:
self.tokenizer = AutoTokenizer.from_pretrained(base_model_name)
self.model = PeftModel.from_pretrained(base, str(adapter_path)).to(self.device)
self.model.eval()
print("✅ RLHF model ready\n")
def build_prompt(self, question, schema):
return f"""You are an expert SQL generator.
Database schema:
{schema}
Generate a valid SQLite query for the question.
Question:
{question}
SQL:
"""
def get_schema(self, db_id):
return self.schema_encoder.structured_schema(db_id)
def extract_sql(self, text: str):
text = text.strip()
if "SQL:" in text:
text = text.split("SQL:")[-1]
match = re.search(r"select[\s\S]*", text, re.IGNORECASE)
if match:
text = match.group(0)
return text.split(";")[0].strip()
def clean_sql(self, sql: str):
sql = sql.replace('"', "'")
sql = re.sub(r"\s+", " ", sql)
return sql.strip()
def generate_sql(self, prompt):
inputs = self.tokenizer(
prompt,
return_tensors="pt",
truncation=True,
max_length=512
).to(self.device)
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=128,
num_beams=5,
early_stopping=True
)
decoded = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
return self.clean_sql(self.extract_sql(decoded))
def execute_sql(self, question, sql, db_id):
if re.search(self.dml_keywords, sql, re.IGNORECASE):
return sql, [], [], "❌ Security Alert"
# FIXED DATABASE PATH
db_path = DB_ROOT / f"{db_id}.sqlite"
sql = self.clean_sql(sql)
sql = semantic_fix(question, sql)
try:
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
cursor.execute(sql)
rows = cursor.fetchall()
columns = [d[0] for d in cursor.description] if cursor.description else []
conn.close()
return sql, columns, rows, None
except Exception as e:
return sql, [], [], str(e)
def ask(self, question, db_id):
question = normalize_question(question)
if re.search(self.dml_keywords, question, re.IGNORECASE):
return {
"question": question,
"sql": "-- BLOCKED",
"columns": [],
"rows": [],
"error": "Malicious prompt"
}
schema = self.get_schema(db_id)
prompt = self.build_prompt(question, schema)
raw_sql = self.generate_sql(prompt)
final_sql, cols, rows, error = self.execute_sql(question, raw_sql, db_id)
return {
"question": question,
"sql": final_sql,
"columns": cols,
"rows": rows,
"error": error
}
_engine = None
def get_engine():
global _engine
if _engine is None:
_engine = Text2SQLEngine()
return _engine