import os import gradio as gr import logging import yaml from db import ( get_last_50_saved_queries, initialize_local_db, export_saved_queries_to_csv, execute_sql_query, fetch_and_save_schema, show_last_50_saved_queries, fetch_schema_info, # Now this function exists in db.py ) from openai_integration import generate_sql_single_call # Import the updated function # Initialize logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') # Call the function to ensure the table is created initialize_local_db() # Function to handle user query input and SQL generation with progress def query_database(nl_query, progress=gr.Progress()): try: progress(0, desc="Starting Query Process") # Generate SQL and reformulated query using the updated single call function progress(0.5, desc="Generating Reformulated Query and SQL") reformulated_query, sql_query, total_cost_per_call = generate_sql_single_call(nl_query) # Default empty result in case of SQL query failure execution_result = [] # If we have a SQL query, attempt execution if sql_query and not sql_query.startswith("Error"): progress(0.8, desc="Executing SQL Query") execution_result = execute_sql_query(sql_query) # Ensure execution_result is in a valid format for a DataFrame if not isinstance(execution_result, list) or len(execution_result) == 0: execution_result = [["No results available."]] else: execution_result = [["No results available."]] progress(1, desc="Query Completed") return reformulated_query, sql_query, execution_result, total_cost_per_call except Exception as e: logging.error(f"Error during query generation or execution: {e}") return "Error during query processing.", "", [["No results available due to an error."], ""] # Function to update the schema when requested def update_schema(): schema_info = fetch_and_save_schema() # Case 1: Check if there is an actual error in the schema fetch process if "error" in schema_info: raise gr.Error("Error fetching schema from the database.", duration=3) # Case 2: Check if the schema is empty if not schema_info: # Empty dictionary or None raise gr.Error("No schema data was returned. The schema is empty.", duration=3) # Case 3: Schema successfully fetched return "Schema updated successfully", gr.Info("DB Schema Updated ℹ️", duration=3) # Function to make hidden components visible after the process def continue_process(): # Ensure all three outputs (SQL, result, and cost) are shown return gr.update(visible=True), gr.update(visible=True), gr.update(visible=True) # Function to reset the interface to its initial state def reset_interface(): return gr.update(value=""), gr.update(value=""), gr.update(visible=False), gr.update(visible=False), gr.update(interactive=False) # Enable the submit button only when text is entered def update_button_state(text): if text.strip(): return gr.update(interactive=True) else: return gr.update(interactive=False) # Function to fetch table names from schema and format them for display def get_table_names(): schema_info = fetch_schema_info() if not schema_info: return [] # Return both the original table name and the formatted name return [ (table_name, ' '.join(word.capitalize() for word in table_name.split('_'))) for table_name in schema_info.keys() ] # Fetch table names as a list of tuples (original_name, formatted_name) table_names = get_table_names() # Function to update the query textbox when a button is clicked def insert_table_name(current_text, table_name): # Add the table name to the current text return current_text + " " + table_name # Function to load examples from YAML file def load_examples_from_yaml(file_path): try: with open(file_path, 'r') as file: examples = yaml.safe_load(file) return examples except Exception as e: logging.error(f"Error loading examples: {e}") return [] # Load examples from YAML EXAMPLES_FILE_PATH = os.path.join(os.path.dirname(__file__), 'examples.yaml') examples_list = load_examples_from_yaml(EXAMPLES_FILE_PATH) # Extract the inputs for Gradio examples example_inputs = [example['input'] for example in examples_list] # Create numbered labels for each example (1., 2., 3., etc.) example_labels = [f"{i+1}" for i in range(len(example_inputs))] # Gradio interface setup with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Ubuntu"), "Arial", "sans-serif"], text_size='sm')) as ydcoza_face: text_input = gr.Textbox(lines=2, label="Text Query") gr.HTML("""
Database Tables:
""") # Dynamically create buttons for each table with gr.Row(): # Create Gradio buttons with formatted label and insert original table name on click for original_name, formatted_name in table_names: gr.Button(formatted_name, size="small", elem_classes="ydcoza-small-button").click( fn=lambda current_text, t=original_name: insert_table_name(current_text, t), inputs=text_input, outputs=text_input ) # Create Gradio Examples component examples = gr.Examples( examples=example_inputs, # The actual inputs from the YAML file example_labels=example_labels, # Numbered labels for buttons label="Demo Natural Language Queries", inputs=[text_input] ) reformulated_output = gr.Textbox(lines=2, label="Optimised Query", elem_id='ydcoza_markdown_output_desc') sql_output = gr.Code(label="Generated SQL", visible=False) sql_result_output = gr.Dataframe(label="Query Results", elem_id='result_output', visible=False) # Dataframe for SQL results start_button = gr.Button("Submit Text Query", elem_id='ydcoza_gradio_button', interactive=False) # Add reset button to reset the interface reset_button = gr.Button("Reset Interface", elem_id='ydcoza_gradio_button_reset') reset_button.click( fn=reset_interface, inputs=[], outputs=[text_input, reformulated_output, sql_output, sql_result_output, start_button] ) gr.HTML(""" """) saved_queries_output = gr.Dataframe( label="Last 50 Saved Queries", headers=["Query", "Optimised Query", "SQL", "Timestamp"], interactive=True, visible=False ) # Show the last 50 saved queries when button is clicked show_saved_queries_button = gr.Button("View Queries", elem_id='ydcoza_gradio_button') show_saved_queries_button.click(show_last_50_saved_queries, outputs=saved_queries_output).then( lambda: gr.update(visible=True), outputs=saved_queries_output # Make the saved queries visible ) gr.HTML(""" """) csv_file_output = gr.File(label="Download CSV", visible=False) # Initially hidden download_csv_button = gr.Button("Download Queries", elem_id='ydcoza_gradio_button') download_csv_button.click(export_saved_queries_to_csv, outputs=csv_file_output).then( lambda: gr.update(visible=True), outputs=csv_file_output # Make the file download visible ) gr.HTML(""" """) # Add a button to pull the latest schema and save it to schema.json fetch_schema_button = gr.Button("Fetch Latest Schema", elem_id='ydcoza_gradio_button') fetch_schema_button.click(update_schema) # Output for the cost information (initially hidden) with gr.Row(): html_output_cost = gr.HTML(elem_id='ydcoza_cost_output', visible=False) # Setup the button click to trigger the process and show results text_input.change(fn=update_button_state, inputs=text_input, outputs=start_button) start_button.click( fn=query_database, inputs=[text_input], outputs=[reformulated_output, sql_output, sql_result_output, html_output_cost] # Include the cost output here ).then( continue_process, outputs=[sql_output, sql_result_output, html_output_cost] # Ensure cost is also shown ).then( lambda: gr.update(interactive=False), outputs=start_button ) # Launch the Gradio interface if __name__ == "__main__": ydcoza_face.launch()