Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python | |
| # coding=utf-8 | |
| # Copyright 2024 The Footscray Coding Collective. All rights reserved. | |
| """ | |
| Financial Research Agent: Advanced Market Analysis and Data Access | |
| This script implements a comprehensive financial research agent capable of performing market analysis, | |
| retrieving financial data, and providing interactive research capabilities through either a GUI or | |
| command-line interface. | |
| The agent leverages the Smolagents framework to create an autonomous system that can: | |
| 1. Access and analyze real-time market data through Alpha Vantage API integration | |
| 2. Process financial documents and extract relevant information | |
| 3. Perform web searches and analyze webpage content | |
| 4. Create visualizations of financial data | |
| 5. Generate comprehensive financial analysis reports | |
| 6. Handle user uploads of various document types | |
| Key Components: | |
| ------------- | |
| - ModelManager: Handles loading and configuration of various LLM models | |
| - ToolRegistry: Manages initialization and organization of tools available to the agent | |
| - GradioUI: Provides a user-friendly interface with responsive design for desktop/mobile | |
| - A robust set of financial tools for retrieving stock data, financial statements, and market sentiment | |
| - Web browsing capabilities with text extraction and analysis | |
| - Document processing for PDFs, spreadsheets, and other common file formats | |
| - Visualization tools for creating charts and graphs from financial data | |
| Usage: | |
| ----- | |
| Run in UI mode (default): | |
| python app.py | |
| Run in headless mode with a specific query: | |
| python app.py --mode headless --query "Analyze Tesla's financial performance for 2023" | |
| Configuration: | |
| ------------ | |
| The script uses environment variables for API keys and other configuration settings. | |
| Required environment variables: | |
| - ALPHA_VANTAGE_API_KEY: For accessing financial data APIs | |
| - HF_TOKEN: For accessing Hugging Face models (optional) | |
| The agent also maintains detailed logs in the logs/ directory for debugging and auditing. | |
| Dependencies: | |
| ----------- | |
| - smolagents: Core framework for agent capabilities | |
| - gradio: For the web interface | |
| - Alpha Vantage API integration: For financial data | |
| - Various data processing libraries: For handling and analyzing financial information | |
| Technical Notes: | |
| -------------- | |
| - The agent runs with a configurable number of maximum steps (default: 20) | |
| - Planning occurs at regular intervals (default: every 4 steps) | |
| - The agent has access to a curated list of authorized Python imports for security | |
| - All file uploads are validated for type and size before processing | |
| Created by the Footscray Coding Collective | |
| Copyright 2024, All rights reserved | |
| """ | |
| import contextlib | |
| import datetime | |
| import logging | |
| import mimetypes | |
| import os | |
| import re | |
| import shutil | |
| from typing import Any, Dict, Generator, List, Optional, Tuple | |
| # Typer for CLI functionality | |
| import typer | |
| # Telemetry imports (optional) | |
| # with contextlib.suppress(ImportError): | |
| # from openinference.instrumentation.smolagents import SmolagentsInstrumentor | |
| # from phoenix.otel import register | |
| # Initialize telemetry for observability and tracing | |
| # register() | |
| # SmolagentsInstrumentor().instrument() | |
| # third-party | |
| import gradio as gr | |
| import pytz | |
| from dotenv import load_dotenv | |
| from huggingface_hub import login | |
| from rich.console import Console | |
| from rich.logging import RichHandler | |
| from smolagents import FinalAnswerTool # smolagents | |
| from smolagents import ( | |
| CodeAgent, | |
| GoogleSearchTool, | |
| HfApiModel, | |
| LiteLLMModel, | |
| OpenAIServerModel, | |
| Tool, | |
| TransformersModel, | |
| ) | |
| from smolagents.agent_types import AgentText, handle_agent_output_types | |
| from smolagents.gradio_ui import pull_messages_from_step | |
| # local | |
| from scripts.finance_tools import ( | |
| DataVisualizationTool, | |
| FinancialCalculatorTool, | |
| TrendAnalysisTool, | |
| get_balance_sheet_data, | |
| get_cash_flow_data, | |
| get_company_overview_data, | |
| get_earnings_data, | |
| get_income_statement_data, | |
| get_market_news_sentiment, | |
| get_stock_quote_data, | |
| get_time_series_daily, | |
| search_symbols, | |
| ) | |
| from scripts.flux_lora_tool import FluxLoRATool | |
| from scripts.text_cleaner_tool import TextCleanerTool | |
| from scripts.text_inspector_tool import TextInspectorTool | |
| from scripts.text_web_browser import ( | |
| ArchiveSearchTool, | |
| DownloadTool, | |
| FinderTool, | |
| FindNextTool, | |
| PageDownTool, | |
| PageUpTool, | |
| SimpleTextBrowser, | |
| VisitTool, | |
| ) | |
| from scripts.time_tools import get_temporal_context | |
| from scripts.visual_qa import visualizer | |
| # Initialize console and app | |
| console = Console() | |
| app = typer.Typer( | |
| help="Financial Research Agent - Access market data and analysis through a CLI or UI", | |
| add_completion=False, | |
| ) | |
| # ------------------------ Configuration and Setup ------------------------ | |
| # Constants and configurations | |
| AUTHORIZED_IMPORTS = [ | |
| "requests", # Web requests (fetching data from the internet) | |
| "pytz", # Timezone handling | |
| "zipfile", # Working with ZIP archives | |
| "pandas", # Data manipulation and analysis (DataFrames) | |
| "numpy", # Numerical computing (arrays, linear algebra) | |
| "sympy", # Symbolic mathematics (algebra, calculus) | |
| "json", # JSON data serialization/deserialization | |
| "bs4", # Beautiful Soup for HTML/XML parsing | |
| "pubchempy", # Accessing PubChem chemical database | |
| "yaml", | |
| "xml", # XML processing | |
| "yahoo_finance", # Fetching stock datauv | |
| "Bio", # Bioinformatics tools (e.g., sequence analysis) | |
| "sklearn", # Scikit-learn for machine learning | |
| "scipy", # Scientific computing (stats, optimization) | |
| "pydub", # Audio manipulation | |
| "PIL", # Pillow for image processing | |
| "chess", # Chess-related functionality | |
| "PyPDF2", # PDF manipulation | |
| "pptx", # PowerPoint file manipulation | |
| "torch", # PyTorch for neural networks | |
| "datetime", # Date and time handling | |
| "fractions", # Rational number arithmetic | |
| "csv", # CSV file reading/writing | |
| "cleantext", # Text cleaning and normalization | |
| "os", # Operating system interaction (file system, etc.) VERY IMPORTANT | |
| "re", # Regular expressions for text processing | |
| "collections", # Useful data structures (e.g., defaultdict, Counter) | |
| "math", # Basic mathematical functions | |
| "random", # Random number generation | |
| "io", # Input/output streams | |
| "urllib.parse", # URL parsing and manipulation (safe URL handling) | |
| "typing", # Support for type hints (improve code clarity) | |
| "concurrent.futures", # For parallel execution | |
| "time", # Measuring time | |
| "tempfile", # Creating temporary files and directories | |
| # Data Visualization (if needed) - Consider security implications carefully | |
| "matplotlib.plt", # Plotting library | |
| "seaborn", # Statistical data visualization (more advanced) | |
| # Web Scraping (more specific/controlled) - Consider ethical implications | |
| "lxml", # Faster XML/HTML processing (alternative to bs4) | |
| "selenium", # Automated browser control (for dynamic websites) | |
| # Database interaction (if needed) - Handle credentials securely! | |
| "sqlite3", # SQLite database access | |
| # Task scheduling | |
| "schedule", # Allow the agent to schedule tasks | |
| "uuid", | |
| "base64", | |
| "smolagents", # smolagents package to be able to create smolagents tools | |
| ] | |
| USER_AGENT = ( | |
| "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 " | |
| "(KHTML, like Gecko) Chrome/119.0.0.0 Safari/537.36 Edg/119.0.0.0" | |
| ) | |
| BROWSER_CONFIG = { | |
| "viewport_size": 1024 * 5, | |
| "downloads_folder": "data/downloads_folder", | |
| "request_kwargs": { | |
| "headers": {"User-Agent": USER_AGENT}, | |
| "timeout": 300, | |
| }, | |
| "serpapi_key": os.getenv("SERPAPI_API_KEY"), | |
| } | |
| CUSTOM_ROLE_CONVERSIONS = {"tool-call": "assistant", "tool-response": "user"} | |
| ALLOWED_FILE_TYPES = [ | |
| "application/pdf", | |
| "application/vnd.openxmlformats-officedocument.wordprocessingml.document", | |
| "text/plain", | |
| "text/markdown", | |
| "application/json", | |
| "image/png", | |
| "image/webp", | |
| "image/jpeg", | |
| "image/gif", | |
| "video/mp4", | |
| "audio/mpeg", | |
| "audio/wav", | |
| "audio/ogg", | |
| ] | |
| # Set up logging configuration | |
| def setup_logging() -> Tuple[str, logging.Logger]: | |
| """ | |
| Configure logging with structured output and file storage. | |
| The function creates logs directory and timestamped log filename, sets up | |
| logging with Rich integration and creates and returns logger. | |
| Returns: | |
| Tuple containing the log file path and configured logger | |
| """ | |
| # Create logs directory | |
| current_dir = os.path.dirname(os.path.abspath(__file__)) | |
| logs_dir = os.path.join(current_dir, "logs") | |
| os.makedirs(logs_dir, exist_ok=True) | |
| # Generate timestamped log filename | |
| melbourne_timezone = pytz.timezone("Australia/Melbourne") | |
| log_filename = f'smolagents_{datetime.datetime.now(melbourne_timezone).strftime("%Y%m%d_%H%M%S")}.log' | |
| log_file = os.path.join(logs_dir, log_filename) | |
| # Set up logging with Rich integration | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s [%(levelname)s] - %(message)s", | |
| datefmt="%Y-%m-%d %H:%M:%S", | |
| handlers=[ | |
| RichHandler(rich_tracebacks=True, show_time=True), | |
| logging.FileHandler(log_file), | |
| ], | |
| ) | |
| # Create and return logger | |
| logger = logging.getLogger(__name__) | |
| return log_file, logger | |
| LOG_FILE, logger = setup_logging() | |
| def setup_environment() -> None: | |
| """Initialize environment variables and authentication. | |
| This function ensures that required environment variables are set and | |
| attempts to authenticate with Hugging Face and Alpha Vantage services. | |
| """ | |
| load_dotenv(override=True) | |
| # Check Hugging Face token | |
| if os.getenv("HF_TOKEN"): # Check if token is actually set | |
| login(os.getenv("HF_TOKEN")) | |
| console.print("HF_TOKEN loaded successfully") | |
| else: | |
| console.print( | |
| "[yellow]HF_TOKEN not found in environment variables. " | |
| "Some features may not work properly.[/yellow]" | |
| ) | |
| # Check Alpha Vantage API key | |
| try: | |
| # Ensure Alpha Vantage API key is available | |
| api_key = os.getenv("ALPHA_VANTAGE_API_KEY") | |
| if not api_key: | |
| console.print( | |
| "[yellow]⚠️ Warning: ALPHA_VANTAGE_API_KEY not found. " | |
| "Finance tools may not work properly.[/yellow]" | |
| ) | |
| else: | |
| console.print("[green]✓ ALPHA_VANTAGE_API_KEY loaded successfully[/green]") | |
| except Exception as e: | |
| console.print(f"[red]Error checking ALPHA_VANTAGE_API_KEY: {e}[/red]") | |
| # ------------------------ Model and Tool Management ------------------------ | |
| class ModelManager: | |
| """Manages model loading and initialization. | |
| This class provides a static method to load the specified model with the | |
| appropriate configuration. It supports the following inference types: | |
| - hf_api: Use the Hugging Face API to load the model. | |
| - hf_api_provider: Use the Hugging Face API to load the model with the | |
| 'together' provider. | |
| - litellm: Load the LiteLLM model with the specified model ID. | |
| - openai: Load the OpenAI model with the specified model ID and API key. | |
| - transformers: Load the Hugging Face transformers model with the | |
| specified model ID and configuration. | |
| """ | |
| def load_model(chosen_inference: str, model_id: str, key_manager=None): | |
| """Load the specified model with appropriate configuration. | |
| Args: | |
| chosen_inference (str): The inference type to use. | |
| model_id (str): The model ID to load. | |
| key_manager (Optional[KeyManager]): The key manager to use for | |
| loading the model. Required for OpenAI models. | |
| Raises: | |
| ValueError: If the chosen inference type is invalid. | |
| Exception: If an error occurs while loading the model. | |
| """ | |
| try: | |
| if chosen_inference == "hf_api": | |
| return HfApiModel(model_id=model_id) | |
| if chosen_inference == "hf_api_provider": | |
| return HfApiModel(provider="together") | |
| if chosen_inference == "litellm": | |
| return LiteLLMModel(model_id=model_id) | |
| if chosen_inference == "openai": | |
| if not key_manager: | |
| raise ValueError("Key manager required for OpenAI model") | |
| return OpenAIServerModel( | |
| model_id=model_id, api_key=key_manager.get_key("openai_api_key") | |
| ) | |
| if chosen_inference == "transformers": | |
| return TransformersModel( | |
| model_id="HuggingFaceTB/SmolLM2-1.7B-Instruct", | |
| device_map="auto", | |
| max_new_tokens=1000, | |
| ) | |
| else: | |
| raise ValueError(f"Invalid inference type: {chosen_inference}") | |
| except Exception as e: | |
| console.print(f"[red]✗ Couldn't load model: {e}[/red]") | |
| raise | |
| # ------------------------ Tool Registration ------------------------ | |
| class ToolRegistry: | |
| """Manages tool initialization and organization using Zhou Protocol priorities.""" | |
| def load_information_tools(model, text_limit=30000): | |
| """ | |
| Initialize and return information analysis tools. | |
| This method creates tools for analyzing text from documents, and other sources. | |
| The information tools should be prioritized first in the agent's toolset. | |
| Args: | |
| model: Language model to use for analysis | |
| text_limit: Maximum character length for text summaries | |
| Returns: | |
| List of information analysis tools | |
| """ | |
| return [ | |
| TextInspectorTool(model, text_limit), | |
| ] | |
| def load_utility_tools(): | |
| """ | |
| Initialize and return utility tools for text cleaning and normalization. | |
| Returns: | |
| List of utility tools | |
| """ | |
| return [ | |
| TextCleanerTool(), | |
| ] | |
| def load_time_tools(): | |
| """ | |
| Initialize and return time-related tools. | |
| Returns: | |
| List of time-related tools | |
| """ | |
| return [get_temporal_context] | |
| def load_finance_tools(): | |
| """ | |
| Initialize and return financial analysis tools. | |
| Returns: | |
| List of financial tools in priority order | |
| """ | |
| return [ | |
| # Analysis tools first (higher priority) | |
| DataVisualizationTool(), | |
| FinancialCalculatorTool(), | |
| TrendAnalysisTool(), | |
| # Data retrieval tools next | |
| search_symbols, | |
| get_stock_quote_data, | |
| get_company_overview_data, | |
| get_earnings_data, | |
| get_income_statement_data, | |
| get_balance_sheet_data, | |
| get_cash_flow_data, | |
| get_time_series_daily, | |
| get_market_news_sentiment, | |
| ] | |
| def load_web_tools(browser, text_limit=20000): | |
| """ | |
| Initialize and return web interaction tools. | |
| Args: | |
| browser: Browser instance for web navigation | |
| text_limit: Maximum character length for text processing | |
| Returns: | |
| List of web tools in priority order | |
| """ | |
| return [ | |
| # Search tools first | |
| GoogleSearchTool(provider="serper"), | |
| # Navigation tools next | |
| VisitTool(browser), | |
| DownloadTool(browser), | |
| # Page interaction tools last | |
| PageUpTool(browser), | |
| PageDownTool(browser), | |
| FinderTool(browser), | |
| FindNextTool(browser), | |
| ArchiveSearchTool(browser), | |
| ] | |
| def load_image_generation_tools(): | |
| """ | |
| Initialize and return image generation tools. | |
| Returns: | |
| Image generation tool or fallback | |
| """ | |
| try: | |
| return Tool.from_space( | |
| space_id="xkerser/FLUX.1-dev", | |
| name="image_generator", | |
| description="Generates high-quality AgentImage using the FLUX.1-dev model based on text prompts.", | |
| ) | |
| except Exception as e: | |
| console.print( | |
| f"[yellow]✗ Couldn't initialize image generation tool: {e}[/yellow]" | |
| ) | |
| return FluxLoRATool() | |
| def load_final_answer_tool(): | |
| """ | |
| Return the final answer tool for providing conclusive responses. | |
| Returns: | |
| List containing the final answer tool | |
| """ | |
| return [FinalAnswerTool()] | |
| def create_agent(model_id: str = "openrouter/google/gemini-2.0-flash-001"): | |
| """ | |
| Create a fresh agent instance with properly configured tools. | |
| This function creates a CodeAgent with tools organized by the Zhou Protocol | |
| priority system, ensuring the most relevant tools are considered first. | |
| Args: | |
| model_id: The ID of the model to use for the agent | |
| Returns: | |
| A configured CodeAgent instance | |
| Raises: | |
| RuntimeError: If agent creation fails | |
| """ | |
| try: | |
| # Initialize model with fallback system | |
| model = _load_model_with_fallback(model_id) | |
| # Initialize tools | |
| text_limit = 30000 | |
| browser = SimpleTextBrowser(**BROWSER_CONFIG) | |
| # Collect all tools with proper Zhou Protocol prioritization | |
| information_tools = ToolRegistry.load_information_tools(model, text_limit) | |
| utility_tools = ToolRegistry.load_utility_tools() | |
| finance_tools = ToolRegistry.load_finance_tools() | |
| web_tools = ToolRegistry.load_web_tools(browser) | |
| time_tools = ToolRegistry.load_time_tools() | |
| image_generator = ToolRegistry.load_image_generation_tools() | |
| final_answer = ToolRegistry.load_final_answer_tool() | |
| # Combine all tools with information tools prioritized first | |
| all_tools = ( | |
| information_tools # Critical information extraction (highest priority) | |
| + utility_tools # General utility functions | |
| + finance_tools # Financial analysis capabilities | |
| + web_tools # Web search and navigation | |
| + time_tools # Time context tools | |
| + [visualizer] # Image analysis | |
| + [image_generator] # Image generation | |
| + final_answer # Task completion (always last) | |
| ) | |
| # Validate tools before creating agent | |
| _validate_tools(all_tools) | |
| return CodeAgent( | |
| model=model, | |
| tools=all_tools, | |
| max_steps=20, | |
| verbosity_level=2, | |
| additional_authorized_imports=AUTHORIZED_IMPORTS, | |
| planning_interval=4, | |
| description=""" | |
| This agent assists with comprehensive research and financial analysis. It first analyzes | |
| any provided documents or text, then leverages specialized financial tools and web search | |
| capabilities to provide thorough insights. | |
| QUERY COMPREHENSION FRAMEWORK | |
| Before answering any complex question, apply the Zhou Comprehension Pattern: | |
| 1. **Initial Parse**: What is literally being asked? | |
| 2. **Intent Detection**: What is the user actually trying to accomplish? | |
| 3. **Knowledge Assessment**: What information is needed to address this properly? | |
| 4. **Tool Selection**: Which tools provide the most direct path to a solution? | |
| 5. **Execution Planning**: What sequence of operations will yield the best result? | |
| CLARIFICATION CHECKLIST | |
| When faced with ambiguous queries, the agent should systematically clarify: | |
| * **Scope**: "How comprehensive should this analysis be?" | |
| * **Format**: "What form would you like the results in?" | |
| * **Technical Level**: "Should I explain technical details or focus on practical applications?" | |
| * **Time Horizon**: "Are you interested in historical data, current status, or future projections?" | |
| * **Priority**: "Which aspect of this question is most important to you?" | |
| """.strip(), | |
| ) | |
| except Exception as e: | |
| console.print(f"[red]✗ Agent creation failed: {e}[/red]") | |
| raise RuntimeError(f"Agent creation failed: {e}") | |
| def _load_model_with_fallback(model_id: str) -> Any: | |
| """ | |
| Attempt to load the specified model with fallbacks if it fails. | |
| Args: | |
| model_id: Primary model ID to try loading | |
| Returns: | |
| Loaded model instance | |
| Raises: | |
| RuntimeError: If all model loading attempts fail | |
| """ | |
| # Fallback model chain from most capable to most reliable | |
| fallback_models = [ | |
| model_id, # Try the requested model first | |
| "openrouter/anthropic/claude-3.7-sonnet", | |
| "openai/gpt-4o-mini", | |
| "anthropic/claude-3.7-sonnet", | |
| "HuggingFaceTB/SmolLM2-1.7B-Instruct", # Last resort local option | |
| ] | |
| last_error = None | |
| for model in fallback_models: | |
| try: | |
| return LiteLLMModel( | |
| custom_role_conversions=CUSTOM_ROLE_CONVERSIONS, | |
| model_id=model, | |
| ) | |
| except Exception as e: | |
| last_error = e | |
| console.print(f"[yellow]Failed to load model {model}: {e}[/yellow]") | |
| # If we get here, all models failed | |
| raise RuntimeError(f"All model loading attempts failed. Last error: {last_error}") | |
| def _validate_tools(tools): | |
| """ | |
| Validate that all tools are proper Tool instances. | |
| Args: | |
| tools: List of tools to validate | |
| Raises: | |
| ValueError: If any tool is not a Tool instance | |
| """ | |
| for tool in tools: | |
| if not isinstance(tool, Tool): | |
| raise ValueError( | |
| f"Invalid tool type: {type(tool)}. " | |
| f"All tools must be instances of Tool class." | |
| ) | |
| # ------------------------ Gradio UI Components ------------------------ | |
| def stream_to_gradio( | |
| agent, | |
| task: str, | |
| reset_agent_memory: bool = False, | |
| additional_args: Optional[dict] = None, | |
| ): | |
| """Streams agent responses with improved status indicators.""" | |
| try: | |
| # Initial processing indicator | |
| yield gr.ChatMessage(role="assistant", content="⏳ Processing your request...") | |
| # Track what we've yielded to replace the processing indicator | |
| first_message_yielded = False | |
| for step_log in agent.run( | |
| task, stream=True, reset=reset_agent_memory, additional_args=additional_args | |
| ): | |
| # The key fix: pull_messages_from_step is a generator function that yields messages | |
| # We need to iterate through each yielded message | |
| for message in pull_messages_from_step(step_log): | |
| if not first_message_yielded: | |
| # Replace the initial "Processing" message | |
| first_message_yielded = True | |
| message.content = message.content.replace( | |
| "⏳ Processing your request...", "" | |
| ) | |
| # Check what type of operation is being performed based on the metadata or content | |
| # Instead of trying to access a 'status' attribute that doesn't exist | |
| content_lower = ( | |
| message.content.lower() if hasattr(message, "content") else "" | |
| ) | |
| if "document analysis" in content_lower: | |
| message.content = f"📄 **Document Analysis:** {message.content}" | |
| elif "search" in content_lower: | |
| message.content = f"🔍 **Search:** {message.content}" | |
| yield message | |
| # Final answer with enhanced formatting | |
| final_answer = handle_agent_output_types(step_log) | |
| if isinstance(final_answer, AgentText): | |
| yield gr.ChatMessage( | |
| role="assistant", | |
| content=f"✅ **Final Answer:**\n\n{final_answer.to_string()}", | |
| ) | |
| else: | |
| yield gr.ChatMessage( | |
| role="assistant", content=f"✅ **Final Answer:** {str(final_answer)}" | |
| ) | |
| except Exception as e: | |
| yield gr.ChatMessage( | |
| role="assistant", | |
| content=f"❌ **Error:** {str(e)}\n\nPlease try again with a different query.", | |
| ) | |
| # ------------------------ Gradio UI Components ------------------------ | |
| class GradioUI: | |
| """A one-line interface to launch your agent in Gradio.""" | |
| def __init__(self, file_upload_folder: str | None = None): | |
| """Initialize the Gradio UI with optional file upload functionality.""" | |
| self.file_upload_folder = file_upload_folder | |
| if self.file_upload_folder is not None: | |
| if not os.path.exists(file_upload_folder): | |
| os.mkdir(file_upload_folder) | |
| def interact_with_agent( | |
| self, | |
| prompt: str, | |
| messages: List[gr.ChatMessage], | |
| session_state: Dict[str, Any], | |
| ) -> Generator[List[gr.ChatMessage], None, None]: | |
| """Main interaction handler with the agent. | |
| Args: | |
| prompt: The user's input prompt | |
| messages: The list of messages so far (including the user's prompt) | |
| session_state: The current state of the user's session | |
| Yields: | |
| A list of messages after each step (including the user's prompt) | |
| """ | |
| # Get or create session-specific agent | |
| if "agent" not in session_state: | |
| model_id = session_state.get( | |
| "model_id", "openrouter/google/gemini-2.0-flash-001" | |
| ) | |
| session_state["agent"] = create_agent(model_id) | |
| # Adding monitoring | |
| try: | |
| # Log the existence of agent memory | |
| has_memory = hasattr(session_state["agent"], "memory") | |
| console.print(f"Agent has memory: {has_memory}") | |
| if has_memory: | |
| console.print(f"Memory type: {type(session_state['agent'].memory)}") | |
| messages.append(gr.ChatMessage(role="user", content=prompt)) | |
| yield messages | |
| for msg in stream_to_gradio( | |
| session_state["agent"], task=prompt, reset_agent_memory=False | |
| ): | |
| messages.append(msg) | |
| yield messages # Yield messages after each step | |
| yield messages # Yield messages one last time | |
| except Exception as e: | |
| console.print(f"[red]Error in interaction: {str(e)}[/red]") | |
| raise | |
| def upload_file( | |
| self, | |
| file, | |
| file_uploads_log, | |
| ): | |
| """Handle file uploads with proper validation and security.""" | |
| if file is None: | |
| return gr.Textbox("No file uploaded", visible=True), file_uploads_log | |
| try: | |
| mime_type, _ = mimetypes.guess_type(file.name) | |
| except Exception as e: | |
| return gr.Textbox(f"Error: {e}", visible=True), file_uploads_log | |
| if mime_type not in ALLOWED_FILE_TYPES: | |
| return gr.Textbox("File type disallowed", visible=True), file_uploads_log | |
| # Sanitize file name | |
| original_name = os.path.basename(file.name) | |
| sanitized_name = re.sub( | |
| r"[^\w\-.]", "_", original_name | |
| ) # Replace invalid chars with underscores | |
| # Ensure the extension correlates to the mime type | |
| type_to_ext = {} | |
| for ext, t in mimetypes.types_map.items(): | |
| if t not in type_to_ext: | |
| type_to_ext[t] = ext | |
| # Build sanitized filename with proper extension | |
| name_parts = sanitized_name.split(".")[:-1] | |
| extension = type_to_ext.get(mime_type, "") | |
| sanitized_name = "".join(name_parts) + extension | |
| # Limit File Size, and Throw Error | |
| max_file_size_mb = 50 # Define the limit | |
| file_size_mb = os.path.getsize(file.name) / (1024 * 1024) # Size in MB | |
| if file_size_mb > max_file_size_mb: | |
| return ( | |
| gr.Textbox( | |
| f"File size exceeds {max_file_size_mb} MB limit.", visible=True | |
| ), | |
| file_uploads_log, | |
| ) | |
| # Save the uploaded file to the specified folder | |
| file_path = os.path.join(self.file_upload_folder, sanitized_name) | |
| shutil.copy(file.name, file_path) | |
| return gr.Textbox( | |
| f"File uploaded: {file_path}", visible=True | |
| ), file_uploads_log + [file_path] | |
| def log_user_message(self, text_input, file_uploads_log): | |
| """Process user message and handle file references.""" | |
| message = text_input | |
| if len(file_uploads_log) > 0: | |
| message += f"\nYou have been provided with these files, which might be helpful or not: {file_uploads_log}" # Added file list | |
| return ( | |
| message, | |
| gr.Textbox( | |
| value="", | |
| interactive=False, | |
| placeholder="Processing...", # Changed placeholder. | |
| ), | |
| gr.Button(interactive=False), | |
| ) | |
| def detect_device(self, request: gr.Request): | |
| """Detect whether the user is on mobile or desktop device.""" | |
| if not request: | |
| return "Unknown device" # Handle case where request is none. | |
| # Method 1: Check sec-ch-ua-mobile header | |
| is_mobile_header = request.headers.get("sec-ch-ua-mobile") | |
| if is_mobile_header: | |
| return "Mobile" if "?1" in is_mobile_header else "Desktop" | |
| # Method 2: Check user-agent string | |
| user_agent = request.headers.get("user-agent", "").lower() | |
| mobile_keywords = ["android", "iphone", "ipad", "mobile", "phone"] | |
| if any(keyword in user_agent for keyword in mobile_keywords): | |
| return "Mobile" | |
| # Method 3: Check platform | |
| platform = request.headers.get("sec-ch-ua-platform", "").lower() | |
| if platform: | |
| if platform in ['"android"', '"ios"']: | |
| return "Mobile" | |
| if platform in ['"windows"', '"macos"', '"linux"']: | |
| return "Desktop" | |
| # Default case if no clear indicators | |
| return "Desktop" | |
| def launch(self, **kwargs): | |
| """Launch the Gradio UI with responsive layout.""" | |
| with gr.Blocks(theme="ocean", fill_height=True) as demo: | |
| # Different layouts for mobile and computer devices | |
| def layout(request: gr.Request): | |
| device = self.detect_device(request) | |
| console.print(f"device - {device}") | |
| # Render layout with sidebar | |
| if device == "Desktop": | |
| return self._create_desktop_layout() | |
| return self._create_mobile_layout() | |
| demo.queue(max_size=20).launch( | |
| debug=True, **kwargs | |
| ) # Add queue with reasonable size | |
| def _create_desktop_layout(self): | |
| """Create the desktop layout with sidebar.""" | |
| with gr.Blocks(fill_height=True) as sidebar_demo: | |
| with gr.Sidebar(): | |
| gr.Markdown( | |
| """#OpenDeepResearch - 3theSmolagents! | |
| Model_id: google/gemini-2.0-flash-001""" | |
| ) | |
| with gr.Group(): | |
| gr.Markdown("**What's on your mind mate?**", container=True) | |
| text_input = gr.Textbox( | |
| lines=3, | |
| label="Your request", | |
| container=False, | |
| placeholder="Enter your prompt here and press Shift+Enter or press the button", | |
| ) | |
| launch_research_btn = gr.Button("Run", variant="primary") | |
| # If an upload folder is provided, enable the upload feature | |
| if self.file_upload_folder is not None: | |
| upload_file = gr.File(label="Upload a file") | |
| upload_status = gr.Textbox( | |
| label="Upload Status", interactive=False, visible=False | |
| ) | |
| file_uploads_log = gr.State([]) | |
| upload_file.change( | |
| self.upload_file, | |
| [upload_file, file_uploads_log], | |
| [upload_status, file_uploads_log], | |
| ) | |
| gr.HTML("<br><br><h4><center>Powered by:</center></h4>") | |
| with gr.Row(): | |
| gr.HTML( | |
| """ | |
| <div style="display: flex; align-items: center; gap: 8px; font-family: system-ui, -apple-system, sans-serif;"> | |
| <img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/smolagents/mascot_smol.png" | |
| style="width: 32px; height: 32px; object-fit: contain;" alt="logo"> | |
| <a target="_blank" href="https://github.com/huggingface/smolagents"> | |
| <b>huggingface/smolagents</b> | |
| </a> | |
| </div> | |
| """ | |
| ) | |
| # Add session state to store session-specific data | |
| session_state = gr.State({}) # Initialize empty state for each session | |
| stored_messages = gr.State([]) | |
| if "file_uploads_log" not in locals(): | |
| file_uploads_log = gr.State([]) | |
| chatbot = gr.Chatbot( | |
| label="Research-Assistant", | |
| type="messages", | |
| avatar_images=( | |
| None, | |
| "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/smolagents/mascot_smol.png", | |
| ), | |
| resizeable=False, | |
| scale=1, | |
| elem_id="my-chatbot", | |
| ) | |
| self._connect_event_handlers( | |
| text_input, | |
| launch_research_btn, | |
| file_uploads_log, | |
| stored_messages, | |
| chatbot, | |
| session_state, | |
| ) | |
| return sidebar_demo | |
| def _create_mobile_layout(self): | |
| """Create the mobile layout (simpler without sidebar).""" | |
| with gr.Blocks(fill_height=True) as simple_demo: | |
| gr.Markdown("""#OpenDeepResearch - free the AI agents!""") | |
| # Add session state to store session-specific data | |
| session_state = gr.State({}) | |
| stored_messages = gr.State([]) | |
| file_uploads_log = gr.State([]) | |
| chatbot = gr.Chatbot( | |
| label="Research-Assistant", | |
| type="messages", | |
| avatar_images=( | |
| None, | |
| "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/smolagents/mascot_smol.png", | |
| ), | |
| resizeable=True, | |
| scale=1, | |
| ) | |
| # If an upload folder is provided, enable the upload feature | |
| if self.file_upload_folder is not None: | |
| upload_file = gr.File(label="Upload a file") | |
| upload_status = gr.Textbox( | |
| label="Upload Status", interactive=False, visible=False | |
| ) | |
| upload_file.change( | |
| self.upload_file, | |
| [upload_file, file_uploads_log], | |
| [upload_status, file_uploads_log], | |
| ) | |
| text_input = gr.Textbox( | |
| lines=1, | |
| label="What's on your mind mate?", | |
| placeholder="Chuck in a question and we'll take care of the rest", | |
| ) | |
| launch_research_btn = gr.Button("Run", variant="primary") | |
| self._connect_event_handlers( | |
| text_input, | |
| launch_research_btn, | |
| file_uploads_log, | |
| stored_messages, | |
| chatbot, | |
| session_state, | |
| ) | |
| return simple_demo | |
| def _connect_event_handlers( | |
| self, | |
| text_input, | |
| launch_research_btn, | |
| file_uploads_log, | |
| stored_messages, | |
| chatbot, | |
| session_state, | |
| ): | |
| """Connect the event handlers for input elements.""" | |
| # Connect text input submit event | |
| text_input.submit( | |
| self.log_user_message, | |
| [text_input, file_uploads_log], | |
| [stored_messages, text_input, launch_research_btn], | |
| ).then( | |
| self.interact_with_agent, | |
| [stored_messages, chatbot, session_state], | |
| [chatbot], | |
| ).then( | |
| lambda: ( | |
| gr.Textbox( | |
| interactive=True, | |
| placeholder="Enter your prompt here and press the button", | |
| ), | |
| gr.Button(interactive=True), | |
| ), | |
| None, | |
| [text_input, launch_research_btn], | |
| ) | |
| # Connect button click event | |
| launch_research_btn.click( | |
| self.log_user_message, | |
| [text_input, file_uploads_log], | |
| [stored_messages, text_input, launch_research_btn], | |
| ).then( | |
| self.interact_with_agent, | |
| [stored_messages, chatbot, session_state], | |
| [chatbot], | |
| ).then( | |
| lambda: ( | |
| gr.Textbox( | |
| interactive=True, | |
| placeholder="Enter your prompt here and press the button", | |
| ), | |
| gr.Button(interactive=True), | |
| ), | |
| None, | |
| [text_input, launch_research_btn], | |
| ) | |
| # ------------------------ CLI Command ------------------------ | |
| def run( | |
| mode: str = typer.Option( | |
| "ui", | |
| "--mode", | |
| "-m", | |
| help="Operating mode: 'ui' for Gradio interface or 'headless' for CLI mode", | |
| ), | |
| model_id: str = typer.Option( | |
| "openrouter/google/gemini-2.0-flash-001", | |
| "--model", | |
| help="Model ID to use for the agent", | |
| ), | |
| query: Optional[str] = typer.Option( | |
| None, "--query", "-q", help="Query to execute (required in headless mode)" | |
| ), | |
| ): | |
| """ | |
| Run the financial research agent in either UI or headless mode. | |
| In UI mode, launches a Gradio interface for interactive use. | |
| In headless mode, processes a single query and outputs the result to the console. | |
| """ | |
| # Setup environment variables | |
| setup_environment() | |
| # Validate inputs for headless mode | |
| if mode == "headless" and not query: | |
| console.print("[red]Error: query parameter is required in headless mode[/red]") | |
| raise typer.Exit(code=1) | |
| # Create agent with specified model ID | |
| console.print(f"[bold]Initializing agent with model:[/bold] {model_id}") | |
| # Execute in appropriate mode | |
| if mode == "ui": | |
| console.print( | |
| "[bold green]Starting UI mode with Gradio interface...[/bold green]" | |
| ) | |
| # Ensure downloads folder exists | |
| os.makedirs(f"./{BROWSER_CONFIG['downloads_folder']}", exist_ok=True) | |
| # Launch UI | |
| GradioUI(file_upload_folder="data/uploaded_files").launch() | |
| elif mode == "headless": | |
| console.print(f"[bold]Processing query in headless mode:[/bold] {query}") | |
| # Create agent for headless mode | |
| agent = create_agent(model_id) | |
| # Show a simple spinner during processing | |
| with console.status("[bold green]Processing query...[/bold green]"): | |
| result = agent.run(query) | |
| # Display the results | |
| console.print("\n[bold green]Results:[/bold green]") | |
| console.print(result) | |
| else: | |
| console.print( | |
| f"[red]Error: Invalid mode '{mode}'. Use 'ui' or 'headless'[/red]" | |
| ) | |
| raise typer.Exit(code=1) | |
| # ------------------------ Main Entry Point ------------------------ | |
| if __name__ == "__main__": | |
| # Use the typer app as the entry point | |
| app() | |