Spaces:
Sleeping
Sleeping
| import os | |
| import gradio as gr | |
| from openai import OpenAI | |
| import snowflake.connector | |
| from typing import Dict, Any | |
| import json | |
| from datetime import date, datetime | |
| import time | |
| from decimal import Decimal | |
| # Initialize OpenAI client | |
| client = OpenAI(api_key=os.environ["OPENAI_API_KEY"]) | |
| ASSISTANT_ID = os.environ['ASSISTANT_ID'] # Replace with your actual assistant ID | |
| def execute_snowflake_query(query: str) -> Dict[str, Any]: | |
| """Execute a Snowflake query and return results with column names.""" | |
| conn = None | |
| try: | |
| conn = snowflake.connector.connect( | |
| user=os.environ['SNOWFLAKE_USER'], | |
| password=os.environ['SNOWFLAKE_PW'], | |
| account=os.environ['SNOWFLAKE_ACCOUNT'], | |
| warehouse=os.environ['SNOWFLAKE_WH'], | |
| database=os.environ['SNOWFLAKE_DB'], | |
| schema=os.environ['SNOWFLAKE_SCHEMA'] | |
| ) | |
| with conn.cursor() as cur: | |
| cur.execute(query) | |
| rows = cur.fetchall() | |
| column_names = [desc[0] for desc in cur.description] | |
| serialized_rows = [ | |
| { | |
| column_names[i]: ( | |
| val.isoformat() if isinstance(val, (date, datetime)) | |
| else float(val) if isinstance(val, Decimal) | |
| else val | |
| ) | |
| for i, val in enumerate(row) | |
| } | |
| for row in rows | |
| ] | |
| return {"columns": column_names, "rows": serialized_rows} | |
| except Exception as e: | |
| return {"error": f"Query failed with error: {e}"} | |
| finally: | |
| if conn: | |
| conn.close() | |
| def wait_on_run(run, thread): | |
| """Wait for the assistant's run to complete.""" | |
| while run.status == "queued" or run.status == "in_progress": | |
| run = client.beta.threads.runs.retrieve( | |
| thread_id=thread.id, | |
| run_id=run.id, | |
| ) | |
| time.sleep(0.5) | |
| return run | |
| def format_sql(sql): | |
| """Format SQL query for display.""" | |
| if not sql: | |
| return "" | |
| return f"```sql\n{sql}\n```" | |
| def process_query( | |
| user_input: str, | |
| history: list, | |
| thread_id: str, | |
| debug_info: str | |
| ) -> tuple: | |
| """Process a user query and return updated history and debug information.""" | |
| # Initialize debug information for this query | |
| current_debug = ["Processing new query..."] | |
| try: | |
| # Create a new thread if none exists | |
| if not thread_id: | |
| thread = client.beta.threads.create() | |
| thread_id = thread.id | |
| current_debug.append(f"Created new thread: {thread_id}") | |
| else: | |
| thread = client.beta.threads.retrieve(thread_id) | |
| current_debug.append(f"Using existing thread: {thread_id}") | |
| # Add the user's message to the thread | |
| message = client.beta.threads.messages.create( | |
| thread_id=thread_id, | |
| role="user", | |
| content=user_input | |
| ) | |
| current_debug.append("Added user message to thread") | |
| # Create a run | |
| run = client.beta.threads.runs.create( | |
| thread_id=thread_id, | |
| assistant_id=ASSISTANT_ID | |
| ) | |
| current_debug.append("Created new run") | |
| # Initialize variables for SQL tracking | |
| latest_sql = "" | |
| # Wait for run to complete, handling any required actions | |
| while True: | |
| run = wait_on_run(run, thread) | |
| if run.status == "requires_action": | |
| # Handle tool calls | |
| tool_outputs = [] | |
| for tool_call in run.required_action.submit_tool_outputs.tool_calls: | |
| if tool_call.function.name == "ask_snowflake": | |
| # Extract and execute SQL query | |
| query = json.loads(tool_call.function.arguments)["query"] | |
| latest_sql = query | |
| current_debug.append(f"Executing SQL query...") | |
| # Execute query and get results | |
| query_result = execute_snowflake_query(query) | |
| tool_outputs.append({ | |
| "tool_call_id": tool_call.id, | |
| "output": json.dumps(query_result) | |
| }) | |
| # Submit outputs back to the assistant | |
| run = client.beta.threads.runs.submit_tool_outputs( | |
| thread_id=thread_id, | |
| run_id=run.id, | |
| tool_outputs=tool_outputs | |
| ) | |
| current_debug.append("Submitted query results to assistant") | |
| elif run.status == "completed": | |
| break | |
| elif run.status == "failed": | |
| raise Exception(f"Run failed: {run.last_error}") | |
| current_debug.append(f"Run status: {run.status}") | |
| # Get messages after run completion | |
| messages = client.beta.threads.messages.list(thread_id=thread_id) | |
| # Update conversation history | |
| new_history = history.copy() | |
| last_message = next(iter(messages)) # Get most recent message | |
| # Add the user's input and assistant's response to history | |
| new_history.extend([ | |
| [user_input, None], # User message | |
| [None, last_message.content[0].text.value] # Assistant response | |
| ]) | |
| # Combine all debug information | |
| full_debug = debug_info + "\n" + "\n".join(current_debug) | |
| return new_history, thread_id, full_debug, format_sql(latest_sql) | |
| except Exception as e: | |
| error_msg = f"Error: {str(e)}" | |
| current_debug.append(error_msg) | |
| full_debug = debug_info + "\n" + "\n".join(current_debug) | |
| return history + [[user_input, error_msg]], thread_id, full_debug, "" | |
| def create_new_thread(): | |
| """Create a new thread and return empty states.""" | |
| return [], "", "Created new thread", "" | |
| # Create the Gradio interface | |
| with gr.Blocks(css="style.css") as interface: | |
| gr.Markdown("# Stock Trading Assistant") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| chatbot = gr.Chatbot( | |
| label="Conversation", | |
| height=750 | |
| ) | |
| with gr.Row(): | |
| msg = gr.Textbox( | |
| label="Your message", | |
| placeholder="Ask about stocks...", | |
| scale=4 | |
| ) | |
| submit = gr.Button("Submit", scale=1) | |
| with gr.Column(scale=1): | |
| sql_display = gr.Markdown(label="Generated SQL") | |
| debug_info = gr.Textbox( | |
| label="Debug Information", | |
| lines=10, | |
| max_lines=10 | |
| ) | |
| new_thread_btn = gr.Button("New Thread") | |
| # Hidden state for thread ID | |
| thread_id = gr.State("") | |
| # Set up event handlers | |
| submit_click = submit.click( | |
| process_query, | |
| inputs=[msg, chatbot, thread_id, debug_info], | |
| outputs=[chatbot, thread_id, debug_info, sql_display] | |
| ) | |
| msg_submit = msg.submit( | |
| process_query, | |
| inputs=[msg, chatbot, thread_id, debug_info], | |
| outputs=[chatbot, thread_id, debug_info, sql_display] | |
| ) | |
| new_thread_click = new_thread_btn.click( | |
| create_new_thread, | |
| outputs=[chatbot, thread_id, debug_info, sql_display] | |
| ) | |
| # Clear message box after submission | |
| submit_click.then(lambda: "", None, msg) | |
| msg_submit.then(lambda: "", None, msg) | |
| # Launch the interface | |
| if __name__ == "__main__": | |
| interface.launch() | |