Omkar1872 commited on
Commit
655b738
·
verified ·
1 Parent(s): 06e440b

Update streamlit_app.py

Browse files
Files changed (1) hide show
  1. streamlit_app.py +76 -624
streamlit_app.py CHANGED
@@ -1,624 +1,76 @@
1
- import sys
2
- sys.dont_write_bytecode =True
3
-
4
- import os
5
- sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
6
-
7
- import chainlit as cl
8
- import json
9
- from enum import Enum
10
- from typing import List,Type
11
- from pydantic import BaseModel,Field
12
- from core.helper import print_colored
13
- import asyncio
14
- import base64
15
- import plotly.graph_objects as go
16
-
17
- # -------------------------------- Structured Agent -------------------------------------
18
-
19
- import json
20
- from enum import Enum
21
- from typing import List,Type
22
- from core.helper import print_colored
23
- from pydantic import BaseModel,Field
24
-
25
- class ChainlitStructuredAgent:
26
-
27
- def __init__(self,model,agent_name,agent_description,agent_instructions,tools=[],assistant_agents=[],max_allowed_attempts=10,vission_model=None,vission_model_prompt=None) -> None:
28
- self.model = model
29
- self.vission_model = vission_model
30
- self.vission_model_prompt = vission_model_prompt
31
- self.agent_name = agent_name
32
- self.agent_description = agent_description
33
- self.agent_instructions=agent_instructions
34
- self.tools = tools
35
- self.assistant_agents = assistant_agents
36
- self.tool_names = []
37
- self.max_allowed_attempts= max_allowed_attempts
38
- self.attempts_made = 0
39
- self.messages = []
40
-
41
- if len(self.assistant_agents):
42
-
43
- self.prepare_prompt()
44
- self.agents_as_tools = {agent.agent_name:agent for agent in assistant_agents}
45
- self.assistants_names = []
46
-
47
- self.response_format = self.prepare_Default_tools()
48
-
49
- if len(self.tools):
50
-
51
- self.tool_objects = {i:j for i,j in zip(self.tool_names,tools)}
52
-
53
- tool_schemas = self.prepare_schema_from_tool(self.tools)
54
- self.agent_instructions+="""\n## Available Tools:\n"""
55
- self.agent_instructions+=f"""\nYou have access to the following tools:\n{tool_schemas}\nYou must use any one of these tools to answer the user question.\n\n"""
56
- self.agent_instructions+="""IMPORTANT!: You must provide your response in the below json format.
57
- {
58
- "thoughts":["Always you should think before taking any action"],
59
- "tool_name":"Name of the tool",
60
- "tool_args":{"arg_name":"arg_value"}
61
- }
62
- """
63
-
64
- def prepare_Default_tools(self):
65
-
66
- # Prepare final answer tool
67
- class FinalAnswer(BaseModel):
68
- final_answer : str = Field(description="Your final response to the user")
69
- def run(self):
70
- return self.final_answer
71
-
72
- self.tools.append(FinalAnswer)
73
-
74
- # Prepare Assign Task tool
75
- if len(self.assistant_agents):
76
-
77
- self.assistants_names = [i.agent_name for i in self.assistant_agents]
78
-
79
- recipients = Enum("recipient", {name: name for name in self.assistants_names})
80
-
81
- assistant_description = f"Select the correct Agent to assign the task : {self.assistants_names}\n\n"
82
-
83
- for assistant in self.assistant_agents:
84
-
85
- assistant_description+=assistant.agent_name+" : "+assistant.agent_description+"\n"
86
-
87
- class AssignTask(BaseModel):
88
-
89
- """Use this tool to facilitate direct, synchronous communication between specialized agents within your agency. When you send a message using this tool, you receive a response exclusively from the designated recipient agent. To continue the dialogue, invoke this tool again with the desired recipient agent and your follow-up message. Remember, communication here is synchronous; the recipient agent won't perform any tasks post-response. You are responsible for relaying the recipient agent's responses back to the user, as the user does not have direct access to these replies. Keep engaging with the tool for continuous interaction until the task is fully resolved. Do not send more than 1 message at a time."""
90
-
91
- my_primary_instructions: str = Field(...,
92
- description="Please repeat your primary instructions step-by-step, including both completed "
93
- "and the following next steps that you need to perform. For multi-step, complex tasks, first break them down "
94
- "into smaller steps yourself. Then, issue each step individually to the "
95
- "recipient agent via the message parameter. Each identified step should be "
96
- "sent in separate message. Keep in mind, that the recipient agent does not have access "
97
- "to these instructions. You must include recipient agent-specific instructions "
98
- "in the message or additional_instructions parameters.")
99
- recipient: recipients = Field(..., description=assistant_description,examples=self.assistants_names)
100
-
101
- task_details: str = Field(...,
102
- description="Specify the task required for the recipient agent to complete. Focus on "
103
- "clarifying what the task entails, rather than providing exact "
104
- "instructions.")
105
-
106
- additional_instructions: str = Field(description="Any additional instructions or clarifications that you would like to provide to the recipient agent.")
107
-
108
- self.tools.append(AssignTask)
109
-
110
- self.tool_names = [i.__name__ for i in self.tools]
111
-
112
- # class ToolChoices(BaseModel):
113
- # thoughts: List[str] = Field(description="Your Thoughts")
114
- # tool_name : Literal[*self.tool_names] = Field(description=f"Select an appropriate tools from : {self.tool_names}",examples=self.tool_names)
115
- # tool_args : Union[*self.tools]
116
-
117
- # return ToolChoices
118
-
119
- def prepare_schema_from_tool(self,Tools: List[Type[BaseModel]]) -> List[dict]:
120
- schemas = ""
121
- for tool in Tools:
122
- schema = tool.model_json_schema()
123
- schemas+="\n"
124
- schemas += f"""
125
- "Tool Name": {tool.__name__},
126
- "Tool Description": {tool.__doc__},
127
- "Tool Parameters":
128
- "Properties": {schema["properties"]},
129
- "Required": {schema["required"]},
130
- "Type": {schema["type"]}\n"""
131
- schemas+="\n"
132
-
133
- return schemas
134
-
135
- def prepare_prompt(self):
136
-
137
- if len(self.assistant_agents):
138
-
139
- self.agent_instructions+="\n**Task Assignment**: You can assign tasks to the following agents who are responsible to help you to achieve your goal.\n"
140
-
141
- self.agent_instructions+="-----------------------------------------------\n"
142
-
143
- for agent in self.assistant_agents:
144
-
145
- self.agent_instructions+="- **Agent Name**: "+agent.agent_name+"\n"
146
- self.agent_instructions+="- **Agent Description**:\n"+agent.agent_description+"\n"
147
-
148
- self.agent_instructions+="\n-----------------------------------------------\n"
149
-
150
- def prepare_messages(self,content,role=None,messages=[]):
151
-
152
- if not len(messages):
153
-
154
- messages = [
155
- {"role":"system","content":self.agent_instructions},
156
- {"role":"user","content":content}
157
- ]
158
-
159
- else:
160
-
161
- messages.append({"role":role,"content":content})
162
-
163
- return messages
164
-
165
- def construct_message_from_output(self,output):
166
-
167
- content = ""
168
-
169
- for key, value in output.items():
170
-
171
- if key=='task_checklist':
172
-
173
- content+=f"Task Check List:\n{"\n".join(value)}\n"
174
-
175
- elif key=='completed_tasks':
176
-
177
- content+=f"Completed Tasks:\n{"\n".join(value)}\n"
178
-
179
- elif key=='thoughts':
180
-
181
- content+=f"Thought: {"\n".join(value)}\n"
182
-
183
- elif key=='tool_name':
184
-
185
- if value=="AssignTask":
186
-
187
- content+=f"I need to assign a task to the '{output['tool_args']['recipient']}' with the following instructions:"
188
- else:
189
-
190
- content+=f"We need to use the '{value}' tool with the following arguments: "
191
-
192
- else:
193
-
194
- content+=f"{value}."
195
-
196
- return content
197
- @cl.step(type="tool")
198
- def execute_tool(self,messages,tool_details):
199
-
200
- try:
201
-
202
- # assistant_content=self.construct_message_from_output(tool_details)
203
- assistant_content=str(tool_details)
204
-
205
- except Exception as e:
206
-
207
- invalid_arg_error_message ="Error while executing tool. Please check the tool name or provide a valid arguments to the tool: "+str(e)
208
-
209
- tool_content = invalid_arg_error_message
210
-
211
- assistant_content = str(tool_details)
212
-
213
- messages.append({"role":"assistant","content":assistant_content})
214
-
215
- messages.append({"role":"user","content":tool_content})
216
-
217
- return messages
218
-
219
- if tool_details['tool_name'] in self.tool_names :
220
-
221
- if tool_details['tool_name'] == 'AssignTask':
222
-
223
- try:
224
-
225
- arguments = tool_details['tool_args']
226
-
227
- task_details =arguments.get('task_details',"")
228
-
229
- additional_instructions =arguments.get('additional_instructions',"")
230
-
231
- print_colored(f"{self.agent_name} assigned a task to {arguments['recipient']}","cyan")
232
-
233
- assistant_agent = self.agents_as_tools[arguments['recipient']]
234
-
235
- user_input = task_details + "\n" + additional_instructions
236
-
237
- print_colored("Task Details: \n"+user_input,"white")
238
-
239
- tool_content = assistant_agent.run(user_input)
240
-
241
- tool_content = f"Response from the {arguments['recipient']} : "+tool_content
242
-
243
- except Exception as e:
244
-
245
- print_colored("Error Tool: "+str(e),"red")
246
-
247
- tool_content = f"Error while assigning task to {arguments['recipient']}. Please provide a correct agent name: {[i.agent_name for i in self.assistant_agents]}"
248
-
249
- else:
250
-
251
- try:
252
-
253
- print_colored(f"{self.agent_name} : Calling Tool {tool_details['tool_name']}","yellow")
254
-
255
- tool_output = self.tool_objects[tool_details['tool_name']](**tool_details['tool_args']).run()
256
-
257
- print_colored(f"Got output....{len(tool_output)}","orange")
258
-
259
- tool_content=f"Output From {tool_details['tool_name']} Tool: {tool_output}"
260
-
261
- if tool_details['tool_name'] == 'JupyterNotebookTool' and len(tool_output):
262
-
263
- for i in range(len(tool_output)):
264
-
265
- elements = []
266
-
267
- # print('Keys: ', tool_output[i].keys())
268
-
269
- # print('expected_output_type: ', tool_output[i].get('expected_output_type',""))
270
-
271
- # print('display_type: ', tool_output[i].get('display_type',""))
272
-
273
- if tool_output[i]['display_type']=='text':
274
-
275
- elements.append(cl.Text(name="Tool Output", content=tool_output[i]['final_output'], display="inline"))
276
-
277
- # asyncio.run(cl.Message("",elements=elements).send())
278
-
279
- elif tool_output[i]['display_type']=='image':
280
-
281
- image_data = tool_output[i]['output'].split('base64,')[1]
282
-
283
- # Decode the base64 image data
284
- image_bytes = base64.b64decode(image_data)
285
-
286
- elements.append(cl.Image(name="Plot",size='medium',content=image_bytes,display="inline"))
287
-
288
- # asyncio.run(cl.Message("",elements=elements).send())
289
-
290
- image_base64 = base64.b64encode(base64.b64decode(image_data)).decode('utf-8')
291
-
292
- if self.vission_model:
293
-
294
- print_colored("Describing the Image...","green")
295
-
296
- plot_dec=asyncio.run(self.vission_model.get_output(self.vission_model_prompt,base64_image=image_base64))
297
-
298
- tool_output[i]['final_output'] +="\n"+plot_dec
299
-
300
- print_colored(f"Image Description.: {tool_output[i]['final_output']}","green")
301
-
302
- tool_output[i]['final_output'] =plot_dec
303
-
304
- elements.append(cl.Text(name="Insights",content=plot_dec))
305
-
306
- # asyncio.run(cl.Message("",elements=elements).send())
307
-
308
- elif tool_output[i]['display_type']=='plotly':
309
-
310
- print_colored("Trying to display plotly.....","red")
311
-
312
- fig = go.Figure(data=tool_output[i]['output'])
313
-
314
- print_colored(f"Created fig.....{type(fig)}","red")
315
-
316
- # Convert the Plotly figure to a PNG image
317
- img_bytes = fig.to_image(format="png")
318
-
319
- print_colored(f"Created img_bytes.....{type(fig)}","red")
320
-
321
- # Encode the image to base64
322
- image_base64 = base64.b64encode(img_bytes).decode('utf-8')
323
-
324
- print_colored(f"Created image_base64.....{type(fig)}","red")
325
-
326
- elements.append(cl.Plotly(name="Chart", size="medium",figure=fig, display="inline"))
327
-
328
- # asyncio.run(cl.Message("",elements=elements).send())
329
-
330
- if self.vission_model:
331
-
332
- print_colored("Describing the plotly...","green")
333
-
334
- plot_dec = asyncio.run(self.vission_model.get_output(self.vission_model_prompt,base64_image=image_base64))
335
-
336
- print_colored(f"Plotly Description.: {tool_output[i]['final_output']}","green")
337
-
338
- tool_output[i]['final_output'] =plot_dec
339
-
340
- elements.append(cl.Text(name="Insights",content=plot_dec,display="inline"))
341
-
342
- # asyncio.run(cl.Message("",elements=elements).send())
343
- else:
344
-
345
- elements.append(cl.Text(name="", content=tool_output[i]['final_output'], display="inline"))
346
-
347
- asyncio.run(cl.Message("Output",elements=elements).send())
348
-
349
- tool_output = [i['final_output'][:30000] for i in tool_output]
350
-
351
- tool_content=f"Output From {tool_details['tool_name']} Tool: {"\n".join(tool_output)}"
352
-
353
- # elements = [cl.Text("Tool Output",content=str(tool_output))]
354
-
355
- # asyncio.run(cl.Message("",elements=elements).send())
356
-
357
- except Exception as e:
358
-
359
- print_colored("Error Tool: "+str(e),"red")
360
-
361
- tool_content = "Error while executing tool. Please check the tool name or provide a valid arguments to the tool: "+str(e)
362
-
363
- else:
364
-
365
- tool_content= "There is no such a tool available. Here are the available tools : "+str(self.tool_names)
366
-
367
- messages.append({"role":"assistant","content":assistant_content.strip()})
368
- messages.append({"role":"user","content":tool_content.strip()})
369
-
370
- return messages
371
-
372
- async def run(self,user_input=None,messages=[]):
373
-
374
- if self.attempts_made<=self.max_allowed_attempts:
375
-
376
- print_colored(f"Attempt Number : {self.attempts_made}/{self.max_allowed_attempts}","pink")
377
-
378
- self.attempts_made+=1
379
-
380
- if user_input:
381
-
382
- messages = self.prepare_messages(user_input,role="user",messages=messages)
383
-
384
- tool_details,total_tokens = await self.model.aget_output(messages)
385
-
386
- if not isinstance(tool_details,dict):
387
-
388
- return "I am not able to process your request"
389
-
390
- # tool_details = json.loads(tool_details)
391
-
392
- if tool_details['tool_name']=='FinalAnswer':
393
-
394
- thoughts = tool_details.get("thoughts","")
395
-
396
- if isinstance(thoughts,list):
397
-
398
- thoughts = '\n'.join(thoughts)
399
-
400
- print_colored(f"Thoughts: {thoughts}","green")
401
-
402
- print_colored(f"{self.agent_name} : {tool_details['tool_args']['final_answer']}","green")
403
-
404
- messages.append({"role":"assistant","content":str(tool_details)})
405
-
406
- self.messages = messages
407
-
408
- asyncio.run(cl.Message(tool_details['tool_args']['final_answer']).send())
409
-
410
- # return tool_details['tool_args']['final_answer']
411
-
412
- else:
413
-
414
- print()
415
-
416
- thoughts = tool_details.get("thoughts","")
417
-
418
- if isinstance(thoughts,list):
419
-
420
- thoughts = '\n'.join(thoughts)
421
-
422
- print_colored(f"Thoughts: {thoughts}","green")
423
-
424
- print_colored(f"Tool Name: {tool_details['tool_name']}","blue")
425
- # print_colored(f"Tool Name: {tool_details['tool_args']}","blue")
426
- print()
427
-
428
- if tool_details['tool_name'] == 'JupyterNotebookTool':
429
-
430
- asyncio.run(cl.Message(thoughts.strip()).send())
431
-
432
- # html = f"""<pre><code>{tool_details['tool_args']['python_code']}</code></pre>"""
433
- html = f"""```python\n\n{tool_details['tool_args']['python_code']}\n\n```"""
434
-
435
- asyncio.run(cl.Message(html).send())
436
-
437
- messages = self.execute_tool(messages,tool_details)
438
-
439
- self.messages = messages
440
-
441
- return await self.run(messages=messages)
442
-
443
- else:
444
-
445
- self.messages = messages
446
-
447
- print_colored(f"{self.agent_name} : Sorry! Max Attempt Exceeded, I can't take anymore tasks","red")
448
-
449
- return "Sorry! Max Attempt Exceeded, I can't take anymore tasks"
450
-
451
- import uuid
452
- import asyncio
453
- from pydantic import BaseModel,Field
454
- from core.helper import print_colored
455
- from core.models import OpenaiChatModel,OpenAIVissionModel,AnthropicModel
456
- from core.text2sql.query_generator_2 import Text2SQL
457
- from core.tools.JupyterTool import NotebookManager
458
- from pydantic import BaseModel, Field
459
-
460
- SQl_Engine = Text2SQL("gpt-4o-mini","",db_type='mysql',host='host',port=3306,username='name',password='password',database='db_name',add_additional_context=True,max_attempts=10)
461
-
462
- class GetRelavantTables(BaseModel):
463
- """
464
- Tool to retrieve relevant tables based on the user's question.
465
- """
466
-
467
- user_question : str = Field("Provide the user question as it is.")
468
-
469
- sub_questions: list[str] = Field(
470
- description=(
471
- "Split the user question into multiple sub-questions if answering it requires data from multiple tables. "
472
- "Each sub-question should focus on a specific aspect of the user query, referring to relevant columns or tables. "
473
- "For example: \n"
474
- "User question: 'What is the total sales of XYZ product last month?'\n"
475
- "Sub-questions: ['Which column contains product names?', 'Which column contains sales details?',]\n"
476
- "Ensure sub-questions are precise and map clearly to tables or columns needed to answer the main query."
477
- "With each table you will get it's relationship to other tables. So dont need seperate question for that."
478
- "Avoid duplication. Focus and retrieving the accurate columns."
479
- )
480
- )
481
-
482
- def run(self):
483
-
484
- docs = SQl_Engine.get_relavant_documents(self.sub_questions, top_n_similar_docs=30,filtered_tables=2)
485
-
486
- filter_result = asyncio.run(SQl_Engine.filter_columns(self.user_question, docs))
487
-
488
- final_schema="# Here are the relevant table schema.\n\n"
489
-
490
- final_schema = "\n\n".join([f"""{i['filtered_columns']}\n\n### Here is the details on how this table connected to other tables\n: {i['common_columns']}""".strip() for i in filter_result if i['filtered_columns']])
491
-
492
- with open("relevant_schemas.txt","a") as f:
493
- f.write(f"User Question : {self.user_question}\n\n")
494
- f.write(f"Sub Queries\n: {"\n\t".join(self.sub_questions)}\n\n")
495
- f.write(final_schema.strip())
496
- f.write("\n\n***************************************************************\n\n")
497
-
498
- return final_schema
499
-
500
- # return "\n\n------------------------------------\n\n".join([i['text_data'] for i in docs])
501
-
502
- class ExecuteInertmediateQuery(BaseModel):
503
- """
504
- Use this tool to execute an intermediate SQL sub-query to retrieve the unique categories or values available in a specific column.
505
-
506
- Do not directly use the user-provided values in the WHERE clause. First, execute a query to fetch the unique values from the column.
507
-
508
- Based on the retrieved results, validate and construct the final query to ensure accuracy and alignment with the data.
509
-
510
- For example, if the user requests 'drug-x' but the table contains 'Drug-X', this tool ensures the correct value appears at the top.
511
- """
512
-
513
- user_question: str = Field(description="The user question")
514
- sub_query: str = Field(description="The sub-query to execute")
515
-
516
- def run(self):
517
-
518
- print("Sub Quer: ",self.sub_query)
519
-
520
- asyncio.run(cl.Message(f"```sql\n\n{self.sub_query}\n\n```").send())
521
-
522
-
523
- return f"Observation: {SQl_Engine.execute_inertmediate_query(self.user_question, self.sub_query)}"
524
-
525
-
526
- class ExecuteFinalQuery(BaseModel):
527
- """
528
- Tool to execute the final SQL query and save the output to an Excel file to do further analysis. It will return the exact path to the saved output
529
- """
530
-
531
- final_query: str = Field(description="The final query to execute")
532
-
533
- def run(self):
534
-
535
- df = SQl_Engine.run_sql_query(self.final_query)
536
-
537
- asyncio.run(cl.Message(f"```sql\n\n{self.final_query}\n\n```").send())
538
-
539
- if not df.empty:
540
-
541
- if df.shape[0]>=20:
542
-
543
- df.to_excel("work_dir/query_output.xlsx", index=False)
544
-
545
- return f"Observation: The data has been stored at `work_dir/query_output.xlsx`. The data contains {len(df)} rows, here is the snapshot(First 5 rows) of the dataframe : \n\n" + df.head(5).to_markdown()
546
-
547
- else:
548
-
549
- df.to_excel("work_dir/query_output.xlsx", index=False)
550
-
551
- return "Observation: The data has been stored at `work_dir/query_output.xlsx`. \n\n Here is the output of the query" + df.to_markdown()
552
- else:
553
- return "The Query Returned Empty DatafRame: \n\n" + df.head(5).to_markdown()
554
-
555
-
556
- @cl.on_chat_start
557
- def start_message():
558
-
559
- user_session_id = str(uuid.uuid4())
560
-
561
- notebookmanager = NotebookManager(user_session_id)
562
-
563
- class JupyterNotebookTool(BaseModel):
564
-
565
- """A tool for executing Python code in a stateful Jupyter notebook environment."""
566
-
567
- python_code: str = Field(description="A valid python code to execute in a new jupyter notebook cell")
568
-
569
- # expected_output_type : List[str] = Field(description="What output type should the script produce? It might be a single output or a combination of these: [text, dataframe, plotly chart, image, log or nothing]. Please specify the expected output(s) type in the exact order they should appear")
570
-
571
- def run(self):
572
-
573
- result = notebookmanager.run_code(self.python_code)
574
-
575
- return result
576
-
577
- description = "Responsible for Answering User question."
578
-
579
- instruction = open(r"prompts/system_prompt.md","r").read()
580
-
581
- tools = [GetRelavantTables,ExecuteInertmediateQuery,ExecuteFinalQuery,JupyterNotebookTool]
582
-
583
- model = OpenaiChatModel(model_name="gpt-4o-mini",verbose=True)
584
-
585
- # model_name= 'claude-3-5-sonnet-20240620'
586
-
587
- # model = AnthropicModel(api_key=api_key)
588
-
589
- vissionmodel = OpenAIVissionModel(model="gpt-4o-mini")
590
-
591
- vission_prompt = "You are provided with a plot from a data analysis, You need to explain all the insights and metrics to the user"
592
-
593
- agent = ChainlitStructuredAgent(model,"AI Assistant",description,instruction,tools,max_allowed_attempts=30,vission_model=vissionmodel,vission_model_prompt=vission_prompt)
594
-
595
- cl.user_session.set("messages",[])
596
-
597
- cl.user_session.set("agent",agent)
598
-
599
- cl.user_session.set("notebookmanager",notebookmanager)
600
-
601
- cl.user_session.set("user_session_id",user_session_id)
602
-
603
-
604
- @cl.on_message
605
- async def on_message(user_input: cl.Message):
606
-
607
- messages=cl.user_session.get("messages")
608
-
609
- agent=cl.user_session.get("agent")
610
-
611
- response = await agent.run(user_input.content,messages)
612
-
613
- cl.user_session.set("messages",agent.messages)
614
-
615
- # await cl.Message(response).send()
616
-
617
- @cl.on_chat_end
618
- def delete_notebook():
619
-
620
- notebookmanager=cl.user_session.get("notebookmanager")
621
-
622
- user_session_id=cl.user_session.get("user_session_id")
623
-
624
- notebookmanager.delete_notebook(user_session_id)
 
1
+ # streamlit_app.py
2
+ import streamlit as st
3
+ import pandas as pd
4
+ from sqlalchemy import create_engine, text
5
+ import openai
6
+ import os
7
+
8
+ # ---- CONFIG ----
9
+ # Set your API key as an environment variable or in a .env file
10
+ openai.api_key = os.getenv("OPENAI_API_KEY")
11
+
12
+ # Database connection (update these with your credentials)
13
+ DB_TYPE = "mysql+pymysql"
14
+ DB_USER = "username"
15
+ DB_PASS = "password"
16
+ DB_HOST = "host"
17
+ DB_PORT = "3306"
18
+ DB_NAME = "db_name"
19
+
20
+ DATABASE_URL = f"{DB_TYPE}://{DB_USER}:{DB_PASS}@{DB_HOST}:{DB_PORT}/{DB_NAME}"
21
+ engine = create_engine(DATABASE_URL)
22
+
23
+ # ---- FUNCTIONS ----
24
+ def generate_sql(user_question, table_names=[]):
25
+ """
26
+ Generates SQL query from user question using OpenAI GPT
27
+ """
28
+ table_info = ""
29
+ if table_names:
30
+ table_info = f"These are your tables: {table_names}\n"
31
+
32
+ prompt = f"""
33
+ You are an expert SQL generator.
34
+ {table_info}
35
+ Write a SQL query that answers the following question:
36
+ \"\"\"{user_question}\"\"\"
37
+ Only return SQL, do not explain.
38
+ """
39
+ response = openai.Completion.create(
40
+ engine="text-davinci-003",
41
+ prompt=prompt,
42
+ temperature=0,
43
+ max_tokens=300
44
+ )
45
+ sql_query = response.choices[0].text.strip()
46
+ return sql_query
47
+
48
+ def run_query(sql_query):
49
+ """
50
+ Runs SQL query using SQLAlchemy
51
+ """
52
+ try:
53
+ with engine.connect() as conn:
54
+ result = pd.read_sql(text(sql_query), conn)
55
+ return result
56
+ except Exception as e:
57
+ return f"Error executing query: {e}"
58
+
59
+ # ---- STREAMLIT UI ----
60
+ st.title("🧠 AI SQL Assistant")
61
+ st.markdown("Ask a question about your database, and it will generate SQL and show results.")
62
+
63
+ user_question = st.text_input("Enter your question:")
64
+
65
+ if st.button("Run Query") and user_question:
66
+ with st.spinner("Generating SQL..."):
67
+ sql_query = generate_sql(user_question)
68
+ st.code(sql_query, language="sql")
69
+
70
+ with st.spinner("Executing SQL..."):
71
+ result = run_query(sql_query)
72
+ if isinstance(result, pd.DataFrame):
73
+ st.success("Query executed successfully!")
74
+ st.dataframe(result)
75
+ else:
76
+ st.error(result)