Spaces:
Paused
Paused
File size: 16,172 Bytes
f647629 40e1a91 f647629 04c12c4 f647629 1ec3391 f647629 2d37ec5 e2aaee8 0d796a8 ceeb737 04c12c4 0d796a8 04c12c4 1ec3391 04c12c4 f647629 1ec3391 a2dc155 2d37ec5 f647629 a643202 f647629 1ec3391 561151f ceeb737 1ec3391 0d796a8 04c12c4 a2dc155 1ec3391 a2dc155 04c12c4 40e1a91 04c12c4 1ec3391 04c12c4 1ec3391 40e1a91 1ec3391 40e1a91 1ec3391 40e1a91 0783971 40e4410 40e1a91 1ec3391 40e1a91 1ec3391 40e1a91 1ec3391 40e1a91 1ec3391 40e1a91 1ec3391 40e1a91 1ec3391 40e1a91 0783971 40e1a91 1ec3391 40e1a91 1ec3391 04c12c4 2d37ec5 1ec3391 04c12c4 2d37ec5 40e4410 0783971 40e4410 1ec3391 a2dc155 1ec3391 a2dc155 04c12c4 2d37ec5 0d796a8 1ec3391 0d796a8 1ec3391 0d796a8 2d37ec5 04c12c4 a2dc155 1ec3391 2d37ec5 a643202 a2dc155 1ec3391 2d37ec5 04c12c4 40e4410 04c12c4 40e4410 04c12c4 a643202 04c12c4 2d37ec5 04c12c4 2d37ec5 04c12c4 1ec3391 40e1a91 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 | #!/usr/bin/env python3
"""
Thread-safe entry point for the Weights & Biases MCP Server.
"""
import os
import sys
import logging
import contextlib
from pathlib import Path
import threading
import wandb
# Add the src directory to Python path
sys.path.insert(0, str(Path(__file__).parent / "src"))
# Configure W&B directories for HF Spaces (must be done before importing wandb)
os.environ["WANDB_CACHE_DIR"] = "/tmp/.wandb_cache"
os.environ["WANDB_CONFIG_DIR"] = "/tmp/.wandb_config"
os.environ["WANDB_DATA_DIR"] = "/tmp/.wandb_data"
os.environ["HOME"] = "/tmp"
os.environ["WANDB_SILENT"] = "True"
os.environ["WEAVE_SILENT"] = "True"
from fastapi import FastAPI, Request
from fastapi.responses import HTMLResponse, JSONResponse, Response
from fastapi.middleware.cors import CORSMiddleware
from mcp.server.fastmcp import FastMCP
import base64
# Import W&B setup functions
from wandb_mcp_server.server import (
validate_and_get_api_key,
validate_api_key,
configure_wandb_logging,
initialize_weave_tracing,
register_tools,
ServerMCPArgs
)
# Import the new API client manager
from wandb_mcp_server.api_client import WandBApiManager
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger("wandb-mcp-server")
# API key management is now handled by WandBApiManager
# which provides thread-safe context storage
# Thread-local storage for W&B client instances
# This prevents recreating clients for each request
thread_local = threading.local()
def get_thread_local_wandb_client(api_key: str):
"""Get or create a thread-local W&B client for the given API key."""
if not hasattr(thread_local, 'clients'):
thread_local.clients = {}
if api_key not in thread_local.clients:
# Store the API key for this thread's client
thread_local.clients[api_key] = {
'api_key': api_key,
'initialized': True
}
return thread_local.clients[api_key]
# Read the index.html file content
INDEX_HTML_PATH = Path(__file__).parent / "index.html"
with open(INDEX_HTML_PATH, "r") as f:
INDEX_HTML_CONTENT = f.read()
# W&B Logo Favicon
WANDB_FAVICON_BASE64 = """iVBORw0KGgoAAAANSUhEUgAAACAAAAAgCAMAAABEpIrGAAAAUVBMVEUAAAD/zzD/zzD/zzD/zjH/yzD/zDP/zDP/zTL/zDP/zTL/yzL/yzL/zDL/zDL/zDP/zDP/zDP/zDP/yzL/yzP/zDL/zDL/zDL/zDL/zDP/zDNs+ITNAAAAGnRSTlMAECAwP0BQX2BvcICPkJ+gr7C/wM/Q3+Dv8ORN9PUAAAEOSURBVBgZfcEJkpswAADBEVphB0EwzmJg/v/QcKbKC3E3FI/xN5fa8VEAjRq5ENUGaNXIhai2QBrsOJTf3yWHziHxw6AvPpl04pOsmXehfvksOYTAoXz6qgONi8hJdNEwuMicZBcvXGVOsit6FxWboq4LNpWLntLZFNj0+s0mTM5KSLmpAjtn7ELV5MQPnXZ8VJacxFvgUrhFZnc1cCGod6BTE7t7Xd/YJbUDKjWw6Zw92AS1AsK9SWyiq4JNau6BN8lV4n+Sq8Sb8PXri93gbOBNGtUnm6Kbpq7gUDDrXFRc6B0TuMqcJbWFyUXmLKoNtC4SmzyOmUMztAUUf9TMbtKRk8g/gw58UvZ9yZu/MeoYEFwSwuAAAAAASUVORK5CYII=""".strip()
FAVICON_BASE64 = WANDB_FAVICON_BASE64
# Initialize W&B
logger.info("Initializing W&B configuration...")
configure_wandb_logging()
args = ServerMCPArgs(
transport="http",
host="0.0.0.0",
port=7860,
wandb_api_key=os.environ.get("WANDB_API_KEY")
)
wandb_configured = False
api_key = validate_and_get_api_key(args)
if api_key:
try:
validate_api_key(api_key)
initialize_weave_tracing()
wandb_configured = True
logger.info("Server W&B API key configured successfully")
except Exception as e:
logger.warning(f"Failed to configure server W&B API key: {e}")
else:
logger.info("No server W&B API key configured - clients will provide their own")
# Create the MCP server in stateless mode
# All clients (OpenAI, Cursor, etc.) must provide Bearer token with each request
# Session IDs are used only as correlation IDs, no state is persisted
logger.info("Creating W&B MCP server in stateless HTTP mode...")
mcp = FastMCP("wandb-mcp-server", stateless_http=True)
# Register all W&B tools
# The tools will use WandBApiManager.get_api_key() to get the current request's API key
register_tools(mcp)
# Custom authentication middleware
async def thread_safe_auth_middleware(request: Request, call_next):
"""
Stateless authentication middleware for MCP endpoints.
Pure stateless operation - every request must include authentication:
- Session IDs are only used as correlation IDs
- No session state is stored between requests
- Each request must include Bearer token authentication
This works with all clients (OpenAI, Cursor, etc.) that support MCP.
"""
# Only apply auth to MCP endpoints
if not request.url.path.startswith("/mcp"):
return await call_next(request)
# Skip auth if explicitly disabled (development only)
if os.environ.get("MCP_AUTH_DISABLED", "false").lower() == "true":
logger.warning("MCP authentication is disabled - endpoints are publicly accessible")
env_key = os.environ.get("WANDB_API_KEY")
if env_key:
token = WandBApiManager.set_context_api_key(env_key)
try:
response = await call_next(request)
return response
finally:
WandBApiManager.reset_context_api_key(token)
return await call_next(request)
try:
api_key = None
# Check if request has MCP session ID (correlation ID only in stateless mode)
session_id = request.headers.get("Mcp-Session-Id") or request.headers.get("mcp-session-id")
if session_id:
logger.debug(f"Request has correlation ID: {session_id[:8]}...")
# Check for Bearer token (for new sessions or explicit auth)
authorization = request.headers.get("Authorization", "")
if authorization.startswith("Bearer "):
bearer_token = authorization[7:].strip()
# Basic validation
if len(bearer_token) < 20 or len(bearer_token) > 100:
return JSONResponse(
status_code=401,
content={"error": f"Invalid W&B API key format. Get your key at: https://wandb.ai/authorize"},
headers={"WWW-Authenticate": 'Bearer realm="W&B MCP", error="invalid_token"'}
)
# Use Bearer token
api_key = bearer_token
logger.info(f"Using Bearer token for authentication")
# Handle session cleanup (stateless mode - just acknowledge and pass through)
if request.method == "DELETE" and session_id:
logger.debug(f"Session cleanup: DELETE for {session_id[:8]}... (stateless - no action needed)")
return await call_next(request)
if api_key:
# Set the API key in context variable (thread-safe)
token = WandBApiManager.set_context_api_key(api_key)
# Also store in request state
request.state.wandb_api_key = api_key
try:
# Process the request
response = await call_next(request)
# In stateless mode, we don't store any session state
response_session_id = response.headers.get("Mcp-Session-Id") or response.headers.get("mcp-session-id")
if response_session_id:
logger.debug(f"Response includes correlation ID: {response_session_id[:8]}...")
return response
finally:
# Reset context variable
WandBApiManager.reset_context_api_key(token)
else:
# No API key available - in stateless mode, this is expected to fail
logger.warning(f"No Bearer token provided for {request.url.path}")
logger.debug(f" Request method: {request.method}")
logger.debug(" Passing to MCP (will likely return 401)")
return await call_next(request)
except Exception as e:
logger.error(f"Authentication error: {e}")
return JSONResponse(
status_code=401,
content={"error": "Authentication failed"},
headers={"WWW-Authenticate": 'Bearer realm="W&B MCP"'}
)
# Create lifespan context manager for session management
@contextlib.asynccontextmanager
async def lifespan(app: FastAPI):
"""Manage MCP session lifecycle."""
async with mcp.session_manager.run():
logger.info("MCP session manager started")
yield
logger.info("MCP session manager stopped")
# Create the main FastAPI app with lifespan
app = FastAPI(
title="Weights & Biases MCP Server",
description="Model Context Protocol server for W&B (Thread-Safe)",
lifespan=lifespan
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Add request logging middleware for debugging
@app.middleware("http")
async def logging_middleware(request, call_next):
"""Log all incoming requests for debugging."""
import time
start_time = time.time()
# Log request details
logger.info(f"Incoming request: {request.method} {request.url.path}")
# Log MCP-specific headers
mcp_session_id = request.headers.get("mcp-session-id")
if mcp_session_id:
logger.info(f" MCP Session ID in request: {mcp_session_id[:8]}...")
# Try to log request body for POST requests
if request.method == "POST" and request.url.path in ["/mcp", "/"]:
try:
# Clone the request body so we can read it
body_bytes = await request.body()
if body_bytes:
import json
try:
body_json = json.loads(body_bytes)
method = body_json.get("method", "unknown")
request_id = body_json.get("id", "unknown")
logger.info(f" JSON-RPC request: method={method}, id={request_id}")
if method == "tools/call":
tool_name = body_json.get("params", {}).get("name", "unknown")
logger.info(f" Tool call request for: {tool_name}")
except json.JSONDecodeError:
logger.debug(f" Request body (non-JSON): {body_bytes[:100]}")
# Reconstruct the request with the body we read
from starlette.datastructures import Headers
from starlette.requests import Request as StarletteRequest
# Create a new request with the body we read
scope = request.scope
scope["body"] = body_bytes
async def receive():
return {"type": "http.request", "body": body_bytes}
request = StarletteRequest(scope, receive)
else:
logger.debug(" No request body")
except Exception as e:
logger.debug(f" Could not read request body: {e}")
# Track if this is an MCP endpoint
is_mcp = request.url.path.startswith("/mcp") or request.url.path == "/"
try:
response = await call_next(request)
# Calculate response time
process_time = time.time() - start_time
# Log response details
status_label = "SUCCESS" if response.status_code < 400 else "ERROR" if response.status_code >= 400 else "WARNING"
logger.info(f"[{status_label}] Response: {request.method} {request.url.path} -> {response.status_code} ({process_time:.3f}s)")
# Log detailed info for 404s
if response.status_code == 404:
logger.warning(f"404 Not Found for {request.url.path}")
logger.debug(f" Full URL: {request.url}")
logger.debug(f" Available routes: /, /health, /favicon.ico, /favicon.png, /mcp")
if is_mcp:
logger.debug(" This appears to be an MCP endpoint that wasn't handled")
return response
except Exception as e:
logger.error(f"Error processing {request.method} {request.url.path}: {e}")
raise
# Add authentication middleware
@app.middleware("http")
async def auth_middleware(request, call_next):
"""Add thread-safe OAuth 2.1 Bearer token authentication for MCP endpoints."""
return await thread_safe_auth_middleware(request, call_next)
# Add custom routes
@app.get("/", response_class=HTMLResponse)
async def index():
"""Serve the landing page."""
return INDEX_HTML_CONTENT
@app.get("/favicon.ico")
async def favicon():
"""Serve the official W&B logo favicon."""
return Response(
content=base64.b64decode(FAVICON_BASE64),
media_type="image/png",
headers={
"Cache-Control": "public, max-age=31536000",
"Content-Type": "image/png"
}
)
@app.get("/favicon.png")
async def favicon_png():
"""Alternative PNG favicon endpoint for better browser compatibility."""
return Response(
content=base64.b64decode(FAVICON_BASE64),
media_type="image/png",
headers={
"Cache-Control": "public, max-age=31536000",
"Content-Type": "image/png"
}
)
@app.get("/health")
async def health():
"""Health check endpoint."""
try:
tools = await mcp.list_tools()
tool_count = len(tools)
except:
tool_count = 0
auth_status = "disabled" if os.environ.get("MCP_AUTH_DISABLED", "false").lower() == "true" else "enabled"
# Include worker information for debugging
worker_info = {
"pid": os.getpid(),
"thread_id": threading.current_thread().name
}
return {
"status": "healthy",
"service": "wandb-mcp-server",
"wandb_configured": wandb_configured,
"tools_registered": tool_count,
"authentication": auth_status,
"worker_info": worker_info
}
# Mount the MCP streamable HTTP app
# NOTE: MCP app is mounted at root "/" to handle all MCP protocol requests
# This means it will catch all unhandled routes, which is why we define our
# custom routes (/, /health, etc.) BEFORE mounting the MCP app
mcp_app = mcp.streamable_http_app()
logger.info("Mounting MCP streamable HTTP app at root /")
logger.info("Note: MCP will handle all unmatched routes, returning 404 for non-MCP requests")
# For debugging: Log incoming requests to understand routing
@app.middleware("http")
async def mcp_routing_debug(request, call_next):
"""Debug middleware to understand MCP routing issues."""
path = request.url.path
method = request.method
# Check if this should be an MCP request
is_mcp_request = (
request.headers.get("Content-Type") == "application/json" and
(request.headers.get("Accept", "").find("text/event-stream") >= 0 or
request.headers.get("Accept", "").find("application/json") >= 0)
)
if path == "/" and method == "GET":
logger.debug("Root GET request - should show landing page")
elif path == "/health" and method == "GET":
logger.debug("Health check request")
elif path in ["/", "/mcp"] and is_mcp_request:
logger.debug(f"MCP protocol request detected on {path}")
elif path == "/" and method in ["POST", "GET"] and not is_mcp_request:
logger.debug(f"Non-MCP {method} request to root - may get 404 from MCP app")
return await call_next(request)
app.mount("/", mcp_app)
# Port for HF Spaces
PORT = int(os.environ.get("PORT", "7860"))
if __name__ == "__main__":
import uvicorn
logger.info(f"Starting server on 0.0.0.0:{PORT}")
logger.info("Landing page: /")
logger.info("Health check: /health")
logger.info("MCP endpoint: /mcp")
# In stateless mode, we can scale horizontally with multiple workers
# However, for HuggingFace Spaces we use single worker for simplicity
logger.info("Starting server (stateless mode - supports horizontal scaling)")
uvicorn.run(app, host="0.0.0.0", port=PORT, workers=1) # Can increase workers if needed
|