File size: 4,202 Bytes
5e468f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import streamlit as st
import json
import re # <--- Added this to handle reading the text box
from transformers import T5Tokenizer, T5ForConditionalGeneration

# Import the tools from the rest of the team
from schema_linker import link_schema
from few_shot_retriever import FewShotRetriever
from execution_checker import get_best_query

# --- ADDED: Teammate D's Regex Parser ---
def parse_raw_sql_to_dict(raw_sql):
    """Converts the CREATE TABLE box into a Python dictionary."""
    schema_dict = {}
    table_blocks = re.findall(r'CREATE TABLE\s+(\w+)\s*\((.*?)\);', raw_sql, re.IGNORECASE | re.DOTALL)
    for table_name, columns_str in table_blocks:
        cols = []
        for col_def in columns_str.split(','):
            col_def = col_def.strip()
            if col_def:
                col_name = col_def.split()[0] 
                cols.append(col_name)
        schema_dict[table_name] = cols
    return schema_dict

# 1. LOAD THE HEAVY AI MODELS ONCE
@st.cache_resource
def load_ai_models():
    # Load Teammate B's retriever
    retriever = FewShotRetriever()
    
    # --- CHANGED: Now using the pre-trained open-source model! ---
    tokenizer = T5Tokenizer.from_pretrained("alpecevit/flan-t5-base-text2sql")
    model = T5ForConditionalGeneration.from_pretrained("alpecevit/flan-t5-base-text2sql")
    
    return retriever, tokenizer, model

retriever, tokenizer, model = load_ai_models()

# 2. BUILD THE WEBSITE DASHBOARD
st.title("Natural Language to SQL Engine")
st.write("Enter your database schema and question below.")

# Text box for the user to paste their raw CREATE TABLE statements
user_raw_schema = st.text_area(
    "Paste your CREATE TABLE statements here:", 
    height=150,
    placeholder="CREATE TABLE employees (id INTEGER, name TEXT);\nCREATE TABLE departments (id INTEGER, location TEXT);"
)

# Text box for the English question
user_question = st.text_input("What do you want to know?", placeholder="e.g., Show me all employees in Chicago")

# The big "Generate" button
if st.button("Generate SQL"):
    if user_raw_schema and user_question:
        
        with st.spinner("Processing through the pipeline..."):
            
            # --- CHANGED: Now dynamically reads whatever the user pastes! ---
            schema_dict = parse_raw_sql_to_dict(user_raw_schema)
            
            # 1. Teammate A tags the schema
            tagged_schema = link_schema(user_question, schema_dict)
            
            # 2. Teammate B gets the cheat sheet (we keep this for when your model is ready)
            few_shot_examples = retriever.get_few_shot_prompt(user_question)
            
            # 3. Teammate D glues it together for the Prompt
            # FIX: We remove `few_shot_examples` from the prompt so we don't confuse the pre-trained model
            final_prompt = f"Translate English to SQLite: {user_question} \nSchema Context: \n{tagged_schema}"
            
            # 4. Generate 5 guesses using Beam Search
            inputs = tokenizer(final_prompt, return_tensors="pt", max_length=1024, truncation=True)
            outputs = model.generate(
                **inputs, 
                max_length=256, 
                num_beams=5,             
                num_return_sequences=5   
            )
            
            candidate_queries = [tokenizer.decode(out, skip_special_tokens=True) for out in outputs]
            
            # --- NEW DEBUG LINE ---
            # This prints the AI's 5 guesses to the website so you can see if it's hallucinating!
            st.warning(f"DEBUG - AI's raw guesses: {candidate_queries}")
            candidate_queries = [tokenizer.decode(out, skip_special_tokens=True) for out in outputs]
            
            # 5. Teammate C acts as the firewall
            winning_sql = get_best_query(user_raw_schema, candidate_queries)
            
            # --- STEP C: DISPLAY THE RESULT ---
            st.success("Query Generated Successfully!")
            st.code(winning_sql, language="sql")
            
    else:
        st.error("Please provide both a schema and a question.")