Spaces:
Sleeping
Sleeping
Upload 5 files
Browse files- app.py +97 -0
- execution_checker.py +52 -0
- few_shot_retriever.py +71 -0
- requirements.txt +5 -3
- schema_linker.py +76 -0
app.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import json
|
| 3 |
+
import re # <--- Added this to handle reading the text box
|
| 4 |
+
from transformers import T5Tokenizer, T5ForConditionalGeneration
|
| 5 |
+
|
| 6 |
+
# Import the tools from the rest of the team
|
| 7 |
+
from schema_linker import link_schema
|
| 8 |
+
from few_shot_retriever import FewShotRetriever
|
| 9 |
+
from execution_checker import get_best_query
|
| 10 |
+
|
| 11 |
+
# --- ADDED: Teammate D's Regex Parser ---
|
| 12 |
+
def parse_raw_sql_to_dict(raw_sql):
|
| 13 |
+
"""Converts the CREATE TABLE box into a Python dictionary."""
|
| 14 |
+
schema_dict = {}
|
| 15 |
+
table_blocks = re.findall(r'CREATE TABLE\s+(\w+)\s*\((.*?)\);', raw_sql, re.IGNORECASE | re.DOTALL)
|
| 16 |
+
for table_name, columns_str in table_blocks:
|
| 17 |
+
cols = []
|
| 18 |
+
for col_def in columns_str.split(','):
|
| 19 |
+
col_def = col_def.strip()
|
| 20 |
+
if col_def:
|
| 21 |
+
col_name = col_def.split()[0]
|
| 22 |
+
cols.append(col_name)
|
| 23 |
+
schema_dict[table_name] = cols
|
| 24 |
+
return schema_dict
|
| 25 |
+
|
| 26 |
+
# 1. LOAD THE HEAVY AI MODELS ONCE
|
| 27 |
+
@st.cache_resource
|
| 28 |
+
def load_ai_models():
|
| 29 |
+
# Load Teammate B's retriever
|
| 30 |
+
retriever = FewShotRetriever()
|
| 31 |
+
|
| 32 |
+
# --- CHANGED: Now using the pre-trained open-source model! ---
|
| 33 |
+
tokenizer = T5Tokenizer.from_pretrained("alpecevit/flan-t5-base-text2sql")
|
| 34 |
+
model = T5ForConditionalGeneration.from_pretrained("alpecevit/flan-t5-base-text2sql")
|
| 35 |
+
|
| 36 |
+
return retriever, tokenizer, model
|
| 37 |
+
|
| 38 |
+
retriever, tokenizer, model = load_ai_models()
|
| 39 |
+
|
| 40 |
+
# 2. BUILD THE WEBSITE DASHBOARD
|
| 41 |
+
st.title("Natural Language to SQL Engine")
|
| 42 |
+
st.write("Enter your database schema and question below.")
|
| 43 |
+
|
| 44 |
+
# Text box for the user to paste their raw CREATE TABLE statements
|
| 45 |
+
user_raw_schema = st.text_area(
|
| 46 |
+
"Paste your CREATE TABLE statements here:",
|
| 47 |
+
height=150,
|
| 48 |
+
placeholder="CREATE TABLE employees (id INTEGER, name TEXT);\nCREATE TABLE departments (id INTEGER, location TEXT);"
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
# Text box for the English question
|
| 52 |
+
user_question = st.text_input("What do you want to know?", placeholder="e.g., Show me all employees in Chicago")
|
| 53 |
+
|
| 54 |
+
# The big "Generate" button
|
| 55 |
+
if st.button("Generate SQL"):
|
| 56 |
+
if user_raw_schema and user_question:
|
| 57 |
+
|
| 58 |
+
with st.spinner("Processing through the pipeline..."):
|
| 59 |
+
|
| 60 |
+
# --- CHANGED: Now dynamically reads whatever the user pastes! ---
|
| 61 |
+
schema_dict = parse_raw_sql_to_dict(user_raw_schema)
|
| 62 |
+
|
| 63 |
+
# 1. Teammate A tags the schema
|
| 64 |
+
tagged_schema = link_schema(user_question, schema_dict)
|
| 65 |
+
|
| 66 |
+
# 2. Teammate B gets the cheat sheet (we keep this for when your model is ready)
|
| 67 |
+
few_shot_examples = retriever.get_few_shot_prompt(user_question)
|
| 68 |
+
|
| 69 |
+
# 3. Teammate D glues it together for the Prompt
|
| 70 |
+
# FIX: We remove `few_shot_examples` from the prompt so we don't confuse the pre-trained model
|
| 71 |
+
final_prompt = f"Translate English to SQLite: {user_question} \nSchema Context: \n{tagged_schema}"
|
| 72 |
+
|
| 73 |
+
# 4. Generate 5 guesses using Beam Search
|
| 74 |
+
inputs = tokenizer(final_prompt, return_tensors="pt", max_length=1024, truncation=True)
|
| 75 |
+
outputs = model.generate(
|
| 76 |
+
**inputs,
|
| 77 |
+
max_length=256,
|
| 78 |
+
num_beams=5,
|
| 79 |
+
num_return_sequences=5
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
candidate_queries = [tokenizer.decode(out, skip_special_tokens=True) for out in outputs]
|
| 83 |
+
|
| 84 |
+
# --- NEW DEBUG LINE ---
|
| 85 |
+
# This prints the AI's 5 guesses to the website so you can see if it's hallucinating!
|
| 86 |
+
st.warning(f"DEBUG - AI's raw guesses: {candidate_queries}")
|
| 87 |
+
candidate_queries = [tokenizer.decode(out, skip_special_tokens=True) for out in outputs]
|
| 88 |
+
|
| 89 |
+
# 5. Teammate C acts as the firewall
|
| 90 |
+
winning_sql = get_best_query(user_raw_schema, candidate_queries)
|
| 91 |
+
|
| 92 |
+
# --- STEP C: DISPLAY THE RESULT ---
|
| 93 |
+
st.success("Query Generated Successfully!")
|
| 94 |
+
st.code(winning_sql, language="sql")
|
| 95 |
+
|
| 96 |
+
else:
|
| 97 |
+
st.error("Please provide both a schema and a question.")
|
execution_checker.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sqlite3
|
| 2 |
+
|
| 3 |
+
def get_best_query(schema_create_statements, candidate_queries):
|
| 4 |
+
"""
|
| 5 |
+
Creates an in-memory database, builds the user's tables,
|
| 6 |
+
and tests the AI's generated queries to find the first valid one.
|
| 7 |
+
"""
|
| 8 |
+
# 1. Create a temporary database in RAM (disappears when the function ends)
|
| 9 |
+
conn = sqlite3.connect(':memory:')
|
| 10 |
+
cursor = conn.cursor()
|
| 11 |
+
|
| 12 |
+
# 2. Build the empty tables using the user's schema
|
| 13 |
+
try:
|
| 14 |
+
cursor.executescript(schema_create_statements)
|
| 15 |
+
except sqlite3.OperationalError as e:
|
| 16 |
+
return f"Error: The provided schema is invalid. ({e})"
|
| 17 |
+
|
| 18 |
+
# 3. Test the AI's candidate queries
|
| 19 |
+
for i, query in enumerate(candidate_queries):
|
| 20 |
+
try:
|
| 21 |
+
# We use EXPLAIN to check syntax without actually querying data
|
| 22 |
+
cursor.execute(f"EXPLAIN {query}")
|
| 23 |
+
conn.close()
|
| 24 |
+
return f"-- Selected Candidate #{i+1} (Syntax Valid)\n{query}"
|
| 25 |
+
|
| 26 |
+
except sqlite3.OperationalError as e:
|
| 27 |
+
# If there is a syntax error, we ignore it and try the next candidate
|
| 28 |
+
print(f"Candidate {i+1} failed syntax check: {e}")
|
| 29 |
+
continue
|
| 30 |
+
|
| 31 |
+
conn.close()
|
| 32 |
+
return "Error: All generated queries contained syntax errors."
|
| 33 |
+
|
| 34 |
+
# --- TESTING BLOCK ---
|
| 35 |
+
if __name__ == "__main__":
|
| 36 |
+
# The frontend will provide the raw CREATE statements
|
| 37 |
+
test_schema = """
|
| 38 |
+
CREATE TABLE employees (id INTEGER PRIMARY KEY, name TEXT, salary REAL);
|
| 39 |
+
CREATE TABLE departments (id INTEGER PRIMARY KEY, name TEXT);
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
# The AI generates these. Notice the first two have deliberate syntax errors.
|
| 43 |
+
ai_candidates = [
|
| 44 |
+
"SELECT SUMM(salary) FROM employees", # Fails: Bad function name
|
| 45 |
+
"SELECT sum(salary) FROM employees JOIN bad_table", # Fails: Table doesn't exist
|
| 46 |
+
"SELECT sum(salary) FROM employees" # Passes: Perfect SQLite syntax
|
| 47 |
+
]
|
| 48 |
+
|
| 49 |
+
print("Testing AI Candidates against In-Memory DB...\n")
|
| 50 |
+
final_output = get_best_query(test_schema, ai_candidates)
|
| 51 |
+
print("\nWinning Query to show the user:\n" + final_output)
|
| 52 |
+
|
few_shot_retriever.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from sentence_transformers import SentenceTransformer
|
| 3 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
| 4 |
+
|
| 5 |
+
class FewShotRetriever:
|
| 6 |
+
def __init__(self):
|
| 7 |
+
# 1. LOAD THE EMBEDDING MODEL
|
| 8 |
+
# This translates English words into mathematical vectors
|
| 9 |
+
print("Loading BGE Model (This might take a minute the first time)...")
|
| 10 |
+
self.model = SentenceTransformer("BAAI/bge-base-en-v1.5")
|
| 11 |
+
|
| 12 |
+
# 2. LOAD THE HISTORICAL DATA (The Answer Bank)
|
| 13 |
+
# In the final version, you can load a JSON file here.
|
| 14 |
+
# For now, we use a hardcoded list so you can test it immediately.
|
| 15 |
+
self.historical_data = [
|
| 16 |
+
{"q": "What is the average salary of IT staff?", "sql": "SELECT avg(salary) FROM staff WHERE dept = 'IT'"},
|
| 17 |
+
{"q": "Count the number of patients in the ICU.", "sql": "SELECT count(*) FROM patients WHERE ward = 'ICU'"},
|
| 18 |
+
{"q": "Show me the total budget for the marketing department.", "sql": "SELECT sum(budget) FROM departments WHERE name = 'Marketing'"},
|
| 19 |
+
{"q": "Find the average age of all employees.", "sql": "SELECT avg(age) FROM employees"},
|
| 20 |
+
{"q": "How many marketing staff earn more than 50000?", "sql": "SELECT count(*) FROM staff WHERE dept = 'Marketing' AND salary > 50000"}
|
| 21 |
+
]
|
| 22 |
+
|
| 23 |
+
# 3. PRE-CALCULATE THE VECTORS
|
| 24 |
+
# We translate all the historical questions into math ONCE when the script starts
|
| 25 |
+
historical_questions = [item["q"] for item in self.historical_data]
|
| 26 |
+
self.historical_embs = self.model.encode(historical_questions, normalize_embeddings=True)
|
| 27 |
+
|
| 28 |
+
def get_few_shot_prompt(self, user_query, top_k=2):
|
| 29 |
+
"""
|
| 30 |
+
Takes the user's new question, finds the 'top_k' most similar past questions,
|
| 31 |
+
and formats them into a text block.
|
| 32 |
+
"""
|
| 33 |
+
# 1. Translate the NEW question into a math vector
|
| 34 |
+
q_emb = self.model.encode([user_query], normalize_embeddings=True)
|
| 35 |
+
|
| 36 |
+
# 2. Calculate the Cosine Similarity (the mathematical overlap)
|
| 37 |
+
# between the new question and all the past questions
|
| 38 |
+
scores = cosine_similarity(q_emb, self.historical_embs)[0]
|
| 39 |
+
|
| 40 |
+
# 3. Get the index positions of the highest scoring questions
|
| 41 |
+
# argsort() sorts lowest to highest, so we grab from the end [-top_k:] and reverse it [::-1]
|
| 42 |
+
top_indices = scores.argsort()[-top_k:][::-1]
|
| 43 |
+
|
| 44 |
+
# 4. Format the output string for the T5 model
|
| 45 |
+
prompt_prefix = "Here are some examples of translating English to SQL:\n\n"
|
| 46 |
+
|
| 47 |
+
for idx in top_indices:
|
| 48 |
+
past_example = self.historical_data[idx]
|
| 49 |
+
prompt_prefix += f"Example Question: {past_example['q']}\n"
|
| 50 |
+
prompt_prefix += f"Example SQL: {past_example['sql']}\n\n"
|
| 51 |
+
|
| 52 |
+
return prompt_prefix
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
# --- TESTING BLOCK ---
|
| 56 |
+
if __name__ == "__main__":
|
| 57 |
+
|
| 58 |
+
# Initialize the class (this loads the model)
|
| 59 |
+
retriever = FewShotRetriever()
|
| 60 |
+
|
| 61 |
+
# Simulate a user asking a brand new question
|
| 62 |
+
new_query = "What is the average salary of the sales team?"
|
| 63 |
+
|
| 64 |
+
print("\n--- INPUT ---")
|
| 65 |
+
print(f"New User Question: {new_query}")
|
| 66 |
+
|
| 67 |
+
print("\n--- YOUR OUTPUT (The Cheat Sheet) ---")
|
| 68 |
+
# Fetch the top 2 most similar examples
|
| 69 |
+
final_result = retriever.get_few_shot_prompt(new_query, top_k=2)
|
| 70 |
+
print(final_result)
|
| 71 |
+
|
requirements.txt
CHANGED
|
@@ -1,3 +1,5 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
| 1 |
+
streamlit
|
| 2 |
+
transformers
|
| 3 |
+
torch
|
| 4 |
+
sentence-transformers
|
| 5 |
+
scikit-learn
|
schema_linker.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import string
|
| 2 |
+
|
| 3 |
+
def link_schema(user_query, raw_schema):
|
| 4 |
+
"""
|
| 5 |
+
Scans the user's question and tags database columns that match exactly.
|
| 6 |
+
"""
|
| 7 |
+
# 1. CLEAN THE QUERY
|
| 8 |
+
# Convert to lowercase: "Show me the Budget!" -> "show me the budget!"
|
| 9 |
+
query_lower = user_query.lower()
|
| 10 |
+
|
| 11 |
+
# Remove punctuation using Python's string library
|
| 12 |
+
# "show me the budget!" -> "show me the budget"
|
| 13 |
+
for punctuation_mark in string.punctuation:
|
| 14 |
+
query_lower = query_lower.replace(punctuation_mark, "")
|
| 15 |
+
|
| 16 |
+
# Split the clean sentence into an array of individual words
|
| 17 |
+
# ["show", "me", "the", "budget"]
|
| 18 |
+
query_words = set(query_lower.split())
|
| 19 |
+
|
| 20 |
+
# 2. PREPARE THE OUTPUT STORAGE
|
| 21 |
+
# This array will hold the final, formatted strings for each table
|
| 22 |
+
linked_schema_lines = []
|
| 23 |
+
|
| 24 |
+
# 3. ITERATE THROUGH THE SCHEMA
|
| 25 |
+
# raw_schema is a dictionary where the key is the table name,
|
| 26 |
+
# and the value is a list of column names.
|
| 27 |
+
for table_name, column_list in raw_schema.items():
|
| 28 |
+
|
| 29 |
+
tagged_columns = []
|
| 30 |
+
|
| 31 |
+
for col in column_list:
|
| 32 |
+
# We convert the column to lowercase just in case
|
| 33 |
+
col_clean = col.lower()
|
| 34 |
+
|
| 35 |
+
# 4. THE MATCHING LOGIC
|
| 36 |
+
# If the exact column name exists in the array of user words
|
| 37 |
+
if col_clean in query_words:
|
| 38 |
+
# Append the tag so the AI knows this is important
|
| 39 |
+
tagged_columns.append(f'{col} (Exact Match: "{col}")')
|
| 40 |
+
else:
|
| 41 |
+
# Otherwise, just keep the column name as normal
|
| 42 |
+
tagged_columns.append(col)
|
| 43 |
+
|
| 44 |
+
# 5. FORMAT THE FINAL STRING
|
| 45 |
+
# Glue the tagged columns together with commas
|
| 46 |
+
formatted_cols = ", ".join(tagged_columns)
|
| 47 |
+
|
| 48 |
+
# Build the final string for this specific table
|
| 49 |
+
table_string = f"Table: {table_name} | Cols: {formatted_cols}"
|
| 50 |
+
|
| 51 |
+
# Add it to our output storage
|
| 52 |
+
linked_schema_lines.append(table_string)
|
| 53 |
+
|
| 54 |
+
# 6. RETURN THE RESULT
|
| 55 |
+
# Join all the individual table strings together with line breaks
|
| 56 |
+
return " \n".join(linked_schema_lines)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
# --- TESTING BLOCK ---
|
| 60 |
+
if __name__ == "__main__":
|
| 61 |
+
|
| 62 |
+
# Simulate what the frontend UI will hand to your function
|
| 63 |
+
test_question = "What is the location and budget for the marketing department?"
|
| 64 |
+
|
| 65 |
+
# Simulate a messy database schema
|
| 66 |
+
test_schema = {
|
| 67 |
+
"employees": ["id", "name", "department_id", "salary"],
|
| 68 |
+
"departments": ["id", "name", "location", "budget", "industry"]
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
print("--- INPUTS ---")
|
| 72 |
+
print(f"Question: {test_question}")
|
| 73 |
+
|
| 74 |
+
print("\n--- YOUR OUTPUT ---")
|
| 75 |
+
final_result = link_schema(test_question, test_schema)
|
| 76 |
+
print(final_result)
|