mcp-server / app.py
NiWaRe's picture
refactor for stateless: turn stateless on for FastMCP to work with OpenAI client etc
40e1a91
#!/usr/bin/env python3
"""
Thread-safe entry point for the Weights & Biases MCP Server.
"""
import os
import sys
import logging
import contextlib
from pathlib import Path
import threading
import wandb
# Add the src directory to Python path
sys.path.insert(0, str(Path(__file__).parent / "src"))
# Configure W&B directories for HF Spaces (must be done before importing wandb)
os.environ["WANDB_CACHE_DIR"] = "/tmp/.wandb_cache"
os.environ["WANDB_CONFIG_DIR"] = "/tmp/.wandb_config"
os.environ["WANDB_DATA_DIR"] = "/tmp/.wandb_data"
os.environ["HOME"] = "/tmp"
os.environ["WANDB_SILENT"] = "True"
os.environ["WEAVE_SILENT"] = "True"
from fastapi import FastAPI, Request
from fastapi.responses import HTMLResponse, JSONResponse, Response
from fastapi.middleware.cors import CORSMiddleware
from mcp.server.fastmcp import FastMCP
import base64
# Import W&B setup functions
from wandb_mcp_server.server import (
validate_and_get_api_key,
validate_api_key,
configure_wandb_logging,
initialize_weave_tracing,
register_tools,
ServerMCPArgs
)
# Import the new API client manager
from wandb_mcp_server.api_client import WandBApiManager
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger("wandb-mcp-server")
# API key management is now handled by WandBApiManager
# which provides thread-safe context storage
# Thread-local storage for W&B client instances
# This prevents recreating clients for each request
thread_local = threading.local()
def get_thread_local_wandb_client(api_key: str):
"""Get or create a thread-local W&B client for the given API key."""
if not hasattr(thread_local, 'clients'):
thread_local.clients = {}
if api_key not in thread_local.clients:
# Store the API key for this thread's client
thread_local.clients[api_key] = {
'api_key': api_key,
'initialized': True
}
return thread_local.clients[api_key]
# Read the index.html file content
INDEX_HTML_PATH = Path(__file__).parent / "index.html"
with open(INDEX_HTML_PATH, "r") as f:
INDEX_HTML_CONTENT = f.read()
# W&B Logo Favicon
WANDB_FAVICON_BASE64 = """iVBORw0KGgoAAAANSUhEUgAAACAAAAAgCAMAAABEpIrGAAAAUVBMVEUAAAD/zzD/zzD/zzD/zjH/yzD/zDP/zDP/zTL/zDP/zTL/yzL/yzL/zDL/zDL/zDP/zDP/zDP/zDP/yzL/yzP/zDL/zDL/zDL/zDL/zDP/zDNs+ITNAAAAGnRSTlMAECAwP0BQX2BvcICPkJ+gr7C/wM/Q3+Dv8ORN9PUAAAEOSURBVBgZfcEJkpswAADBEVphB0EwzmJg/v/QcKbKC3E3FI/xN5fa8VEAjRq5ENUGaNXIhai2QBrsOJTf3yWHziHxw6AvPpl04pOsmXehfvksOYTAoXz6qgONi8hJdNEwuMicZBcvXGVOsit6FxWboq4LNpWLntLZFNj0+s0mTM5KSLmpAjtn7ELV5MQPnXZ8VJacxFvgUrhFZnc1cCGod6BTE7t7Xd/YJbUDKjWw6Zw92AS1AsK9SWyiq4JNau6BN8lV4n+Sq8Sb8PXri93gbOBNGtUnm6Kbpq7gUDDrXFRc6B0TuMqcJbWFyUXmLKoNtC4SmzyOmUMztAUUf9TMbtKRk8g/gw58UvZ9yZu/MeoYEFwSwuAAAAAASUVORK5CYII=""".strip()
FAVICON_BASE64 = WANDB_FAVICON_BASE64
# Initialize W&B
logger.info("Initializing W&B configuration...")
configure_wandb_logging()
args = ServerMCPArgs(
transport="http",
host="0.0.0.0",
port=7860,
wandb_api_key=os.environ.get("WANDB_API_KEY")
)
wandb_configured = False
api_key = validate_and_get_api_key(args)
if api_key:
try:
validate_api_key(api_key)
initialize_weave_tracing()
wandb_configured = True
logger.info("Server W&B API key configured successfully")
except Exception as e:
logger.warning(f"Failed to configure server W&B API key: {e}")
else:
logger.info("No server W&B API key configured - clients will provide their own")
# Create the MCP server in stateless mode
# All clients (OpenAI, Cursor, etc.) must provide Bearer token with each request
# Session IDs are used only as correlation IDs, no state is persisted
logger.info("Creating W&B MCP server in stateless HTTP mode...")
mcp = FastMCP("wandb-mcp-server", stateless_http=True)
# Register all W&B tools
# The tools will use WandBApiManager.get_api_key() to get the current request's API key
register_tools(mcp)
# Custom authentication middleware
async def thread_safe_auth_middleware(request: Request, call_next):
"""
Stateless authentication middleware for MCP endpoints.
Pure stateless operation - every request must include authentication:
- Session IDs are only used as correlation IDs
- No session state is stored between requests
- Each request must include Bearer token authentication
This works with all clients (OpenAI, Cursor, etc.) that support MCP.
"""
# Only apply auth to MCP endpoints
if not request.url.path.startswith("/mcp"):
return await call_next(request)
# Skip auth if explicitly disabled (development only)
if os.environ.get("MCP_AUTH_DISABLED", "false").lower() == "true":
logger.warning("MCP authentication is disabled - endpoints are publicly accessible")
env_key = os.environ.get("WANDB_API_KEY")
if env_key:
token = WandBApiManager.set_context_api_key(env_key)
try:
response = await call_next(request)
return response
finally:
WandBApiManager.reset_context_api_key(token)
return await call_next(request)
try:
api_key = None
# Check if request has MCP session ID (correlation ID only in stateless mode)
session_id = request.headers.get("Mcp-Session-Id") or request.headers.get("mcp-session-id")
if session_id:
logger.debug(f"Request has correlation ID: {session_id[:8]}...")
# Check for Bearer token (for new sessions or explicit auth)
authorization = request.headers.get("Authorization", "")
if authorization.startswith("Bearer "):
bearer_token = authorization[7:].strip()
# Basic validation
if len(bearer_token) < 20 or len(bearer_token) > 100:
return JSONResponse(
status_code=401,
content={"error": f"Invalid W&B API key format. Get your key at: https://wandb.ai/authorize"},
headers={"WWW-Authenticate": 'Bearer realm="W&B MCP", error="invalid_token"'}
)
# Use Bearer token
api_key = bearer_token
logger.info(f"Using Bearer token for authentication")
# Handle session cleanup (stateless mode - just acknowledge and pass through)
if request.method == "DELETE" and session_id:
logger.debug(f"Session cleanup: DELETE for {session_id[:8]}... (stateless - no action needed)")
return await call_next(request)
if api_key:
# Set the API key in context variable (thread-safe)
token = WandBApiManager.set_context_api_key(api_key)
# Also store in request state
request.state.wandb_api_key = api_key
try:
# Process the request
response = await call_next(request)
# In stateless mode, we don't store any session state
response_session_id = response.headers.get("Mcp-Session-Id") or response.headers.get("mcp-session-id")
if response_session_id:
logger.debug(f"Response includes correlation ID: {response_session_id[:8]}...")
return response
finally:
# Reset context variable
WandBApiManager.reset_context_api_key(token)
else:
# No API key available - in stateless mode, this is expected to fail
logger.warning(f"No Bearer token provided for {request.url.path}")
logger.debug(f" Request method: {request.method}")
logger.debug(" Passing to MCP (will likely return 401)")
return await call_next(request)
except Exception as e:
logger.error(f"Authentication error: {e}")
return JSONResponse(
status_code=401,
content={"error": "Authentication failed"},
headers={"WWW-Authenticate": 'Bearer realm="W&B MCP"'}
)
# Create lifespan context manager for session management
@contextlib.asynccontextmanager
async def lifespan(app: FastAPI):
"""Manage MCP session lifecycle."""
async with mcp.session_manager.run():
logger.info("MCP session manager started")
yield
logger.info("MCP session manager stopped")
# Create the main FastAPI app with lifespan
app = FastAPI(
title="Weights & Biases MCP Server",
description="Model Context Protocol server for W&B (Thread-Safe)",
lifespan=lifespan
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Add request logging middleware for debugging
@app.middleware("http")
async def logging_middleware(request, call_next):
"""Log all incoming requests for debugging."""
import time
start_time = time.time()
# Log request details
logger.info(f"Incoming request: {request.method} {request.url.path}")
# Log MCP-specific headers
mcp_session_id = request.headers.get("mcp-session-id")
if mcp_session_id:
logger.info(f" MCP Session ID in request: {mcp_session_id[:8]}...")
# Try to log request body for POST requests
if request.method == "POST" and request.url.path in ["/mcp", "/"]:
try:
# Clone the request body so we can read it
body_bytes = await request.body()
if body_bytes:
import json
try:
body_json = json.loads(body_bytes)
method = body_json.get("method", "unknown")
request_id = body_json.get("id", "unknown")
logger.info(f" JSON-RPC request: method={method}, id={request_id}")
if method == "tools/call":
tool_name = body_json.get("params", {}).get("name", "unknown")
logger.info(f" Tool call request for: {tool_name}")
except json.JSONDecodeError:
logger.debug(f" Request body (non-JSON): {body_bytes[:100]}")
# Reconstruct the request with the body we read
from starlette.datastructures import Headers
from starlette.requests import Request as StarletteRequest
# Create a new request with the body we read
scope = request.scope
scope["body"] = body_bytes
async def receive():
return {"type": "http.request", "body": body_bytes}
request = StarletteRequest(scope, receive)
else:
logger.debug(" No request body")
except Exception as e:
logger.debug(f" Could not read request body: {e}")
# Track if this is an MCP endpoint
is_mcp = request.url.path.startswith("/mcp") or request.url.path == "/"
try:
response = await call_next(request)
# Calculate response time
process_time = time.time() - start_time
# Log response details
status_label = "SUCCESS" if response.status_code < 400 else "ERROR" if response.status_code >= 400 else "WARNING"
logger.info(f"[{status_label}] Response: {request.method} {request.url.path} -> {response.status_code} ({process_time:.3f}s)")
# Log detailed info for 404s
if response.status_code == 404:
logger.warning(f"404 Not Found for {request.url.path}")
logger.debug(f" Full URL: {request.url}")
logger.debug(f" Available routes: /, /health, /favicon.ico, /favicon.png, /mcp")
if is_mcp:
logger.debug(" This appears to be an MCP endpoint that wasn't handled")
return response
except Exception as e:
logger.error(f"Error processing {request.method} {request.url.path}: {e}")
raise
# Add authentication middleware
@app.middleware("http")
async def auth_middleware(request, call_next):
"""Add thread-safe OAuth 2.1 Bearer token authentication for MCP endpoints."""
return await thread_safe_auth_middleware(request, call_next)
# Add custom routes
@app.get("/", response_class=HTMLResponse)
async def index():
"""Serve the landing page."""
return INDEX_HTML_CONTENT
@app.get("/favicon.ico")
async def favicon():
"""Serve the official W&B logo favicon."""
return Response(
content=base64.b64decode(FAVICON_BASE64),
media_type="image/png",
headers={
"Cache-Control": "public, max-age=31536000",
"Content-Type": "image/png"
}
)
@app.get("/favicon.png")
async def favicon_png():
"""Alternative PNG favicon endpoint for better browser compatibility."""
return Response(
content=base64.b64decode(FAVICON_BASE64),
media_type="image/png",
headers={
"Cache-Control": "public, max-age=31536000",
"Content-Type": "image/png"
}
)
@app.get("/health")
async def health():
"""Health check endpoint."""
try:
tools = await mcp.list_tools()
tool_count = len(tools)
except:
tool_count = 0
auth_status = "disabled" if os.environ.get("MCP_AUTH_DISABLED", "false").lower() == "true" else "enabled"
# Include worker information for debugging
worker_info = {
"pid": os.getpid(),
"thread_id": threading.current_thread().name
}
return {
"status": "healthy",
"service": "wandb-mcp-server",
"wandb_configured": wandb_configured,
"tools_registered": tool_count,
"authentication": auth_status,
"worker_info": worker_info
}
# Mount the MCP streamable HTTP app
# NOTE: MCP app is mounted at root "/" to handle all MCP protocol requests
# This means it will catch all unhandled routes, which is why we define our
# custom routes (/, /health, etc.) BEFORE mounting the MCP app
mcp_app = mcp.streamable_http_app()
logger.info("Mounting MCP streamable HTTP app at root /")
logger.info("Note: MCP will handle all unmatched routes, returning 404 for non-MCP requests")
# For debugging: Log incoming requests to understand routing
@app.middleware("http")
async def mcp_routing_debug(request, call_next):
"""Debug middleware to understand MCP routing issues."""
path = request.url.path
method = request.method
# Check if this should be an MCP request
is_mcp_request = (
request.headers.get("Content-Type") == "application/json" and
(request.headers.get("Accept", "").find("text/event-stream") >= 0 or
request.headers.get("Accept", "").find("application/json") >= 0)
)
if path == "/" and method == "GET":
logger.debug("Root GET request - should show landing page")
elif path == "/health" and method == "GET":
logger.debug("Health check request")
elif path in ["/", "/mcp"] and is_mcp_request:
logger.debug(f"MCP protocol request detected on {path}")
elif path == "/" and method in ["POST", "GET"] and not is_mcp_request:
logger.debug(f"Non-MCP {method} request to root - may get 404 from MCP app")
return await call_next(request)
app.mount("/", mcp_app)
# Port for HF Spaces
PORT = int(os.environ.get("PORT", "7860"))
if __name__ == "__main__":
import uvicorn
logger.info(f"Starting server on 0.0.0.0:{PORT}")
logger.info("Landing page: /")
logger.info("Health check: /health")
logger.info("MCP endpoint: /mcp")
# In stateless mode, we can scale horizontally with multiple workers
# However, for HuggingFace Spaces we use single worker for simplicity
logger.info("Starting server (stateless mode - supports horizontal scaling)")
uvicorn.run(app, host="0.0.0.0", port=PORT, workers=1) # Can increase workers if needed