Spaces:
Runtime error
Runtime error
File size: 4,441 Bytes
df96495 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
import gradio as gr
import os
import re
from dotenv import load_dotenv
from contextlib import redirect_stdout
from io import StringIO
from langchain import SQLDatabase, SQLDatabaseChain
from langchain.llms import AzureOpenAI
from langchain.agents import create_sql_agent
from langchain.agents.agent_toolkits import SQLDatabaseToolkit
from langchain.agents.agent_types import AgentType
#https://www.youtube.com/watch?v=IN6Q5AwHyLc
load_dotenv(os.getcwd() + "/.env")
llm = AzureOpenAI(
model_name=os.environ["OPENAI_MODEL_NAME"],
deployment_name=os.environ["OPENAI_DEPLOYMENT_NAME"],
temperature=0,
)
sqlite_db_path = "data/Chinook.db"
db = SQLDatabase.from_uri(f"sqlite:///{sqlite_db_path}")
db_chain = SQLDatabaseChain(llm=llm, database=db, verbose=True)
agent_executor = create_sql_agent(
llm=llm,
toolkit=SQLDatabaseToolkit(db=db, llm=llm),
verbose=True,
agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
)
def clear_input():
return "", "Hit 'Submit' to see output here"
def generate_output_of_db_chain(user_message):
print(user_message)
if not user_message:
print("Empty input")
yield "Please enter a messager before hitting Send!"
with redirect_stdout(StringIO()) as f:
db_chain.run(user_message)
s = f.getvalue()
#[6:]: skip first two \n and special tag from LangChain
s = s[6:].replace('\n', '<br/>')
yield re.sub(r"(\x1b)?\[(\d+[m;])+", "", s)
def generate_output_of_db_agent(user_message):
if not user_message:
print("Empty input")
yield "Please enter a messager before hitting Send!"
return ""
with redirect_stdout(StringIO()) as f:
agent_executor.run(user_message)
s = f.getvalue()
#[6:]: skip first two \n and special tag from LangChain
s = s[6:].replace("\n", "<br/>")
yield re.sub(r"(\x1b)?\[(\d+[m;])+", "", s)
custom_css = """
#banner-image {
display: block;
margin-left: auto;
margin-right: auto;
}
#chat-message {
font-size: 14px;
min-height: 300px;
}
"""
with gr.Blocks(analytics_enabled=False, css=custom_css) as demo:
gr.HTML("""<h1 align="center">LLM Mini-Series #4 💬</h1>""")
with gr.Row():
with gr.Column():
gr.Markdown(
f"""
💻 TODO Add some nice description text
"""
)
# normal SQL Chain
gr.HTML("""<h2 align="left">Using LangChain's SQLDatabaseChain</h2>""")
with gr.Row():
with gr.Column():
user_message = gr.Textbox(
placeholder="Enter your message here",
show_label=False,
elem_id="q-input",
)
with gr.Row():
clear_btn = gr.Button("Clear", elem_id="clear-btn", visible=True)
submit_btn = gr.Button("Submit", elem_id="submit-btn", visible=True)
with gr.Box():
output_field = gr.HTML(
value="Hit 'Submit' to see output here",
label="Output of model",
interactive=False,
)
# Agent-based approach
gr.HTML("""<h2 align="left">Using an agent-based approach with LangChain""")
with gr.Row():
with gr.Column():
user_message_agent = gr.Textbox(
placeholder="Enter your message here",
show_label=False,
elem_id="q-agent-input",
)
with gr.Row():
clear_agent_btn = gr.Button(
"Clear", elem_id="clear-agent-btn", visible=True
)
submit_agent_btn = gr.Button(
"Submit", elem_id="submit-agent-btn", visible=True
)
with gr.Box():
output_agent_field = gr.HTML(
value="Hit 'Submit' to see output here",
label="Output of model",
interactive=False,
)
clear_btn.click(clear_input, outputs=[user_message, output_field])
submit_btn.click(
generate_output_of_db_chain, inputs=[user_message], outputs=[output_field]
)
submit_agent_btn.click(
generate_output_of_db_agent,
inputs=[user_message_agent],
outputs=[output_agent_field],
)
clear_agent_btn.click(clear_input, outputs=[user_message_agent, output_agent_field])
demo.queue(concurrency_count=16).launch(debug=True) # , server_port=8080)
|