Leonardo
Update app.py
56c8b27 verified
raw
history blame
26.1 kB
import os
import shutil
import threading
import tempfile
from typing import Optional, List, Dict, Any
from dotenv import load_dotenv
from huggingface_hub import login
import gradio as gr
from scripts.text_inspector_tool import TextInspectorTool
from scripts.text_web_browser import (
ArchiveSearchTool,
FinderTool,
FindNextTool,
PageDownTool,
PageUpTool,
SimpleTextBrowser,
VisitTool,
)
from scripts.visual_qa import visualizer
from scripts.legal_document_tool import LegalDocumentTool
from smolagents import (
CodeAgent,
HfApiModel,
LiteLLMModel,
OpenAIServerModel,
TransformersModel,
GoogleSearchTool,
Tool,
)
from smolagents.agent_types import AgentText, AgentImage, AgentAudio
from smolagents.gradio_ui import pull_messages_from_step, handle_agent_output_types
# ------------------------ Configuration and Setup ------------------------
# Constants and configurations
AUTHORIZED_IMPORTS = [
"requests",
"zipfile",
"pandas",
"numpy",
"sympy",
"json",
"bs4",
"pubchempy",
"xml",
"yahoo_finance",
"Bio",
"sklearn",
"scipy",
"pydub",
"PIL",
"chess",
"PyPDF2",
"pptx",
"torch",
"datetime",
"fractions",
"csv",
"clean-text",
"langchain",
"llama_index",
]
user_agent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/119.0.0.0 Safari/537.36 Edg/119.0.0.0"
BROWSER_CONFIG = {
"viewport_size": 1024 * 5,
"downloads_folder": "downloads_folder",
"request_kwargs": {
"headers": {"User-Agent": user_agent},
"timeout": 300,
},
"serpapi_key": os.getenv("SERPAPI_API_KEY"),
}
custom_role_conversions = {"tool-call": "assistant", "tool-response": "user"}
# Multimedia file types supported (using Gradio-compatible format)
ALLOWED_FILE_TYPES = [
".pdf", # application/pdf
".docx", # application/vnd.openxmlformats-officedocument.wordprocessingml.document
".txt", # text/plain
".png", # image/png
".webp", # image/webp
".jpeg", # image/jpeg
".jpg", # image/jpeg
".gif", # image/gif
".mp4", # video/mp4
".mp3", # audio/mpeg
".wav", # audio/wav
".ogg", # audio/ogg
]
def setup_environment():
"""Initialize environment variables and authentication."""
load_dotenv(override=True)
if os.getenv("HF_TOKEN"): # Check if token is actually set
login(os.getenv("HF_TOKEN"))
print("HF_TOKEN (last 10 characters):", os.getenv("HF_TOKEN")[-10:])
else:
print("HF_TOKEN not found in environment variables.")
# ------------------------ Model and Tool Management ------------------------
class ModelManager:
"""Manages model loading and initialization with Zhou Protocol patterns."""
_instance = None
_lock = threading.Lock()
@classmethod
def get_instance(cls):
"""Thread-safe singleton access to model manager."""
if cls._instance is None:
with cls._lock:
if cls._instance is None:
cls._instance = cls()
return cls._instance
def __init__(self):
"""Initialize with model cache."""
self.model_cache = {}
def load_model(self, chosen_inference: str, model_id: str, key_manager=None):
"""Load the specified model with appropriate configuration and caching."""
cache_key = f"{chosen_inference}:{model_id}"
# Return cached model if available
if cache_key in self.model_cache:
return self.model_cache[cache_key]
try:
if chosen_inference == "hf_api":
model = HfApiModel(model_id=model_id)
elif chosen_inference == "hf_api_provider":
model = HfApiModel(provider="together")
elif chosen_inference == "litellm":
model = LiteLLMModel(model_id=model_id)
elif chosen_inference == "openai":
if not key_manager:
raise ValueError("Key manager required for OpenAI model")
model = OpenAIServerModel(
model_id=model_id, api_key=key_manager.get_key("openai_api_key")
)
elif chosen_inference == "transformers":
model = TransformersModel(
model_id="HuggingFaceTB/SmolLM2-1.7B-Instruct",
device_map="auto",
max_new_tokens=1000,
)
else:
raise ValueError(f"Invalid inference type: {chosen_inference}")
# Cache the model for future use
self.model_cache[cache_key] = model
return model
except Exception as e:
print(f"✗ Couldn't load model: {e}")
raise
class ToolRegistry:
"""Manages tool initialization and organization with validation."""
@staticmethod
def validate_tools(tools: List[Tool]) -> List[Tool]:
"""Validate tools and filter out any None values."""
return [tool for tool in tools if isinstance(tool, Tool)]
@staticmethod
def load_web_tools(model, browser, text_limit=20000):
"""Initialize and return web-related tools."""
return [
GoogleSearchTool(provider="serper"),
VisitTool(browser),
PageUpTool(browser),
PageDownTool(browser),
FinderTool(browser),
FindNextTool(browser),
ArchiveSearchTool(browser),
TextInspectorTool(model, text_limit),
]
@staticmethod
def load_image_generation_tools():
"""Initialize and return image generation tools."""
try:
return Tool.from_space(
space_id="xkerser/FLUX.1-dev",
name="image_generator",
description="Generates high-quality AgentImage using the FLUX.1-dev model based on text prompts.",
)
except Exception as e:
print(f"✗ Couldn't initialize image generation tool: {e}")
return None
@staticmethod
def load_legal_document_tool():
"""Initialize and return the legal document processing tool."""
try:
# Create a simple instance with default parameters
return LegalDocumentTool()
except Exception as e:
print(f"✗ Couldn't initialize legal document tool: {e}")
# Return None instead of raising to make this tool optional
return None
# ------------------------ Session Management ------------------------
class SessionManager:
"""Manages agent sessions with proper cleanup and lifecycle management."""
_instance = None
_lock = threading.Lock()
@classmethod
def get_instance(cls):
"""Thread-safe singleton access."""
if cls._instance is None:
with cls._lock:
if cls._instance is None:
cls._instance = cls()
return cls._instance
def __init__(self):
"""Initialize session management structures."""
self.sessions = {}
self.temp_files = {}
self.last_activity = {}
self._cleanup_thread = None
self._running = False
self._start_cleanup_thread()
def _start_cleanup_thread(self):
"""Start a background thread for session cleanup."""
if self._cleanup_thread is None:
self._running = True
self._cleanup_thread = threading.Thread(
target=self._cleanup_inactive_sessions, daemon=True
)
self._cleanup_thread.start()
def _cleanup_inactive_sessions(self):
"""Periodically clean up inactive sessions."""
import time
# Session timeout in seconds (30 minutes)
SESSION_TIMEOUT = 30 * 60
while self._running:
current_time = time.time()
# Find inactive sessions
inactive_sessions = [
session_id
for session_id, last_time in self.last_activity.items()
if (current_time - last_time) > SESSION_TIMEOUT
]
# Clean up each inactive session
for session_id in inactive_sessions:
self.cleanup_session(session_id)
# Sleep for a minute before next check
time.sleep(60)
def register_session(self, session_id):
"""Register a new session."""
if session_id not in self.sessions:
self.sessions[session_id] = {}
self.temp_files[session_id] = []
# Update activity timestamp
self.last_activity[session_id] = time.time()
def update_activity(self, session_id):
"""Update the last activity timestamp for a session."""
self.last_activity[session_id] = time.time()
def register_temp_file(self, session_id, file_path):
"""Register a temporary file with a session for later cleanup."""
if session_id not in self.temp_files:
self.temp_files[session_id] = []
self.temp_files[session_id].append(file_path)
def cleanup_session(self, session_id):
"""Clean up resources for a session."""
# Remove temporary files
if session_id in self.temp_files:
for file_path in self.temp_files[session_id]:
try:
if os.path.exists(file_path):
os.remove(file_path)
except Exception as e:
print(f"Error removing temp file {file_path}: {e}")
del self.temp_files[session_id]
# Clean up session data
if session_id in self.sessions:
del self.sessions[session_id]
# Clean up activity record
if session_id in self.last_activity:
del self.last_activity[session_id]
def __del__(self):
"""Clean up all sessions when the manager is destroyed."""
self._running = False
if self._cleanup_thread and self._cleanup_thread.is_alive():
self._cleanup_thread.join(timeout=1.0)
# Clean up all remaining sessions
for session_id in list(self.sessions.keys()):
self.cleanup_session(session_id)
# ------------------------ Agent Creation and Execution ------------------------
class AgentFactory:
"""Factory for creating and managing agent instances with Zhou Protocol patterns."""
_instance = None
_lock = threading.Lock()
@classmethod
def get_instance(cls):
"""Thread-safe singleton access."""
if cls._instance is None:
with cls._lock:
if cls._instance is None:
cls._instance = cls()
return cls._instance
def __init__(self):
"""Initialize with agent cache."""
self.agent_cache = {}
def create_agent(self, session_id: str = "default") -> CodeAgent:
"""Creates a fresh agent instance with properly configured tools."""
# Return cached agent if available for this session
if session_id in self.agent_cache:
return self.agent_cache[session_id]
# Initialize model
model = LiteLLMModel(
custom_role_conversions=custom_role_conversions,
model_id="openrouter/perplexity/r1-1776", # currently serving:
) # DEEPSEEK = openrouter/perplexity/r1-1776 <--- boss model
# Initialize tools
text_limit = 30000
browser = SimpleTextBrowser(**BROWSER_CONFIG)
# Collect all tools in a single list
web_tools = ToolRegistry.load_web_tools(model, browser, text_limit)
image_generator = ToolRegistry.load_image_generation_tools()
legal_tool = ToolRegistry.load_legal_document_tool()
# Combine and validate all tools
all_tools = [visualizer] + web_tools
# Only add tools that are properly initialized (not None)
if image_generator:
all_tools.append(image_generator)
if legal_tool:
all_tools.append(legal_tool)
# Final validation to ensure all tools are valid
all_tools = ToolRegistry.validate_tools(all_tools)
agent = CodeAgent(
model=model,
tools=all_tools, # Pass a single list containing all tools
max_steps=10,
verbosity_level=1,
additional_authorized_imports=AUTHORIZED_IMPORTS,
planning_interval=4,
)
# Cache the agent for future use
self.agent_cache[session_id] = agent
return agent
def clear_agent(self, session_id: str):
"""Remove an agent from the cache."""
if session_id in self.agent_cache:
del self.agent_cache[session_id]
def stream_to_gradio(
agent,
task: str,
reset_agent_memory: bool = False,
additional_args: Optional[Dict[str, Any]] = None,
):
"""Runs an agent with the given task and streams messages as Gradio ChatMessages."""
for step_log in agent.run(
task, stream=True, reset=reset_agent_memory, additional_args=additional_args
):
for message in pull_messages_from_step(step_log):
yield message
# Process final answer : Use a more comprehensive media output
final_answer = step_log # Last log is the run's final_answer
final_answer = handle_agent_output_types(final_answer)
if isinstance(final_answer, AgentText):
yield gr.ChatMessage(
role="assistant",
content=f"**Final answer:**\n{final_answer.to_string()}\n",
)
elif isinstance(final_answer, AgentImage):
yield gr.ChatMessage(
role="assistant",
content={"image": final_answer.to_string(), "type": "file"},
) # Send as Gradio-compatible file object:
elif isinstance(final_answer, AgentAudio):
yield gr.ChatMessage(
role="assistant",
content={"audio": final_answer.to_string(), "type": "file"},
) # Send as Gradio-compatible file object
else:
yield gr.ChatMessage(
role="assistant", content=f"**Final answer:** {str(final_answer)}"
)
# ------------------------ Gradio UI Components ------------------------
class GradioUI:
"""A Gradio-compliant interface to launch your agent with proper resource management."""
def __init__(self):
"""Initialize the Gradio UI with proper session management."""
self.session_manager = SessionManager.get_instance()
self.agent_factory = AgentFactory.get_instance()
self.temp_dir = tempfile.mkdtemp(prefix="gradio_")
def __del__(self):
"""Clean up resources when the UI is destroyed."""
try:
# Clean up the temporary directory
if os.path.exists(self.temp_dir):
shutil.rmtree(self.temp_dir, ignore_errors=True)
except Exception as e:
print(f"Error cleaning up temporary directory: {e}")
@staticmethod
def _get_session_id(session_state):
"""Generate or retrieve a session ID."""
if "session_id" not in session_state:
session_state["session_id"] = f"session_{id(session_state)}"
return session_state["session_id"]
def interact_with_agent(self, prompt, messages, session_state):
"""Main interaction handler with the agent."""
# Get or create session ID
session_id = self._get_session_id(session_state)
# Register/update the session
self.session_manager.register_session(session_id)
self.session_manager.update_activity(session_id)
# Get or create session-specific agent
agent = self.agent_factory.create_agent(session_id)
try:
# Log the existence of agent memory
has_memory = hasattr(agent, "memory")
print(f"Agent has memory: {has_memory}")
if has_memory:
print(f"Memory type: {type(agent.memory)}")
messages.append(gr.ChatMessage(role="user", content=prompt))
yield messages
for msg in stream_to_gradio(agent, task=prompt, reset_agent_memory=False):
messages.append(msg)
self.session_manager.update_activity(session_id)
yield messages # Yield messages after each step
yield messages # Yield messages one last time
except gr.Error as e:
# Handle Gradio-specific errors
messages.append(
gr.ChatMessage(role="assistant", content=f"Error: {str(e)}")
)
yield messages
except Exception as e:
# Log the error but present a user-friendly message
print(f"Error in interaction: {str(e)}")
messages.append(
gr.ChatMessage(
role="assistant",
content="I encountered an error processing your request. Please try again with a different query.",
)
)
yield messages
@gr.validate_input(file="file")
def upload_file(self, file, file_uploads_log, session_state):
"""Handle file uploads with Gradio-compliant temporary file handling."""
session_id = self._get_session_id(session_state)
self.session_manager.update_activity(session_id)
if file is None:
return gr.Textbox("No file uploaded", visible=True), file_uploads_log
try:
# Create a temporary file with a secure random name
temp_file_path = ""
with tempfile.NamedTemporaryFile(delete=False, dir=self.temp_dir) as tmp:
# Copy the uploaded file to the temporary file
shutil.copyfileobj(open(file.name, "rb"), tmp)
temp_file_path = tmp.name
# Register the temporary file with the session manager
self.session_manager.register_temp_file(session_id, temp_file_path)
# Store the original filename for reference
orig_filename = os.path.basename(file.name)
return (
gr.Textbox(f"File uploaded: {orig_filename}", visible=True),
file_uploads_log + [(temp_file_path, orig_filename)],
)
except Exception as e:
print(f"Error handling file upload: {e}")
return (
gr.Textbox(f"Error uploading file: {str(e)}", visible=True),
file_uploads_log,
)
def log_user_message(self, text_input, file_uploads_log, session_state):
"""Process user message and handle file references."""
session_id = self._get_session_id(session_state)
self.session_manager.update_activity(session_id)
message = text_input
if file_uploads_log and len(file_uploads_log) > 0:
# Include only the original filenames in the message
filenames = [f[1] for f in file_uploads_log]
message += f"\nYou have been provided with these files: {filenames}"
# Include the actual file paths in the additional_args
if "additional_args" not in session_state:
session_state["additional_args"] = {}
session_state["additional_args"]["file_paths"] = [
f[0] for f in file_uploads_log
]
return (
message,
gr.Textbox(
value="",
interactive=False,
placeholder="Processing...",
),
gr.Button(interactive=False),
session_state,
)
def launch(self, **kwargs):
"""Launch the Gradio UI with responsive layout."""
with gr.Blocks(theme="soft", css=self._get_responsive_css()) as demo:
with gr.Row(equal_height=True) as main_row:
# Sidebar (adapts to screen size via CSS)
with gr.Column(scale=1, min_width=100) as sidebar:
gr.Markdown(
"""# OpenDeepResearch
## Powered by Smolagents"""
)
with gr.Group():
gr.Markdown("**What's on your mind?**", container=True)
text_input = gr.Textbox(
lines=3,
label="Your request",
container=False,
placeholder="Enter your prompt here and press Shift+Enter or press the button",
)
launch_research_btn = gr.Button("Run", variant="primary")
# Clean file upload with Gradio-compliant file_types
upload_file = gr.File(
label="Upload a file",
file_types=ALLOWED_FILE_TYPES,
type="file",
)
upload_status = gr.Textbox(
label="Upload Status", interactive=False, visible=False
)
# Footer with proper responsive behavior
with gr.Row(visible=True) as footer:
gr.HTML(
"""
<div style="display: flex; align-items: center; gap: 8px; font-family: system-ui, -apple-system, sans-serif;">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/smolagents/mascot_smol.png"
style="width: 32px; height: 32px; object-fit: contain;" alt="logo">
<a target="_blank" href="https://github.com/huggingface/smolagents">
<b>huggingface/smolagents</b>
</a>
</div>
"""
)
# Main content area
with gr.Column(scale=4, min_width=400) as content:
# Add session state to store session-specific data
session_state = gr.State({})
stored_messages = gr.State([])
file_uploads_log = gr.State([])
chatbot = gr.Chatbot(
label="OpenDeepResearch",
show_label=True,
type="messages",
avatar_images=(
None,
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/smolagents/mascot_smol.png",
),
height=600,
elem_id="research-chatbot",
)
# Connect event handlers
text_input.submit(
self.log_user_message,
[text_input, file_uploads_log, session_state],
[stored_messages, text_input, launch_research_btn, session_state],
).then(
self.interact_with_agent,
[stored_messages, chatbot, session_state],
[chatbot],
).then(
lambda: (
gr.Textbox(
interactive=True,
placeholder="Enter your prompt here and press the button",
),
gr.Button(interactive=True),
),
None,
[text_input, launch_research_btn],
)
launch_research_btn.click(
self.log_user_message,
[text_input, file_uploads_log, session_state],
[stored_messages, text_input, launch_research_btn, session_state],
).then(
self.interact_with_agent,
[stored_messages, chatbot, session_state],
[chatbot],
).then(
lambda: (
gr.Textbox(
interactive=True,
placeholder="Enter your prompt here and press the button",
),
gr.Button(interactive=True),
),
None,
[text_input, launch_research_btn],
)
upload_file.change(
self.upload_file,
[upload_file, file_uploads_log, session_state],
[upload_status, file_uploads_log],
)
# Clean up session on page unload
demo.load(
lambda: None,
None,
None,
_js="""
() => {
window.addEventListener('beforeunload', function() {
// Notify backend about session end (would require additional endpoint)
console.log('Cleaning up session');
});
}
""",
)
demo.queue(max_size=20).launch(debug=True, **kwargs)
def _get_responsive_css(self):
"""Get CSS for responsive layout."""
return """
/* Responsive layout */
@media (max-width: 768px) {
#research-chatbot {
height: 400px !important;
}
/* Stack columns on small screens */
.gradio-row {
flex-direction: column;
}
/* Adjust column widths */
.gradio-column {
min-width: 100% !important;
width: 100% !important;
}
}
/* Base styling */
.gradio-container {
max-width: 100% !important;
}
"""
# ------------------------ Execution ------------------------
def main():
"""Main entry point for the application."""
# Initialize environment
setup_environment()
# Ensure downloads folder exists for browser
os.makedirs(f"./{BROWSER_CONFIG['downloads_folder']}", exist_ok=True)
# Launch UI
GradioUI().launch(share=True)
if __name__ == "__main__":
main()