ai_eee_sql_gen / app.py
laudes's picture
Upload 8 files
2cb3f69 verified
import os
import gradio as gr
import logging
import yaml
from db import (
get_last_50_saved_queries,
initialize_local_db,
export_saved_queries_to_csv,
execute_sql_query,
fetch_and_save_schema,
show_last_50_saved_queries,
fetch_schema_info, # Now this function exists in db.py
)
from openai_integration import generate_sql_single_call # Import the updated function
# Initialize logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# Call the function to ensure the table is created
initialize_local_db()
# Function to handle user query input and SQL generation with progress
def query_database(nl_query, progress=gr.Progress()):
try:
progress(0, desc="Starting Query Process")
# Generate SQL and reformulated query using the updated single call function
progress(0.5, desc="Generating Reformulated Query and SQL")
reformulated_query, sql_query, total_cost_per_call = generate_sql_single_call(nl_query)
# Default empty result in case of SQL query failure
execution_result = []
# If we have a SQL query, attempt execution
if sql_query and not sql_query.startswith("Error"):
progress(0.8, desc="Executing SQL Query")
execution_result = execute_sql_query(sql_query)
# Ensure execution_result is in a valid format for a DataFrame
if not isinstance(execution_result, list) or len(execution_result) == 0:
execution_result = [["No results available."]]
else:
execution_result = [["No results available."]]
progress(1, desc="Query Completed")
return reformulated_query, sql_query, execution_result, total_cost_per_call
except Exception as e:
logging.error(f"Error during query generation or execution: {e}")
return "Error during query processing.", "", [["No results available due to an error."], ""]
# Function to update the schema when requested
def update_schema():
schema_info = fetch_and_save_schema()
# Case 1: Check if there is an actual error in the schema fetch process
if "error" in schema_info:
raise gr.Error("Error fetching schema from the database.", duration=3)
# Case 2: Check if the schema is empty
if not schema_info: # Empty dictionary or None
raise gr.Error("No schema data was returned. The schema is empty.", duration=3)
# Case 3: Schema successfully fetched
return "Schema updated successfully", gr.Info("DB Schema Updated ℹ️", duration=3)
# Function to make hidden components visible after the process
def continue_process():
# Ensure all three outputs (SQL, result, and cost) are shown
return gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
# Function to reset the interface to its initial state
def reset_interface():
return gr.update(value=""), gr.update(value=""), gr.update(visible=False), gr.update(visible=False), gr.update(interactive=False)
# Enable the submit button only when text is entered
def update_button_state(text):
if text.strip():
return gr.update(interactive=True)
else:
return gr.update(interactive=False)
# Function to fetch table names from schema and format them for display
def get_table_names():
schema_info = fetch_schema_info()
if not schema_info:
return []
# Return both the original table name and the formatted name
return [
(table_name, ' '.join(word.capitalize() for word in table_name.split('_')))
for table_name in schema_info.keys()
]
# Fetch table names as a list of tuples (original_name, formatted_name)
table_names = get_table_names()
# Function to update the query textbox when a button is clicked
def insert_table_name(current_text, table_name):
# Add the table name to the current text
return current_text + " " + table_name
# Function to load examples from YAML file
def load_examples_from_yaml(file_path):
try:
with open(file_path, 'r') as file:
examples = yaml.safe_load(file)
return examples
except Exception as e:
logging.error(f"Error loading examples: {e}")
return []
# Load examples from YAML
EXAMPLES_FILE_PATH = os.path.join(os.path.dirname(__file__), 'examples.yaml')
examples_list = load_examples_from_yaml(EXAMPLES_FILE_PATH)
# Extract the inputs for Gradio examples
example_inputs = [example['input'] for example in examples_list]
# Create numbered labels for each example (1., 2., 3., etc.)
example_labels = [f"{i+1}" for i in range(len(example_inputs))]
# Gradio interface setup
with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Ubuntu"), "Arial", "sans-serif"], text_size='sm')) as ydcoza_face:
text_input = gr.Textbox(lines=2, label="Text Query")
gr.HTML("""
<p>Database Tables:</p>
""")
# Dynamically create buttons for each table
with gr.Row():
# Create Gradio buttons with formatted label and insert original table name on click
for original_name, formatted_name in table_names:
gr.Button(formatted_name, size="small", elem_classes="ydcoza-small-button").click(
fn=lambda current_text, t=original_name: insert_table_name(current_text, t),
inputs=text_input,
outputs=text_input
)
# Create Gradio Examples component
examples = gr.Examples(
examples=example_inputs, # The actual inputs from the YAML file
example_labels=example_labels, # Numbered labels for buttons
label="Demo Natural Language Queries",
inputs=[text_input]
)
reformulated_output = gr.Textbox(lines=2, label="Optimised Query", elem_id='ydcoza_markdown_output_desc')
sql_output = gr.Code(label="Generated SQL", visible=False)
sql_result_output = gr.Dataframe(label="Query Results", elem_id='result_output', visible=False) # Dataframe for SQL results
start_button = gr.Button("Submit Text Query", elem_id='ydcoza_gradio_button', interactive=False)
# Add reset button to reset the interface
reset_button = gr.Button("Reset Interface", elem_id='ydcoza_gradio_button_reset')
reset_button.click(
fn=reset_interface,
inputs=[],
outputs=[text_input, reformulated_output, sql_output, sql_result_output, start_button]
)
gr.HTML("""
<span class="ydcoza_gradio_banner">View The last 50 Queries generated in Table format.</span>
""")
saved_queries_output = gr.Dataframe(
label="Last 50 Saved Queries",
headers=["Query", "Optimised Query", "SQL", "Timestamp"],
interactive=True,
visible=False
)
# Show the last 50 saved queries when button is clicked
show_saved_queries_button = gr.Button("View Queries", elem_id='ydcoza_gradio_button')
show_saved_queries_button.click(show_last_50_saved_queries, outputs=saved_queries_output).then(
lambda: gr.update(visible=True), outputs=saved_queries_output # Make the saved queries visible
)
gr.HTML("""
<span class="ydcoza_gradio_banner">Download the generated Queries in .csv for you to explore.</span>
""")
csv_file_output = gr.File(label="Download CSV", visible=False) # Initially hidden
download_csv_button = gr.Button("Download Queries", elem_id='ydcoza_gradio_button')
download_csv_button.click(export_saved_queries_to_csv, outputs=csv_file_output).then(
lambda: gr.update(visible=True), outputs=csv_file_output # Make the file download visible
)
gr.HTML("""
<span class="ydcoza_gradio_banner">If you made changes to the database structure we need to import the latest DB Schema.</span>
""")
# Add a button to pull the latest schema and save it to schema.json
fetch_schema_button = gr.Button("Fetch Latest Schema", elem_id='ydcoza_gradio_button')
fetch_schema_button.click(update_schema)
# Output for the cost information (initially hidden)
with gr.Row():
html_output_cost = gr.HTML(elem_id='ydcoza_cost_output', visible=False)
# Setup the button click to trigger the process and show results
text_input.change(fn=update_button_state, inputs=text_input, outputs=start_button)
start_button.click(
fn=query_database,
inputs=[text_input],
outputs=[reformulated_output, sql_output, sql_result_output, html_output_cost] # Include the cost output here
).then(
continue_process,
outputs=[sql_output, sql_result_output, html_output_cost] # Ensure cost is also shown
).then(
lambda: gr.update(interactive=False), outputs=start_button
)
# Launch the Gradio interface
if __name__ == "__main__":
ydcoza_face.launch()