Spaces:
Running
Running
| import sqlite3 | |
| import torch | |
| import re | |
| import time | |
| from pathlib import Path | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
| from peft import PeftModel | |
| from src.sql_validator import SQLValidator | |
| from src.schema_encoder import SchemaEncoder | |
| PROJECT_ROOT = Path(__file__).resolve().parents[1] | |
| # ================================ | |
| # DATABASE PATH AUTO DETECTION | |
| # ================================ | |
| if (PROJECT_ROOT / "data/database").exists(): | |
| DB_ROOT = PROJECT_ROOT / "data/database" | |
| else: | |
| DB_ROOT = PROJECT_ROOT / "final_databases" | |
| def normalize_question(q: str): | |
| q = q.lower().strip() | |
| q = re.sub(r"distinct\s+(\d+)", r"\1 distinct", q) | |
| q = re.sub(r"\s+", " ", q) | |
| return q | |
| def semantic_fix(question, sql): | |
| q = question.lower().strip() | |
| s = sql.lower() | |
| num_match = re.search(r'\b(?:show|list|top|limit|get|first|last)\s+(\d+)\b', q) | |
| if num_match and "limit" not in s and "count(" not in s: | |
| limit_val = num_match.group(1) | |
| sql = sql.rstrip(";") | |
| sql = f"{sql.strip()} LIMIT {limit_val}" | |
| return sql | |
| class Text2SQLEngine: | |
| def __init__(self, | |
| adapter_path=None, | |
| base_model_name="Salesforce/codet5-base", | |
| use_lora=True): | |
| self.device = "mps" if torch.backends.mps.is_available() else ( | |
| "cuda" if torch.cuda.is_available() else "cpu" | |
| ) | |
| self.validator = SQLValidator(DB_ROOT) | |
| self.schema_encoder = SchemaEncoder(DB_ROOT) | |
| self.dml_keywords = r'\b(delete|update|insert|drop|alter|truncate)\b' | |
| print("Loading base model...") | |
| base = AutoModelForSeq2SeqLM.from_pretrained(base_model_name) | |
| if not use_lora: | |
| self.tokenizer = AutoTokenizer.from_pretrained(base_model_name) | |
| self.model = base.to(self.device) | |
| self.model.eval() | |
| return | |
| if (PROJECT_ROOT / "checkpoints/best_rlhf_model").exists(): | |
| adapter_path = PROJECT_ROOT / "checkpoints/best_rlhf_model" | |
| else: | |
| adapter_path = PROJECT_ROOT / "best_rlhf_model" | |
| adapter_path = adapter_path.resolve() | |
| print("Loading tokenizer and LoRA adapter...") | |
| try: | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| str(adapter_path), | |
| local_files_only=True | |
| ) | |
| except Exception: | |
| self.tokenizer = AutoTokenizer.from_pretrained(base_model_name) | |
| self.model = PeftModel.from_pretrained(base, str(adapter_path)).to(self.device) | |
| self.model.eval() | |
| print("✅ RLHF model ready\n") | |
| def build_prompt(self, question, schema): | |
| return f"""You are an expert SQL generator. | |
| Database schema: | |
| {schema} | |
| Generate a valid SQLite query for the question. | |
| Question: | |
| {question} | |
| SQL: | |
| """ | |
| def get_schema(self, db_id): | |
| return self.schema_encoder.structured_schema(db_id) | |
| def extract_sql(self, text: str): | |
| text = text.strip() | |
| if "SQL:" in text: | |
| text = text.split("SQL:")[-1] | |
| match = re.search(r"select[\s\S]*", text, re.IGNORECASE) | |
| if match: | |
| text = match.group(0) | |
| return text.split(";")[0].strip() | |
| def clean_sql(self, sql: str): | |
| sql = sql.replace('"', "'") | |
| sql = re.sub(r"\s+", " ", sql) | |
| return sql.strip() | |
| def generate_sql(self, prompt): | |
| inputs = self.tokenizer( | |
| prompt, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=512 | |
| ).to(self.device) | |
| with torch.no_grad(): | |
| outputs = self.model.generate( | |
| **inputs, | |
| max_new_tokens=128, | |
| num_beams=5, | |
| early_stopping=True | |
| ) | |
| decoded = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| return self.clean_sql(self.extract_sql(decoded)) | |
| def execute_sql(self, question, sql, db_id): | |
| if re.search(self.dml_keywords, sql, re.IGNORECASE): | |
| return sql, [], [], "❌ Security Alert" | |
| # FIXED DATABASE PATH | |
| db_path = DB_ROOT / f"{db_id}.sqlite" | |
| sql = self.clean_sql(sql) | |
| sql = semantic_fix(question, sql) | |
| try: | |
| conn = sqlite3.connect(db_path) | |
| cursor = conn.cursor() | |
| cursor.execute(sql) | |
| rows = cursor.fetchall() | |
| columns = [d[0] for d in cursor.description] if cursor.description else [] | |
| conn.close() | |
| return sql, columns, rows, None | |
| except Exception as e: | |
| return sql, [], [], str(e) | |
| def ask(self, question, db_id): | |
| question = normalize_question(question) | |
| if re.search(self.dml_keywords, question, re.IGNORECASE): | |
| return { | |
| "question": question, | |
| "sql": "-- BLOCKED", | |
| "columns": [], | |
| "rows": [], | |
| "error": "Malicious prompt" | |
| } | |
| schema = self.get_schema(db_id) | |
| prompt = self.build_prompt(question, schema) | |
| raw_sql = self.generate_sql(prompt) | |
| final_sql, cols, rows, error = self.execute_sql(question, raw_sql, db_id) | |
| return { | |
| "question": question, | |
| "sql": final_sql, | |
| "columns": cols, | |
| "rows": rows, | |
| "error": error | |
| } | |
| _engine = None | |
| def get_engine(): | |
| global _engine | |
| if _engine is None: | |
| _engine = Text2SQLEngine() | |
| return _engine | |