Spaces:
Paused
Paused
| from fastapi import APIRouter, HTTPException, Depends, Request, Body, File, UploadFile, Form, Query | |
| from fastapi.responses import StreamingResponse | |
| import asyncio | |
| import json | |
| import traceback | |
| import base64 | |
| from datetime import datetime, timezone | |
| import uuid | |
| from typing import Optional, List, Dict, Any | |
| import jwt | |
| from pydantic import BaseModel | |
| import tempfile | |
| import os | |
| from agentpress.thread_manager import ThreadManager | |
| from services.supabase import DBConnection | |
| from services import redis | |
| from utils.auth_utils import get_current_user_id_from_jwt, get_user_id_from_stream_auth, verify_thread_access, verify_admin_api_key | |
| from utils.logger import logger, structlog | |
| from services.billing import check_billing_status, can_use_model | |
| from utils.config import config | |
| from sandbox.sandbox import create_sandbox, delete_sandbox, get_or_start_sandbox | |
| from services.llm import make_llm_api_call | |
| from run_agent_background import run_agent_background, _cleanup_redis_response_list, update_agent_run_status | |
| from models import model_manager | |
| from flags.flags import is_enabled | |
| from .config_helper import extract_agent_config, build_unified_config | |
| from .utils import check_agent_run_limit | |
| from .versioning.version_service import get_version_service | |
| from .versioning.api import router as version_router, initialize as initialize_versioning | |
| from .tools.sb_presentation_tool import SandboxPresentationTool | |
| # Helper for version service | |
| async def _get_version_service(): | |
| return await get_version_service() | |
| router = APIRouter() | |
| router.include_router(version_router) | |
| db = None | |
| instance_id = None # Global instance ID for this backend instance | |
| # TTL for Redis response lists (24 hours) | |
| REDIS_RESPONSE_LIST_TTL = 3600 * 24 | |
| class AgentStartRequest(BaseModel): | |
| model_name: Optional[str] = None # Will be set to default model in the endpoint | |
| enable_thinking: Optional[bool] = False | |
| reasoning_effort: Optional[str] = 'low' | |
| stream: Optional[bool] = True | |
| enable_context_manager: Optional[bool] = False | |
| agent_id: Optional[str] = None # Custom agent to use | |
| class InitiateAgentResponse(BaseModel): | |
| thread_id: str | |
| agent_run_id: Optional[str] = None | |
| class CreateThreadResponse(BaseModel): | |
| thread_id: str | |
| project_id: str | |
| class MessageCreateRequest(BaseModel): | |
| type: str | |
| content: str | |
| is_llm_message: bool = True | |
| class AgentCreateRequest(BaseModel): | |
| name: str | |
| description: Optional[str] = None | |
| system_prompt: Optional[str] = None # Make optional to allow defaulting to Suna's system prompt | |
| configured_mcps: Optional[List[Dict[str, Any]]] = [] | |
| custom_mcps: Optional[List[Dict[str, Any]]] = [] | |
| agentpress_tools: Optional[Dict[str, Any]] = {} | |
| is_default: Optional[bool] = False | |
| # Deprecated, kept for backward-compat | |
| avatar: Optional[str] = None | |
| avatar_color: Optional[str] = None | |
| # New profile image url (can be external or Supabase storage URL) | |
| profile_image_url: Optional[str] = None | |
| class AgentVersionResponse(BaseModel): | |
| version_id: str | |
| agent_id: str | |
| version_number: int | |
| version_name: str | |
| system_prompt: str | |
| model: Optional[str] = None # Add model field | |
| configured_mcps: List[Dict[str, Any]] | |
| custom_mcps: List[Dict[str, Any]] | |
| agentpress_tools: Dict[str, Any] | |
| is_active: bool | |
| created_at: str | |
| updated_at: str | |
| created_by: Optional[str] = None | |
| class AgentVersionCreateRequest(BaseModel): | |
| system_prompt: str | |
| configured_mcps: Optional[List[Dict[str, Any]]] = [] | |
| custom_mcps: Optional[List[Dict[str, Any]]] = [] | |
| agentpress_tools: Optional[Dict[str, Any]] = {} | |
| version_name: Optional[str] = None # Custom version name | |
| description: Optional[str] = None # Version description | |
| class AgentUpdateRequest(BaseModel): | |
| name: Optional[str] = None | |
| description: Optional[str] = None | |
| system_prompt: Optional[str] = None | |
| configured_mcps: Optional[List[Dict[str, Any]]] = None | |
| custom_mcps: Optional[List[Dict[str, Any]]] = None | |
| agentpress_tools: Optional[Dict[str, Any]] = None | |
| is_default: Optional[bool] = None | |
| # Deprecated, kept for backward-compat | |
| avatar: Optional[str] = None | |
| avatar_color: Optional[str] = None | |
| # New profile image url | |
| profile_image_url: Optional[str] = None | |
| class AgentResponse(BaseModel): | |
| agent_id: str | |
| name: str | |
| description: Optional[str] = None | |
| system_prompt: str | |
| configured_mcps: List[Dict[str, Any]] | |
| custom_mcps: List[Dict[str, Any]] | |
| agentpress_tools: Dict[str, Any] | |
| is_default: bool | |
| # Deprecated | |
| avatar: Optional[str] = None | |
| avatar_color: Optional[str] = None | |
| # New | |
| profile_image_url: Optional[str] = None | |
| created_at: str | |
| updated_at: Optional[str] = None | |
| is_public: Optional[bool] = False | |
| tags: Optional[List[str]] = [] | |
| current_version_id: Optional[str] = None | |
| version_count: Optional[int] = 1 | |
| current_version: Optional[AgentVersionResponse] = None | |
| metadata: Optional[Dict[str, Any]] = None | |
| class PaginationInfo(BaseModel): | |
| page: int | |
| limit: int | |
| total: int | |
| pages: int | |
| class AgentsResponse(BaseModel): | |
| agents: List[AgentResponse] | |
| pagination: PaginationInfo | |
| class ThreadAgentResponse(BaseModel): | |
| agent: Optional[AgentResponse] | |
| source: str # "thread", "default", "none", "missing" | |
| message: str | |
| class AgentExportData(BaseModel): | |
| """Exportable agent configuration data""" | |
| name: str | |
| description: Optional[str] = None | |
| system_prompt: str | |
| agentpress_tools: Dict[str, Any] | |
| configured_mcps: List[Dict[str, Any]] | |
| custom_mcps: List[Dict[str, Any]] | |
| # Deprecated | |
| avatar: Optional[str] = None | |
| avatar_color: Optional[str] = None | |
| # New | |
| profile_image_url: Optional[str] = None | |
| tags: Optional[List[str]] = [] | |
| metadata: Optional[Dict[str, Any]] = None | |
| export_version: str = "1.1" | |
| exported_at: str | |
| exported_by: Optional[str] = None | |
| class AgentImportRequest(BaseModel): | |
| """Request to import an agent from JSON""" | |
| import_data: AgentExportData | |
| import_as_new: bool = True # Always true, only creating new agents is supported | |
| def initialize( | |
| _db: DBConnection, | |
| _instance_id: Optional[str] = None | |
| ): | |
| """Initialize the agent API with resources from the main API.""" | |
| global db, instance_id | |
| db = _db | |
| # Initialize the versioning module with the same database connection | |
| initialize_versioning(_db) | |
| # Use provided instance_id or generate a new one | |
| if _instance_id: | |
| instance_id = _instance_id | |
| else: | |
| # Generate instance ID | |
| instance_id = str(uuid.uuid4())[:8] | |
| logger.debug(f"Initialized agent API with instance ID: {instance_id}") | |
| async def cleanup(): | |
| """Clean up resources and stop running agents on shutdown.""" | |
| logger.debug("Starting cleanup of agent API resources") | |
| # Use the instance_id to find and clean up this instance's keys | |
| try: | |
| if instance_id: # Ensure instance_id is set | |
| running_keys = await redis.keys(f"active_run:{instance_id}:*") | |
| logger.debug(f"Found {len(running_keys)} running agent runs for instance {instance_id} to clean up") | |
| for key in running_keys: | |
| # Key format: active_run:{instance_id}:{agent_run_id} | |
| parts = key.split(":") | |
| if len(parts) == 3: | |
| agent_run_id = parts[2] | |
| await stop_agent_run(agent_run_id, error_message=f"Instance {instance_id} shutting down") | |
| else: | |
| logger.warning(f"Unexpected key format found: {key}") | |
| else: | |
| logger.warning("Instance ID not set, cannot clean up instance-specific agent runs.") | |
| except Exception as e: | |
| logger.error(f"Failed to clean up running agent runs: {str(e)}") | |
| # Close Redis connection | |
| await redis.close() | |
| logger.debug("Completed cleanup of agent API resources") | |
| async def stop_agent_run(agent_run_id: str, error_message: Optional[str] = None): | |
| """Update database and publish stop signal to Redis.""" | |
| logger.debug(f"Stopping agent run: {agent_run_id}") | |
| client = await db.client | |
| final_status = "failed" if error_message else "stopped" | |
| # Attempt to fetch final responses from Redis | |
| response_list_key = f"agent_run:{agent_run_id}:responses" | |
| all_responses = [] | |
| try: | |
| all_responses_json = await redis.lrange(response_list_key, 0, -1) | |
| all_responses = [json.loads(r) for r in all_responses_json] | |
| logger.debug(f"Fetched {len(all_responses)} responses from Redis for DB update on stop/fail: {agent_run_id}") | |
| except Exception as e: | |
| logger.error(f"Failed to fetch responses from Redis for {agent_run_id} during stop/fail: {e}") | |
| # Try fetching from DB as a fallback? Or proceed without responses? Proceeding without for now. | |
| # Update the agent run status in the database | |
| update_success = await update_agent_run_status( | |
| client, agent_run_id, final_status, error=error_message | |
| ) | |
| if not update_success: | |
| logger.error(f"Failed to update database status for stopped/failed run {agent_run_id}") | |
| raise HTTPException(status_code=500, detail="Failed to update agent run status in database") | |
| # Send STOP signal to the global control channel | |
| global_control_channel = f"agent_run:{agent_run_id}:control" | |
| try: | |
| await redis.publish(global_control_channel, "STOP") | |
| logger.debug(f"Published STOP signal to global channel {global_control_channel}") | |
| except Exception as e: | |
| logger.error(f"Failed to publish STOP signal to global channel {global_control_channel}: {str(e)}") | |
| # Find all instances handling this agent run and send STOP to instance-specific channels | |
| try: | |
| instance_keys = await redis.keys(f"active_run:*:{agent_run_id}") | |
| logger.debug(f"Found {len(instance_keys)} active instance keys for agent run {agent_run_id}") | |
| for key in instance_keys: | |
| # Key format: active_run:{instance_id}:{agent_run_id} | |
| parts = key.split(":") | |
| if len(parts) == 3: | |
| instance_id_from_key = parts[1] | |
| instance_control_channel = f"agent_run:{agent_run_id}:control:{instance_id_from_key}" | |
| try: | |
| await redis.publish(instance_control_channel, "STOP") | |
| logger.debug(f"Published STOP signal to instance channel {instance_control_channel}") | |
| except Exception as e: | |
| logger.warning(f"Failed to publish STOP signal to instance channel {instance_control_channel}: {str(e)}") | |
| else: | |
| logger.warning(f"Unexpected key format found: {key}") | |
| # Clean up the response list immediately on stop/fail | |
| await _cleanup_redis_response_list(agent_run_id) | |
| except Exception as e: | |
| logger.error(f"Failed to find or signal active instances for {agent_run_id}: {str(e)}") | |
| logger.debug(f"Successfully initiated stop process for agent run: {agent_run_id}") | |
| async def get_agent_run_with_access_check(client, agent_run_id: str, user_id: str): | |
| agent_run = await client.table('agent_runs').select('*, threads(account_id)').eq('id', agent_run_id).execute() | |
| if not agent_run.data: | |
| raise HTTPException(status_code=404, detail="Agent run not found") | |
| agent_run_data = agent_run.data[0] | |
| thread_id = agent_run_data['thread_id'] | |
| account_id = agent_run_data['threads']['account_id'] | |
| if account_id == user_id: | |
| return agent_run_data | |
| await verify_thread_access(client, thread_id, user_id) | |
| return agent_run_data | |
| async def start_agent( | |
| thread_id: str, | |
| body: AgentStartRequest = Body(...), | |
| user_id: str = Depends(get_current_user_id_from_jwt) | |
| ): | |
| """Start an agent for a specific thread in the background""" | |
| structlog.contextvars.bind_contextvars( | |
| thread_id=thread_id, | |
| ) | |
| global instance_id # Ensure instance_id is accessible | |
| if not instance_id: | |
| raise HTTPException(status_code=500, detail="Agent API not initialized with instance ID") | |
| # Use model from config if not specified in the request | |
| model_name = body.model_name | |
| logger.debug(f"Original model_name from request: {model_name}") | |
| # Log the model name after alias resolution using new model manager | |
| from models import model_manager | |
| resolved_model = model_manager.resolve_model_id(model_name) | |
| logger.debug(f"Resolved model name: {resolved_model}") | |
| # Update model_name to use the resolved version | |
| model_name = resolved_model | |
| logger.debug(f"Starting new agent for thread: {thread_id} with config: model={model_name}, thinking={body.enable_thinking}, effort={body.reasoning_effort}, stream={body.stream}, context_manager={body.enable_context_manager} (Instance: {instance_id})") | |
| client = await db.client | |
| thread_result = await client.table('threads').select('project_id', 'account_id', 'metadata').eq('thread_id', thread_id).execute() | |
| if not thread_result.data: | |
| raise HTTPException(status_code=404, detail="Thread not found") | |
| thread_data = thread_result.data[0] | |
| project_id = thread_data.get('project_id') | |
| account_id = thread_data.get('account_id') | |
| thread_metadata = thread_data.get('metadata', {}) | |
| if account_id != user_id: | |
| await verify_thread_access(client, thread_id, user_id) | |
| structlog.contextvars.bind_contextvars( | |
| project_id=project_id, | |
| account_id=account_id, | |
| thread_metadata=thread_metadata, | |
| ) | |
| # Load agent configuration with version support | |
| agent_config = None | |
| effective_agent_id = body.agent_id # Optional agent ID from request | |
| logger.debug(f"[AGENT LOAD] Agent loading flow:") | |
| logger.debug(f" - body.agent_id: {body.agent_id}") | |
| logger.debug(f" - effective_agent_id: {effective_agent_id}") | |
| if effective_agent_id: | |
| logger.debug(f"[AGENT LOAD] Querying for agent: {effective_agent_id}") | |
| # Get agent | |
| agent_result = await client.table('agents').select('*').eq('agent_id', effective_agent_id).eq('account_id', account_id).execute() | |
| logger.debug(f"[AGENT LOAD] Query result: found {len(agent_result.data) if agent_result.data else 0} agents") | |
| if not agent_result.data: | |
| if body.agent_id: | |
| raise HTTPException(status_code=404, detail="Agent not found or access denied") | |
| else: | |
| logger.warning(f"Stored agent_id {effective_agent_id} not found, falling back to default") | |
| effective_agent_id = None | |
| else: | |
| agent_data = agent_result.data[0] | |
| version_data = None | |
| if agent_data.get('current_version_id'): | |
| try: | |
| version_service = await _get_version_service() | |
| version_obj = await version_service.get_version( | |
| agent_id=effective_agent_id, | |
| version_id=agent_data['current_version_id'], | |
| user_id=user_id | |
| ) | |
| version_data = version_obj.to_dict() | |
| logger.debug(f"[AGENT LOAD] Got version data from version manager: {version_data.get('version_name')}") | |
| except Exception as e: | |
| logger.warning(f"[AGENT LOAD] Failed to get version data: {e}") | |
| logger.debug(f"[AGENT LOAD] About to call extract_agent_config with agent_data keys: {list(agent_data.keys())}") | |
| logger.debug(f"[AGENT LOAD] version_data type: {type(version_data)}, has data: {version_data is not None}") | |
| agent_config = extract_agent_config(agent_data, version_data) | |
| if version_data: | |
| logger.debug(f"Using agent {agent_config['name']} ({effective_agent_id}) version {agent_config.get('version_name', 'v1')}") | |
| else: | |
| logger.debug(f"Using agent {agent_config['name']} ({effective_agent_id}) - no version data") | |
| source = "request" if body.agent_id else "fallback" | |
| else: | |
| logger.debug(f"[AGENT LOAD] No effective_agent_id, will try default agent") | |
| if not agent_config: | |
| logger.debug(f"[AGENT LOAD] No agent config yet, querying for default agent") | |
| default_agent_result = await client.table('agents').select('*').eq('account_id', account_id).eq('is_default', True).execute() | |
| logger.debug(f"[AGENT LOAD] Default agent query result: found {len(default_agent_result.data) if default_agent_result.data else 0} default agents") | |
| if default_agent_result.data: | |
| agent_data = default_agent_result.data[0] | |
| # Use versioning system to get current version | |
| version_data = None | |
| if agent_data.get('current_version_id'): | |
| try: | |
| version_service = await _get_version_service() | |
| version_obj = await version_service.get_version( | |
| agent_id=agent_data['agent_id'], | |
| version_id=agent_data['current_version_id'], | |
| user_id=user_id | |
| ) | |
| version_data = version_obj.to_dict() | |
| logger.debug(f"[AGENT LOAD] Got default agent version from version manager: {version_data.get('version_name')}") | |
| except Exception as e: | |
| logger.warning(f"[AGENT LOAD] Failed to get default agent version data: {e}") | |
| logger.debug(f"[AGENT LOAD] About to call extract_agent_config for DEFAULT agent with version data: {version_data is not None}") | |
| agent_config = extract_agent_config(agent_data, version_data) | |
| if version_data: | |
| logger.debug(f"Using default agent: {agent_config['name']} ({agent_config['agent_id']}) version {agent_config.get('version_name', 'v1')}") | |
| else: | |
| logger.debug(f"Using default agent: {agent_config['name']} ({agent_config['agent_id']}) - no version data") | |
| else: | |
| logger.warning(f"[AGENT LOAD] No default agent found for account {account_id}") | |
| logger.debug(f"[AGENT LOAD] Final agent_config: {agent_config is not None}") | |
| if agent_config: | |
| logger.debug(f"[AGENT LOAD] Agent config keys: {list(agent_config.keys())}") | |
| logger.debug(f"Using agent {agent_config['agent_id']} for this agent run (thread remains agent-agnostic)") | |
| # Run all checks concurrently | |
| model_check_task = asyncio.create_task(can_use_model(client, account_id, model_name)) | |
| billing_check_task = asyncio.create_task(check_billing_status(client, account_id)) | |
| limit_check_task = asyncio.create_task(check_agent_run_limit(client, account_id)) | |
| # Wait for all checks to complete | |
| (can_use, model_message, allowed_models), (can_run, message, subscription), limit_check = await asyncio.gather( | |
| model_check_task, billing_check_task, limit_check_task | |
| ) | |
| # Check results and raise appropriate errors | |
| if not can_use: | |
| raise HTTPException(status_code=403, detail={"message": model_message, "allowed_models": allowed_models}) | |
| if not can_run: | |
| raise HTTPException(status_code=402, detail={"message": message, "subscription": subscription}) | |
| if not limit_check['can_start']: | |
| error_detail = { | |
| "message": f"Maximum of {config.MAX_PARALLEL_AGENT_RUNS} parallel agent runs allowed within 24 hours. You currently have {limit_check['running_count']} running.", | |
| "running_thread_ids": limit_check['running_thread_ids'], | |
| "running_count": limit_check['running_count'], | |
| "limit": config.MAX_PARALLEL_AGENT_RUNS | |
| } | |
| logger.warning(f"Agent run limit exceeded for account {account_id}: {limit_check['running_count']} running agents") | |
| raise HTTPException(status_code=429, detail=error_detail) | |
| effective_model = model_name | |
| if not model_name and agent_config and agent_config.get('model'): | |
| effective_model = agent_config['model'] | |
| logger.debug(f"No model specified by user, using agent's configured model: {effective_model}") | |
| elif model_name: | |
| logger.debug(f"Using user-selected model: {effective_model}") | |
| else: | |
| logger.debug(f"Using default model: {effective_model}") | |
| agent_run = await client.table('agent_runs').insert({ | |
| "thread_id": thread_id, | |
| "status": "running", | |
| "started_at": datetime.now(timezone.utc).isoformat(), | |
| "agent_id": agent_config.get('agent_id') if agent_config else None, | |
| "agent_version_id": agent_config.get('current_version_id') if agent_config else None, | |
| "metadata": { | |
| "model_name": effective_model, | |
| "requested_model": model_name, | |
| "enable_thinking": body.enable_thinking, | |
| "reasoning_effort": body.reasoning_effort, | |
| "enable_context_manager": body.enable_context_manager | |
| } | |
| }).execute() | |
| agent_run_id = agent_run.data[0]['id'] | |
| structlog.contextvars.bind_contextvars( | |
| agent_run_id=agent_run_id, | |
| ) | |
| logger.debug(f"Created new agent run: {agent_run_id}") | |
| instance_key = f"active_run:{instance_id}:{agent_run_id}" | |
| try: | |
| await redis.set(instance_key, "running", ex=redis.REDIS_KEY_TTL) | |
| except Exception as e: | |
| logger.warning(f"Failed to register agent run in Redis ({instance_key}): {str(e)}") | |
| request_id = structlog.contextvars.get_contextvars().get('request_id') | |
| run_agent_background.send( | |
| agent_run_id=agent_run_id, thread_id=thread_id, instance_id=instance_id, | |
| project_id=project_id, | |
| model_name=model_name, # Already resolved above | |
| enable_thinking=body.enable_thinking, reasoning_effort=body.reasoning_effort, | |
| stream=body.stream, enable_context_manager=body.enable_context_manager, | |
| agent_config=agent_config, # Pass agent configuration | |
| request_id=request_id, | |
| ) | |
| return {"agent_run_id": agent_run_id, "status": "running"} | |
| async def stop_agent(agent_run_id: str, user_id: str = Depends(get_current_user_id_from_jwt)): | |
| """Stop a running agent.""" | |
| structlog.contextvars.bind_contextvars( | |
| agent_run_id=agent_run_id, | |
| ) | |
| logger.debug(f"Received request to stop agent run: {agent_run_id}") | |
| client = await db.client | |
| await get_agent_run_with_access_check(client, agent_run_id, user_id) | |
| await stop_agent_run(agent_run_id) | |
| return {"status": "stopped"} | |
| async def get_agent_runs(thread_id: str, user_id: str = Depends(get_current_user_id_from_jwt)): | |
| """Get all agent runs for a thread.""" | |
| structlog.contextvars.bind_contextvars( | |
| thread_id=thread_id, | |
| ) | |
| logger.debug(f"Fetching agent runs for thread: {thread_id}") | |
| client = await db.client | |
| await verify_thread_access(client, thread_id, user_id) | |
| agent_runs = await client.table('agent_runs').select('id, thread_id, status, started_at, completed_at, error, created_at, updated_at').eq("thread_id", thread_id).order('created_at', desc=True).execute() | |
| logger.debug(f"Found {len(agent_runs.data)} agent runs for thread: {thread_id}") | |
| return {"agent_runs": agent_runs.data} | |
| async def get_agent_run(agent_run_id: str, user_id: str = Depends(get_current_user_id_from_jwt)): | |
| """Get agent run status and responses.""" | |
| structlog.contextvars.bind_contextvars( | |
| agent_run_id=agent_run_id, | |
| ) | |
| logger.debug(f"Fetching agent run details: {agent_run_id}") | |
| client = await db.client | |
| agent_run_data = await get_agent_run_with_access_check(client, agent_run_id, user_id) | |
| # Note: Responses are not included here by default, they are in the stream or DB | |
| return { | |
| "id": agent_run_data['id'], | |
| "threadId": agent_run_data['thread_id'], | |
| "status": agent_run_data['status'], | |
| "startedAt": agent_run_data['started_at'], | |
| "completedAt": agent_run_data['completed_at'], | |
| "error": agent_run_data['error'] | |
| } | |
| async def get_thread_agent(thread_id: str, user_id: str = Depends(get_current_user_id_from_jwt)): | |
| """Get the agent details for a specific thread. Since threads are fully agent-agnostic, | |
| this returns the most recently used agent from agent_runs only.""" | |
| structlog.contextvars.bind_contextvars( | |
| thread_id=thread_id, | |
| ) | |
| logger.debug(f"Fetching agent details for thread: {thread_id}") | |
| client = await db.client | |
| try: | |
| # Verify thread access and get thread data | |
| await verify_thread_access(client, thread_id, user_id) | |
| thread_result = await client.table('threads').select('account_id').eq('thread_id', thread_id).execute() | |
| if not thread_result.data: | |
| raise HTTPException(status_code=404, detail="Thread not found") | |
| thread_data = thread_result.data[0] | |
| account_id = thread_data.get('account_id') | |
| effective_agent_id = None | |
| agent_source = "none" | |
| # Get the most recently used agent from agent_runs | |
| recent_agent_result = await client.table('agent_runs').select('agent_id', 'agent_version_id').eq('thread_id', thread_id).not_.is_('agent_id', 'null').order('created_at', desc=True).limit(1).execute() | |
| if recent_agent_result.data: | |
| effective_agent_id = recent_agent_result.data[0]['agent_id'] | |
| recent_version_id = recent_agent_result.data[0].get('agent_version_id') | |
| agent_source = "recent" | |
| logger.debug(f"Found most recently used agent: {effective_agent_id} (version: {recent_version_id})") | |
| # If no agent found in agent_runs | |
| if not effective_agent_id: | |
| return { | |
| "agent": None, | |
| "source": "none", | |
| "message": "No agent has been used in this thread yet. Threads are agent-agnostic - use /agent/start to select an agent." | |
| } | |
| # Fetch the agent details | |
| agent_result = await client.table('agents').select('*').eq('agent_id', effective_agent_id).eq('account_id', account_id).execute() | |
| if not agent_result.data: | |
| # Agent was deleted or doesn't exist | |
| return { | |
| "agent": None, | |
| "source": "missing", | |
| "message": f"Agent {effective_agent_id} not found or was deleted. You can select a different agent." | |
| } | |
| agent_data = agent_result.data[0] | |
| # Use versioning system to get current version data | |
| version_data = None | |
| current_version = None | |
| if agent_data.get('current_version_id'): | |
| try: | |
| version_service = await _get_version_service() | |
| current_version_obj = await version_service.get_version( | |
| agent_id=effective_agent_id, | |
| version_id=agent_data['current_version_id'], | |
| user_id=user_id | |
| ) | |
| current_version_data = current_version_obj.to_dict() | |
| version_data = current_version_data | |
| # Create AgentVersionResponse from version data | |
| current_version = AgentVersionResponse( | |
| version_id=current_version_data['version_id'], | |
| agent_id=current_version_data['agent_id'], | |
| version_number=current_version_data['version_number'], | |
| version_name=current_version_data['version_name'], | |
| system_prompt=current_version_data['system_prompt'], | |
| model=current_version_data.get('model'), | |
| configured_mcps=current_version_data.get('configured_mcps', []), | |
| custom_mcps=current_version_data.get('custom_mcps', []), | |
| agentpress_tools=current_version_data.get('agentpress_tools', {}), | |
| is_active=current_version_data.get('is_active', True), | |
| created_at=current_version_data['created_at'], | |
| updated_at=current_version_data.get('updated_at', current_version_data['created_at']), | |
| created_by=current_version_data.get('created_by') | |
| ) | |
| logger.debug(f"Using agent {agent_data['name']} version {current_version_data.get('version_name', 'v1')}") | |
| except Exception as e: | |
| logger.warning(f"Failed to get version data for agent {effective_agent_id}: {e}") | |
| version_data = None | |
| if current_version: | |
| version_data = { | |
| 'version_id': current_version.version_id, | |
| 'agent_id': current_version.agent_id, | |
| 'version_number': current_version.version_number, | |
| 'version_name': current_version.version_name, | |
| 'system_prompt': current_version.system_prompt, | |
| 'model': current_version.model, | |
| 'configured_mcps': current_version.configured_mcps, | |
| 'custom_mcps': current_version.custom_mcps, | |
| 'agentpress_tools': current_version.agentpress_tools, | |
| 'is_active': current_version.is_active, | |
| 'created_at': current_version.created_at, | |
| 'updated_at': current_version.updated_at, | |
| 'created_by': current_version.created_by | |
| } | |
| from agent.config_helper import extract_agent_config | |
| agent_config = extract_agent_config(agent_data, version_data) | |
| system_prompt = agent_config['system_prompt'] | |
| configured_mcps = agent_config['configured_mcps'] | |
| custom_mcps = agent_config['custom_mcps'] | |
| agentpress_tools = agent_config['agentpress_tools'] | |
| return { | |
| "agent": AgentResponse( | |
| agent_id=agent_data['agent_id'], | |
| name=agent_data['name'], | |
| description=agent_data.get('description'), | |
| system_prompt=system_prompt, | |
| configured_mcps=configured_mcps, | |
| custom_mcps=custom_mcps, | |
| agentpress_tools=agentpress_tools, | |
| is_default=agent_data.get('is_default', False), | |
| is_public=agent_data.get('is_public', False), | |
| tags=agent_data.get('tags', []), | |
| avatar=agent_config.get('avatar'), | |
| avatar_color=agent_config.get('avatar_color'), | |
| profile_image_url=agent_config.get('profile_image_url'), | |
| created_at=agent_data['created_at'], | |
| updated_at=agent_data.get('updated_at', agent_data['created_at']), | |
| current_version_id=agent_data.get('current_version_id'), | |
| version_count=agent_data.get('version_count', 1), | |
| current_version=current_version, | |
| metadata=agent_data.get('metadata') | |
| ), | |
| "source": agent_source, | |
| "message": f"Using {agent_source} agent: {agent_data['name']}. Threads are agent-agnostic - you can change agents anytime." | |
| } | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error fetching agent for thread {thread_id}: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Failed to fetch thread agent: {str(e)}") | |
| async def stream_agent_run( | |
| agent_run_id: str, | |
| token: Optional[str] = None, | |
| request: Request = None | |
| ): | |
| """Stream the responses of an agent run using Redis Lists and Pub/Sub.""" | |
| logger.debug(f"Starting stream for agent run: {agent_run_id}") | |
| client = await db.client | |
| user_id = await get_user_id_from_stream_auth(request, token) # practically instant | |
| agent_run_data = await get_agent_run_with_access_check(client, agent_run_id, user_id) # 1 db query | |
| structlog.contextvars.bind_contextvars( | |
| agent_run_id=agent_run_id, | |
| user_id=user_id, | |
| ) | |
| response_list_key = f"agent_run:{agent_run_id}:responses" | |
| response_channel = f"agent_run:{agent_run_id}:new_response" | |
| control_channel = f"agent_run:{agent_run_id}:control" # Global control channel | |
| async def stream_generator(agent_run_data): | |
| logger.debug(f"Streaming responses for {agent_run_id} using Redis list {response_list_key} and channel {response_channel}") | |
| last_processed_index = -1 | |
| pubsub_response = None | |
| pubsub_control = None | |
| listener_task = None | |
| terminate_stream = False | |
| initial_yield_complete = False | |
| try: | |
| # 1. Fetch and yield initial responses from Redis list | |
| initial_responses_json = await redis.lrange(response_list_key, 0, -1) | |
| initial_responses = [] | |
| if initial_responses_json: | |
| initial_responses = [json.loads(r) for r in initial_responses_json] | |
| logger.debug(f"Sending {len(initial_responses)} initial responses for {agent_run_id}") | |
| for response in initial_responses: | |
| yield f"data: {json.dumps(response)}\n\n" | |
| last_processed_index = len(initial_responses) - 1 | |
| initial_yield_complete = True | |
| # 2. Check run status | |
| current_status = agent_run_data.get('status') if agent_run_data else None | |
| if current_status != 'running': | |
| logger.debug(f"Agent run {agent_run_id} is not running (status: {current_status}). Ending stream.") | |
| yield f"data: {json.dumps({'type': 'status', 'status': 'completed'})}\n\n" | |
| return | |
| structlog.contextvars.bind_contextvars( | |
| thread_id=agent_run_data.get('thread_id'), | |
| ) | |
| # 3. Set up Pub/Sub listeners for new responses and control signals concurrently | |
| pubsub_response_task = asyncio.create_task(redis.create_pubsub()) | |
| pubsub_control_task = asyncio.create_task(redis.create_pubsub()) | |
| pubsub_response, pubsub_control = await asyncio.gather(pubsub_response_task, pubsub_control_task) | |
| # Subscribe to channels concurrently | |
| response_subscribe_task = asyncio.create_task(pubsub_response.subscribe(response_channel)) | |
| control_subscribe_task = asyncio.create_task(pubsub_control.subscribe(control_channel)) | |
| await asyncio.gather(response_subscribe_task, control_subscribe_task) | |
| logger.debug(f"Subscribed to response channel: {response_channel}") | |
| logger.debug(f"Subscribed to control channel: {control_channel}") | |
| # Queue to communicate between listeners and the main generator loop | |
| message_queue = asyncio.Queue() | |
| async def listen_messages(): | |
| response_reader = pubsub_response.listen() | |
| control_reader = pubsub_control.listen() | |
| tasks = [asyncio.create_task(response_reader.__anext__()), asyncio.create_task(control_reader.__anext__())] | |
| while not terminate_stream: | |
| done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) | |
| for task in done: | |
| try: | |
| message = task.result() | |
| if message and isinstance(message, dict) and message.get("type") == "message": | |
| channel = message.get("channel") | |
| data = message.get("data") | |
| if isinstance(data, bytes): data = data.decode('utf-8') | |
| if channel == response_channel and data == "new": | |
| await message_queue.put({"type": "new_response"}) | |
| elif channel == control_channel and data in ["STOP", "END_STREAM", "ERROR"]: | |
| logger.debug(f"Received control signal '{data}' for {agent_run_id}") | |
| await message_queue.put({"type": "control", "data": data}) | |
| return # Stop listening on control signal | |
| except StopAsyncIteration: | |
| logger.warning(f"Listener {task} stopped.") | |
| # Decide how to handle listener stopping, maybe terminate? | |
| await message_queue.put({"type": "error", "data": "Listener stopped unexpectedly"}) | |
| return | |
| except Exception as e: | |
| logger.error(f"Error in listener for {agent_run_id}: {e}") | |
| await message_queue.put({"type": "error", "data": "Listener failed"}) | |
| return | |
| finally: | |
| # Reschedule the completed listener task | |
| if task in tasks: | |
| tasks.remove(task) | |
| if message and isinstance(message, dict) and message.get("channel") == response_channel: | |
| tasks.append(asyncio.create_task(response_reader.__anext__())) | |
| elif message and isinstance(message, dict) and message.get("channel") == control_channel: | |
| tasks.append(asyncio.create_task(control_reader.__anext__())) | |
| # Cancel pending listener tasks on exit | |
| for p_task in pending: p_task.cancel() | |
| for task in tasks: task.cancel() | |
| listener_task = asyncio.create_task(listen_messages()) | |
| # 4. Main loop to process messages from the queue | |
| while not terminate_stream: | |
| try: | |
| queue_item = await message_queue.get() | |
| if queue_item["type"] == "new_response": | |
| # Fetch new responses from Redis list starting after the last processed index | |
| new_start_index = last_processed_index + 1 | |
| new_responses_json = await redis.lrange(response_list_key, new_start_index, -1) | |
| if new_responses_json: | |
| new_responses = [json.loads(r) for r in new_responses_json] | |
| num_new = len(new_responses) | |
| # logger.debug(f"Received {num_new} new responses for {agent_run_id} (index {new_start_index} onwards)") | |
| for response in new_responses: | |
| yield f"data: {json.dumps(response)}\n\n" | |
| # Check if this response signals completion | |
| if response.get('type') == 'status' and response.get('status') in ['completed', 'failed', 'stopped']: | |
| logger.debug(f"Detected run completion via status message in stream: {response.get('status')}") | |
| terminate_stream = True | |
| break # Stop processing further new responses | |
| last_processed_index += num_new | |
| if terminate_stream: break | |
| elif queue_item["type"] == "control": | |
| control_signal = queue_item["data"] | |
| terminate_stream = True # Stop the stream on any control signal | |
| yield f"data: {json.dumps({'type': 'status', 'status': control_signal})}\n\n" | |
| break | |
| elif queue_item["type"] == "error": | |
| logger.error(f"Listener error for {agent_run_id}: {queue_item['data']}") | |
| terminate_stream = True | |
| yield f"data: {json.dumps({'type': 'status', 'status': 'error'})}\n\n" | |
| break | |
| except asyncio.CancelledError: | |
| logger.debug(f"Stream generator main loop cancelled for {agent_run_id}") | |
| terminate_stream = True | |
| break | |
| except Exception as loop_err: | |
| logger.error(f"Error in stream generator main loop for {agent_run_id}: {loop_err}", exc_info=True) | |
| terminate_stream = True | |
| yield f"data: {json.dumps({'type': 'status', 'status': 'error', 'message': f'Stream failed: {loop_err}'})}\n\n" | |
| break | |
| except Exception as e: | |
| logger.error(f"Error setting up stream for agent run {agent_run_id}: {e}", exc_info=True) | |
| # Only yield error if initial yield didn't happen | |
| if not initial_yield_complete: | |
| yield f"data: {json.dumps({'type': 'status', 'status': 'error', 'message': f'Failed to start stream: {e}'})}\n\n" | |
| finally: | |
| terminate_stream = True | |
| # Graceful shutdown order: unsubscribe → close → cancel | |
| if pubsub_response: await pubsub_response.unsubscribe(response_channel) | |
| if pubsub_control: await pubsub_control.unsubscribe(control_channel) | |
| if pubsub_response: await pubsub_response.close() | |
| if pubsub_control: await pubsub_control.close() | |
| if listener_task: | |
| listener_task.cancel() | |
| try: | |
| await listener_task # Reap inner tasks & swallow their errors | |
| except asyncio.CancelledError: | |
| pass | |
| except Exception as e: | |
| logger.debug(f"listener_task ended with: {e}") | |
| # Wait briefly for tasks to cancel | |
| await asyncio.sleep(0.1) | |
| logger.debug(f"Streaming cleanup complete for agent run: {agent_run_id}") | |
| return StreamingResponse(stream_generator(agent_run_data), media_type="text/event-stream", headers={ | |
| "Cache-Control": "no-cache, no-transform", "Connection": "keep-alive", | |
| "X-Accel-Buffering": "no", "Content-Type": "text/event-stream", | |
| "Access-Control-Allow-Origin": "*" | |
| }) | |
| async def generate_and_update_project_name(project_id: str, prompt: str): | |
| """Generates a project name using an LLM and updates the database.""" | |
| logger.debug(f"Starting background task to generate name for project: {project_id}") | |
| try: | |
| db_conn = DBConnection() | |
| client = await db_conn.client | |
| model_name = "openai/gpt-4o-mini" | |
| system_prompt = "You are a helpful assistant that generates extremely concise titles (2-4 words maximum) for chat threads based on the user's message. Respond with only the title, no other text or punctuation." | |
| user_message = f"Generate an extremely brief title (2-4 words only) for a chat thread that starts with this message: \"{prompt}\"" | |
| messages = [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_message}] | |
| logger.debug(f"Calling LLM ({model_name}) for project {project_id} naming.") | |
| response = await make_llm_api_call(messages=messages, model_name=model_name, max_tokens=20, temperature=0.7) | |
| generated_name = None | |
| if response and response.get('choices') and response['choices'][0].get('message'): | |
| raw_name = response['choices'][0]['message'].get('content', '').strip() | |
| cleaned_name = raw_name.strip('\'" \n\t') | |
| if cleaned_name: | |
| generated_name = cleaned_name | |
| logger.debug(f"LLM generated name for project {project_id}: '{generated_name}'") | |
| else: | |
| logger.warning(f"LLM returned an empty name for project {project_id}.") | |
| else: | |
| logger.warning(f"Failed to get valid response from LLM for project {project_id} naming. Response: {response}") | |
| if generated_name: | |
| update_result = await client.table('projects').update({"name": generated_name}).eq("project_id", project_id).execute() | |
| if hasattr(update_result, 'data') and update_result.data: | |
| logger.debug(f"Successfully updated project {project_id} name to '{generated_name}'") | |
| else: | |
| logger.error(f"Failed to update project {project_id} name in database. Update result: {update_result}") | |
| else: | |
| logger.warning(f"No generated name, skipping database update for project {project_id}.") | |
| except Exception as e: | |
| logger.error(f"Error in background naming task for project {project_id}: {str(e)}\n{traceback.format_exc()}") | |
| finally: | |
| # No need to disconnect DBConnection singleton instance here | |
| logger.debug(f"Finished background naming task for project: {project_id}") | |
| async def initiate_agent_with_files( | |
| prompt: str = Form(...), | |
| model_name: Optional[str] = Form(None), # Default to None to use default model | |
| enable_thinking: Optional[bool] = Form(False), | |
| reasoning_effort: Optional[str] = Form("low"), | |
| stream: Optional[bool] = Form(True), | |
| enable_context_manager: Optional[bool] = Form(False), | |
| agent_id: Optional[str] = Form(None), # Add agent_id parameter | |
| files: List[UploadFile] = File(default=[]), | |
| user_id: str = Depends(get_current_user_id_from_jwt) | |
| ): | |
| """ | |
| Initiate a new agent session with optional file attachments. | |
| [WARNING] Keep in sync with create thread endpoint. | |
| """ | |
| global instance_id # Ensure instance_id is accessible | |
| if not instance_id: | |
| raise HTTPException(status_code=500, detail="Agent API not initialized with instance ID") | |
| # Use model from config if not specified in the request | |
| logger.debug(f"Original model_name from request: {model_name}") | |
| if model_name is None: | |
| model_name = "openai/gpt-5-mini" | |
| logger.debug(f"Using default model: {model_name}") | |
| from models import model_manager | |
| # Log the model name after alias resolution using new model manager | |
| resolved_model = model_manager.resolve_model_id(model_name) | |
| logger.debug(f"Resolved model name: {resolved_model}") | |
| # Update model_name to use the resolved version | |
| model_name = resolved_model | |
| logger.debug(f"[\033[91mDEBUG\033[0m] Initiating new agent with prompt and {len(files)} files (Instance: {instance_id}), model: {model_name}, enable_thinking: {enable_thinking}") | |
| client = await db.client | |
| account_id = user_id # In Basejump, personal account_id is the same as user_id | |
| # Load agent configuration with version support (same as start_agent endpoint) | |
| agent_config = None | |
| logger.debug(f"[AGENT INITIATE] Agent loading flow:") | |
| logger.debug(f" - agent_id param: {agent_id}") | |
| if agent_id: | |
| logger.debug(f"[AGENT INITIATE] Querying for specific agent: {agent_id}") | |
| # Get agent | |
| agent_result = await client.table('agents').select('*').eq('agent_id', agent_id).eq('account_id', account_id).execute() | |
| logger.debug(f"[AGENT INITIATE] Query result: found {len(agent_result.data) if agent_result.data else 0} agents") | |
| if not agent_result.data: | |
| raise HTTPException(status_code=404, detail="Agent not found or access denied") | |
| agent_data = agent_result.data[0] | |
| # Use versioning system to get current version | |
| version_data = None | |
| if agent_data.get('current_version_id'): | |
| try: | |
| version_service = await _get_version_service() | |
| version_obj = await version_service.get_version( | |
| agent_id=agent_id, | |
| version_id=agent_data['current_version_id'], | |
| user_id=user_id | |
| ) | |
| version_data = version_obj.to_dict() | |
| logger.debug(f"[AGENT INITIATE] Got version data from version manager: {version_data.get('version_name')}") | |
| logger.debug(f"[AGENT INITIATE] Version data: {version_data}") | |
| except Exception as e: | |
| logger.warning(f"[AGENT INITIATE] Failed to get version data: {e}") | |
| logger.debug(f"[AGENT INITIATE] About to call extract_agent_config with version data: {version_data is not None}") | |
| agent_config = extract_agent_config(agent_data, version_data) | |
| if version_data: | |
| logger.debug(f"Using custom agent: {agent_config['name']} ({agent_id}) version {agent_config.get('version_name', 'v1')}") | |
| else: | |
| logger.debug(f"Using custom agent: {agent_config['name']} ({agent_id}) - no version data") | |
| else: | |
| logger.debug(f"[AGENT INITIATE] No agent_id provided, querying for default agent") | |
| # Try to get default agent for the account | |
| default_agent_result = await client.table('agents').select('*').eq('account_id', account_id).eq('is_default', True).execute() | |
| logger.debug(f"[AGENT INITIATE] Default agent query result: found {len(default_agent_result.data) if default_agent_result.data else 0} default agents") | |
| if default_agent_result.data: | |
| agent_data = default_agent_result.data[0] | |
| # Use versioning system to get current version | |
| version_data = None | |
| if agent_data.get('current_version_id'): | |
| try: | |
| version_service = await _get_version_service() | |
| version_obj = await version_service.get_version( | |
| agent_id=agent_data['agent_id'], | |
| version_id=agent_data['current_version_id'], | |
| user_id=user_id | |
| ) | |
| version_data = version_obj.to_dict() | |
| logger.debug(f"[AGENT INITIATE] Got default agent version from version manager: {version_data.get('version_name')}") | |
| except Exception as e: | |
| logger.warning(f"[AGENT INITIATE] Failed to get default agent version data: {e}") | |
| logger.debug(f"[AGENT INITIATE] About to call extract_agent_config for DEFAULT agent with version data: {version_data is not None}") | |
| agent_config = extract_agent_config(agent_data, version_data) | |
| if version_data: | |
| logger.debug(f"Using default agent: {agent_config['name']} ({agent_config['agent_id']}) version {agent_config.get('version_name', 'v1')}") | |
| else: | |
| logger.debug(f"Using default agent: {agent_config['name']} ({agent_config['agent_id']}) - no version data") | |
| else: | |
| logger.warning(f"[AGENT INITIATE] No default agent found for account {account_id}") | |
| logger.debug(f"[AGENT INITIATE] Final agent_config: {agent_config is not None}") | |
| if agent_config: | |
| logger.debug(f"[AGENT INITIATE] Agent config keys: {list(agent_config.keys())}") | |
| # Run all checks concurrently | |
| model_check_task = asyncio.create_task(can_use_model(client, account_id, model_name)) | |
| billing_check_task = asyncio.create_task(check_billing_status(client, account_id)) | |
| limit_check_task = asyncio.create_task(check_agent_run_limit(client, account_id)) | |
| # Wait for all checks to complete | |
| (can_use, model_message, allowed_models), (can_run, message, subscription), limit_check = await asyncio.gather( | |
| model_check_task, billing_check_task, limit_check_task | |
| ) | |
| # Check results and raise appropriate errors | |
| if not can_use: | |
| raise HTTPException(status_code=403, detail={"message": model_message, "allowed_models": allowed_models}) | |
| can_run, message, subscription = await check_billing_status(client, account_id) | |
| if not can_run: | |
| raise HTTPException(status_code=402, detail={"message": message, "subscription": subscription}) | |
| # Check agent run limit (maximum parallel runs in past 24 hours) | |
| limit_check = await check_agent_run_limit(client, account_id) | |
| if not limit_check['can_start']: | |
| error_detail = { | |
| "message": f"Maximum of {config.MAX_PARALLEL_AGENT_RUNS} parallel agent runs allowed within 24 hours. You currently have {limit_check['running_count']} running.", | |
| "running_thread_ids": limit_check['running_thread_ids'], | |
| "running_count": limit_check['running_count'], | |
| "limit": config.MAX_PARALLEL_AGENT_RUNS | |
| } | |
| logger.warning(f"Agent run limit exceeded for account {account_id}: {limit_check['running_count']} running agents") | |
| raise HTTPException(status_code=429, detail=error_detail) | |
| try: | |
| # 1. Create Project | |
| placeholder_name = f"{prompt[:30]}..." if len(prompt) > 30 else prompt | |
| project = await client.table('projects').insert({ | |
| "project_id": str(uuid.uuid4()), "account_id": account_id, "name": placeholder_name, | |
| "created_at": datetime.now(timezone.utc).isoformat() | |
| }).execute() | |
| project_id = project.data[0]['project_id'] | |
| logger.info(f"Created new project: {project_id}") | |
| # 2. Create Sandbox (lazy): only create now if files were uploaded and need the | |
| # sandbox immediately. Otherwise leave sandbox creation to `_ensure_sandbox()` | |
| # which will create it lazily when tools require it. | |
| sandbox_id = None | |
| sandbox = None | |
| sandbox_pass = None | |
| vnc_url = None | |
| website_url = None | |
| token = None | |
| if files: | |
| # 3. Create Sandbox (lazy): only create now if files were uploaded and need the | |
| try: | |
| sandbox_pass = str(uuid.uuid4()) | |
| sandbox = await create_sandbox(sandbox_pass, project_id) | |
| sandbox_id = sandbox.id | |
| logger.info(f"Created new sandbox {sandbox_id} for project {project_id}") | |
| # Get preview links | |
| vnc_link = await sandbox.get_preview_link(6080) | |
| website_link = await sandbox.get_preview_link(8080) | |
| vnc_url = vnc_link.url if hasattr(vnc_link, 'url') else str(vnc_link).split("url='")[1].split("'")[0] | |
| website_url = website_link.url if hasattr(website_link, 'url') else str(website_link).split("url='")[1].split("'")[0] | |
| token = None | |
| if hasattr(vnc_link, 'token'): | |
| token = vnc_link.token | |
| elif "token='" in str(vnc_link): | |
| token = str(vnc_link).split("token='")[1].split("'")[0] | |
| # Update project with sandbox info | |
| update_result = await client.table('projects').update({ | |
| 'sandbox': { | |
| 'id': sandbox_id, 'pass': sandbox_pass, 'vnc_preview': vnc_url, | |
| 'sandbox_url': website_url, 'token': token | |
| } | |
| }).eq('project_id', project_id).execute() | |
| if not update_result.data: | |
| logger.error(f"Failed to update project {project_id} with new sandbox {sandbox_id}") | |
| if sandbox_id: | |
| try: await delete_sandbox(sandbox_id) | |
| except Exception as e: logger.error(f"Error deleting sandbox: {str(e)}") | |
| raise Exception("Database update failed") | |
| except Exception as e: | |
| logger.error(f"Error creating sandbox: {str(e)}") | |
| await client.table('projects').delete().eq('project_id', project_id).execute() | |
| if sandbox_id: | |
| try: await delete_sandbox(sandbox_id) | |
| except Exception: | |
| pass | |
| raise Exception("Failed to create sandbox") | |
| # 3. Create Thread | |
| thread_data = { | |
| "thread_id": str(uuid.uuid4()), | |
| "project_id": project_id, | |
| "account_id": account_id, | |
| "created_at": datetime.now(timezone.utc).isoformat() | |
| } | |
| structlog.contextvars.bind_contextvars( | |
| thread_id=thread_data["thread_id"], | |
| project_id=project_id, | |
| account_id=account_id, | |
| ) | |
| # Don't store agent_id in thread since threads are now agent-agnostic | |
| # The agent selection will be handled per message/agent run | |
| if agent_config: | |
| logger.debug(f"Using agent {agent_config['agent_id']} for this conversation (thread remains agent-agnostic)") | |
| structlog.contextvars.bind_contextvars( | |
| agent_id=agent_config['agent_id'], | |
| ) | |
| thread = await client.table('threads').insert(thread_data).execute() | |
| thread_id = thread.data[0]['thread_id'] | |
| logger.debug(f"Created new thread: {thread_id}") | |
| # Trigger Background Naming Task | |
| asyncio.create_task(generate_and_update_project_name(project_id=project_id, prompt=prompt)) | |
| # 4. Upload Files to Sandbox (if any) | |
| message_content = prompt | |
| if files: | |
| successful_uploads = [] | |
| failed_uploads = [] | |
| for file in files: | |
| if file.filename: | |
| try: | |
| safe_filename = file.filename.replace('/', '_').replace('\\', '_') | |
| target_path = f"/workspace/{safe_filename}" | |
| logger.debug(f"Attempting to upload {safe_filename} to {target_path} in sandbox {sandbox_id}") | |
| content = await file.read() | |
| upload_successful = False | |
| try: | |
| if hasattr(sandbox, 'fs') and hasattr(sandbox.fs, 'upload_file'): | |
| await sandbox.fs.upload_file(content, target_path) | |
| logger.debug(f"Called sandbox.fs.upload_file for {target_path}") | |
| upload_successful = True | |
| else: | |
| raise NotImplementedError("Suitable upload method not found on sandbox object.") | |
| except Exception as upload_error: | |
| logger.error(f"Error during sandbox upload call for {safe_filename}: {str(upload_error)}", exc_info=True) | |
| if upload_successful: | |
| try: | |
| await asyncio.sleep(0.2) | |
| parent_dir = os.path.dirname(target_path) | |
| files_in_dir = await sandbox.fs.list_files(parent_dir) | |
| file_names_in_dir = [f.name for f in files_in_dir] | |
| if safe_filename in file_names_in_dir: | |
| successful_uploads.append(target_path) | |
| logger.debug(f"Successfully uploaded and verified file {safe_filename} to sandbox path {target_path}") | |
| else: | |
| logger.error(f"Verification failed for {safe_filename}: File not found in {parent_dir} after upload attempt.") | |
| failed_uploads.append(safe_filename) | |
| except Exception as verify_error: | |
| logger.error(f"Error verifying file {safe_filename} after upload: {str(verify_error)}", exc_info=True) | |
| failed_uploads.append(safe_filename) | |
| else: | |
| failed_uploads.append(safe_filename) | |
| except Exception as file_error: | |
| logger.error(f"Error processing file {file.filename}: {str(file_error)}", exc_info=True) | |
| failed_uploads.append(file.filename) | |
| finally: | |
| await file.close() | |
| if successful_uploads: | |
| message_content += "\n\n" if message_content else "" | |
| for file_path in successful_uploads: message_content += f"[Uploaded File: {file_path}]\n" | |
| if failed_uploads: | |
| message_content += "\n\nThe following files failed to upload:\n" | |
| for failed_file in failed_uploads: message_content += f"- {failed_file}\n" | |
| # 5. Add initial user message to thread | |
| message_id = str(uuid.uuid4()) | |
| message_payload = {"role": "user", "content": message_content} | |
| await client.table('messages').insert({ | |
| "message_id": message_id, "thread_id": thread_id, "type": "user", | |
| "is_llm_message": True, "content": json.dumps(message_payload), | |
| "created_at": datetime.now(timezone.utc).isoformat() | |
| }).execute() | |
| effective_model = model_name | |
| if not model_name and agent_config and agent_config.get('model'): | |
| effective_model = agent_config['model'] | |
| logger.debug(f"No model specified by user, using agent's configured model: {effective_model}") | |
| elif model_name: | |
| logger.debug(f"Using user-selected model: {effective_model}") | |
| else: | |
| logger.debug(f"Using default model: {effective_model}") | |
| agent_run = await client.table('agent_runs').insert({ | |
| "thread_id": thread_id, "status": "running", | |
| "started_at": datetime.now(timezone.utc).isoformat(), | |
| "agent_id": agent_config.get('agent_id') if agent_config else None, | |
| "agent_version_id": agent_config.get('current_version_id') if agent_config else None, | |
| "metadata": { | |
| "model_name": effective_model, | |
| "requested_model": model_name, | |
| "enable_thinking": enable_thinking, | |
| "reasoning_effort": reasoning_effort, | |
| "enable_context_manager": enable_context_manager | |
| } | |
| }).execute() | |
| agent_run_id = agent_run.data[0]['id'] | |
| logger.debug(f"Created new agent run: {agent_run_id}") | |
| structlog.contextvars.bind_contextvars( | |
| agent_run_id=agent_run_id, | |
| ) | |
| # Register run in Redis | |
| instance_key = f"active_run:{instance_id}:{agent_run_id}" | |
| try: | |
| await redis.set(instance_key, "running", ex=redis.REDIS_KEY_TTL) | |
| except Exception as e: | |
| logger.warning(f"Failed to register agent run in Redis ({instance_key}): {str(e)}") | |
| request_id = structlog.contextvars.get_contextvars().get('request_id') | |
| # Run agent in background | |
| run_agent_background.send( | |
| agent_run_id=agent_run_id, thread_id=thread_id, instance_id=instance_id, | |
| project_id=project_id, | |
| model_name=model_name, # Already resolved above | |
| enable_thinking=enable_thinking, reasoning_effort=reasoning_effort, | |
| stream=stream, enable_context_manager=enable_context_manager, | |
| agent_config=agent_config, # Pass agent configuration | |
| request_id=request_id, | |
| ) | |
| return {"thread_id": thread_id, "agent_run_id": agent_run_id} | |
| except Exception as e: | |
| logger.error(f"Error in agent initiation: {str(e)}\n{traceback.format_exc()}") | |
| # TODO: Clean up created project/thread if initiation fails mid-way | |
| raise HTTPException(status_code=500, detail=f"Failed to initiate agent session: {str(e)}") | |
| # Custom agents | |
| async def get_agents( | |
| user_id: str = Depends(get_current_user_id_from_jwt), | |
| page: Optional[int] = Query(1, ge=1, description="Page number (1-based)"), | |
| limit: Optional[int] = Query(20, ge=1, le=100, description="Number of items per page"), | |
| search: Optional[str] = Query(None, description="Search in name and description"), | |
| sort_by: Optional[str] = Query("created_at", description="Sort field: name, created_at, updated_at, tools_count"), | |
| sort_order: Optional[str] = Query("desc", description="Sort order: asc, desc"), | |
| has_default: Optional[bool] = Query(None, description="Filter by default agents"), | |
| has_mcp_tools: Optional[bool] = Query(None, description="Filter by agents with MCP tools"), | |
| has_agentpress_tools: Optional[bool] = Query(None, description="Filter by agents with AgentPress tools"), | |
| tools: Optional[str] = Query(None, description="Comma-separated list of tools to filter by") | |
| ): | |
| """Get agents for the current user with pagination, search, sort, and filter support.""" | |
| if not await is_enabled("custom_agents"): | |
| raise HTTPException( | |
| status_code=403, | |
| detail="Custom agents currently disabled. This feature is not available at the moment." | |
| ) | |
| logger.debug(f"Fetching agents for user: {user_id} with page={page}, limit={limit}, search='{search}', sort_by={sort_by}, sort_order={sort_order}") | |
| client = await db.client | |
| try: | |
| # Calculate offset | |
| offset = (page - 1) * limit | |
| # Start building the query | |
| query = client.table('agents').select('*', count='exact').eq("account_id", user_id) | |
| # Apply search filter | |
| if search: | |
| search_term = f"%{search}%" | |
| query = query.or_(f"name.ilike.{search_term},description.ilike.{search_term}") | |
| # Apply filters | |
| if has_default is not None: | |
| query = query.eq("is_default", has_default) | |
| # For MCP and AgentPress tools filtering, we'll need to do post-processing | |
| # since Supabase doesn't have great JSON array/object filtering | |
| # Apply sorting | |
| if sort_by == "name": | |
| query = query.order("name", desc=(sort_order == "desc")) | |
| elif sort_by == "updated_at": | |
| query = query.order("updated_at", desc=(sort_order == "desc")) | |
| elif sort_by == "created_at": | |
| query = query.order("created_at", desc=(sort_order == "desc")) | |
| else: | |
| # Default to created_at | |
| query = query.order("created_at", desc=(sort_order == "desc")) | |
| # Get paginated data and total count in one request | |
| query = query.range(offset, offset + limit - 1) | |
| agents_result = await query.execute() | |
| total_count = agents_result.count if agents_result.count is not None else 0 | |
| if not agents_result.data: | |
| logger.debug(f"No agents found for user: {user_id}") | |
| return { | |
| "agents": [], | |
| "pagination": { | |
| "page": page, | |
| "limit": limit, | |
| "total": 0, | |
| "pages": 0 | |
| } | |
| } | |
| # Post-process for tool filtering and tools_count sorting | |
| agents_data = agents_result.data | |
| # First, fetch version data for all agents to ensure we have correct tool info | |
| # Do this in a single batched query instead of per-agent service calls | |
| agent_version_map = {} | |
| version_ids = list({agent['current_version_id'] for agent in agents_data if agent.get('current_version_id')}) | |
| if version_ids: | |
| try: | |
| from utils.query_utils import batch_query_in | |
| versions_data = await batch_query_in( | |
| client=client, | |
| table_name='agent_versions', | |
| select_fields='version_id, agent_id, version_number, version_name, is_active, created_at, updated_at, created_by, config', | |
| in_field='version_id', | |
| in_values=version_ids | |
| ) | |
| for row in versions_data: | |
| config = row.get('config') or {} | |
| tools = config.get('tools') or {} | |
| version_dict = { | |
| 'version_id': row['version_id'], | |
| 'agent_id': row['agent_id'], | |
| 'version_number': row['version_number'], | |
| 'version_name': row['version_name'], | |
| 'system_prompt': config.get('system_prompt', ''), | |
| 'configured_mcps': tools.get('mcp', []), | |
| 'custom_mcps': tools.get('custom_mcp', []), | |
| 'agentpress_tools': tools.get('agentpress', {}), | |
| 'is_active': row.get('is_active', False), | |
| 'created_at': row.get('created_at'), | |
| 'updated_at': row.get('updated_at') or row.get('created_at'), | |
| 'created_by': row.get('created_by'), | |
| } | |
| agent_version_map[row['agent_id']] = version_dict | |
| except Exception as e: | |
| logger.warning(f"Failed to batch load versions for agents: {e}") | |
| # Apply tool-based filters using version data | |
| if has_mcp_tools is not None or has_agentpress_tools is not None or tools: | |
| filtered_agents = [] | |
| tools_filter = [] | |
| if tools: | |
| # Handle case where tools might be passed as dict instead of string | |
| if isinstance(tools, str): | |
| tools_filter = [tool.strip() for tool in tools.split(',') if tool.strip()] | |
| elif isinstance(tools, dict): | |
| # If tools is a dict, log the issue and skip filtering | |
| logger.warning(f"Received tools parameter as dict instead of string: {tools}") | |
| tools_filter = [] | |
| elif isinstance(tools, list): | |
| # If tools is a list, use it directly | |
| tools_filter = [str(tool).strip() for tool in tools if str(tool).strip()] | |
| else: | |
| logger.warning(f"Unexpected tools parameter type: {type(tools)}, value: {tools}") | |
| tools_filter = [] | |
| for agent in agents_data: | |
| # Get version data if available and extract configuration | |
| version_data = agent_version_map.get(agent['agent_id']) | |
| from agent.config_helper import extract_agent_config | |
| agent_config = extract_agent_config(agent, version_data) | |
| configured_mcps = agent_config['configured_mcps'] | |
| agentpress_tools = agent_config['agentpress_tools'] | |
| # Check MCP tools filter | |
| if has_mcp_tools is not None: | |
| has_mcp = bool(configured_mcps and len(configured_mcps) > 0) | |
| if has_mcp_tools != has_mcp: | |
| continue | |
| # Check AgentPress tools filter | |
| if has_agentpress_tools is not None: | |
| has_enabled_tools = any( | |
| tool_data and isinstance(tool_data, dict) and tool_data.get('enabled', False) | |
| for tool_data in agentpress_tools.values() | |
| ) | |
| if has_agentpress_tools != has_enabled_tools: | |
| continue | |
| # Check specific tools filter | |
| if tools_filter: | |
| agent_tools = set() | |
| # Add MCP tools | |
| for mcp in configured_mcps: | |
| if isinstance(mcp, dict) and 'name' in mcp: | |
| agent_tools.add(f"mcp:{mcp['name']}") | |
| # Add enabled AgentPress tools | |
| for tool_name, tool_data in agentpress_tools.items(): | |
| if tool_data and isinstance(tool_data, dict) and tool_data.get('enabled', False): | |
| agent_tools.add(f"agentpress:{tool_name}") | |
| # Check if any of the requested tools are present | |
| if not any(tool in agent_tools for tool in tools_filter): | |
| continue | |
| filtered_agents.append(agent) | |
| agents_data = filtered_agents | |
| # Handle tools_count sorting (post-processing required) | |
| if sort_by == "tools_count": | |
| def get_tools_count(agent): | |
| # Get version data if available | |
| version_data = agent_version_map.get(agent['agent_id']) | |
| # Use version data for tools if available, otherwise fallback to agent data | |
| if version_data: | |
| configured_mcps = version_data.get('configured_mcps', []) | |
| agentpress_tools = version_data.get('agentpress_tools', {}) | |
| else: | |
| configured_mcps = agent.get('configured_mcps', []) | |
| agentpress_tools = agent.get('agentpress_tools', {}) | |
| mcp_count = len(configured_mcps) | |
| agentpress_count = sum( | |
| 1 for tool_data in agentpress_tools.values() | |
| if tool_data and isinstance(tool_data, dict) and tool_data.get('enabled', False) | |
| ) | |
| return mcp_count + agentpress_count | |
| agents_data.sort(key=get_tools_count, reverse=(sort_order == "desc")) | |
| # Apply pagination to filtered results if we did post-processing | |
| if has_mcp_tools is not None or has_agentpress_tools is not None or tools or sort_by == "tools_count": | |
| total_count = len(agents_data) | |
| agents_data = agents_data[offset:offset + limit] | |
| # Format the response | |
| agent_list = [] | |
| for agent in agents_data: | |
| current_version = None | |
| # Use already fetched version data from agent_version_map | |
| version_dict = agent_version_map.get(agent['agent_id']) | |
| if version_dict: | |
| try: | |
| current_version = AgentVersionResponse( | |
| version_id=version_dict['version_id'], | |
| agent_id=version_dict['agent_id'], | |
| version_number=version_dict['version_number'], | |
| version_name=version_dict['version_name'], | |
| system_prompt=version_dict['system_prompt'], | |
| model=version_dict.get('model'), | |
| configured_mcps=version_dict.get('configured_mcps', []), | |
| custom_mcps=version_dict.get('custom_mcps', []), | |
| agentpress_tools=version_dict.get('agentpress_tools', {}), | |
| is_active=version_dict.get('is_active', True), | |
| created_at=version_dict['created_at'], | |
| updated_at=version_dict.get('updated_at', version_dict['created_at']), | |
| created_by=version_dict.get('created_by') | |
| ) | |
| except Exception as e: | |
| logger.warning(f"Failed to get version data for agent {agent['agent_id']}: {e}") | |
| # Extract configuration using the unified config approach | |
| from agent.config_helper import extract_agent_config | |
| agent_config = extract_agent_config(agent, version_dict) | |
| system_prompt = agent_config['system_prompt'] | |
| configured_mcps = agent_config['configured_mcps'] | |
| custom_mcps = agent_config['custom_mcps'] | |
| agentpress_tools = agent_config['agentpress_tools'] | |
| agent_list.append(AgentResponse( | |
| agent_id=agent['agent_id'], | |
| name=agent['name'], | |
| description=agent.get('description'), | |
| system_prompt=system_prompt, | |
| configured_mcps=configured_mcps, | |
| custom_mcps=custom_mcps, | |
| agentpress_tools=agentpress_tools, | |
| is_default=agent.get('is_default', False), | |
| is_public=agent.get('is_public', False), | |
| tags=agent.get('tags', []), | |
| avatar=agent_config.get('avatar'), | |
| avatar_color=agent_config.get('avatar_color'), | |
| profile_image_url=agent_config.get('profile_image_url'), | |
| created_at=agent['created_at'], | |
| updated_at=agent['updated_at'], | |
| current_version_id=agent.get('current_version_id'), | |
| version_count=agent.get('version_count', 1), | |
| current_version=current_version, | |
| metadata=agent.get('metadata') | |
| )) | |
| total_pages = (total_count + limit - 1) // limit | |
| logger.debug(f"Found {len(agent_list)} agents for user: {user_id} (page {page}/{total_pages})") | |
| return { | |
| "agents": agent_list, | |
| "pagination": { | |
| "page": page, | |
| "limit": limit, | |
| "total": total_count, | |
| "pages": total_pages | |
| } | |
| } | |
| except Exception as e: | |
| logger.error(f"Error fetching agents for user {user_id}: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Failed to fetch agents: {str(e)}") | |
| async def get_agent(agent_id: str, user_id: str = Depends(get_current_user_id_from_jwt)): | |
| """Get a specific agent by ID with current version information. Only the owner can access non-public agents.""" | |
| if not await is_enabled("custom_agents"): | |
| raise HTTPException( | |
| status_code=403, | |
| detail="Custom agents currently disabled. This feature is not available at the moment." | |
| ) | |
| logger.debug(f"Fetching agent {agent_id} for user: {user_id}") | |
| client = await db.client | |
| try: | |
| # Get agent | |
| agent = await client.table('agents').select('*').eq("agent_id", agent_id).execute() | |
| if not agent.data: | |
| raise HTTPException(status_code=404, detail="Agent not found") | |
| agent_data = agent.data[0] | |
| # Check ownership - only owner can access non-public agents | |
| if agent_data['account_id'] != user_id and not agent_data.get('is_public', False): | |
| raise HTTPException(status_code=403, detail="Access denied") | |
| # Use versioning system to get current version data | |
| current_version = None | |
| if agent_data.get('current_version_id'): | |
| try: | |
| version_service = await _get_version_service() | |
| current_version_obj = await version_service.get_version( | |
| agent_id=agent_id, | |
| version_id=agent_data['current_version_id'], | |
| user_id=user_id | |
| ) | |
| current_version_data = current_version_obj.to_dict() | |
| version_data = current_version_data | |
| # Create AgentVersionResponse from version data | |
| current_version = AgentVersionResponse( | |
| version_id=current_version_data['version_id'], | |
| agent_id=current_version_data['agent_id'], | |
| version_number=current_version_data['version_number'], | |
| version_name=current_version_data['version_name'], | |
| system_prompt=current_version_data['system_prompt'], | |
| model=current_version_data.get('model'), | |
| configured_mcps=current_version_data.get('configured_mcps', []), | |
| custom_mcps=current_version_data.get('custom_mcps', []), | |
| agentpress_tools=current_version_data.get('agentpress_tools', {}), | |
| is_active=current_version_data.get('is_active', True), | |
| created_at=current_version_data['created_at'], | |
| updated_at=current_version_data.get('updated_at', current_version_data['created_at']), | |
| created_by=current_version_data.get('created_by') | |
| ) | |
| logger.debug(f"Using agent {agent_data['name']} version {current_version_data.get('version_name', 'v1')}") | |
| except Exception as e: | |
| logger.warning(f"Failed to get version data for agent {agent_id}: {e}") | |
| # Extract configuration using the unified config approach | |
| version_data = None | |
| if current_version: | |
| version_data = { | |
| 'version_id': current_version.version_id, | |
| 'agent_id': current_version.agent_id, | |
| 'version_number': current_version.version_number, | |
| 'version_name': current_version.version_name, | |
| 'system_prompt': current_version.system_prompt, | |
| 'model': current_version.model, | |
| 'configured_mcps': current_version.configured_mcps, | |
| 'custom_mcps': current_version.custom_mcps, | |
| 'agentpress_tools': current_version.agentpress_tools, | |
| 'is_active': current_version.is_active, | |
| 'created_at': current_version.created_at, | |
| 'updated_at': current_version.updated_at, | |
| 'created_by': current_version.created_by | |
| } | |
| from agent.config_helper import extract_agent_config | |
| agent_config = extract_agent_config(agent_data, version_data) | |
| system_prompt = agent_config['system_prompt'] | |
| configured_mcps = agent_config['configured_mcps'] | |
| custom_mcps = agent_config['custom_mcps'] | |
| agentpress_tools = agent_config['agentpress_tools'] | |
| return AgentResponse( | |
| agent_id=agent_data['agent_id'], | |
| name=agent_data['name'], | |
| description=agent_data.get('description'), | |
| system_prompt=system_prompt, | |
| configured_mcps=configured_mcps, | |
| custom_mcps=custom_mcps, | |
| agentpress_tools=agentpress_tools, | |
| is_default=agent_data.get('is_default', False), | |
| is_public=agent_data.get('is_public', False), | |
| tags=agent_data.get('tags', []), | |
| avatar=agent_config.get('avatar'), | |
| avatar_color=agent_config.get('avatar_color'), | |
| profile_image_url=agent_config.get('profile_image_url'), | |
| created_at=agent_data['created_at'], | |
| updated_at=agent_data.get('updated_at', agent_data['created_at']), | |
| current_version_id=agent_data.get('current_version_id'), | |
| version_count=agent_data.get('version_count', 1), | |
| current_version=current_version, | |
| metadata=agent_data.get('metadata') | |
| ) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error fetching agent {agent_id} for user {user_id}: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Failed to fetch agent: {str(e)}") | |
| async def export_agent(agent_id: str, user_id: str = Depends(get_current_user_id_from_jwt)): | |
| """Export an agent configuration as JSON""" | |
| logger.debug(f"Exporting agent {agent_id} for user: {user_id}") | |
| try: | |
| client = await db.client | |
| # Get agent data | |
| agent_result = await client.table('agents').select('*').eq('agent_id', agent_id).eq('account_id', user_id).execute() | |
| if not agent_result.data: | |
| raise HTTPException(status_code=404, detail="Agent not found") | |
| agent = agent_result.data[0] | |
| # Get current version data if available | |
| current_version = None | |
| if agent.get('current_version_id'): | |
| version_result = await client.table('agent_versions').select('*').eq('version_id', agent['current_version_id']).execute() | |
| if version_result.data: | |
| current_version = version_result.data[0] | |
| from agent.config_helper import extract_agent_config | |
| config = extract_agent_config(agent, current_version) | |
| from templates.template_service import TemplateService | |
| template_service = TemplateService(db) | |
| full_config = { | |
| 'system_prompt': config.get('system_prompt', ''), | |
| 'tools': { | |
| 'agentpress': config.get('agentpress_tools', {}), | |
| 'mcp': config.get('configured_mcps', []), | |
| 'custom_mcp': config.get('custom_mcps', []) | |
| }, | |
| 'metadata': { | |
| # keep backward compat metadata | |
| 'avatar': config.get('avatar'), | |
| 'avatar_color': config.get('avatar_color'), | |
| # include profile image url in metadata for completeness | |
| 'profile_image_url': agent.get('profile_image_url') | |
| } | |
| } | |
| sanitized_config = template_service._fallback_sanitize_config(full_config) | |
| export_metadata = {} | |
| if agent.get('metadata'): | |
| export_metadata = {k: v for k, v in agent['metadata'].items() | |
| if k not in ['is_suna_default', 'centrally_managed', 'installation_date', 'last_central_update']} | |
| export_data = { | |
| "tools": sanitized_config['tools'], | |
| "metadata": sanitized_config['metadata'], | |
| "system_prompt": sanitized_config['system_prompt'], | |
| "name": config.get('name', ''), | |
| "description": config.get('description', ''), | |
| # Deprecated | |
| "avatar": config.get('avatar'), | |
| "avatar_color": config.get('avatar_color'), | |
| # New | |
| "profile_image_url": agent.get('profile_image_url'), | |
| "tags": agent.get('tags', []), | |
| "export_metadata": export_metadata, | |
| "exported_at": datetime.now(timezone.utc).isoformat() | |
| } | |
| logger.debug(f"Successfully exported agent {agent_id}") | |
| return export_data | |
| except Exception as e: | |
| logger.error(f"Error exporting agent {agent_id}: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Failed to export agent: {str(e)}") | |
| # JSON Import endpoints - similar to template installation flow | |
| class JsonAnalysisRequest(BaseModel): | |
| """Request to analyze JSON for import requirements""" | |
| json_data: Dict[str, Any] | |
| class JsonAnalysisResponse(BaseModel): | |
| """Response from JSON analysis""" | |
| requires_setup: bool | |
| missing_regular_credentials: List[Dict[str, Any]] = [] | |
| missing_custom_configs: List[Dict[str, Any]] = [] | |
| agent_info: Dict[str, Any] = {} | |
| class JsonImportRequestModel(BaseModel): | |
| """Request to import agent from JSON""" | |
| json_data: Dict[str, Any] | |
| instance_name: Optional[str] = None | |
| custom_system_prompt: Optional[str] = None | |
| profile_mappings: Optional[Dict[str, str]] = None | |
| custom_mcp_configs: Optional[Dict[str, Dict[str, Any]]] = None | |
| class JsonImportResponse(BaseModel): | |
| """Response from JSON import""" | |
| status: str | |
| instance_id: Optional[str] = None | |
| name: Optional[str] = None | |
| missing_regular_credentials: List[Dict[str, Any]] = [] | |
| missing_custom_configs: List[Dict[str, Any]] = [] | |
| agent_info: Dict[str, Any] = {} | |
| async def analyze_json_for_import( | |
| request: JsonAnalysisRequest, | |
| user_id: str = Depends(get_current_user_id_from_jwt) | |
| ): | |
| """Analyze imported JSON to determine required credentials and configurations""" | |
| logger.debug(f"Analyzing JSON for import - user: {user_id}") | |
| if not await is_enabled("custom_agents"): | |
| raise HTTPException( | |
| status_code=403, | |
| detail="Custom agents currently disabled. This feature is not available at the moment." | |
| ) | |
| try: | |
| from agent.json_import_service import JsonImportService | |
| import_service = JsonImportService(db) | |
| analysis = await import_service.analyze_json(request.json_data, user_id) | |
| return JsonAnalysisResponse( | |
| requires_setup=analysis.requires_setup, | |
| missing_regular_credentials=analysis.missing_regular_credentials, | |
| missing_custom_configs=analysis.missing_custom_configs, | |
| agent_info=analysis.agent_info | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error analyzing JSON: {str(e)}") | |
| raise HTTPException(status_code=400, detail=f"Failed to analyze JSON: {str(e)}") | |
| async def import_agent_from_json( | |
| request: JsonImportRequestModel, | |
| user_id: str = Depends(get_current_user_id_from_jwt) | |
| ): | |
| logger.debug(f"Importing agent from JSON - user: {user_id}") | |
| if not await is_enabled("custom_agents"): | |
| raise HTTPException( | |
| status_code=403, | |
| detail="Custom agents currently disabled. This feature is not available at the moment." | |
| ) | |
| client = await db.client | |
| from .utils import check_agent_count_limit | |
| limit_check = await check_agent_count_limit(client, user_id) | |
| if not limit_check['can_create']: | |
| error_detail = { | |
| "message": f"Maximum of {limit_check['limit']} agents allowed for your current plan. You have {limit_check['current_count']} agents.", | |
| "current_count": limit_check['current_count'], | |
| "limit": limit_check['limit'], | |
| "tier_name": limit_check['tier_name'], | |
| "error_code": "AGENT_LIMIT_EXCEEDED" | |
| } | |
| logger.warning(f"Agent limit exceeded for account {user_id}: {limit_check['current_count']}/{limit_check['limit']} agents") | |
| raise HTTPException(status_code=402, detail=error_detail) | |
| try: | |
| from agent.json_import_service import JsonImportService, JsonImportRequest | |
| import_service = JsonImportService(db) | |
| import_request = JsonImportRequest( | |
| json_data=request.json_data, | |
| account_id=user_id, | |
| instance_name=request.instance_name, | |
| custom_system_prompt=request.custom_system_prompt, | |
| profile_mappings=request.profile_mappings, | |
| custom_mcp_configs=request.custom_mcp_configs | |
| ) | |
| result = await import_service.import_json(import_request) | |
| return JsonImportResponse( | |
| status=result.status, | |
| instance_id=result.instance_id, | |
| name=result.name, | |
| missing_regular_credentials=result.missing_regular_credentials, | |
| missing_custom_configs=result.missing_custom_configs, | |
| agent_info=result.agent_info | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error importing agent from JSON: {str(e)}") | |
| raise HTTPException(status_code=400, detail=f"Failed to import agent: {str(e)}") | |
| async def create_agent( | |
| agent_data: AgentCreateRequest, | |
| user_id: str = Depends(get_current_user_id_from_jwt) | |
| ): | |
| logger.debug(f"Creating new agent for user: {user_id}") | |
| if not await is_enabled("custom_agents"): | |
| raise HTTPException( | |
| status_code=403, | |
| detail="Custom agents currently disabled. This feature is not available at the moment." | |
| ) | |
| client = await db.client | |
| from .utils import check_agent_count_limit | |
| limit_check = await check_agent_count_limit(client, user_id) | |
| if not limit_check['can_create']: | |
| error_detail = { | |
| "message": f"Maximum of {limit_check['limit']} agents allowed for your current plan. You have {limit_check['current_count']} agents.", | |
| "current_count": limit_check['current_count'], | |
| "limit": limit_check['limit'], | |
| "tier_name": limit_check['tier_name'], | |
| "error_code": "AGENT_LIMIT_EXCEEDED" | |
| } | |
| logger.warning(f"Agent limit exceeded for account {user_id}: {limit_check['current_count']}/{limit_check['limit']} agents") | |
| raise HTTPException(status_code=402, detail=error_detail) | |
| try: | |
| if agent_data.is_default: | |
| await client.table('agents').update({"is_default": False}).eq("account_id", user_id).eq("is_default", True).execute() | |
| insert_data = { | |
| "account_id": user_id, | |
| "name": agent_data.name, | |
| "description": agent_data.description, | |
| # Deprecated fields still populated if sent by older clients | |
| "avatar": agent_data.avatar, | |
| "avatar_color": agent_data.avatar_color, | |
| # New profile image url field | |
| "profile_image_url": agent_data.profile_image_url, | |
| "is_default": agent_data.is_default or False, | |
| "version_count": 1 | |
| } | |
| new_agent = await client.table('agents').insert(insert_data).execute() | |
| if not new_agent.data: | |
| raise HTTPException(status_code=500, detail="Failed to create agent") | |
| agent = new_agent.data[0] | |
| try: | |
| version_service = await _get_version_service() | |
| from agent.suna_config import SUNA_CONFIG | |
| from agent.config_helper import _get_default_agentpress_tools | |
| from models import model_manager | |
| system_prompt = SUNA_CONFIG["system_prompt"] | |
| agentpress_tools = agent_data.agentpress_tools if agent_data.agentpress_tools else _get_default_agentpress_tools() | |
| default_model = await model_manager.get_default_model_for_user(client, user_id) | |
| version = await version_service.create_version( | |
| agent_id=agent['agent_id'], | |
| user_id=user_id, | |
| system_prompt=system_prompt, | |
| model=default_model, | |
| configured_mcps=agent_data.configured_mcps or [], | |
| custom_mcps=agent_data.custom_mcps or [], | |
| agentpress_tools=agentpress_tools, | |
| version_name="v1", | |
| change_description="Initial version" | |
| ) | |
| agent['current_version_id'] = version.version_id | |
| agent['version_count'] = 1 | |
| current_version = AgentVersionResponse( | |
| version_id=version.version_id, | |
| agent_id=version.agent_id, | |
| version_number=version.version_number, | |
| version_name=version.version_name, | |
| system_prompt=version.system_prompt, | |
| model=version.model, | |
| configured_mcps=version.configured_mcps, | |
| custom_mcps=version.custom_mcps, | |
| agentpress_tools=version.agentpress_tools, | |
| is_active=version.is_active, | |
| created_at=version.created_at.isoformat(), | |
| updated_at=version.updated_at.isoformat(), | |
| created_by=version.created_by | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error creating initial version: {str(e)}") | |
| await client.table('agents').delete().eq('agent_id', agent['agent_id']).execute() | |
| raise HTTPException(status_code=500, detail="Failed to create initial version") | |
| from utils.cache import Cache | |
| await Cache.invalidate(f"agent_count_limit:{user_id}") | |
| logger.debug(f"Created agent {agent['agent_id']} with v1 for user: {user_id}") | |
| return AgentResponse( | |
| agent_id=agent['agent_id'], | |
| name=agent['name'], | |
| description=agent.get('description'), | |
| system_prompt=version.system_prompt, | |
| model=version.model, | |
| configured_mcps=version.configured_mcps, | |
| custom_mcps=version.custom_mcps, | |
| agentpress_tools=version.agentpress_tools, | |
| is_default=agent.get('is_default', False), | |
| is_public=agent.get('is_public', False), | |
| tags=agent.get('tags', []), | |
| avatar=agent.get('avatar'), | |
| avatar_color=agent.get('avatar_color'), | |
| profile_image_url=agent.get('profile_image_url'), | |
| created_at=agent['created_at'], | |
| updated_at=agent.get('updated_at', agent['created_at']), | |
| current_version_id=agent.get('current_version_id'), | |
| version_count=agent.get('version_count', 1), | |
| current_version=current_version, | |
| metadata=agent.get('metadata') | |
| ) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error creating agent for user {user_id}: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Failed to create agent: {str(e)}") | |
| def merge_custom_mcps(existing_mcps: List[Dict[str, Any]], new_mcps: List[Dict[str, Any]]) -> List[Dict[str, Any]]: | |
| if not new_mcps: | |
| return existing_mcps | |
| merged_mcps = existing_mcps.copy() | |
| for new_mcp in new_mcps: | |
| new_mcp_name = new_mcp.get('name') | |
| existing_index = None | |
| for i, existing_mcp in enumerate(merged_mcps): | |
| if existing_mcp.get('name') == new_mcp_name: | |
| existing_index = i | |
| break | |
| if existing_index is not None: | |
| merged_mcps[existing_index] = new_mcp | |
| else: | |
| merged_mcps.append(new_mcp) | |
| return merged_mcps | |
| async def update_agent( | |
| agent_id: str, | |
| agent_data: AgentUpdateRequest, | |
| user_id: str = Depends(get_current_user_id_from_jwt) | |
| ): | |
| if not await is_enabled("custom_agents"): | |
| raise HTTPException( | |
| status_code=403, | |
| detail="Custom agent currently disabled. This feature is not available at the moment." | |
| ) | |
| logger.debug(f"Updating agent {agent_id} for user: {user_id}") | |
| client = await db.client | |
| try: | |
| existing_agent = await client.table('agents').select('*').eq("agent_id", agent_id).eq("account_id", user_id).maybe_single().execute() | |
| if not existing_agent.data: | |
| raise HTTPException(status_code=404, detail="Agent not found") | |
| existing_data = existing_agent.data | |
| agent_metadata = existing_data.get('metadata', {}) | |
| is_suna_agent = agent_metadata.get('is_suna_default', False) | |
| restrictions = agent_metadata.get('restrictions', {}) | |
| if is_suna_agent: | |
| logger.warning(f"Update attempt on Suna default agent {agent_id} by user {user_id}") | |
| if (agent_data.name is not None and | |
| agent_data.name != existing_data.get('name') and | |
| restrictions.get('name_editable') == False): | |
| logger.error(f"User {user_id} attempted to modify restricted name of Suna agent {agent_id}") | |
| raise HTTPException( | |
| status_code=403, | |
| detail="Suna's name cannot be modified. This restriction is managed centrally." | |
| ) | |
| if (agent_data.description is not None and | |
| agent_data.description != existing_data.get('description') and | |
| restrictions.get('description_editable') == False): | |
| logger.error(f"User {user_id} attempted to modify restricted description of Suna agent {agent_id}") | |
| raise HTTPException( | |
| status_code=403, | |
| detail="Suna's description cannot be modified." | |
| ) | |
| if (agent_data.system_prompt is not None and | |
| restrictions.get('system_prompt_editable') == False): | |
| logger.error(f"User {user_id} attempted to modify restricted system prompt of Suna agent {agent_id}") | |
| raise HTTPException( | |
| status_code=403, | |
| detail="Suna's system prompt cannot be modified. This is managed centrally to ensure optimal performance." | |
| ) | |
| if (agent_data.agentpress_tools is not None and | |
| restrictions.get('tools_editable') == False): | |
| logger.error(f"User {user_id} attempted to modify restricted tools of Suna agent {agent_id}") | |
| raise HTTPException( | |
| status_code=403, | |
| detail="Suna's default tools cannot be modified. These tools are optimized for Suna's capabilities." | |
| ) | |
| if ((agent_data.configured_mcps is not None or agent_data.custom_mcps is not None) and | |
| restrictions.get('mcps_editable') == False): | |
| logger.error(f"User {user_id} attempted to modify restricted MCPs of Suna agent {agent_id}") | |
| raise HTTPException( | |
| status_code=403, | |
| detail="Suna's integrations cannot be modified." | |
| ) | |
| logger.debug(f"Suna agent update validation passed for agent {agent_id} by user {user_id}") | |
| current_version_data = None | |
| if existing_data.get('current_version_id'): | |
| try: | |
| version_service = await _get_version_service() | |
| current_version_obj = await version_service.get_version( | |
| agent_id=agent_id, | |
| version_id=existing_data['current_version_id'], | |
| user_id=user_id | |
| ) | |
| current_version_data = current_version_obj.to_dict() | |
| except Exception as e: | |
| logger.warning(f"Failed to get current version data for agent {agent_id}: {e}") | |
| if current_version_data is None: | |
| logger.debug(f"Agent {agent_id} has no version data, creating initial version") | |
| try: | |
| workflows_result = await client.table('agent_workflows').select('*').eq('agent_id', agent_id).execute() | |
| workflows = workflows_result.data if workflows_result.data else [] | |
| # Fetch triggers for the agent | |
| triggers_result = await client.table('agent_triggers').select('*').eq('agent_id', agent_id).execute() | |
| triggers = [] | |
| if triggers_result.data: | |
| import json | |
| for trigger in triggers_result.data: | |
| # Parse the config string if it's a string | |
| trigger_copy = trigger.copy() | |
| if 'config' in trigger_copy and isinstance(trigger_copy['config'], str): | |
| try: | |
| trigger_copy['config'] = json.loads(trigger_copy['config']) | |
| except json.JSONDecodeError: | |
| logger.warning(f"Failed to parse trigger config for {trigger_copy.get('trigger_id')}") | |
| trigger_copy['config'] = {} | |
| triggers.append(trigger_copy) | |
| initial_version_data = { | |
| "agent_id": agent_id, | |
| "version_number": 1, | |
| "version_name": "v1", | |
| "system_prompt": existing_data.get('system_prompt', ''), | |
| "configured_mcps": existing_data.get('configured_mcps', []), | |
| "custom_mcps": existing_data.get('custom_mcps', []), | |
| "agentpress_tools": existing_data.get('agentpress_tools', {}), | |
| "is_active": True, | |
| "created_by": user_id | |
| } | |
| initial_config = build_unified_config( | |
| system_prompt=initial_version_data["system_prompt"], | |
| agentpress_tools=initial_version_data["agentpress_tools"], | |
| configured_mcps=initial_version_data["configured_mcps"], | |
| custom_mcps=initial_version_data["custom_mcps"], | |
| avatar=None, | |
| avatar_color=None, | |
| workflows=workflows, | |
| triggers=triggers | |
| ) | |
| initial_version_data["config"] = initial_config | |
| version_result = await client.table('agent_versions').insert(initial_version_data).execute() | |
| if version_result.data: | |
| version_id = version_result.data[0]['version_id'] | |
| await client.table('agents').update({ | |
| 'current_version_id': version_id, | |
| 'version_count': 1 | |
| }).eq('agent_id', agent_id).execute() | |
| current_version_data = initial_version_data | |
| logger.debug(f"Created initial version for agent {agent_id}") | |
| else: | |
| current_version_data = { | |
| 'system_prompt': existing_data.get('system_prompt', ''), | |
| 'configured_mcps': existing_data.get('configured_mcps', []), | |
| 'custom_mcps': existing_data.get('custom_mcps', []), | |
| 'agentpress_tools': existing_data.get('agentpress_tools', {}) | |
| } | |
| except Exception as e: | |
| logger.warning(f"Failed to create initial version for agent {agent_id}: {e}") | |
| current_version_data = { | |
| 'system_prompt': existing_data.get('system_prompt', ''), | |
| 'configured_mcps': existing_data.get('configured_mcps', []), | |
| 'custom_mcps': existing_data.get('custom_mcps', []), | |
| 'agentpress_tools': existing_data.get('agentpress_tools', {}) | |
| } | |
| needs_new_version = False | |
| version_changes = {} | |
| def values_different(new_val, old_val): | |
| if new_val is None: | |
| return False | |
| import json | |
| try: | |
| new_json = json.dumps(new_val, sort_keys=True) if new_val is not None else None | |
| old_json = json.dumps(old_val, sort_keys=True) if old_val is not None else None | |
| return new_json != old_json | |
| except (TypeError, ValueError): | |
| return new_val != old_val | |
| if values_different(agent_data.system_prompt, current_version_data.get('system_prompt')): | |
| needs_new_version = True | |
| version_changes['system_prompt'] = agent_data.system_prompt | |
| if values_different(agent_data.configured_mcps, current_version_data.get('configured_mcps', [])): | |
| needs_new_version = True | |
| version_changes['configured_mcps'] = agent_data.configured_mcps | |
| if values_different(agent_data.custom_mcps, current_version_data.get('custom_mcps', [])): | |
| needs_new_version = True | |
| if agent_data.custom_mcps is not None: | |
| merged_custom_mcps = merge_custom_mcps( | |
| current_version_data.get('custom_mcps', []), | |
| agent_data.custom_mcps | |
| ) | |
| version_changes['custom_mcps'] = merged_custom_mcps | |
| else: | |
| version_changes['custom_mcps'] = current_version_data.get('custom_mcps', []) | |
| if values_different(agent_data.agentpress_tools, current_version_data.get('agentpress_tools', {})): | |
| needs_new_version = True | |
| version_changes['agentpress_tools'] = agent_data.agentpress_tools | |
| update_data = {} | |
| if agent_data.name is not None: | |
| update_data["name"] = agent_data.name | |
| if agent_data.description is not None: | |
| update_data["description"] = agent_data.description | |
| if agent_data.is_default is not None: | |
| update_data["is_default"] = agent_data.is_default | |
| if agent_data.is_default: | |
| await client.table('agents').update({"is_default": False}).eq("account_id", user_id).eq("is_default", True).neq("agent_id", agent_id).execute() | |
| if agent_data.avatar is not None: | |
| update_data["avatar"] = agent_data.avatar | |
| if agent_data.avatar_color is not None: | |
| update_data["avatar_color"] = agent_data.avatar_color | |
| if agent_data.profile_image_url is not None: | |
| update_data["profile_image_url"] = agent_data.profile_image_url | |
| current_system_prompt = agent_data.system_prompt if agent_data.system_prompt is not None else current_version_data.get('system_prompt', '') | |
| current_configured_mcps = agent_data.configured_mcps if agent_data.configured_mcps is not None else current_version_data.get('configured_mcps', []) | |
| if agent_data.custom_mcps is not None: | |
| current_custom_mcps = merge_custom_mcps( | |
| current_version_data.get('custom_mcps', []), | |
| agent_data.custom_mcps | |
| ) | |
| else: | |
| current_custom_mcps = current_version_data.get('custom_mcps', []) | |
| current_agentpress_tools = agent_data.agentpress_tools if agent_data.agentpress_tools is not None else current_version_data.get('agentpress_tools', {}) | |
| current_avatar = agent_data.avatar if agent_data.avatar is not None else existing_data.get('avatar') | |
| current_avatar_color = agent_data.avatar_color if agent_data.avatar_color is not None else existing_data.get('avatar_color') | |
| new_version_id = None | |
| if needs_new_version: | |
| try: | |
| version_service = await _get_version_service() | |
| new_version = await version_service.create_version( | |
| agent_id=agent_id, | |
| user_id=user_id, | |
| system_prompt=current_system_prompt, | |
| configured_mcps=current_configured_mcps, | |
| custom_mcps=current_custom_mcps, | |
| agentpress_tools=current_agentpress_tools, | |
| change_description="Configuration updated" | |
| ) | |
| new_version_id = new_version.version_id | |
| update_data['current_version_id'] = new_version_id | |
| update_data['version_count'] = new_version.version_number | |
| logger.debug(f"Created new version {new_version.version_name} for agent {agent_id}") | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error creating new version for agent {agent_id}: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Failed to create new agent version: {str(e)}") | |
| if update_data: | |
| try: | |
| update_result = await client.table('agents').update(update_data).eq("agent_id", agent_id).eq("account_id", user_id).execute() | |
| if not update_result.data: | |
| raise HTTPException(status_code=500, detail="Failed to update agent - no rows affected") | |
| except Exception as e: | |
| logger.error(f"Error updating agent {agent_id}: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Failed to update agent: {str(e)}") | |
| updated_agent = await client.table('agents').select('*').eq("agent_id", agent_id).eq("account_id", user_id).maybe_single().execute() | |
| if not updated_agent.data: | |
| raise HTTPException(status_code=500, detail="Failed to fetch updated agent") | |
| agent = updated_agent.data | |
| current_version = None | |
| if agent.get('current_version_id'): | |
| try: | |
| version_service = await _get_version_service() | |
| current_version_obj = await version_service.get_version( | |
| agent_id=agent_id, | |
| version_id=agent['current_version_id'], | |
| user_id=user_id | |
| ) | |
| current_version_data = current_version_obj.to_dict() | |
| version_data = current_version_data | |
| current_version = AgentVersionResponse( | |
| version_id=current_version_data['version_id'], | |
| agent_id=current_version_data['agent_id'], | |
| version_number=current_version_data['version_number'], | |
| version_name=current_version_data['version_name'], | |
| system_prompt=current_version_data['system_prompt'], | |
| model=current_version_data.get('model'), | |
| configured_mcps=current_version_data.get('configured_mcps', []), | |
| custom_mcps=current_version_data.get('custom_mcps', []), | |
| agentpress_tools=current_version_data.get('agentpress_tools', {}), | |
| is_active=current_version_data.get('is_active', True), | |
| created_at=current_version_data['created_at'], | |
| updated_at=current_version_data.get('updated_at', current_version_data['created_at']), | |
| created_by=current_version_data.get('created_by') | |
| ) | |
| logger.debug(f"Using agent {agent['name']} version {current_version_data.get('version_name', 'v1')}") | |
| except Exception as e: | |
| logger.warning(f"Failed to get version data for updated agent {agent_id}: {e}") | |
| version_data = None | |
| if current_version: | |
| version_data = { | |
| 'version_id': current_version.version_id, | |
| 'agent_id': current_version.agent_id, | |
| 'version_number': current_version.version_number, | |
| 'version_name': current_version.version_name, | |
| 'system_prompt': current_version.system_prompt, | |
| 'model': current_version.model, | |
| 'configured_mcps': current_version.configured_mcps, | |
| 'custom_mcps': current_version.custom_mcps, | |
| 'agentpress_tools': current_version.agentpress_tools, | |
| 'is_active': current_version.is_active, | |
| } | |
| from agent.config_helper import extract_agent_config | |
| agent_config = extract_agent_config(agent, version_data) | |
| system_prompt = agent_config['system_prompt'] | |
| configured_mcps = agent_config['configured_mcps'] | |
| custom_mcps = agent_config['custom_mcps'] | |
| agentpress_tools = agent_config['agentpress_tools'] | |
| return AgentResponse( | |
| agent_id=agent['agent_id'], | |
| name=agent['name'], | |
| description=agent.get('description'), | |
| system_prompt=system_prompt, | |
| configured_mcps=configured_mcps, | |
| custom_mcps=custom_mcps, | |
| agentpress_tools=agentpress_tools, | |
| is_default=agent.get('is_default', False), | |
| is_public=agent.get('is_public', False), | |
| tags=agent.get('tags', []), | |
| avatar=agent_config.get('avatar'), | |
| avatar_color=agent_config.get('avatar_color'), | |
| profile_image_url=agent_config.get('profile_image_url'), | |
| created_at=agent['created_at'], | |
| updated_at=agent.get('updated_at', agent['created_at']), | |
| current_version_id=agent.get('current_version_id'), | |
| version_count=agent.get('version_count', 1), | |
| current_version=current_version, | |
| metadata=agent.get('metadata') | |
| ) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error updating agent {agent_id} for user {user_id}: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Failed to update agent: {str(e)}") | |
| async def delete_agent(agent_id: str, user_id: str = Depends(get_current_user_id_from_jwt)): | |
| if not await is_enabled("custom_agents"): | |
| raise HTTPException( | |
| status_code=403, | |
| detail="Custom agent currently disabled. This feature is not available at the moment." | |
| ) | |
| logger.debug(f"Deleting agent: {agent_id}") | |
| client = await db.client | |
| try: | |
| agent_result = await client.table('agents').select('*').eq('agent_id', agent_id).execute() | |
| if not agent_result.data: | |
| raise HTTPException(status_code=404, detail="Agent not found") | |
| agent = agent_result.data[0] | |
| if agent['account_id'] != user_id: | |
| raise HTTPException(status_code=403, detail="Access denied") | |
| if agent['is_default']: | |
| raise HTTPException(status_code=400, detail="Cannot delete default agent") | |
| if agent.get('metadata', {}).get('is_suna_default', False): | |
| raise HTTPException(status_code=400, detail="Cannot delete Suna default agent") | |
| # Clean up triggers before deleting agent to ensure proper remote cleanup | |
| try: | |
| from triggers.trigger_service import get_trigger_service | |
| trigger_service = get_trigger_service(db) | |
| # Get all triggers for this agent | |
| triggers_result = await client.table('agent_triggers').select('trigger_id').eq('agent_id', agent_id).execute() | |
| if triggers_result.data: | |
| logger.debug(f"Cleaning up {len(triggers_result.data)} triggers for agent {agent_id}") | |
| # Delete each trigger properly (this handles remote cleanup) | |
| for trigger_record in triggers_result.data: | |
| trigger_id = trigger_record['trigger_id'] | |
| try: | |
| await trigger_service.delete_trigger(trigger_id) | |
| logger.debug(f"Successfully cleaned up trigger {trigger_id}") | |
| except Exception as e: | |
| logger.warning(f"Failed to clean up trigger {trigger_id}: {str(e)}") | |
| # Continue with other triggers even if one fails | |
| except Exception as e: | |
| logger.warning(f"Failed to clean up triggers for agent {agent_id}: {str(e)}") | |
| # Continue with agent deletion even if trigger cleanup fails | |
| delete_result = await client.table('agents').delete().eq('agent_id', agent_id).execute() | |
| if not delete_result.data: | |
| logger.warning(f"No agent was deleted for agent_id: {agent_id}, user_id: {user_id}") | |
| raise HTTPException(status_code=403, detail="Unable to delete agent - permission denied or agent not found") | |
| try: | |
| from utils.cache import Cache | |
| await Cache.invalidate(f"agent_count_limit:{user_id}") | |
| except Exception as cache_error: | |
| logger.warning(f"Cache invalidation failed for user {user_id}: {str(cache_error)}") | |
| logger.debug(f"Successfully deleted agent: {agent_id}") | |
| return {"message": "Agent deleted successfully"} | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error deleting agent {agent_id}: {str(e)}") | |
| raise HTTPException(status_code=500, detail="Internal server error") | |
| async def get_pipedream_tools_for_agent( | |
| agent_id: str, | |
| profile_id: str, | |
| user_id: str = Depends(get_current_user_id_from_jwt), | |
| version: Optional[str] = Query(None, description="Version ID to get tools from specific version") | |
| ): | |
| logger.debug(f"Getting tools for agent {agent_id}, profile {profile_id}, user {user_id}, version {version}") | |
| try: | |
| from pipedream import profile_service, mcp_service | |
| from uuid import UUID | |
| profile = await profile_service.get_profile(UUID(user_id), UUID(profile_id)) | |
| if not profile: | |
| logger.error(f"Profile {profile_id} not found for user {user_id}") | |
| try: | |
| all_profiles = await profile_service.get_profiles(UUID(user_id)) | |
| pipedream_profiles = [p for p in all_profiles if 'pipedream' in p.mcp_qualified_name] | |
| logger.debug(f"User {user_id} has {len(pipedream_profiles)} pipedream profiles: {[p.profile_id for p in pipedream_profiles]}") | |
| except Exception as debug_e: | |
| logger.warning(f"Could not check user's profiles: {str(debug_e)}") | |
| raise HTTPException(status_code=404, detail=f"Profile {profile_id} not found or access denied") | |
| if not profile.is_connected: | |
| raise HTTPException(status_code=400, detail="Profile is not connected") | |
| enabled_tools = [] | |
| try: | |
| client = await db.client | |
| agent_row = await client.table('agents')\ | |
| .select('current_version_id')\ | |
| .eq('agent_id', agent_id)\ | |
| .eq('account_id', user_id)\ | |
| .maybe_single()\ | |
| .execute() | |
| if agent_row.data and agent_row.data.get('current_version_id'): | |
| if version: | |
| version_result = await client.table('agent_versions')\ | |
| .select('config')\ | |
| .eq('version_id', version)\ | |
| .maybe_single()\ | |
| .execute() | |
| else: | |
| version_result = await client.table('agent_versions')\ | |
| .select('config')\ | |
| .eq('version_id', agent_row.data['current_version_id'])\ | |
| .maybe_single()\ | |
| .execute() | |
| if version_result.data and version_result.data.get('config'): | |
| agent_config = version_result.data['config'] | |
| tools = agent_config.get('tools', {}) | |
| custom_mcps = tools.get('custom_mcp', []) or [] | |
| for mcp in custom_mcps: | |
| mcp_profile_id = mcp.get('config', {}).get('profile_id') | |
| if mcp_profile_id == profile_id: | |
| enabled_tools = mcp.get('enabledTools', mcp.get('enabled_tools', [])) | |
| logger.debug(f"Found enabled tools for profile {profile_id}: {enabled_tools}") | |
| break | |
| if not enabled_tools: | |
| logger.debug(f"No enabled tools found for profile {profile_id} in agent {agent_id}") | |
| except Exception as e: | |
| logger.error(f"Error retrieving enabled tools for profile {profile_id}: {str(e)}") | |
| logger.debug(f"Using {len(enabled_tools)} enabled tools for profile {profile_id}: {enabled_tools}") | |
| try: | |
| from pipedream.mcp_service import ExternalUserId, AppSlug | |
| external_user_id = ExternalUserId(profile.external_user_id) | |
| app_slug_obj = AppSlug(profile.app_slug) | |
| logger.debug(f"Discovering servers for user {external_user_id.value} and app {app_slug_obj.value}") | |
| servers = await mcp_service.discover_servers_for_user(external_user_id, app_slug_obj) | |
| logger.debug(f"Found {len(servers)} servers: {[s.app_slug for s in servers]}") | |
| server = servers[0] if servers else None | |
| logger.debug(f"Selected server: {server.app_slug if server else 'None'} with {len(server.available_tools) if server else 0} tools") | |
| if not server: | |
| return { | |
| 'profile_id': profile_id, | |
| 'app_name': profile.app_name, | |
| 'profile_name': profile.profile_name, | |
| 'tools': [], | |
| 'has_mcp_config': len(enabled_tools) > 0 | |
| } | |
| available_tools = server.available_tools | |
| formatted_tools = [] | |
| def tools_match(api_tool_name, stored_tool_name): | |
| api_normalized = api_tool_name.lower().replace('-', '_') | |
| stored_normalized = stored_tool_name.lower().replace('-', '_') | |
| return api_normalized == stored_normalized | |
| for tool in available_tools: | |
| is_enabled = any(tools_match(tool.name, stored_tool) for stored_tool in enabled_tools) | |
| formatted_tools.append({ | |
| 'name': tool.name, | |
| 'description': tool.description or f"Tool from {profile.app_name}", | |
| 'enabled': is_enabled | |
| }) | |
| return { | |
| 'profile_id': profile_id, | |
| 'app_name': profile.app_name, | |
| 'profile_name': profile.profile_name, | |
| 'tools': formatted_tools, | |
| 'has_mcp_config': len(enabled_tools) > 0 | |
| } | |
| except Exception as e: | |
| logger.error(f"Error discovering tools: {e}", exc_info=True) | |
| return { | |
| 'profile_id': profile_id, | |
| 'app_name': getattr(profile, 'app_name', 'Unknown'), | |
| 'profile_name': getattr(profile, 'profile_name', 'Unknown'), | |
| 'tools': [], | |
| 'has_mcp_config': len(enabled_tools) > 0, | |
| 'error': str(e) | |
| } | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error getting Pipedream tools for agent {agent_id}: {e}") | |
| raise HTTPException(status_code=500, detail="Internal server error") | |
| async def update_pipedream_tools_for_agent( | |
| agent_id: str, | |
| profile_id: str, | |
| request: dict, | |
| user_id: str = Depends(get_current_user_id_from_jwt) | |
| ): | |
| try: | |
| client = await db.client | |
| agent_row = await client.table('agents')\ | |
| .select('current_version_id')\ | |
| .eq('agent_id', agent_id)\ | |
| .eq('account_id', user_id)\ | |
| .maybe_single()\ | |
| .execute() | |
| if not agent_row.data: | |
| raise HTTPException(status_code=404, detail="Agent not found") | |
| agent_config = {} | |
| if agent_row.data.get('current_version_id'): | |
| version_result = await client.table('agent_versions')\ | |
| .select('config')\ | |
| .eq('version_id', agent_row.data['current_version_id'])\ | |
| .maybe_single()\ | |
| .execute() | |
| if version_result.data and version_result.data.get('config'): | |
| agent_config = version_result.data['config'] | |
| tools = agent_config.get('tools', {}) | |
| custom_mcps = tools.get('custom_mcp', []) or [] | |
| if any(mcp.get('config', {}).get('profile_id') == profile_id for mcp in custom_mcps): | |
| raise HTTPException(status_code=400, detail="This profile is already added to this agent") | |
| enabled_tools = request.get('enabled_tools', []) | |
| updated = False | |
| for mcp in custom_mcps: | |
| mcp_profile_id = mcp.get('config', {}).get('profile_id') | |
| if mcp_profile_id == profile_id: | |
| mcp['enabledTools'] = enabled_tools | |
| mcp['enabled_tools'] = enabled_tools | |
| updated = True | |
| logger.debug(f"Updated enabled tools for profile {profile_id}: {enabled_tools}") | |
| break | |
| if not updated: | |
| logger.warning(f"Profile {profile_id} not found in agent {agent_id} custom_mcps configuration") | |
| if updated: | |
| agent_config['tools']['custom_mcp'] = custom_mcps | |
| await client.table('agent_versions')\ | |
| .update({'config': agent_config})\ | |
| .eq('version_id', agent_row.data['current_version_id'])\ | |
| .execute() | |
| logger.debug(f"Successfully updated agent configuration for {agent_id}") | |
| result = { | |
| 'success': updated, | |
| 'enabled_tools': enabled_tools, | |
| 'total_tools': len(enabled_tools), | |
| 'profile_id': profile_id | |
| } | |
| logger.debug(f"Successfully updated Pipedream tools for agent {agent_id}, profile {profile_id}") | |
| return result | |
| except ValueError as e: | |
| logger.error(f"Validation error updating Pipedream tools: {e}") | |
| raise HTTPException(status_code=404, detail=str(e)) | |
| except Exception as e: | |
| logger.error(f"Error updating Pipedream tools for agent {agent_id}: {e}") | |
| raise HTTPException(status_code=500, detail="Internal server error") | |
| async def get_custom_mcp_tools_for_agent( | |
| agent_id: str, | |
| request: Request, | |
| user_id: str = Depends(get_current_user_id_from_jwt) | |
| ): | |
| logger.debug(f"Getting custom MCP tools for agent {agent_id}, user {user_id}") | |
| try: | |
| client = await db.client | |
| agent_result = await client.table('agents').select('current_version_id').eq('agent_id', agent_id).eq('account_id', user_id).execute() | |
| if not agent_result.data: | |
| raise HTTPException(status_code=404, detail="Agent not found") | |
| agent = agent_result.data[0] | |
| agent_config = {} | |
| if agent.get('current_version_id'): | |
| version_result = await client.table('agent_versions')\ | |
| .select('config')\ | |
| .eq('version_id', agent['current_version_id'])\ | |
| .maybe_single()\ | |
| .execute() | |
| if version_result.data and version_result.data.get('config'): | |
| agent_config = version_result.data['config'] | |
| tools = agent_config.get('tools', {}) | |
| custom_mcps = tools.get('custom_mcp', []) | |
| mcp_url = request.headers.get('X-MCP-URL') | |
| mcp_type = request.headers.get('X-MCP-Type', 'sse') | |
| if not mcp_url: | |
| raise HTTPException(status_code=400, detail="X-MCP-URL header is required") | |
| mcp_config = { | |
| 'url': mcp_url, | |
| 'type': mcp_type | |
| } | |
| if 'X-MCP-Headers' in request.headers: | |
| import json | |
| try: | |
| mcp_config['headers'] = json.loads(request.headers['X-MCP-Headers']) | |
| except json.JSONDecodeError: | |
| logger.warning("Failed to parse X-MCP-Headers as JSON") | |
| from mcp_module import mcp_service | |
| discovery_result = await mcp_service.discover_custom_tools(mcp_type, mcp_config) | |
| existing_mcp = None | |
| for mcp in custom_mcps: | |
| if (mcp.get('type') == mcp_type and | |
| mcp.get('config', {}).get('url') == mcp_url): | |
| existing_mcp = mcp | |
| break | |
| tools = [] | |
| enabled_tools = existing_mcp.get('enabledTools', []) if existing_mcp else [] | |
| for tool in discovery_result.tools: | |
| tools.append({ | |
| 'name': tool['name'], | |
| 'description': tool.get('description', f'Tool from {mcp_type.upper()} MCP server'), | |
| 'enabled': tool['name'] in enabled_tools | |
| }) | |
| return { | |
| 'tools': tools, | |
| 'has_mcp_config': existing_mcp is not None, | |
| 'server_type': mcp_type, | |
| 'server_url': mcp_url | |
| } | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error getting custom MCP tools for agent {agent_id}: {e}") | |
| raise HTTPException(status_code=500, detail="Internal server error") | |
| async def update_custom_mcp_tools_for_agent( | |
| agent_id: str, | |
| request: dict, | |
| user_id: str = Depends(get_current_user_id_from_jwt) | |
| ): | |
| logger.debug(f"Updating custom MCP tools for agent {agent_id}, user {user_id}") | |
| try: | |
| client = await db.client | |
| agent_result = await client.table('agents').select('current_version_id').eq('agent_id', agent_id).eq('account_id', user_id).execute() | |
| if not agent_result.data: | |
| raise HTTPException(status_code=404, detail="Agent not found") | |
| agent = agent_result.data[0] | |
| agent_config = {} | |
| if agent.get('current_version_id'): | |
| version_result = await client.table('agent_versions')\ | |
| .select('config')\ | |
| .eq('version_id', agent['current_version_id'])\ | |
| .maybe_single()\ | |
| .execute() | |
| if version_result.data and version_result.data.get('config'): | |
| agent_config = version_result.data['config'] | |
| tools = agent_config.get('tools', {}) | |
| custom_mcps = tools.get('custom_mcp', []) | |
| mcp_url = request.get('url') | |
| mcp_type = request.get('type', 'sse') | |
| enabled_tools = request.get('enabled_tools', []) | |
| if not mcp_url: | |
| raise HTTPException(status_code=400, detail="MCP URL is required") | |
| updated = False | |
| for i, mcp in enumerate(custom_mcps): | |
| if mcp_type == 'composio': | |
| # For Composio, match by profile_id | |
| if (mcp.get('type') == 'composio' and | |
| mcp.get('config', {}).get('profile_id') == mcp_url): | |
| custom_mcps[i]['enabledTools'] = enabled_tools | |
| updated = True | |
| break | |
| else: | |
| if (mcp.get('customType') == mcp_type and | |
| mcp.get('config', {}).get('url') == mcp_url): | |
| custom_mcps[i]['enabledTools'] = enabled_tools | |
| updated = True | |
| break | |
| if not updated: | |
| if mcp_type == 'composio': | |
| try: | |
| from composio_integration.composio_profile_service import ComposioProfileService | |
| from services.supabase import DBConnection | |
| profile_service = ComposioProfileService(DBConnection()) | |
| profile_id = mcp_url | |
| mcp_config = await profile_service.get_mcp_config_for_agent(profile_id) | |
| mcp_config['enabledTools'] = enabled_tools | |
| custom_mcps.append(mcp_config) | |
| except Exception as e: | |
| logger.error(f"Failed to get Composio profile config: {e}") | |
| raise HTTPException(status_code=400, detail=f"Failed to get Composio profile: {str(e)}") | |
| else: | |
| new_mcp_config = { | |
| "name": f"Custom MCP ({mcp_type.upper()})", | |
| "customType": mcp_type, | |
| "type": mcp_type, | |
| "config": { | |
| "url": mcp_url | |
| }, | |
| "enabledTools": enabled_tools | |
| } | |
| custom_mcps.append(new_mcp_config) | |
| tools['custom_mcp'] = custom_mcps | |
| agent_config['tools'] = tools | |
| from agent.versioning.version_service import get_version_service | |
| try: | |
| version_service = await get_version_service() | |
| new_version = await version_service.create_version( | |
| agent_id=agent_id, | |
| user_id=user_id, | |
| system_prompt=agent_config.get('system_prompt', ''), | |
| configured_mcps=agent_config.get('tools', {}).get('mcp', []), | |
| custom_mcps=custom_mcps, | |
| agentpress_tools=agent_config.get('tools', {}).get('agentpress', {}), | |
| change_description=f"Updated custom MCP tools for {mcp_type}" | |
| ) | |
| logger.debug(f"Created version {new_version.version_id} for custom MCP tools update on agent {agent_id}") | |
| except Exception as e: | |
| logger.error(f"Failed to create version for custom MCP tools update: {e}") | |
| raise HTTPException(status_code=500, detail="Failed to save changes") | |
| return { | |
| 'success': True, | |
| 'enabled_tools': enabled_tools, | |
| 'total_tools': len(enabled_tools) | |
| } | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error updating custom MCP tools for agent {agent_id}: {e}") | |
| raise HTTPException(status_code=500, detail="Internal server error") | |
| async def get_agent_tools( | |
| agent_id: str, | |
| user_id: str = Depends(get_current_user_id_from_jwt) | |
| ): | |
| if not await is_enabled("custom_agents"): | |
| raise HTTPException(status_code=403, detail="Custom agents currently disabled") | |
| logger.debug(f"Fetching enabled tools for agent: {agent_id} by user: {user_id}") | |
| client = await db.client | |
| agent_result = await client.table('agents').select('*').eq('agent_id', agent_id).execute() | |
| if not agent_result.data: | |
| raise HTTPException(status_code=404, detail="Agent not found") | |
| agent = agent_result.data[0] | |
| if agent['account_id'] != user_id and not agent.get('is_public', False): | |
| raise HTTPException(status_code=403, detail="Access denied") | |
| version_data = None | |
| if agent.get('current_version_id'): | |
| try: | |
| version_service = await _get_version_service() | |
| version_obj = await version_service.get_version( | |
| agent_id=agent_id, | |
| version_id=agent['current_version_id'], | |
| user_id=user_id | |
| ) | |
| version_data = version_obj.to_dict() | |
| except Exception as e: | |
| logger.warning(f"Failed to fetch version data for tools endpoint: {e}") | |
| from agent.config_helper import extract_agent_config | |
| agent_config = extract_agent_config(agent, version_data) | |
| agentpress_tools_config = agent_config['agentpress_tools'] | |
| configured_mcps = agent_config['configured_mcps'] | |
| custom_mcps = agent_config['custom_mcps'] | |
| agentpress_tools = [] | |
| for name, enabled in agentpress_tools_config.items(): | |
| is_enabled_tool = bool(enabled.get('enabled', False)) if isinstance(enabled, dict) else bool(enabled) | |
| agentpress_tools.append({"name": name, "enabled": is_enabled_tool}) | |
| mcp_tools = [] | |
| for mcp in configured_mcps + custom_mcps: | |
| server = mcp.get('name') | |
| enabled_tools = mcp.get('enabledTools') or mcp.get('enabled_tools') or [] | |
| for tool_name in enabled_tools: | |
| mcp_tools.append({"name": tool_name, "server": server, "enabled": True}) | |
| return {"agentpress_tools": agentpress_tools, "mcp_tools": mcp_tools} | |
| async def get_user_threads( | |
| user_id: str = Depends(get_current_user_id_from_jwt), | |
| page: Optional[int] = Query(1, ge=1, description="Page number (1-based)"), | |
| limit: Optional[int] = Query(1000, ge=1, le=1000, description="Number of items per page (max 1000)") | |
| ): | |
| """Get all threads for the current user with associated project data.""" | |
| logger.debug(f"Fetching threads with project data for user: {user_id} (page={page}, limit={limit})") | |
| client = await db.client | |
| try: | |
| offset = (page - 1) * limit | |
| # First, get threads for the user | |
| threads_result = await client.table('threads').select('*').eq('account_id', user_id).order('created_at', desc=True).execute() | |
| if not threads_result.data: | |
| logger.debug(f"No threads found for user: {user_id}") | |
| return { | |
| "threads": [], | |
| "pagination": { | |
| "page": page, | |
| "limit": limit, | |
| "total": 0, | |
| "pages": 0 | |
| } | |
| } | |
| total_count = len(threads_result.data) | |
| # Apply pagination to threads | |
| paginated_threads = threads_result.data[offset:offset + limit] | |
| # Extract unique project IDs from threads that have them | |
| project_ids = [ | |
| thread['project_id'] for thread in paginated_threads | |
| if thread.get('project_id') | |
| ] | |
| unique_project_ids = list(set(project_ids)) if project_ids else [] | |
| # Fetch projects if we have project IDs | |
| projects_by_id = {} | |
| if unique_project_ids: | |
| from utils.query_utils import batch_query_in | |
| projects_data = await batch_query_in( | |
| client=client, | |
| table_name='projects', | |
| select_fields='*', | |
| in_field='project_id', | |
| in_values=unique_project_ids | |
| ) | |
| logger.debug(f"[API] Retrieved {len(projects_data)} projects") | |
| # Create a lookup map of projects by ID | |
| projects_by_id = { | |
| project['project_id']: project | |
| for project in projects_data | |
| } | |
| # Map threads with their associated projects | |
| mapped_threads = [] | |
| for thread in paginated_threads: | |
| project_data = None | |
| if thread.get('project_id') and thread['project_id'] in projects_by_id: | |
| project = projects_by_id[thread['project_id']] | |
| project_data = { | |
| "project_id": project['project_id'], | |
| "name": project.get('name', ''), | |
| "description": project.get('description', ''), | |
| "sandbox": project.get('sandbox', {}), | |
| "is_public": project.get('is_public', False), | |
| "created_at": project['created_at'], | |
| "updated_at": project['updated_at'] | |
| } | |
| mapped_thread = { | |
| "thread_id": thread['thread_id'], | |
| "project_id": thread.get('project_id'), | |
| "metadata": thread.get('metadata', {}), | |
| "is_public": thread.get('is_public', False), | |
| "created_at": thread['created_at'], | |
| "updated_at": thread['updated_at'], | |
| "project": project_data | |
| } | |
| mapped_threads.append(mapped_thread) | |
| total_pages = (total_count + limit - 1) // limit if total_count else 0 | |
| logger.debug(f"[API] Mapped threads for frontend: {len(mapped_threads)} threads, {len(projects_by_id)} unique projects") | |
| return { | |
| "threads": mapped_threads, | |
| "pagination": { | |
| "page": page, | |
| "limit": limit, | |
| "total": total_count, | |
| "pages": total_pages | |
| } | |
| } | |
| except Exception as e: | |
| logger.error(f"Error fetching threads for user {user_id}: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Failed to fetch threads: {str(e)}") | |
| async def get_thread( | |
| thread_id: str, | |
| user_id: str = Depends(get_current_user_id_from_jwt) | |
| ): | |
| """Get a specific thread by ID with complete related data.""" | |
| logger.debug(f"Fetching thread: {thread_id}") | |
| client = await db.client | |
| try: | |
| await verify_thread_access(client, thread_id, user_id) | |
| # Get the thread data | |
| thread_result = await client.table('threads').select('*').eq('thread_id', thread_id).execute() | |
| if not thread_result.data: | |
| raise HTTPException(status_code=404, detail="Thread not found") | |
| thread = thread_result.data[0] | |
| # Get associated project if thread has a project_id | |
| project_data = None | |
| if thread.get('project_id'): | |
| project_result = await client.table('projects').select('*').eq('project_id', thread['project_id']).execute() | |
| if project_result.data: | |
| project = project_result.data[0] | |
| logger.debug(f"[API] Raw project from DB for thread {thread_id}") | |
| project_data = { | |
| "project_id": project['project_id'], | |
| "name": project.get('name', ''), | |
| "description": project.get('description', ''), | |
| "sandbox": project.get('sandbox', {}), | |
| "is_public": project.get('is_public', False), | |
| "created_at": project['created_at'], | |
| "updated_at": project['updated_at'] | |
| } | |
| # Get message count for the thread | |
| message_count_result = await client.table('messages').select('message_id', count='exact').eq('thread_id', thread_id).execute() | |
| message_count = message_count_result.count if message_count_result.count is not None else 0 | |
| # Get recent agent runs for the thread | |
| agent_runs_result = await client.table('agent_runs').select('*').eq('thread_id', thread_id).order('created_at', desc=True).execute() | |
| agent_runs_data = [] | |
| if agent_runs_result.data: | |
| agent_runs_data = [{ | |
| "id": run['id'], | |
| "status": run.get('status', ''), | |
| "started_at": run.get('started_at'), | |
| "completed_at": run.get('completed_at'), | |
| "error": run.get('error'), | |
| "agent_id": run.get('agent_id'), | |
| "agent_version_id": run.get('agent_version_id'), | |
| "created_at": run['created_at'] | |
| } for run in agent_runs_result.data] | |
| # Map thread data for frontend (matching actual DB structure) | |
| mapped_thread = { | |
| "thread_id": thread['thread_id'], | |
| "project_id": thread.get('project_id'), | |
| "metadata": thread.get('metadata', {}), | |
| "is_public": thread.get('is_public', False), | |
| "created_at": thread['created_at'], | |
| "updated_at": thread['updated_at'], | |
| "project": project_data, | |
| "message_count": message_count, | |
| "recent_agent_runs": agent_runs_data | |
| } | |
| logger.debug(f"[API] Mapped thread for frontend: {thread_id} with {message_count} messages and {len(agent_runs_data)} recent runs") | |
| return mapped_thread | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error fetching thread {thread_id}: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Failed to fetch thread: {str(e)}") | |
| async def create_thread( | |
| name: Optional[str] = Form(None), | |
| user_id: str = Depends(get_current_user_id_from_jwt) | |
| ): | |
| """ | |
| Create a new thread without starting an agent run. | |
| [WARNING] Keep in sync with initiate endpoint. | |
| """ | |
| if not name: | |
| name = "New Project" | |
| logger.debug(f"Creating new thread with name: {name}") | |
| client = await db.client | |
| account_id = user_id # In Basejump, personal account_id is the same as user_id | |
| try: | |
| # 1. Create Project | |
| project_name = name or "New Project" | |
| project = await client.table('projects').insert({ | |
| "project_id": str(uuid.uuid4()), | |
| "account_id": account_id, | |
| "name": project_name, | |
| "created_at": datetime.now(timezone.utc).isoformat() | |
| }).execute() | |
| project_id = project.data[0]['project_id'] | |
| logger.debug(f"Created new project: {project_id}") | |
| # 2. Create Sandbox | |
| sandbox_id = None | |
| try: | |
| sandbox_pass = str(uuid.uuid4()) | |
| sandbox = await create_sandbox(sandbox_pass, project_id) | |
| sandbox_id = sandbox.id | |
| logger.debug(f"Created new sandbox {sandbox_id} for project {project_id}") | |
| # Get preview links | |
| vnc_link = await sandbox.get_preview_link(6080) | |
| website_link = await sandbox.get_preview_link(8080) | |
| vnc_url = vnc_link.url if hasattr(vnc_link, 'url') else str(vnc_link).split("url='")[1].split("'")[0] | |
| website_url = website_link.url if hasattr(website_link, 'url') else str(website_link).split("url='")[1].split("'")[0] | |
| token = None | |
| if hasattr(vnc_link, 'token'): | |
| token = vnc_link.token | |
| elif "token='" in str(vnc_link): | |
| token = str(vnc_link).split("token='")[1].split("'")[0] | |
| except Exception as e: | |
| logger.error(f"Error creating sandbox: {str(e)}") | |
| await client.table('projects').delete().eq('project_id', project_id).execute() | |
| if sandbox_id: | |
| try: | |
| await delete_sandbox(sandbox_id) | |
| except Exception as e: | |
| logger.error(f"Error deleting sandbox: {str(e)}") | |
| raise Exception("Failed to create sandbox") | |
| # Update project with sandbox info | |
| update_result = await client.table('projects').update({ | |
| 'sandbox': { | |
| 'id': sandbox_id, | |
| 'pass': sandbox_pass, | |
| 'vnc_preview': vnc_url, | |
| 'sandbox_url': website_url, | |
| 'token': token | |
| } | |
| }).eq('project_id', project_id).execute() | |
| if not update_result.data: | |
| logger.error(f"Failed to update project {project_id} with new sandbox {sandbox_id}") | |
| if sandbox_id: | |
| try: | |
| await delete_sandbox(sandbox_id) | |
| except Exception as e: | |
| logger.error(f"Error deleting sandbox: {str(e)}") | |
| raise Exception("Database update failed") | |
| # 3. Create Thread | |
| thread_data = { | |
| "thread_id": str(uuid.uuid4()), | |
| "project_id": project_id, | |
| "account_id": account_id, | |
| "created_at": datetime.now(timezone.utc).isoformat() | |
| } | |
| structlog.contextvars.bind_contextvars( | |
| thread_id=thread_data["thread_id"], | |
| project_id=project_id, | |
| account_id=account_id, | |
| ) | |
| thread = await client.table('threads').insert(thread_data).execute() | |
| thread_id = thread.data[0]['thread_id'] | |
| logger.debug(f"Created new thread: {thread_id}") | |
| logger.debug(f"Successfully created thread {thread_id} with project {project_id}") | |
| return {"thread_id": thread_id, "project_id": project_id} | |
| except Exception as e: | |
| logger.error(f"Error creating thread: {str(e)}\n{traceback.format_exc()}") | |
| # TODO: Clean up created project/thread if creation fails mid-way | |
| raise HTTPException(status_code=500, detail=f"Failed to create thread: {str(e)}") | |
| async def get_thread_messages( | |
| thread_id: str, | |
| user_id: str = Depends(get_current_user_id_from_jwt), | |
| order: str = Query("desc", description="Order by created_at: 'asc' or 'desc'") | |
| ): | |
| """Get all messages for a thread, fetching in batches of 1000 from the DB to avoid large queries.""" | |
| logger.debug(f"Fetching all messages for thread: {thread_id}, order={order}") | |
| client = await db.client | |
| await verify_thread_access(client, thread_id, user_id) | |
| try: | |
| batch_size = 1000 | |
| offset = 0 | |
| all_messages = [] | |
| while True: | |
| query = client.table('messages').select('*').eq('thread_id', thread_id) | |
| query = query.order('created_at', desc=(order == "desc")) | |
| query = query.range(offset, offset + batch_size - 1) | |
| messages_result = await query.execute() | |
| batch = messages_result.data or [] | |
| all_messages.extend(batch) | |
| logger.debug(f"Fetched batch of {len(batch)} messages (offset {offset})") | |
| if len(batch) < batch_size: | |
| break | |
| offset += batch_size | |
| return {"messages": all_messages} | |
| except Exception as e: | |
| logger.error(f"Error fetching messages for thread {thread_id}: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Failed to fetch messages: {str(e)}") | |
| async def get_agent_run( | |
| agent_run_id: str, | |
| user_id: str = Depends(get_current_user_id_from_jwt), | |
| ): | |
| """ | |
| [DEPRECATED] Get an agent run by ID. | |
| This endpoint is deprecated and may be removed in future versions. | |
| """ | |
| logger.warning(f"[DEPRECATED] Fetching agent run: {agent_run_id}") | |
| client = await db.client | |
| try: | |
| agent_run_result = await client.table('agent_runs').select('*').eq('agent_run_id', agent_run_id).eq('account_id', user_id).execute() | |
| if not agent_run_result.data: | |
| raise HTTPException(status_code=404, detail="Agent run not found") | |
| return agent_run_result.data[0] | |
| except Exception as e: | |
| logger.error(f"Error fetching agent run {agent_run_id}: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Failed to fetch agent run: {str(e)}") | |
| async def add_message_to_thread( | |
| thread_id: str, | |
| message: str, | |
| user_id: str = Depends(get_current_user_id_from_jwt), | |
| ): | |
| """Add a message to a thread""" | |
| logger.debug(f"Adding message to thread: {thread_id}") | |
| client = await db.client | |
| await verify_thread_access(client, thread_id, user_id) | |
| try: | |
| message_result = await client.table('messages').insert({ | |
| 'thread_id': thread_id, | |
| 'type': 'user', | |
| 'is_llm_message': True, | |
| 'content': { | |
| "role": "user", | |
| "content": message | |
| } | |
| }).execute() | |
| return message_result.data[0] | |
| except Exception as e: | |
| logger.error(f"Error adding message to thread {thread_id}: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Failed to add message: {str(e)}") | |
| async def create_message( | |
| thread_id: str, | |
| message_data: MessageCreateRequest, | |
| user_id: str = Depends(get_current_user_id_from_jwt) | |
| ): | |
| """Create a new message in a thread.""" | |
| logger.debug(f"Creating message in thread: {thread_id}") | |
| client = await db.client | |
| try: | |
| await verify_thread_access(client, thread_id, user_id) | |
| message_payload = { | |
| "role": "user" if message_data.type == "user" else "assistant", | |
| "content": message_data.content | |
| } | |
| insert_data = { | |
| "message_id": str(uuid.uuid4()), | |
| "thread_id": thread_id, | |
| "type": message_data.type, | |
| "is_llm_message": message_data.is_llm_message, | |
| "content": message_payload, # Store as JSONB object, not JSON string | |
| "created_at": datetime.now(timezone.utc).isoformat() | |
| } | |
| message_result = await client.table('messages').insert(insert_data).execute() | |
| if not message_result.data: | |
| raise HTTPException(status_code=500, detail="Failed to create message") | |
| logger.debug(f"Created message: {message_result.data[0]['message_id']}") | |
| return message_result.data[0] | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error creating message in thread {thread_id}: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Failed to create message: {str(e)}") | |
| async def delete_message( | |
| thread_id: str, | |
| message_id: str, | |
| user_id: str = Depends(get_current_user_id_from_jwt) | |
| ): | |
| """Delete a message from a thread.""" | |
| logger.debug(f"Deleting message from thread: {thread_id}") | |
| client = await db.client | |
| await verify_thread_access(client, thread_id, user_id) | |
| try: | |
| # Don't allow users to delete the "status" messages | |
| await client.table('messages').delete().eq('message_id', message_id).eq('is_llm_message', True).eq('thread_id', thread_id).execute() | |
| return {"message": "Message deleted successfully"} | |
| except Exception as e: | |
| logger.error(f"Error deleting message {message_id} from thread {thread_id}: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Failed to delete message: {str(e)}") | |
| async def update_agent_custom_mcps( | |
| agent_id: str, | |
| request: dict, | |
| user_id: str = Depends(get_current_user_id_from_jwt) | |
| ): | |
| logger.debug(f"Updating agent {agent_id} custom MCPs for user {user_id}") | |
| try: | |
| client = await db.client | |
| agent_result = await client.table('agents').select('current_version_id').eq('agent_id', agent_id).eq('account_id', user_id).execute() | |
| if not agent_result.data: | |
| raise HTTPException(status_code=404, detail="Agent not found") | |
| agent = agent_result.data[0] | |
| agent_config = {} | |
| if agent.get('current_version_id'): | |
| version_result = await client.table('agent_versions')\ | |
| .select('config')\ | |
| .eq('version_id', agent['current_version_id'])\ | |
| .maybe_single()\ | |
| .execute() | |
| if version_result.data and version_result.data.get('config'): | |
| agent_config = version_result.data['config'] | |
| new_custom_mcps = request.get('custom_mcps', []) | |
| if not new_custom_mcps: | |
| raise HTTPException(status_code=400, detail="custom_mcps array is required") | |
| tools = agent_config.get('tools', {}) | |
| existing_custom_mcps = tools.get('custom_mcp', []) | |
| updated = False | |
| for new_mcp in new_custom_mcps: | |
| mcp_type = new_mcp.get('type', '') | |
| if mcp_type == 'composio': | |
| profile_id = new_mcp.get('config', {}).get('profile_id') | |
| if not profile_id: | |
| continue | |
| for i, existing_mcp in enumerate(existing_custom_mcps): | |
| if (existing_mcp.get('type') == 'composio' and | |
| existing_mcp.get('config', {}).get('profile_id') == profile_id): | |
| existing_custom_mcps[i] = new_mcp | |
| updated = True | |
| break | |
| if not updated: | |
| existing_custom_mcps.append(new_mcp) | |
| updated = True | |
| else: | |
| mcp_url = new_mcp.get('config', {}).get('url') | |
| mcp_name = new_mcp.get('name', '') | |
| for i, existing_mcp in enumerate(existing_custom_mcps): | |
| if (existing_mcp.get('config', {}).get('url') == mcp_url or | |
| (mcp_name and existing_mcp.get('name') == mcp_name)): | |
| existing_custom_mcps[i] = new_mcp | |
| updated = True | |
| break | |
| if not updated: | |
| existing_custom_mcps.append(new_mcp) | |
| updated = True | |
| tools['custom_mcp'] = existing_custom_mcps | |
| agent_config['tools'] = tools | |
| from agent.versioning.version_service import get_version_service | |
| import datetime | |
| try: | |
| version_service = await get_version_service() | |
| timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") | |
| change_description = f"MCP tools update {timestamp}" | |
| new_version = await version_service.create_version( | |
| agent_id=agent_id, | |
| user_id=user_id, | |
| system_prompt=agent_config.get('system_prompt', ''), | |
| configured_mcps=agent_config.get('tools', {}).get('mcp', []), | |
| custom_mcps=existing_custom_mcps, | |
| agentpress_tools=agent_config.get('tools', {}).get('agentpress', {}), | |
| change_description=change_description | |
| ) | |
| logger.debug(f"Created version {new_version.version_id} for agent {agent_id}") | |
| total_enabled_tools = sum(len(mcp.get('enabledTools', [])) for mcp in new_custom_mcps) | |
| except Exception as e: | |
| logger.error(f"Failed to create version for custom MCP tools update: {e}") | |
| raise HTTPException(status_code=500, detail="Failed to save changes") | |
| return { | |
| 'success': True, | |
| 'data': { | |
| 'custom_mcps': existing_custom_mcps, | |
| 'total_enabled_tools': total_enabled_tools | |
| } | |
| } | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error updating agent custom MCPs: {e}") | |
| raise HTTPException(status_code=500, detail="Internal server error") | |
| async def export_presentation( | |
| request: Dict[str, Any] = Body(...), | |
| user_id: str = Depends(get_current_user_id_from_jwt) | |
| ): | |
| try: | |
| presentation_name = request.get("presentation_name") | |
| export_format = request.get("format", "pptx") | |
| project_id = request.get("project_id") | |
| if not presentation_name: | |
| raise HTTPException(status_code=400, detail="presentation_name is required") | |
| if not project_id: | |
| raise HTTPException(status_code=400, detail="project_id is required") | |
| if db is None: | |
| db_conn = DBConnection() | |
| client = await db_conn.client | |
| else: | |
| client = await db.client | |
| project_result = await client.table('projects').select('sandbox').eq('project_id', project_id).execute() | |
| if not project_result.data: | |
| raise HTTPException(status_code=404, detail="Project not found") | |
| sandbox_data = project_result.data[0].get('sandbox', {}) | |
| sandbox_id = sandbox_data.get('id') | |
| if not sandbox_id: | |
| raise HTTPException(status_code=400, detail="No sandbox found for this project") | |
| thread_manager = ThreadManager() | |
| presentation_tool = SandboxPresentationTool( | |
| project_id=project_id, | |
| thread_manager=thread_manager | |
| ) | |
| result = await presentation_tool.export_presentation( | |
| presentation_name=presentation_name, | |
| format=export_format | |
| ) | |
| if result.success: | |
| import json | |
| import urllib.parse | |
| data = json.loads(result.output) | |
| export_file = data.get("export_file") | |
| logger.debug(f"Export file from tool: {export_file}") | |
| logger.debug(f"Sandbox ID: {sandbox_id}") | |
| if export_file: | |
| from fastapi.responses import Response | |
| from sandbox.api import get_sandbox_by_id_safely, verify_sandbox_access | |
| try: | |
| file_path = export_file.replace("/workspace/", "").lstrip("/") | |
| full_path = f"/workspace/{file_path}" | |
| sandbox = await get_sandbox_by_id_safely(client, sandbox_id) | |
| file_content = await sandbox.fs.download_file(full_path) | |
| return { | |
| "success": True, | |
| "message": data.get("message"), | |
| "file_content": base64.b64encode(file_content).decode('utf-8'), | |
| "filename": export_file.split('/')[-1], | |
| "export_file": data.get("export_file"), | |
| "format": data.get("format"), | |
| "file_size": data.get("file_size") | |
| } | |
| except Exception as e: | |
| logger.error(f"Failed to read exported file: {str(e)}") | |
| return { | |
| "success": False, | |
| "error": f"Failed to read exported file: {str(e)}" | |
| } | |
| else: | |
| return { | |
| "success": True, | |
| "message": data.get("message"), | |
| "download_url": data.get("download_url"), | |
| "export_file": data.get("export_file"), | |
| "format": data.get("format"), | |
| "file_size": data.get("file_size") | |
| } | |
| else: | |
| raise HTTPException(status_code=400, detail=result.output or "Export failed") | |
| except Exception as e: | |
| logger.error(f"Export presentation error: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Failed to export presentation: {str(e)}") | |
| async def upload_agent_profile_image( | |
| file: UploadFile = File(...), | |
| user_id: str = Depends(get_current_user_id_from_jwt) | |
| ): | |
| try: | |
| content_type = file.content_type or "image/png" | |
| image_bytes = await file.read() | |
| from utils.s3_upload_utils import upload_image_bytes | |
| public_url = await upload_image_bytes(image_bytes=image_bytes, content_type=content_type, bucket_name="agent-profile-images") | |
| return {"url": public_url} | |
| except Exception as e: | |
| logger.error(f"Failed to upload agent profile image for user {user_id}: {e}") | |
| raise HTTPException(status_code=500, detail="Failed to upload profile image") | |