File size: 7,023 Bytes
a3621be
 
 
 
 
 
 
fcf7df3
f67c385
 
76511dc
a3621be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f67c385
 
 
 
 
 
 
 
a3621be
 
 
 
f67c385
fcf7df3
5fd5cb7
9a23f24
5fd5cb7
9a23f24
5fd5cb7
9a23f24
5fd5cb7
fcf7df3
a3621be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fcf7df3
a3621be
 
 
 
 
 
5fd5cb7
 
 
9a23f24
5fd5cb7
9a23f24
5fd5cb7
9a23f24
5fd5cb7
 
 
f67c385
a3621be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
#@title Gradio UI - 1
import gradio as gr
import os
from langchain.chat_models import init_chat_model
from langchain_community.utilities import SQLDatabase
from langchain_community.agent_toolkits import SQLDatabaseToolkit
from langchain.agents import create_agent
import markdown as md
import base64

# from google.colab import userdata

# --- GLOBALS ---
model = None  # βœ… prevent "NameError" before setup
db = None
toolkit = None
agent = None

# Define available models for each provider
PROVIDER_MODELS = {
    "google_genai": ["gemini-2.0-pro", "gemini-2.5-flash-lite", "gemini-2.5-flash", "gemini-2.0-flash-thinking"],
    "openai": ["gpt-4o", "gpt-4o-mini", "gpt-3.5-turbo"],
    "anthropic": ["claude-3-opus", "claude-3-sonnet", "claude-3-haiku"],
    "azure_openai": ["gpt-4-turbo", "gpt-4o-mini", "gpt-35-turbo"],
    "bedrock": ["anthropic.claude-3-sonnet-v1", "mistral.mixtral-8x7b"],
    "xai": ["grok-2", "grok-2-mini"],
    "deepseek": ["deepseek-chat", "deepseek-coder"],
    "perplexity": ["sonar-small-chat", "sonar-medium-chat", "sonar-large-chat"]
}

def encode_image(image_path):
    with open(image_path, "rb") as image_file:
        return base64.b64encode(image_file.read()).decode('utf-8')

# Encode the images
github_logo_encoded = encode_image("Images/github-logo.png")
linkedin_logo_encoded = encode_image("Images/linkedin-logo.png")
website_logo_encoded = encode_image("Images/ai-logo.png")

def create_chatbot_interface():
    with gr.Blocks(theme=gr.themes.Ocean(font=[gr.themes.GoogleFont("Noto Sans")]),
                   css='footer {visibility: hidden}') as demo:
        gr.Markdown("# DB Assistant πŸ“ŠπŸ‘¨πŸ»β€πŸ­")
        with gr.Accordion("πŸ“” Description:", open=False):
            gr.Markdown(md.description)
        with gr.Accordion('πŸš€ Key Features', open=False):
            gr.Markdown(md.key_features)
        with gr.Accordion('βš™οΈ Tech Stack Overview', open=False):
            gr.Markdown(md.tech_stack_overview)
        with gr.Accordion('🧩 How It Works', open=False):
            gr.Markdown(md.how_it_works)
            
        with gr.Row():
            with gr.Column(scale=1):
                gr.Markdown("## LLM Setup")
                llm_provider = gr.Dropdown(
                    list(PROVIDER_MODELS.keys()), label="Select Provider", value="google_genai"
                )
                llm_model = gr.Dropdown(
                    choices=PROVIDER_MODELS["google_genai"],
                    label="Select Model",
                    value="gemini-2.5-flash-lite"
                )
                api_key = gr.Textbox(label="Enter API Key", type="password")
                setup_llm_btn = gr.Button("Setup LLM")

                gr.Markdown("## Database Connection")
                db_connection_string = gr.Textbox(label="Enter Connection String", type="password")
                connect_db_btn = gr.Button("Connect to Database")
                db_status = gr.Markdown("") 

            with gr.Column(scale=2):
                gr.Markdown("## Chatbot Interface")
                chatbot = gr.Chatbot()
                msg = gr.Textbox(label="Enter your question")
                clear = gr.Button("Clear")

        with gr.Accordion("🧠 Example Interaction", open=False):
            gr.Markdown(md.example)
        with gr.Accordion('πŸ§‘β€πŸ’» Ideal Use Cases', open=False):
            gr.Markdown(md.use_cases)
        with gr.Accordion('πŸͺ„ Future Enhancements', open=False):
            gr.Markdown(md.enhancements)
        with gr.Accordion('πŸ’‘ Credits', open=False):
            gr.Markdown(md.credits)

        gr.HTML(md.footer.format(github_logo_encoded, linkedin_logo_encoded, website_logo_encoded))


        # --- FUNCTIONS ---

        def update_model_dropdown(provider):
            models = PROVIDER_MODELS.get(provider, [])
            default = models[0] if models else None
            return gr.update(choices=models, value=default)

        def setup_llm(provider, model_name, key):
            os.environ["GOOGLE_API_KEY"] = key
            global model
            model = init_chat_model(model_name, model_provider=provider)
            return f"βœ… LLM model `{model_name}` from `{provider}` setup successfully."

        def connect_to_db(db_url):
            global db, toolkit, tools, agent, system_prompt, model
            if model is None:
                return "❌ Please set up the LLM before connecting to the database."

            db = SQLDatabase.from_uri(db_url)
            toolkit = SQLDatabaseToolkit(db=db, llm=model)
            tools = toolkit.get_tools()

            system_prompt = """
              You are an agent designed to interact with a SQL database.
              Given an input question, create a syntactically correct {dialect} query to run,
              then look at the results of the query and return the answer. Unless the user
              specifies a specific number of examples they wish to obtain, always limit your
              query to at most {top_k} results.

              You can order the results by a relevant column to return the most interesting
              examples in the database. Never query for all the columns from a specific table,
              only ask for the relevant columns given the question.

              You MUST double check your query before executing it. If you get an error while
              executing a query, rewrite the query and try again.

              DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the
              database.

              To start you should ALWAYS look at the tables in the database to see what you
              can query. Do NOT skip this step.

              Then you should query the schema of the most relevant tables.
              """.format(
                  dialect=db.dialect,
                  top_k=5,
              )

            agent = create_agent(model, tools, system_prompt=system_prompt)
            tables = db.get_usable_table_names()
            return f"βœ… Connected to database successfully.\n\n**Available tables:** {tables}"

        def respond(message, chat_history):
            result = ""
            for step in agent.stream(
                {"messages": [{"role": "user", "content": message}]},
                stream_mode="values",
            ):
                result += step["messages"][-1].content
            chat_history.append((message, result))
            return "", chat_history

        # --- EVENTS ---
        llm_provider.change(update_model_dropdown, inputs=llm_provider, outputs=llm_model)
        setup_llm_btn.click(setup_llm, inputs=[llm_provider, llm_model, api_key], outputs=None)
        connect_db_btn.click(connect_to_db, inputs=[db_connection_string], outputs=db_status)
        msg.submit(respond, [msg, chatbot], [msg, chatbot])
        clear.click(lambda: None, None, chatbot, queue=False)

    return demo


if __name__ == "__main__":
    demo = create_chatbot_interface()
    demo.launch(debug=True)