mcp-bridge / test_oauth_flow.py
patdev's picture
Update mcp-bridge OAuth server
082d529 verified
Raw
History Blame Contribute Delete
9.11 kB
import asyncio
import os
from urllib.parse import parse_qs, urlparse
os.environ.setdefault("MCP_API_KEY", "test-key")
os.environ.setdefault("PUBLIC_BASE_URL", "http://testserver")
import pytest
from fastmcp import Client
from mcp.types import CallToolResult, ImageContent, TextContent
from starlette.testclient import TestClient
import app as bridge_app
from app import SESSIONS, app, _normalize_remote_url, _remote_client, _remote_headers, mcp_connect, mcp_tool_list
class FakeRemoteClient:
def __init__(self, result):
self.result = result
self.calls = []
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
return False
async def call_tool_mcp(self, name, arguments):
self.calls.append((name, arguments))
return self.result
async def list_tools(self):
return []
def test_oauth_metadata_and_flow():
with TestClient(app) as c:
assert c.get("/").json() == {"ok": True}
r = c.get("/.well-known/oauth-protected-resource")
assert r.status_code == 200
assert r.json()["authorization_servers"] == ["http://testserver"]
r = c.get("/.well-known/oauth-authorization-server")
assert r.status_code == 200
assert "registration_endpoint" in r.json()
r = c.post("/oauth/register", json={"redirect_uris": ["https://chatgpt.com/oauth/callback"]})
assert r.status_code == 201
client_id = r.json()["client_id"]
params = {
"response_type": "code",
"client_id": client_id,
"redirect_uri": "https://chatgpt.com/oauth/callback",
"state": "abc",
"resource": "http://testserver/sse",
"code_challenge": "verifier",
"code_challenge_method": "plain",
}
r = c.get("/oauth/authorize", params=params)
assert r.status_code == 200
r = c.post("/oauth/authorize", data={**params, "access_key": "test-key"}, follow_redirects=False)
assert r.status_code in (302, 307)
location = r.headers["location"]
q = parse_qs(urlparse(location).query)
code = q["code"][0]
r = c.post("/oauth/token", data={
"grant_type": "authorization_code",
"code": code,
"redirect_uri": "https://chatgpt.com/oauth/callback",
"client_id": client_id,
"code_verifier": "verifier",
"resource": "http://testserver/sse",
})
assert r.status_code == 200
access_token = r.json()["access_token"]
r = c.get("/sse")
assert r.status_code == 401
assert "WWW-Authenticate" in r.headers
r = c.get("/sse", headers={"Authorization": f"Bearer {access_token}"})
assert r.status_code != 401
def test_remote_client_preserves_headers_for_http_transport():
client = _remote_client("https://example.test/sse", {"Authorization": "Bearer token"})
assert client.transport.headers == {"Authorization": "Bearer token"}
assert client.transport.url == "https://example.test/sse"
def test_normalize_remote_url_adds_https_for_bare_remote_host():
assert _normalize_remote_url("patdev-mcp-bridge.hf.space/sse") == "https://patdev-mcp-bridge.hf.space/sse"
def test_normalize_remote_url_adds_http_for_localhost_target():
assert _normalize_remote_url("localhost:4000/mcp") == "http://localhost:4000/mcp"
def test_remote_headers_adds_bearer_and_api_key():
headers = _remote_headers({}, bearer_token="secret-token", x_api_key="secret-key")
assert headers["Authorization"] == "Bearer secret-token"
assert headers["x-api-key"] == "secret-key"
def test_mcp_connect_stores_normalized_url(monkeypatch):
fake_client = FakeRemoteClient(None)
monkeypatch.setattr("app._remote_client", lambda url, headers=None: fake_client)
result = asyncio.run(mcp_connect.fn("localhost:4000/mcp"))
session_id = result["session_id"]
try:
assert result["url"] == "http://localhost:4000/mcp"
assert result["verified"] is True
assert result["tool_count"] == 0
assert SESSIONS[session_id]["url"] == "http://localhost:4000/mcp"
finally:
SESSIONS.pop(session_id, None)
def test_mcp_connect_merges_remote_auth_headers(monkeypatch):
captured = {}
fake_client = FakeRemoteClient(None)
def fake_remote_client(url, headers=None):
captured["url"] = url
captured["headers"] = headers
return fake_client
monkeypatch.setattr("app._remote_client", fake_remote_client)
result = asyncio.run(mcp_connect.fn("https://example.test/sse", bearer_token="token-123", x_api_key="key-456"))
session_id = result["session_id"]
try:
assert captured["url"] == "https://example.test/sse"
assert captured["headers"]["Authorization"] == "Bearer token-123"
assert captured["headers"]["x-api-key"] == "key-456"
assert SESSIONS[session_id]["headers"] == captured["headers"]
finally:
SESSIONS.pop(session_id, None)
def test_mcp_connect_reports_target_url_on_verify_failure(monkeypatch):
class FailingClient:
async def __aenter__(self):
raise RuntimeError("401 Unauthorized")
async def __aexit__(self, exc_type, exc, tb):
return False
monkeypatch.setattr("app._remote_client", lambda url, headers=None: FailingClient())
with pytest.raises(ValueError, match="https://patdev-mcp-bridge.hf.space/sse"):
asyncio.run(mcp_connect.fn("https://patdev-mcp-bridge.hf.space/sse", bearer_token="bad-token"))
def test_mcp_tool_call_emits_native_image_content_end_to_end(monkeypatch):
result = CallToolResult(
content=[
TextContent(type="text", text="ok"),
ImageContent(type="image", data="AAAA", mimeType="image/png"),
],
isError=False,
)
fake_client = FakeRemoteClient(result)
session_id = "test-session"
SESSIONS[session_id] = {"url": "https://example.test/sse", "headers": {}}
monkeypatch.setattr("app._remote_client", lambda url, headers=None: fake_client)
async def run_test():
client = Client(bridge_app.mcp)
async with client:
returned = await client.call_tool_mcp(
"mcp_tool_call",
{"session_id": session_id, "tool_name": "Screenshot", "arguments": {}},
)
assert returned.content[0].type == "image"
assert returned.content[0].mimeType == "image/png"
assert len(returned.content) == 1
assert returned.isError is False
try:
asyncio.run(run_test())
finally:
SESSIONS.pop(session_id, None)
assert fake_client.calls == [("Screenshot", {})]
def test_mcp_tool_list_strips_remote_output_schema(monkeypatch):
class FakeTool:
def model_dump(self):
return {
"name": "mcp_windows_Screenshot",
"description": "fake screenshot tool",
"inputSchema": {"type": "object"},
"outputSchema": {"type": "object", "properties": {"image": {"type": "string"}}},
}
class ListingClient:
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
return False
async def list_tools(self):
return [FakeTool()]
session_id = "tool-list-session"
SESSIONS[session_id] = {"url": "https://example.test/sse", "headers": {}}
monkeypatch.setattr("app._remote_client", lambda url, headers=None: ListingClient())
async def run_test():
returned = await mcp_tool_list.fn(session_id)
assert returned["tools"][0]["name"] == "mcp_windows_Screenshot"
assert returned["tools"][0]["inputSchema"] == {"type": "object"}
assert "outputSchema" not in returned["tools"][0]
try:
asyncio.run(run_test())
finally:
SESSIONS.pop(session_id, None)
def test_mcp_tool_list_reports_target_url_on_connect_failure(monkeypatch):
session_id = "broken-session"
SESSIONS[session_id] = {"url": "https://missing.example.test/sse", "headers": {}}
class FailingClient:
async def __aenter__(self):
raise OSError("[Errno -2] Name or service not known")
async def __aexit__(self, exc_type, exc, tb):
return False
monkeypatch.setattr("app._remote_client", lambda url, headers=None: FailingClient())
async def run_test():
try:
await mcp_tool_list.fn(session_id)
except ValueError as exc:
message = str(exc)
assert "https://missing.example.test/sse" in message
assert "Name or service not known" in message
else:
raise AssertionError("Expected mcp_tool_list to raise ValueError")
try:
asyncio.run(run_test())
finally:
SESSIONS.pop(session_id, None)
if __name__ == "__main__":
test_oauth_metadata_and_flow()
test_remote_client_preserves_headers_for_http_transport()
print("OAuth smoke test passed")