Spaces:
Runtime error
Runtime error
| import os | |
| import shutil | |
| import threading | |
| import tempfile | |
| from typing import Optional, List, Dict, Any | |
| from dotenv import load_dotenv | |
| from huggingface_hub import login | |
| import gradio as gr | |
| from scripts.text_inspector_tool import TextInspectorTool | |
| from scripts.text_web_browser import ( | |
| ArchiveSearchTool, | |
| FinderTool, | |
| FindNextTool, | |
| PageDownTool, | |
| PageUpTool, | |
| SimpleTextBrowser, | |
| VisitTool, | |
| ) | |
| from scripts.visual_qa import visualizer | |
| from scripts.legal_document_tool import LegalDocumentTool | |
| from smolagents import ( | |
| CodeAgent, | |
| HfApiModel, | |
| LiteLLMModel, | |
| OpenAIServerModel, | |
| TransformersModel, | |
| GoogleSearchTool, | |
| Tool, | |
| ) | |
| from smolagents.agent_types import AgentText, AgentImage, AgentAudio | |
| from smolagents.gradio_ui import pull_messages_from_step, handle_agent_output_types | |
| # ------------------------ Configuration and Setup ------------------------ | |
| # Constants and configurations | |
| AUTHORIZED_IMPORTS = [ | |
| "requests", | |
| "zipfile", | |
| "pandas", | |
| "numpy", | |
| "sympy", | |
| "json", | |
| "bs4", | |
| "pubchempy", | |
| "xml", | |
| "yahoo_finance", | |
| "Bio", | |
| "sklearn", | |
| "scipy", | |
| "pydub", | |
| "PIL", | |
| "chess", | |
| "PyPDF2", | |
| "pptx", | |
| "torch", | |
| "datetime", | |
| "fractions", | |
| "csv", | |
| "clean-text", | |
| "langchain", | |
| "llama_index", | |
| ] | |
| user_agent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/119.0.0.0 Safari/537.36 Edg/119.0.0.0" | |
| BROWSER_CONFIG = { | |
| "viewport_size": 1024 * 5, | |
| "downloads_folder": "downloads_folder", | |
| "request_kwargs": { | |
| "headers": {"User-Agent": user_agent}, | |
| "timeout": 300, | |
| }, | |
| "serpapi_key": os.getenv("SERPAPI_API_KEY"), | |
| } | |
| custom_role_conversions = {"tool-call": "assistant", "tool-response": "user"} | |
| # Multimedia file types supported (using Gradio-compatible format) | |
| ALLOWED_FILE_TYPES = [ | |
| ".pdf", # application/pdf | |
| ".docx", # application/vnd.openxmlformats-officedocument.wordprocessingml.document | |
| ".txt", # text/plain | |
| ".png", # image/png | |
| ".webp", # image/webp | |
| ".jpeg", # image/jpeg | |
| ".jpg", # image/jpeg | |
| ".gif", # image/gif | |
| ".mp4", # video/mp4 | |
| ".mp3", # audio/mpeg | |
| ".wav", # audio/wav | |
| ".ogg", # audio/ogg | |
| ] | |
| def setup_environment(): | |
| """Initialize environment variables and authentication.""" | |
| load_dotenv(override=True) | |
| if os.getenv("HF_TOKEN"): # Check if token is actually set | |
| login(os.getenv("HF_TOKEN")) | |
| print("HF_TOKEN (last 10 characters):", os.getenv("HF_TOKEN")[-10:]) | |
| else: | |
| print("HF_TOKEN not found in environment variables.") | |
| # ------------------------ Model and Tool Management ------------------------ | |
| class ModelManager: | |
| """Manages model loading and initialization with Zhou Protocol patterns.""" | |
| _instance = None | |
| _lock = threading.Lock() | |
| def get_instance(cls): | |
| """Thread-safe singleton access to model manager.""" | |
| if cls._instance is None: | |
| with cls._lock: | |
| if cls._instance is None: | |
| cls._instance = cls() | |
| return cls._instance | |
| def __init__(self): | |
| """Initialize with model cache.""" | |
| self.model_cache = {} | |
| def load_model(self, chosen_inference: str, model_id: str, key_manager=None): | |
| """Load the specified model with appropriate configuration and caching.""" | |
| cache_key = f"{chosen_inference}:{model_id}" | |
| # Return cached model if available | |
| if cache_key in self.model_cache: | |
| return self.model_cache[cache_key] | |
| try: | |
| if chosen_inference == "hf_api": | |
| model = HfApiModel(model_id=model_id) | |
| elif chosen_inference == "hf_api_provider": | |
| model = HfApiModel(provider="together") | |
| elif chosen_inference == "litellm": | |
| model = LiteLLMModel(model_id=model_id) | |
| elif chosen_inference == "openai": | |
| if not key_manager: | |
| raise ValueError("Key manager required for OpenAI model") | |
| model = OpenAIServerModel( | |
| model_id=model_id, api_key=key_manager.get_key("openai_api_key") | |
| ) | |
| elif chosen_inference == "transformers": | |
| model = TransformersModel( | |
| model_id="HuggingFaceTB/SmolLM2-1.7B-Instruct", | |
| device_map="auto", | |
| max_new_tokens=1000, | |
| ) | |
| else: | |
| raise ValueError(f"Invalid inference type: {chosen_inference}") | |
| # Cache the model for future use | |
| self.model_cache[cache_key] = model | |
| return model | |
| except Exception as e: | |
| print(f"✗ Couldn't load model: {e}") | |
| raise | |
| class ToolRegistry: | |
| """Manages tool initialization and organization with validation.""" | |
| def validate_tools(tools: List[Tool]) -> List[Tool]: | |
| """Validate tools and filter out any None values.""" | |
| return [tool for tool in tools if isinstance(tool, Tool)] | |
| def load_web_tools(model, browser, text_limit=20000): | |
| """Initialize and return web-related tools.""" | |
| return [ | |
| GoogleSearchTool(provider="serper"), | |
| VisitTool(browser), | |
| PageUpTool(browser), | |
| PageDownTool(browser), | |
| FinderTool(browser), | |
| FindNextTool(browser), | |
| ArchiveSearchTool(browser), | |
| TextInspectorTool(model, text_limit), | |
| ] | |
| def load_image_generation_tools(): | |
| """Initialize and return image generation tools.""" | |
| try: | |
| return Tool.from_space( | |
| space_id="xkerser/FLUX.1-dev", | |
| name="image_generator", | |
| description="Generates high-quality AgentImage using the FLUX.1-dev model based on text prompts.", | |
| ) | |
| except Exception as e: | |
| print(f"✗ Couldn't initialize image generation tool: {e}") | |
| return None | |
| def load_legal_document_tool(): | |
| """Initialize and return the legal document processing tool.""" | |
| try: | |
| # Create a simple instance with default parameters | |
| return LegalDocumentTool() | |
| except Exception as e: | |
| print(f"✗ Couldn't initialize legal document tool: {e}") | |
| # Return None instead of raising to make this tool optional | |
| return None | |
| # ------------------------ Session Management ------------------------ | |
| class SessionManager: | |
| """Manages agent sessions with proper cleanup and lifecycle management.""" | |
| _instance = None | |
| _lock = threading.Lock() | |
| def get_instance(cls): | |
| """Thread-safe singleton access.""" | |
| if cls._instance is None: | |
| with cls._lock: | |
| if cls._instance is None: | |
| cls._instance = cls() | |
| return cls._instance | |
| def __init__(self): | |
| """Initialize session management structures.""" | |
| self.sessions = {} | |
| self.temp_files = {} | |
| self.last_activity = {} | |
| self._cleanup_thread = None | |
| self._running = False | |
| self._start_cleanup_thread() | |
| def _start_cleanup_thread(self): | |
| """Start a background thread for session cleanup.""" | |
| if self._cleanup_thread is None: | |
| self._running = True | |
| self._cleanup_thread = threading.Thread( | |
| target=self._cleanup_inactive_sessions, daemon=True | |
| ) | |
| self._cleanup_thread.start() | |
| def _cleanup_inactive_sessions(self): | |
| """Periodically clean up inactive sessions.""" | |
| import time | |
| # Session timeout in seconds (30 minutes) | |
| SESSION_TIMEOUT = 30 * 60 | |
| while self._running: | |
| current_time = time.time() | |
| # Find inactive sessions | |
| inactive_sessions = [ | |
| session_id | |
| for session_id, last_time in self.last_activity.items() | |
| if (current_time - last_time) > SESSION_TIMEOUT | |
| ] | |
| # Clean up each inactive session | |
| for session_id in inactive_sessions: | |
| self.cleanup_session(session_id) | |
| # Sleep for a minute before next check | |
| time.sleep(60) | |
| def register_session(self, session_id): | |
| """Register a new session.""" | |
| if session_id not in self.sessions: | |
| self.sessions[session_id] = {} | |
| self.temp_files[session_id] = [] | |
| # Update activity timestamp | |
| self.last_activity[session_id] = time.time() | |
| def update_activity(self, session_id): | |
| """Update the last activity timestamp for a session.""" | |
| self.last_activity[session_id] = time.time() | |
| def register_temp_file(self, session_id, file_path): | |
| """Register a temporary file with a session for later cleanup.""" | |
| if session_id not in self.temp_files: | |
| self.temp_files[session_id] = [] | |
| self.temp_files[session_id].append(file_path) | |
| def cleanup_session(self, session_id): | |
| """Clean up resources for a session.""" | |
| # Remove temporary files | |
| if session_id in self.temp_files: | |
| for file_path in self.temp_files[session_id]: | |
| try: | |
| if os.path.exists(file_path): | |
| os.remove(file_path) | |
| except Exception as e: | |
| print(f"Error removing temp file {file_path}: {e}") | |
| del self.temp_files[session_id] | |
| # Clean up session data | |
| if session_id in self.sessions: | |
| del self.sessions[session_id] | |
| # Clean up activity record | |
| if session_id in self.last_activity: | |
| del self.last_activity[session_id] | |
| def __del__(self): | |
| """Clean up all sessions when the manager is destroyed.""" | |
| self._running = False | |
| if self._cleanup_thread and self._cleanup_thread.is_alive(): | |
| self._cleanup_thread.join(timeout=1.0) | |
| # Clean up all remaining sessions | |
| for session_id in list(self.sessions.keys()): | |
| self.cleanup_session(session_id) | |
| # ------------------------ Agent Creation and Execution ------------------------ | |
| class AgentFactory: | |
| """Factory for creating and managing agent instances with Zhou Protocol patterns.""" | |
| _instance = None | |
| _lock = threading.Lock() | |
| def get_instance(cls): | |
| """Thread-safe singleton access.""" | |
| if cls._instance is None: | |
| with cls._lock: | |
| if cls._instance is None: | |
| cls._instance = cls() | |
| return cls._instance | |
| def __init__(self): | |
| """Initialize with agent cache.""" | |
| self.agent_cache = {} | |
| def create_agent(self, session_id: str = "default") -> CodeAgent: | |
| """Creates a fresh agent instance with properly configured tools.""" | |
| # Return cached agent if available for this session | |
| if session_id in self.agent_cache: | |
| return self.agent_cache[session_id] | |
| # Initialize model | |
| model = LiteLLMModel( | |
| custom_role_conversions=custom_role_conversions, | |
| model_id="openrouter/perplexity/r1-1776", # currently serving: | |
| ) # DEEPSEEK = openrouter/perplexity/r1-1776 <--- boss model | |
| # Initialize tools | |
| text_limit = 30000 | |
| browser = SimpleTextBrowser(**BROWSER_CONFIG) | |
| # Collect all tools in a single list | |
| web_tools = ToolRegistry.load_web_tools(model, browser, text_limit) | |
| image_generator = ToolRegistry.load_image_generation_tools() | |
| legal_tool = ToolRegistry.load_legal_document_tool() | |
| # Combine and validate all tools | |
| all_tools = [visualizer] + web_tools | |
| # Only add tools that are properly initialized (not None) | |
| if image_generator: | |
| all_tools.append(image_generator) | |
| if legal_tool: | |
| all_tools.append(legal_tool) | |
| # Final validation to ensure all tools are valid | |
| all_tools = ToolRegistry.validate_tools(all_tools) | |
| agent = CodeAgent( | |
| model=model, | |
| tools=all_tools, # Pass a single list containing all tools | |
| max_steps=10, | |
| verbosity_level=1, | |
| additional_authorized_imports=AUTHORIZED_IMPORTS, | |
| planning_interval=4, | |
| ) | |
| # Cache the agent for future use | |
| self.agent_cache[session_id] = agent | |
| return agent | |
| def clear_agent(self, session_id: str): | |
| """Remove an agent from the cache.""" | |
| if session_id in self.agent_cache: | |
| del self.agent_cache[session_id] | |
| def stream_to_gradio( | |
| agent, | |
| task: str, | |
| reset_agent_memory: bool = False, | |
| additional_args: Optional[Dict[str, Any]] = None, | |
| ): | |
| """Runs an agent with the given task and streams messages as Gradio ChatMessages.""" | |
| for step_log in agent.run( | |
| task, stream=True, reset=reset_agent_memory, additional_args=additional_args | |
| ): | |
| for message in pull_messages_from_step(step_log): | |
| yield message | |
| # Process final answer : Use a more comprehensive media output | |
| final_answer = step_log # Last log is the run's final_answer | |
| final_answer = handle_agent_output_types(final_answer) | |
| if isinstance(final_answer, AgentText): | |
| yield gr.ChatMessage( | |
| role="assistant", | |
| content=f"**Final answer:**\n{final_answer.to_string()}\n", | |
| ) | |
| elif isinstance(final_answer, AgentImage): | |
| yield gr.ChatMessage( | |
| role="assistant", | |
| content={"image": final_answer.to_string(), "type": "file"}, | |
| ) # Send as Gradio-compatible file object: | |
| elif isinstance(final_answer, AgentAudio): | |
| yield gr.ChatMessage( | |
| role="assistant", | |
| content={"audio": final_answer.to_string(), "type": "file"}, | |
| ) # Send as Gradio-compatible file object | |
| else: | |
| yield gr.ChatMessage( | |
| role="assistant", content=f"**Final answer:** {str(final_answer)}" | |
| ) | |
| # ------------------------ Gradio UI Components ------------------------ | |
| class GradioUI: | |
| """A Gradio-compliant interface to launch your agent with proper resource management.""" | |
| def __init__(self): | |
| """Initialize the Gradio UI with proper session management.""" | |
| self.session_manager = SessionManager.get_instance() | |
| self.agent_factory = AgentFactory.get_instance() | |
| self.temp_dir = tempfile.mkdtemp(prefix="gradio_") | |
| def __del__(self): | |
| """Clean up resources when the UI is destroyed.""" | |
| try: | |
| # Clean up the temporary directory | |
| if os.path.exists(self.temp_dir): | |
| shutil.rmtree(self.temp_dir, ignore_errors=True) | |
| except Exception as e: | |
| print(f"Error cleaning up temporary directory: {e}") | |
| def _get_session_id(session_state): | |
| """Generate or retrieve a session ID.""" | |
| if "session_id" not in session_state: | |
| session_state["session_id"] = f"session_{id(session_state)}" | |
| return session_state["session_id"] | |
| def interact_with_agent(self, prompt, messages, session_state): | |
| """Main interaction handler with the agent.""" | |
| # Get or create session ID | |
| session_id = self._get_session_id(session_state) | |
| # Register/update the session | |
| self.session_manager.register_session(session_id) | |
| self.session_manager.update_activity(session_id) | |
| # Get or create session-specific agent | |
| agent = self.agent_factory.create_agent(session_id) | |
| try: | |
| # Log the existence of agent memory | |
| has_memory = hasattr(agent, "memory") | |
| print(f"Agent has memory: {has_memory}") | |
| if has_memory: | |
| print(f"Memory type: {type(agent.memory)}") | |
| messages.append(gr.ChatMessage(role="user", content=prompt)) | |
| yield messages | |
| for msg in stream_to_gradio(agent, task=prompt, reset_agent_memory=False): | |
| messages.append(msg) | |
| self.session_manager.update_activity(session_id) | |
| yield messages # Yield messages after each step | |
| yield messages # Yield messages one last time | |
| except gr.Error as e: | |
| # Handle Gradio-specific errors | |
| messages.append( | |
| gr.ChatMessage(role="assistant", content=f"Error: {str(e)}") | |
| ) | |
| yield messages | |
| except Exception as e: | |
| # Log the error but present a user-friendly message | |
| print(f"Error in interaction: {str(e)}") | |
| messages.append( | |
| gr.ChatMessage( | |
| role="assistant", | |
| content="I encountered an error processing your request. Please try again with a different query.", | |
| ) | |
| ) | |
| yield messages | |
| def upload_file(self, file, file_uploads_log, session_state): | |
| """Handle file uploads with Gradio-compliant temporary file handling.""" | |
| session_id = self._get_session_id(session_state) | |
| self.session_manager.update_activity(session_id) | |
| if file is None: | |
| return gr.Textbox("No file uploaded", visible=True), file_uploads_log | |
| try: | |
| # Create a temporary file with a secure random name | |
| temp_file_path = "" | |
| with tempfile.NamedTemporaryFile(delete=False, dir=self.temp_dir) as tmp: | |
| # Copy the uploaded file to the temporary file | |
| shutil.copyfileobj(open(file.name, "rb"), tmp) | |
| temp_file_path = tmp.name | |
| # Register the temporary file with the session manager | |
| self.session_manager.register_temp_file(session_id, temp_file_path) | |
| # Store the original filename for reference | |
| orig_filename = os.path.basename(file.name) | |
| return ( | |
| gr.Textbox(f"File uploaded: {orig_filename}", visible=True), | |
| file_uploads_log + [(temp_file_path, orig_filename)], | |
| ) | |
| except Exception as e: | |
| print(f"Error handling file upload: {e}") | |
| return ( | |
| gr.Textbox(f"Error uploading file: {str(e)}", visible=True), | |
| file_uploads_log, | |
| ) | |
| def log_user_message(self, text_input, file_uploads_log, session_state): | |
| """Process user message and handle file references.""" | |
| session_id = self._get_session_id(session_state) | |
| self.session_manager.update_activity(session_id) | |
| message = text_input | |
| if file_uploads_log and len(file_uploads_log) > 0: | |
| # Include only the original filenames in the message | |
| filenames = [f[1] for f in file_uploads_log] | |
| message += f"\nYou have been provided with these files: {filenames}" | |
| # Include the actual file paths in the additional_args | |
| if "additional_args" not in session_state: | |
| session_state["additional_args"] = {} | |
| session_state["additional_args"]["file_paths"] = [ | |
| f[0] for f in file_uploads_log | |
| ] | |
| return ( | |
| message, | |
| gr.Textbox( | |
| value="", | |
| interactive=False, | |
| placeholder="Processing...", | |
| ), | |
| gr.Button(interactive=False), | |
| session_state, | |
| ) | |
| def launch(self, **kwargs): | |
| """Launch the Gradio UI with responsive layout.""" | |
| with gr.Blocks(theme="soft", css=self._get_responsive_css()) as demo: | |
| with gr.Row(equal_height=True) as main_row: | |
| # Sidebar (adapts to screen size via CSS) | |
| with gr.Column(scale=1, min_width=100) as sidebar: | |
| gr.Markdown( | |
| """# OpenDeepResearch | |
| ## Powered by Smolagents""" | |
| ) | |
| with gr.Group(): | |
| gr.Markdown("**What's on your mind?**", container=True) | |
| text_input = gr.Textbox( | |
| lines=3, | |
| label="Your request", | |
| container=False, | |
| placeholder="Enter your prompt here and press Shift+Enter or press the button", | |
| ) | |
| launch_research_btn = gr.Button("Run", variant="primary") | |
| # Clean file upload with Gradio-compliant file_types | |
| upload_file = gr.File( | |
| label="Upload a file", | |
| file_types=ALLOWED_FILE_TYPES, | |
| type="file", | |
| ) | |
| upload_status = gr.Textbox( | |
| label="Upload Status", interactive=False, visible=False | |
| ) | |
| # Footer with proper responsive behavior | |
| with gr.Row(visible=True) as footer: | |
| gr.HTML( | |
| """ | |
| <div style="display: flex; align-items: center; gap: 8px; font-family: system-ui, -apple-system, sans-serif;"> | |
| <img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/smolagents/mascot_smol.png" | |
| style="width: 32px; height: 32px; object-fit: contain;" alt="logo"> | |
| <a target="_blank" href="https://github.com/huggingface/smolagents"> | |
| <b>huggingface/smolagents</b> | |
| </a> | |
| </div> | |
| """ | |
| ) | |
| # Main content area | |
| with gr.Column(scale=4, min_width=400) as content: | |
| # Add session state to store session-specific data | |
| session_state = gr.State({}) | |
| stored_messages = gr.State([]) | |
| file_uploads_log = gr.State([]) | |
| chatbot = gr.Chatbot( | |
| label="OpenDeepResearch", | |
| show_label=True, | |
| type="messages", | |
| avatar_images=( | |
| None, | |
| "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/smolagents/mascot_smol.png", | |
| ), | |
| height=600, | |
| elem_id="research-chatbot", | |
| ) | |
| # Connect event handlers | |
| text_input.submit( | |
| self.log_user_message, | |
| [text_input, file_uploads_log, session_state], | |
| [stored_messages, text_input, launch_research_btn, session_state], | |
| ).then( | |
| self.interact_with_agent, | |
| [stored_messages, chatbot, session_state], | |
| [chatbot], | |
| ).then( | |
| lambda: ( | |
| gr.Textbox( | |
| interactive=True, | |
| placeholder="Enter your prompt here and press the button", | |
| ), | |
| gr.Button(interactive=True), | |
| ), | |
| None, | |
| [text_input, launch_research_btn], | |
| ) | |
| launch_research_btn.click( | |
| self.log_user_message, | |
| [text_input, file_uploads_log, session_state], | |
| [stored_messages, text_input, launch_research_btn, session_state], | |
| ).then( | |
| self.interact_with_agent, | |
| [stored_messages, chatbot, session_state], | |
| [chatbot], | |
| ).then( | |
| lambda: ( | |
| gr.Textbox( | |
| interactive=True, | |
| placeholder="Enter your prompt here and press the button", | |
| ), | |
| gr.Button(interactive=True), | |
| ), | |
| None, | |
| [text_input, launch_research_btn], | |
| ) | |
| upload_file.change( | |
| self.upload_file, | |
| [upload_file, file_uploads_log, session_state], | |
| [upload_status, file_uploads_log], | |
| ) | |
| # Clean up session on page unload | |
| demo.load( | |
| lambda: None, | |
| None, | |
| None, | |
| _js=""" | |
| () => { | |
| window.addEventListener('beforeunload', function() { | |
| // Notify backend about session end (would require additional endpoint) | |
| console.log('Cleaning up session'); | |
| }); | |
| } | |
| """, | |
| ) | |
| demo.queue(max_size=20).launch(debug=True, **kwargs) | |
| def _get_responsive_css(self): | |
| """Get CSS for responsive layout.""" | |
| return """ | |
| /* Responsive layout */ | |
| @media (max-width: 768px) { | |
| #research-chatbot { | |
| height: 400px !important; | |
| } | |
| /* Stack columns on small screens */ | |
| .gradio-row { | |
| flex-direction: column; | |
| } | |
| /* Adjust column widths */ | |
| .gradio-column { | |
| min-width: 100% !important; | |
| width: 100% !important; | |
| } | |
| } | |
| /* Base styling */ | |
| .gradio-container { | |
| max-width: 100% !important; | |
| } | |
| """ | |
| # ------------------------ Execution ------------------------ | |
| def main(): | |
| """Main entry point for the application.""" | |
| # Initialize environment | |
| setup_environment() | |
| # Ensure downloads folder exists for browser | |
| os.makedirs(f"./{BROWSER_CONFIG['downloads_folder']}", exist_ok=True) | |
| # Launch UI | |
| GradioUI().launch(share=True) | |
| if __name__ == "__main__": | |
| main() | |