| | import ast
|
| | import gradio as gr
|
| | from functions import sql_example_question_generator, sql_chatbot_with_fc
|
| | from data_sources import connect_sql_db
|
| | from utils import message_dict
|
| |
|
| | def hide_info():
|
| | return gr.update(visible=False)
|
| |
|
| | with gr.Blocks() as demo:
|
| | description = gr.HTML("""
|
| | <!-- Header -->
|
| | <div class="max-w-4xl mx-auto mb-12 text-center">
|
| | <div class="bg-blue-50 border border-blue-200 rounded-lg max-w-2xl mx-auto">
|
| | <p>This tool allows users to communicate with and query real time data from a SQL DB (postgres for now, others can be added if requested) using natural
|
| | language and the above features.</p>
|
| | <p style="font-weight:bold;">Notice: the way this system is designed, no login information is retained and credentials are passed as session variables until the user leaves or
|
| | refreshes the page in which they disappear. They are never saved to any files. I also make use of the Pandas read_sql_query function to apply SQL
|
| | queries, which can't delete, drop, or add database lines to avoid unhappy accidents or glitches.
|
| | That being said, it's probably not a good idea to connect a production database to a strange AI tool with an unfamiliar author.
|
| | This should be for demonstration purposes.</p>
|
| | <p>Contact me if this is something you would like built in your organization, on your infrastructure, and with the requisite privacy and control a production
|
| | database analytics tool requires.</p>
|
| | </div>
|
| | </div>
|
| | """, elem_classes="description_component")
|
| | sql_url = gr.Textbox(label="URL", value="virtual-data-analyst-pg.cyetm2yjzppu.us-west-1.rds.amazonaws.com")
|
| | with gr.Row():
|
| | sql_port = gr.Textbox(label="Port", value="5432")
|
| | sql_user = gr.Textbox(label="Username", value="postgres")
|
| | sql_pass = gr.Textbox(label="Password", value="Vda-1988", type="password")
|
| | sql_db_name = gr.Textbox(label="Database Name", value="dvdrental")
|
| |
|
| | submit = gr.Button(value="Submit")
|
| | submit.click(fn=hide_info, outputs=description)
|
| |
|
| | @gr.render(inputs=[sql_url,sql_port,sql_user,sql_pass,sql_db_name], triggers=[submit.click])
|
| | def sql_chat(request: gr.Request, url=sql_url.value, sql_port=sql_port.value, sql_user=sql_user.value, sql_pass=sql_pass.value, sql_db_name=sql_db_name.value):
|
| | if request.session_hash not in message_dict:
|
| | message_dict[request.session_hash] = {}
|
| | message_dict[request.session_hash]['sql'] = None
|
| | if url:
|
| | print("SQL APP")
|
| | process_message = process_sql_db(url, sql_user, sql_port, sql_pass, sql_db_name, request.session_hash)
|
| | gr.HTML(value=process_message[1], padding=False)
|
| | if process_message[0] == "success":
|
| | if "virtual-data-analyst-pg.cyetm2yjzppu.us-west-1.rds.amazonaws.com" in url:
|
| | example_questions = [
|
| | ["Describe the dataset"],
|
| | ["What is the total revenue generated by each store?"],
|
| | ["Can you generate and display a bar chart of film category to number of films in that category?"],
|
| | ["Can you generate a pie chart showing the top 10 most rented films by revenue vs all other films?"],
|
| | ["Can you generate a line chart of rental revenue over time?"],
|
| | ["What is the relationship between film length and rental frequency?"]
|
| | ]
|
| | else:
|
| | try:
|
| | generated_examples = ast.literal_eval(sql_example_question_generator(request.session_hash, process_message[2], sql_db_name))
|
| | example_questions = [
|
| | ["Describe the dataset"]
|
| | ]
|
| | for example in generated_examples:
|
| | example_questions.append([example])
|
| | except Exception as e:
|
| | print("SQL QUESTION GENERATION ERROR")
|
| | print(e)
|
| | example_questions = [
|
| | ["Describe the dataset"],
|
| | ["List the columns in the dataset"],
|
| | ["What could this data be used for?"],
|
| | ]
|
| | session_hash = gr.Textbox(visible=False, value=request.session_hash)
|
| | db_url = gr.Textbox(visible=False, value=url)
|
| | db_port = gr.Textbox(visible=False, value=sql_port)
|
| | db_user = gr.Textbox(visible=False, value=sql_user)
|
| | db_pass = gr.Textbox(visible=False, value=sql_pass)
|
| | db_name = gr.Textbox(visible=False, value=sql_db_name)
|
| | db_tables = gr.Textbox(value=process_message[2], interactive=False, label="SQL Tables")
|
| | bot = gr.Chatbot(type='messages', label="CSV Chat Window", render_markdown=True, sanitize_html=False, show_label=True, render=False, visible=True, elem_classes="chatbot")
|
| | chat = gr.ChatInterface(
|
| | fn=sql_chatbot_with_fc,
|
| | type='messages',
|
| | chatbot=bot,
|
| | title="Chat with your Database",
|
| | examples=example_questions,
|
| | concurrency_limit=None,
|
| | additional_inputs=[session_hash, db_url, db_port, db_user, db_pass, db_name, db_tables]
|
| | )
|
| |
|
| | def process_sql_db(url, sql_user, sql_port, sql_pass, sql_db_name, session_hash):
|
| | if url:
|
| | process_message = connect_sql_db(url, sql_user, sql_port, sql_pass, sql_db_name, session_hash)
|
| | return process_message
|
| |
|
| | if __name__ == "__main__":
|
| | demo.launch() |