File size: 4,441 Bytes
df96495
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import gradio as gr
import os
import re

from dotenv import load_dotenv
from contextlib import redirect_stdout
from io import StringIO


from langchain import SQLDatabase, SQLDatabaseChain
from langchain.llms import AzureOpenAI
from langchain.agents import create_sql_agent
from langchain.agents.agent_toolkits import SQLDatabaseToolkit
from langchain.agents.agent_types import AgentType

#https://www.youtube.com/watch?v=IN6Q5AwHyLc


load_dotenv(os.getcwd() + "/.env")

llm = AzureOpenAI(
    model_name=os.environ["OPENAI_MODEL_NAME"],
    deployment_name=os.environ["OPENAI_DEPLOYMENT_NAME"],
    temperature=0,
)

sqlite_db_path = "data/Chinook.db"
db = SQLDatabase.from_uri(f"sqlite:///{sqlite_db_path}")
db_chain = SQLDatabaseChain(llm=llm, database=db, verbose=True)

agent_executor = create_sql_agent(
    llm=llm,
    toolkit=SQLDatabaseToolkit(db=db, llm=llm),
    verbose=True,
    agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
)


def clear_input():
    return "", "Hit 'Submit' to see output here"


def generate_output_of_db_chain(user_message):
    print(user_message)
    if not user_message:
        print("Empty input")
        yield "Please enter a messager before hitting Send!"

    with redirect_stdout(StringIO()) as f:
        db_chain.run(user_message)

    s = f.getvalue()
    #[6:]: skip first two \n and special tag from LangChain
    s = s[6:].replace('\n', '<br/>')

    yield re.sub(r"(\x1b)?\[(\d+[m;])+", "", s)


def generate_output_of_db_agent(user_message):
    if not user_message:
        print("Empty input")
        yield "Please enter a messager before hitting Send!"
        return ""

    with redirect_stdout(StringIO()) as f:
        agent_executor.run(user_message)

    s = f.getvalue()
    #[6:]: skip first two \n and special tag from LangChain
    s = s[6:].replace("\n", "<br/>")

    yield re.sub(r"(\x1b)?\[(\d+[m;])+", "", s)


custom_css = """
#banner-image {
    display: block;
    margin-left: auto;
    margin-right: auto;
}
#chat-message {
    font-size: 14px;
    min-height: 300px;
}
"""


with gr.Blocks(analytics_enabled=False, css=custom_css) as demo:
    gr.HTML("""<h1 align="center">LLM Mini-Series #4 💬</h1>""")

    with gr.Row():
        with gr.Column():
            gr.Markdown(
                f"""
                💻 TODO Add some nice description text
                """
            )

    # normal SQL Chain
    gr.HTML("""<h2 align="left">Using LangChain's SQLDatabaseChain</h2>""")
    with gr.Row():
        with gr.Column():
            user_message = gr.Textbox(
                placeholder="Enter your message here",
                show_label=False,
                elem_id="q-input",
            )
            with gr.Row():
                clear_btn = gr.Button("Clear", elem_id="clear-btn", visible=True)
                submit_btn = gr.Button("Submit", elem_id="submit-btn", visible=True)

        with gr.Box():
            output_field = gr.HTML(
                value="Hit 'Submit' to see output here",
                label="Output of model",
                interactive=False,
            )

    # Agent-based approach
    gr.HTML("""<h2 align="left">Using an agent-based approach with LangChain""")
    with gr.Row():
        with gr.Column():
            user_message_agent = gr.Textbox(
                placeholder="Enter your message here",
                show_label=False,
                elem_id="q-agent-input",
            )
            with gr.Row():
                clear_agent_btn = gr.Button(
                    "Clear", elem_id="clear-agent-btn", visible=True
                )
                submit_agent_btn = gr.Button(
                    "Submit", elem_id="submit-agent-btn", visible=True
                )

        with gr.Box():
            output_agent_field = gr.HTML(
                value="Hit 'Submit' to see output here",
                label="Output of model",
                interactive=False,
            )

    clear_btn.click(clear_input, outputs=[user_message, output_field])
    submit_btn.click(
        generate_output_of_db_chain, inputs=[user_message], outputs=[output_field]
    )

    submit_agent_btn.click(
        generate_output_of_db_agent,
        inputs=[user_message_agent],
        outputs=[output_agent_field],
    )
    clear_agent_btn.click(clear_input, outputs=[user_message_agent, output_agent_field])

demo.queue(concurrency_count=16).launch(debug=True)  # , server_port=8080)