Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import json | |
| import re # <--- Added this to handle reading the text box | |
| from transformers import T5Tokenizer, T5ForConditionalGeneration | |
| # Import the tools from the rest of the team | |
| from schema_linker import link_schema | |
| from few_shot_retriever import FewShotRetriever | |
| from execution_checker import get_best_query | |
| # --- ADDED: Teammate D's Regex Parser --- | |
| def parse_raw_sql_to_dict(raw_sql): | |
| """Converts the CREATE TABLE box into a Python dictionary.""" | |
| schema_dict = {} | |
| table_blocks = re.findall(r'CREATE TABLE\s+(\w+)\s*\((.*?)\);', raw_sql, re.IGNORECASE | re.DOTALL) | |
| for table_name, columns_str in table_blocks: | |
| cols = [] | |
| for col_def in columns_str.split(','): | |
| col_def = col_def.strip() | |
| if col_def: | |
| col_name = col_def.split()[0] | |
| cols.append(col_name) | |
| schema_dict[table_name] = cols | |
| return schema_dict | |
| # 1. LOAD THE HEAVY AI MODELS ONCE | |
| def load_ai_models(): | |
| # Load Teammate B's retriever | |
| retriever = FewShotRetriever() | |
| # --- CHANGED: Now using the pre-trained open-source model! --- | |
| tokenizer = T5Tokenizer.from_pretrained("alpecevit/flan-t5-base-text2sql") | |
| model = T5ForConditionalGeneration.from_pretrained("alpecevit/flan-t5-base-text2sql") | |
| return retriever, tokenizer, model | |
| retriever, tokenizer, model = load_ai_models() | |
| # 2. BUILD THE WEBSITE DASHBOARD | |
| st.title("Natural Language to SQL Engine") | |
| st.write("Enter your database schema and question below.") | |
| # Text box for the user to paste their raw CREATE TABLE statements | |
| user_raw_schema = st.text_area( | |
| "Paste your CREATE TABLE statements here:", | |
| height=150, | |
| placeholder="CREATE TABLE employees (id INTEGER, name TEXT);\nCREATE TABLE departments (id INTEGER, location TEXT);" | |
| ) | |
| # Text box for the English question | |
| user_question = st.text_input("What do you want to know?", placeholder="e.g., Show me all employees in Chicago") | |
| # The big "Generate" button | |
| if st.button("Generate SQL"): | |
| if user_raw_schema and user_question: | |
| with st.spinner("Processing through the pipeline..."): | |
| # --- CHANGED: Now dynamically reads whatever the user pastes! --- | |
| schema_dict = parse_raw_sql_to_dict(user_raw_schema) | |
| # 1. Teammate A tags the schema | |
| tagged_schema = link_schema(user_question, schema_dict) | |
| # 2. Teammate B gets the cheat sheet (we keep this for when your model is ready) | |
| few_shot_examples = retriever.get_few_shot_prompt(user_question) | |
| # 3. Teammate D glues it together for the Prompt | |
| # FIX: We remove `few_shot_examples` from the prompt so we don't confuse the pre-trained model | |
| final_prompt = f"Translate English to SQLite: {user_question} \nSchema Context: \n{tagged_schema}" | |
| # 4. Generate 5 guesses using Beam Search | |
| inputs = tokenizer(final_prompt, return_tensors="pt", max_length=1024, truncation=True) | |
| outputs = model.generate( | |
| **inputs, | |
| max_length=256, | |
| num_beams=5, | |
| num_return_sequences=5 | |
| ) | |
| candidate_queries = [tokenizer.decode(out, skip_special_tokens=True) for out in outputs] | |
| # --- NEW DEBUG LINE --- | |
| # This prints the AI's 5 guesses to the website so you can see if it's hallucinating! | |
| st.warning(f"DEBUG - AI's raw guesses: {candidate_queries}") | |
| candidate_queries = [tokenizer.decode(out, skip_special_tokens=True) for out in outputs] | |
| # 5. Teammate C acts as the firewall | |
| winning_sql = get_best_query(user_raw_schema, candidate_queries) | |
| # --- STEP C: DISPLAY THE RESULT --- | |
| st.success("Query Generated Successfully!") | |
| st.code(winning_sql, language="sql") | |
| else: | |
| st.error("Please provide both a schema and a question.") |