import os import shutil import threading import tempfile from typing import Optional, List, Dict, Any from dotenv import load_dotenv from huggingface_hub import login import gradio as gr from scripts.text_inspector_tool import TextInspectorTool from scripts.text_web_browser import ( ArchiveSearchTool, FinderTool, FindNextTool, PageDownTool, PageUpTool, SimpleTextBrowser, VisitTool, ) from scripts.visual_qa import visualizer from scripts.legal_document_tool import LegalDocumentTool from smolagents import ( CodeAgent, HfApiModel, LiteLLMModel, OpenAIServerModel, TransformersModel, GoogleSearchTool, Tool, ) from smolagents.agent_types import AgentText, AgentImage, AgentAudio from smolagents.gradio_ui import pull_messages_from_step, handle_agent_output_types # ------------------------ Configuration and Setup ------------------------ # Constants and configurations AUTHORIZED_IMPORTS = [ "requests", "zipfile", "pandas", "numpy", "sympy", "json", "bs4", "pubchempy", "xml", "yahoo_finance", "Bio", "sklearn", "scipy", "pydub", "PIL", "chess", "PyPDF2", "pptx", "torch", "datetime", "fractions", "csv", "clean-text", "langchain", "llama_index", ] user_agent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/119.0.0.0 Safari/537.36 Edg/119.0.0.0" BROWSER_CONFIG = { "viewport_size": 1024 * 5, "downloads_folder": "downloads_folder", "request_kwargs": { "headers": {"User-Agent": user_agent}, "timeout": 300, }, "serpapi_key": os.getenv("SERPAPI_API_KEY"), } custom_role_conversions = {"tool-call": "assistant", "tool-response": "user"} # Multimedia file types supported (using Gradio-compatible format) ALLOWED_FILE_TYPES = [ ".pdf", # application/pdf ".docx", # application/vnd.openxmlformats-officedocument.wordprocessingml.document ".txt", # text/plain ".png", # image/png ".webp", # image/webp ".jpeg", # image/jpeg ".jpg", # image/jpeg ".gif", # image/gif ".mp4", # video/mp4 ".mp3", # audio/mpeg ".wav", # audio/wav ".ogg", # audio/ogg ] def setup_environment(): """Initialize environment variables and authentication.""" load_dotenv(override=True) if os.getenv("HF_TOKEN"): # Check if token is actually set login(os.getenv("HF_TOKEN")) print("HF_TOKEN (last 10 characters):", os.getenv("HF_TOKEN")[-10:]) else: print("HF_TOKEN not found in environment variables.") # ------------------------ Model and Tool Management ------------------------ class ModelManager: """Manages model loading and initialization with Zhou Protocol patterns.""" _instance = None _lock = threading.Lock() @classmethod def get_instance(cls): """Thread-safe singleton access to model manager.""" if cls._instance is None: with cls._lock: if cls._instance is None: cls._instance = cls() return cls._instance def __init__(self): """Initialize with model cache.""" self.model_cache = {} def load_model(self, chosen_inference: str, model_id: str, key_manager=None): """Load the specified model with appropriate configuration and caching.""" cache_key = f"{chosen_inference}:{model_id}" # Return cached model if available if cache_key in self.model_cache: return self.model_cache[cache_key] try: if chosen_inference == "hf_api": model = HfApiModel(model_id=model_id) elif chosen_inference == "hf_api_provider": model = HfApiModel(provider="together") elif chosen_inference == "litellm": model = LiteLLMModel(model_id=model_id) elif chosen_inference == "openai": if not key_manager: raise ValueError("Key manager required for OpenAI model") model = OpenAIServerModel( model_id=model_id, api_key=key_manager.get_key("openai_api_key") ) elif chosen_inference == "transformers": model = TransformersModel( model_id="HuggingFaceTB/SmolLM2-1.7B-Instruct", device_map="auto", max_new_tokens=1000, ) else: raise ValueError(f"Invalid inference type: {chosen_inference}") # Cache the model for future use self.model_cache[cache_key] = model return model except Exception as e: print(f"✗ Couldn't load model: {e}") raise class ToolRegistry: """Manages tool initialization and organization with validation.""" @staticmethod def validate_tools(tools: List[Tool]) -> List[Tool]: """Validate tools and filter out any None values.""" return [tool for tool in tools if isinstance(tool, Tool)] @staticmethod def load_web_tools(model, browser, text_limit=20000): """Initialize and return web-related tools.""" return [ GoogleSearchTool(provider="serper"), VisitTool(browser), PageUpTool(browser), PageDownTool(browser), FinderTool(browser), FindNextTool(browser), ArchiveSearchTool(browser), TextInspectorTool(model, text_limit), ] @staticmethod def load_image_generation_tools(): """Initialize and return image generation tools.""" try: return Tool.from_space( space_id="xkerser/FLUX.1-dev", name="image_generator", description="Generates high-quality AgentImage using the FLUX.1-dev model based on text prompts.", ) except Exception as e: print(f"✗ Couldn't initialize image generation tool: {e}") return None @staticmethod def load_legal_document_tool(): """Initialize and return the legal document processing tool.""" try: # Create a simple instance with default parameters return LegalDocumentTool() except Exception as e: print(f"✗ Couldn't initialize legal document tool: {e}") # Return None instead of raising to make this tool optional return None # ------------------------ Session Management ------------------------ class SessionManager: """Manages agent sessions with proper cleanup and lifecycle management.""" _instance = None _lock = threading.Lock() @classmethod def get_instance(cls): """Thread-safe singleton access.""" if cls._instance is None: with cls._lock: if cls._instance is None: cls._instance = cls() return cls._instance def __init__(self): """Initialize session management structures.""" self.sessions = {} self.temp_files = {} self.last_activity = {} self._cleanup_thread = None self._running = False self._start_cleanup_thread() def _start_cleanup_thread(self): """Start a background thread for session cleanup.""" if self._cleanup_thread is None: self._running = True self._cleanup_thread = threading.Thread( target=self._cleanup_inactive_sessions, daemon=True ) self._cleanup_thread.start() def _cleanup_inactive_sessions(self): """Periodically clean up inactive sessions.""" import time # Session timeout in seconds (30 minutes) SESSION_TIMEOUT = 30 * 60 while self._running: current_time = time.time() # Find inactive sessions inactive_sessions = [ session_id for session_id, last_time in self.last_activity.items() if (current_time - last_time) > SESSION_TIMEOUT ] # Clean up each inactive session for session_id in inactive_sessions: self.cleanup_session(session_id) # Sleep for a minute before next check time.sleep(60) def register_session(self, session_id): """Register a new session.""" if session_id not in self.sessions: self.sessions[session_id] = {} self.temp_files[session_id] = [] # Update activity timestamp self.last_activity[session_id] = time.time() def update_activity(self, session_id): """Update the last activity timestamp for a session.""" self.last_activity[session_id] = time.time() def register_temp_file(self, session_id, file_path): """Register a temporary file with a session for later cleanup.""" if session_id not in self.temp_files: self.temp_files[session_id] = [] self.temp_files[session_id].append(file_path) def cleanup_session(self, session_id): """Clean up resources for a session.""" # Remove temporary files if session_id in self.temp_files: for file_path in self.temp_files[session_id]: try: if os.path.exists(file_path): os.remove(file_path) except Exception as e: print(f"Error removing temp file {file_path}: {e}") del self.temp_files[session_id] # Clean up session data if session_id in self.sessions: del self.sessions[session_id] # Clean up activity record if session_id in self.last_activity: del self.last_activity[session_id] def __del__(self): """Clean up all sessions when the manager is destroyed.""" self._running = False if self._cleanup_thread and self._cleanup_thread.is_alive(): self._cleanup_thread.join(timeout=1.0) # Clean up all remaining sessions for session_id in list(self.sessions.keys()): self.cleanup_session(session_id) # ------------------------ Agent Creation and Execution ------------------------ class AgentFactory: """Factory for creating and managing agent instances with Zhou Protocol patterns.""" _instance = None _lock = threading.Lock() @classmethod def get_instance(cls): """Thread-safe singleton access.""" if cls._instance is None: with cls._lock: if cls._instance is None: cls._instance = cls() return cls._instance def __init__(self): """Initialize with agent cache.""" self.agent_cache = {} def create_agent(self, session_id: str = "default") -> CodeAgent: """Creates a fresh agent instance with properly configured tools.""" # Return cached agent if available for this session if session_id in self.agent_cache: return self.agent_cache[session_id] # Initialize model model = LiteLLMModel( custom_role_conversions=custom_role_conversions, model_id="openrouter/perplexity/r1-1776", # currently serving: ) # DEEPSEEK = openrouter/perplexity/r1-1776 <--- boss model # Initialize tools text_limit = 30000 browser = SimpleTextBrowser(**BROWSER_CONFIG) # Collect all tools in a single list web_tools = ToolRegistry.load_web_tools(model, browser, text_limit) image_generator = ToolRegistry.load_image_generation_tools() legal_tool = ToolRegistry.load_legal_document_tool() # Combine and validate all tools all_tools = [visualizer] + web_tools # Only add tools that are properly initialized (not None) if image_generator: all_tools.append(image_generator) if legal_tool: all_tools.append(legal_tool) # Final validation to ensure all tools are valid all_tools = ToolRegistry.validate_tools(all_tools) agent = CodeAgent( model=model, tools=all_tools, # Pass a single list containing all tools max_steps=10, verbosity_level=1, additional_authorized_imports=AUTHORIZED_IMPORTS, planning_interval=4, ) # Cache the agent for future use self.agent_cache[session_id] = agent return agent def clear_agent(self, session_id: str): """Remove an agent from the cache.""" if session_id in self.agent_cache: del self.agent_cache[session_id] def stream_to_gradio( agent, task: str, reset_agent_memory: bool = False, additional_args: Optional[Dict[str, Any]] = None, ): """Runs an agent with the given task and streams messages as Gradio ChatMessages.""" for step_log in agent.run( task, stream=True, reset=reset_agent_memory, additional_args=additional_args ): for message in pull_messages_from_step(step_log): yield message # Process final answer : Use a more comprehensive media output final_answer = step_log # Last log is the run's final_answer final_answer = handle_agent_output_types(final_answer) if isinstance(final_answer, AgentText): yield gr.ChatMessage( role="assistant", content=f"**Final answer:**\n{final_answer.to_string()}\n", ) elif isinstance(final_answer, AgentImage): yield gr.ChatMessage( role="assistant", content={"image": final_answer.to_string(), "type": "file"}, ) # Send as Gradio-compatible file object: elif isinstance(final_answer, AgentAudio): yield gr.ChatMessage( role="assistant", content={"audio": final_answer.to_string(), "type": "file"}, ) # Send as Gradio-compatible file object else: yield gr.ChatMessage( role="assistant", content=f"**Final answer:** {str(final_answer)}" ) # ------------------------ Gradio UI Components ------------------------ class GradioUI: """A Gradio-compliant interface to launch your agent with proper resource management.""" def __init__(self): """Initialize the Gradio UI with proper session management.""" self.session_manager = SessionManager.get_instance() self.agent_factory = AgentFactory.get_instance() self.temp_dir = tempfile.mkdtemp(prefix="gradio_") def __del__(self): """Clean up resources when the UI is destroyed.""" try: # Clean up the temporary directory if os.path.exists(self.temp_dir): shutil.rmtree(self.temp_dir, ignore_errors=True) except Exception as e: print(f"Error cleaning up temporary directory: {e}") @staticmethod def _get_session_id(session_state): """Generate or retrieve a session ID.""" if "session_id" not in session_state: session_state["session_id"] = f"session_{id(session_state)}" return session_state["session_id"] def interact_with_agent(self, prompt, messages, session_state): """Main interaction handler with the agent.""" # Get or create session ID session_id = self._get_session_id(session_state) # Register/update the session self.session_manager.register_session(session_id) self.session_manager.update_activity(session_id) # Get or create session-specific agent agent = self.agent_factory.create_agent(session_id) try: # Log the existence of agent memory has_memory = hasattr(agent, "memory") print(f"Agent has memory: {has_memory}") if has_memory: print(f"Memory type: {type(agent.memory)}") messages.append(gr.ChatMessage(role="user", content=prompt)) yield messages for msg in stream_to_gradio(agent, task=prompt, reset_agent_memory=False): messages.append(msg) self.session_manager.update_activity(session_id) yield messages # Yield messages after each step yield messages # Yield messages one last time except gr.Error as e: # Handle Gradio-specific errors messages.append( gr.ChatMessage(role="assistant", content=f"Error: {str(e)}") ) yield messages except Exception as e: # Log the error but present a user-friendly message print(f"Error in interaction: {str(e)}") messages.append( gr.ChatMessage( role="assistant", content="I encountered an error processing your request. Please try again with a different query.", ) ) yield messages @gr.validate_input(file="file") def upload_file(self, file, file_uploads_log, session_state): """Handle file uploads with Gradio-compliant temporary file handling.""" session_id = self._get_session_id(session_state) self.session_manager.update_activity(session_id) if file is None: return gr.Textbox("No file uploaded", visible=True), file_uploads_log try: # Create a temporary file with a secure random name temp_file_path = "" with tempfile.NamedTemporaryFile(delete=False, dir=self.temp_dir) as tmp: # Copy the uploaded file to the temporary file shutil.copyfileobj(open(file.name, "rb"), tmp) temp_file_path = tmp.name # Register the temporary file with the session manager self.session_manager.register_temp_file(session_id, temp_file_path) # Store the original filename for reference orig_filename = os.path.basename(file.name) return ( gr.Textbox(f"File uploaded: {orig_filename}", visible=True), file_uploads_log + [(temp_file_path, orig_filename)], ) except Exception as e: print(f"Error handling file upload: {e}") return ( gr.Textbox(f"Error uploading file: {str(e)}", visible=True), file_uploads_log, ) def log_user_message(self, text_input, file_uploads_log, session_state): """Process user message and handle file references.""" session_id = self._get_session_id(session_state) self.session_manager.update_activity(session_id) message = text_input if file_uploads_log and len(file_uploads_log) > 0: # Include only the original filenames in the message filenames = [f[1] for f in file_uploads_log] message += f"\nYou have been provided with these files: {filenames}" # Include the actual file paths in the additional_args if "additional_args" not in session_state: session_state["additional_args"] = {} session_state["additional_args"]["file_paths"] = [ f[0] for f in file_uploads_log ] return ( message, gr.Textbox( value="", interactive=False, placeholder="Processing...", ), gr.Button(interactive=False), session_state, ) def launch(self, **kwargs): """Launch the Gradio UI with responsive layout.""" with gr.Blocks(theme="soft", css=self._get_responsive_css()) as demo: with gr.Row(equal_height=True) as main_row: # Sidebar (adapts to screen size via CSS) with gr.Column(scale=1, min_width=100) as sidebar: gr.Markdown( """# OpenDeepResearch ## Powered by Smolagents""" ) with gr.Group(): gr.Markdown("**What's on your mind?**", container=True) text_input = gr.Textbox( lines=3, label="Your request", container=False, placeholder="Enter your prompt here and press Shift+Enter or press the button", ) launch_research_btn = gr.Button("Run", variant="primary") # Clean file upload with Gradio-compliant file_types upload_file = gr.File( label="Upload a file", file_types=ALLOWED_FILE_TYPES, type="file", ) upload_status = gr.Textbox( label="Upload Status", interactive=False, visible=False ) # Footer with proper responsive behavior with gr.Row(visible=True) as footer: gr.HTML( """
""" ) # Main content area with gr.Column(scale=4, min_width=400) as content: # Add session state to store session-specific data session_state = gr.State({}) stored_messages = gr.State([]) file_uploads_log = gr.State([]) chatbot = gr.Chatbot( label="OpenDeepResearch", show_label=True, type="messages", avatar_images=( None, "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/smolagents/mascot_smol.png", ), height=600, elem_id="research-chatbot", ) # Connect event handlers text_input.submit( self.log_user_message, [text_input, file_uploads_log, session_state], [stored_messages, text_input, launch_research_btn, session_state], ).then( self.interact_with_agent, [stored_messages, chatbot, session_state], [chatbot], ).then( lambda: ( gr.Textbox( interactive=True, placeholder="Enter your prompt here and press the button", ), gr.Button(interactive=True), ), None, [text_input, launch_research_btn], ) launch_research_btn.click( self.log_user_message, [text_input, file_uploads_log, session_state], [stored_messages, text_input, launch_research_btn, session_state], ).then( self.interact_with_agent, [stored_messages, chatbot, session_state], [chatbot], ).then( lambda: ( gr.Textbox( interactive=True, placeholder="Enter your prompt here and press the button", ), gr.Button(interactive=True), ), None, [text_input, launch_research_btn], ) upload_file.change( self.upload_file, [upload_file, file_uploads_log, session_state], [upload_status, file_uploads_log], ) # Clean up session on page unload demo.load( lambda: None, None, None, _js=""" () => { window.addEventListener('beforeunload', function() { // Notify backend about session end (would require additional endpoint) console.log('Cleaning up session'); }); } """, ) demo.queue(max_size=20).launch(debug=True, **kwargs) def _get_responsive_css(self): """Get CSS for responsive layout.""" return """ /* Responsive layout */ @media (max-width: 768px) { #research-chatbot { height: 400px !important; } /* Stack columns on small screens */ .gradio-row { flex-direction: column; } /* Adjust column widths */ .gradio-column { min-width: 100% !important; width: 100% !important; } } /* Base styling */ .gradio-container { max-width: 100% !important; } """ # ------------------------ Execution ------------------------ def main(): """Main entry point for the application.""" # Initialize environment setup_environment() # Ensure downloads folder exists for browser os.makedirs(f"./{BROWSER_CONFIG['downloads_folder']}", exist_ok=True) # Launch UI GradioUI().launch(share=True) if __name__ == "__main__": main()