import os import gradio as gr from openai import OpenAI import snowflake.connector from typing import Dict, Any import json from datetime import date, datetime import time from decimal import Decimal # Initialize OpenAI client client = OpenAI(api_key=os.environ["OPENAI_API_KEY"]) ASSISTANT_ID = os.environ['ASSISTANT_ID'] # Replace with your actual assistant ID def execute_snowflake_query(query: str) -> Dict[str, Any]: """Execute a Snowflake query and return results with column names.""" conn = None try: conn = snowflake.connector.connect( user=os.environ['SNOWFLAKE_USER'], password=os.environ['SNOWFLAKE_PW'], account=os.environ['SNOWFLAKE_ACCOUNT'], warehouse=os.environ['SNOWFLAKE_WH'], database=os.environ['SNOWFLAKE_DB'], schema=os.environ['SNOWFLAKE_SCHEMA'] ) with conn.cursor() as cur: cur.execute(query) rows = cur.fetchall() column_names = [desc[0] for desc in cur.description] serialized_rows = [ { column_names[i]: ( val.isoformat() if isinstance(val, (date, datetime)) else float(val) if isinstance(val, Decimal) else val ) for i, val in enumerate(row) } for row in rows ] return {"columns": column_names, "rows": serialized_rows} except Exception as e: return {"error": f"Query failed with error: {e}"} finally: if conn: conn.close() def wait_on_run(run, thread): """Wait for the assistant's run to complete.""" while run.status == "queued" or run.status == "in_progress": run = client.beta.threads.runs.retrieve( thread_id=thread.id, run_id=run.id, ) time.sleep(0.5) return run def format_sql(sql): """Format SQL query for display.""" if not sql: return "" return f"```sql\n{sql}\n```" def process_query( user_input: str, history: list, thread_id: str, debug_info: str ) -> tuple: """Process a user query and return updated history and debug information.""" # Initialize debug information for this query current_debug = ["Processing new query..."] try: # Create a new thread if none exists if not thread_id: thread = client.beta.threads.create() thread_id = thread.id current_debug.append(f"Created new thread: {thread_id}") else: thread = client.beta.threads.retrieve(thread_id) current_debug.append(f"Using existing thread: {thread_id}") # Add the user's message to the thread message = client.beta.threads.messages.create( thread_id=thread_id, role="user", content=user_input ) current_debug.append("Added user message to thread") # Create a run run = client.beta.threads.runs.create( thread_id=thread_id, assistant_id=ASSISTANT_ID ) current_debug.append("Created new run") # Initialize variables for SQL tracking latest_sql = "" # Wait for run to complete, handling any required actions while True: run = wait_on_run(run, thread) if run.status == "requires_action": # Handle tool calls tool_outputs = [] for tool_call in run.required_action.submit_tool_outputs.tool_calls: if tool_call.function.name == "ask_snowflake": # Extract and execute SQL query query = json.loads(tool_call.function.arguments)["query"] latest_sql = query current_debug.append(f"Executing SQL query...") # Execute query and get results query_result = execute_snowflake_query(query) tool_outputs.append({ "tool_call_id": tool_call.id, "output": json.dumps(query_result) }) # Submit outputs back to the assistant run = client.beta.threads.runs.submit_tool_outputs( thread_id=thread_id, run_id=run.id, tool_outputs=tool_outputs ) current_debug.append("Submitted query results to assistant") elif run.status == "completed": break elif run.status == "failed": raise Exception(f"Run failed: {run.last_error}") current_debug.append(f"Run status: {run.status}") # Get messages after run completion messages = client.beta.threads.messages.list(thread_id=thread_id) # Update conversation history new_history = history.copy() last_message = next(iter(messages)) # Get most recent message # Add the user's input and assistant's response to history new_history.extend([ [user_input, None], # User message [None, last_message.content[0].text.value] # Assistant response ]) # Combine all debug information full_debug = debug_info + "\n" + "\n".join(current_debug) return new_history, thread_id, full_debug, format_sql(latest_sql) except Exception as e: error_msg = f"Error: {str(e)}" current_debug.append(error_msg) full_debug = debug_info + "\n" + "\n".join(current_debug) return history + [[user_input, error_msg]], thread_id, full_debug, "" def create_new_thread(): """Create a new thread and return empty states.""" return [], "", "Created new thread", "" # Create the Gradio interface with gr.Blocks(css="style.css") as interface: gr.Markdown("# Stock Trading Assistant") with gr.Row(): with gr.Column(scale=2): chatbot = gr.Chatbot( label="Conversation", height=750 ) with gr.Row(): msg = gr.Textbox( label="Your message", placeholder="Ask about stocks...", scale=4 ) submit = gr.Button("Submit", scale=1) with gr.Column(scale=1): sql_display = gr.Markdown(label="Generated SQL") debug_info = gr.Textbox( label="Debug Information", lines=10, max_lines=10 ) new_thread_btn = gr.Button("New Thread") # Hidden state for thread ID thread_id = gr.State("") # Set up event handlers submit_click = submit.click( process_query, inputs=[msg, chatbot, thread_id, debug_info], outputs=[chatbot, thread_id, debug_info, sql_display] ) msg_submit = msg.submit( process_query, inputs=[msg, chatbot, thread_id, debug_info], outputs=[chatbot, thread_id, debug_info, sql_display] ) new_thread_click = new_thread_btn.click( create_new_thread, outputs=[chatbot, thread_id, debug_info, sql_display] ) # Clear message box after submission submit_click.then(lambda: "", None, msg) msg_submit.then(lambda: "", None, msg) # Launch the interface if __name__ == "__main__": interface.launch()