Spaces:
Sleeping
Sleeping
File size: 4,202 Bytes
5e468f2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 | import streamlit as st
import json
import re # <--- Added this to handle reading the text box
from transformers import T5Tokenizer, T5ForConditionalGeneration
# Import the tools from the rest of the team
from schema_linker import link_schema
from few_shot_retriever import FewShotRetriever
from execution_checker import get_best_query
# --- ADDED: Teammate D's Regex Parser ---
def parse_raw_sql_to_dict(raw_sql):
"""Converts the CREATE TABLE box into a Python dictionary."""
schema_dict = {}
table_blocks = re.findall(r'CREATE TABLE\s+(\w+)\s*\((.*?)\);', raw_sql, re.IGNORECASE | re.DOTALL)
for table_name, columns_str in table_blocks:
cols = []
for col_def in columns_str.split(','):
col_def = col_def.strip()
if col_def:
col_name = col_def.split()[0]
cols.append(col_name)
schema_dict[table_name] = cols
return schema_dict
# 1. LOAD THE HEAVY AI MODELS ONCE
@st.cache_resource
def load_ai_models():
# Load Teammate B's retriever
retriever = FewShotRetriever()
# --- CHANGED: Now using the pre-trained open-source model! ---
tokenizer = T5Tokenizer.from_pretrained("alpecevit/flan-t5-base-text2sql")
model = T5ForConditionalGeneration.from_pretrained("alpecevit/flan-t5-base-text2sql")
return retriever, tokenizer, model
retriever, tokenizer, model = load_ai_models()
# 2. BUILD THE WEBSITE DASHBOARD
st.title("Natural Language to SQL Engine")
st.write("Enter your database schema and question below.")
# Text box for the user to paste their raw CREATE TABLE statements
user_raw_schema = st.text_area(
"Paste your CREATE TABLE statements here:",
height=150,
placeholder="CREATE TABLE employees (id INTEGER, name TEXT);\nCREATE TABLE departments (id INTEGER, location TEXT);"
)
# Text box for the English question
user_question = st.text_input("What do you want to know?", placeholder="e.g., Show me all employees in Chicago")
# The big "Generate" button
if st.button("Generate SQL"):
if user_raw_schema and user_question:
with st.spinner("Processing through the pipeline..."):
# --- CHANGED: Now dynamically reads whatever the user pastes! ---
schema_dict = parse_raw_sql_to_dict(user_raw_schema)
# 1. Teammate A tags the schema
tagged_schema = link_schema(user_question, schema_dict)
# 2. Teammate B gets the cheat sheet (we keep this for when your model is ready)
few_shot_examples = retriever.get_few_shot_prompt(user_question)
# 3. Teammate D glues it together for the Prompt
# FIX: We remove `few_shot_examples` from the prompt so we don't confuse the pre-trained model
final_prompt = f"Translate English to SQLite: {user_question} \nSchema Context: \n{tagged_schema}"
# 4. Generate 5 guesses using Beam Search
inputs = tokenizer(final_prompt, return_tensors="pt", max_length=1024, truncation=True)
outputs = model.generate(
**inputs,
max_length=256,
num_beams=5,
num_return_sequences=5
)
candidate_queries = [tokenizer.decode(out, skip_special_tokens=True) for out in outputs]
# --- NEW DEBUG LINE ---
# This prints the AI's 5 guesses to the website so you can see if it's hallucinating!
st.warning(f"DEBUG - AI's raw guesses: {candidate_queries}")
candidate_queries = [tokenizer.decode(out, skip_special_tokens=True) for out in outputs]
# 5. Teammate C acts as the firewall
winning_sql = get_best_query(user_raw_schema, candidate_queries)
# --- STEP C: DISPLAY THE RESULT ---
st.success("Query Generated Successfully!")
st.code(winning_sql, language="sql")
else:
st.error("Please provide both a schema and a question.") |