Spaces:
Sleeping
Sleeping
File size: 7,023 Bytes
a3621be fcf7df3 f67c385 76511dc a3621be f67c385 a3621be f67c385 fcf7df3 5fd5cb7 9a23f24 5fd5cb7 9a23f24 5fd5cb7 9a23f24 5fd5cb7 fcf7df3 a3621be fcf7df3 a3621be 5fd5cb7 9a23f24 5fd5cb7 9a23f24 5fd5cb7 9a23f24 5fd5cb7 f67c385 a3621be |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
#@title Gradio UI - 1
import gradio as gr
import os
from langchain.chat_models import init_chat_model
from langchain_community.utilities import SQLDatabase
from langchain_community.agent_toolkits import SQLDatabaseToolkit
from langchain.agents import create_agent
import markdown as md
import base64
# from google.colab import userdata
# --- GLOBALS ---
model = None # β
prevent "NameError" before setup
db = None
toolkit = None
agent = None
# Define available models for each provider
PROVIDER_MODELS = {
"google_genai": ["gemini-2.0-pro", "gemini-2.5-flash-lite", "gemini-2.5-flash", "gemini-2.0-flash-thinking"],
"openai": ["gpt-4o", "gpt-4o-mini", "gpt-3.5-turbo"],
"anthropic": ["claude-3-opus", "claude-3-sonnet", "claude-3-haiku"],
"azure_openai": ["gpt-4-turbo", "gpt-4o-mini", "gpt-35-turbo"],
"bedrock": ["anthropic.claude-3-sonnet-v1", "mistral.mixtral-8x7b"],
"xai": ["grok-2", "grok-2-mini"],
"deepseek": ["deepseek-chat", "deepseek-coder"],
"perplexity": ["sonar-small-chat", "sonar-medium-chat", "sonar-large-chat"]
}
def encode_image(image_path):
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode('utf-8')
# Encode the images
github_logo_encoded = encode_image("Images/github-logo.png")
linkedin_logo_encoded = encode_image("Images/linkedin-logo.png")
website_logo_encoded = encode_image("Images/ai-logo.png")
def create_chatbot_interface():
with gr.Blocks(theme=gr.themes.Ocean(font=[gr.themes.GoogleFont("Noto Sans")]),
css='footer {visibility: hidden}') as demo:
gr.Markdown("# DB Assistant ππ¨π»βπ")
with gr.Accordion("π Description:", open=False):
gr.Markdown(md.description)
with gr.Accordion('π Key Features', open=False):
gr.Markdown(md.key_features)
with gr.Accordion('βοΈ Tech Stack Overview', open=False):
gr.Markdown(md.tech_stack_overview)
with gr.Accordion('π§© How It Works', open=False):
gr.Markdown(md.how_it_works)
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("## LLM Setup")
llm_provider = gr.Dropdown(
list(PROVIDER_MODELS.keys()), label="Select Provider", value="google_genai"
)
llm_model = gr.Dropdown(
choices=PROVIDER_MODELS["google_genai"],
label="Select Model",
value="gemini-2.5-flash-lite"
)
api_key = gr.Textbox(label="Enter API Key", type="password")
setup_llm_btn = gr.Button("Setup LLM")
gr.Markdown("## Database Connection")
db_connection_string = gr.Textbox(label="Enter Connection String", type="password")
connect_db_btn = gr.Button("Connect to Database")
db_status = gr.Markdown("")
with gr.Column(scale=2):
gr.Markdown("## Chatbot Interface")
chatbot = gr.Chatbot()
msg = gr.Textbox(label="Enter your question")
clear = gr.Button("Clear")
with gr.Accordion("π§ Example Interaction", open=False):
gr.Markdown(md.example)
with gr.Accordion('π§βπ» Ideal Use Cases', open=False):
gr.Markdown(md.use_cases)
with gr.Accordion('πͺ Future Enhancements', open=False):
gr.Markdown(md.enhancements)
with gr.Accordion('π‘ Credits', open=False):
gr.Markdown(md.credits)
gr.HTML(md.footer.format(github_logo_encoded, linkedin_logo_encoded, website_logo_encoded))
# --- FUNCTIONS ---
def update_model_dropdown(provider):
models = PROVIDER_MODELS.get(provider, [])
default = models[0] if models else None
return gr.update(choices=models, value=default)
def setup_llm(provider, model_name, key):
os.environ["GOOGLE_API_KEY"] = key
global model
model = init_chat_model(model_name, model_provider=provider)
return f"β
LLM model `{model_name}` from `{provider}` setup successfully."
def connect_to_db(db_url):
global db, toolkit, tools, agent, system_prompt, model
if model is None:
return "β Please set up the LLM before connecting to the database."
db = SQLDatabase.from_uri(db_url)
toolkit = SQLDatabaseToolkit(db=db, llm=model)
tools = toolkit.get_tools()
system_prompt = """
You are an agent designed to interact with a SQL database.
Given an input question, create a syntactically correct {dialect} query to run,
then look at the results of the query and return the answer. Unless the user
specifies a specific number of examples they wish to obtain, always limit your
query to at most {top_k} results.
You can order the results by a relevant column to return the most interesting
examples in the database. Never query for all the columns from a specific table,
only ask for the relevant columns given the question.
You MUST double check your query before executing it. If you get an error while
executing a query, rewrite the query and try again.
DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the
database.
To start you should ALWAYS look at the tables in the database to see what you
can query. Do NOT skip this step.
Then you should query the schema of the most relevant tables.
""".format(
dialect=db.dialect,
top_k=5,
)
agent = create_agent(model, tools, system_prompt=system_prompt)
tables = db.get_usable_table_names()
return f"β
Connected to database successfully.\n\n**Available tables:** {tables}"
def respond(message, chat_history):
result = ""
for step in agent.stream(
{"messages": [{"role": "user", "content": message}]},
stream_mode="values",
):
result += step["messages"][-1].content
chat_history.append((message, result))
return "", chat_history
# --- EVENTS ---
llm_provider.change(update_model_dropdown, inputs=llm_provider, outputs=llm_model)
setup_llm_btn.click(setup_llm, inputs=[llm_provider, llm_model, api_key], outputs=None)
connect_db_btn.click(connect_to_db, inputs=[db_connection_string], outputs=db_status)
msg.submit(respond, [msg, chatbot], [msg, chatbot])
clear.click(lambda: None, None, chatbot, queue=False)
return demo
if __name__ == "__main__":
demo = create_chatbot_interface()
demo.launch(debug=True) |