nisar9034 commited on
Commit
5e468f2
·
verified ·
1 Parent(s): 435f01f

Upload 5 files

Browse files
Files changed (5) hide show
  1. app.py +97 -0
  2. execution_checker.py +52 -0
  3. few_shot_retriever.py +71 -0
  4. requirements.txt +5 -3
  5. schema_linker.py +76 -0
app.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import json
3
+ import re # <--- Added this to handle reading the text box
4
+ from transformers import T5Tokenizer, T5ForConditionalGeneration
5
+
6
+ # Import the tools from the rest of the team
7
+ from schema_linker import link_schema
8
+ from few_shot_retriever import FewShotRetriever
9
+ from execution_checker import get_best_query
10
+
11
+ # --- ADDED: Teammate D's Regex Parser ---
12
+ def parse_raw_sql_to_dict(raw_sql):
13
+ """Converts the CREATE TABLE box into a Python dictionary."""
14
+ schema_dict = {}
15
+ table_blocks = re.findall(r'CREATE TABLE\s+(\w+)\s*\((.*?)\);', raw_sql, re.IGNORECASE | re.DOTALL)
16
+ for table_name, columns_str in table_blocks:
17
+ cols = []
18
+ for col_def in columns_str.split(','):
19
+ col_def = col_def.strip()
20
+ if col_def:
21
+ col_name = col_def.split()[0]
22
+ cols.append(col_name)
23
+ schema_dict[table_name] = cols
24
+ return schema_dict
25
+
26
+ # 1. LOAD THE HEAVY AI MODELS ONCE
27
+ @st.cache_resource
28
+ def load_ai_models():
29
+ # Load Teammate B's retriever
30
+ retriever = FewShotRetriever()
31
+
32
+ # --- CHANGED: Now using the pre-trained open-source model! ---
33
+ tokenizer = T5Tokenizer.from_pretrained("alpecevit/flan-t5-base-text2sql")
34
+ model = T5ForConditionalGeneration.from_pretrained("alpecevit/flan-t5-base-text2sql")
35
+
36
+ return retriever, tokenizer, model
37
+
38
+ retriever, tokenizer, model = load_ai_models()
39
+
40
+ # 2. BUILD THE WEBSITE DASHBOARD
41
+ st.title("Natural Language to SQL Engine")
42
+ st.write("Enter your database schema and question below.")
43
+
44
+ # Text box for the user to paste their raw CREATE TABLE statements
45
+ user_raw_schema = st.text_area(
46
+ "Paste your CREATE TABLE statements here:",
47
+ height=150,
48
+ placeholder="CREATE TABLE employees (id INTEGER, name TEXT);\nCREATE TABLE departments (id INTEGER, location TEXT);"
49
+ )
50
+
51
+ # Text box for the English question
52
+ user_question = st.text_input("What do you want to know?", placeholder="e.g., Show me all employees in Chicago")
53
+
54
+ # The big "Generate" button
55
+ if st.button("Generate SQL"):
56
+ if user_raw_schema and user_question:
57
+
58
+ with st.spinner("Processing through the pipeline..."):
59
+
60
+ # --- CHANGED: Now dynamically reads whatever the user pastes! ---
61
+ schema_dict = parse_raw_sql_to_dict(user_raw_schema)
62
+
63
+ # 1. Teammate A tags the schema
64
+ tagged_schema = link_schema(user_question, schema_dict)
65
+
66
+ # 2. Teammate B gets the cheat sheet (we keep this for when your model is ready)
67
+ few_shot_examples = retriever.get_few_shot_prompt(user_question)
68
+
69
+ # 3. Teammate D glues it together for the Prompt
70
+ # FIX: We remove `few_shot_examples` from the prompt so we don't confuse the pre-trained model
71
+ final_prompt = f"Translate English to SQLite: {user_question} \nSchema Context: \n{tagged_schema}"
72
+
73
+ # 4. Generate 5 guesses using Beam Search
74
+ inputs = tokenizer(final_prompt, return_tensors="pt", max_length=1024, truncation=True)
75
+ outputs = model.generate(
76
+ **inputs,
77
+ max_length=256,
78
+ num_beams=5,
79
+ num_return_sequences=5
80
+ )
81
+
82
+ candidate_queries = [tokenizer.decode(out, skip_special_tokens=True) for out in outputs]
83
+
84
+ # --- NEW DEBUG LINE ---
85
+ # This prints the AI's 5 guesses to the website so you can see if it's hallucinating!
86
+ st.warning(f"DEBUG - AI's raw guesses: {candidate_queries}")
87
+ candidate_queries = [tokenizer.decode(out, skip_special_tokens=True) for out in outputs]
88
+
89
+ # 5. Teammate C acts as the firewall
90
+ winning_sql = get_best_query(user_raw_schema, candidate_queries)
91
+
92
+ # --- STEP C: DISPLAY THE RESULT ---
93
+ st.success("Query Generated Successfully!")
94
+ st.code(winning_sql, language="sql")
95
+
96
+ else:
97
+ st.error("Please provide both a schema and a question.")
execution_checker.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sqlite3
2
+
3
+ def get_best_query(schema_create_statements, candidate_queries):
4
+ """
5
+ Creates an in-memory database, builds the user's tables,
6
+ and tests the AI's generated queries to find the first valid one.
7
+ """
8
+ # 1. Create a temporary database in RAM (disappears when the function ends)
9
+ conn = sqlite3.connect(':memory:')
10
+ cursor = conn.cursor()
11
+
12
+ # 2. Build the empty tables using the user's schema
13
+ try:
14
+ cursor.executescript(schema_create_statements)
15
+ except sqlite3.OperationalError as e:
16
+ return f"Error: The provided schema is invalid. ({e})"
17
+
18
+ # 3. Test the AI's candidate queries
19
+ for i, query in enumerate(candidate_queries):
20
+ try:
21
+ # We use EXPLAIN to check syntax without actually querying data
22
+ cursor.execute(f"EXPLAIN {query}")
23
+ conn.close()
24
+ return f"-- Selected Candidate #{i+1} (Syntax Valid)\n{query}"
25
+
26
+ except sqlite3.OperationalError as e:
27
+ # If there is a syntax error, we ignore it and try the next candidate
28
+ print(f"Candidate {i+1} failed syntax check: {e}")
29
+ continue
30
+
31
+ conn.close()
32
+ return "Error: All generated queries contained syntax errors."
33
+
34
+ # --- TESTING BLOCK ---
35
+ if __name__ == "__main__":
36
+ # The frontend will provide the raw CREATE statements
37
+ test_schema = """
38
+ CREATE TABLE employees (id INTEGER PRIMARY KEY, name TEXT, salary REAL);
39
+ CREATE TABLE departments (id INTEGER PRIMARY KEY, name TEXT);
40
+ """
41
+
42
+ # The AI generates these. Notice the first two have deliberate syntax errors.
43
+ ai_candidates = [
44
+ "SELECT SUMM(salary) FROM employees", # Fails: Bad function name
45
+ "SELECT sum(salary) FROM employees JOIN bad_table", # Fails: Table doesn't exist
46
+ "SELECT sum(salary) FROM employees" # Passes: Perfect SQLite syntax
47
+ ]
48
+
49
+ print("Testing AI Candidates against In-Memory DB...\n")
50
+ final_output = get_best_query(test_schema, ai_candidates)
51
+ print("\nWinning Query to show the user:\n" + final_output)
52
+
few_shot_retriever.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from sentence_transformers import SentenceTransformer
3
+ from sklearn.metrics.pairwise import cosine_similarity
4
+
5
+ class FewShotRetriever:
6
+ def __init__(self):
7
+ # 1. LOAD THE EMBEDDING MODEL
8
+ # This translates English words into mathematical vectors
9
+ print("Loading BGE Model (This might take a minute the first time)...")
10
+ self.model = SentenceTransformer("BAAI/bge-base-en-v1.5")
11
+
12
+ # 2. LOAD THE HISTORICAL DATA (The Answer Bank)
13
+ # In the final version, you can load a JSON file here.
14
+ # For now, we use a hardcoded list so you can test it immediately.
15
+ self.historical_data = [
16
+ {"q": "What is the average salary of IT staff?", "sql": "SELECT avg(salary) FROM staff WHERE dept = 'IT'"},
17
+ {"q": "Count the number of patients in the ICU.", "sql": "SELECT count(*) FROM patients WHERE ward = 'ICU'"},
18
+ {"q": "Show me the total budget for the marketing department.", "sql": "SELECT sum(budget) FROM departments WHERE name = 'Marketing'"},
19
+ {"q": "Find the average age of all employees.", "sql": "SELECT avg(age) FROM employees"},
20
+ {"q": "How many marketing staff earn more than 50000?", "sql": "SELECT count(*) FROM staff WHERE dept = 'Marketing' AND salary > 50000"}
21
+ ]
22
+
23
+ # 3. PRE-CALCULATE THE VECTORS
24
+ # We translate all the historical questions into math ONCE when the script starts
25
+ historical_questions = [item["q"] for item in self.historical_data]
26
+ self.historical_embs = self.model.encode(historical_questions, normalize_embeddings=True)
27
+
28
+ def get_few_shot_prompt(self, user_query, top_k=2):
29
+ """
30
+ Takes the user's new question, finds the 'top_k' most similar past questions,
31
+ and formats them into a text block.
32
+ """
33
+ # 1. Translate the NEW question into a math vector
34
+ q_emb = self.model.encode([user_query], normalize_embeddings=True)
35
+
36
+ # 2. Calculate the Cosine Similarity (the mathematical overlap)
37
+ # between the new question and all the past questions
38
+ scores = cosine_similarity(q_emb, self.historical_embs)[0]
39
+
40
+ # 3. Get the index positions of the highest scoring questions
41
+ # argsort() sorts lowest to highest, so we grab from the end [-top_k:] and reverse it [::-1]
42
+ top_indices = scores.argsort()[-top_k:][::-1]
43
+
44
+ # 4. Format the output string for the T5 model
45
+ prompt_prefix = "Here are some examples of translating English to SQL:\n\n"
46
+
47
+ for idx in top_indices:
48
+ past_example = self.historical_data[idx]
49
+ prompt_prefix += f"Example Question: {past_example['q']}\n"
50
+ prompt_prefix += f"Example SQL: {past_example['sql']}\n\n"
51
+
52
+ return prompt_prefix
53
+
54
+
55
+ # --- TESTING BLOCK ---
56
+ if __name__ == "__main__":
57
+
58
+ # Initialize the class (this loads the model)
59
+ retriever = FewShotRetriever()
60
+
61
+ # Simulate a user asking a brand new question
62
+ new_query = "What is the average salary of the sales team?"
63
+
64
+ print("\n--- INPUT ---")
65
+ print(f"New User Question: {new_query}")
66
+
67
+ print("\n--- YOUR OUTPUT (The Cheat Sheet) ---")
68
+ # Fetch the top 2 most similar examples
69
+ final_result = retriever.get_few_shot_prompt(new_query, top_k=2)
70
+ print(final_result)
71
+
requirements.txt CHANGED
@@ -1,3 +1,5 @@
1
- altair
2
- pandas
3
- streamlit
 
 
 
1
+ streamlit
2
+ transformers
3
+ torch
4
+ sentence-transformers
5
+ scikit-learn
schema_linker.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import string
2
+
3
+ def link_schema(user_query, raw_schema):
4
+ """
5
+ Scans the user's question and tags database columns that match exactly.
6
+ """
7
+ # 1. CLEAN THE QUERY
8
+ # Convert to lowercase: "Show me the Budget!" -> "show me the budget!"
9
+ query_lower = user_query.lower()
10
+
11
+ # Remove punctuation using Python's string library
12
+ # "show me the budget!" -> "show me the budget"
13
+ for punctuation_mark in string.punctuation:
14
+ query_lower = query_lower.replace(punctuation_mark, "")
15
+
16
+ # Split the clean sentence into an array of individual words
17
+ # ["show", "me", "the", "budget"]
18
+ query_words = set(query_lower.split())
19
+
20
+ # 2. PREPARE THE OUTPUT STORAGE
21
+ # This array will hold the final, formatted strings for each table
22
+ linked_schema_lines = []
23
+
24
+ # 3. ITERATE THROUGH THE SCHEMA
25
+ # raw_schema is a dictionary where the key is the table name,
26
+ # and the value is a list of column names.
27
+ for table_name, column_list in raw_schema.items():
28
+
29
+ tagged_columns = []
30
+
31
+ for col in column_list:
32
+ # We convert the column to lowercase just in case
33
+ col_clean = col.lower()
34
+
35
+ # 4. THE MATCHING LOGIC
36
+ # If the exact column name exists in the array of user words
37
+ if col_clean in query_words:
38
+ # Append the tag so the AI knows this is important
39
+ tagged_columns.append(f'{col} (Exact Match: "{col}")')
40
+ else:
41
+ # Otherwise, just keep the column name as normal
42
+ tagged_columns.append(col)
43
+
44
+ # 5. FORMAT THE FINAL STRING
45
+ # Glue the tagged columns together with commas
46
+ formatted_cols = ", ".join(tagged_columns)
47
+
48
+ # Build the final string for this specific table
49
+ table_string = f"Table: {table_name} | Cols: {formatted_cols}"
50
+
51
+ # Add it to our output storage
52
+ linked_schema_lines.append(table_string)
53
+
54
+ # 6. RETURN THE RESULT
55
+ # Join all the individual table strings together with line breaks
56
+ return " \n".join(linked_schema_lines)
57
+
58
+
59
+ # --- TESTING BLOCK ---
60
+ if __name__ == "__main__":
61
+
62
+ # Simulate what the frontend UI will hand to your function
63
+ test_question = "What is the location and budget for the marketing department?"
64
+
65
+ # Simulate a messy database schema
66
+ test_schema = {
67
+ "employees": ["id", "name", "department_id", "salary"],
68
+ "departments": ["id", "name", "location", "budget", "industry"]
69
+ }
70
+
71
+ print("--- INPUTS ---")
72
+ print(f"Question: {test_question}")
73
+
74
+ print("\n--- YOUR OUTPUT ---")
75
+ final_result = link_schema(test_question, test_schema)
76
+ print(final_result)