Spaces:
Running
Running
| from fastapi import FastAPI, HTTPException, UploadFile, File, Query, Request | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import StreamingResponse, FileResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from pydantic import BaseModel | |
| from typing import List, Optional, Dict | |
| import json | |
| import httpx | |
| import os | |
| import logging | |
| import asyncio | |
| import re | |
| import signal | |
| import sys | |
| import threading | |
| from datetime import datetime | |
| from concurrent.futures import ThreadPoolExecutor | |
| from .agents import AGENT_REGISTRY, get_default_counters, get_registry_for_frontend | |
| # Thread pool for running sync generators without blocking the event loop | |
| # Use daemon threads so they don't block shutdown | |
| _executor = ThreadPoolExecutor(max_workers=10) | |
| # Flag to signal shutdown to running threads | |
| _shutdown_flag = False | |
| async def _stream_sync_generator(sync_gen_func, *args, **kwargs): | |
| """Run a sync generator in a thread, yielding SSE-formatted JSON lines. | |
| This is the standard pattern for all agent handlers: wrap a blocking | |
| sync generator so it doesn't block the async event loop. | |
| """ | |
| loop = asyncio.get_event_loop() | |
| queue = asyncio.Queue() | |
| def run(): | |
| try: | |
| for update in sync_gen_func(*args, **kwargs): | |
| loop.call_soon_threadsafe(queue.put_nowait, update) | |
| finally: | |
| loop.call_soon_threadsafe(queue.put_nowait, None) | |
| future = loop.run_in_executor(_executor, run) | |
| while True: | |
| update = await queue.get() | |
| if update is None: | |
| break | |
| yield f"data: {json.dumps(update)}\n\n" | |
| await asyncio.wrap_future(future) | |
| def signal_handler(signum, frame): | |
| """Handle Ctrl+C by setting shutdown flag and exiting""" | |
| global _shutdown_flag | |
| _shutdown_flag = True | |
| logger.info("Shutdown signal received, cleaning up...") | |
| # Force exit after a short delay if threads don't stop | |
| sys.exit(0) | |
| # Register signal handlers | |
| signal.signal(signal.SIGINT, signal_handler) | |
| signal.signal(signal.SIGTERM, signal_handler) | |
| # ============================================ | |
| # Abort Registry | |
| # ============================================ | |
| # Per-agent abort flags for cooperative cancellation | |
| _abort_flags: Dict[str, threading.Event] = {} | |
| _agent_children: Dict[str, set] = {} # parent_id -> set of child_ids | |
| def register_agent(agent_id: str, parent_id: str = None) -> threading.Event: | |
| """Register an agent and return its abort event.""" | |
| event = threading.Event() | |
| _abort_flags[agent_id] = event | |
| if parent_id: | |
| _agent_children.setdefault(parent_id, set()).add(agent_id) | |
| return event | |
| def unregister_agent(agent_id: str): | |
| """Clean up abort flag and parent-child mappings.""" | |
| _abort_flags.pop(agent_id, None) | |
| # Remove from parent's children set | |
| for children in _agent_children.values(): | |
| children.discard(agent_id) | |
| _agent_children.pop(agent_id, None) | |
| def abort_agent_tree(agent_id: str) -> list: | |
| """Recursively abort an agent and all its children. Returns list of aborted IDs.""" | |
| aborted = [] | |
| if agent_id in _abort_flags: | |
| _abort_flags[agent_id].set() | |
| aborted.append(agent_id) | |
| for child_id in list(_agent_children.get(agent_id, [])): | |
| aborted.extend(abort_agent_tree(child_id)) | |
| return aborted | |
| # Configure logging to match uvicorn's format | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(levelname)s: %(message)s' | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # Silence noisy third-party loggers | |
| logging.getLogger("e2b").setLevel(logging.WARNING) | |
| logging.getLogger("e2b.api").setLevel(logging.WARNING) | |
| logging.getLogger("httpx").setLevel(logging.WARNING) | |
| app = FastAPI(title="AgentUI API") | |
| # ============================================ | |
| # Environment Variable Fallbacks | |
| # ============================================ | |
| # These allow API keys to be set via environment variables as fallback | |
| # when not configured in settings. Useful for HF Spaces deployment. | |
| def get_env_fallback(value: Optional[str], env_var: str) -> Optional[str]: | |
| """Return value if set, otherwise check environment variable.""" | |
| if value: | |
| return value | |
| return os.environ.get(env_var) | |
| # Get the project root directory (parent of backend/) | |
| PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | |
| # For code execution | |
| try: | |
| from e2b_code_interpreter import Sandbox | |
| from .code import stream_code_execution | |
| from openai import OpenAI | |
| E2B_AVAILABLE = True | |
| except ImportError: | |
| E2B_AVAILABLE = False | |
| logger.warning("E2B not available. Code execution will be disabled.") | |
| # For research | |
| try: | |
| from .research import stream_research | |
| RESEARCH_AVAILABLE = True | |
| except ImportError as e: | |
| RESEARCH_AVAILABLE = False | |
| logger.warning(f"Research dependencies not available ({e}). Install with: pip install trafilatura requests") | |
| # For command center with tool-based launching | |
| try: | |
| from .command import stream_command_center | |
| COMMAND_AVAILABLE = True | |
| except ImportError: | |
| COMMAND_AVAILABLE = False | |
| logger.warning("Command center tool handling not available.") | |
| # For agent with web tools | |
| try: | |
| from .agent import stream_agent_execution | |
| AGENT_AVAILABLE = True | |
| except ImportError: | |
| AGENT_AVAILABLE = False | |
| logger.warning("Agent web tools not available. Install with: pip install readability-lxml markdownify") | |
| # For image agent with HuggingFace tools | |
| try: | |
| from .image import stream_image_execution | |
| IMAGE_AVAILABLE = True | |
| except ImportError: | |
| IMAGE_AVAILABLE = False | |
| logger.warning("Image agent not available. Install with: pip install huggingface_hub Pillow") | |
| # Session management for sandboxes | |
| SANDBOXES: Dict[str, any] = {} | |
| SANDBOX_TIMEOUT = 300 | |
| # Conversation history per tab (persistent across requests) | |
| # Structure: {tab_id: [messages...]} | |
| CONVERSATION_HISTORY: Dict[str, List[Dict]] = {} | |
| # Figure stores per tab (persistent across requests so re-entry works without multimodal) | |
| # Structure: {tab_id: {figure_name: base64_png, ...}} | |
| # Global figure store: all agents write here so cross-agent references work. | |
| # Keys are namespaced like "figure_T{tab}_{N}" so there are no collisions. | |
| FIGURE_STORE: Dict[str, dict] = {} | |
| # Per-tab counters to track the next figure number for each tab | |
| FIGURE_COUNTERS: Dict[str, int] = {} | |
| # Multi-user isolation | |
| MULTI_USER = False | |
| USERS_ROOT = None # Set to FILES_ROOT/users/ when multi-user enabled | |
| USER_SESSIONS: Dict[str, str] = {} # user_id -> current_session_name | |
| USER_WORKSPACE_FILES: Dict[str, str] = {} # user_id -> workspace_file_path | |
| def get_user_id(request: Request) -> str: | |
| """Get user ID from request. Returns '' in single-user mode.""" | |
| if not MULTI_USER: | |
| return '' | |
| # Check header first (from apiFetch), then query param (from window.open) | |
| return request.headers.get('x-session-id') or request.query_params.get('session_id') or 'anonymous' | |
| def get_user_files_root(user_id: str) -> str: | |
| if not MULTI_USER: | |
| return FILES_ROOT | |
| user_dir = os.path.join(USERS_ROOT, user_id) | |
| os.makedirs(user_dir, exist_ok=True) | |
| return user_dir | |
| def get_user_sessions_root(user_id: str) -> str: | |
| root = get_user_files_root(user_id) | |
| sessions_dir = os.path.join(root, "sessions") | |
| os.makedirs(sessions_dir, exist_ok=True) | |
| return sessions_dir | |
| def get_user_settings_file(user_id: str) -> str: | |
| if not MULTI_USER: | |
| return SETTINGS_FILE | |
| user_dir = os.path.join(USERS_ROOT, user_id) | |
| os.makedirs(user_dir, exist_ok=True) | |
| user_settings = os.path.join(user_dir, "settings.json") | |
| # Copy default settings for new users | |
| if not os.path.exists(user_settings) and os.path.exists(SETTINGS_FILE): | |
| import shutil | |
| shutil.copy2(SETTINGS_FILE, user_settings) | |
| return user_settings | |
| def get_user_current_session(user_id: str) -> Optional[str]: | |
| if not MULTI_USER: | |
| return CURRENT_SESSION | |
| return USER_SESSIONS.get(user_id) | |
| def set_user_current_session(user_id: str, session_name: Optional[str]): | |
| global CURRENT_SESSION | |
| if not MULTI_USER: | |
| CURRENT_SESSION = session_name | |
| else: | |
| if session_name is None: | |
| USER_SESSIONS.pop(user_id, None) | |
| else: | |
| USER_SESSIONS[user_id] = session_name | |
| def get_user_workspace_file(user_id: str) -> Optional[str]: | |
| if not MULTI_USER: | |
| return WORKSPACE_FILE | |
| return USER_WORKSPACE_FILES.get(user_id) | |
| def set_user_workspace_file(user_id: str, path: Optional[str]): | |
| global WORKSPACE_FILE | |
| if not MULTI_USER: | |
| WORKSPACE_FILE = path | |
| else: | |
| if path is None: | |
| USER_WORKSPACE_FILES.pop(user_id, None) | |
| else: | |
| USER_WORKSPACE_FILES[user_id] = path | |
| def user_key(user_id: str, key: str) -> str: | |
| """Prefix a dict key with user_id for isolation.""" | |
| if not user_id: | |
| return key | |
| return f"{user_id}:{key}" | |
| async def safe_stream_wrapper(generator): | |
| """ | |
| Wrap a streaming generator to gracefully handle client disconnections. | |
| This prevents 'socket.send() raised exception' spam in logs when | |
| clients disconnect during SSE streaming. | |
| """ | |
| try: | |
| async for item in generator: | |
| yield item | |
| except (ConnectionResetError, BrokenPipeError, Exception) as e: | |
| # Client disconnected - this is normal, just stop streaming | |
| error_name = type(e).__name__ | |
| if "disconnect" in str(e).lower() or "closed" in str(e).lower() or isinstance(e, (ConnectionResetError, BrokenPipeError)): | |
| # Silently handle client disconnections | |
| return | |
| # Re-raise unexpected errors | |
| raise | |
| def sync_to_async_generator(sync_gen): | |
| """Convert a synchronous generator to an async generator with disconnect handling.""" | |
| async def async_gen(): | |
| try: | |
| for item in sync_gen: | |
| yield item | |
| except (ConnectionResetError, BrokenPipeError): | |
| # Client disconnected | |
| return | |
| except GeneratorExit: | |
| # Generator was closed (client disconnected) | |
| return | |
| return async_gen() | |
| # CORS middleware for frontend connection | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # In production, specify your frontend URL | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Agent type registry is in agents.py — system prompts, tools, and metadata are all defined there | |
| class Message(BaseModel): | |
| role: str | |
| content: str | |
| tool_call_id: Optional[str] = None # Required for role="tool" messages | |
| tool_calls: Optional[List[Dict]] = None # Required for assistant messages with tool use | |
| class FrontendContext(BaseModel): | |
| """Dynamic context from the frontend that can affect system prompts""" | |
| theme: Optional[Dict] = None # Current theme colors {name, accent, bg, etc.} | |
| open_agents: Optional[List[str]] = None # List of open agent types/names | |
| class ChatRequest(BaseModel): | |
| messages: List[Message] | |
| agent_type: str = "command" | |
| stream: bool = True | |
| endpoint: str # User's configured LLM endpoint | |
| token: Optional[str] = None # Optional auth token | |
| model: Optional[str] = "gpt-4" # Model name | |
| extra_params: Optional[Dict] = None # Extra parameters for API calls (e.g., enable_thinking) | |
| multimodal: bool = False # Whether the model supports vision/image input | |
| e2b_key: Optional[str] = None # E2B API key for code execution | |
| serper_key: Optional[str] = None # Serper API key for research | |
| hf_token: Optional[str] = None # HuggingFace token for image generation | |
| image_gen_model: Optional[str] = None # HuggingFace model for text-to-image | |
| image_edit_model: Optional[str] = None # HuggingFace model for image-to-image | |
| research_sub_agent_model: Optional[str] = None # Model for research sub-tasks | |
| research_sub_agent_endpoint: Optional[str] = None # Endpoint for research sub-agent (may differ from main) | |
| research_sub_agent_token: Optional[str] = None # Token for research sub-agent endpoint | |
| research_sub_agent_extra_params: Optional[Dict] = None # Extra params for research sub-agent | |
| research_parallel_workers: Optional[int] = None # Number of parallel workers for research | |
| research_max_websites: Optional[int] = None # Max websites to analyze per research session | |
| agent_id: Optional[str] = None # Unique agent/tab ID for session management | |
| parent_agent_id: Optional[str] = None # Parent agent ID for abort propagation | |
| frontend_context: Optional[FrontendContext] = None # Dynamic context from frontend | |
| class AbortRequest(BaseModel): | |
| agent_id: str | |
| class TitleRequest(BaseModel): | |
| query: str | |
| endpoint: str # User's configured LLM endpoint | |
| token: Optional[str] = None # Optional auth token | |
| model: Optional[str] = "gpt-4" # Model name | |
| class SandboxRequest(BaseModel): | |
| session_id: str | |
| e2b_key: str | |
| class SandboxStopRequest(BaseModel): | |
| session_id: str | |
| async def stream_code_agent( | |
| messages: List[dict], | |
| endpoint: str, | |
| token: Optional[str], | |
| model: str, | |
| e2b_key: str, | |
| session_id: str, | |
| tab_id: str = "default", | |
| parent_agent_id: Optional[str] = None, | |
| frontend_context: Optional[Dict] = None, | |
| extra_params: Optional[Dict] = None, | |
| files_root: str = None, | |
| multimodal: bool = False | |
| ): | |
| """Handle code agent with execution capabilities""" | |
| abort_event = register_agent(tab_id, parent_agent_id) | |
| try: | |
| async for chunk in _stream_code_agent_inner(messages, endpoint, token, model, e2b_key, session_id, tab_id, frontend_context, extra_params, abort_event, files_root, multimodal): | |
| yield chunk | |
| finally: | |
| unregister_agent(tab_id) | |
| async def _stream_code_agent_inner(messages, endpoint, token, model, e2b_key, session_id, tab_id, frontend_context, extra_params, abort_event, files_root=None, multimodal=False): | |
| if not E2B_AVAILABLE: | |
| yield f"data: {json.dumps({'type': 'error', 'content': 'E2B not available. Install with: pip install e2b-code-interpreter'})}\n\n" | |
| return | |
| if not e2b_key: | |
| yield f"data: {json.dumps({'type': 'error', 'content': 'E2B API key required for code execution. Please configure in settings.'})}\n\n" | |
| return | |
| try: | |
| if session_id not in SANDBOXES: | |
| os.environ["E2B_API_KEY"] = e2b_key | |
| SANDBOXES[session_id] = Sandbox.create(timeout=SANDBOX_TIMEOUT) | |
| sbx = SANDBOXES[session_id] | |
| client = OpenAI(base_url=endpoint, api_key=token) | |
| system_prompt = get_system_prompt("code", frontend_context) | |
| full_messages = [{"role": "system", "content": system_prompt}] + messages | |
| # Ensure per-tab counter exists | |
| if tab_id not in FIGURE_COUNTERS: | |
| FIGURE_COUNTERS[tab_id] = 0 | |
| async for chunk in _stream_sync_generator( | |
| stream_code_execution, client, model, full_messages, sbx, | |
| files_root=files_root or FILES_ROOT, extra_params=extra_params, | |
| abort_event=abort_event, multimodal=multimodal, tab_id=tab_id, | |
| figure_store=FIGURE_STORE, | |
| ): | |
| yield chunk | |
| # Derive counter from store keys for this tab's prefix | |
| prefix = f"figure_T{tab_id}_" | |
| max_counter = 0 | |
| for name in FIGURE_STORE: | |
| if name.startswith(prefix): | |
| m = re.search(r'_(\d+)$', name) | |
| if m: | |
| max_counter = max(max_counter, int(m.group(1))) | |
| FIGURE_COUNTERS[tab_id] = max_counter | |
| except Exception as e: | |
| import traceback | |
| error_message = f"Code execution error: {str(e)}\n{traceback.format_exc()}" | |
| logger.error(error_message) | |
| # Check if this is a sandbox timeout error (502) | |
| error_str = str(e) | |
| if "502" in error_str or "sandbox was not found" in error_str.lower() or "timeout" in error_str.lower(): | |
| if session_id in SANDBOXES: | |
| try: | |
| SANDBOXES[session_id].kill() | |
| except: | |
| pass | |
| del SANDBOXES[session_id] | |
| yield f"data: {json.dumps({'type': 'info', 'content': 'Sandbox timed out. Creating new sandbox and retrying...'})}\n\n" | |
| try: | |
| os.environ["E2B_API_KEY"] = e2b_key | |
| SANDBOXES[session_id] = Sandbox.create(timeout=SANDBOX_TIMEOUT) | |
| sbx = SANDBOXES[session_id] | |
| yield f"data: {json.dumps({'type': 'info', 'content': 'New sandbox created. Retrying execution...'})}\n\n" | |
| async for chunk in _stream_sync_generator( | |
| stream_code_execution, client, model, full_messages, sbx, | |
| files_root=files_root or FILES_ROOT, extra_params=extra_params, | |
| abort_event=abort_event, multimodal=multimodal, tab_id=tab_id, | |
| figure_store=FIGURE_STORE, | |
| ): | |
| yield chunk | |
| except Exception as retry_error: | |
| yield f"data: {json.dumps({'type': 'error', 'content': f'Failed to retry after timeout: {str(retry_error)}'})}\n\n" | |
| else: | |
| yield f"data: {json.dumps({'type': 'error', 'content': error_message})}\n\n" | |
| async def stream_research_agent( | |
| messages: List[dict], | |
| endpoint: str, | |
| token: Optional[str], | |
| model: str, | |
| serper_key: str, | |
| sub_agent_model: Optional[str] = None, | |
| parallel_workers: Optional[int] = None, | |
| max_websites: Optional[int] = None, | |
| tab_id: str = "default", | |
| parent_agent_id: Optional[str] = None, | |
| sub_agent_endpoint: Optional[str] = None, | |
| sub_agent_token: Optional[str] = None, | |
| extra_params: Optional[Dict] = None, | |
| sub_agent_extra_params: Optional[Dict] = None | |
| ): | |
| """Handle research agent with web search""" | |
| abort_event = register_agent(tab_id, parent_agent_id) | |
| try: | |
| async for chunk in _stream_research_agent_inner(messages, endpoint, token, model, serper_key, sub_agent_model, parallel_workers, max_websites, tab_id, sub_agent_endpoint, sub_agent_token, extra_params, sub_agent_extra_params, abort_event): | |
| yield chunk | |
| finally: | |
| unregister_agent(tab_id) | |
| async def _stream_research_agent_inner(messages, endpoint, token, model, serper_key, sub_agent_model, parallel_workers, max_websites, tab_id, sub_agent_endpoint, sub_agent_token, extra_params, sub_agent_extra_params, abort_event): | |
| if not RESEARCH_AVAILABLE: | |
| yield f"data: {json.dumps({'type': 'error', 'content': 'Research dependencies not available. Install with: pip install trafilatura requests'})}\n\n" | |
| return | |
| if not serper_key: | |
| yield f"data: {json.dumps({'type': 'error', 'content': 'Serper API key required for research. Please configure in settings.'})}\n\n" | |
| return | |
| try: | |
| # Get the research question from the last user message | |
| question = messages[-1]['content'] if messages else "" | |
| if not question: | |
| yield f"data: {json.dumps({'type': 'error', 'content': 'No research question provided'})}\n\n" | |
| return | |
| # Create OpenAI client for main model | |
| client = OpenAI(base_url=endpoint, api_key=token) | |
| # Create separate client for sub-agent if different endpoint provided | |
| sub_agent_client = None | |
| if sub_agent_endpoint and sub_agent_endpoint != endpoint: | |
| sub_agent_client = OpenAI(base_url=sub_agent_endpoint, api_key=sub_agent_token) | |
| # Get system prompt for research (with file tree) | |
| system_prompt = get_system_prompt("research") | |
| # Use sub-agent model if provided, otherwise fall back to main model | |
| analysis_model = sub_agent_model if sub_agent_model else model | |
| # Use parallel workers if provided, otherwise default to 8 | |
| workers = parallel_workers if parallel_workers else 8 | |
| # Use max websites if provided, otherwise default to 50 | |
| max_sites = max_websites if max_websites else 50 | |
| async for chunk in _stream_sync_generator( | |
| stream_research, client, model, question, serper_key, | |
| max_websites=max_sites, system_prompt=system_prompt, | |
| sub_agent_model=analysis_model, parallel_workers=workers, | |
| sub_agent_client=sub_agent_client, extra_params=extra_params, | |
| sub_agent_extra_params=sub_agent_extra_params, abort_event=abort_event | |
| ): | |
| yield chunk | |
| except Exception as e: | |
| import traceback | |
| error_message = f"Research error: {str(e)}\n{traceback.format_exc()}" | |
| logger.error(error_message) | |
| yield f"data: {json.dumps({'type': 'error', 'content': error_message})}\n\n" | |
| async def stream_command_center_handler( | |
| messages: List[dict], | |
| endpoint: str, | |
| token: Optional[str], | |
| model: str, | |
| tab_id: str = "0", | |
| extra_params: Optional[Dict] = None, | |
| files_root: str = None, | |
| ): | |
| """Handle command center with tool-based agent launching""" | |
| abort_event = register_agent(tab_id) | |
| try: | |
| async for chunk in _stream_command_center_inner(messages, endpoint, token, model, tab_id, extra_params, abort_event, files_root=files_root): | |
| yield chunk | |
| finally: | |
| unregister_agent(tab_id) | |
| async def _stream_command_center_inner(messages, endpoint, token, model, tab_id, extra_params, abort_event, files_root=None): | |
| if not COMMAND_AVAILABLE: | |
| # Fallback to regular chat if command tools not available | |
| async for chunk in stream_chat_response(messages, endpoint, token, model, "command", tab_id, extra_params): | |
| yield chunk | |
| return | |
| try: | |
| client = OpenAI(base_url=endpoint, api_key=token) | |
| system_prompt = get_system_prompt("command") | |
| full_messages = [{"role": "system", "content": system_prompt}] + messages | |
| async for chunk in _stream_sync_generator( | |
| stream_command_center, client, model, full_messages, | |
| extra_params=extra_params, abort_event=abort_event, | |
| files_root=files_root or FILES_ROOT | |
| ): | |
| yield chunk | |
| except Exception as e: | |
| import traceback | |
| error_message = f"Command center error: {str(e)}\n{traceback.format_exc()}" | |
| logger.error(error_message) | |
| yield f"data: {json.dumps({'type': 'error', 'content': error_message})}\n\n" | |
| async def stream_web_agent( | |
| messages: List[dict], | |
| endpoint: str, | |
| token: Optional[str], | |
| model: str, | |
| serper_key: str, | |
| tab_id: str = "default", | |
| parent_agent_id: Optional[str] = None, | |
| extra_params: Optional[Dict] = None, | |
| multimodal: bool = False | |
| ): | |
| """Handle web agent with tools (search, read, screenshot)""" | |
| abort_event = register_agent(tab_id, parent_agent_id) | |
| try: | |
| async for chunk in _stream_web_agent_inner(messages, endpoint, token, model, serper_key, tab_id, extra_params, abort_event, multimodal): | |
| yield chunk | |
| finally: | |
| unregister_agent(tab_id) | |
| async def _stream_web_agent_inner(messages, endpoint, token, model, serper_key, tab_id, extra_params, abort_event, multimodal=False): | |
| if not AGENT_AVAILABLE: | |
| async for chunk in stream_chat_response(messages, endpoint, token, model, "agent", tab_id, extra_params): | |
| yield chunk | |
| return | |
| try: | |
| client = OpenAI(base_url=endpoint, api_key=token) | |
| system_prompt = get_system_prompt("agent") | |
| full_messages = [{"role": "system", "content": system_prompt}] + messages | |
| async for chunk in _stream_sync_generator( | |
| stream_agent_execution, client, model, full_messages, serper_key, | |
| extra_params=extra_params, abort_event=abort_event, multimodal=multimodal | |
| ): | |
| yield chunk | |
| except Exception as e: | |
| import traceback | |
| error_message = f"Agent error: {str(e)}\n{traceback.format_exc()}" | |
| logger.error(error_message) | |
| yield f"data: {json.dumps({'type': 'error', 'content': error_message})}\n\n" | |
| async def stream_image_agent( | |
| messages: List[dict], | |
| endpoint: str, | |
| token: Optional[str], | |
| model: str, | |
| hf_token: str, | |
| image_gen_model: Optional[str] = None, | |
| image_edit_model: Optional[str] = None, | |
| tab_id: str = "default", | |
| parent_agent_id: Optional[str] = None, | |
| extra_params: Optional[Dict] = None, | |
| files_root: str = None, | |
| multimodal: bool = False | |
| ): | |
| """Handle image agent with HuggingFace image generation tools""" | |
| abort_event = register_agent(tab_id, parent_agent_id) | |
| try: | |
| async for chunk in _stream_image_agent_inner(messages, endpoint, token, model, hf_token, image_gen_model, image_edit_model, tab_id, extra_params, abort_event, files_root, multimodal): | |
| yield chunk | |
| finally: | |
| unregister_agent(tab_id) | |
| async def _stream_image_agent_inner(messages, endpoint, token, model, hf_token, image_gen_model, image_edit_model, tab_id, extra_params, abort_event, files_root=None, multimodal=False): | |
| if not IMAGE_AVAILABLE: | |
| yield f"data: {json.dumps({'type': 'error', 'content': 'Image agent not available. Install with: pip install huggingface_hub Pillow'})}\n\n" | |
| return | |
| if not hf_token: | |
| yield f"data: {json.dumps({'type': 'error', 'content': 'HuggingFace token required for image generation. Please configure in settings or set HF_TOKEN environment variable.'})}\n\n" | |
| return | |
| # Ensure per-tab counter exists | |
| if tab_id not in FIGURE_COUNTERS: | |
| FIGURE_COUNTERS[tab_id] = 0 | |
| try: | |
| client = OpenAI(base_url=endpoint, api_key=token) | |
| system_prompt = get_system_prompt("image") | |
| full_messages = [{"role": "system", "content": system_prompt}] + messages | |
| async for chunk in _stream_sync_generator( | |
| stream_image_execution, client, model, full_messages, hf_token, | |
| image_gen_model=image_gen_model, image_edit_model=image_edit_model, | |
| extra_params=extra_params, abort_event=abort_event, | |
| files_root=files_root, multimodal=multimodal, | |
| tab_id=tab_id, | |
| image_store=FIGURE_STORE, | |
| image_counter=FIGURE_COUNTERS[tab_id], | |
| ): | |
| yield chunk | |
| # Derive counter from store keys for this tab's prefix | |
| prefix = f"figure_T{tab_id}_" | |
| max_counter = 0 | |
| for name in FIGURE_STORE: | |
| if name.startswith(prefix): | |
| m = re.search(r'_(\d+)$', name) | |
| if m: | |
| max_counter = max(max_counter, int(m.group(1))) | |
| FIGURE_COUNTERS[tab_id] = max_counter | |
| except Exception as e: | |
| import traceback | |
| error_message = f"Image agent error: {str(e)}\n{traceback.format_exc()}" | |
| logger.error(error_message) | |
| yield f"data: {json.dumps({'type': 'error', 'content': error_message})}\n\n" | |
| async def stream_chat_response( | |
| messages: List[dict], | |
| endpoint: str, | |
| token: Optional[str], | |
| model: str, | |
| agent_type: str, | |
| tab_id: str = "default", | |
| extra_params: Optional[Dict] = None | |
| ): | |
| """Proxy stream from user's configured LLM endpoint""" | |
| try: | |
| logger.info(f"Stream request: endpoint={endpoint}, model={model}, messages={len(messages)}, token={'yes' if token else 'no'}") | |
| # Prepare messages with appropriate system prompt based on agent type (with file tree) | |
| system_prompt = get_system_prompt(agent_type) | |
| full_messages = [ | |
| {"role": "system", "content": system_prompt} | |
| ] + messages | |
| # Handle Hugging Face endpoint with fallback to HF_TOKEN | |
| if not token and "huggingface.co" in endpoint: | |
| token = os.getenv("HF_TOKEN") | |
| if token: | |
| logger.debug("Using HF_TOKEN from environment for Hugging Face endpoint") | |
| else: | |
| logger.warning("No token provided and HF_TOKEN not found in environment!") | |
| # Prepare headers | |
| headers = { | |
| "Content-Type": "application/json" | |
| } | |
| if token: | |
| headers["Authorization"] = f"Bearer {token}" | |
| # Prepare request body (OpenAI-compatible format) | |
| request_body = { | |
| "model": model, | |
| "messages": full_messages, | |
| "stream": True, | |
| "temperature": 0.7 | |
| } | |
| # Apply any extra params (e.g., enable_thinking) | |
| if extra_params: | |
| request_body.update(extra_params) | |
| logger.debug(f"Sending request to: {endpoint}/chat/completions") | |
| # Make streaming request to user's endpoint | |
| async with httpx.AsyncClient(timeout=60.0) as client: | |
| async with client.stream( | |
| "POST", | |
| f"{endpoint}/chat/completions", | |
| json=request_body, | |
| headers=headers | |
| ) as response: | |
| if response.status_code != 200: | |
| error_text = await response.aread() | |
| error_detail = error_text.decode() if error_text else "" | |
| # Try to extract JSON error message; fall back to short status text | |
| try: | |
| error_detail = json.loads(error_detail).get("error", {}).get("message", error_detail) | |
| except (json.JSONDecodeError, AttributeError): | |
| pass | |
| if "<html" in error_detail.lower(): | |
| error_detail = f"Status {response.status_code}" | |
| error_message = f"LLM API error ({response.status_code}): {error_detail}" | |
| logger.error(f"LLM API error: {error_message}") | |
| yield f"data: {json.dumps({'type': 'error', 'content': error_message})}\n\n" | |
| return | |
| # Stream the response | |
| async for line in response.aiter_lines(): | |
| if line.startswith("data: "): | |
| data_str = line[6:] # Remove "data: " prefix | |
| if data_str.strip() == "[DONE]": | |
| yield f"data: {json.dumps({'type': 'done'})}\n\n" | |
| continue | |
| try: | |
| data = json.loads(data_str) | |
| # Extract content from OpenAI-compatible response | |
| if "choices" in data and len(data["choices"]) > 0: | |
| delta = data["choices"][0].get("delta", {}) | |
| content = delta.get("content") | |
| if content: | |
| # Forward the content token | |
| yield f"data: {json.dumps({'type': 'content', 'content': content})}\n\n" | |
| except json.JSONDecodeError: | |
| # Skip malformed JSON | |
| continue | |
| # Always send done event after stream completes (in case LLM didn't send [DONE]) | |
| yield f"data: {json.dumps({'type': 'done'})}\n\n" | |
| except httpx.RequestError as e: | |
| error_message = f"Connection error to LLM endpoint: {str(e)}" | |
| logger.error(f"HTTP Request Error: {e}") | |
| yield f"data: {json.dumps({'type': 'error', 'content': error_message})}\n\n" | |
| except Exception as e: | |
| import traceback | |
| error_message = f"Error: {str(e) or 'Unknown error occurred'}" | |
| logger.error(f"Exception in stream_chat_response: {e}\n{traceback.format_exc()}") | |
| yield f"data: {json.dumps({'type': 'error', 'content': error_message})}\n\n" | |
| # NOTE: Root "/" endpoint now serves index.html (see serve_index at end of file) | |
| # API info is available at /api/info instead | |
| async def api_info(): | |
| return { | |
| "message": "AgentUI API - LLM Proxy Server", | |
| "version": "1.0.0", | |
| "endpoints": { | |
| "/api/chat/stream": "POST - Proxy streaming chat to user's LLM endpoint" | |
| } | |
| } | |
| async def get_agents(): | |
| """Return agent type registry for frontend consumption.""" | |
| return {"agents": get_registry_for_frontend()} | |
| async def generate_title(request: TitleRequest): | |
| """Generate a short 2-3 word title for a user query""" | |
| try: | |
| # Create headers | |
| headers = {"Content-Type": "application/json"} | |
| if request.token: | |
| headers["Authorization"] = f"Bearer {request.token}" | |
| # Call the LLM to generate a title | |
| async with httpx.AsyncClient(timeout=30.0) as client: | |
| llm_response = await client.post( | |
| f"{request.endpoint}/chat/completions", | |
| headers=headers, | |
| json={ | |
| "model": request.model, | |
| "messages": [ | |
| { | |
| "role": "system", | |
| "content": "You are a helpful assistant that generates concise 2-3 word titles for user queries. Respond with ONLY the title, no additional text, punctuation, or quotes." | |
| }, | |
| { | |
| "role": "user", | |
| "content": f"Generate a 2-3 word title for this query: {request.query}" | |
| } | |
| ], | |
| "temperature": 0.3, | |
| "max_tokens": 20 | |
| } | |
| ) | |
| if llm_response.status_code != 200: | |
| raise HTTPException(status_code=llm_response.status_code, detail="LLM API error") | |
| result = llm_response.json() | |
| title = result["choices"][0]["message"]["content"].strip() | |
| # Remove any quotes that might be in the response | |
| title = title.replace('"', '').replace("'", '') | |
| return {"title": title} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def abort_agent(raw_request: Request, request: AbortRequest): | |
| """Abort a running agent and all its children.""" | |
| user_id = get_user_id(raw_request) | |
| uk_agent_id = user_key(user_id, request.agent_id) | |
| aborted = abort_agent_tree(uk_agent_id) | |
| logger.info(f"Aborted agents: {aborted}") | |
| return {"aborted": aborted} | |
| async def chat_stream(raw_request: Request, request: ChatRequest): | |
| """Proxy streaming chat to user's configured LLM endpoint""" | |
| logger.debug(f"Chat stream request: agent_type={request.agent_type}") | |
| if not request.messages: | |
| raise HTTPException(status_code=400, detail="Messages are required") | |
| if not request.endpoint: | |
| raise HTTPException(status_code=400, detail="LLM endpoint is required") | |
| # Multi-user isolation | |
| user_id = get_user_id(raw_request) | |
| files_root = get_user_files_root(user_id) | |
| # Convert Pydantic models to dicts, preserving tool call fields | |
| messages = [] | |
| for msg in request.messages: | |
| m = {"role": msg.role, "content": msg.content} | |
| if msg.tool_call_id is not None: | |
| m["tool_call_id"] = msg.tool_call_id | |
| if msg.tool_calls is not None: | |
| m["tool_calls"] = msg.tool_calls | |
| messages.append(m) | |
| # Get tab_id for debugging (prefixed with user_id for dict isolation) | |
| tab_id = request.agent_id or "0" | |
| uk_tab_id = user_key(user_id, tab_id) | |
| # Convert frontend_context to dict if provided | |
| frontend_context = request.frontend_context.model_dump() if request.frontend_context else None | |
| # Apply environment variable fallbacks for API keys | |
| e2b_key = get_env_fallback(request.e2b_key, "E2B_API_KEY") | |
| serper_key = get_env_fallback(request.serper_key, "SERPER_API_KEY") | |
| hf_token = get_env_fallback(request.hf_token, "HF_TOKEN") | |
| token = get_env_fallback(request.token, "LLM_API_KEY") | |
| # For image generation: fall back to the LLM provider token (often the same HF token) | |
| if not hf_token: | |
| hf_token = token | |
| # Route to code execution handler | |
| if request.agent_type == "code": | |
| # Use user-prefixed agent_id as session key for sandbox isolation | |
| session_id = user_key(user_id, request.agent_id or "default") | |
| return StreamingResponse( | |
| stream_code_agent( | |
| messages, | |
| request.endpoint, | |
| token, | |
| request.model or "gpt-4", | |
| e2b_key or "", | |
| session_id, | |
| uk_tab_id, | |
| user_key(user_id, request.parent_agent_id) if request.parent_agent_id else None, | |
| frontend_context, | |
| request.extra_params, | |
| files_root=files_root, | |
| multimodal=request.multimodal | |
| ), | |
| media_type="text/event-stream", | |
| headers={ | |
| "Cache-Control": "no-cache", | |
| "Connection": "keep-alive", | |
| "X-Accel-Buffering": "no", | |
| } | |
| ) | |
| # Route to research handler | |
| if request.agent_type == "research": | |
| # Use sub-agent endpoint/token if provided, otherwise fall back to main | |
| sub_agent_endpoint = request.research_sub_agent_endpoint or request.endpoint | |
| sub_agent_token = request.research_sub_agent_token if request.research_sub_agent_endpoint else token | |
| return StreamingResponse( | |
| stream_research_agent( | |
| messages, | |
| request.endpoint, | |
| token, | |
| request.model or "gpt-4", | |
| serper_key or "", | |
| request.research_sub_agent_model, | |
| request.research_parallel_workers, | |
| None, | |
| uk_tab_id, | |
| user_key(user_id, request.parent_agent_id) if request.parent_agent_id else None, | |
| sub_agent_endpoint, | |
| sub_agent_token, | |
| request.extra_params, | |
| request.research_sub_agent_extra_params | |
| ), | |
| media_type="text/event-stream", | |
| headers={ | |
| "Cache-Control": "no-cache", | |
| "Connection": "keep-alive", | |
| "X-Accel-Buffering": "no", | |
| } | |
| ) | |
| # Route to image handler with HuggingFace tools | |
| if request.agent_type == "image": | |
| return StreamingResponse( | |
| stream_image_agent( | |
| messages, | |
| request.endpoint, | |
| token, | |
| request.model or "gpt-4", | |
| hf_token or "", | |
| request.image_gen_model, | |
| request.image_edit_model, | |
| uk_tab_id, | |
| user_key(user_id, request.parent_agent_id) if request.parent_agent_id else None, | |
| request.extra_params, | |
| files_root=files_root, | |
| multimodal=request.multimodal | |
| ), | |
| media_type="text/event-stream", | |
| headers={ | |
| "Cache-Control": "no-cache", | |
| "Connection": "keep-alive", | |
| "X-Accel-Buffering": "no", | |
| } | |
| ) | |
| # Route to agent handler with web tools | |
| if request.agent_type == "agent": | |
| return StreamingResponse( | |
| stream_web_agent( | |
| messages, | |
| request.endpoint, | |
| token, | |
| request.model or "gpt-4", | |
| serper_key or "", | |
| uk_tab_id, | |
| user_key(user_id, request.parent_agent_id) if request.parent_agent_id else None, | |
| request.extra_params, | |
| multimodal=request.multimodal | |
| ), | |
| media_type="text/event-stream", | |
| headers={ | |
| "Cache-Control": "no-cache", | |
| "Connection": "keep-alive", | |
| "X-Accel-Buffering": "no", | |
| } | |
| ) | |
| # Route to command center handler for command center (with tool-based launching) | |
| if request.agent_type == "command": | |
| return StreamingResponse( | |
| stream_command_center_handler( | |
| messages, | |
| request.endpoint, | |
| token, | |
| request.model or "gpt-4", | |
| uk_tab_id, | |
| request.extra_params, | |
| files_root=files_root, | |
| ), | |
| media_type="text/event-stream", | |
| headers={ | |
| "Cache-Control": "no-cache", | |
| "Connection": "keep-alive", | |
| "X-Accel-Buffering": "no", | |
| } | |
| ) | |
| # Regular chat for other agent types | |
| return StreamingResponse( | |
| stream_chat_response( | |
| messages, | |
| request.endpoint, | |
| token, | |
| request.model or "gpt-4", | |
| request.agent_type, | |
| uk_tab_id, | |
| request.extra_params | |
| ), | |
| media_type="text/event-stream", | |
| headers={ | |
| "Cache-Control": "no-cache", | |
| "Connection": "keep-alive", | |
| "X-Accel-Buffering": "no", # Disable nginx buffering | |
| } | |
| ) | |
| async def start_sandbox(raw_request: Request, request: SandboxRequest): | |
| """Start a sandbox for a code agent session""" | |
| user_id = get_user_id(raw_request) | |
| session_id = user_key(user_id, request.session_id) | |
| e2b_key = request.e2b_key | |
| if not E2B_AVAILABLE: | |
| return {"success": False, "error": "E2B not available. Install with: pip install e2b-code-interpreter"} | |
| if not e2b_key: | |
| return {"success": False, "error": "E2B API key required"} | |
| if not request.session_id: | |
| return {"success": False, "error": "Session ID required"} | |
| try: | |
| # Check if sandbox already exists and is alive | |
| if session_id in SANDBOXES: | |
| try: | |
| # Try to ping the sandbox to see if it's still alive | |
| sbx = SANDBOXES[session_id] | |
| # Simple test to check if sandbox is responsive | |
| sbx.run_code("1+1") | |
| return {"success": True, "message": "Sandbox already running"} | |
| except: | |
| # Sandbox is dead, remove it | |
| try: | |
| SANDBOXES[session_id].kill() | |
| except: | |
| pass | |
| del SANDBOXES[session_id] | |
| # Create new sandbox | |
| os.environ["E2B_API_KEY"] = e2b_key | |
| sbx = Sandbox.create(timeout=SANDBOX_TIMEOUT) | |
| # Suppress noisy warnings (matplotlib, deprecation, etc.) | |
| sbx.run_code("import warnings; warnings.filterwarnings('ignore'); import logging; logging.disable(logging.WARNING)") | |
| SANDBOXES[session_id] = sbx | |
| return {"success": True, "message": "Sandbox started successfully"} | |
| except Exception as e: | |
| return {"success": False, "error": f"Failed to start sandbox: {str(e)}"} | |
| async def stop_sandbox(raw_request: Request, request: SandboxStopRequest): | |
| """Stop a sandbox for a code agent session""" | |
| user_id = get_user_id(raw_request) | |
| session_id = user_key(user_id, request.session_id) | |
| if not request.session_id: | |
| return {"success": False, "error": "Session ID required"} | |
| if session_id in SANDBOXES: | |
| try: | |
| SANDBOXES[session_id].kill() | |
| del SANDBOXES[session_id] | |
| return {"success": True, "message": "Sandbox stopped"} | |
| except Exception as e: | |
| return {"success": False, "error": f"Failed to stop sandbox: {str(e)}"} | |
| return {"success": True, "message": "No sandbox found for this session"} | |
| async def add_tool_response(raw_request: Request, request: dict): | |
| """Add a tool response to the conversation history when an agent returns a result""" | |
| global CONVERSATION_HISTORY | |
| user_id = get_user_id(raw_request) | |
| tab_id = request.get("tab_id", "0") | |
| uk_tab_id = user_key(user_id, tab_id) | |
| tool_call_id = request.get("tool_call_id") | |
| content = request.get("content") | |
| if not tool_call_id or not content: | |
| return {"success": False, "error": "tool_call_id and content are required"} | |
| # Initialize if needed | |
| if uk_tab_id not in CONVERSATION_HISTORY: | |
| CONVERSATION_HISTORY[uk_tab_id] = [] | |
| # Add tool response to conversation history | |
| CONVERSATION_HISTORY[uk_tab_id].append({ | |
| "role": "tool", | |
| "tool_call_id": tool_call_id, | |
| "content": content | |
| }) | |
| return {"success": True} | |
| async def get_debug_messages(request: Request, tab_id: str): | |
| """Get the message history for a specific tab for debugging. | |
| Debug data is now streamed via SSE events (debug_call_input/output) and stored in the frontend. | |
| This endpoint is kept for backward compatibility but returns empty.""" | |
| return {"calls": []} | |
| async def health(): | |
| """Health check endpoint""" | |
| return {"status": "healthy"} | |
| # File paths - use ~/.config/agentui/ by default (cross-platform standard) | |
| # Falls back to ~/.config/productive/ for backward compatibility | |
| # These can be overridden via command-line arguments or set_*_file functions | |
| def get_default_config_dir(): | |
| """Get the default config directory (~/.config/agentui/), with fallback to ~/.config/productive/""" | |
| config_home = os.environ.get("XDG_CONFIG_HOME", os.path.join(os.path.expanduser("~"), ".config")) | |
| new_dir = os.path.join(config_home, "agentui") | |
| old_dir = os.path.join(config_home, "productive") | |
| # Use new dir if it exists, or if old dir doesn't exist (fresh install) | |
| if os.path.exists(new_dir) or not os.path.exists(old_dir): | |
| return new_dir | |
| # Fall back to old dir for existing installations | |
| return old_dir | |
| CONFIG_DIR = get_default_config_dir() | |
| os.makedirs(CONFIG_DIR, exist_ok=True) | |
| SETTINGS_FILE = os.path.join(CONFIG_DIR, "settings.json") | |
| FILES_ROOT = os.getcwd() # Root directory for file tree (current working directory) | |
| SESSIONS_ROOT = os.path.join(FILES_ROOT, "sessions") # Sessions are stored in sessions/ subfolder | |
| CURRENT_SESSION = None # Name of the current session (None = no session selected) | |
| WORKSPACE_FILE = None # Set when session is selected | |
| # Directories/patterns to exclude from file tree | |
| FILES_EXCLUDE = { | |
| 'node_modules', '__pycache__', '.git', '.pytest_cache', | |
| 'env', 'venv', 'env312', '.venv', 'dist', 'build', | |
| '.egg-info', '.tox', '.coverage', 'htmlcov', | |
| 'test-results', 'playwright-report', 'sessions', 'users' | |
| } | |
| def set_settings_file(path: str): | |
| """Set the settings file path (used for testing)""" | |
| global SETTINGS_FILE | |
| SETTINGS_FILE = path | |
| def set_workspace_file(path: str): | |
| """Set the workspace file path (used for testing)""" | |
| global WORKSPACE_FILE | |
| WORKSPACE_FILE = path | |
| def set_config_dir(directory: str): | |
| """Set the config directory for settings.json""" | |
| global SETTINGS_FILE, CONFIG_DIR | |
| os.makedirs(directory, exist_ok=True) | |
| CONFIG_DIR = directory | |
| SETTINGS_FILE = os.path.join(directory, "settings.json") | |
| def set_files_root(directory: str): | |
| """Set the root directory for file tree browsing""" | |
| global FILES_ROOT | |
| FILES_ROOT = directory | |
| async def get_config(): | |
| """Return server configuration flags for the frontend""" | |
| return {"multiUser": MULTI_USER} | |
| async def check_user_exists(username: str): | |
| """Check if a user directory already exists (multi-user mode only)""" | |
| if not MULTI_USER or not USERS_ROOT: | |
| return {"exists": False} | |
| user_dir = os.path.join(USERS_ROOT, username) | |
| return {"exists": os.path.isdir(user_dir)} | |
| async def get_settings(request: Request): | |
| """Read settings from settings.json file""" | |
| user_id = get_user_id(request) | |
| settings_file = get_user_settings_file(user_id) | |
| try: | |
| if os.path.exists(settings_file): | |
| with open(settings_file, "r") as f: | |
| settings = json.load(f) | |
| # Migrate old "notebooks" key to "agents" | |
| if "notebooks" in settings and "agents" not in settings: | |
| settings["agents"] = settings.pop("notebooks") | |
| settings["_settingsPath"] = settings_file | |
| return settings | |
| else: | |
| # Return default settings if file doesn't exist | |
| return { | |
| "endpoint": "https://api.openai.com/v1", | |
| "token": "", | |
| "model": "gpt-4" | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Failed to read settings: {str(e)}") | |
| async def save_settings(request: Request, settings: dict): | |
| """Save settings to settings.json file""" | |
| user_id = get_user_id(request) | |
| settings_file = get_user_settings_file(user_id) | |
| try: | |
| with open(settings_file, "w") as f: | |
| json.dump(settings, f, indent=2) | |
| return {"success": True} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Failed to save settings: {str(e)}") | |
| # ============================================ | |
| # Session Management | |
| # ============================================ | |
| def get_session_path(session_name: str, sessions_root: str = None) -> str: | |
| """Get the full path for a session folder""" | |
| return os.path.join(sessions_root or SESSIONS_ROOT, session_name) | |
| def list_sessions(sessions_root: str = None) -> list: | |
| """List all available sessions""" | |
| root = sessions_root or SESSIONS_ROOT | |
| sessions = [] | |
| if os.path.exists(root): | |
| for name in os.listdir(root): | |
| session_path = os.path.join(root, name) | |
| workspace_file = os.path.join(session_path, "workspace.json") | |
| if os.path.isdir(session_path) and os.path.exists(workspace_file): | |
| # Get modification time | |
| mtime = os.path.getmtime(workspace_file) | |
| sessions.append({ | |
| "name": name, | |
| "modified": mtime | |
| }) | |
| # Sort by modification time (most recent first) | |
| sessions.sort(key=lambda x: x["modified"], reverse=True) | |
| return sessions | |
| def select_session(session_name: str, user_id: str = '') -> bool: | |
| """Select a session and update paths (per-user in multi-user mode)""" | |
| sessions_root = get_user_sessions_root(user_id) | |
| session_path = get_session_path(session_name, sessions_root) | |
| workspace_file = os.path.join(session_path, "workspace.json") | |
| if not os.path.exists(session_path): | |
| return False | |
| set_user_current_session(user_id, session_name) | |
| set_user_workspace_file(user_id, workspace_file) | |
| # FILES_ROOT stays at the original working directory (not session-specific) | |
| # Clear backend state when switching sessions | |
| # In multi-user mode, only clear keys belonging to this user | |
| if MULTI_USER and user_id: | |
| prefix = f"{user_id}:" | |
| keys_to_remove = [k for k in CONVERSATION_HISTORY if k.startswith(prefix)] | |
| for k in keys_to_remove: | |
| del CONVERSATION_HISTORY[k] | |
| # Clear figure store entries belonging to this user's tabs | |
| for k in [k for k in FIGURE_STORE if k.startswith(f"figure_T{prefix}")]: | |
| del FIGURE_STORE[k] | |
| for k in [k for k in FIGURE_COUNTERS if k.startswith(prefix)]: | |
| del FIGURE_COUNTERS[k] | |
| else: | |
| CONVERSATION_HISTORY.clear() | |
| FIGURE_STORE.clear() | |
| FIGURE_COUNTERS.clear() | |
| return True | |
| def create_session(session_name: str, sessions_root: str = None) -> bool: | |
| """Create a new session folder with default workspace""" | |
| session_path = get_session_path(session_name, sessions_root) | |
| if os.path.exists(session_path): | |
| return False # Session already exists | |
| os.makedirs(session_path, exist_ok=True) | |
| # Create default workspace.json | |
| workspace_file = os.path.join(session_path, "workspace.json") | |
| with open(workspace_file, "w") as f: | |
| json.dump(get_default_workspace(), f, indent=2) | |
| return True | |
| async def api_random_session_name(): | |
| """Get a random isotope name for session naming. | |
| Uses two-stage sampling: 1) pick random element, 2) pick random isotope. | |
| This gives equal weight to all elements regardless of isotope count. | |
| """ | |
| import random | |
| try: | |
| from .defaultnames import ISOTOPES | |
| except ImportError: | |
| from defaultnames import ISOTOPES | |
| # Two-stage sampling: first pick element, then pick isotope | |
| element = random.choice(list(ISOTOPES.keys())) | |
| mass_number = random.choice(ISOTOPES[element]) | |
| return {"name": f"{element}-{mass_number}"} | |
| async def api_list_sessions(request: Request): | |
| """List all available sessions""" | |
| user_id = get_user_id(request) | |
| sessions_root = get_user_sessions_root(user_id) | |
| return { | |
| "sessions": list_sessions(sessions_root), | |
| "current": get_user_current_session(user_id), | |
| "sessionsRoot": sessions_root | |
| } | |
| async def api_create_session(request: Request, data: dict): | |
| """Create a new session""" | |
| user_id = get_user_id(request) | |
| sessions_root = get_user_sessions_root(user_id) | |
| name = data.get("name", "").strip() | |
| if not name: | |
| raise HTTPException(status_code=400, detail="Session name is required") | |
| # Sanitize name for filesystem | |
| safe_name = "".join(c for c in name if c.isalnum() or c in "- _").strip() | |
| if not safe_name: | |
| raise HTTPException(status_code=400, detail="Invalid session name") | |
| if not create_session(safe_name, sessions_root): | |
| raise HTTPException(status_code=409, detail="Session already exists") | |
| # Auto-select the new session | |
| select_session(safe_name, user_id) | |
| return {"success": True, "name": safe_name} | |
| async def api_select_session(request: Request, data: dict): | |
| """Select an existing session""" | |
| user_id = get_user_id(request) | |
| name = data.get("name", "").strip() | |
| if not name: | |
| raise HTTPException(status_code=400, detail="Session name is required") | |
| if not select_session(name, user_id): | |
| raise HTTPException(status_code=404, detail="Session not found") | |
| return {"success": True, "name": name} | |
| async def api_rename_session(request: Request, data: dict): | |
| """Rename a session""" | |
| user_id = get_user_id(request) | |
| sessions_root = get_user_sessions_root(user_id) | |
| old_name = data.get("oldName", "").strip() | |
| new_name = data.get("newName", "").strip() | |
| if not old_name or not new_name: | |
| raise HTTPException(status_code=400, detail="Both oldName and newName are required") | |
| # Sanitize new name | |
| safe_new_name = "".join(c for c in new_name if c.isalnum() or c in "- _").strip() | |
| if not safe_new_name: | |
| raise HTTPException(status_code=400, detail="Invalid new session name") | |
| old_path = get_session_path(old_name, sessions_root) | |
| new_path = get_session_path(safe_new_name, sessions_root) | |
| if not os.path.exists(old_path): | |
| raise HTTPException(status_code=404, detail="Session not found") | |
| if os.path.exists(new_path): | |
| raise HTTPException(status_code=409, detail="A session with that name already exists") | |
| os.rename(old_path, new_path) | |
| # Update current session if it was renamed | |
| if get_user_current_session(user_id) == old_name: | |
| set_user_current_session(user_id, safe_new_name) | |
| set_user_workspace_file(user_id, os.path.join(new_path, "workspace.json")) | |
| return {"success": True, "name": safe_new_name} | |
| async def api_delete_session(request: Request, session_name: str): | |
| """Delete a session""" | |
| import shutil | |
| user_id = get_user_id(request) | |
| sessions_root = get_user_sessions_root(user_id) | |
| if not session_name: | |
| raise HTTPException(status_code=400, detail="Session name is required") | |
| session_path = get_session_path(session_name, sessions_root) | |
| if not os.path.exists(session_path): | |
| raise HTTPException(status_code=404, detail="Session not found") | |
| # Don't allow deleting the current session | |
| if get_user_current_session(user_id) == session_name: | |
| raise HTTPException(status_code=400, detail="Cannot delete the currently active session") | |
| # Delete the session folder | |
| shutil.rmtree(session_path) | |
| return {"success": True} | |
| # ============================================ | |
| # Workspace State Persistence | |
| # ============================================ | |
| def get_default_workspace(): | |
| """Return default empty workspace state""" | |
| return { | |
| "version": 1, | |
| "tabCounter": 1, | |
| "activeTabId": 0, | |
| "agentCounters": get_default_counters(), | |
| "tabs": [ | |
| { | |
| "id": 0, | |
| "type": "command-center", | |
| "title": "COMMAND", | |
| "messages": [] | |
| } | |
| ] | |
| } | |
| async def get_workspace(request: Request): | |
| """Load workspace state from workspace.json file""" | |
| user_id = get_user_id(request) | |
| workspace_file = get_user_workspace_file(user_id) | |
| if workspace_file is None: | |
| raise HTTPException(status_code=400, detail="No session selected") | |
| try: | |
| if os.path.exists(workspace_file): | |
| with open(workspace_file, "r") as f: | |
| workspace = json.load(f) | |
| return workspace | |
| else: | |
| # Return default workspace if file doesn't exist | |
| return get_default_workspace() | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Failed to read workspace: {str(e)}") | |
| async def save_workspace(request: Request, workspace: dict): | |
| """Save workspace state to workspace.json file""" | |
| user_id = get_user_id(request) | |
| workspace_file = get_user_workspace_file(user_id) | |
| if workspace_file is None: | |
| raise HTTPException(status_code=400, detail="No session selected") | |
| try: | |
| with open(workspace_file, "w") as f: | |
| json.dump(workspace, f, indent=2) | |
| return {"success": True} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Failed to save workspace: {str(e)}") | |
| async def clear_workspace(request: Request): | |
| """Clear workspace and start fresh""" | |
| user_id = get_user_id(request) | |
| workspace_file = get_user_workspace_file(user_id) | |
| if workspace_file is None: | |
| raise HTTPException(status_code=400, detail="No session selected") | |
| try: | |
| default_workspace = get_default_workspace() | |
| with open(workspace_file, "w") as f: | |
| json.dump(default_workspace, f, indent=2) | |
| return {"success": True, "workspace": default_workspace} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Failed to clear workspace: {str(e)}") | |
| # ============================================ | |
| # File Tree API | |
| # ============================================ | |
| def build_file_tree(root_path: str, show_hidden: bool = False, files_root: str = None) -> list: | |
| """Build a file tree structure from a directory""" | |
| tree = [] | |
| base_root = files_root or FILES_ROOT | |
| try: | |
| entries = sorted(os.listdir(root_path)) | |
| except PermissionError: | |
| return tree | |
| for entry in entries: | |
| # Skip hidden files unless show_hidden is True | |
| if entry.startswith('.') and not show_hidden: | |
| continue | |
| # Skip excluded directories | |
| if entry in FILES_EXCLUDE: | |
| continue | |
| full_path = os.path.join(root_path, entry) | |
| rel_path = os.path.relpath(full_path, base_root) | |
| if os.path.isdir(full_path): | |
| children = build_file_tree(full_path, show_hidden, base_root) | |
| tree.append({ | |
| "name": entry, | |
| "type": "folder", | |
| "path": rel_path, | |
| "children": children | |
| }) | |
| else: | |
| tree.append({ | |
| "name": entry, | |
| "type": "file", | |
| "path": rel_path | |
| }) | |
| return tree | |
| def format_file_tree_text(tree: list, prefix: str = "", is_last: bool = True) -> str: | |
| """Format file tree as text for system prompts""" | |
| lines = [] | |
| for i, item in enumerate(tree): | |
| is_last_item = (i == len(tree) - 1) | |
| connector = "└── " if is_last_item else "├── " | |
| lines.append(f"{prefix}{connector}{item['name']}{'/' if item['type'] == 'folder' else ''}") | |
| if item['type'] == 'folder' and item.get('children'): | |
| extension = " " if is_last_item else "│ " | |
| child_text = format_file_tree_text(item['children'], prefix + extension, is_last_item) | |
| if child_text: | |
| lines.append(child_text) | |
| return "\n".join(lines) | |
| def get_file_tree_for_prompt() -> str: | |
| """Get formatted file tree text for inclusion in system prompts""" | |
| tree = build_file_tree(FILES_ROOT, show_hidden=False) | |
| tree_text = format_file_tree_text(tree) | |
| return f"Working Directory: {FILES_ROOT}\n{tree_text}" | |
| def get_styling_context(theme: Optional[Dict] = None) -> str: | |
| """Generate styling guidance for code agents based on current theme""" | |
| # App style description | |
| style_desc = """## Visual Style Guidelines | |
| The application has a minimalist, technical aesthetic with clean lines and muted colors. When generating plots or visualizations: | |
| - Use white/light backgrounds to match the notebook style | |
| - Prefer clean, simple chart styles without excessive decoration | |
| - Use the theme accent color as the primary color for data series | |
| - Use neutral grays (#666, #999, #ccc) for secondary elements, gridlines, and text | |
| - Use 300 DPI for all figures unless the user specifies otherwise (e.g., plt.figure(figsize=..., dpi=300) or plt.savefig(..., dpi=300))""" | |
| if theme: | |
| accent = theme.get('accent', '#1b5e20') | |
| bg = theme.get('bg', '#e8f5e9') | |
| name = theme.get('name', 'forest') | |
| bg_primary = theme.get('bgPrimary', '#ffffff') | |
| text_primary = theme.get('textPrimary', '#1a1a1a') | |
| text_secondary = theme.get('textSecondary', '#666666') | |
| style_desc += f""" | |
| Current theme: {name} | |
| - Primary/accent color: {accent} (use for main data series, highlights) | |
| - Light background: {bg} (use for fills, light accents) | |
| - Chart background color: {bg_primary} (use for figure and axes facecolor) | |
| - Text color: {text_primary} (use for titles, labels, tick labels) | |
| - Secondary text color: {text_secondary} (use for gridlines, secondary labels) | |
| - Set fig.patch.set_facecolor('{bg_primary}') and ax.set_facecolor('{bg_primary}') for all plots""" | |
| return style_desc | |
| def get_system_prompt(agent_type: str, frontend_context: Optional[Dict] = None) -> str: | |
| """Get system prompt for an agent type with dynamic context appended""" | |
| from .agents import get_system_prompt as _get_agent_prompt | |
| base_prompt = _get_agent_prompt(agent_type) or _get_agent_prompt("command") | |
| file_tree = get_file_tree_for_prompt() | |
| # Build the full prompt with context sections | |
| sections = [base_prompt, f"## Project Files\n{file_tree}"] | |
| # Add styling context for code agents | |
| if agent_type == "code" and frontend_context: | |
| theme = frontend_context.get('theme') if frontend_context else None | |
| styling = get_styling_context(theme) | |
| sections.append(styling) | |
| return "\n\n".join(sections) | |
| async def get_file_tree(request: Request, show_hidden: bool = False): | |
| """Get file tree structure for the working directory""" | |
| user_id = get_user_id(request) | |
| files_root = get_user_files_root(user_id) | |
| try: | |
| tree = build_file_tree(files_root, show_hidden, files_root) | |
| return { | |
| "root": files_root, | |
| "tree": tree | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Failed to read file tree: {str(e)}") | |
| async def check_file_paths(request: Request): | |
| """Check which paths exist in the workspace""" | |
| user_id = get_user_id(request) | |
| files_root = get_user_files_root(user_id) | |
| data = await request.json() | |
| paths = data.get("paths", []) | |
| existing = [] | |
| root = os.path.normpath(files_root) | |
| for p in paths: | |
| full = os.path.normpath(os.path.join(files_root, p)) | |
| if full.startswith(root) and os.path.exists(full): | |
| existing.append(p) | |
| return {"existing": existing} | |
| async def download_file(request: Request, path: str = Query(..., description="Relative path to file"), session_id: str = Query(None, description="Session ID for multi-user auth (used by window.open)")): | |
| """Download a file from the workspace to the browser""" | |
| user_id = get_user_id(request) | |
| files_root = get_user_files_root(user_id) | |
| full_path = os.path.normpath(os.path.join(files_root, path)) | |
| # Security: ensure path stays within files_root | |
| if not full_path.startswith(os.path.normpath(files_root)): | |
| raise HTTPException(status_code=403, detail="Access denied") | |
| if not os.path.isfile(full_path): | |
| raise HTTPException(status_code=404, detail="File not found") | |
| return FileResponse(full_path, filename=os.path.basename(full_path)) | |
| async def upload_file(request: Request, file: UploadFile = File(...), folder: str = Query("", description="Relative folder path")): | |
| """Upload a file from the browser to the workspace""" | |
| user_id = get_user_id(request) | |
| files_root = get_user_files_root(user_id) | |
| target_dir = os.path.normpath(os.path.join(files_root, folder)) if folder else files_root | |
| # Security: ensure path stays within files_root | |
| if not target_dir.startswith(os.path.normpath(files_root)): | |
| raise HTTPException(status_code=403, detail="Access denied") | |
| if not os.path.isdir(target_dir): | |
| raise HTTPException(status_code=404, detail="Folder not found") | |
| target_path = os.path.join(target_dir, file.filename) | |
| with open(target_path, "wb") as f: | |
| content = await file.read() | |
| f.write(content) | |
| return {"success": True, "path": os.path.relpath(target_path, files_root)} | |
| # ============================================ | |
| # Static File Serving (Frontend) | |
| # ============================================ | |
| FRONTEND_DIR = os.path.join(PROJECT_ROOT, "frontend") | |
| async def serve_index(): | |
| """Serve the main index.html file""" | |
| index_path = os.path.join(FRONTEND_DIR, "index.html") | |
| if os.path.exists(index_path): | |
| return FileResponse(index_path, media_type="text/html") | |
| raise HTTPException(status_code=404, detail="index.html not found") | |
| # Serve static files (JS, CSS) - must be after API routes | |
| # Mount at root to serve script.js, style.css, research-ui.js directly | |
| app.mount("/", StaticFiles(directory=FRONTEND_DIR, html=False), name="static") | |
| def start(): | |
| """Entry point for the 'start' command.""" | |
| import argparse | |
| import webbrowser | |
| import threading | |
| import uvicorn | |
| parser = argparse.ArgumentParser(description="AgentUI API Server") | |
| parser.add_argument("--clean", action="store_true", help="Clear workspace at startup") | |
| parser.add_argument("--port", type=int, default=8765, help="Port to run the server on (default: 8765)") | |
| parser.add_argument("--no-browser", action="store_true", help="Don't open browser automatically") | |
| parser.add_argument("--config-dir", type=str, help="Directory for config files (settings.json)") | |
| parser.add_argument("--workspace-dir", type=str, help="Working directory for workspace.json and file tree") | |
| parser.add_argument("--multi-user", action="store_true", help="Enable per-user session isolation") | |
| args = parser.parse_args() | |
| # Set custom config directory if provided | |
| if args.config_dir: | |
| set_config_dir(args.config_dir) | |
| logger.info(f"Using config directory: {args.config_dir}") | |
| # Set custom workspace directory if provided | |
| if args.workspace_dir: | |
| global FILES_ROOT, SESSIONS_ROOT | |
| FILES_ROOT = os.path.abspath(args.workspace_dir) | |
| SESSIONS_ROOT = os.path.join(FILES_ROOT, "sessions") | |
| # Enable multi-user mode | |
| if args.multi_user: | |
| global MULTI_USER, USERS_ROOT | |
| MULTI_USER = True | |
| USERS_ROOT = os.path.join(FILES_ROOT, "users") | |
| os.makedirs(USERS_ROOT, exist_ok=True) | |
| logger.info(f"Multi-user mode enabled, users root: {USERS_ROOT}") | |
| # Ensure sessions directory exists | |
| os.makedirs(SESSIONS_ROOT, exist_ok=True) | |
| url = f"http://localhost:{args.port}" | |
| logger.info(f"Starting AgentUI server...") | |
| logger.info(f"Config directory: {CONFIG_DIR}") | |
| logger.info(f"Sessions directory: {SESSIONS_ROOT}") | |
| logger.info(f"Opening {url} in your browser...") | |
| # Open browser after a short delay to let the server start | |
| if not args.no_browser: | |
| def open_browser(): | |
| import time | |
| time.sleep(1) # Wait for server to start | |
| webbrowser.open(url) | |
| threading.Thread(target=open_browser, daemon=True).start() | |
| uvicorn.run(app, host="0.0.0.0", port=args.port) | |
| if __name__ == "__main__": | |
| start() | |