Spaces:
Running
Running
| import base64 | |
| import hashlib | |
| import hmac | |
| import json | |
| import os | |
| import secrets | |
| import time | |
| import uuid | |
| from typing import Any, Dict, Optional | |
| from urllib.parse import urlencode, parse_qs, urlsplit | |
| from fastmcp import Client, FastMCP | |
| from fastmcp.utilities.types import Image | |
| from mcp.types import CallToolResult | |
| from starlette.middleware.base import BaseHTTPMiddleware | |
| from starlette.requests import Request | |
| from starlette.responses import HTMLResponse, JSONResponse, PlainTextResponse, RedirectResponse, Response | |
| APP_NAME = "mcp-bridge" | |
| ACCESS_TOKEN_TTL_SECONDS = int(os.getenv("ACCESS_TOKEN_TTL_SECONDS", "3600")) | |
| REFRESH_TOKEN_TTL_SECONDS = int(os.getenv("REFRESH_TOKEN_TTL_SECONDS", "2592000")) | |
| AUTH_CODE_TTL_SECONDS = int(os.getenv("AUTH_CODE_TTL_SECONDS", "300")) | |
| OAUTH_STATIC_CLIENT_ID = os.getenv("OAUTH_STATIC_CLIENT_ID", "chatgpt").strip() | |
| OAUTH_STATIC_CLIENT_SECRET = os.getenv("OAUTH_STATIC_CLIENT_SECRET", "").strip() | |
| REMOTE_MCP_BEARER_TOKEN = os.getenv("REMOTE_MCP_BEARER_TOKEN", "").strip() | |
| REMOTE_MCP_X_API_KEY = os.getenv("REMOTE_MCP_X_API_KEY", "").strip() | |
| mcp = FastMCP(APP_NAME) | |
| SESSIONS: Dict[str, Dict[str, Any]] = {} | |
| OAUTH_CLIENTS: Dict[str, Dict[str, Any]] = {} | |
| AUTH_CODES: Dict[str, Dict[str, Any]] = {} | |
| def _normalize_remote_url(url: str) -> str: | |
| raw = str(url or "").strip() | |
| if not raw: | |
| raise ValueError("Remote MCP URL is required") | |
| parsed = urlsplit(raw) | |
| if not parsed.scheme: | |
| scheme = "http" if raw.startswith(("localhost", "127.0.0.1", "0.0.0.0", "[::1]")) else "https" | |
| raw = f"{scheme}://{raw}" | |
| parsed = urlsplit(raw) | |
| elif parsed.scheme not in {"http", "https"} and not parsed.netloc and ":" in raw: | |
| # urlsplit("localhost:4000/mcp") treats "localhost" as the scheme. | |
| raw = f"http://{raw}" | |
| parsed = urlsplit(raw) | |
| if parsed.scheme not in {"http", "https"} or not parsed.netloc: | |
| raise ValueError( | |
| "Invalid remote MCP URL. Use a full HTTP(S) URL such as " | |
| "https://example.com/sse or http://127.0.0.1:4000/mcp" | |
| ) | |
| return raw | |
| def _remote_headers( | |
| headers: Optional[Dict[str, str]] = None, | |
| bearer_token: str = "", | |
| x_api_key: str = "", | |
| ) -> Dict[str, str]: | |
| merged = {str(key): str(value) for key, value in (headers or {}).items() if value is not None} | |
| if bearer_token: | |
| merged["Authorization"] = f"Bearer {bearer_token}" | |
| elif REMOTE_MCP_BEARER_TOKEN and not any(key.lower() == "authorization" for key in merged): | |
| merged["Authorization"] = f"Bearer {REMOTE_MCP_BEARER_TOKEN}" | |
| if x_api_key: | |
| merged["x-api-key"] = x_api_key | |
| elif REMOTE_MCP_X_API_KEY and not any(key.lower() == "x-api-key" for key in merged): | |
| merged["x-api-key"] = REMOTE_MCP_X_API_KEY | |
| return merged | |
| def _prioritize_image_content(result: CallToolResult) -> CallToolResult: | |
| content = list(result.content or []) | |
| if len(content) < 2: | |
| return result | |
| image_items = [item for item in content if getattr(item, "type", None) == "image"] | |
| if not image_items: | |
| return result | |
| non_image_items = [item for item in content if getattr(item, "type", None) != "image"] | |
| if content == image_items + non_image_items: | |
| return result | |
| return result.model_copy(update={"content": image_items + non_image_items}) | |
| def _native_tool_outputs(result: CallToolResult) -> list[Any]: | |
| prioritized = _prioritize_image_content(result) | |
| image_outputs: list[Any] = [] | |
| fallback_outputs: list[Any] = [] | |
| for item in prioritized.content or []: | |
| item_type = getattr(item, "type", None) | |
| if item_type == "image": | |
| mime_type = getattr(item, "mimeType", None) or "image/png" | |
| image_format = mime_type.split("/", 1)[1] if "/" in mime_type else "png" | |
| image_outputs.append(Image(data=base64.b64decode(item.data), format=image_format)) | |
| continue | |
| if item_type == "text": | |
| fallback_outputs.append(item.text) | |
| continue | |
| if hasattr(item, "model_dump"): | |
| fallback_outputs.append(json.dumps(item.model_dump(), ensure_ascii=False, indent=2)) | |
| else: | |
| fallback_outputs.append(str(item)) | |
| if image_outputs: | |
| return image_outputs | |
| if fallback_outputs: | |
| return fallback_outputs | |
| return [json.dumps(prioritized.model_dump(), ensure_ascii=False, indent=2)] | |
| def _sanitized_tool_descriptor(tool: Any) -> Dict[str, Any]: | |
| if hasattr(tool, "model_dump"): | |
| descriptor = tool.model_dump() | |
| elif isinstance(tool, dict): | |
| descriptor = dict(tool) | |
| else: | |
| return {"name": str(tool)} | |
| descriptor.pop("outputSchema", None) | |
| return descriptor | |
| def _b64url(data: bytes) -> str: | |
| return base64.urlsafe_b64encode(data).decode("ascii").rstrip("=") | |
| def _b64url_decode(data: str) -> bytes: | |
| return base64.urlsafe_b64decode(data + "=" * (-len(data) % 4)) | |
| def _signing_secret() -> str: | |
| secret = os.getenv("MCP_OAUTH_SIGNING_KEY", "").strip() or os.getenv("MCP_API_KEY", "").strip() | |
| if not secret: | |
| # Keep metadata available even before secrets are configured, but tokens cannot be issued safely. | |
| secret = "development-only-change-me" | |
| return secret | |
| def _sign(value: str) -> str: | |
| digest = hmac.new(_signing_secret().encode("utf-8"), value.encode("utf-8"), hashlib.sha256).digest() | |
| return _b64url(digest) | |
| def _make_token(kind: str, audience: str, ttl: int, scope: str = "mcp") -> str: | |
| now = int(time.time()) | |
| payload = { | |
| "typ": kind, | |
| "iss": None, # Filled by validation context only; kept opaque for clients. | |
| "sub": "owner", | |
| "aud": audience, | |
| "scope": scope, | |
| "iat": now, | |
| "exp": now + ttl, | |
| "jti": secrets.token_urlsafe(16), | |
| } | |
| body = _b64url(json.dumps(payload, separators=(",", ":")).encode("utf-8")) | |
| return f"{body}.{_sign(body)}" | |
| def _decode_token(token: str) -> Optional[Dict[str, Any]]: | |
| try: | |
| body, sig = token.rsplit(".", 1) | |
| if not hmac.compare_digest(sig, _sign(body)): | |
| return None | |
| payload = json.loads(_b64url_decode(body)) | |
| if int(payload.get("exp", 0)) < int(time.time()): | |
| return None | |
| return payload | |
| except Exception: | |
| return None | |
| def _base_url(request: Request) -> str: | |
| explicit = os.getenv("PUBLIC_BASE_URL", "").strip().rstrip("/") | |
| if explicit: | |
| return explicit | |
| space_host = os.getenv("SPACE_HOST", "").strip().rstrip("/") | |
| if space_host: | |
| if not space_host.startswith("http"): | |
| return f"https://{space_host}" | |
| return space_host | |
| space_id = os.getenv("SPACE_ID", "").strip() | |
| if "/" in space_id: | |
| owner, name = space_id.split("/", 1) | |
| return f"https://{owner}-{name}.hf.space" | |
| proto = request.headers.get("x-forwarded-proto") or request.url.scheme or "https" | |
| host = request.headers.get("x-forwarded-host") or request.headers.get("host") or request.url.netloc | |
| return f"{proto}://{host}".rstrip("/") | |
| def _resource_uri(request: Request) -> str: | |
| return f"{_base_url(request)}/sse" | |
| def _metadata_url(request: Request) -> str: | |
| return f"{_base_url(request)}/.well-known/oauth-protected-resource" | |
| def _auth_server_metadata(request: Request) -> Dict[str, Any]: | |
| base = _base_url(request) | |
| return { | |
| "issuer": base, | |
| "authorization_endpoint": f"{base}/oauth/authorize", | |
| "token_endpoint": f"{base}/oauth/token", | |
| "registration_endpoint": f"{base}/oauth/register", | |
| "revocation_endpoint": f"{base}/oauth/revoke", | |
| "jwks_uri": f"{base}/oauth/jwks", | |
| "response_types_supported": ["code"], | |
| "grant_types_supported": ["authorization_code", "refresh_token"], | |
| "code_challenge_methods_supported": ["S256", "plain"], | |
| "token_endpoint_auth_methods_supported": ["none", "client_secret_post", "client_secret_basic"], | |
| "scopes_supported": ["mcp"], | |
| # ChatGPT may show manual fields when it cannot use dynamic client registration. | |
| # In that mode use OAUTH_STATIC_CLIENT_ID (default: chatgpt) and token auth method "none". | |
| "client_registration_types_supported": ["dynamic", "manual"], | |
| "resource_indicators_supported": True, | |
| } | |
| def _protected_resource_metadata(request: Request) -> Dict[str, Any]: | |
| base = _base_url(request) | |
| return { | |
| "resource": _resource_uri(request), | |
| "authorization_servers": [base], | |
| "bearer_methods_supported": ["header"], | |
| "scopes_supported": ["mcp"], | |
| "resource_documentation": base, | |
| } | |
| def _unauthorized(request: Request, error: str = "invalid_token") -> JSONResponse: | |
| header = f'Bearer resource_metadata="{_metadata_url(request)}", error="{error}"' | |
| return JSONResponse({"error": "unauthorized"}, status_code=401, headers={"WWW-Authenticate": header}) | |
| def _parse_bearer(auth_header: str) -> Optional[str]: | |
| if not auth_header: | |
| return None | |
| parts = auth_header.split(None, 1) | |
| if len(parts) == 2 and parts[0].lower() == "bearer": | |
| return parts[1].strip() | |
| return None | |
| def _constant_time_equal(a: str, b: str) -> bool: | |
| return hmac.compare_digest(a.encode("utf-8"), b.encode("utf-8")) | |
| def _verify_mcp_auth(request: Request) -> bool: | |
| expected = os.getenv("MCP_API_KEY", "").strip() | |
| token = _parse_bearer(request.headers.get("authorization", "")) | |
| x_api_key = request.headers.get("x-api-key", "").strip() | |
| # Legacy/manual auth still works for local tools and non-ChatGPT clients. | |
| if expected and (x_api_key and _constant_time_equal(x_api_key, expected)): | |
| return True | |
| if expected and token and _constant_time_equal(token, expected): | |
| return True | |
| if not token: | |
| return False | |
| payload = _decode_token(token) | |
| if not payload or payload.get("typ") != "access": | |
| return False | |
| # Validate the audience/resource binding. Accept the exact SSE resource, plus base URL for older clients. | |
| aud = str(payload.get("aud", "")).rstrip("/") | |
| allowed = {_resource_uri(request).rstrip("/"), _base_url(request).rstrip("/")} | |
| return aud in allowed | |
| def _is_oauth_or_public_path(path: str) -> bool: | |
| if path in {"/", "/health"}: | |
| return True | |
| if path.startswith("/.well-known/"): | |
| return True | |
| if path.startswith("/oauth/"): | |
| return True | |
| return False | |
| class AuthMiddleware(BaseHTTPMiddleware): | |
| async def dispatch(self, request: Request, call_next): | |
| path = request.url.path.rstrip("/") or "/" | |
| if request.method == "OPTIONS": | |
| return Response(status_code=204) | |
| if _is_oauth_or_public_path(path): | |
| return await call_next(request) | |
| if not os.getenv("MCP_API_KEY", "").strip(): | |
| return JSONResponse({"error": "server_auth_not_configured"}, status_code=503) | |
| if not _verify_mcp_auth(request): | |
| return _unauthorized(request) | |
| return await call_next(request) | |
| async def public_ok(request: Request) -> JSONResponse: | |
| # Intentionally boring: no transport, tools, URLs, secrets, versions, or config hints. | |
| return JSONResponse({"ok": True}) | |
| async def oauth_authorization_server(request: Request) -> JSONResponse: | |
| return JSONResponse(_auth_server_metadata(request)) | |
| async def oauth_openid_configuration(request: Request) -> JSONResponse: | |
| # Some clients probe OIDC discovery even though this is OAuth-only. | |
| return JSONResponse(_auth_server_metadata(request)) | |
| async def oauth_protected_resource(request: Request) -> JSONResponse: | |
| return JSONResponse(_protected_resource_metadata(request)) | |
| async def oauth_jwks(request: Request) -> JSONResponse: | |
| # Tokens are opaque HMAC tokens; no public keys are needed. | |
| return JSONResponse({"keys": []}) | |
| async def oauth_register(request: Request) -> JSONResponse: | |
| try: | |
| data = await request.json() | |
| except Exception: | |
| data = {} | |
| redirect_uris = data.get("redirect_uris") or [] | |
| if not isinstance(redirect_uris, list) or not redirect_uris: | |
| return JSONResponse({"error": "invalid_redirect_uris"}, status_code=400) | |
| client_id = "mcp_client_" + secrets.token_urlsafe(24) | |
| client_secret = secrets.token_urlsafe(32) | |
| now = int(time.time()) | |
| client = { | |
| "client_id": client_id, | |
| "client_secret": client_secret, | |
| "client_name": data.get("client_name") or "MCP client", | |
| "redirect_uris": redirect_uris, | |
| "grant_types": data.get("grant_types") or ["authorization_code", "refresh_token"], | |
| "response_types": data.get("response_types") or ["code"], | |
| "scope": data.get("scope") or "mcp", | |
| "token_endpoint_auth_method": data.get("token_endpoint_auth_method") or "none", | |
| "created_at": now, | |
| } | |
| OAUTH_CLIENTS[client_id] = client | |
| # Return a secret too, but allow public-client 'none'. This maximizes client compatibility. | |
| return JSONResponse({ | |
| "client_id": client_id, | |
| "client_id_issued_at": now, | |
| "client_secret": client_secret, | |
| "client_secret_expires_at": 0, | |
| "client_name": client["client_name"], | |
| "redirect_uris": redirect_uris, | |
| "grant_types": client["grant_types"], | |
| "response_types": client["response_types"], | |
| "scope": client["scope"], | |
| "token_endpoint_auth_method": client["token_endpoint_auth_method"], | |
| }, status_code=201) | |
| def _allowed_redirect_uris() -> list[str]: | |
| return [x.strip() for x in os.getenv("OAUTH_ALLOWED_REDIRECT_URIS", "").split(",") if x.strip()] | |
| def _allow_chatgpt_redirect_prefix() -> bool: | |
| # ChatGPT custom connector callback URLs contain a per-connector id, for example: | |
| # https://chatgpt.com/connector/oauth/<random-id> | |
| # If OAUTH_ALLOWED_REDIRECT_URIS is empty, allow this prefix so users do not | |
| # need to resync the Space every time ChatGPT creates a new callback URL. | |
| value = os.getenv("OAUTH_ALLOW_CHATGPT_REDIRECT_PREFIX", "true").strip().lower() | |
| return value not in {"0", "false", "no", "off"} | |
| def _is_chatgpt_redirect_uri(redirect_uri: str) -> bool: | |
| return ( | |
| _allow_chatgpt_redirect_prefix() | |
| and redirect_uri.startswith("https://chatgpt.com/connector/oauth/") | |
| and len(redirect_uri) > len("https://chatgpt.com/connector/oauth/") | |
| ) | |
| def _is_static_client(client_id: str) -> bool: | |
| return bool(OAUTH_STATIC_CLIENT_ID) and client_id == OAUTH_STATIC_CLIENT_ID | |
| def _validate_redirect_uri(client_id: str, redirect_uri: str) -> bool: | |
| client = OAUTH_CLIENTS.get(client_id) | |
| if client: | |
| return redirect_uri in client.get("redirect_uris", []) | |
| if _is_static_client(client_id): | |
| allowlist = _allowed_redirect_uris() | |
| if allowlist: | |
| return redirect_uri in allowlist | |
| # Manual ChatGPT setup. When no exact allowlist is configured, accept | |
| # ChatGPT connector callbacks by prefix. The user must still enter | |
| # MCP_API_KEY on the authorize page before any code/token is issued. | |
| return _is_chatgpt_redirect_uri(redirect_uri) | |
| return False | |
| def _oauth_authorize_params_from_query(request: Request) -> Dict[str, str]: | |
| return {k: v for k, v in request.query_params.items()} | |
| def _html_escape(value: str) -> str: | |
| return ( | |
| value.replace("&", "&") | |
| .replace("<", "<") | |
| .replace(">", ">") | |
| .replace('"', """) | |
| .replace("'", "'") | |
| ) | |
| def _remote_client(url: str, headers: Optional[Dict[str, str]] = None) -> Client: | |
| url = _normalize_remote_url(url) | |
| normalized_headers = {str(key): str(value) for key, value in (headers or {}).items() if value is not None} | |
| if not normalized_headers: | |
| return Client(url) | |
| try: | |
| from fastmcp.client.transports import SSETransport, StreamableHttpTransport | |
| except ImportError: | |
| return Client(url, headers=normalized_headers) | |
| transport = SSETransport(url=url, headers=normalized_headers) | |
| if not url.rstrip("/").endswith("/sse"): | |
| transport = StreamableHttpTransport(url=url, headers=normalized_headers) | |
| return Client(transport) | |
| async def oauth_authorize_get(request: Request) -> Response: | |
| p = _oauth_authorize_params_from_query(request) | |
| required = ["client_id", "redirect_uri", "response_type"] | |
| missing = [x for x in required if not p.get(x)] | |
| if missing: | |
| return JSONResponse({"error": "invalid_request", "missing": missing}, status_code=400) | |
| if p.get("response_type") != "code": | |
| return JSONResponse({"error": "unsupported_response_type"}, status_code=400) | |
| if not _validate_redirect_uri(p["client_id"], p["redirect_uri"]): | |
| return JSONResponse({"error": "invalid_redirect_uri"}, status_code=400) | |
| hidden = "\n".join( | |
| f'<input type="hidden" name="{_html_escape(k)}" value="{_html_escape(v)}">' | |
| for k, v in p.items() | |
| ) | |
| html = f"""<!doctype html> | |
| <html lang="en"> | |
| <head> | |
| <meta charset="utf-8"> | |
| <meta name="viewport" content="width=device-width, initial-scale=1"> | |
| <title>Authorize</title> | |
| <style> | |
| body {{ font-family: system-ui, sans-serif; max-width: 520px; margin: 4rem auto; padding: 0 1rem; }} | |
| input, button {{ font: inherit; padding: .75rem; width: 100%; box-sizing: border-box; }} | |
| button {{ margin-top: .75rem; cursor: pointer; }} | |
| .muted {{ color: #666; font-size: .9rem; }} | |
| </style> | |
| </head> | |
| <body> | |
| <h1>Authorize MCP access</h1> | |
| <p class="muted">Enter the server access key to approve this client.</p> | |
| <form method="post" action="/oauth/authorize"> | |
| {hidden} | |
| <input type="password" name="access_key" placeholder="Access key" autocomplete="current-password" autofocus required> | |
| <button type="submit">Authorize</button> | |
| </form> | |
| </body> | |
| </html>""" | |
| return HTMLResponse(html) | |
| async def _read_form_urlencoded(request: Request) -> Dict[str, str]: | |
| body = (await request.body()).decode("utf-8") | |
| parsed = parse_qs(body, keep_blank_values=True) | |
| return {k: v[-1] if v else "" for k, v in parsed.items()} | |
| def _redirect_with_error(redirect_uri: str, state: str, error: str, description: str = "") -> RedirectResponse: | |
| q = {"error": error} | |
| if description: | |
| q["error_description"] = description | |
| if state: | |
| q["state"] = state | |
| sep = "&" if "?" in redirect_uri else "?" | |
| return RedirectResponse(f"{redirect_uri}{sep}{urlencode(q)}", status_code=302) | |
| async def oauth_authorize_post(request: Request) -> Response: | |
| p = await _read_form_urlencoded(request) | |
| expected = os.getenv("MCP_API_KEY", "").strip() | |
| if not expected: | |
| return PlainTextResponse("Server auth is not configured.", status_code=503) | |
| redirect_uri = p.get("redirect_uri", "") | |
| state = p.get("state", "") | |
| client_id = p.get("client_id", "") | |
| if not redirect_uri or not _validate_redirect_uri(client_id, redirect_uri): | |
| return JSONResponse({"error": "invalid_redirect_uri"}, status_code=400) | |
| if not _constant_time_equal(p.get("access_key", ""), expected): | |
| return _redirect_with_error(redirect_uri, state, "access_denied", "Invalid access key") | |
| resource = p.get("resource") or _resource_uri(request) | |
| code = secrets.token_urlsafe(32) | |
| AUTH_CODES[code] = { | |
| "client_id": client_id, | |
| "redirect_uri": redirect_uri, | |
| "scope": p.get("scope") or "mcp", | |
| "resource": resource.rstrip("/"), | |
| "code_challenge": p.get("code_challenge", ""), | |
| "code_challenge_method": p.get("code_challenge_method", "plain"), | |
| "expires_at": int(time.time()) + AUTH_CODE_TTL_SECONDS, | |
| } | |
| q = {"code": code} | |
| if state: | |
| q["state"] = state | |
| sep = "&" if "?" in redirect_uri else "?" | |
| return RedirectResponse(f"{redirect_uri}{sep}{urlencode(q)}", status_code=302) | |
| def _check_pkce(code_info: Dict[str, Any], verifier: str) -> bool: | |
| challenge = code_info.get("code_challenge") or "" | |
| method = (code_info.get("code_challenge_method") or "plain").upper() | |
| if not challenge: | |
| # Be lenient for older clients, but ChatGPT should send PKCE. | |
| return True | |
| if method == "S256": | |
| calculated = _b64url(hashlib.sha256(verifier.encode("ascii")).digest()) | |
| return hmac.compare_digest(calculated, challenge) | |
| return hmac.compare_digest(verifier, challenge) | |
| async def oauth_token(request: Request) -> JSONResponse: | |
| p = await _read_form_urlencoded(request) | |
| grant_type = p.get("grant_type", "") | |
| if grant_type == "authorization_code": | |
| code = p.get("code", "") | |
| info = AUTH_CODES.pop(code, None) | |
| if not info: | |
| return JSONResponse({"error": "invalid_grant"}, status_code=400) | |
| if int(info.get("expires_at", 0)) < int(time.time()): | |
| return JSONResponse({"error": "invalid_grant"}, status_code=400) | |
| if p.get("redirect_uri") and p.get("redirect_uri") != info.get("redirect_uri"): | |
| return JSONResponse({"error": "invalid_grant"}, status_code=400) | |
| submitted_client_id = p.get("client_id") or info.get("client_id") | |
| if submitted_client_id != info.get("client_id"): | |
| return JSONResponse({"error": "invalid_client"}, status_code=401) | |
| # Optional support for confidential/manual clients. ChatGPT can use public-client mode with "none". | |
| if _is_static_client(submitted_client_id) and OAUTH_STATIC_CLIENT_SECRET: | |
| provided_secret = p.get("client_secret", "") | |
| if not hmac.compare_digest(provided_secret.encode("utf-8"), OAUTH_STATIC_CLIENT_SECRET.encode("utf-8")): | |
| return JSONResponse({"error": "invalid_client"}, status_code=401) | |
| if not _check_pkce(info, p.get("code_verifier", "")): | |
| return JSONResponse({"error": "invalid_grant", "error_description": "PKCE verification failed"}, status_code=400) | |
| resource = (p.get("resource") or info.get("resource") or _resource_uri(request)).rstrip("/") | |
| access = _make_token("access", resource, ACCESS_TOKEN_TTL_SECONDS, info.get("scope", "mcp")) | |
| refresh = _make_token("refresh", resource, REFRESH_TOKEN_TTL_SECONDS, info.get("scope", "mcp")) | |
| return JSONResponse({ | |
| "access_token": access, | |
| "token_type": "Bearer", | |
| "expires_in": ACCESS_TOKEN_TTL_SECONDS, | |
| "refresh_token": refresh, | |
| "scope": info.get("scope", "mcp"), | |
| }) | |
| if grant_type == "refresh_token": | |
| refresh_token = p.get("refresh_token", "") | |
| payload = _decode_token(refresh_token) | |
| if not payload or payload.get("typ") != "refresh": | |
| return JSONResponse({"error": "invalid_grant"}, status_code=400) | |
| resource = (p.get("resource") or payload.get("aud") or _resource_uri(request)).rstrip("/") | |
| scope = payload.get("scope", "mcp") | |
| return JSONResponse({ | |
| "access_token": _make_token("access", resource, ACCESS_TOKEN_TTL_SECONDS, scope), | |
| "token_type": "Bearer", | |
| "expires_in": ACCESS_TOKEN_TTL_SECONDS, | |
| "refresh_token": _make_token("refresh", resource, REFRESH_TOKEN_TTL_SECONDS, scope), | |
| "scope": scope, | |
| }) | |
| return JSONResponse({"error": "unsupported_grant_type"}, status_code=400) | |
| async def oauth_revoke(request: Request) -> JSONResponse: | |
| # Stateless tokens cannot be revoked individually in this minimal implementation. | |
| # Rotate MCP_OAUTH_SIGNING_KEY or MCP_API_KEY to invalidate all tokens. | |
| return JSONResponse({"ok": True}) | |
| async def mcp_connect( | |
| url: str, | |
| headers: Optional[Dict[str, str]] = None, | |
| bearer_token: str = "", | |
| x_api_key: str = "", | |
| verify: bool = True, | |
| ) -> Dict[str, Any]: | |
| """Register a remote MCP HTTP/SSE server URL and optionally verify connectivity immediately.""" | |
| normalized_url = _normalize_remote_url(url) | |
| resolved_headers = _remote_headers(headers, bearer_token=bearer_token, x_api_key=x_api_key) | |
| session_id = str(uuid.uuid4()) | |
| tool_count: Optional[int] = None | |
| if verify: | |
| try: | |
| async with _remote_client(normalized_url, resolved_headers) as client: | |
| tool_count = len(await client.list_tools()) | |
| except Exception as exc: | |
| raise ValueError(f"Failed to connect to remote MCP server at {normalized_url}: {exc}") from exc | |
| SESSIONS[session_id] = {"url": normalized_url, "headers": resolved_headers} | |
| return { | |
| "session_id": session_id, | |
| "status": "connected_registered", | |
| "url": normalized_url, | |
| "verified": verify, | |
| "tool_count": tool_count, | |
| } | |
| async def mcp_tool_list(session_id: str) -> Dict[str, Any]: | |
| """List tools from a registered remote MCP server.""" | |
| cfg = SESSIONS.get(session_id) | |
| if not cfg: | |
| raise ValueError("Unknown session_id") | |
| try: | |
| async with _remote_client(cfg["url"], cfg["headers"]) as client: | |
| tools = await client.list_tools() | |
| except Exception as exc: | |
| raise ValueError(f"Failed to connect to remote MCP server at {cfg['url']}: {exc}") from exc | |
| return { | |
| "session_id": session_id, | |
| "tools": [_sanitized_tool_descriptor(tool) for tool in tools], | |
| } | |
| async def mcp_tool_call(session_id: str, tool_name: str, arguments: Optional[Dict[str, Any]] = None) -> list[Any]: | |
| """Call a tool on a registered remote MCP server and emit native FastMCP tool outputs.""" | |
| cfg = SESSIONS.get(session_id) | |
| if not cfg: | |
| raise ValueError("Unknown session_id") | |
| try: | |
| async with _remote_client(cfg["url"], cfg["headers"]) as client: | |
| result = await client.call_tool_mcp(tool_name, arguments or {}) | |
| prioritized = _prioritize_image_content(result) | |
| if prioritized.isError: | |
| text_parts = [item.text for item in prioritized.content or [] if getattr(item, "type", None) == "text"] | |
| message = text_parts[0] if text_parts else json.dumps(prioritized.model_dump(), ensure_ascii=False, indent=2) | |
| raise ValueError(f"Remote MCP tool {tool_name} returned an error: {message}") | |
| return _native_tool_outputs(prioritized) | |
| except Exception as exc: | |
| raise ValueError(f"Failed to connect to remote MCP server at {cfg['url']}: {exc}") from exc | |
| async def mcp_disconnect(session_id: str) -> Dict[str, Any]: | |
| """Forget a registered remote MCP session.""" | |
| existed = session_id in SESSIONS | |
| SESSIONS.pop(session_id, None) | |
| return {"session_id": session_id, "disconnected": existed} | |
| async def mcp_session_list() -> Dict[str, Any]: | |
| """List registered session IDs without exposing headers.""" | |
| return { | |
| "count": len(SESSIONS), | |
| "sessions": [{"session_id": sid, "url": cfg.get("url")} for sid, cfg in SESSIONS.items()], | |
| } | |
| # FastMCP returns a Starlette app. Keep everything Starlette-compatible. | |
| app = mcp.http_app(path="/sse") | |
| app.add_middleware(AuthMiddleware) | |
| # Quiet public routes. | |
| app.add_route("/", public_ok, methods=["GET", "HEAD"]) | |
| app.add_route("/health", public_ok, methods=["GET", "HEAD"]) | |
| # OAuth / MCP authorization discovery. | |
| app.add_route("/.well-known/oauth-authorization-server", oauth_authorization_server, methods=["GET"]) | |
| app.add_route("/.well-known/oauth-authorization-server/sse", oauth_authorization_server, methods=["GET"]) | |
| app.add_route("/.well-known/openid-configuration", oauth_openid_configuration, methods=["GET"]) | |
| app.add_route("/.well-known/oauth-protected-resource", oauth_protected_resource, methods=["GET"]) | |
| app.add_route("/.well-known/oauth-protected-resource/sse", oauth_protected_resource, methods=["GET"]) | |
| # OAuth endpoints. | |
| app.add_route("/oauth/register", oauth_register, methods=["POST"]) | |
| app.add_route("/oauth/authorize", oauth_authorize_get, methods=["GET"]) | |
| app.add_route("/oauth/authorize", oauth_authorize_post, methods=["POST"]) | |
| app.add_route("/oauth/token", oauth_token, methods=["POST"]) | |
| app.add_route("/oauth/revoke", oauth_revoke, methods=["POST"]) | |
| app.add_route("/oauth/jwks", oauth_jwks, methods=["GET"]) | |