Spaces:
Paused
Paused
| #!/usr/bin/env python3 | |
| """ | |
| Thread-safe entry point for the Weights & Biases MCP Server. | |
| """ | |
| import os | |
| import sys | |
| import logging | |
| import contextlib | |
| from pathlib import Path | |
| import threading | |
| import wandb | |
| # Add the src directory to Python path | |
| sys.path.insert(0, str(Path(__file__).parent / "src")) | |
| # Configure W&B directories for HF Spaces (must be done before importing wandb) | |
| os.environ["WANDB_CACHE_DIR"] = "/tmp/.wandb_cache" | |
| os.environ["WANDB_CONFIG_DIR"] = "/tmp/.wandb_config" | |
| os.environ["WANDB_DATA_DIR"] = "/tmp/.wandb_data" | |
| os.environ["HOME"] = "/tmp" | |
| os.environ["WANDB_SILENT"] = "True" | |
| os.environ["WEAVE_SILENT"] = "True" | |
| from fastapi import FastAPI, Request | |
| from fastapi.responses import HTMLResponse, JSONResponse, Response | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from mcp.server.fastmcp import FastMCP | |
| import base64 | |
| # Import W&B setup functions | |
| from wandb_mcp_server.server import ( | |
| validate_and_get_api_key, | |
| validate_api_key, | |
| configure_wandb_logging, | |
| initialize_weave_tracing, | |
| register_tools, | |
| ServerMCPArgs | |
| ) | |
| # Import the new API client manager | |
| from wandb_mcp_server.api_client import WandBApiManager | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
| ) | |
| logger = logging.getLogger("wandb-mcp-server") | |
| # API key management is now handled by WandBApiManager | |
| # which provides thread-safe context storage | |
| # Thread-local storage for W&B client instances | |
| # This prevents recreating clients for each request | |
| thread_local = threading.local() | |
| def get_thread_local_wandb_client(api_key: str): | |
| """Get or create a thread-local W&B client for the given API key.""" | |
| if not hasattr(thread_local, 'clients'): | |
| thread_local.clients = {} | |
| if api_key not in thread_local.clients: | |
| # Store the API key for this thread's client | |
| thread_local.clients[api_key] = { | |
| 'api_key': api_key, | |
| 'initialized': True | |
| } | |
| return thread_local.clients[api_key] | |
| # Read the index.html file content | |
| INDEX_HTML_PATH = Path(__file__).parent / "index.html" | |
| with open(INDEX_HTML_PATH, "r") as f: | |
| INDEX_HTML_CONTENT = f.read() | |
| # W&B Logo Favicon | |
| WANDB_FAVICON_BASE64 = """iVBORw0KGgoAAAANSUhEUgAAACAAAAAgCAMAAABEpIrGAAAAUVBMVEUAAAD/zzD/zzD/zzD/zjH/yzD/zDP/zDP/zTL/zDP/zTL/yzL/yzL/zDL/zDL/zDP/zDP/zDP/zDP/yzL/yzP/zDL/zDL/zDL/zDL/zDP/zDNs+ITNAAAAGnRSTlMAECAwP0BQX2BvcICPkJ+gr7C/wM/Q3+Dv8ORN9PUAAAEOSURBVBgZfcEJkpswAADBEVphB0EwzmJg/v/QcKbKC3E3FI/xN5fa8VEAjRq5ENUGaNXIhai2QBrsOJTf3yWHziHxw6AvPpl04pOsmXehfvksOYTAoXz6qgONi8hJdNEwuMicZBcvXGVOsit6FxWboq4LNpWLntLZFNj0+s0mTM5KSLmpAjtn7ELV5MQPnXZ8VJacxFvgUrhFZnc1cCGod6BTE7t7Xd/YJbUDKjWw6Zw92AS1AsK9SWyiq4JNau6BN8lV4n+Sq8Sb8PXri93gbOBNGtUnm6Kbpq7gUDDrXFRc6B0TuMqcJbWFyUXmLKoNtC4SmzyOmUMztAUUf9TMbtKRk8g/gw58UvZ9yZu/MeoYEFwSwuAAAAAASUVORK5CYII=""".strip() | |
| FAVICON_BASE64 = WANDB_FAVICON_BASE64 | |
| # Initialize W&B | |
| logger.info("Initializing W&B configuration...") | |
| configure_wandb_logging() | |
| args = ServerMCPArgs( | |
| transport="http", | |
| host="0.0.0.0", | |
| port=7860, | |
| wandb_api_key=os.environ.get("WANDB_API_KEY") | |
| ) | |
| wandb_configured = False | |
| api_key = validate_and_get_api_key(args) | |
| if api_key: | |
| try: | |
| validate_api_key(api_key) | |
| initialize_weave_tracing() | |
| wandb_configured = True | |
| logger.info("Server W&B API key configured successfully") | |
| except Exception as e: | |
| logger.warning(f"Failed to configure server W&B API key: {e}") | |
| else: | |
| logger.info("No server W&B API key configured - clients will provide their own") | |
| # Create the MCP server in stateless mode | |
| # All clients (OpenAI, Cursor, etc.) must provide Bearer token with each request | |
| # Session IDs are used only as correlation IDs, no state is persisted | |
| logger.info("Creating W&B MCP server in stateless HTTP mode...") | |
| mcp = FastMCP("wandb-mcp-server", stateless_http=True) | |
| # Register all W&B tools | |
| # The tools will use WandBApiManager.get_api_key() to get the current request's API key | |
| register_tools(mcp) | |
| # Custom authentication middleware | |
| async def thread_safe_auth_middleware(request: Request, call_next): | |
| """ | |
| Stateless authentication middleware for MCP endpoints. | |
| Pure stateless operation - every request must include authentication: | |
| - Session IDs are only used as correlation IDs | |
| - No session state is stored between requests | |
| - Each request must include Bearer token authentication | |
| This works with all clients (OpenAI, Cursor, etc.) that support MCP. | |
| """ | |
| # Only apply auth to MCP endpoints | |
| if not request.url.path.startswith("/mcp"): | |
| return await call_next(request) | |
| # Skip auth if explicitly disabled (development only) | |
| if os.environ.get("MCP_AUTH_DISABLED", "false").lower() == "true": | |
| logger.warning("MCP authentication is disabled - endpoints are publicly accessible") | |
| env_key = os.environ.get("WANDB_API_KEY") | |
| if env_key: | |
| token = WandBApiManager.set_context_api_key(env_key) | |
| try: | |
| response = await call_next(request) | |
| return response | |
| finally: | |
| WandBApiManager.reset_context_api_key(token) | |
| return await call_next(request) | |
| try: | |
| api_key = None | |
| # Check if request has MCP session ID (correlation ID only in stateless mode) | |
| session_id = request.headers.get("Mcp-Session-Id") or request.headers.get("mcp-session-id") | |
| if session_id: | |
| logger.debug(f"Request has correlation ID: {session_id[:8]}...") | |
| # Check for Bearer token (for new sessions or explicit auth) | |
| authorization = request.headers.get("Authorization", "") | |
| if authorization.startswith("Bearer "): | |
| bearer_token = authorization[7:].strip() | |
| # Basic validation | |
| if len(bearer_token) < 20 or len(bearer_token) > 100: | |
| return JSONResponse( | |
| status_code=401, | |
| content={"error": f"Invalid W&B API key format. Get your key at: https://wandb.ai/authorize"}, | |
| headers={"WWW-Authenticate": 'Bearer realm="W&B MCP", error="invalid_token"'} | |
| ) | |
| # Use Bearer token | |
| api_key = bearer_token | |
| logger.info(f"Using Bearer token for authentication") | |
| # Handle session cleanup (stateless mode - just acknowledge and pass through) | |
| if request.method == "DELETE" and session_id: | |
| logger.debug(f"Session cleanup: DELETE for {session_id[:8]}... (stateless - no action needed)") | |
| return await call_next(request) | |
| if api_key: | |
| # Set the API key in context variable (thread-safe) | |
| token = WandBApiManager.set_context_api_key(api_key) | |
| # Also store in request state | |
| request.state.wandb_api_key = api_key | |
| try: | |
| # Process the request | |
| response = await call_next(request) | |
| # In stateless mode, we don't store any session state | |
| response_session_id = response.headers.get("Mcp-Session-Id") or response.headers.get("mcp-session-id") | |
| if response_session_id: | |
| logger.debug(f"Response includes correlation ID: {response_session_id[:8]}...") | |
| return response | |
| finally: | |
| # Reset context variable | |
| WandBApiManager.reset_context_api_key(token) | |
| else: | |
| # No API key available - in stateless mode, this is expected to fail | |
| logger.warning(f"No Bearer token provided for {request.url.path}") | |
| logger.debug(f" Request method: {request.method}") | |
| logger.debug(" Passing to MCP (will likely return 401)") | |
| return await call_next(request) | |
| except Exception as e: | |
| logger.error(f"Authentication error: {e}") | |
| return JSONResponse( | |
| status_code=401, | |
| content={"error": "Authentication failed"}, | |
| headers={"WWW-Authenticate": 'Bearer realm="W&B MCP"'} | |
| ) | |
| # Create lifespan context manager for session management | |
| async def lifespan(app: FastAPI): | |
| """Manage MCP session lifecycle.""" | |
| async with mcp.session_manager.run(): | |
| logger.info("MCP session manager started") | |
| yield | |
| logger.info("MCP session manager stopped") | |
| # Create the main FastAPI app with lifespan | |
| app = FastAPI( | |
| title="Weights & Biases MCP Server", | |
| description="Model Context Protocol server for W&B (Thread-Safe)", | |
| lifespan=lifespan | |
| ) | |
| # Add CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Add request logging middleware for debugging | |
| async def logging_middleware(request, call_next): | |
| """Log all incoming requests for debugging.""" | |
| import time | |
| start_time = time.time() | |
| # Log request details | |
| logger.info(f"Incoming request: {request.method} {request.url.path}") | |
| # Log MCP-specific headers | |
| mcp_session_id = request.headers.get("mcp-session-id") | |
| if mcp_session_id: | |
| logger.info(f" MCP Session ID in request: {mcp_session_id[:8]}...") | |
| # Try to log request body for POST requests | |
| if request.method == "POST" and request.url.path in ["/mcp", "/"]: | |
| try: | |
| # Clone the request body so we can read it | |
| body_bytes = await request.body() | |
| if body_bytes: | |
| import json | |
| try: | |
| body_json = json.loads(body_bytes) | |
| method = body_json.get("method", "unknown") | |
| request_id = body_json.get("id", "unknown") | |
| logger.info(f" JSON-RPC request: method={method}, id={request_id}") | |
| if method == "tools/call": | |
| tool_name = body_json.get("params", {}).get("name", "unknown") | |
| logger.info(f" Tool call request for: {tool_name}") | |
| except json.JSONDecodeError: | |
| logger.debug(f" Request body (non-JSON): {body_bytes[:100]}") | |
| # Reconstruct the request with the body we read | |
| from starlette.datastructures import Headers | |
| from starlette.requests import Request as StarletteRequest | |
| # Create a new request with the body we read | |
| scope = request.scope | |
| scope["body"] = body_bytes | |
| async def receive(): | |
| return {"type": "http.request", "body": body_bytes} | |
| request = StarletteRequest(scope, receive) | |
| else: | |
| logger.debug(" No request body") | |
| except Exception as e: | |
| logger.debug(f" Could not read request body: {e}") | |
| # Track if this is an MCP endpoint | |
| is_mcp = request.url.path.startswith("/mcp") or request.url.path == "/" | |
| try: | |
| response = await call_next(request) | |
| # Calculate response time | |
| process_time = time.time() - start_time | |
| # Log response details | |
| status_label = "SUCCESS" if response.status_code < 400 else "ERROR" if response.status_code >= 400 else "WARNING" | |
| logger.info(f"[{status_label}] Response: {request.method} {request.url.path} -> {response.status_code} ({process_time:.3f}s)") | |
| # Log detailed info for 404s | |
| if response.status_code == 404: | |
| logger.warning(f"404 Not Found for {request.url.path}") | |
| logger.debug(f" Full URL: {request.url}") | |
| logger.debug(f" Available routes: /, /health, /favicon.ico, /favicon.png, /mcp") | |
| if is_mcp: | |
| logger.debug(" This appears to be an MCP endpoint that wasn't handled") | |
| return response | |
| except Exception as e: | |
| logger.error(f"Error processing {request.method} {request.url.path}: {e}") | |
| raise | |
| # Add authentication middleware | |
| async def auth_middleware(request, call_next): | |
| """Add thread-safe OAuth 2.1 Bearer token authentication for MCP endpoints.""" | |
| return await thread_safe_auth_middleware(request, call_next) | |
| # Add custom routes | |
| async def index(): | |
| """Serve the landing page.""" | |
| return INDEX_HTML_CONTENT | |
| async def favicon(): | |
| """Serve the official W&B logo favicon.""" | |
| return Response( | |
| content=base64.b64decode(FAVICON_BASE64), | |
| media_type="image/png", | |
| headers={ | |
| "Cache-Control": "public, max-age=31536000", | |
| "Content-Type": "image/png" | |
| } | |
| ) | |
| async def favicon_png(): | |
| """Alternative PNG favicon endpoint for better browser compatibility.""" | |
| return Response( | |
| content=base64.b64decode(FAVICON_BASE64), | |
| media_type="image/png", | |
| headers={ | |
| "Cache-Control": "public, max-age=31536000", | |
| "Content-Type": "image/png" | |
| } | |
| ) | |
| async def health(): | |
| """Health check endpoint.""" | |
| try: | |
| tools = await mcp.list_tools() | |
| tool_count = len(tools) | |
| except: | |
| tool_count = 0 | |
| auth_status = "disabled" if os.environ.get("MCP_AUTH_DISABLED", "false").lower() == "true" else "enabled" | |
| # Include worker information for debugging | |
| worker_info = { | |
| "pid": os.getpid(), | |
| "thread_id": threading.current_thread().name | |
| } | |
| return { | |
| "status": "healthy", | |
| "service": "wandb-mcp-server", | |
| "wandb_configured": wandb_configured, | |
| "tools_registered": tool_count, | |
| "authentication": auth_status, | |
| "worker_info": worker_info | |
| } | |
| # Mount the MCP streamable HTTP app | |
| # NOTE: MCP app is mounted at root "/" to handle all MCP protocol requests | |
| # This means it will catch all unhandled routes, which is why we define our | |
| # custom routes (/, /health, etc.) BEFORE mounting the MCP app | |
| mcp_app = mcp.streamable_http_app() | |
| logger.info("Mounting MCP streamable HTTP app at root /") | |
| logger.info("Note: MCP will handle all unmatched routes, returning 404 for non-MCP requests") | |
| # For debugging: Log incoming requests to understand routing | |
| async def mcp_routing_debug(request, call_next): | |
| """Debug middleware to understand MCP routing issues.""" | |
| path = request.url.path | |
| method = request.method | |
| # Check if this should be an MCP request | |
| is_mcp_request = ( | |
| request.headers.get("Content-Type") == "application/json" and | |
| (request.headers.get("Accept", "").find("text/event-stream") >= 0 or | |
| request.headers.get("Accept", "").find("application/json") >= 0) | |
| ) | |
| if path == "/" and method == "GET": | |
| logger.debug("Root GET request - should show landing page") | |
| elif path == "/health" and method == "GET": | |
| logger.debug("Health check request") | |
| elif path in ["/", "/mcp"] and is_mcp_request: | |
| logger.debug(f"MCP protocol request detected on {path}") | |
| elif path == "/" and method in ["POST", "GET"] and not is_mcp_request: | |
| logger.debug(f"Non-MCP {method} request to root - may get 404 from MCP app") | |
| return await call_next(request) | |
| app.mount("/", mcp_app) | |
| # Port for HF Spaces | |
| PORT = int(os.environ.get("PORT", "7860")) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| logger.info(f"Starting server on 0.0.0.0:{PORT}") | |
| logger.info("Landing page: /") | |
| logger.info("Health check: /health") | |
| logger.info("MCP endpoint: /mcp") | |
| # In stateless mode, we can scale horizontally with multiple workers | |
| # However, for HuggingFace Spaces we use single worker for simplicity | |
| logger.info("Starting server (stateless mode - supports horizontal scaling)") | |
| uvicorn.run(app, host="0.0.0.0", port=PORT, workers=1) # Can increase workers if needed | |