pschofield2's picture
Update app.py
8157e12 verified
import os
import gradio as gr
from openai import OpenAI
import snowflake.connector
from typing import Dict, Any
import json
from datetime import date, datetime
import time
from decimal import Decimal
# Initialize OpenAI client
client = OpenAI(api_key=os.environ["OPENAI_API_KEY"])
ASSISTANT_ID = os.environ['ASSISTANT_ID'] # Replace with your actual assistant ID
def execute_snowflake_query(query: str) -> Dict[str, Any]:
"""Execute a Snowflake query and return results with column names."""
conn = None
try:
conn = snowflake.connector.connect(
user=os.environ['SNOWFLAKE_USER'],
password=os.environ['SNOWFLAKE_PW'],
account=os.environ['SNOWFLAKE_ACCOUNT'],
warehouse=os.environ['SNOWFLAKE_WH'],
database=os.environ['SNOWFLAKE_DB'],
schema=os.environ['SNOWFLAKE_SCHEMA']
)
with conn.cursor() as cur:
cur.execute(query)
rows = cur.fetchall()
column_names = [desc[0] for desc in cur.description]
serialized_rows = [
{
column_names[i]: (
val.isoformat() if isinstance(val, (date, datetime))
else float(val) if isinstance(val, Decimal)
else val
)
for i, val in enumerate(row)
}
for row in rows
]
return {"columns": column_names, "rows": serialized_rows}
except Exception as e:
return {"error": f"Query failed with error: {e}"}
finally:
if conn:
conn.close()
def wait_on_run(run, thread):
"""Wait for the assistant's run to complete."""
while run.status == "queued" or run.status == "in_progress":
run = client.beta.threads.runs.retrieve(
thread_id=thread.id,
run_id=run.id,
)
time.sleep(0.5)
return run
def format_sql(sql):
"""Format SQL query for display."""
if not sql:
return ""
return f"```sql\n{sql}\n```"
def process_query(
user_input: str,
history: list,
thread_id: str,
debug_info: str
) -> tuple:
"""Process a user query and return updated history and debug information."""
# Initialize debug information for this query
current_debug = ["Processing new query..."]
try:
# Create a new thread if none exists
if not thread_id:
thread = client.beta.threads.create()
thread_id = thread.id
current_debug.append(f"Created new thread: {thread_id}")
else:
thread = client.beta.threads.retrieve(thread_id)
current_debug.append(f"Using existing thread: {thread_id}")
# Add the user's message to the thread
message = client.beta.threads.messages.create(
thread_id=thread_id,
role="user",
content=user_input
)
current_debug.append("Added user message to thread")
# Create a run
run = client.beta.threads.runs.create(
thread_id=thread_id,
assistant_id=ASSISTANT_ID
)
current_debug.append("Created new run")
# Initialize variables for SQL tracking
latest_sql = ""
# Wait for run to complete, handling any required actions
while True:
run = wait_on_run(run, thread)
if run.status == "requires_action":
# Handle tool calls
tool_outputs = []
for tool_call in run.required_action.submit_tool_outputs.tool_calls:
if tool_call.function.name == "ask_snowflake":
# Extract and execute SQL query
query = json.loads(tool_call.function.arguments)["query"]
latest_sql = query
current_debug.append(f"Executing SQL query...")
# Execute query and get results
query_result = execute_snowflake_query(query)
tool_outputs.append({
"tool_call_id": tool_call.id,
"output": json.dumps(query_result)
})
# Submit outputs back to the assistant
run = client.beta.threads.runs.submit_tool_outputs(
thread_id=thread_id,
run_id=run.id,
tool_outputs=tool_outputs
)
current_debug.append("Submitted query results to assistant")
elif run.status == "completed":
break
elif run.status == "failed":
raise Exception(f"Run failed: {run.last_error}")
current_debug.append(f"Run status: {run.status}")
# Get messages after run completion
messages = client.beta.threads.messages.list(thread_id=thread_id)
# Update conversation history
new_history = history.copy()
last_message = next(iter(messages)) # Get most recent message
# Add the user's input and assistant's response to history
new_history.extend([
[user_input, None], # User message
[None, last_message.content[0].text.value] # Assistant response
])
# Combine all debug information
full_debug = debug_info + "\n" + "\n".join(current_debug)
return new_history, thread_id, full_debug, format_sql(latest_sql)
except Exception as e:
error_msg = f"Error: {str(e)}"
current_debug.append(error_msg)
full_debug = debug_info + "\n" + "\n".join(current_debug)
return history + [[user_input, error_msg]], thread_id, full_debug, ""
def create_new_thread():
"""Create a new thread and return empty states."""
return [], "", "Created new thread", ""
# Create the Gradio interface
with gr.Blocks(css="style.css") as interface:
gr.Markdown("# Stock Trading Assistant")
with gr.Row():
with gr.Column(scale=2):
chatbot = gr.Chatbot(
label="Conversation",
height=750
)
with gr.Row():
msg = gr.Textbox(
label="Your message",
placeholder="Ask about stocks...",
scale=4
)
submit = gr.Button("Submit", scale=1)
with gr.Column(scale=1):
sql_display = gr.Markdown(label="Generated SQL")
debug_info = gr.Textbox(
label="Debug Information",
lines=10,
max_lines=10
)
new_thread_btn = gr.Button("New Thread")
# Hidden state for thread ID
thread_id = gr.State("")
# Set up event handlers
submit_click = submit.click(
process_query,
inputs=[msg, chatbot, thread_id, debug_info],
outputs=[chatbot, thread_id, debug_info, sql_display]
)
msg_submit = msg.submit(
process_query,
inputs=[msg, chatbot, thread_id, debug_info],
outputs=[chatbot, thread_id, debug_info, sql_display]
)
new_thread_click = new_thread_btn.click(
create_new_thread,
outputs=[chatbot, thread_id, debug_info, sql_display]
)
# Clear message box after submission
submit_click.then(lambda: "", None, msg)
msg_submit.then(lambda: "", None, msg)
# Launch the interface
if __name__ == "__main__":
interface.launch()