Update streamlit_app.py
Browse files- streamlit_app.py +76 -624
streamlit_app.py
CHANGED
|
@@ -1,624 +1,76 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
import
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
{
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
"
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
self.assistants_names = [i.agent_name for i in self.assistant_agents]
|
| 78 |
-
|
| 79 |
-
recipients = Enum("recipient", {name: name for name in self.assistants_names})
|
| 80 |
-
|
| 81 |
-
assistant_description = f"Select the correct Agent to assign the task : {self.assistants_names}\n\n"
|
| 82 |
-
|
| 83 |
-
for assistant in self.assistant_agents:
|
| 84 |
-
|
| 85 |
-
assistant_description+=assistant.agent_name+" : "+assistant.agent_description+"\n"
|
| 86 |
-
|
| 87 |
-
class AssignTask(BaseModel):
|
| 88 |
-
|
| 89 |
-
"""Use this tool to facilitate direct, synchronous communication between specialized agents within your agency. When you send a message using this tool, you receive a response exclusively from the designated recipient agent. To continue the dialogue, invoke this tool again with the desired recipient agent and your follow-up message. Remember, communication here is synchronous; the recipient agent won't perform any tasks post-response. You are responsible for relaying the recipient agent's responses back to the user, as the user does not have direct access to these replies. Keep engaging with the tool for continuous interaction until the task is fully resolved. Do not send more than 1 message at a time."""
|
| 90 |
-
|
| 91 |
-
my_primary_instructions: str = Field(...,
|
| 92 |
-
description="Please repeat your primary instructions step-by-step, including both completed "
|
| 93 |
-
"and the following next steps that you need to perform. For multi-step, complex tasks, first break them down "
|
| 94 |
-
"into smaller steps yourself. Then, issue each step individually to the "
|
| 95 |
-
"recipient agent via the message parameter. Each identified step should be "
|
| 96 |
-
"sent in separate message. Keep in mind, that the recipient agent does not have access "
|
| 97 |
-
"to these instructions. You must include recipient agent-specific instructions "
|
| 98 |
-
"in the message or additional_instructions parameters.")
|
| 99 |
-
recipient: recipients = Field(..., description=assistant_description,examples=self.assistants_names)
|
| 100 |
-
|
| 101 |
-
task_details: str = Field(...,
|
| 102 |
-
description="Specify the task required for the recipient agent to complete. Focus on "
|
| 103 |
-
"clarifying what the task entails, rather than providing exact "
|
| 104 |
-
"instructions.")
|
| 105 |
-
|
| 106 |
-
additional_instructions: str = Field(description="Any additional instructions or clarifications that you would like to provide to the recipient agent.")
|
| 107 |
-
|
| 108 |
-
self.tools.append(AssignTask)
|
| 109 |
-
|
| 110 |
-
self.tool_names = [i.__name__ for i in self.tools]
|
| 111 |
-
|
| 112 |
-
# class ToolChoices(BaseModel):
|
| 113 |
-
# thoughts: List[str] = Field(description="Your Thoughts")
|
| 114 |
-
# tool_name : Literal[*self.tool_names] = Field(description=f"Select an appropriate tools from : {self.tool_names}",examples=self.tool_names)
|
| 115 |
-
# tool_args : Union[*self.tools]
|
| 116 |
-
|
| 117 |
-
# return ToolChoices
|
| 118 |
-
|
| 119 |
-
def prepare_schema_from_tool(self,Tools: List[Type[BaseModel]]) -> List[dict]:
|
| 120 |
-
schemas = ""
|
| 121 |
-
for tool in Tools:
|
| 122 |
-
schema = tool.model_json_schema()
|
| 123 |
-
schemas+="\n"
|
| 124 |
-
schemas += f"""
|
| 125 |
-
"Tool Name": {tool.__name__},
|
| 126 |
-
"Tool Description": {tool.__doc__},
|
| 127 |
-
"Tool Parameters":
|
| 128 |
-
"Properties": {schema["properties"]},
|
| 129 |
-
"Required": {schema["required"]},
|
| 130 |
-
"Type": {schema["type"]}\n"""
|
| 131 |
-
schemas+="\n"
|
| 132 |
-
|
| 133 |
-
return schemas
|
| 134 |
-
|
| 135 |
-
def prepare_prompt(self):
|
| 136 |
-
|
| 137 |
-
if len(self.assistant_agents):
|
| 138 |
-
|
| 139 |
-
self.agent_instructions+="\n**Task Assignment**: You can assign tasks to the following agents who are responsible to help you to achieve your goal.\n"
|
| 140 |
-
|
| 141 |
-
self.agent_instructions+="-----------------------------------------------\n"
|
| 142 |
-
|
| 143 |
-
for agent in self.assistant_agents:
|
| 144 |
-
|
| 145 |
-
self.agent_instructions+="- **Agent Name**: "+agent.agent_name+"\n"
|
| 146 |
-
self.agent_instructions+="- **Agent Description**:\n"+agent.agent_description+"\n"
|
| 147 |
-
|
| 148 |
-
self.agent_instructions+="\n-----------------------------------------------\n"
|
| 149 |
-
|
| 150 |
-
def prepare_messages(self,content,role=None,messages=[]):
|
| 151 |
-
|
| 152 |
-
if not len(messages):
|
| 153 |
-
|
| 154 |
-
messages = [
|
| 155 |
-
{"role":"system","content":self.agent_instructions},
|
| 156 |
-
{"role":"user","content":content}
|
| 157 |
-
]
|
| 158 |
-
|
| 159 |
-
else:
|
| 160 |
-
|
| 161 |
-
messages.append({"role":role,"content":content})
|
| 162 |
-
|
| 163 |
-
return messages
|
| 164 |
-
|
| 165 |
-
def construct_message_from_output(self,output):
|
| 166 |
-
|
| 167 |
-
content = ""
|
| 168 |
-
|
| 169 |
-
for key, value in output.items():
|
| 170 |
-
|
| 171 |
-
if key=='task_checklist':
|
| 172 |
-
|
| 173 |
-
content+=f"Task Check List:\n{"\n".join(value)}\n"
|
| 174 |
-
|
| 175 |
-
elif key=='completed_tasks':
|
| 176 |
-
|
| 177 |
-
content+=f"Completed Tasks:\n{"\n".join(value)}\n"
|
| 178 |
-
|
| 179 |
-
elif key=='thoughts':
|
| 180 |
-
|
| 181 |
-
content+=f"Thought: {"\n".join(value)}\n"
|
| 182 |
-
|
| 183 |
-
elif key=='tool_name':
|
| 184 |
-
|
| 185 |
-
if value=="AssignTask":
|
| 186 |
-
|
| 187 |
-
content+=f"I need to assign a task to the '{output['tool_args']['recipient']}' with the following instructions:"
|
| 188 |
-
else:
|
| 189 |
-
|
| 190 |
-
content+=f"We need to use the '{value}' tool with the following arguments: "
|
| 191 |
-
|
| 192 |
-
else:
|
| 193 |
-
|
| 194 |
-
content+=f"{value}."
|
| 195 |
-
|
| 196 |
-
return content
|
| 197 |
-
@cl.step(type="tool")
|
| 198 |
-
def execute_tool(self,messages,tool_details):
|
| 199 |
-
|
| 200 |
-
try:
|
| 201 |
-
|
| 202 |
-
# assistant_content=self.construct_message_from_output(tool_details)
|
| 203 |
-
assistant_content=str(tool_details)
|
| 204 |
-
|
| 205 |
-
except Exception as e:
|
| 206 |
-
|
| 207 |
-
invalid_arg_error_message ="Error while executing tool. Please check the tool name or provide a valid arguments to the tool: "+str(e)
|
| 208 |
-
|
| 209 |
-
tool_content = invalid_arg_error_message
|
| 210 |
-
|
| 211 |
-
assistant_content = str(tool_details)
|
| 212 |
-
|
| 213 |
-
messages.append({"role":"assistant","content":assistant_content})
|
| 214 |
-
|
| 215 |
-
messages.append({"role":"user","content":tool_content})
|
| 216 |
-
|
| 217 |
-
return messages
|
| 218 |
-
|
| 219 |
-
if tool_details['tool_name'] in self.tool_names :
|
| 220 |
-
|
| 221 |
-
if tool_details['tool_name'] == 'AssignTask':
|
| 222 |
-
|
| 223 |
-
try:
|
| 224 |
-
|
| 225 |
-
arguments = tool_details['tool_args']
|
| 226 |
-
|
| 227 |
-
task_details =arguments.get('task_details',"")
|
| 228 |
-
|
| 229 |
-
additional_instructions =arguments.get('additional_instructions',"")
|
| 230 |
-
|
| 231 |
-
print_colored(f"{self.agent_name} assigned a task to {arguments['recipient']}","cyan")
|
| 232 |
-
|
| 233 |
-
assistant_agent = self.agents_as_tools[arguments['recipient']]
|
| 234 |
-
|
| 235 |
-
user_input = task_details + "\n" + additional_instructions
|
| 236 |
-
|
| 237 |
-
print_colored("Task Details: \n"+user_input,"white")
|
| 238 |
-
|
| 239 |
-
tool_content = assistant_agent.run(user_input)
|
| 240 |
-
|
| 241 |
-
tool_content = f"Response from the {arguments['recipient']} : "+tool_content
|
| 242 |
-
|
| 243 |
-
except Exception as e:
|
| 244 |
-
|
| 245 |
-
print_colored("Error Tool: "+str(e),"red")
|
| 246 |
-
|
| 247 |
-
tool_content = f"Error while assigning task to {arguments['recipient']}. Please provide a correct agent name: {[i.agent_name for i in self.assistant_agents]}"
|
| 248 |
-
|
| 249 |
-
else:
|
| 250 |
-
|
| 251 |
-
try:
|
| 252 |
-
|
| 253 |
-
print_colored(f"{self.agent_name} : Calling Tool {tool_details['tool_name']}","yellow")
|
| 254 |
-
|
| 255 |
-
tool_output = self.tool_objects[tool_details['tool_name']](**tool_details['tool_args']).run()
|
| 256 |
-
|
| 257 |
-
print_colored(f"Got output....{len(tool_output)}","orange")
|
| 258 |
-
|
| 259 |
-
tool_content=f"Output From {tool_details['tool_name']} Tool: {tool_output}"
|
| 260 |
-
|
| 261 |
-
if tool_details['tool_name'] == 'JupyterNotebookTool' and len(tool_output):
|
| 262 |
-
|
| 263 |
-
for i in range(len(tool_output)):
|
| 264 |
-
|
| 265 |
-
elements = []
|
| 266 |
-
|
| 267 |
-
# print('Keys: ', tool_output[i].keys())
|
| 268 |
-
|
| 269 |
-
# print('expected_output_type: ', tool_output[i].get('expected_output_type',""))
|
| 270 |
-
|
| 271 |
-
# print('display_type: ', tool_output[i].get('display_type',""))
|
| 272 |
-
|
| 273 |
-
if tool_output[i]['display_type']=='text':
|
| 274 |
-
|
| 275 |
-
elements.append(cl.Text(name="Tool Output", content=tool_output[i]['final_output'], display="inline"))
|
| 276 |
-
|
| 277 |
-
# asyncio.run(cl.Message("",elements=elements).send())
|
| 278 |
-
|
| 279 |
-
elif tool_output[i]['display_type']=='image':
|
| 280 |
-
|
| 281 |
-
image_data = tool_output[i]['output'].split('base64,')[1]
|
| 282 |
-
|
| 283 |
-
# Decode the base64 image data
|
| 284 |
-
image_bytes = base64.b64decode(image_data)
|
| 285 |
-
|
| 286 |
-
elements.append(cl.Image(name="Plot",size='medium',content=image_bytes,display="inline"))
|
| 287 |
-
|
| 288 |
-
# asyncio.run(cl.Message("",elements=elements).send())
|
| 289 |
-
|
| 290 |
-
image_base64 = base64.b64encode(base64.b64decode(image_data)).decode('utf-8')
|
| 291 |
-
|
| 292 |
-
if self.vission_model:
|
| 293 |
-
|
| 294 |
-
print_colored("Describing the Image...","green")
|
| 295 |
-
|
| 296 |
-
plot_dec=asyncio.run(self.vission_model.get_output(self.vission_model_prompt,base64_image=image_base64))
|
| 297 |
-
|
| 298 |
-
tool_output[i]['final_output'] +="\n"+plot_dec
|
| 299 |
-
|
| 300 |
-
print_colored(f"Image Description.: {tool_output[i]['final_output']}","green")
|
| 301 |
-
|
| 302 |
-
tool_output[i]['final_output'] =plot_dec
|
| 303 |
-
|
| 304 |
-
elements.append(cl.Text(name="Insights",content=plot_dec))
|
| 305 |
-
|
| 306 |
-
# asyncio.run(cl.Message("",elements=elements).send())
|
| 307 |
-
|
| 308 |
-
elif tool_output[i]['display_type']=='plotly':
|
| 309 |
-
|
| 310 |
-
print_colored("Trying to display plotly.....","red")
|
| 311 |
-
|
| 312 |
-
fig = go.Figure(data=tool_output[i]['output'])
|
| 313 |
-
|
| 314 |
-
print_colored(f"Created fig.....{type(fig)}","red")
|
| 315 |
-
|
| 316 |
-
# Convert the Plotly figure to a PNG image
|
| 317 |
-
img_bytes = fig.to_image(format="png")
|
| 318 |
-
|
| 319 |
-
print_colored(f"Created img_bytes.....{type(fig)}","red")
|
| 320 |
-
|
| 321 |
-
# Encode the image to base64
|
| 322 |
-
image_base64 = base64.b64encode(img_bytes).decode('utf-8')
|
| 323 |
-
|
| 324 |
-
print_colored(f"Created image_base64.....{type(fig)}","red")
|
| 325 |
-
|
| 326 |
-
elements.append(cl.Plotly(name="Chart", size="medium",figure=fig, display="inline"))
|
| 327 |
-
|
| 328 |
-
# asyncio.run(cl.Message("",elements=elements).send())
|
| 329 |
-
|
| 330 |
-
if self.vission_model:
|
| 331 |
-
|
| 332 |
-
print_colored("Describing the plotly...","green")
|
| 333 |
-
|
| 334 |
-
plot_dec = asyncio.run(self.vission_model.get_output(self.vission_model_prompt,base64_image=image_base64))
|
| 335 |
-
|
| 336 |
-
print_colored(f"Plotly Description.: {tool_output[i]['final_output']}","green")
|
| 337 |
-
|
| 338 |
-
tool_output[i]['final_output'] =plot_dec
|
| 339 |
-
|
| 340 |
-
elements.append(cl.Text(name="Insights",content=plot_dec,display="inline"))
|
| 341 |
-
|
| 342 |
-
# asyncio.run(cl.Message("",elements=elements).send())
|
| 343 |
-
else:
|
| 344 |
-
|
| 345 |
-
elements.append(cl.Text(name="", content=tool_output[i]['final_output'], display="inline"))
|
| 346 |
-
|
| 347 |
-
asyncio.run(cl.Message("Output",elements=elements).send())
|
| 348 |
-
|
| 349 |
-
tool_output = [i['final_output'][:30000] for i in tool_output]
|
| 350 |
-
|
| 351 |
-
tool_content=f"Output From {tool_details['tool_name']} Tool: {"\n".join(tool_output)}"
|
| 352 |
-
|
| 353 |
-
# elements = [cl.Text("Tool Output",content=str(tool_output))]
|
| 354 |
-
|
| 355 |
-
# asyncio.run(cl.Message("",elements=elements).send())
|
| 356 |
-
|
| 357 |
-
except Exception as e:
|
| 358 |
-
|
| 359 |
-
print_colored("Error Tool: "+str(e),"red")
|
| 360 |
-
|
| 361 |
-
tool_content = "Error while executing tool. Please check the tool name or provide a valid arguments to the tool: "+str(e)
|
| 362 |
-
|
| 363 |
-
else:
|
| 364 |
-
|
| 365 |
-
tool_content= "There is no such a tool available. Here are the available tools : "+str(self.tool_names)
|
| 366 |
-
|
| 367 |
-
messages.append({"role":"assistant","content":assistant_content.strip()})
|
| 368 |
-
messages.append({"role":"user","content":tool_content.strip()})
|
| 369 |
-
|
| 370 |
-
return messages
|
| 371 |
-
|
| 372 |
-
async def run(self,user_input=None,messages=[]):
|
| 373 |
-
|
| 374 |
-
if self.attempts_made<=self.max_allowed_attempts:
|
| 375 |
-
|
| 376 |
-
print_colored(f"Attempt Number : {self.attempts_made}/{self.max_allowed_attempts}","pink")
|
| 377 |
-
|
| 378 |
-
self.attempts_made+=1
|
| 379 |
-
|
| 380 |
-
if user_input:
|
| 381 |
-
|
| 382 |
-
messages = self.prepare_messages(user_input,role="user",messages=messages)
|
| 383 |
-
|
| 384 |
-
tool_details,total_tokens = await self.model.aget_output(messages)
|
| 385 |
-
|
| 386 |
-
if not isinstance(tool_details,dict):
|
| 387 |
-
|
| 388 |
-
return "I am not able to process your request"
|
| 389 |
-
|
| 390 |
-
# tool_details = json.loads(tool_details)
|
| 391 |
-
|
| 392 |
-
if tool_details['tool_name']=='FinalAnswer':
|
| 393 |
-
|
| 394 |
-
thoughts = tool_details.get("thoughts","")
|
| 395 |
-
|
| 396 |
-
if isinstance(thoughts,list):
|
| 397 |
-
|
| 398 |
-
thoughts = '\n'.join(thoughts)
|
| 399 |
-
|
| 400 |
-
print_colored(f"Thoughts: {thoughts}","green")
|
| 401 |
-
|
| 402 |
-
print_colored(f"{self.agent_name} : {tool_details['tool_args']['final_answer']}","green")
|
| 403 |
-
|
| 404 |
-
messages.append({"role":"assistant","content":str(tool_details)})
|
| 405 |
-
|
| 406 |
-
self.messages = messages
|
| 407 |
-
|
| 408 |
-
asyncio.run(cl.Message(tool_details['tool_args']['final_answer']).send())
|
| 409 |
-
|
| 410 |
-
# return tool_details['tool_args']['final_answer']
|
| 411 |
-
|
| 412 |
-
else:
|
| 413 |
-
|
| 414 |
-
print()
|
| 415 |
-
|
| 416 |
-
thoughts = tool_details.get("thoughts","")
|
| 417 |
-
|
| 418 |
-
if isinstance(thoughts,list):
|
| 419 |
-
|
| 420 |
-
thoughts = '\n'.join(thoughts)
|
| 421 |
-
|
| 422 |
-
print_colored(f"Thoughts: {thoughts}","green")
|
| 423 |
-
|
| 424 |
-
print_colored(f"Tool Name: {tool_details['tool_name']}","blue")
|
| 425 |
-
# print_colored(f"Tool Name: {tool_details['tool_args']}","blue")
|
| 426 |
-
print()
|
| 427 |
-
|
| 428 |
-
if tool_details['tool_name'] == 'JupyterNotebookTool':
|
| 429 |
-
|
| 430 |
-
asyncio.run(cl.Message(thoughts.strip()).send())
|
| 431 |
-
|
| 432 |
-
# html = f"""<pre><code>{tool_details['tool_args']['python_code']}</code></pre>"""
|
| 433 |
-
html = f"""```python\n\n{tool_details['tool_args']['python_code']}\n\n```"""
|
| 434 |
-
|
| 435 |
-
asyncio.run(cl.Message(html).send())
|
| 436 |
-
|
| 437 |
-
messages = self.execute_tool(messages,tool_details)
|
| 438 |
-
|
| 439 |
-
self.messages = messages
|
| 440 |
-
|
| 441 |
-
return await self.run(messages=messages)
|
| 442 |
-
|
| 443 |
-
else:
|
| 444 |
-
|
| 445 |
-
self.messages = messages
|
| 446 |
-
|
| 447 |
-
print_colored(f"{self.agent_name} : Sorry! Max Attempt Exceeded, I can't take anymore tasks","red")
|
| 448 |
-
|
| 449 |
-
return "Sorry! Max Attempt Exceeded, I can't take anymore tasks"
|
| 450 |
-
|
| 451 |
-
import uuid
|
| 452 |
-
import asyncio
|
| 453 |
-
from pydantic import BaseModel,Field
|
| 454 |
-
from core.helper import print_colored
|
| 455 |
-
from core.models import OpenaiChatModel,OpenAIVissionModel,AnthropicModel
|
| 456 |
-
from core.text2sql.query_generator_2 import Text2SQL
|
| 457 |
-
from core.tools.JupyterTool import NotebookManager
|
| 458 |
-
from pydantic import BaseModel, Field
|
| 459 |
-
|
| 460 |
-
SQl_Engine = Text2SQL("gpt-4o-mini","",db_type='mysql',host='host',port=3306,username='name',password='password',database='db_name',add_additional_context=True,max_attempts=10)
|
| 461 |
-
|
| 462 |
-
class GetRelavantTables(BaseModel):
|
| 463 |
-
"""
|
| 464 |
-
Tool to retrieve relevant tables based on the user's question.
|
| 465 |
-
"""
|
| 466 |
-
|
| 467 |
-
user_question : str = Field("Provide the user question as it is.")
|
| 468 |
-
|
| 469 |
-
sub_questions: list[str] = Field(
|
| 470 |
-
description=(
|
| 471 |
-
"Split the user question into multiple sub-questions if answering it requires data from multiple tables. "
|
| 472 |
-
"Each sub-question should focus on a specific aspect of the user query, referring to relevant columns or tables. "
|
| 473 |
-
"For example: \n"
|
| 474 |
-
"User question: 'What is the total sales of XYZ product last month?'\n"
|
| 475 |
-
"Sub-questions: ['Which column contains product names?', 'Which column contains sales details?',]\n"
|
| 476 |
-
"Ensure sub-questions are precise and map clearly to tables or columns needed to answer the main query."
|
| 477 |
-
"With each table you will get it's relationship to other tables. So dont need seperate question for that."
|
| 478 |
-
"Avoid duplication. Focus and retrieving the accurate columns."
|
| 479 |
-
)
|
| 480 |
-
)
|
| 481 |
-
|
| 482 |
-
def run(self):
|
| 483 |
-
|
| 484 |
-
docs = SQl_Engine.get_relavant_documents(self.sub_questions, top_n_similar_docs=30,filtered_tables=2)
|
| 485 |
-
|
| 486 |
-
filter_result = asyncio.run(SQl_Engine.filter_columns(self.user_question, docs))
|
| 487 |
-
|
| 488 |
-
final_schema="# Here are the relevant table schema.\n\n"
|
| 489 |
-
|
| 490 |
-
final_schema = "\n\n".join([f"""{i['filtered_columns']}\n\n### Here is the details on how this table connected to other tables\n: {i['common_columns']}""".strip() for i in filter_result if i['filtered_columns']])
|
| 491 |
-
|
| 492 |
-
with open("relevant_schemas.txt","a") as f:
|
| 493 |
-
f.write(f"User Question : {self.user_question}\n\n")
|
| 494 |
-
f.write(f"Sub Queries\n: {"\n\t".join(self.sub_questions)}\n\n")
|
| 495 |
-
f.write(final_schema.strip())
|
| 496 |
-
f.write("\n\n***************************************************************\n\n")
|
| 497 |
-
|
| 498 |
-
return final_schema
|
| 499 |
-
|
| 500 |
-
# return "\n\n------------------------------------\n\n".join([i['text_data'] for i in docs])
|
| 501 |
-
|
| 502 |
-
class ExecuteInertmediateQuery(BaseModel):
|
| 503 |
-
"""
|
| 504 |
-
Use this tool to execute an intermediate SQL sub-query to retrieve the unique categories or values available in a specific column.
|
| 505 |
-
|
| 506 |
-
Do not directly use the user-provided values in the WHERE clause. First, execute a query to fetch the unique values from the column.
|
| 507 |
-
|
| 508 |
-
Based on the retrieved results, validate and construct the final query to ensure accuracy and alignment with the data.
|
| 509 |
-
|
| 510 |
-
For example, if the user requests 'drug-x' but the table contains 'Drug-X', this tool ensures the correct value appears at the top.
|
| 511 |
-
"""
|
| 512 |
-
|
| 513 |
-
user_question: str = Field(description="The user question")
|
| 514 |
-
sub_query: str = Field(description="The sub-query to execute")
|
| 515 |
-
|
| 516 |
-
def run(self):
|
| 517 |
-
|
| 518 |
-
print("Sub Quer: ",self.sub_query)
|
| 519 |
-
|
| 520 |
-
asyncio.run(cl.Message(f"```sql\n\n{self.sub_query}\n\n```").send())
|
| 521 |
-
|
| 522 |
-
|
| 523 |
-
return f"Observation: {SQl_Engine.execute_inertmediate_query(self.user_question, self.sub_query)}"
|
| 524 |
-
|
| 525 |
-
|
| 526 |
-
class ExecuteFinalQuery(BaseModel):
|
| 527 |
-
"""
|
| 528 |
-
Tool to execute the final SQL query and save the output to an Excel file to do further analysis. It will return the exact path to the saved output
|
| 529 |
-
"""
|
| 530 |
-
|
| 531 |
-
final_query: str = Field(description="The final query to execute")
|
| 532 |
-
|
| 533 |
-
def run(self):
|
| 534 |
-
|
| 535 |
-
df = SQl_Engine.run_sql_query(self.final_query)
|
| 536 |
-
|
| 537 |
-
asyncio.run(cl.Message(f"```sql\n\n{self.final_query}\n\n```").send())
|
| 538 |
-
|
| 539 |
-
if not df.empty:
|
| 540 |
-
|
| 541 |
-
if df.shape[0]>=20:
|
| 542 |
-
|
| 543 |
-
df.to_excel("work_dir/query_output.xlsx", index=False)
|
| 544 |
-
|
| 545 |
-
return f"Observation: The data has been stored at `work_dir/query_output.xlsx`. The data contains {len(df)} rows, here is the snapshot(First 5 rows) of the dataframe : \n\n" + df.head(5).to_markdown()
|
| 546 |
-
|
| 547 |
-
else:
|
| 548 |
-
|
| 549 |
-
df.to_excel("work_dir/query_output.xlsx", index=False)
|
| 550 |
-
|
| 551 |
-
return "Observation: The data has been stored at `work_dir/query_output.xlsx`. \n\n Here is the output of the query" + df.to_markdown()
|
| 552 |
-
else:
|
| 553 |
-
return "The Query Returned Empty DatafRame: \n\n" + df.head(5).to_markdown()
|
| 554 |
-
|
| 555 |
-
|
| 556 |
-
@cl.on_chat_start
|
| 557 |
-
def start_message():
|
| 558 |
-
|
| 559 |
-
user_session_id = str(uuid.uuid4())
|
| 560 |
-
|
| 561 |
-
notebookmanager = NotebookManager(user_session_id)
|
| 562 |
-
|
| 563 |
-
class JupyterNotebookTool(BaseModel):
|
| 564 |
-
|
| 565 |
-
"""A tool for executing Python code in a stateful Jupyter notebook environment."""
|
| 566 |
-
|
| 567 |
-
python_code: str = Field(description="A valid python code to execute in a new jupyter notebook cell")
|
| 568 |
-
|
| 569 |
-
# expected_output_type : List[str] = Field(description="What output type should the script produce? It might be a single output or a combination of these: [text, dataframe, plotly chart, image, log or nothing]. Please specify the expected output(s) type in the exact order they should appear")
|
| 570 |
-
|
| 571 |
-
def run(self):
|
| 572 |
-
|
| 573 |
-
result = notebookmanager.run_code(self.python_code)
|
| 574 |
-
|
| 575 |
-
return result
|
| 576 |
-
|
| 577 |
-
description = "Responsible for Answering User question."
|
| 578 |
-
|
| 579 |
-
instruction = open(r"prompts/system_prompt.md","r").read()
|
| 580 |
-
|
| 581 |
-
tools = [GetRelavantTables,ExecuteInertmediateQuery,ExecuteFinalQuery,JupyterNotebookTool]
|
| 582 |
-
|
| 583 |
-
model = OpenaiChatModel(model_name="gpt-4o-mini",verbose=True)
|
| 584 |
-
|
| 585 |
-
# model_name= 'claude-3-5-sonnet-20240620'
|
| 586 |
-
|
| 587 |
-
# model = AnthropicModel(api_key=api_key)
|
| 588 |
-
|
| 589 |
-
vissionmodel = OpenAIVissionModel(model="gpt-4o-mini")
|
| 590 |
-
|
| 591 |
-
vission_prompt = "You are provided with a plot from a data analysis, You need to explain all the insights and metrics to the user"
|
| 592 |
-
|
| 593 |
-
agent = ChainlitStructuredAgent(model,"AI Assistant",description,instruction,tools,max_allowed_attempts=30,vission_model=vissionmodel,vission_model_prompt=vission_prompt)
|
| 594 |
-
|
| 595 |
-
cl.user_session.set("messages",[])
|
| 596 |
-
|
| 597 |
-
cl.user_session.set("agent",agent)
|
| 598 |
-
|
| 599 |
-
cl.user_session.set("notebookmanager",notebookmanager)
|
| 600 |
-
|
| 601 |
-
cl.user_session.set("user_session_id",user_session_id)
|
| 602 |
-
|
| 603 |
-
|
| 604 |
-
@cl.on_message
|
| 605 |
-
async def on_message(user_input: cl.Message):
|
| 606 |
-
|
| 607 |
-
messages=cl.user_session.get("messages")
|
| 608 |
-
|
| 609 |
-
agent=cl.user_session.get("agent")
|
| 610 |
-
|
| 611 |
-
response = await agent.run(user_input.content,messages)
|
| 612 |
-
|
| 613 |
-
cl.user_session.set("messages",agent.messages)
|
| 614 |
-
|
| 615 |
-
# await cl.Message(response).send()
|
| 616 |
-
|
| 617 |
-
@cl.on_chat_end
|
| 618 |
-
def delete_notebook():
|
| 619 |
-
|
| 620 |
-
notebookmanager=cl.user_session.get("notebookmanager")
|
| 621 |
-
|
| 622 |
-
user_session_id=cl.user_session.get("user_session_id")
|
| 623 |
-
|
| 624 |
-
notebookmanager.delete_notebook(user_session_id)
|
|
|
|
| 1 |
+
# streamlit_app.py
|
| 2 |
+
import streamlit as st
|
| 3 |
+
import pandas as pd
|
| 4 |
+
from sqlalchemy import create_engine, text
|
| 5 |
+
import openai
|
| 6 |
+
import os
|
| 7 |
+
|
| 8 |
+
# ---- CONFIG ----
|
| 9 |
+
# Set your API key as an environment variable or in a .env file
|
| 10 |
+
openai.api_key = os.getenv("OPENAI_API_KEY")
|
| 11 |
+
|
| 12 |
+
# Database connection (update these with your credentials)
|
| 13 |
+
DB_TYPE = "mysql+pymysql"
|
| 14 |
+
DB_USER = "username"
|
| 15 |
+
DB_PASS = "password"
|
| 16 |
+
DB_HOST = "host"
|
| 17 |
+
DB_PORT = "3306"
|
| 18 |
+
DB_NAME = "db_name"
|
| 19 |
+
|
| 20 |
+
DATABASE_URL = f"{DB_TYPE}://{DB_USER}:{DB_PASS}@{DB_HOST}:{DB_PORT}/{DB_NAME}"
|
| 21 |
+
engine = create_engine(DATABASE_URL)
|
| 22 |
+
|
| 23 |
+
# ---- FUNCTIONS ----
|
| 24 |
+
def generate_sql(user_question, table_names=[]):
|
| 25 |
+
"""
|
| 26 |
+
Generates SQL query from user question using OpenAI GPT
|
| 27 |
+
"""
|
| 28 |
+
table_info = ""
|
| 29 |
+
if table_names:
|
| 30 |
+
table_info = f"These are your tables: {table_names}\n"
|
| 31 |
+
|
| 32 |
+
prompt = f"""
|
| 33 |
+
You are an expert SQL generator.
|
| 34 |
+
{table_info}
|
| 35 |
+
Write a SQL query that answers the following question:
|
| 36 |
+
\"\"\"{user_question}\"\"\"
|
| 37 |
+
Only return SQL, do not explain.
|
| 38 |
+
"""
|
| 39 |
+
response = openai.Completion.create(
|
| 40 |
+
engine="text-davinci-003",
|
| 41 |
+
prompt=prompt,
|
| 42 |
+
temperature=0,
|
| 43 |
+
max_tokens=300
|
| 44 |
+
)
|
| 45 |
+
sql_query = response.choices[0].text.strip()
|
| 46 |
+
return sql_query
|
| 47 |
+
|
| 48 |
+
def run_query(sql_query):
|
| 49 |
+
"""
|
| 50 |
+
Runs SQL query using SQLAlchemy
|
| 51 |
+
"""
|
| 52 |
+
try:
|
| 53 |
+
with engine.connect() as conn:
|
| 54 |
+
result = pd.read_sql(text(sql_query), conn)
|
| 55 |
+
return result
|
| 56 |
+
except Exception as e:
|
| 57 |
+
return f"Error executing query: {e}"
|
| 58 |
+
|
| 59 |
+
# ---- STREAMLIT UI ----
|
| 60 |
+
st.title("🧠 AI SQL Assistant")
|
| 61 |
+
st.markdown("Ask a question about your database, and it will generate SQL and show results.")
|
| 62 |
+
|
| 63 |
+
user_question = st.text_input("Enter your question:")
|
| 64 |
+
|
| 65 |
+
if st.button("Run Query") and user_question:
|
| 66 |
+
with st.spinner("Generating SQL..."):
|
| 67 |
+
sql_query = generate_sql(user_question)
|
| 68 |
+
st.code(sql_query, language="sql")
|
| 69 |
+
|
| 70 |
+
with st.spinner("Executing SQL..."):
|
| 71 |
+
result = run_query(sql_query)
|
| 72 |
+
if isinstance(result, pd.DataFrame):
|
| 73 |
+
st.success("Query executed successfully!")
|
| 74 |
+
st.dataframe(result)
|
| 75 |
+
else:
|
| 76 |
+
st.error(result)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|