text2sql-demo / src /sql_validator.py
tjhalanigrid's picture
Add src folder
dc59b01
import sqlite3
import re
from pathlib import Path
class SQLValidator:
def __init__(self, db_root):
self.db_root = Path(db_root)
# ---------------------------
# Load schema
# ---------------------------
def load_schema(self, db_id):
db_path = self.db_root / db_id / f"{db_id}.sqlite"
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
tables = cursor.execute(
"SELECT name FROM sqlite_master WHERE type='table';"
).fetchall()
schema = {}
for (table,) in tables:
cols = cursor.execute(f"PRAGMA table_info({table});").fetchall()
schema[table.lower()] = [c[1].lower() for c in cols]
conn.close()
return schema
# ---------------------------
# Basic syntax check
# ---------------------------
def basic_structure_valid(self, sql):
s = sql.lower()
if "select" not in s or "from" not in s:
return False, "Missing SELECT or FROM"
if len(s.split()) < 4:
return False, "Too short to be SQL"
return True, None
# ---------------------------
# Extract identifiers
# ---------------------------
def extract_identifiers(self, sql):
tokens = re.findall(r"[A-Za-z_]+", sql.lower())
return set(tokens)
# ---------------------------
# Table validation
# ---------------------------
def validate_tables(self, sql, schema):
words = self.extract_identifiers(sql)
tables = set(schema.keys())
used_tables = [w for w in words if w in tables]
if not used_tables:
return False, "No valid table used"
return True, None
# ---------------------------
# Column validation
# ---------------------------
def validate_columns(self, sql, schema):
words = self.extract_identifiers(sql)
valid_columns = set()
for cols in schema.values():
valid_columns.update(cols)
# ignore SQL keywords
keywords = {
"select","from","where","join","on","group","by",
"order","limit","count","sum","avg","min","max",
"and","or","in","like","distinct","asc","desc"
}
invalid = []
for w in words:
if w not in valid_columns and w not in schema and w not in keywords:
if not w.isdigit():
invalid.append(w)
# allow small hallucinations but block many
if len(invalid) > 3:
return False, f"Too many unknown identifiers: {invalid[:5]}"
return True, None
# ---------------------------
# Dangerous query protection
# ---------------------------
def block_dangerous(self, sql):
bad = ["drop", "delete", "update", "insert", "alter"]
s = sql.lower()
for b in bad:
if b in s:
return False, f"Dangerous keyword detected: {b}"
return True, None
# ---------------------------
# Main validation
# ---------------------------
def validate(self, sql, db_id):
schema = self.load_schema(db_id)
checks = [
self.block_dangerous(sql),
self.basic_structure_valid(sql),
self.validate_tables(sql, schema),
self.validate_columns(sql, schema),
]
for ok, msg in checks:
if not ok:
return False, msg
return True, None