Spaces:
Sleeping
Sleeping
| from functools import partial | |
| import gradio as gr | |
| from smolagents import GradioUI | |
| class CustomGradioUI(GradioUI): | |
| """Custom GradioUI that allows customization of the smolagents default interface.""" | |
| def __init__( | |
| self, | |
| agent_factory, | |
| file_upload_folder=None, | |
| reset_agent_memory=False, | |
| allowed_file_types=None, | |
| examples=None, | |
| ): | |
| # Store the factory function instead of the agent directly | |
| self.agent_factory = agent_factory | |
| # Create initial agent with no API key (will use env var if available) | |
| super().__init__(agent_factory(), file_upload_folder, reset_agent_memory) | |
| self.allowed_file_types = allowed_file_types or [".pdf", ".docx", ".txt"] | |
| self.examples = examples or [] | |
| def update_api_key(self, api_key, current_key, session_state, oauth_token: gr.OAuthToken | None = None): | |
| """Update the agent with a new API key.""" | |
| if api_key and api_key != current_key: | |
| # Recreate the agent with the new API key | |
| self.agent = self.agent_factory(tavily_api_key=api_key, oauth_token=oauth_token) | |
| # Store in session state for persistence | |
| session_state["tavily_api_key"] = api_key | |
| return api_key, gr.Markdown("✓ API key updated successfully", visible=True), session_state | |
| elif not api_key and current_key: | |
| # Reset to default (env var) | |
| self.agent = self.agent_factory(oauth_token=oauth_token) | |
| # Clear from session state | |
| session_state.pop("tavily_api_key", None) | |
| return "", gr.Markdown( | |
| "API key cleared, using environment variable if set", visible=True | |
| ), session_state | |
| return current_key, gr.Markdown("", visible=False), session_state | |
| def interact_with_agent_oauth(self, stored_messages, chatbot, session_state, oauth_token: gr.OAuthToken | None = None): | |
| """Wrapper for interact_with_agent that recreates agent with OAuth token.""" | |
| # Recreate agent with the OAuth token before interaction | |
| if oauth_token is not None: | |
| # Get current API key if any | |
| current_api_key = session_state.get("tavily_api_key", None) | |
| self.agent = self.agent_factory(tavily_api_key=current_api_key, oauth_token=oauth_token) | |
| # Call the parent's interact_with_agent method and yield all results | |
| yield from self.interact_with_agent(stored_messages, chatbot, session_state) | |
| def create_app(self): | |
| """Override create_app to use custom allowed_file_types.""" | |
| # Call parent's create_app but we need to rebuild it with our custom upload handler | |
| with gr.Blocks(theme="ocean", fill_height=True) as demo: | |
| session_state = gr.State({}) | |
| stored_messages = gr.State([]) | |
| file_uploads_log = gr.State([]) | |
| current_api_key = gr.State("") # Store current Tavily API key | |
| with gr.Sidebar(): | |
| gr.Markdown( | |
| f"# {self.name.replace('_', ' ').capitalize()}" | |
| "\n> This web ui allows you to interact with a `smolagents` agent " | |
| "that can use tools and execute steps to complete tasks." | |
| + ( | |
| f"\n\n**Agent description:**\n{self.description}" | |
| if self.description | |
| else "" | |
| ) | |
| ) | |
| # Add OAuth Login Button | |
| gr.LoginButton() | |
| with gr.Group(): | |
| gr.Markdown("**Your request**", container=True) | |
| text_input = gr.Textbox( | |
| lines=3, | |
| label="Chat Message", | |
| container=False, | |
| placeholder=( | |
| "Enter your prompt here and press Shift+Enter or press the button" | |
| ), | |
| ) | |
| submit_btn = gr.Button("Submit", variant="primary") | |
| # If an upload folder is provided, enable the upload feature | |
| if self.file_upload_folder is not None: | |
| upload_file = gr.File(label="Upload a file") | |
| upload_status = gr.Textbox( | |
| label="Upload Status", interactive=False, visible=False | |
| ) | |
| # Use partial to bind allowed_file_types | |
| upload_handler = partial( | |
| self.upload_file, allowed_file_types=self.allowed_file_types | |
| ) | |
| upload_file.change( | |
| upload_handler, | |
| [upload_file, file_uploads_log], | |
| [upload_status, file_uploads_log], | |
| ) | |
| # Tavily API Key section | |
| with gr.Group(): | |
| gr.Markdown("**Tavily API Key**", container=True) | |
| gr.Markdown( | |
| "Get your free API key at [tavily.com](https://app.tavily.com/home)", | |
| container=False, | |
| ) | |
| tavily_api_key_input = gr.Textbox( | |
| label="API Key (optional)", | |
| type="password", | |
| placeholder="Enter your Tavily API key to enable web search", | |
| container=False, | |
| ) | |
| api_key_status = gr.Markdown("", visible=False) | |
| # Update agent when API key changes | |
| tavily_api_key_input.change( | |
| self.update_api_key, | |
| [tavily_api_key_input, current_api_key, session_state], | |
| [current_api_key, api_key_status, session_state], | |
| ) | |
| gr.HTML( | |
| "<br><br><h4><center>Powered by <a target='_blank' href='https://github.com/huggingface/smolagents'><b>smolagents</b></a></center></h4>" | |
| ) | |
| # Main chat interface | |
| chatbot = gr.Chatbot( | |
| label="Agent", | |
| type="messages", | |
| avatar_images=( | |
| None, | |
| "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/smolagents/mascot_smol.png", | |
| ), | |
| resizable=True, | |
| scale=1, | |
| latex_delimiters=[ | |
| {"left": r"$$", "right": r"$$", "display": True}, | |
| {"left": r"$", "right": r"$", "display": False}, | |
| {"left": r"\[", "right": r"\]", "display": True}, | |
| {"left": r"\(", "right": r"\)", "display": False}, | |
| ], | |
| ) | |
| # Add examples if provided | |
| if self.examples: | |
| gr.Examples( | |
| examples=self.examples, | |
| inputs=text_input, | |
| cache_examples=False, | |
| ) | |
| # Set up event handlers | |
| text_input.submit( | |
| self.log_user_message, | |
| [text_input, file_uploads_log], | |
| [stored_messages, text_input, submit_btn], | |
| ).then( | |
| self.interact_with_agent_oauth, | |
| [stored_messages, chatbot, session_state], | |
| [chatbot] | |
| ).then( | |
| lambda: ( | |
| gr.Textbox( | |
| interactive=True, | |
| placeholder="Enter your prompt here and press Shift+Enter or the button", | |
| ), | |
| gr.Button(interactive=True), | |
| ), | |
| None, | |
| [text_input, submit_btn], | |
| ) | |
| submit_btn.click( | |
| self.log_user_message, | |
| [text_input, file_uploads_log], | |
| [stored_messages, text_input, submit_btn], | |
| ).then( | |
| self.interact_with_agent_oauth, | |
| [stored_messages, chatbot, session_state], | |
| [chatbot] | |
| ).then( | |
| lambda: ( | |
| gr.Textbox( | |
| interactive=True, | |
| placeholder="Enter your prompt here and press Shift+Enter or the button", | |
| ), | |
| gr.Button(interactive=True), | |
| ), | |
| None, | |
| [text_input, submit_btn], | |
| ) | |
| return demo | |