Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import os | |
| import re | |
| from dotenv import load_dotenv | |
| from contextlib import redirect_stdout | |
| from io import StringIO | |
| from langchain import SQLDatabase, SQLDatabaseChain | |
| from langchain.llms import AzureOpenAI | |
| from langchain.agents import create_sql_agent | |
| from langchain.agents.agent_toolkits import SQLDatabaseToolkit | |
| from langchain.agents.agent_types import AgentType | |
| #https://www.youtube.com/watch?v=IN6Q5AwHyLc | |
| load_dotenv(os.getcwd() + "/.env") | |
| llm = AzureOpenAI( | |
| model_name=os.environ["OPENAI_MODEL_NAME"], | |
| deployment_name=os.environ["OPENAI_DEPLOYMENT_NAME"], | |
| temperature=0, | |
| ) | |
| sqlite_db_path = "data/Chinook.db" | |
| db = SQLDatabase.from_uri(f"sqlite:///{sqlite_db_path}") | |
| db_chain = SQLDatabaseChain(llm=llm, database=db, verbose=True) | |
| agent_executor = create_sql_agent( | |
| llm=llm, | |
| toolkit=SQLDatabaseToolkit(db=db, llm=llm), | |
| verbose=True, | |
| agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION, | |
| ) | |
| def clear_input(): | |
| return "", "Hit 'Submit' to see output here" | |
| def generate_output_of_db_chain(user_message): | |
| print(user_message) | |
| if not user_message: | |
| print("Empty input") | |
| yield "Please enter a messager before hitting Send!" | |
| with redirect_stdout(StringIO()) as f: | |
| db_chain.run(user_message) | |
| s = f.getvalue() | |
| #[6:]: skip first two \n and special tag from LangChain | |
| s = s[6:].replace('\n', '<br/>') | |
| yield re.sub(r"(\x1b)?\[(\d+[m;])+", "", s) | |
| def generate_output_of_db_agent(user_message): | |
| if not user_message: | |
| print("Empty input") | |
| yield "Please enter a messager before hitting Send!" | |
| return "" | |
| with redirect_stdout(StringIO()) as f: | |
| agent_executor.run(user_message) | |
| s = f.getvalue() | |
| #[6:]: skip first two \n and special tag from LangChain | |
| s = s[6:].replace("\n", "<br/>") | |
| yield re.sub(r"(\x1b)?\[(\d+[m;])+", "", s) | |
| custom_css = """ | |
| #banner-image { | |
| display: block; | |
| margin-left: auto; | |
| margin-right: auto; | |
| } | |
| #chat-message { | |
| font-size: 14px; | |
| min-height: 300px; | |
| } | |
| """ | |
| with gr.Blocks(analytics_enabled=False, css=custom_css) as demo: | |
| gr.HTML("""<h1 align="center">LLM Mini-Series #4 💬</h1>""") | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown( | |
| f""" | |
| 💻 TODO Add some nice description text | |
| """ | |
| ) | |
| # normal SQL Chain | |
| gr.HTML("""<h2 align="left">Using LangChain's SQLDatabaseChain</h2>""") | |
| with gr.Row(): | |
| with gr.Column(): | |
| user_message = gr.Textbox( | |
| placeholder="Enter your message here", | |
| show_label=False, | |
| elem_id="q-input", | |
| ) | |
| with gr.Row(): | |
| clear_btn = gr.Button("Clear", elem_id="clear-btn", visible=True) | |
| submit_btn = gr.Button("Submit", elem_id="submit-btn", visible=True) | |
| with gr.Box(): | |
| output_field = gr.HTML( | |
| value="Hit 'Submit' to see output here", | |
| label="Output of model", | |
| interactive=False, | |
| ) | |
| # Agent-based approach | |
| gr.HTML("""<h2 align="left">Using an agent-based approach with LangChain""") | |
| with gr.Row(): | |
| with gr.Column(): | |
| user_message_agent = gr.Textbox( | |
| placeholder="Enter your message here", | |
| show_label=False, | |
| elem_id="q-agent-input", | |
| ) | |
| with gr.Row(): | |
| clear_agent_btn = gr.Button( | |
| "Clear", elem_id="clear-agent-btn", visible=True | |
| ) | |
| submit_agent_btn = gr.Button( | |
| "Submit", elem_id="submit-agent-btn", visible=True | |
| ) | |
| with gr.Box(): | |
| output_agent_field = gr.HTML( | |
| value="Hit 'Submit' to see output here", | |
| label="Output of model", | |
| interactive=False, | |
| ) | |
| clear_btn.click(clear_input, outputs=[user_message, output_field]) | |
| submit_btn.click( | |
| generate_output_of_db_chain, inputs=[user_message], outputs=[output_field] | |
| ) | |
| submit_agent_btn.click( | |
| generate_output_of_db_agent, | |
| inputs=[user_message_agent], | |
| outputs=[output_agent_field], | |
| ) | |
| clear_agent_btn.click(clear_input, outputs=[user_message_agent, output_agent_field]) | |
| demo.queue(concurrency_count=16).launch(debug=True) # , server_port=8080) | |