water3 / main.py
onewayto's picture
Upload 187 files
db41152 verified
"""
Factor Agent - Enterprise-Grade FastAPI Backend
===============================================
High-performance AI agent backend with:
- Rate limiting and request throttling
- Comprehensive monitoring and logging
- Caching layer for improved performance
- Health checks and graceful degradation
- WebSocket support for real-time communication
"""
import asyncio
import json
import logging
import os
import time
import uuid
from collections import defaultdict
from contextlib import asynccontextmanager
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional
import redis.asyncio as redis
from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException, Request, WebSocket, WebSocketDisconnect, status
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware
from fastapi.responses import JSONResponse
from prometheus_client import Counter, Histogram, generate_latest, CONTENT_TYPE_LATEST
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.errors import RateLimitExceeded
from slowapi.util import get_remote_address
from websocket import manager as connection_manager
# Load environment variables
load_dotenv()
# Configure structured logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s'
)
logger = logging.getLogger("factor_agent")
# Prometheus metrics - use try/except to handle re-registration in multiprocess mode
try:
REQUEST_COUNT = Counter('factor_requests_total', 'Total requests', ['method', 'endpoint', 'status'])
REQUEST_DURATION = Histogram('factor_request_duration_seconds', 'Request duration', ['method', 'endpoint'])
ACTIVE_SESSIONS = Counter('factor_active_sessions', 'Number of active sessions')
TOOL_EXECUTIONS = Counter('factor_tool_executions_total', 'Tool executions', ['tool_name', 'status'])
WEBSOCKET_CONNECTIONS = Counter('factor_websocket_connections', 'WebSocket connections', ['event_type'])
except ValueError:
# Metrics already registered, import from existing registry
from prometheus_client import REGISTRY
REQUEST_COUNT = REGISTRY._names_to_collectors.get('factor_requests_total')
REQUEST_DURATION = REGISTRY._names_to_collectors.get('factor_request_duration_seconds')
ACTIVE_SESSIONS = REGISTRY._names_to_collectors.get('factor_active_sessions')
TOOL_EXECUTIONS = REGISTRY._names_to_collectors.get('factor_tool_executions_total')
WEBSOCKET_CONNECTIONS = REGISTRY._names_to_collectors.get('factor_websocket_connections')
# Rate limiter
limiter = Limiter(key_func=get_remote_address)
# In-memory cache (can be replaced with Redis)
_cache: Dict[str, Any] = {}
_cache_timestamps: Dict[str, float] = {}
class CacheManager:
"""Simple cache manager with TTL support"""
@staticmethod
def get(key: str) -> Optional[Any]:
if key in _cache:
timestamp = _cache_timestamps.get(key, 0)
ttl = get_config().cache.ttl_seconds
if time.time() - timestamp < ttl:
return _cache[key]
else:
CacheManager.delete(key)
return None
@staticmethod
def set(key: str, value: Any) -> None:
_cache[key] = value
_cache_timestamps[key] = time.time()
@staticmethod
def delete(key: str) -> None:
_cache.pop(key, None)
_cache_timestamps.pop(key, None)
@staticmethod
def clear() -> None:
_cache.clear()
_cache_timestamps.clear()
class MetricsCollector:
"""Collect and expose application metrics"""
def __init__(self):
self.start_time = time.time()
self.request_counts = defaultdict(int)
self.error_counts = defaultdict(int)
self.tool_executions = defaultdict(int)
def record_request(self, method: str, endpoint: str, status_code: int, duration: float):
self.request_counts[f"{method}:{endpoint}"] += 1
REQUEST_COUNT.labels(method=method, endpoint=endpoint, status=status_code).inc()
REQUEST_DURATION.labels(method=method, endpoint=endpoint).observe(duration)
def record_error(self, error_type: str):
self.error_counts[error_type] += 1
def record_tool_execution(self, tool_name: str, success: bool):
status = "success" if success else "failure"
self.tool_executions[f"{tool_name}:{status}"] += 1
TOOL_EXECUTIONS.labels(tool_name=tool_name, status=status).inc()
def get_uptime(self) -> float:
return time.time() - self.start_time
def get_stats(self) -> Dict[str, Any]:
return {
"uptime_seconds": self.get_uptime(),
"total_requests": dict(self.request_counts),
"total_errors": dict(self.error_counts),
"tool_executions": dict(self.tool_executions),
"cache_size": len(_cache),
"active_websockets": len(connection_manager.active_connections),
}
metrics = MetricsCollector()
# WebSocket connection manager
class ConnectionManager:
"""Manage WebSocket connections"""
def __init__(self):
self.active_connections: Dict[str, WebSocket] = {}
self.connection_metadata: Dict[str, Dict[str, Any]] = {}
async def connect(self, websocket: WebSocket, client_id: str):
await websocket.accept()
self.active_connections[client_id] = websocket
self.connection_metadata[client_id] = {
"connected_at": datetime.now(timezone.utc).isoformat(),
"ip": websocket.client.host if websocket.client else "unknown",
}
WEBSOCKET_CONNECTIONS.labels(event_type="connect").inc()
logger.info(f"WebSocket connected: {client_id}")
def disconnect(self, client_id: str):
if client_id in self.active_connections:
del self.active_connections[client_id]
self.connection_metadata.pop(client_id, None)
WEBSOCKET_CONNECTIONS.labels(event_type="disconnect").inc()
logger.info(f"WebSocket disconnected: {client_id}")
async def send_message(self, client_id: str, message: Dict[str, Any]):
if client_id in self.active_connections:
try:
await self.active_connections[client_id].send_json(message)
except Exception as e:
logger.error(f"Failed to send message to {client_id}: {e}")
self.disconnect(client_id)
async def broadcast(self, message: Dict[str, Any]):
disconnected = []
for client_id, connection in self.active_connections.items():
try:
await connection.send_json(message)
except Exception as e:
logger.error(f"Failed to broadcast to {client_id}: {e}")
disconnected.append(client_id)
for client_id in disconnected:
self.disconnect(client_id)
# Import Factor Agent modules
from datetime import timezone
from agent.config import get_config, FactorConfig
from agent.core.agent_loop import submission_loop, Handlers, process_submission
from agent.core.session import Session, Event, OpType
# Get configuration
config = get_config()
# Simple wrapper to invoke agent with a message
def invoke_agent(message_text: str) -> str:
"""
Invoke the agent with a message and collect the response.
Returns the agent's text response. This runs in a thread.
"""
try:
logger.info(f"invoke_agent called with: {message_text[:100]}")
# Get event loop (might be different thread)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
# Create event queue to collect responses
event_queue: asyncio.Queue = asyncio.Queue()
# Create a tool router with MCP servers from config
from agent.core.tools import ToolRouter
tool_router = ToolRouter(config.mcp_servers)
# Create a config adapter that provides the expected interface
class SimpleNamespace:
pass
config_adapter = SimpleNamespace()
config_adapter.model_name = config.model.name
config_adapter.openrouter_enabled = config.openrouter_enabled
config_adapter.openrouter_model = config.openrouter_model
config_adapter.save_sessions = config.save_sessions
config_adapter.session_dataset_repo = config.session_dataset_repo
config_adapter.auto_save_interval = config.auto_save_interval
config_adapter.security = config.security
config_adapter.mcp_servers = config.mcp_servers
logger.info(f"Created config adapter with model: {config_adapter.model_name}")
session = Session(event_queue, config=config_adapter, tool_router=tool_router)
logger.info(f"Created session: {session.session_id}")
# Run the agent with the message
logger.info(f"Calling Handlers.run_agent with message: {message_text[:50]}")
response = loop.run_until_complete(Handlers.run_agent(session, message_text))
logger.info(f"run_agent returned: {type(response).__name__} = {str(response)[:100]}")
# If we got a direct response, return it
if response:
logger.info(f"Got direct response, returning it")
return response
# Otherwise, collect assistant chunks from events
logger.info(f"No direct response, checking events...")
logger.info(f"Event queue size: {event_queue.qsize()}")
logger.info(f"Session logged events: {len(session.logged_events)}")
collected_text = ""
while not event_queue.empty():
try:
event = event_queue.get_nowait()
logger.debug(f"Got event: {event.event_type}")
if event.event_type == "assistant_chunk":
chunk_content = event.data.get("content", "")
collected_text += chunk_content
logger.debug(f"Collected chunk: {chunk_content}")
except asyncio.QueueEmpty:
break
logger.info(f"Collected text length: {len(collected_text)}")
# If we collected chunks, return them
if collected_text:
logger.info(f"Returning collected text")
return collected_text
# If nothing collected, check for error events in logged events
for event_item in session.logged_events:
logger.debug(f"Logged event: {event_item.get('event_type')}")
if event_item.get("event_type") == "error":
error_msg = event_item.get('data', {}).get('error', 'Unknown error')
logger.info(f"Found error event: {error_msg}")
return f"Agent error: {error_msg}"
logger.info(f"No response found, returning default message")
return "No response generated"
finally:
loop.close()
except Exception as e:
logger.error(f"Error invoking agent: {e}", exc_info=True)
return f"Error: {str(e)}"
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Application lifespan handler"""
logger.info(f"🚀 Starting Factor Agent v{config.app_version}")
logger.info(f"📊 Environment: {config.environment}")
logger.info(f"🤖 Model: {config.model.name}")
logger.info(f"⚡ YOLO Mode: {config.security.yolo_mode}")
logger.info(f"🔒 Rate Limiting: {config.rate_limit.enabled}")
logger.info(f"💾 Caching: {config.cache.enabled}")
# Startup tasks
yield
# Shutdown tasks
logger.info("🛑 Shutting down Factor Agent...")
CacheManager.clear()
# Create FastAPI app
app = FastAPI(
title="Factor Agent API",
description="Enterprise-Grade AI Agent Platform",
version=config.app_version,
docs_url="/docs" if config.environment != "production" else None,
redoc_url="/redoc" if config.environment != "production" else None,
lifespan=lifespan,
)
# Add rate limiter
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
# Add middleware
app.add_middleware(GZipMiddleware, minimum_size=1000)
# CORS middleware - Allow all origins for HF Spaces compatibility
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
expose_headers=["*"],
)
# Request timing middleware
@app.middleware("http")
async def add_process_time_header(request: Request, call_next):
start_time = time.time()
response = await call_next(request)
process_time = time.time() - start_time
response.headers["X-Process-Time"] = str(process_time)
# Record metrics
metrics.record_request(
method=request.method,
endpoint=request.url.path,
status_code=response.status_code,
duration=process_time
)
return response
# Error handling middleware
@app.middleware("http")
async def error_handling_middleware(request: Request, call_next):
try:
return await call_next(request)
except Exception as e:
logger.exception(f"Unhandled error: {e}")
metrics.record_error(type(e).__name__)
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content={
"error": "Internal server error",
"detail": str(e) if config.environment == "development" else "An unexpected error occurred"
}
)
# ============================================================================
# API Routes
# ============================================================================
@app.get("/")
async def root():
"""Root endpoint"""
return {
"name": "Factor Agent API",
"version": config.app_version,
"environment": config.environment,
"status": "operational",
"yolo_mode": config.security.yolo_mode,
"docs": "/docs" if config.environment != "production" else None,
}
@app.get("/health")
async def health_check():
"""Health check endpoint"""
return {
"status": "healthy",
"timestamp": datetime.now(timezone.utc).isoformat(),
"version": config.app_version,
"uptime": metrics.get_uptime(),
}
@app.get("/health/detailed")
async def detailed_health_check():
"""Detailed health check with component status"""
return {
"status": "healthy",
"timestamp": datetime.now(timezone.utc).isoformat(),
"version": config.app_version,
"uptime": metrics.get_uptime(),
"components": {
"api": "operational",
"websocket": "operational" if len(connection_manager.active_connections) < 1000 else "degraded",
"cache": "operational" if len(_cache) < config.cache.max_size else "full",
},
"metrics": metrics.get_stats(),
}
@app.get("/metrics")
async def prometheus_metrics():
"""Prometheus metrics endpoint"""
from prometheus_client import generate_latest, CONTENT_TYPE_LATEST
return JSONResponse(
content=generate_latest().decode(),
media_type=CONTENT_TYPE_LATEST
)
@app.get("/config")
async def get_configuration():
"""Get current configuration (sanitized)"""
return {
"app_name": config.app_name,
"version": config.app_version,
"environment": config.environment,
"model": {
"name": config.model.name,
"provider": config.model.provider,
"temperature": config.model.temperature,
"max_tokens": config.model.max_tokens,
},
"security": {
"yolo_mode": config.security.yolo_mode,
"max_execution_time": config.security.max_execution_time,
},
"features": config.features,
}
# ============================================================================
# Session Management Routes
# ============================================================================
# In-memory session store (replace with Redis in production)
sessions: Dict[str, Dict[str, Any]] = {}
@app.post("/api/session", status_code=status.HTTP_201_CREATED)
@limiter.limit("10/minute")
async def create_session(request: Request):
"""Create a new agent session"""
session_id = f"sess_{uuid.uuid4().hex[:12]}"
if len(sessions) >= config.max_total_sessions:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Maximum number of sessions reached"
)
sessions[session_id] = {
"id": session_id,
"created_at": datetime.now(timezone.utc).isoformat(),
"last_activity": datetime.now(timezone.utc).isoformat(),
"messages": [],
"status": "active",
"metadata": {},
}
ACTIVE_SESSIONS.inc()
logger.info(f"Session created: {session_id}")
return {
"session_id": session_id,
"status": "created",
"yolo_mode": config.security.yolo_mode,
}
@app.get("/api/sessions")
async def list_sessions():
"""List all active sessions"""
return {
"sessions": [
{
"id": s["id"],
"created_at": s["created_at"],
"last_activity": s["last_activity"],
"status": s["status"],
"message_count": len(s["messages"]),
}
for s in sessions.values()
],
"total": len(sessions),
"max": config.max_total_sessions,
}
@app.get("/api/session/{session_id}")
async def get_session(session_id: str):
"""Get session details"""
if session_id not in sessions:
raise HTTPException(status_code=404, detail="Session not found")
session = sessions[session_id]
return {
"id": session["id"],
"created_at": session["created_at"],
"last_activity": session["last_activity"],
"status": session["status"],
"messages": session["messages"],
}
@app.delete("/api/session/{session_id}")
async def delete_session(session_id: str):
"""Delete a session"""
if session_id not in sessions:
raise HTTPException(status_code=404, detail="Session not found")
del sessions[session_id]
logger.info(f"Session deleted: {session_id}")
return {"status": "deleted", "session_id": session_id}
# ============================================================================
# Agent Interaction Routes
# ============================================================================
@app.post("/api/submit")
@limiter.limit("30/minute")
async def submit_message(request: Request):
"""Submit a message to the agent"""
try:
data = await request.json()
session_id = data.get("session_id")
text = data.get("text", "").strip()
if not session_id or session_id not in sessions:
raise HTTPException(status_code=404, detail="Session not found")
if not text:
raise HTTPException(status_code=400, detail="Message text is required")
# Add message to session
message = {
"id": f"msg_{uuid.uuid4().hex[:8]}",
"role": "user",
"content": text,
"timestamp": datetime.now(timezone.utc).isoformat(),
}
sessions[session_id]["messages"].append(message)
sessions[session_id]["last_activity"] = datetime.now(timezone.utc).isoformat()
logger.info(f"Message submitted to session {session_id}: {text[:50]}...")
return {
"status": "submitted",
"session_id": session_id,
"message_id": message["id"],
}
except json.JSONDecodeError:
raise HTTPException(status_code=400, detail="Invalid JSON")
@app.post("/api/execute")
@limiter.limit("20/minute")
async def execute_code(request: Request):
"""Execute code (simplified version - full implementation in tools)"""
try:
data = await request.json()
command = data.get("command", "").strip()
timeout = data.get("timeout", config.security.max_execution_time)
if not command:
raise HTTPException(status_code=400, detail="Command is required")
# Security check - block dangerous commands
for pattern in config.security.blocked_patterns:
if pattern in command:
raise HTTPException(
status_code=403,
detail=f"Command contains blocked pattern: {pattern}"
)
# Execute command (simplified - full implementation would use subprocess)
import subprocess
result = subprocess.run(
command,
shell=True,
capture_output=True,
text=True,
timeout=min(timeout, config.security.max_execution_time),
)
metrics.record_tool_execution("execute_code", result.returncode == 0)
return {
"success": result.returncode == 0,
"output": result.stdout,
"error": result.stderr if result.stderr else None,
"exit_code": result.returncode,
}
except subprocess.TimeoutExpired:
metrics.record_tool_execution("execute_code", False)
raise HTTPException(status_code=408, detail="Command execution timed out")
except Exception as e:
metrics.record_tool_execution("execute_code", False)
raise HTTPException(status_code=500, detail=str(e))
# ============================================================================
# Tool Routes
# ============================================================================
@app.post("/api/tools/web_search")
@limiter.limit("10/minute")
async def web_search_endpoint(request: Request):
"""Web search endpoint"""
try:
data = await request.json()
query = data.get("query", "").strip()
max_results = min(data.get("max_results", 5), 10)
if not query:
raise HTTPException(status_code=400, detail="Query is required")
# Check cache
cache_key = f"web_search:{query}:{max_results}"
cached = CacheManager.get(cache_key)
if cached:
logger.info(f"Cache hit for web search: {query}")
return cached
# Perform search
from duckduckgo_search import DDGS
with DDGS() as ddgs:
results = list(ddgs.text(query, max_results=max_results))
response = {
"query": query,
"results": results,
"count": len(results),
}
# Cache result
CacheManager.set(cache_key, response)
metrics.record_tool_execution("web_search", True)
return response
except Exception as e:
metrics.record_tool_execution("web_search", False)
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/tools/generate_image")
@limiter.limit("5/minute")
async def generate_image_endpoint(request: Request):
"""Generate image endpoint"""
try:
data = await request.json()
prompt = data.get("prompt", "").strip()
size = data.get("size", "512x512")
if not prompt:
raise HTTPException(status_code=400, detail="Prompt is required")
import httpx
width, height = map(int, size.split("x"))
encoded_prompt = prompt.replace(" ", "%20")
image_url = f"https://image.pollinations.ai/prompt/{encoded_prompt}?width={width}&height={height}&nologo=true"
async with httpx.AsyncClient(timeout=60.0) as client:
response = await client.get(image_url)
response.raise_for_status()
image_data = response.content
import base64
image_b64 = base64.b64encode(image_data).decode()
metrics.record_tool_execution("generate_image", True)
return {
"success": True,
"image_data": f"data:image/png;base64,{image_b64}",
"prompt": prompt,
"size": size,
}
except Exception as e:
metrics.record_tool_execution("generate_image", False)
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/tools/execute_code")
@limiter.limit("20/minute")
async def execute_code_tool_endpoint(request: Request):
"""Execute code tool endpoint"""
try:
data = await request.json()
command = data.get("command", "").strip()
timeout = data.get("timeout", config.security.max_execution_time)
if not command:
raise HTTPException(status_code=400, detail="Command is required")
# Security check - block dangerous commands
for pattern in config.security.blocked_patterns:
if pattern in command:
raise HTTPException(
status_code=403,
detail=f"Command contains blocked pattern: {pattern}"
)
# Execute command
import subprocess
result = subprocess.run(
command,
shell=True,
capture_output=True,
text=True,
timeout=min(timeout, config.security.max_execution_time),
)
metrics.record_tool_execution("execute_code", result.returncode == 0)
return {
"success": result.returncode == 0,
"output": result.stdout,
"error": result.stderr if result.stderr else None,
"exit_code": result.returncode,
}
except subprocess.TimeoutExpired:
metrics.record_tool_execution("execute_code", False)
raise HTTPException(status_code=408, detail="Command execution timed out")
except Exception as e:
metrics.record_tool_execution("execute_code", False)
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/tools/create_slides")
@limiter.limit("5/minute")
async def create_slides_endpoint(request: Request):
"""Create slides/presentation endpoint"""
try:
data = await request.json()
title = data.get("title", "Presentation")
slides = data.get("slides", [])
filename = data.get("filename", "presentation.pptx")
if not slides:
raise HTTPException(status_code=400, detail="Slides are required")
# Import and call the slides tool
from agent.tools.slides_tool import create_slides_handler
result = await create_slides_handler({
"title": title,
"slides": slides,
"filename": filename
})
if len(result) >= 3:
output, success, file_data = result[0], result[1], result[2]
else:
output, success = result[0], result[1]
file_data = None
metrics.record_tool_execution("create_slides", success)
return {
"success": success,
"message": output,
"file": file_data,
}
except Exception as e:
metrics.record_tool_execution("create_slides", False)
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/tools/create_document")
@limiter.limit("5/minute")
async def create_document_endpoint(request: Request):
"""Create document endpoint"""
try:
data = await request.json()
title = data.get("title", "Document")
sections = data.get("sections", [])
filename = data.get("filename", "document.docx")
if not sections:
raise HTTPException(status_code=400, detail="Sections are required")
# Import and call the document tool
from agent.tools.document_tool import create_document_handler
result = await create_document_handler({
"title": title,
"sections": sections,
"filename": filename
})
if len(result) >= 3:
output, success, file_data = result[0], result[1], result[2]
else:
output, success = result[0], result[1]
file_data = None
metrics.record_tool_execution("create_document", success)
return {
"success": success,
"message": output,
"file": file_data,
}
except Exception as e:
metrics.record_tool_execution("create_document", False)
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/tools/terminal")
@limiter.limit("10/minute")
async def terminal_endpoint(request: Request):
"""Execute terminal command endpoint"""
try:
data = await request.json()
command = data.get("command", "").strip()
timeout = data.get("timeout", 30)
if not command:
raise HTTPException(status_code=400, detail="Command is required")
# Import and call the terminal tool
from agent.tools.terminal_tool import execute_terminal_handler
result = await execute_terminal_handler({
"command": command,
"timeout": timeout
})
output, success = result[0], result[1]
metrics.record_tool_execution("terminal", success)
return {
"success": success,
"output": output,
}
except Exception as e:
metrics.record_tool_execution("terminal", False)
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/tools/browser/screenshot")
@limiter.limit("10/minute")
async def browser_screenshot_endpoint(request: Request):
"""Browser screenshot endpoint"""
try:
data = await request.json()
url = data.get("url", "").strip()
wait_for = data.get("wait_for")
viewport_width = data.get("viewport_width", 1280)
viewport_height = data.get("viewport_height", 720)
full_page = data.get("full_page", False)
if not url:
raise HTTPException(status_code=400, detail="URL is required")
# Import and call the browser tool
from agent.tools.browser_tool import browser_screenshot_handler
result = await browser_screenshot_handler({
"url": url,
"wait_for": wait_for,
"viewport_width": viewport_width,
"viewport_height": viewport_height,
"full_page": full_page
})
if len(result) >= 3:
output, success, file_data = result[0], result[1], result[2]
else:
output, success = result[0], result[1]
file_data = None
metrics.record_tool_execution("browser_screenshot", success)
return {
"success": success,
"output": output,
"file": file_data,
}
except Exception as e:
metrics.record_tool_execution("browser_screenshot", False)
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/tools/browser/scrape")
@limiter.limit("10/minute")
async def browser_scrape_endpoint(request: Request):
"""Browser scrape endpoint"""
try:
data = await request.json()
url = data.get("url", "").strip()
extract_script = data.get("extract_script", "").strip()
wait_for = data.get("wait_for")
if not url:
raise HTTPException(status_code=400, detail="URL is required")
if not extract_script:
raise HTTPException(status_code=400, detail="Extract script is required")
# Import and call the browser tool
from agent.tools.browser_tool import browser_scrape_handler
result = await browser_scrape_handler({
"url": url,
"extract_script": extract_script,
"wait_for": wait_for
})
output, success = result[0], result[1]
metrics.record_tool_execution("browser_scrape", success)
return {
"success": success,
"output": output,
}
except Exception as e:
metrics.record_tool_execution("browser_scrape", False)
raise HTTPException(status_code=500, detail=str(e))
# ============================================================================
# Session File Management Endpoints
# ============================================================================
@app.post("/api/sessions/{session_id}/files")
@limiter.limit("20/minute")
async def create_session_file_endpoint(session_id: str, request: Request):
"""Create file in session folder"""
try:
data = await request.json()
path = data.get("path", "").strip()
content = data.get("content", "")
folder = data.get("folder", "files")
if not path:
raise HTTPException(status_code=400, detail="Path is required")
# Get session folder
from session_manager import session_manager
agent_session = session_manager.sessions.get(session_id)
if not agent_session:
raise HTTPException(status_code=404, detail="Session not found")
from agent.tools.file_system_tool import SessionFileOperator
from pathlib import Path
fs = SessionFileOperator(Path(agent_session.session.session_folder))
result = await fs.create_file(path, content, folder)
# Emit file_generated event to session so frontend updates immediately
try:
from agent.core.session import Event
file_data = {
"path": result.path,
"full_path": result.full_path,
"size": result.size,
"created": result.created,
}
# Track on session and emit event if possible
if hasattr(agent_session.session, "add_generated_file"):
agent_session.session.add_generated_file(file_data)
if hasattr(agent_session.session, "send_event"):
await agent_session.session.send_event(
Event(event_type="file_generated", data={"file": file_data})
)
except Exception:
logger.exception("Failed to emit file_generated event from API endpoint")
return {
"success": True,
"path": result.path,
"size": result.size,
}
except PermissionError as e:
raise HTTPException(status_code=403, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/api/sessions/{session_id}/files")
@limiter.limit("30/minute")
async def list_session_files_endpoint(session_id: str, request: Request, folder: str = "files"):
"""List files in session folder"""
try:
from session_manager import session_manager
agent_session = session_manager.sessions.get(session_id)
if not agent_session:
raise HTTPException(status_code=404, detail="Session not found")
from agent.tools.file_system_tool import SessionFileOperator
from pathlib import Path
fs = SessionFileOperator(Path(agent_session.session.session_folder))
listing = await fs.list_files(folder)
return {
"path": listing.path,
"items": listing.items,
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/api/sessions/{session_id}/tree")
@limiter.limit("20/minute")
async def get_session_tree_endpoint(session_id: str, request: Request):
"""Get file tree for session"""
try:
from session_manager import session_manager
agent_session = session_manager.sessions.get(session_id)
if not agent_session:
raise HTTPException(status_code=404, detail="Session not found")
from agent.tools.file_system_tool import SessionFileOperator
from pathlib import Path
fs = SessionFileOperator(Path(agent_session.session.session_folder))
tree = await fs.get_file_tree()
return {
"session_id": tree.session_id,
"structure": tree.structure,
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/api/sessions/{session_id}/file")
@limiter.limit("30/minute")
async def get_session_file_content_endpoint(session_id: str, request: Request, path: str):
"""Get file content from session"""
try:
if not path:
raise HTTPException(status_code=400, detail="Path is required")
from session_manager import session_manager
agent_session = session_manager.sessions.get(session_id)
if not agent_session:
raise HTTPException(status_code=404, detail="Session not found")
from agent.tools.file_system_tool import SessionFileOperator
from pathlib import Path
fs = SessionFileOperator(Path(agent_session.session.session_folder))
content = await fs.read_file(path)
return {
"path": path,
"content": content,
}
except FileNotFoundError:
raise HTTPException(status_code=404, detail="File not found")
except PermissionError as e:
raise HTTPException(status_code=403, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# ============================================================================
# WebSocket Endpoint
# ============================================================================
@app.websocket("/api/ws/{session_id}")
async def websocket_endpoint(websocket: WebSocket, session_id: str):
"""WebSocket endpoint for real-time communication"""
# Use the global session_manager if available
from session_manager import session_manager
# Accept only if session exists in either local sessions or session_manager
if session_id not in sessions and session_id not in session_manager.sessions:
await websocket.accept()
await websocket.close(code=4004, reason="Session not found")
return
# Register the WebSocket under the canonical session_id so event forwarding works
await connection_manager.connect(websocket, session_id)
try:
# Send initial ready event via the shared manager
await connection_manager.send_event(session_id, "ready", {
"message": "Factor Agent connected",
"yolo_mode": config.security.yolo_mode,
})
while True:
data = await websocket.receive_json()
message_text = data.get("text", "").strip()
if data.get("type") == "ping":
await connection_manager.send_event(session_id, "pong", {})
continue
if not message_text:
await connection_manager.send_event(session_id, "error", {"message": "Message text is required"})
continue
logger.info(f"Received WebSocket message for session {session_id}: {message_text[:50]}...")
# Acknowledge receipt
await connection_manager.send_event(session_id, "message_acknowledged", {"message": "Processing your request..."})
try:
# Get or create a managed AgentSession so events flow through the forwarder
agent_session = session_manager.sessions.get(session_id)
if not agent_session:
# If no agent session exists, create one in session_manager
logger.info(f"Creating agent session for {session_id} on first WebSocket message")
try:
# Register this session_id in session_manager by creating a new session
# and mapping it. For now, create in session_manager and track the mapping.
# Simple approach: if session is in local dict but not in manager,
# we need to initialize the session_manager session with the correct ID
# session_manager.create_session generates a UUID, so we can't control it.
# Instead, directly create AgentSession in session_manager.sessions
from pathlib import Path
# Create session folder for file isolation
sessions_root = Path(__file__).parent / "sessions"
sessions_root.mkdir(parents=True, exist_ok=True)
session_folder = sessions_root / session_id
session_folder.mkdir(parents=True, exist_ok=True)
(session_folder / "files").mkdir(exist_ok=True)
(session_folder / "documents").mkdir(exist_ok=True)
# Create event queue
submission_queue = asyncio.Queue()
event_queue = asyncio.Queue()
# Create session and tool router
from agent.core.tools import ToolRouter
from agent.core.session import Session
def _create_session_sync():
tool_router = ToolRouter(config.mcp_servers)
# Create a simple config adapter with attributes Session expects
class ConfigAdapter:
model_name = config.model.name
save_sessions = config.save_sessions
session_dataset_repo = config.session_dataset_repo
auto_save_interval = config.auto_save_interval
mcp_servers = config.mcp_servers
session = Session(event_queue, config=ConfigAdapter(), tool_router=tool_router)
return tool_router, session
tool_router, session = await asyncio.to_thread(_create_session_sync)
session.session_folder = str(session_folder)
# Create agent session wrapper and register it
from session_manager import AgentSession
agent_session = AgentSession(
session_id=session_id,
session=session,
tool_router=tool_router,
submission_queue=submission_queue,
user_id="dev",
)
session_manager.sessions[session_id] = agent_session
# Start the agent loop task
task = asyncio.create_task(
session_manager._run_session(session_id, submission_queue, event_queue, tool_router)
)
agent_session.task = task
logger.info(f"Created agent session {session_id} for WebSocket")
except Exception as e:
logger.error(f"Failed to create agent session: {e}", exc_info=True)
await connection_manager.send_event(session_id, "error", {"message": "Failed to initialize agent session"})
return
# Submit the message to the agent session (events will flow through forwarder)
if agent_session:
success = await session_manager.submit_user_input(session_id, message_text)
if not success:
await connection_manager.send_event(session_id, "error", {"message": "Failed to submit to session"})
except Exception as e:
logger.error(f"Error processing message in WebSocket for session {session_id}: {e}")
await connection_manager.send_event(session_id, "error", {"message": f"Error processing request: {str(e)}"})
except WebSocketDisconnect:
connection_manager.disconnect(session_id)
except Exception as e:
logger.error(f"WebSocket error for session {session_id}: {e}")
connection_manager.disconnect(session_id)
# ============================================================================
# Main Entry Point
# ============================================================================
if __name__ == "__main__":
import uvicorn
port = int(os.environ.get("PORT", 7860))
host = os.environ.get("HOST", "0.0.0.0")
uvicorn.run(
"main:app",
host=host,
port=port,
reload=config.environment == "development",
workers=1, # Use single worker to avoid Prometheus registry conflicts
log_level="info",
)