| | """ |
| | Factor Agent - Enterprise-Grade FastAPI Backend |
| | =============================================== |
| | High-performance AI agent backend with: |
| | - Rate limiting and request throttling |
| | - Comprehensive monitoring and logging |
| | - Caching layer for improved performance |
| | - Health checks and graceful degradation |
| | - WebSocket support for real-time communication |
| | """ |
| |
|
| | import asyncio |
| | import json |
| | import logging |
| | import os |
| | import time |
| | import uuid |
| | from collections import defaultdict |
| | from contextlib import asynccontextmanager |
| | from datetime import datetime, timedelta |
| | from typing import Any, Dict, List, Optional |
| |
|
| | import redis.asyncio as redis |
| | from dotenv import load_dotenv |
| | from fastapi import FastAPI, HTTPException, Request, WebSocket, WebSocketDisconnect, status |
| | from fastapi.middleware.cors import CORSMiddleware |
| | from fastapi.middleware.gzip import GZipMiddleware |
| | from fastapi.responses import JSONResponse |
| | from prometheus_client import Counter, Histogram, generate_latest, CONTENT_TYPE_LATEST |
| | from slowapi import Limiter, _rate_limit_exceeded_handler |
| | from slowapi.errors import RateLimitExceeded |
| | from slowapi.util import get_remote_address |
| | from websocket import manager as connection_manager |
| |
|
| | |
| | load_dotenv() |
| |
|
| | |
| | logging.basicConfig( |
| | level=logging.INFO, |
| | format='%(asctime)s - %(name)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s' |
| | ) |
| | logger = logging.getLogger("factor_agent") |
| |
|
| | |
| | try: |
| | REQUEST_COUNT = Counter('factor_requests_total', 'Total requests', ['method', 'endpoint', 'status']) |
| | REQUEST_DURATION = Histogram('factor_request_duration_seconds', 'Request duration', ['method', 'endpoint']) |
| | ACTIVE_SESSIONS = Counter('factor_active_sessions', 'Number of active sessions') |
| | TOOL_EXECUTIONS = Counter('factor_tool_executions_total', 'Tool executions', ['tool_name', 'status']) |
| | WEBSOCKET_CONNECTIONS = Counter('factor_websocket_connections', 'WebSocket connections', ['event_type']) |
| | except ValueError: |
| | |
| | from prometheus_client import REGISTRY |
| | REQUEST_COUNT = REGISTRY._names_to_collectors.get('factor_requests_total') |
| | REQUEST_DURATION = REGISTRY._names_to_collectors.get('factor_request_duration_seconds') |
| | ACTIVE_SESSIONS = REGISTRY._names_to_collectors.get('factor_active_sessions') |
| | TOOL_EXECUTIONS = REGISTRY._names_to_collectors.get('factor_tool_executions_total') |
| | WEBSOCKET_CONNECTIONS = REGISTRY._names_to_collectors.get('factor_websocket_connections') |
| |
|
| | |
| | limiter = Limiter(key_func=get_remote_address) |
| |
|
| | |
| | _cache: Dict[str, Any] = {} |
| | _cache_timestamps: Dict[str, float] = {} |
| |
|
| |
|
| | class CacheManager: |
| | """Simple cache manager with TTL support""" |
| | |
| | @staticmethod |
| | def get(key: str) -> Optional[Any]: |
| | if key in _cache: |
| | timestamp = _cache_timestamps.get(key, 0) |
| | ttl = get_config().cache.ttl_seconds |
| | if time.time() - timestamp < ttl: |
| | return _cache[key] |
| | else: |
| | CacheManager.delete(key) |
| | return None |
| | |
| | @staticmethod |
| | def set(key: str, value: Any) -> None: |
| | _cache[key] = value |
| | _cache_timestamps[key] = time.time() |
| | |
| | @staticmethod |
| | def delete(key: str) -> None: |
| | _cache.pop(key, None) |
| | _cache_timestamps.pop(key, None) |
| | |
| | @staticmethod |
| | def clear() -> None: |
| | _cache.clear() |
| | _cache_timestamps.clear() |
| |
|
| |
|
| | class MetricsCollector: |
| | """Collect and expose application metrics""" |
| | |
| | def __init__(self): |
| | self.start_time = time.time() |
| | self.request_counts = defaultdict(int) |
| | self.error_counts = defaultdict(int) |
| | self.tool_executions = defaultdict(int) |
| | |
| | def record_request(self, method: str, endpoint: str, status_code: int, duration: float): |
| | self.request_counts[f"{method}:{endpoint}"] += 1 |
| | REQUEST_COUNT.labels(method=method, endpoint=endpoint, status=status_code).inc() |
| | REQUEST_DURATION.labels(method=method, endpoint=endpoint).observe(duration) |
| | |
| | def record_error(self, error_type: str): |
| | self.error_counts[error_type] += 1 |
| | |
| | def record_tool_execution(self, tool_name: str, success: bool): |
| | status = "success" if success else "failure" |
| | self.tool_executions[f"{tool_name}:{status}"] += 1 |
| | TOOL_EXECUTIONS.labels(tool_name=tool_name, status=status).inc() |
| | |
| | def get_uptime(self) -> float: |
| | return time.time() - self.start_time |
| | |
| | def get_stats(self) -> Dict[str, Any]: |
| | return { |
| | "uptime_seconds": self.get_uptime(), |
| | "total_requests": dict(self.request_counts), |
| | "total_errors": dict(self.error_counts), |
| | "tool_executions": dict(self.tool_executions), |
| | "cache_size": len(_cache), |
| | "active_websockets": len(connection_manager.active_connections), |
| | } |
| |
|
| |
|
| | metrics = MetricsCollector() |
| |
|
| |
|
| | |
| | class ConnectionManager: |
| | """Manage WebSocket connections""" |
| | |
| | def __init__(self): |
| | self.active_connections: Dict[str, WebSocket] = {} |
| | self.connection_metadata: Dict[str, Dict[str, Any]] = {} |
| | |
| | async def connect(self, websocket: WebSocket, client_id: str): |
| | await websocket.accept() |
| | self.active_connections[client_id] = websocket |
| | self.connection_metadata[client_id] = { |
| | "connected_at": datetime.now(timezone.utc).isoformat(), |
| | "ip": websocket.client.host if websocket.client else "unknown", |
| | } |
| | WEBSOCKET_CONNECTIONS.labels(event_type="connect").inc() |
| | logger.info(f"WebSocket connected: {client_id}") |
| | |
| | def disconnect(self, client_id: str): |
| | if client_id in self.active_connections: |
| | del self.active_connections[client_id] |
| | self.connection_metadata.pop(client_id, None) |
| | WEBSOCKET_CONNECTIONS.labels(event_type="disconnect").inc() |
| | logger.info(f"WebSocket disconnected: {client_id}") |
| | |
| | async def send_message(self, client_id: str, message: Dict[str, Any]): |
| | if client_id in self.active_connections: |
| | try: |
| | await self.active_connections[client_id].send_json(message) |
| | except Exception as e: |
| | logger.error(f"Failed to send message to {client_id}: {e}") |
| | self.disconnect(client_id) |
| | |
| | async def broadcast(self, message: Dict[str, Any]): |
| | disconnected = [] |
| | for client_id, connection in self.active_connections.items(): |
| | try: |
| | await connection.send_json(message) |
| | except Exception as e: |
| | logger.error(f"Failed to broadcast to {client_id}: {e}") |
| | disconnected.append(client_id) |
| | |
| | for client_id in disconnected: |
| | self.disconnect(client_id) |
| |
|
| |
|
| |
|
| |
|
| | |
| | from datetime import timezone |
| | from agent.config import get_config, FactorConfig |
| | from agent.core.agent_loop import submission_loop, Handlers, process_submission |
| | from agent.core.session import Session, Event, OpType |
| |
|
| | |
| | config = get_config() |
| |
|
| |
|
| | |
| | def invoke_agent(message_text: str) -> str: |
| | """ |
| | Invoke the agent with a message and collect the response. |
| | Returns the agent's text response. This runs in a thread. |
| | """ |
| | try: |
| | logger.info(f"invoke_agent called with: {message_text[:100]}") |
| | |
| | loop = asyncio.new_event_loop() |
| | asyncio.set_event_loop(loop) |
| | |
| | try: |
| | |
| | event_queue: asyncio.Queue = asyncio.Queue() |
| | |
| | |
| | from agent.core.tools import ToolRouter |
| | tool_router = ToolRouter(config.mcp_servers) |
| | |
| | |
| | class SimpleNamespace: |
| | pass |
| | |
| | config_adapter = SimpleNamespace() |
| | config_adapter.model_name = config.model.name |
| | config_adapter.openrouter_enabled = config.openrouter_enabled |
| | config_adapter.openrouter_model = config.openrouter_model |
| | config_adapter.save_sessions = config.save_sessions |
| | config_adapter.session_dataset_repo = config.session_dataset_repo |
| | config_adapter.auto_save_interval = config.auto_save_interval |
| | config_adapter.security = config.security |
| | config_adapter.mcp_servers = config.mcp_servers |
| | |
| | logger.info(f"Created config adapter with model: {config_adapter.model_name}") |
| | |
| | session = Session(event_queue, config=config_adapter, tool_router=tool_router) |
| | logger.info(f"Created session: {session.session_id}") |
| | |
| | |
| | logger.info(f"Calling Handlers.run_agent with message: {message_text[:50]}") |
| | response = loop.run_until_complete(Handlers.run_agent(session, message_text)) |
| | |
| | logger.info(f"run_agent returned: {type(response).__name__} = {str(response)[:100]}") |
| | |
| | |
| | if response: |
| | logger.info(f"Got direct response, returning it") |
| | return response |
| | |
| | |
| | logger.info(f"No direct response, checking events...") |
| | logger.info(f"Event queue size: {event_queue.qsize()}") |
| | logger.info(f"Session logged events: {len(session.logged_events)}") |
| | |
| | collected_text = "" |
| | while not event_queue.empty(): |
| | try: |
| | event = event_queue.get_nowait() |
| | logger.debug(f"Got event: {event.event_type}") |
| | if event.event_type == "assistant_chunk": |
| | chunk_content = event.data.get("content", "") |
| | collected_text += chunk_content |
| | logger.debug(f"Collected chunk: {chunk_content}") |
| | except asyncio.QueueEmpty: |
| | break |
| | |
| | logger.info(f"Collected text length: {len(collected_text)}") |
| | |
| | |
| | if collected_text: |
| | logger.info(f"Returning collected text") |
| | return collected_text |
| | |
| | |
| | for event_item in session.logged_events: |
| | logger.debug(f"Logged event: {event_item.get('event_type')}") |
| | if event_item.get("event_type") == "error": |
| | error_msg = event_item.get('data', {}).get('error', 'Unknown error') |
| | logger.info(f"Found error event: {error_msg}") |
| | return f"Agent error: {error_msg}" |
| | |
| | logger.info(f"No response found, returning default message") |
| | return "No response generated" |
| | |
| | finally: |
| | loop.close() |
| | |
| | except Exception as e: |
| | logger.error(f"Error invoking agent: {e}", exc_info=True) |
| | return f"Error: {str(e)}" |
| |
|
| |
|
| | @asynccontextmanager |
| | async def lifespan(app: FastAPI): |
| | """Application lifespan handler""" |
| | logger.info(f"🚀 Starting Factor Agent v{config.app_version}") |
| | logger.info(f"📊 Environment: {config.environment}") |
| | logger.info(f"🤖 Model: {config.model.name}") |
| | logger.info(f"⚡ YOLO Mode: {config.security.yolo_mode}") |
| | logger.info(f"🔒 Rate Limiting: {config.rate_limit.enabled}") |
| | logger.info(f"💾 Caching: {config.cache.enabled}") |
| | |
| | |
| | yield |
| | |
| | |
| | logger.info("🛑 Shutting down Factor Agent...") |
| | CacheManager.clear() |
| |
|
| |
|
| | |
| | app = FastAPI( |
| | title="Factor Agent API", |
| | description="Enterprise-Grade AI Agent Platform", |
| | version=config.app_version, |
| | docs_url="/docs" if config.environment != "production" else None, |
| | redoc_url="/redoc" if config.environment != "production" else None, |
| | lifespan=lifespan, |
| | ) |
| |
|
| | |
| | app.state.limiter = limiter |
| | app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) |
| |
|
| | |
| | app.add_middleware(GZipMiddleware, minimum_size=1000) |
| |
|
| | |
| | app.add_middleware( |
| | CORSMiddleware, |
| | allow_origins=["*"], |
| | allow_credentials=True, |
| | allow_methods=["*"], |
| | allow_headers=["*"], |
| | expose_headers=["*"], |
| | ) |
| |
|
| |
|
| | |
| | @app.middleware("http") |
| | async def add_process_time_header(request: Request, call_next): |
| | start_time = time.time() |
| | response = await call_next(request) |
| | process_time = time.time() - start_time |
| | response.headers["X-Process-Time"] = str(process_time) |
| | |
| | |
| | metrics.record_request( |
| | method=request.method, |
| | endpoint=request.url.path, |
| | status_code=response.status_code, |
| | duration=process_time |
| | ) |
| | |
| | return response |
| |
|
| |
|
| | |
| | @app.middleware("http") |
| | async def error_handling_middleware(request: Request, call_next): |
| | try: |
| | return await call_next(request) |
| | except Exception as e: |
| | logger.exception(f"Unhandled error: {e}") |
| | metrics.record_error(type(e).__name__) |
| | return JSONResponse( |
| | status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, |
| | content={ |
| | "error": "Internal server error", |
| | "detail": str(e) if config.environment == "development" else "An unexpected error occurred" |
| | } |
| | ) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | @app.get("/") |
| | async def root(): |
| | """Root endpoint""" |
| | return { |
| | "name": "Factor Agent API", |
| | "version": config.app_version, |
| | "environment": config.environment, |
| | "status": "operational", |
| | "yolo_mode": config.security.yolo_mode, |
| | "docs": "/docs" if config.environment != "production" else None, |
| | } |
| |
|
| |
|
| | @app.get("/health") |
| | async def health_check(): |
| | """Health check endpoint""" |
| | return { |
| | "status": "healthy", |
| | "timestamp": datetime.now(timezone.utc).isoformat(), |
| | "version": config.app_version, |
| | "uptime": metrics.get_uptime(), |
| | } |
| |
|
| |
|
| | @app.get("/health/detailed") |
| | async def detailed_health_check(): |
| | """Detailed health check with component status""" |
| | return { |
| | "status": "healthy", |
| | "timestamp": datetime.now(timezone.utc).isoformat(), |
| | "version": config.app_version, |
| | "uptime": metrics.get_uptime(), |
| | "components": { |
| | "api": "operational", |
| | "websocket": "operational" if len(connection_manager.active_connections) < 1000 else "degraded", |
| | "cache": "operational" if len(_cache) < config.cache.max_size else "full", |
| | }, |
| | "metrics": metrics.get_stats(), |
| | } |
| |
|
| |
|
| | @app.get("/metrics") |
| | async def prometheus_metrics(): |
| | """Prometheus metrics endpoint""" |
| | from prometheus_client import generate_latest, CONTENT_TYPE_LATEST |
| | return JSONResponse( |
| | content=generate_latest().decode(), |
| | media_type=CONTENT_TYPE_LATEST |
| | ) |
| |
|
| |
|
| | @app.get("/config") |
| | async def get_configuration(): |
| | """Get current configuration (sanitized)""" |
| | return { |
| | "app_name": config.app_name, |
| | "version": config.app_version, |
| | "environment": config.environment, |
| | "model": { |
| | "name": config.model.name, |
| | "provider": config.model.provider, |
| | "temperature": config.model.temperature, |
| | "max_tokens": config.model.max_tokens, |
| | }, |
| | "security": { |
| | "yolo_mode": config.security.yolo_mode, |
| | "max_execution_time": config.security.max_execution_time, |
| | }, |
| | "features": config.features, |
| | } |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | |
| | sessions: Dict[str, Dict[str, Any]] = {} |
| |
|
| |
|
| | @app.post("/api/session", status_code=status.HTTP_201_CREATED) |
| | @limiter.limit("10/minute") |
| | async def create_session(request: Request): |
| | """Create a new agent session""" |
| | session_id = f"sess_{uuid.uuid4().hex[:12]}" |
| | |
| | if len(sessions) >= config.max_total_sessions: |
| | raise HTTPException( |
| | status_code=status.HTTP_503_SERVICE_UNAVAILABLE, |
| | detail="Maximum number of sessions reached" |
| | ) |
| | |
| | sessions[session_id] = { |
| | "id": session_id, |
| | "created_at": datetime.now(timezone.utc).isoformat(), |
| | "last_activity": datetime.now(timezone.utc).isoformat(), |
| | "messages": [], |
| | "status": "active", |
| | "metadata": {}, |
| | } |
| | |
| | ACTIVE_SESSIONS.inc() |
| | logger.info(f"Session created: {session_id}") |
| | |
| | return { |
| | "session_id": session_id, |
| | "status": "created", |
| | "yolo_mode": config.security.yolo_mode, |
| | } |
| |
|
| |
|
| | @app.get("/api/sessions") |
| | async def list_sessions(): |
| | """List all active sessions""" |
| | return { |
| | "sessions": [ |
| | { |
| | "id": s["id"], |
| | "created_at": s["created_at"], |
| | "last_activity": s["last_activity"], |
| | "status": s["status"], |
| | "message_count": len(s["messages"]), |
| | } |
| | for s in sessions.values() |
| | ], |
| | "total": len(sessions), |
| | "max": config.max_total_sessions, |
| | } |
| |
|
| |
|
| | @app.get("/api/session/{session_id}") |
| | async def get_session(session_id: str): |
| | """Get session details""" |
| | if session_id not in sessions: |
| | raise HTTPException(status_code=404, detail="Session not found") |
| | |
| | session = sessions[session_id] |
| | return { |
| | "id": session["id"], |
| | "created_at": session["created_at"], |
| | "last_activity": session["last_activity"], |
| | "status": session["status"], |
| | "messages": session["messages"], |
| | } |
| |
|
| |
|
| | @app.delete("/api/session/{session_id}") |
| | async def delete_session(session_id: str): |
| | """Delete a session""" |
| | if session_id not in sessions: |
| | raise HTTPException(status_code=404, detail="Session not found") |
| | |
| | del sessions[session_id] |
| | logger.info(f"Session deleted: {session_id}") |
| | |
| | return {"status": "deleted", "session_id": session_id} |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | @app.post("/api/submit") |
| | @limiter.limit("30/minute") |
| | async def submit_message(request: Request): |
| | """Submit a message to the agent""" |
| | try: |
| | data = await request.json() |
| | session_id = data.get("session_id") |
| | text = data.get("text", "").strip() |
| | |
| | if not session_id or session_id not in sessions: |
| | raise HTTPException(status_code=404, detail="Session not found") |
| | |
| | if not text: |
| | raise HTTPException(status_code=400, detail="Message text is required") |
| | |
| | |
| | message = { |
| | "id": f"msg_{uuid.uuid4().hex[:8]}", |
| | "role": "user", |
| | "content": text, |
| | "timestamp": datetime.now(timezone.utc).isoformat(), |
| | } |
| | sessions[session_id]["messages"].append(message) |
| | sessions[session_id]["last_activity"] = datetime.now(timezone.utc).isoformat() |
| | |
| | logger.info(f"Message submitted to session {session_id}: {text[:50]}...") |
| | |
| | return { |
| | "status": "submitted", |
| | "session_id": session_id, |
| | "message_id": message["id"], |
| | } |
| | |
| | except json.JSONDecodeError: |
| | raise HTTPException(status_code=400, detail="Invalid JSON") |
| |
|
| |
|
| | @app.post("/api/execute") |
| | @limiter.limit("20/minute") |
| | async def execute_code(request: Request): |
| | """Execute code (simplified version - full implementation in tools)""" |
| | try: |
| | data = await request.json() |
| | command = data.get("command", "").strip() |
| | timeout = data.get("timeout", config.security.max_execution_time) |
| | |
| | if not command: |
| | raise HTTPException(status_code=400, detail="Command is required") |
| | |
| | |
| | for pattern in config.security.blocked_patterns: |
| | if pattern in command: |
| | raise HTTPException( |
| | status_code=403, |
| | detail=f"Command contains blocked pattern: {pattern}" |
| | ) |
| | |
| | |
| | import subprocess |
| | |
| | result = subprocess.run( |
| | command, |
| | shell=True, |
| | capture_output=True, |
| | text=True, |
| | timeout=min(timeout, config.security.max_execution_time), |
| | ) |
| | |
| | metrics.record_tool_execution("execute_code", result.returncode == 0) |
| | |
| | return { |
| | "success": result.returncode == 0, |
| | "output": result.stdout, |
| | "error": result.stderr if result.stderr else None, |
| | "exit_code": result.returncode, |
| | } |
| | |
| | except subprocess.TimeoutExpired: |
| | metrics.record_tool_execution("execute_code", False) |
| | raise HTTPException(status_code=408, detail="Command execution timed out") |
| | except Exception as e: |
| | metrics.record_tool_execution("execute_code", False) |
| | raise HTTPException(status_code=500, detail=str(e)) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | @app.post("/api/tools/web_search") |
| | @limiter.limit("10/minute") |
| | async def web_search_endpoint(request: Request): |
| | """Web search endpoint""" |
| | try: |
| | data = await request.json() |
| | query = data.get("query", "").strip() |
| | max_results = min(data.get("max_results", 5), 10) |
| | |
| | if not query: |
| | raise HTTPException(status_code=400, detail="Query is required") |
| | |
| | |
| | cache_key = f"web_search:{query}:{max_results}" |
| | cached = CacheManager.get(cache_key) |
| | if cached: |
| | logger.info(f"Cache hit for web search: {query}") |
| | return cached |
| | |
| | |
| | from duckduckgo_search import DDGS |
| | |
| | with DDGS() as ddgs: |
| | results = list(ddgs.text(query, max_results=max_results)) |
| | |
| | response = { |
| | "query": query, |
| | "results": results, |
| | "count": len(results), |
| | } |
| | |
| | |
| | CacheManager.set(cache_key, response) |
| | metrics.record_tool_execution("web_search", True) |
| | |
| | return response |
| | |
| | except Exception as e: |
| | metrics.record_tool_execution("web_search", False) |
| | raise HTTPException(status_code=500, detail=str(e)) |
| |
|
| |
|
| | @app.post("/api/tools/generate_image") |
| | @limiter.limit("5/minute") |
| | async def generate_image_endpoint(request: Request): |
| | """Generate image endpoint""" |
| | try: |
| | data = await request.json() |
| | prompt = data.get("prompt", "").strip() |
| | size = data.get("size", "512x512") |
| | |
| | if not prompt: |
| | raise HTTPException(status_code=400, detail="Prompt is required") |
| | |
| | import httpx |
| | |
| | width, height = map(int, size.split("x")) |
| | encoded_prompt = prompt.replace(" ", "%20") |
| | image_url = f"https://image.pollinations.ai/prompt/{encoded_prompt}?width={width}&height={height}&nologo=true" |
| | |
| | async with httpx.AsyncClient(timeout=60.0) as client: |
| | response = await client.get(image_url) |
| | response.raise_for_status() |
| | image_data = response.content |
| | |
| | import base64 |
| | image_b64 = base64.b64encode(image_data).decode() |
| | |
| | metrics.record_tool_execution("generate_image", True) |
| | |
| | return { |
| | "success": True, |
| | "image_data": f"data:image/png;base64,{image_b64}", |
| | "prompt": prompt, |
| | "size": size, |
| | } |
| | |
| | except Exception as e: |
| | metrics.record_tool_execution("generate_image", False) |
| | raise HTTPException(status_code=500, detail=str(e)) |
| |
|
| |
|
| | @app.post("/api/tools/execute_code") |
| | @limiter.limit("20/minute") |
| | async def execute_code_tool_endpoint(request: Request): |
| | """Execute code tool endpoint""" |
| | try: |
| | data = await request.json() |
| | command = data.get("command", "").strip() |
| | timeout = data.get("timeout", config.security.max_execution_time) |
| |
|
| | if not command: |
| | raise HTTPException(status_code=400, detail="Command is required") |
| |
|
| | |
| | for pattern in config.security.blocked_patterns: |
| | if pattern in command: |
| | raise HTTPException( |
| | status_code=403, |
| | detail=f"Command contains blocked pattern: {pattern}" |
| | ) |
| |
|
| | |
| | import subprocess |
| |
|
| | result = subprocess.run( |
| | command, |
| | shell=True, |
| | capture_output=True, |
| | text=True, |
| | timeout=min(timeout, config.security.max_execution_time), |
| | ) |
| |
|
| | metrics.record_tool_execution("execute_code", result.returncode == 0) |
| |
|
| | return { |
| | "success": result.returncode == 0, |
| | "output": result.stdout, |
| | "error": result.stderr if result.stderr else None, |
| | "exit_code": result.returncode, |
| | } |
| |
|
| | except subprocess.TimeoutExpired: |
| | metrics.record_tool_execution("execute_code", False) |
| | raise HTTPException(status_code=408, detail="Command execution timed out") |
| | except Exception as e: |
| | metrics.record_tool_execution("execute_code", False) |
| | raise HTTPException(status_code=500, detail=str(e)) |
| |
|
| |
|
| | @app.post("/api/tools/create_slides") |
| | @limiter.limit("5/minute") |
| | async def create_slides_endpoint(request: Request): |
| | """Create slides/presentation endpoint""" |
| | try: |
| | data = await request.json() |
| | title = data.get("title", "Presentation") |
| | slides = data.get("slides", []) |
| | filename = data.get("filename", "presentation.pptx") |
| |
|
| | if not slides: |
| | raise HTTPException(status_code=400, detail="Slides are required") |
| |
|
| | |
| | from agent.tools.slides_tool import create_slides_handler |
| |
|
| | result = await create_slides_handler({ |
| | "title": title, |
| | "slides": slides, |
| | "filename": filename |
| | }) |
| |
|
| | if len(result) >= 3: |
| | output, success, file_data = result[0], result[1], result[2] |
| | else: |
| | output, success = result[0], result[1] |
| | file_data = None |
| |
|
| | metrics.record_tool_execution("create_slides", success) |
| |
|
| | return { |
| | "success": success, |
| | "message": output, |
| | "file": file_data, |
| | } |
| |
|
| | except Exception as e: |
| | metrics.record_tool_execution("create_slides", False) |
| | raise HTTPException(status_code=500, detail=str(e)) |
| |
|
| |
|
| | @app.post("/api/tools/create_document") |
| | @limiter.limit("5/minute") |
| | async def create_document_endpoint(request: Request): |
| | """Create document endpoint""" |
| | try: |
| | data = await request.json() |
| | title = data.get("title", "Document") |
| | sections = data.get("sections", []) |
| | filename = data.get("filename", "document.docx") |
| |
|
| | if not sections: |
| | raise HTTPException(status_code=400, detail="Sections are required") |
| |
|
| | |
| | from agent.tools.document_tool import create_document_handler |
| |
|
| | result = await create_document_handler({ |
| | "title": title, |
| | "sections": sections, |
| | "filename": filename |
| | }) |
| |
|
| | if len(result) >= 3: |
| | output, success, file_data = result[0], result[1], result[2] |
| | else: |
| | output, success = result[0], result[1] |
| | file_data = None |
| |
|
| | metrics.record_tool_execution("create_document", success) |
| |
|
| | return { |
| | "success": success, |
| | "message": output, |
| | "file": file_data, |
| | } |
| |
|
| | except Exception as e: |
| | metrics.record_tool_execution("create_document", False) |
| | raise HTTPException(status_code=500, detail=str(e)) |
| |
|
| |
|
| | @app.post("/api/tools/terminal") |
| | @limiter.limit("10/minute") |
| | async def terminal_endpoint(request: Request): |
| | """Execute terminal command endpoint""" |
| | try: |
| | data = await request.json() |
| | command = data.get("command", "").strip() |
| | timeout = data.get("timeout", 30) |
| |
|
| | if not command: |
| | raise HTTPException(status_code=400, detail="Command is required") |
| |
|
| | |
| | from agent.tools.terminal_tool import execute_terminal_handler |
| |
|
| | result = await execute_terminal_handler({ |
| | "command": command, |
| | "timeout": timeout |
| | }) |
| |
|
| | output, success = result[0], result[1] |
| |
|
| | metrics.record_tool_execution("terminal", success) |
| |
|
| | return { |
| | "success": success, |
| | "output": output, |
| | } |
| |
|
| | except Exception as e: |
| | metrics.record_tool_execution("terminal", False) |
| | raise HTTPException(status_code=500, detail=str(e)) |
| |
|
| |
|
| | @app.post("/api/tools/browser/screenshot") |
| | @limiter.limit("10/minute") |
| | async def browser_screenshot_endpoint(request: Request): |
| | """Browser screenshot endpoint""" |
| | try: |
| | data = await request.json() |
| | url = data.get("url", "").strip() |
| | wait_for = data.get("wait_for") |
| | viewport_width = data.get("viewport_width", 1280) |
| | viewport_height = data.get("viewport_height", 720) |
| | full_page = data.get("full_page", False) |
| |
|
| | if not url: |
| | raise HTTPException(status_code=400, detail="URL is required") |
| |
|
| | |
| | from agent.tools.browser_tool import browser_screenshot_handler |
| |
|
| | result = await browser_screenshot_handler({ |
| | "url": url, |
| | "wait_for": wait_for, |
| | "viewport_width": viewport_width, |
| | "viewport_height": viewport_height, |
| | "full_page": full_page |
| | }) |
| |
|
| | if len(result) >= 3: |
| | output, success, file_data = result[0], result[1], result[2] |
| | else: |
| | output, success = result[0], result[1] |
| | file_data = None |
| |
|
| | metrics.record_tool_execution("browser_screenshot", success) |
| |
|
| | return { |
| | "success": success, |
| | "output": output, |
| | "file": file_data, |
| | } |
| |
|
| | except Exception as e: |
| | metrics.record_tool_execution("browser_screenshot", False) |
| | raise HTTPException(status_code=500, detail=str(e)) |
| |
|
| |
|
| | @app.post("/api/tools/browser/scrape") |
| | @limiter.limit("10/minute") |
| | async def browser_scrape_endpoint(request: Request): |
| | """Browser scrape endpoint""" |
| | try: |
| | data = await request.json() |
| | url = data.get("url", "").strip() |
| | extract_script = data.get("extract_script", "").strip() |
| | wait_for = data.get("wait_for") |
| |
|
| | if not url: |
| | raise HTTPException(status_code=400, detail="URL is required") |
| |
|
| | if not extract_script: |
| | raise HTTPException(status_code=400, detail="Extract script is required") |
| |
|
| | |
| | from agent.tools.browser_tool import browser_scrape_handler |
| |
|
| | result = await browser_scrape_handler({ |
| | "url": url, |
| | "extract_script": extract_script, |
| | "wait_for": wait_for |
| | }) |
| |
|
| | output, success = result[0], result[1] |
| |
|
| | metrics.record_tool_execution("browser_scrape", success) |
| |
|
| | return { |
| | "success": success, |
| | "output": output, |
| | } |
| |
|
| | except Exception as e: |
| | metrics.record_tool_execution("browser_scrape", False) |
| | raise HTTPException(status_code=500, detail=str(e)) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | @app.post("/api/sessions/{session_id}/files") |
| | @limiter.limit("20/minute") |
| | async def create_session_file_endpoint(session_id: str, request: Request): |
| | """Create file in session folder""" |
| | try: |
| | data = await request.json() |
| | path = data.get("path", "").strip() |
| | content = data.get("content", "") |
| | folder = data.get("folder", "files") |
| |
|
| | if not path: |
| | raise HTTPException(status_code=400, detail="Path is required") |
| |
|
| | |
| | from session_manager import session_manager |
| | agent_session = session_manager.sessions.get(session_id) |
| |
|
| | if not agent_session: |
| | raise HTTPException(status_code=404, detail="Session not found") |
| |
|
| | from agent.tools.file_system_tool import SessionFileOperator |
| | from pathlib import Path |
| |
|
| | fs = SessionFileOperator(Path(agent_session.session.session_folder)) |
| | result = await fs.create_file(path, content, folder) |
| |
|
| | |
| | try: |
| | from agent.core.session import Event |
| |
|
| | file_data = { |
| | "path": result.path, |
| | "full_path": result.full_path, |
| | "size": result.size, |
| | "created": result.created, |
| | } |
| |
|
| | |
| | if hasattr(agent_session.session, "add_generated_file"): |
| | agent_session.session.add_generated_file(file_data) |
| |
|
| | if hasattr(agent_session.session, "send_event"): |
| | await agent_session.session.send_event( |
| | Event(event_type="file_generated", data={"file": file_data}) |
| | ) |
| | except Exception: |
| | logger.exception("Failed to emit file_generated event from API endpoint") |
| |
|
| | return { |
| | "success": True, |
| | "path": result.path, |
| | "size": result.size, |
| | } |
| |
|
| | except PermissionError as e: |
| | raise HTTPException(status_code=403, detail=str(e)) |
| | except Exception as e: |
| | raise HTTPException(status_code=500, detail=str(e)) |
| |
|
| |
|
| | @app.get("/api/sessions/{session_id}/files") |
| | @limiter.limit("30/minute") |
| | async def list_session_files_endpoint(session_id: str, request: Request, folder: str = "files"): |
| | """List files in session folder""" |
| | try: |
| | from session_manager import session_manager |
| | agent_session = session_manager.sessions.get(session_id) |
| |
|
| | if not agent_session: |
| | raise HTTPException(status_code=404, detail="Session not found") |
| |
|
| | from agent.tools.file_system_tool import SessionFileOperator |
| | from pathlib import Path |
| |
|
| | fs = SessionFileOperator(Path(agent_session.session.session_folder)) |
| | listing = await fs.list_files(folder) |
| |
|
| | return { |
| | "path": listing.path, |
| | "items": listing.items, |
| | } |
| |
|
| | except Exception as e: |
| | raise HTTPException(status_code=500, detail=str(e)) |
| |
|
| |
|
| | @app.get("/api/sessions/{session_id}/tree") |
| | @limiter.limit("20/minute") |
| | async def get_session_tree_endpoint(session_id: str, request: Request): |
| | """Get file tree for session""" |
| | try: |
| | from session_manager import session_manager |
| | agent_session = session_manager.sessions.get(session_id) |
| |
|
| | if not agent_session: |
| | raise HTTPException(status_code=404, detail="Session not found") |
| |
|
| | from agent.tools.file_system_tool import SessionFileOperator |
| | from pathlib import Path |
| |
|
| | fs = SessionFileOperator(Path(agent_session.session.session_folder)) |
| | tree = await fs.get_file_tree() |
| |
|
| | return { |
| | "session_id": tree.session_id, |
| | "structure": tree.structure, |
| | } |
| |
|
| | except Exception as e: |
| | raise HTTPException(status_code=500, detail=str(e)) |
| |
|
| |
|
| | @app.get("/api/sessions/{session_id}/file") |
| | @limiter.limit("30/minute") |
| | async def get_session_file_content_endpoint(session_id: str, request: Request, path: str): |
| | """Get file content from session""" |
| | try: |
| | if not path: |
| | raise HTTPException(status_code=400, detail="Path is required") |
| |
|
| | from session_manager import session_manager |
| | agent_session = session_manager.sessions.get(session_id) |
| |
|
| | if not agent_session: |
| | raise HTTPException(status_code=404, detail="Session not found") |
| |
|
| | from agent.tools.file_system_tool import SessionFileOperator |
| | from pathlib import Path |
| |
|
| | fs = SessionFileOperator(Path(agent_session.session.session_folder)) |
| | content = await fs.read_file(path) |
| |
|
| | return { |
| | "path": path, |
| | "content": content, |
| | } |
| |
|
| | except FileNotFoundError: |
| | raise HTTPException(status_code=404, detail="File not found") |
| | except PermissionError as e: |
| | raise HTTPException(status_code=403, detail=str(e)) |
| | except Exception as e: |
| | raise HTTPException(status_code=500, detail=str(e)) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | @app.websocket("/api/ws/{session_id}") |
| | async def websocket_endpoint(websocket: WebSocket, session_id: str): |
| | """WebSocket endpoint for real-time communication""" |
| |
|
| | |
| | from session_manager import session_manager |
| |
|
| | |
| | if session_id not in sessions and session_id not in session_manager.sessions: |
| | await websocket.accept() |
| | await websocket.close(code=4004, reason="Session not found") |
| | return |
| |
|
| | |
| | await connection_manager.connect(websocket, session_id) |
| |
|
| | try: |
| | |
| | await connection_manager.send_event(session_id, "ready", { |
| | "message": "Factor Agent connected", |
| | "yolo_mode": config.security.yolo_mode, |
| | }) |
| |
|
| | while True: |
| | data = await websocket.receive_json() |
| | message_text = data.get("text", "").strip() |
| |
|
| | if data.get("type") == "ping": |
| | await connection_manager.send_event(session_id, "pong", {}) |
| | continue |
| |
|
| | if not message_text: |
| | await connection_manager.send_event(session_id, "error", {"message": "Message text is required"}) |
| | continue |
| |
|
| | logger.info(f"Received WebSocket message for session {session_id}: {message_text[:50]}...") |
| |
|
| | |
| | await connection_manager.send_event(session_id, "message_acknowledged", {"message": "Processing your request..."}) |
| |
|
| | try: |
| | |
| | agent_session = session_manager.sessions.get(session_id) |
| | |
| | if not agent_session: |
| | |
| | logger.info(f"Creating agent session for {session_id} on first WebSocket message") |
| | try: |
| | |
| | |
| | |
| | |
| | |
| | |
| | from pathlib import Path |
| | |
| | |
| | sessions_root = Path(__file__).parent / "sessions" |
| | sessions_root.mkdir(parents=True, exist_ok=True) |
| | session_folder = sessions_root / session_id |
| | session_folder.mkdir(parents=True, exist_ok=True) |
| | (session_folder / "files").mkdir(exist_ok=True) |
| | (session_folder / "documents").mkdir(exist_ok=True) |
| | |
| | |
| | submission_queue = asyncio.Queue() |
| | event_queue = asyncio.Queue() |
| | |
| | |
| | from agent.core.tools import ToolRouter |
| | from agent.core.session import Session |
| | |
| | def _create_session_sync(): |
| | tool_router = ToolRouter(config.mcp_servers) |
| | |
| | class ConfigAdapter: |
| | model_name = config.model.name |
| | save_sessions = config.save_sessions |
| | session_dataset_repo = config.session_dataset_repo |
| | auto_save_interval = config.auto_save_interval |
| | mcp_servers = config.mcp_servers |
| | |
| | session = Session(event_queue, config=ConfigAdapter(), tool_router=tool_router) |
| | return tool_router, session |
| | |
| | tool_router, session = await asyncio.to_thread(_create_session_sync) |
| | session.session_folder = str(session_folder) |
| | |
| | |
| | from session_manager import AgentSession |
| | |
| | agent_session = AgentSession( |
| | session_id=session_id, |
| | session=session, |
| | tool_router=tool_router, |
| | submission_queue=submission_queue, |
| | user_id="dev", |
| | ) |
| | |
| | session_manager.sessions[session_id] = agent_session |
| | |
| | |
| | task = asyncio.create_task( |
| | session_manager._run_session(session_id, submission_queue, event_queue, tool_router) |
| | ) |
| | agent_session.task = task |
| | |
| | logger.info(f"Created agent session {session_id} for WebSocket") |
| | except Exception as e: |
| | logger.error(f"Failed to create agent session: {e}", exc_info=True) |
| | await connection_manager.send_event(session_id, "error", {"message": "Failed to initialize agent session"}) |
| | return |
| | |
| | |
| | if agent_session: |
| | success = await session_manager.submit_user_input(session_id, message_text) |
| | if not success: |
| | await connection_manager.send_event(session_id, "error", {"message": "Failed to submit to session"}) |
| |
|
| | except Exception as e: |
| | logger.error(f"Error processing message in WebSocket for session {session_id}: {e}") |
| | await connection_manager.send_event(session_id, "error", {"message": f"Error processing request: {str(e)}"}) |
| |
|
| | except WebSocketDisconnect: |
| | connection_manager.disconnect(session_id) |
| | except Exception as e: |
| | logger.error(f"WebSocket error for session {session_id}: {e}") |
| | connection_manager.disconnect(session_id) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | if __name__ == "__main__": |
| | import uvicorn |
| | |
| | port = int(os.environ.get("PORT", 7860)) |
| | host = os.environ.get("HOST", "0.0.0.0") |
| | |
| | uvicorn.run( |
| | "main:app", |
| | host=host, |
| | port=port, |
| | reload=config.environment == "development", |
| | workers=1, |
| | log_level="info", |
| | ) |
| |
|