import streamlit as st import json import re # <--- Added this to handle reading the text box from transformers import T5Tokenizer, T5ForConditionalGeneration # Import the tools from the rest of the team from schema_linker import link_schema from few_shot_retriever import FewShotRetriever from execution_checker import get_best_query # --- ADDED: Teammate D's Regex Parser --- def parse_raw_sql_to_dict(raw_sql): """Converts the CREATE TABLE box into a Python dictionary.""" schema_dict = {} table_blocks = re.findall(r'CREATE TABLE\s+(\w+)\s*\((.*?)\);', raw_sql, re.IGNORECASE | re.DOTALL) for table_name, columns_str in table_blocks: cols = [] for col_def in columns_str.split(','): col_def = col_def.strip() if col_def: col_name = col_def.split()[0] cols.append(col_name) schema_dict[table_name] = cols return schema_dict # 1. LOAD THE HEAVY AI MODELS ONCE @st.cache_resource def load_ai_models(): # Load Teammate B's retriever retriever = FewShotRetriever() # --- CHANGED: Now using the pre-trained open-source model! --- tokenizer = T5Tokenizer.from_pretrained("alpecevit/flan-t5-base-text2sql") model = T5ForConditionalGeneration.from_pretrained("alpecevit/flan-t5-base-text2sql") return retriever, tokenizer, model retriever, tokenizer, model = load_ai_models() # 2. BUILD THE WEBSITE DASHBOARD st.title("Natural Language to SQL Engine") st.write("Enter your database schema and question below.") # Text box for the user to paste their raw CREATE TABLE statements user_raw_schema = st.text_area( "Paste your CREATE TABLE statements here:", height=150, placeholder="CREATE TABLE employees (id INTEGER, name TEXT);\nCREATE TABLE departments (id INTEGER, location TEXT);" ) # Text box for the English question user_question = st.text_input("What do you want to know?", placeholder="e.g., Show me all employees in Chicago") # The big "Generate" button if st.button("Generate SQL"): if user_raw_schema and user_question: with st.spinner("Processing through the pipeline..."): # --- CHANGED: Now dynamically reads whatever the user pastes! --- schema_dict = parse_raw_sql_to_dict(user_raw_schema) # 1. Teammate A tags the schema tagged_schema = link_schema(user_question, schema_dict) # 2. Teammate B gets the cheat sheet (we keep this for when your model is ready) few_shot_examples = retriever.get_few_shot_prompt(user_question) # 3. Teammate D glues it together for the Prompt # FIX: We remove `few_shot_examples` from the prompt so we don't confuse the pre-trained model final_prompt = f"Translate English to SQLite: {user_question} \nSchema Context: \n{tagged_schema}" # 4. Generate 5 guesses using Beam Search inputs = tokenizer(final_prompt, return_tensors="pt", max_length=1024, truncation=True) outputs = model.generate( **inputs, max_length=256, num_beams=5, num_return_sequences=5 ) candidate_queries = [tokenizer.decode(out, skip_special_tokens=True) for out in outputs] # --- NEW DEBUG LINE --- # This prints the AI's 5 guesses to the website so you can see if it's hallucinating! st.warning(f"DEBUG - AI's raw guesses: {candidate_queries}") candidate_queries = [tokenizer.decode(out, skip_special_tokens=True) for out in outputs] # 5. Teammate C acts as the firewall winning_sql = get_best_query(user_raw_schema, candidate_queries) # --- STEP C: DISPLAY THE RESULT --- st.success("Query Generated Successfully!") st.code(winning_sql, language="sql") else: st.error("Please provide both a schema and a question.")