File size: 16,172 Bytes
f647629
 
40e1a91
f647629
 
 
 
 
04c12c4
f647629
1ec3391
 
f647629
 
 
 
2d37ec5
 
 
 
 
 
 
 
e2aaee8
0d796a8
ceeb737
04c12c4
0d796a8
04c12c4
 
 
 
1ec3391
04c12c4
 
 
 
 
f647629
1ec3391
 
a2dc155
2d37ec5
f647629
 
 
 
a643202
f647629
1ec3391
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
561151f
 
 
 
ceeb737
1ec3391
0d796a8
 
 
 
04c12c4
 
 
 
 
 
 
 
 
 
 
 
a2dc155
 
 
1ec3391
a2dc155
 
 
 
 
 
 
04c12c4
40e1a91
 
 
 
 
04c12c4
 
1ec3391
04c12c4
 
1ec3391
 
 
40e1a91
 
 
 
 
 
1ec3391
40e1a91
1ec3391
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40e1a91
0783971
40e4410
40e1a91
1ec3391
 
 
 
40e1a91
1ec3391
 
40e1a91
1ec3391
 
 
 
 
40e1a91
 
 
 
1ec3391
40e1a91
1ec3391
40e1a91
1ec3391
 
 
 
 
 
 
 
 
 
 
 
 
40e1a91
0783971
 
40e1a91
1ec3391
 
 
 
 
 
40e1a91
 
 
 
1ec3391
 
 
 
 
 
 
 
 
 
04c12c4
 
 
 
 
 
 
 
 
 
2d37ec5
 
1ec3391
04c12c4
2d37ec5
 
 
 
 
 
 
 
 
 
 
40e4410
 
 
 
 
 
 
 
 
0783971
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40e4410
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1ec3391
a2dc155
 
1ec3391
 
a2dc155
04c12c4
2d37ec5
 
 
 
 
0d796a8
 
1ec3391
0d796a8
 
 
 
1ec3391
 
0d796a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2d37ec5
 
 
04c12c4
 
 
 
 
 
a2dc155
 
1ec3391
 
 
 
 
 
2d37ec5
 
 
a643202
a2dc155
1ec3391
 
2d37ec5
 
04c12c4
40e4410
 
 
04c12c4
40e4410
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
04c12c4
a643202
04c12c4
 
2d37ec5
04c12c4
 
 
2d37ec5
 
04c12c4
1ec3391
40e1a91
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
#!/usr/bin/env python3
"""
Thread-safe entry point for the Weights & Biases MCP Server.
"""

import os
import sys
import logging
import contextlib
from pathlib import Path
import threading
import wandb

# Add the src directory to Python path
sys.path.insert(0, str(Path(__file__).parent / "src"))

# Configure W&B directories for HF Spaces (must be done before importing wandb)
os.environ["WANDB_CACHE_DIR"] = "/tmp/.wandb_cache"
os.environ["WANDB_CONFIG_DIR"] = "/tmp/.wandb_config"
os.environ["WANDB_DATA_DIR"] = "/tmp/.wandb_data"
os.environ["HOME"] = "/tmp"
os.environ["WANDB_SILENT"] = "True"
os.environ["WEAVE_SILENT"] = "True"

from fastapi import FastAPI, Request
from fastapi.responses import HTMLResponse, JSONResponse, Response
from fastapi.middleware.cors import CORSMiddleware
from mcp.server.fastmcp import FastMCP
import base64

# Import W&B setup functions
from wandb_mcp_server.server import (
    validate_and_get_api_key,
    validate_api_key,
    configure_wandb_logging,
    initialize_weave_tracing,
    register_tools,
    ServerMCPArgs
)

# Import the new API client manager
from wandb_mcp_server.api_client import WandBApiManager

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger("wandb-mcp-server")

# API key management is now handled by WandBApiManager
# which provides thread-safe context storage

# Thread-local storage for W&B client instances
# This prevents recreating clients for each request
thread_local = threading.local()

def get_thread_local_wandb_client(api_key: str):
    """Get or create a thread-local W&B client for the given API key."""
    if not hasattr(thread_local, 'clients'):
        thread_local.clients = {}
    
    if api_key not in thread_local.clients:
        # Store the API key for this thread's client
        thread_local.clients[api_key] = {
            'api_key': api_key,
            'initialized': True
        }
    
    return thread_local.clients[api_key]

# Read the index.html file content
INDEX_HTML_PATH = Path(__file__).parent / "index.html"
with open(INDEX_HTML_PATH, "r") as f:
    INDEX_HTML_CONTENT = f.read()

# W&B Logo Favicon
WANDB_FAVICON_BASE64 = """iVBORw0KGgoAAAANSUhEUgAAACAAAAAgCAMAAABEpIrGAAAAUVBMVEUAAAD/zzD/zzD/zzD/zjH/yzD/zDP/zDP/zTL/zDP/zTL/yzL/yzL/zDL/zDL/zDP/zDP/zDP/zDP/yzL/yzP/zDL/zDL/zDL/zDL/zDP/zDNs+ITNAAAAGnRSTlMAECAwP0BQX2BvcICPkJ+gr7C/wM/Q3+Dv8ORN9PUAAAEOSURBVBgZfcEJkpswAADBEVphB0EwzmJg/v/QcKbKC3E3FI/xN5fa8VEAjRq5ENUGaNXIhai2QBrsOJTf3yWHziHxw6AvPpl04pOsmXehfvksOYTAoXz6qgONi8hJdNEwuMicZBcvXGVOsit6FxWboq4LNpWLntLZFNj0+s0mTM5KSLmpAjtn7ELV5MQPnXZ8VJacxFvgUrhFZnc1cCGod6BTE7t7Xd/YJbUDKjWw6Zw92AS1AsK9SWyiq4JNau6BN8lV4n+Sq8Sb8PXri93gbOBNGtUnm6Kbpq7gUDDrXFRc6B0TuMqcJbWFyUXmLKoNtC4SmzyOmUMztAUUf9TMbtKRk8g/gw58UvZ9yZu/MeoYEFwSwuAAAAAASUVORK5CYII=""".strip()

FAVICON_BASE64 = WANDB_FAVICON_BASE64

# Initialize W&B
logger.info("Initializing W&B configuration...")
configure_wandb_logging()

args = ServerMCPArgs(
    transport="http",
    host="0.0.0.0",
    port=7860,
    wandb_api_key=os.environ.get("WANDB_API_KEY")
)

wandb_configured = False
api_key = validate_and_get_api_key(args)
if api_key:
    try:
        validate_api_key(api_key)
        initialize_weave_tracing()
        wandb_configured = True
        logger.info("Server W&B API key configured successfully")
    except Exception as e:
        logger.warning(f"Failed to configure server W&B API key: {e}")
else:
    logger.info("No server W&B API key configured - clients will provide their own")

# Create the MCP server in stateless mode
# All clients (OpenAI, Cursor, etc.) must provide Bearer token with each request
# Session IDs are used only as correlation IDs, no state is persisted
logger.info("Creating W&B MCP server in stateless HTTP mode...")
mcp = FastMCP("wandb-mcp-server", stateless_http=True)

# Register all W&B tools
# The tools will use WandBApiManager.get_api_key() to get the current request's API key
register_tools(mcp)

# Custom authentication middleware
async def thread_safe_auth_middleware(request: Request, call_next):
    """
    Stateless authentication middleware for MCP endpoints.
    
    Pure stateless operation - every request must include authentication:
    - Session IDs are only used as correlation IDs
    - No session state is stored between requests
    - Each request must include Bearer token authentication
    
    This works with all clients (OpenAI, Cursor, etc.) that support MCP.
    """
    # Only apply auth to MCP endpoints
    if not request.url.path.startswith("/mcp"):
        return await call_next(request)
    
    # Skip auth if explicitly disabled (development only)
    if os.environ.get("MCP_AUTH_DISABLED", "false").lower() == "true":
        logger.warning("MCP authentication is disabled - endpoints are publicly accessible")
        env_key = os.environ.get("WANDB_API_KEY")
        if env_key:
            token = WandBApiManager.set_context_api_key(env_key)
            try:
                response = await call_next(request)
                return response
            finally:
                WandBApiManager.reset_context_api_key(token)
        return await call_next(request)
    
    try:
        api_key = None
        
        # Check if request has MCP session ID (correlation ID only in stateless mode)
        session_id = request.headers.get("Mcp-Session-Id") or request.headers.get("mcp-session-id")
        if session_id:
            logger.debug(f"Request has correlation ID: {session_id[:8]}...")
        
        # Check for Bearer token (for new sessions or explicit auth)
        authorization = request.headers.get("Authorization", "")
        if authorization.startswith("Bearer "):
            bearer_token = authorization[7:].strip()
            
            # Basic validation
            if len(bearer_token) < 20 or len(bearer_token) > 100:
                return JSONResponse(
                    status_code=401,
                    content={"error": f"Invalid W&B API key format. Get your key at: https://wandb.ai/authorize"},
                    headers={"WWW-Authenticate": 'Bearer realm="W&B MCP", error="invalid_token"'}
                )
            
            # Use Bearer token
            api_key = bearer_token
            logger.info(f"Using Bearer token for authentication")
        
        # Handle session cleanup (stateless mode - just acknowledge and pass through)
        if request.method == "DELETE" and session_id:
            logger.debug(f"Session cleanup: DELETE for {session_id[:8]}... (stateless - no action needed)")
            return await call_next(request)
        
        if api_key:
            # Set the API key in context variable (thread-safe)
            token = WandBApiManager.set_context_api_key(api_key)
            
            # Also store in request state
            request.state.wandb_api_key = api_key
            
            try:
                # Process the request
                response = await call_next(request)
                
                # In stateless mode, we don't store any session state
                response_session_id = response.headers.get("Mcp-Session-Id") or response.headers.get("mcp-session-id")
                if response_session_id:
                    logger.debug(f"Response includes correlation ID: {response_session_id[:8]}...")
                
                return response
            finally:
                # Reset context variable
                WandBApiManager.reset_context_api_key(token)
        else:
            # No API key available - in stateless mode, this is expected to fail
            logger.warning(f"No Bearer token provided for {request.url.path}")
            logger.debug(f"   Request method: {request.method}")
            logger.debug("   Passing to MCP (will likely return 401)")
            return await call_next(request)
        
    except Exception as e:
        logger.error(f"Authentication error: {e}")
        return JSONResponse(
            status_code=401,
            content={"error": "Authentication failed"},
            headers={"WWW-Authenticate": 'Bearer realm="W&B MCP"'}
        )

# Create lifespan context manager for session management
@contextlib.asynccontextmanager
async def lifespan(app: FastAPI):
    """Manage MCP session lifecycle."""
    async with mcp.session_manager.run():
        logger.info("MCP session manager started")
        yield
        logger.info("MCP session manager stopped")

# Create the main FastAPI app with lifespan
app = FastAPI(
    title="Weights & Biases MCP Server",
    description="Model Context Protocol server for W&B (Thread-Safe)",
    lifespan=lifespan
)

# Add CORS middleware
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Add request logging middleware for debugging
@app.middleware("http")
async def logging_middleware(request, call_next):
    """Log all incoming requests for debugging."""
    import time
    start_time = time.time()
    
    # Log request details
    logger.info(f"Incoming request: {request.method} {request.url.path}")
    
    # Log MCP-specific headers
    mcp_session_id = request.headers.get("mcp-session-id")
    if mcp_session_id:
        logger.info(f"   MCP Session ID in request: {mcp_session_id[:8]}...")
    
    # Try to log request body for POST requests
    if request.method == "POST" and request.url.path in ["/mcp", "/"]:
        try:
            # Clone the request body so we can read it
            body_bytes = await request.body()
            if body_bytes:
                import json
                try:
                    body_json = json.loads(body_bytes)
                    method = body_json.get("method", "unknown")
                    request_id = body_json.get("id", "unknown")
                    logger.info(f"   JSON-RPC request: method={method}, id={request_id}")
                    if method == "tools/call":
                        tool_name = body_json.get("params", {}).get("name", "unknown")
                        logger.info(f"   Tool call request for: {tool_name}")
                except json.JSONDecodeError:
                    logger.debug(f"   Request body (non-JSON): {body_bytes[:100]}")
                
                # Reconstruct the request with the body we read
                from starlette.datastructures import Headers
                from starlette.requests import Request as StarletteRequest
                
                # Create a new request with the body we read
                scope = request.scope
                scope["body"] = body_bytes
                
                async def receive():
                    return {"type": "http.request", "body": body_bytes}
                
                request = StarletteRequest(scope, receive)
            else:
                logger.debug("   No request body")
        except Exception as e:
            logger.debug(f"   Could not read request body: {e}")
    
    # Track if this is an MCP endpoint
    is_mcp = request.url.path.startswith("/mcp") or request.url.path == "/"
    
    try:
        response = await call_next(request)
        
        # Calculate response time
        process_time = time.time() - start_time
        
        # Log response details
        status_label = "SUCCESS" if response.status_code < 400 else "ERROR" if response.status_code >= 400 else "WARNING"
        logger.info(f"[{status_label}] Response: {request.method} {request.url.path} -> {response.status_code} ({process_time:.3f}s)")
        
        # Log detailed info for 404s
        if response.status_code == 404:
            logger.warning(f"404 Not Found for {request.url.path}")
            logger.debug(f"   Full URL: {request.url}")
            logger.debug(f"   Available routes: /, /health, /favicon.ico, /favicon.png, /mcp")
            if is_mcp:
                logger.debug("   This appears to be an MCP endpoint that wasn't handled")
        
        return response
    except Exception as e:
        logger.error(f"Error processing {request.method} {request.url.path}: {e}")
        raise

# Add authentication middleware
@app.middleware("http")
async def auth_middleware(request, call_next):
    """Add thread-safe OAuth 2.1 Bearer token authentication for MCP endpoints."""
    return await thread_safe_auth_middleware(request, call_next)

# Add custom routes
@app.get("/", response_class=HTMLResponse)
async def index():
    """Serve the landing page."""
    return INDEX_HTML_CONTENT

@app.get("/favicon.ico")
async def favicon():
    """Serve the official W&B logo favicon."""
    return Response(
        content=base64.b64decode(FAVICON_BASE64),
        media_type="image/png",
        headers={
            "Cache-Control": "public, max-age=31536000",
            "Content-Type": "image/png"
        }
    )

@app.get("/favicon.png")
async def favicon_png():
    """Alternative PNG favicon endpoint for better browser compatibility."""
    return Response(
        content=base64.b64decode(FAVICON_BASE64),
        media_type="image/png",
        headers={
            "Cache-Control": "public, max-age=31536000",
            "Content-Type": "image/png"
        }
    )

@app.get("/health")
async def health():
    """Health check endpoint."""
    try:
        tools = await mcp.list_tools()
        tool_count = len(tools)
    except:
        tool_count = 0
    
    auth_status = "disabled" if os.environ.get("MCP_AUTH_DISABLED", "false").lower() == "true" else "enabled"
    
    # Include worker information for debugging
    worker_info = {
        "pid": os.getpid(),
        "thread_id": threading.current_thread().name
    }
    
    return {
        "status": "healthy",
        "service": "wandb-mcp-server",
        "wandb_configured": wandb_configured,
        "tools_registered": tool_count,
        "authentication": auth_status,
        "worker_info": worker_info
    }

# Mount the MCP streamable HTTP app
# NOTE: MCP app is mounted at root "/" to handle all MCP protocol requests
# This means it will catch all unhandled routes, which is why we define our
# custom routes (/, /health, etc.) BEFORE mounting the MCP app
mcp_app = mcp.streamable_http_app()
logger.info("Mounting MCP streamable HTTP app at root /")
logger.info("Note: MCP will handle all unmatched routes, returning 404 for non-MCP requests")

# For debugging: Log incoming requests to understand routing
@app.middleware("http")
async def mcp_routing_debug(request, call_next):
    """Debug middleware to understand MCP routing issues."""
    path = request.url.path
    method = request.method
    
    # Check if this should be an MCP request
    is_mcp_request = (
        request.headers.get("Content-Type") == "application/json" and
        (request.headers.get("Accept", "").find("text/event-stream") >= 0 or
         request.headers.get("Accept", "").find("application/json") >= 0)
    )
    
    if path == "/" and method == "GET":
        logger.debug("Root GET request - should show landing page")
    elif path == "/health" and method == "GET":
        logger.debug("Health check request")
    elif path in ["/", "/mcp"] and is_mcp_request:
        logger.debug(f"MCP protocol request detected on {path}")
    elif path == "/" and method in ["POST", "GET"] and not is_mcp_request:
        logger.debug(f"Non-MCP {method} request to root - may get 404 from MCP app")
    
    return await call_next(request)

app.mount("/", mcp_app)

# Port for HF Spaces
PORT = int(os.environ.get("PORT", "7860"))

if __name__ == "__main__":
    import uvicorn
    logger.info(f"Starting server on 0.0.0.0:{PORT}")
    logger.info("Landing page: /")
    logger.info("Health check: /health")
    logger.info("MCP endpoint: /mcp")
    
    # In stateless mode, we can scale horizontally with multiple workers
    # However, for HuggingFace Spaces we use single worker for simplicity
    logger.info("Starting server (stateless mode - supports horizontal scaling)")
    uvicorn.run(app, host="0.0.0.0", port=PORT, workers=1)  # Can increase workers if needed