pschofield2 commited on
Commit
e3adea3
·
verified ·
1 Parent(s): 1d2a022

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +222 -0
app.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ from openai import OpenAI
4
+ import snowflake.connector
5
+ from typing import Dict, Any
6
+ import json
7
+ from datetime import date, datetime
8
+ import time
9
+
10
+ # Initialize OpenAI client
11
+ client = OpenAI(api_key=os.environ["OPENAI_API_KEY"])
12
+ ASSISTANT_ID = "your_assistant_id" # Replace with your actual assistant ID
13
+
14
+ def execute_snowflake_query(query: str) -> Dict[str, Any]:
15
+ """Execute a Snowflake query and return results with column names."""
16
+ conn = None
17
+ try:
18
+ conn = snowflake.connector.connect(
19
+ user=os.environ['SNOWFLAKE_USER'],
20
+ password=os.environ['SNOWFLAKE_PW'],
21
+ account=os.environ['SNOWFLAKE_ACCOUNT'],
22
+ warehouse=os.environ['SNOWFLAKE_WH'],
23
+ database=os.environ['SNOWFLAKE_DB'],
24
+ schema=os.environ['SNOWFLAKE_SCHEMA']
25
+ )
26
+
27
+ with conn.cursor() as cur:
28
+ cur.execute(query)
29
+ rows = cur.fetchall()
30
+ column_names = [desc[0] for desc in cur.description]
31
+
32
+ serialized_rows = [
33
+ {
34
+ column_names[i]: (
35
+ val.isoformat() if isinstance(val, (date, datetime)) else val
36
+ )
37
+ for i, val in enumerate(row)
38
+ }
39
+ for row in rows
40
+ ]
41
+
42
+ return {"columns": column_names, "rows": serialized_rows}
43
+ except Exception as e:
44
+ return {"error": f"Query failed with error: {e}"}
45
+ finally:
46
+ if conn:
47
+ conn.close()
48
+
49
+ def wait_on_run(run, thread):
50
+ """Wait for the assistant's run to complete."""
51
+ while run.status == "queued" or run.status == "in_progress":
52
+ run = client.beta.threads.runs.retrieve(
53
+ thread_id=thread.id,
54
+ run_id=run.id,
55
+ )
56
+ time.sleep(0.5)
57
+ return run
58
+
59
+ def format_sql(sql):
60
+ """Format SQL query for display."""
61
+ if not sql:
62
+ return ""
63
+ return f"```sql\n{sql}\n```"
64
+
65
+ def process_query(
66
+ user_input: str,
67
+ history: list,
68
+ thread_id: str,
69
+ debug_info: str
70
+ ) -> tuple:
71
+ """Process a user query and return updated history and debug information."""
72
+
73
+ # Initialize debug information for this query
74
+ current_debug = ["Processing new query..."]
75
+
76
+ try:
77
+ # Create a new thread if none exists
78
+ if not thread_id:
79
+ thread = client.beta.threads.create()
80
+ thread_id = thread.id
81
+ current_debug.append(f"Created new thread: {thread_id}")
82
+ else:
83
+ thread = client.beta.threads.retrieve(thread_id)
84
+ current_debug.append(f"Using existing thread: {thread_id}")
85
+
86
+ # Add the user's message to the thread
87
+ message = client.beta.threads.messages.create(
88
+ thread_id=thread_id,
89
+ role="user",
90
+ content=user_input
91
+ )
92
+ current_debug.append("Added user message to thread")
93
+
94
+ # Create a run
95
+ run = client.beta.threads.runs.create(
96
+ thread_id=thread_id,
97
+ assistant_id=ASSISTANT_ID
98
+ )
99
+ current_debug.append("Created new run")
100
+
101
+ # Initialize variables for SQL tracking
102
+ latest_sql = ""
103
+
104
+ # Wait for run to complete, handling any required actions
105
+ while True:
106
+ run = wait_on_run(run, thread)
107
+
108
+ if run.status == "requires_action":
109
+ # Handle tool calls
110
+ tool_outputs = []
111
+ for tool_call in run.required_action.submit_tool_outputs.tool_calls:
112
+ if tool_call.function.name == "ask_snowflake":
113
+ # Extract and execute SQL query
114
+ query = json.loads(tool_call.function.arguments)["query"]
115
+ latest_sql = query
116
+ current_debug.append(f"Executing SQL query...")
117
+
118
+ # Execute query and get results
119
+ query_result = execute_snowflake_query(query)
120
+ tool_outputs.append({
121
+ "tool_call_id": tool_call.id,
122
+ "output": json.dumps(query_result)
123
+ })
124
+
125
+ # Submit outputs back to the assistant
126
+ run = client.beta.threads.runs.submit_tool_outputs(
127
+ thread_id=thread_id,
128
+ run_id=run.id,
129
+ tool_outputs=tool_outputs
130
+ )
131
+ current_debug.append("Submitted query results to assistant")
132
+
133
+ elif run.status == "completed":
134
+ break
135
+ elif run.status == "failed":
136
+ raise Exception(f"Run failed: {run.last_error}")
137
+
138
+ current_debug.append(f"Run status: {run.status}")
139
+
140
+ # Get messages after run completion
141
+ messages = client.beta.threads.messages.list(thread_id=thread_id)
142
+
143
+ # Update conversation history
144
+ new_history = history.copy()
145
+ last_message = next(iter(messages)) # Get most recent message
146
+
147
+ # Add the user's input and assistant's response to history
148
+ new_history.extend([
149
+ [user_input, None], # User message
150
+ [None, last_message.content[0].text.value] # Assistant response
151
+ ])
152
+
153
+ # Combine all debug information
154
+ full_debug = debug_info + "\n" + "\n".join(current_debug)
155
+
156
+ return new_history, thread_id, full_debug, format_sql(latest_sql)
157
+
158
+ except Exception as e:
159
+ error_msg = f"Error: {str(e)}"
160
+ current_debug.append(error_msg)
161
+ full_debug = debug_info + "\n" + "\n".join(current_debug)
162
+ return history + [[user_input, error_msg]], thread_id, full_debug, ""
163
+
164
+ def create_new_thread():
165
+ """Create a new thread and return empty states."""
166
+ return [], "", "Created new thread", ""
167
+
168
+ # Create the Gradio interface
169
+ with gr.Blocks(css="style.css") as interface:
170
+ gr.Markdown("# Stock Trading Assistant")
171
+
172
+ with gr.Row():
173
+ with gr.Column(scale=2):
174
+ chatbot = gr.Chatbot(
175
+ label="Conversation",
176
+ height=500
177
+ )
178
+ with gr.Row():
179
+ msg = gr.Textbox(
180
+ label="Your message",
181
+ placeholder="Ask about stocks...",
182
+ scale=4
183
+ )
184
+ submit = gr.Button("Submit", scale=1)
185
+
186
+ with gr.Column(scale=1):
187
+ sql_display = gr.Markdown(label="Generated SQL")
188
+ debug_info = gr.Textbox(
189
+ label="Debug Information",
190
+ lines=10,
191
+ max_lines=10
192
+ )
193
+ new_thread_btn = gr.Button("New Thread")
194
+
195
+ # Hidden state for thread ID
196
+ thread_id = gr.State("")
197
+
198
+ # Set up event handlers
199
+ submit_click = submit.click(
200
+ process_query,
201
+ inputs=[msg, chatbot, thread_id, debug_info],
202
+ outputs=[chatbot, thread_id, debug_info, sql_display]
203
+ )
204
+
205
+ msg_submit = msg.submit(
206
+ process_query,
207
+ inputs=[msg, chatbot, thread_id, debug_info],
208
+ outputs=[chatbot, thread_id, debug_info, sql_display]
209
+ )
210
+
211
+ new_thread_click = new_thread_btn.click(
212
+ create_new_thread,
213
+ outputs=[chatbot, thread_id, debug_info, sql_display]
214
+ )
215
+
216
+ # Clear message box after submission
217
+ submit_click.then(lambda: "", None, msg)
218
+ msg_submit.then(lambda: "", None, msg)
219
+
220
+ # Launch the interface
221
+ if __name__ == "__main__":
222
+ interface.launch()