DeepSpaceSearch / my_ui.py
2stacks's picture
Fix Tavily API key persistence in session state
6b5afed verified
from functools import partial
import gradio as gr
from smolagents import GradioUI
class CustomGradioUI(GradioUI):
"""Custom GradioUI that allows customization of the smolagents default interface."""
def __init__(
self,
agent_factory,
file_upload_folder=None,
reset_agent_memory=False,
allowed_file_types=None,
examples=None,
):
# Store the factory function instead of the agent directly
self.agent_factory = agent_factory
# Create initial agent with no API key (will use env var if available)
super().__init__(agent_factory(), file_upload_folder, reset_agent_memory)
self.allowed_file_types = allowed_file_types or [".pdf", ".docx", ".txt"]
self.examples = examples or []
def update_api_key(self, api_key, current_key, session_state, oauth_token: gr.OAuthToken | None = None):
"""Update the agent with a new API key."""
if api_key and api_key != current_key:
# Recreate the agent with the new API key
self.agent = self.agent_factory(tavily_api_key=api_key, oauth_token=oauth_token)
# Store in session state for persistence
session_state["tavily_api_key"] = api_key
return api_key, gr.Markdown("✓ API key updated successfully", visible=True), session_state
elif not api_key and current_key:
# Reset to default (env var)
self.agent = self.agent_factory(oauth_token=oauth_token)
# Clear from session state
session_state.pop("tavily_api_key", None)
return "", gr.Markdown(
"API key cleared, using environment variable if set", visible=True
), session_state
return current_key, gr.Markdown("", visible=False), session_state
def interact_with_agent_oauth(self, stored_messages, chatbot, session_state, oauth_token: gr.OAuthToken | None = None):
"""Wrapper for interact_with_agent that recreates agent with OAuth token."""
# Recreate agent with the OAuth token before interaction
if oauth_token is not None:
# Get current API key if any
current_api_key = session_state.get("tavily_api_key", None)
self.agent = self.agent_factory(tavily_api_key=current_api_key, oauth_token=oauth_token)
# Call the parent's interact_with_agent method and yield all results
yield from self.interact_with_agent(stored_messages, chatbot, session_state)
def create_app(self):
"""Override create_app to use custom allowed_file_types."""
# Call parent's create_app but we need to rebuild it with our custom upload handler
with gr.Blocks(theme="ocean", fill_height=True) as demo:
session_state = gr.State({})
stored_messages = gr.State([])
file_uploads_log = gr.State([])
current_api_key = gr.State("") # Store current Tavily API key
with gr.Sidebar():
gr.Markdown(
f"# {self.name.replace('_', ' ').capitalize()}"
"\n> This web ui allows you to interact with a `smolagents` agent "
"that can use tools and execute steps to complete tasks."
+ (
f"\n\n**Agent description:**\n{self.description}"
if self.description
else ""
)
)
# Add OAuth Login Button
gr.LoginButton()
with gr.Group():
gr.Markdown("**Your request**", container=True)
text_input = gr.Textbox(
lines=3,
label="Chat Message",
container=False,
placeholder=(
"Enter your prompt here and press Shift+Enter or press the button"
),
)
submit_btn = gr.Button("Submit", variant="primary")
# If an upload folder is provided, enable the upload feature
if self.file_upload_folder is not None:
upload_file = gr.File(label="Upload a file")
upload_status = gr.Textbox(
label="Upload Status", interactive=False, visible=False
)
# Use partial to bind allowed_file_types
upload_handler = partial(
self.upload_file, allowed_file_types=self.allowed_file_types
)
upload_file.change(
upload_handler,
[upload_file, file_uploads_log],
[upload_status, file_uploads_log],
)
# Tavily API Key section
with gr.Group():
gr.Markdown("**Tavily API Key**", container=True)
gr.Markdown(
"Get your free API key at [tavily.com](https://app.tavily.com/home)",
container=False,
)
tavily_api_key_input = gr.Textbox(
label="API Key (optional)",
type="password",
placeholder="Enter your Tavily API key to enable web search",
container=False,
)
api_key_status = gr.Markdown("", visible=False)
# Update agent when API key changes
tavily_api_key_input.change(
self.update_api_key,
[tavily_api_key_input, current_api_key, session_state],
[current_api_key, api_key_status, session_state],
)
gr.HTML(
"<br><br><h4><center>Powered by <a target='_blank' href='https://github.com/huggingface/smolagents'><b>smolagents</b></a></center></h4>"
)
# Main chat interface
chatbot = gr.Chatbot(
label="Agent",
type="messages",
avatar_images=(
None,
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/smolagents/mascot_smol.png",
),
resizable=True,
scale=1,
latex_delimiters=[
{"left": r"$$", "right": r"$$", "display": True},
{"left": r"$", "right": r"$", "display": False},
{"left": r"\[", "right": r"\]", "display": True},
{"left": r"\(", "right": r"\)", "display": False},
],
)
# Add examples if provided
if self.examples:
gr.Examples(
examples=self.examples,
inputs=text_input,
cache_examples=False,
)
# Set up event handlers
text_input.submit(
self.log_user_message,
[text_input, file_uploads_log],
[stored_messages, text_input, submit_btn],
).then(
self.interact_with_agent_oauth,
[stored_messages, chatbot, session_state],
[chatbot]
).then(
lambda: (
gr.Textbox(
interactive=True,
placeholder="Enter your prompt here and press Shift+Enter or the button",
),
gr.Button(interactive=True),
),
None,
[text_input, submit_btn],
)
submit_btn.click(
self.log_user_message,
[text_input, file_uploads_log],
[stored_messages, text_input, submit_btn],
).then(
self.interact_with_agent_oauth,
[stored_messages, chatbot, session_state],
[chatbot]
).then(
lambda: (
gr.Textbox(
interactive=True,
placeholder="Enter your prompt here and press Shift+Enter or the button",
),
gr.Button(interactive=True),
),
None,
[text_input, submit_btn],
)
return demo