import sqlite3 import torch import re import time from pathlib import Path from transformers import AutoTokenizer, AutoModelForSeq2SeqLM from peft import PeftModel from src.sql_validator import SQLValidator from src.schema_encoder import SchemaEncoder PROJECT_ROOT = Path(__file__).resolve().parents[1] # ================================ # DATABASE PATH AUTO DETECTION # ================================ if (PROJECT_ROOT / "data/database").exists(): DB_ROOT = PROJECT_ROOT / "data/database" else: DB_ROOT = PROJECT_ROOT / "final_databases" def normalize_question(q: str): q = q.lower().strip() q = re.sub(r"distinct\s+(\d+)", r"\1 distinct", q) q = re.sub(r"\s+", " ", q) return q def semantic_fix(question, sql): q = question.lower().strip() s = sql.lower() num_match = re.search(r'\b(?:show|list|top|limit|get|first|last)\s+(\d+)\b', q) if num_match and "limit" not in s and "count(" not in s: limit_val = num_match.group(1) sql = sql.rstrip(";") sql = f"{sql.strip()} LIMIT {limit_val}" return sql class Text2SQLEngine: def __init__(self, adapter_path=None, base_model_name="Salesforce/codet5-base", use_lora=True): self.device = "mps" if torch.backends.mps.is_available() else ( "cuda" if torch.cuda.is_available() else "cpu" ) self.validator = SQLValidator(DB_ROOT) self.schema_encoder = SchemaEncoder(DB_ROOT) self.dml_keywords = r'\b(delete|update|insert|drop|alter|truncate)\b' print("Loading base model...") base = AutoModelForSeq2SeqLM.from_pretrained(base_model_name) if not use_lora: self.tokenizer = AutoTokenizer.from_pretrained(base_model_name) self.model = base.to(self.device) self.model.eval() return if (PROJECT_ROOT / "checkpoints/best_rlhf_model").exists(): adapter_path = PROJECT_ROOT / "checkpoints/best_rlhf_model" else: adapter_path = PROJECT_ROOT / "best_rlhf_model" adapter_path = adapter_path.resolve() print("Loading tokenizer and LoRA adapter...") try: self.tokenizer = AutoTokenizer.from_pretrained( str(adapter_path), local_files_only=True ) except Exception: self.tokenizer = AutoTokenizer.from_pretrained(base_model_name) self.model = PeftModel.from_pretrained(base, str(adapter_path)).to(self.device) self.model.eval() print("✅ RLHF model ready\n") def build_prompt(self, question, schema): return f"""You are an expert SQL generator. Database schema: {schema} Generate a valid SQLite query for the question. Question: {question} SQL: """ def get_schema(self, db_id): return self.schema_encoder.structured_schema(db_id) def extract_sql(self, text: str): text = text.strip() if "SQL:" in text: text = text.split("SQL:")[-1] match = re.search(r"select[\s\S]*", text, re.IGNORECASE) if match: text = match.group(0) return text.split(";")[0].strip() def clean_sql(self, sql: str): sql = sql.replace('"', "'") sql = re.sub(r"\s+", " ", sql) return sql.strip() def generate_sql(self, prompt): inputs = self.tokenizer( prompt, return_tensors="pt", truncation=True, max_length=512 ).to(self.device) with torch.no_grad(): outputs = self.model.generate( **inputs, max_new_tokens=128, num_beams=5, early_stopping=True ) decoded = self.tokenizer.decode(outputs[0], skip_special_tokens=True) return self.clean_sql(self.extract_sql(decoded)) def execute_sql(self, question, sql, db_id): if re.search(self.dml_keywords, sql, re.IGNORECASE): return sql, [], [], "❌ Security Alert" # FIXED DATABASE PATH db_path = DB_ROOT / f"{db_id}.sqlite" sql = self.clean_sql(sql) sql = semantic_fix(question, sql) try: conn = sqlite3.connect(db_path) cursor = conn.cursor() cursor.execute(sql) rows = cursor.fetchall() columns = [d[0] for d in cursor.description] if cursor.description else [] conn.close() return sql, columns, rows, None except Exception as e: return sql, [], [], str(e) def ask(self, question, db_id): question = normalize_question(question) if re.search(self.dml_keywords, question, re.IGNORECASE): return { "question": question, "sql": "-- BLOCKED", "columns": [], "rows": [], "error": "Malicious prompt" } schema = self.get_schema(db_id) prompt = self.build_prompt(question, schema) raw_sql = self.generate_sql(prompt) final_sql, cols, rows, error = self.execute_sql(question, raw_sql, db_id) return { "question": question, "sql": final_sql, "columns": cols, "rows": rows, "error": error } _engine = None def get_engine(): global _engine if _engine is None: _engine = Text2SQLEngine() return _engine