Spaces:
Sleeping
Sleeping
| import chainlit as cl | |
| import pandas as pd | |
| import time | |
| from typing import Dict, Any | |
| from agents.table_selection import table_selection_agent | |
| from agents.data_retrieval import sample_data_retrieval_agent | |
| from agents.sql_generation import sql_generation_agent | |
| from agents.validation import query_validation_and_optimization | |
| from agents.execution import execution_agent | |
| from utils.bigquery_utils import init_bigquery_connection | |
| from utils.feedback_utils import save_feedback_to_bigquery | |
| async def on_chat_start(): | |
| """Initialize the chat session.""" | |
| # Initialize BigQuery client | |
| client = init_bigquery_connection() | |
| # Store the client in the user session | |
| cl.user_session.set("client", client) | |
| # Send a welcome message | |
| await cl.Message( | |
| content="👋 Welcome to the Natural Language to SQL Query Assistant! Ask me any question about your e-commerce data.", | |
| author="SQL Assistant" | |
| ).send() | |
| # Add some example questions without using actions | |
| await cl.Message( | |
| content="Here are some example questions you can ask:", | |
| author="SQL Assistant" | |
| ).send() | |
| examples = [ | |
| "What are the top 5 products by revenue?", | |
| "How many orders were placed in the last month?", | |
| "Which customers spent the most in 2023?", | |
| "What is the average order value by product category?" | |
| ] | |
| # Display all examples in a single message | |
| examples_text = "\n\n".join([f"• {example}" for example in examples]) | |
| examples_text += "\n\n(You can copy and paste any of these examples to try them out)" | |
| await cl.Message( | |
| content=examples_text, | |
| author="SQL Assistant" | |
| ).send() | |
| async def on_message(message: cl.Message): | |
| """Handle user messages.""" | |
| query = message.content | |
| # Check if we're in "awaiting feedback" mode | |
| awaiting_feedback = cl.user_session.get("awaiting_feedback", False) | |
| if awaiting_feedback: | |
| client = cl.user_session.get("client") | |
| original_query = cl.user_session.get("original_query") | |
| generated_sql = cl.user_session.get("generated_sql") | |
| optimized_sql = cl.user_session.get("optimized_sql") | |
| # Save the detailed feedback | |
| feedback_details = f"negative: {query}" | |
| success = save_feedback_to_bigquery( | |
| client, | |
| original_query, | |
| generated_sql, | |
| optimized_sql, | |
| feedback_details | |
| ) | |
| # Reset the awaiting feedback flag | |
| cl.user_session.set("awaiting_feedback", False) | |
| if success: | |
| await cl.Message(content="Thanks for your detailed feedback! I've saved it to improve future responses.", author="SQL Assistant").send() | |
| else: | |
| await cl.Message(content="Thanks for your feedback! (Note: There was an issue saving it to the database)", author="SQL Assistant").send() | |
| return | |
| # If not in feedback mode, process as a regular query | |
| # Get the BigQuery client from the user session | |
| client = cl.user_session.get("client") | |
| # Store the original query in the user session for feedback | |
| cl.user_session.set("original_query", query) | |
| # Send a thinking message | |
| thinking_msg = await cl.Message(content="🤔 Thinking...", author="SQL Assistant").send() | |
| try: | |
| # Step 1: Analyze relevant tables | |
| thinking_msg.content = "🔍 Analyzing relevant tables..." | |
| await thinking_msg.update() | |
| # Initialize the state with the query | |
| state = {"sql_query": query, "client": client} | |
| tables_state = table_selection_agent(state) | |
| relevant_tables = tables_state.get("relevant_tables", []) | |
| # Send the tables analysis with a slight delay for better UX | |
| await cl.sleep(1) | |
| if relevant_tables: | |
| tables_text = "I've identified these relevant tables for your query:\n\n" | |
| tables_text += "\n".join([f"- `{table}`" for table in relevant_tables]) | |
| await cl.Message(content=tables_text, author="SQL Assistant").send() | |
| # Step 2: Retrieve sample data | |
| thinking_msg.content = "📊 Retrieving sample data..." | |
| await thinking_msg.update() | |
| await cl.sleep(1) | |
| # Update state with relevant tables and get sample data | |
| state.update(tables_state) | |
| sample_data_state = sample_data_retrieval_agent(state) | |
| # Step 3: Generate SQL | |
| thinking_msg.content = "💻 Generating SQL query..." | |
| await thinking_msg.update() | |
| await cl.sleep(1) | |
| # Update state with sample data and generate SQL | |
| state.update(sample_data_state) | |
| sql_state = sql_generation_agent(state) | |
| generated_sql = sql_state.get("generated_sql", "No SQL generated") | |
| # Store the generated SQL in the user session | |
| cl.user_session.set("generated_sql", generated_sql) | |
| # Send the generated SQL | |
| await cl.Message( | |
| content=f"Here's the SQL query I generated:\n\n```sql\n{generated_sql}\n```", | |
| author="SQL Assistant" | |
| ).send() | |
| # Step 4: Optimize SQL | |
| thinking_msg.content = "🔧 Optimizing the query..." | |
| await thinking_msg.update() | |
| await cl.sleep(1) | |
| # Update state with generated SQL and optimize | |
| state.update(sql_state) | |
| optimization_state = query_validation_and_optimization(state) | |
| optimized_sql = optimization_state.get("optimized_sql", "No optimized SQL") | |
| # Store the optimized SQL in the user session | |
| cl.user_session.set("optimized_sql", optimized_sql) | |
| # Send the optimized SQL | |
| await cl.Message( | |
| content=f"Here's the optimized version of the query:\n\n```sql\n{optimized_sql}\n```", | |
| author="SQL Assistant" | |
| ).send() | |
| # Step 5: Execute query | |
| thinking_msg.content = "⚙️ Executing query..." | |
| await thinking_msg.update() | |
| await cl.sleep(1) | |
| # Update state with optimized SQL and execute | |
| state.update(optimization_state) | |
| execution_state = execution_agent(state) | |
| execution_result = execution_state.get("execution_result", {}) | |
| # Format and send the results | |
| if isinstance(execution_result, dict) and "error" in execution_result: | |
| error_msg = execution_result.get("error", "Unknown error occurred") | |
| await cl.Message( | |
| content=f"❌ Error executing query: {error_msg}", | |
| author="SQL Assistant" | |
| ).send() | |
| elif not execution_result: | |
| await cl.Message( | |
| content="✅ Query executed successfully but returned no results.", | |
| author="SQL Assistant" | |
| ).send() | |
| else: | |
| try: | |
| # Convert results to DataFrame for better display | |
| if isinstance(execution_result[0], tuple): | |
| # Try to get column names from BigQuery schema | |
| try: | |
| # Get the schema from the query job | |
| query_job = client.query(optimized_sql) | |
| schema = query_job.result().schema | |
| column_names = [field.name for field in schema] | |
| # Use these column names for the DataFrame | |
| df = pd.DataFrame(execution_result, columns=column_names) | |
| except Exception: | |
| # Fallback to generic column names | |
| columns = [f"Column_{i}" for i in range(len(execution_result[0]))] | |
| df = pd.DataFrame(execution_result, columns=columns) | |
| else: | |
| df = pd.DataFrame(execution_result) | |
| # Display the DataFrame as a table | |
| await cl.Message( | |
| content="✅ Query executed successfully! Here are the results:", | |
| author="SQL Assistant" | |
| ).send() | |
| # Send the DataFrame as an element | |
| elements = [cl.Dataframe(data=df)] | |
| await cl.Message(content="", elements=elements, author="SQL Assistant").send() | |
| # Also provide a summary of the results with feedback buttons | |
| num_rows = len(df) | |
| num_cols = len(df.columns) | |
| # Ask for feedback using AskActionMessage | |
| res = await cl.AskActionMessage( | |
| content=f"The query returned {num_rows} rows and {num_cols} columns.\n\nWas this result helpful?", | |
| actions=[ | |
| cl.Action(name="feedback_positive", payload={"value": "positive"}, label="👍 Good results"), | |
| cl.Action(name="feedback_negative", payload={"value": "negative"}, label="👎 Not what I wanted") | |
| ], | |
| ).send() | |
| if res: | |
| feedback_value = res.get("payload", {}).get("value") | |
| client = cl.user_session.get("client") | |
| original_query = cl.user_session.get("original_query") | |
| generated_sql = cl.user_session.get("generated_sql") | |
| optimized_sql = cl.user_session.get("optimized_sql") | |
| if feedback_value == "positive": | |
| # Handle positive feedback | |
| success = save_feedback_to_bigquery( | |
| client, | |
| original_query, | |
| generated_sql, | |
| optimized_sql, | |
| "positive" | |
| ) | |
| if success: | |
| await cl.Message(content="Thanks for your positive feedback! I've saved it to improve future responses.", author="SQL Assistant").send() | |
| else: | |
| await cl.Message(content="Thanks for your feedback! (Note: There was an issue saving it to the database)", author="SQL Assistant").send() | |
| elif feedback_value == "negative": | |
| # For negative feedback, just ask for text input | |
| await cl.Message(content="I'm sorry the results weren't what you expected. Please type your feedback about what was wrong.", author="SQL Assistant").send() | |
| # Set flag to indicate we're awaiting detailed feedback | |
| cl.user_session.set("awaiting_feedback", True) | |
| # Save initial negative feedback | |
| save_feedback_to_bigquery( | |
| client, | |
| original_query, | |
| generated_sql, | |
| optimized_sql, | |
| "negative" | |
| ) | |
| except Exception as e: | |
| await cl.Message( | |
| content=f"❌ Error formatting results: {str(e)}", | |
| author="SQL Assistant" | |
| ).send() | |
| except Exception as e: | |
| # Handle any errors | |
| thinking_msg.content = f"❌ Error: {str(e)}" | |
| await thinking_msg.update() | |
| await cl.Message( | |
| content=f"I encountered an error while processing your query: {str(e)}", | |
| author="SQL Assistant" | |
| ).send() | |
| # Callback handlers for actions | |
| async def on_feedback_positive(action): | |
| """Handle positive feedback.""" | |
| client = cl.user_session.get("client") | |
| original_query = cl.user_session.get("original_query") | |
| generated_sql = cl.user_session.get("generated_sql") | |
| optimized_sql = cl.user_session.get("optimized_sql") | |
| # Handle positive feedback | |
| success = save_feedback_to_bigquery( | |
| client, | |
| original_query, | |
| generated_sql, | |
| optimized_sql, | |
| "positive" | |
| ) | |
| if success: | |
| await cl.Message(content="Thanks for your positive feedback! I've saved it to improve future responses.", author="SQL Assistant").send() | |
| else: | |
| await cl.Message(content="Thanks for your feedback! (Note: There was an issue saving it to the database)", author="SQL Assistant").send() | |
| async def on_feedback_negative(action): | |
| """Handle negative feedback.""" | |
| # Ask for more detailed feedback | |
| await cl.Message(content="I'm sorry the results weren't what you expected. Please type your feedback about what was wrong.", author="SQL Assistant").send() | |
| # Set flag to indicate we're awaiting detailed feedback | |
| cl.user_session.set("awaiting_feedback", True) | |
| client = cl.user_session.get("client") | |
| original_query = cl.user_session.get("original_query") | |
| generated_sql = cl.user_session.get("generated_sql") | |
| optimized_sql = cl.user_session.get("optimized_sql") | |
| # Save initial negative feedback | |
| save_feedback_to_bigquery( | |
| client, | |
| original_query, | |
| generated_sql, | |
| optimized_sql, | |
| "negative" | |
| ) | |
| # This is needed for Chainlit to run properly | |
| if __name__ == "__main__": | |
| # Note: Chainlit uses its own CLI command to run the app | |
| # You'll run this with: chainlit run new_app.py -w | |
| pass |