Spaces:
Sleeping
Sleeping
File size: 7,690 Bytes
e3adea3 1cce353 e3adea3 3448fea e3adea3 1cce353 e3adea3 8157e12 e3adea3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 |
import os
import gradio as gr
from openai import OpenAI
import snowflake.connector
from typing import Dict, Any
import json
from datetime import date, datetime
import time
from decimal import Decimal
# Initialize OpenAI client
client = OpenAI(api_key=os.environ["OPENAI_API_KEY"])
ASSISTANT_ID = os.environ['ASSISTANT_ID'] # Replace with your actual assistant ID
def execute_snowflake_query(query: str) -> Dict[str, Any]:
"""Execute a Snowflake query and return results with column names."""
conn = None
try:
conn = snowflake.connector.connect(
user=os.environ['SNOWFLAKE_USER'],
password=os.environ['SNOWFLAKE_PW'],
account=os.environ['SNOWFLAKE_ACCOUNT'],
warehouse=os.environ['SNOWFLAKE_WH'],
database=os.environ['SNOWFLAKE_DB'],
schema=os.environ['SNOWFLAKE_SCHEMA']
)
with conn.cursor() as cur:
cur.execute(query)
rows = cur.fetchall()
column_names = [desc[0] for desc in cur.description]
serialized_rows = [
{
column_names[i]: (
val.isoformat() if isinstance(val, (date, datetime))
else float(val) if isinstance(val, Decimal)
else val
)
for i, val in enumerate(row)
}
for row in rows
]
return {"columns": column_names, "rows": serialized_rows}
except Exception as e:
return {"error": f"Query failed with error: {e}"}
finally:
if conn:
conn.close()
def wait_on_run(run, thread):
"""Wait for the assistant's run to complete."""
while run.status == "queued" or run.status == "in_progress":
run = client.beta.threads.runs.retrieve(
thread_id=thread.id,
run_id=run.id,
)
time.sleep(0.5)
return run
def format_sql(sql):
"""Format SQL query for display."""
if not sql:
return ""
return f"```sql\n{sql}\n```"
def process_query(
user_input: str,
history: list,
thread_id: str,
debug_info: str
) -> tuple:
"""Process a user query and return updated history and debug information."""
# Initialize debug information for this query
current_debug = ["Processing new query..."]
try:
# Create a new thread if none exists
if not thread_id:
thread = client.beta.threads.create()
thread_id = thread.id
current_debug.append(f"Created new thread: {thread_id}")
else:
thread = client.beta.threads.retrieve(thread_id)
current_debug.append(f"Using existing thread: {thread_id}")
# Add the user's message to the thread
message = client.beta.threads.messages.create(
thread_id=thread_id,
role="user",
content=user_input
)
current_debug.append("Added user message to thread")
# Create a run
run = client.beta.threads.runs.create(
thread_id=thread_id,
assistant_id=ASSISTANT_ID
)
current_debug.append("Created new run")
# Initialize variables for SQL tracking
latest_sql = ""
# Wait for run to complete, handling any required actions
while True:
run = wait_on_run(run, thread)
if run.status == "requires_action":
# Handle tool calls
tool_outputs = []
for tool_call in run.required_action.submit_tool_outputs.tool_calls:
if tool_call.function.name == "ask_snowflake":
# Extract and execute SQL query
query = json.loads(tool_call.function.arguments)["query"]
latest_sql = query
current_debug.append(f"Executing SQL query...")
# Execute query and get results
query_result = execute_snowflake_query(query)
tool_outputs.append({
"tool_call_id": tool_call.id,
"output": json.dumps(query_result)
})
# Submit outputs back to the assistant
run = client.beta.threads.runs.submit_tool_outputs(
thread_id=thread_id,
run_id=run.id,
tool_outputs=tool_outputs
)
current_debug.append("Submitted query results to assistant")
elif run.status == "completed":
break
elif run.status == "failed":
raise Exception(f"Run failed: {run.last_error}")
current_debug.append(f"Run status: {run.status}")
# Get messages after run completion
messages = client.beta.threads.messages.list(thread_id=thread_id)
# Update conversation history
new_history = history.copy()
last_message = next(iter(messages)) # Get most recent message
# Add the user's input and assistant's response to history
new_history.extend([
[user_input, None], # User message
[None, last_message.content[0].text.value] # Assistant response
])
# Combine all debug information
full_debug = debug_info + "\n" + "\n".join(current_debug)
return new_history, thread_id, full_debug, format_sql(latest_sql)
except Exception as e:
error_msg = f"Error: {str(e)}"
current_debug.append(error_msg)
full_debug = debug_info + "\n" + "\n".join(current_debug)
return history + [[user_input, error_msg]], thread_id, full_debug, ""
def create_new_thread():
"""Create a new thread and return empty states."""
return [], "", "Created new thread", ""
# Create the Gradio interface
with gr.Blocks(css="style.css") as interface:
gr.Markdown("# Stock Trading Assistant")
with gr.Row():
with gr.Column(scale=2):
chatbot = gr.Chatbot(
label="Conversation",
height=750
)
with gr.Row():
msg = gr.Textbox(
label="Your message",
placeholder="Ask about stocks...",
scale=4
)
submit = gr.Button("Submit", scale=1)
with gr.Column(scale=1):
sql_display = gr.Markdown(label="Generated SQL")
debug_info = gr.Textbox(
label="Debug Information",
lines=10,
max_lines=10
)
new_thread_btn = gr.Button("New Thread")
# Hidden state for thread ID
thread_id = gr.State("")
# Set up event handlers
submit_click = submit.click(
process_query,
inputs=[msg, chatbot, thread_id, debug_info],
outputs=[chatbot, thread_id, debug_info, sql_display]
)
msg_submit = msg.submit(
process_query,
inputs=[msg, chatbot, thread_id, debug_info],
outputs=[chatbot, thread_id, debug_info, sql_display]
)
new_thread_click = new_thread_btn.click(
create_new_thread,
outputs=[chatbot, thread_id, debug_info, sql_display]
)
# Clear message box after submission
submit_click.then(lambda: "", None, msg)
msg_submit.then(lambda: "", None, msg)
# Launch the interface
if __name__ == "__main__":
interface.launch()
|