| """Tests for tools/mcp_oauth.py — thin OAuth adapter over MCP SDK.""" |
|
|
| import json |
| import os |
| from pathlib import Path |
| from unittest.mock import patch, MagicMock, AsyncMock |
|
|
| import pytest |
|
|
| from tools.mcp_oauth import ( |
| HermesTokenStorage, |
| build_oauth_auth, |
| remove_oauth_tokens, |
| _find_free_port, |
| _can_open_browser, |
| ) |
|
|
|
|
| |
| |
| |
|
|
| class TestHermesTokenStorage: |
| def test_roundtrip_tokens(self, tmp_path, monkeypatch): |
| monkeypatch.setenv("HERMES_HOME", str(tmp_path)) |
| storage = HermesTokenStorage("test-server") |
|
|
| import asyncio |
|
|
| |
| assert asyncio.run(storage.get_tokens()) is None |
|
|
| |
| mock_token = MagicMock() |
| mock_token.model_dump.return_value = { |
| "access_token": "abc123", |
| "token_type": "Bearer", |
| "refresh_token": "ref456", |
| } |
| asyncio.run(storage.set_tokens(mock_token)) |
|
|
| |
| token_path = tmp_path / "mcp-tokens" / "test-server.json" |
| assert token_path.exists() |
| data = json.loads(token_path.read_text()) |
| assert data["access_token"] == "abc123" |
|
|
| def test_roundtrip_client_info(self, tmp_path, monkeypatch): |
| monkeypatch.setenv("HERMES_HOME", str(tmp_path)) |
| storage = HermesTokenStorage("test-server") |
| import asyncio |
|
|
| assert asyncio.run(storage.get_client_info()) is None |
|
|
| mock_client = MagicMock() |
| mock_client.model_dump.return_value = { |
| "client_id": "hermes-123", |
| "client_secret": "secret", |
| } |
| asyncio.run(storage.set_client_info(mock_client)) |
|
|
| client_path = tmp_path / "mcp-tokens" / "test-server.client.json" |
| assert client_path.exists() |
|
|
| def test_remove_cleans_up(self, tmp_path, monkeypatch): |
| monkeypatch.setenv("HERMES_HOME", str(tmp_path)) |
| storage = HermesTokenStorage("test-server") |
|
|
| |
| d = tmp_path / "mcp-tokens" |
| d.mkdir(parents=True) |
| (d / "test-server.json").write_text("{}") |
| (d / "test-server.client.json").write_text("{}") |
|
|
| storage.remove() |
| assert not (d / "test-server.json").exists() |
| assert not (d / "test-server.client.json").exists() |
|
|
|
|
| |
| |
| |
|
|
| class TestBuildOAuthAuth: |
| def test_returns_oauth_provider(self): |
| try: |
| from mcp.client.auth import OAuthClientProvider |
| except ImportError: |
| pytest.skip("MCP SDK auth not available") |
|
|
| auth = build_oauth_auth("test", "https://example.com/mcp") |
| assert isinstance(auth, OAuthClientProvider) |
|
|
| def test_returns_none_without_sdk(self, monkeypatch): |
| import tools.mcp_oauth as mod |
| orig_import = __builtins__.__import__ if hasattr(__builtins__, '__import__') else __import__ |
|
|
| def _block_import(name, *args, **kwargs): |
| if "mcp.client.auth" in name: |
| raise ImportError("blocked") |
| return orig_import(name, *args, **kwargs) |
|
|
| with patch("builtins.__import__", side_effect=_block_import): |
| result = build_oauth_auth("test", "https://example.com") |
| |
| assert result is None or result is not None |
|
|
|
|
| |
| |
| |
|
|
| class TestUtilities: |
| def test_find_free_port_returns_int(self): |
| port = _find_free_port() |
| assert isinstance(port, int) |
| assert 1024 <= port <= 65535 |
|
|
| def test_can_open_browser_false_in_ssh(self, monkeypatch): |
| monkeypatch.setenv("SSH_CLIENT", "1.2.3.4 1234 22") |
| assert _can_open_browser() is False |
|
|
| def test_can_open_browser_false_without_display(self, monkeypatch): |
| monkeypatch.delenv("SSH_CLIENT", raising=False) |
| monkeypatch.delenv("SSH_TTY", raising=False) |
| monkeypatch.delenv("DISPLAY", raising=False) |
| |
| monkeypatch.setattr(os, "name", "posix") |
| monkeypatch.setattr(os, "uname", lambda: type("", (), {"sysname": "Linux"})()) |
| assert _can_open_browser() is False |
|
|
|
|
| |
| |
| |
|
|
| class TestPathTraversal: |
| """Verify server_name is sanitized to prevent path traversal.""" |
|
|
| def test_path_traversal_blocked(self, tmp_path, monkeypatch): |
| monkeypatch.setenv("HERMES_HOME", str(tmp_path)) |
| storage = HermesTokenStorage("../../.ssh/config") |
| path = storage._tokens_path() |
| |
| assert "mcp-tokens" in str(path) |
| assert ".ssh" not in str(path.resolve()) |
|
|
| def test_dots_and_slashes_sanitized(self, tmp_path, monkeypatch): |
| monkeypatch.setenv("HERMES_HOME", str(tmp_path)) |
| storage = HermesTokenStorage("../../../etc/passwd") |
| path = storage._tokens_path() |
| resolved = path.resolve() |
| assert resolved.is_relative_to((tmp_path / "mcp-tokens").resolve()) |
|
|
| def test_normal_name_unchanged(self, tmp_path, monkeypatch): |
| monkeypatch.setenv("HERMES_HOME", str(tmp_path)) |
| storage = HermesTokenStorage("my-mcp-server") |
| assert "my-mcp-server.json" in str(storage._tokens_path()) |
|
|
| def test_special_chars_sanitized(self, tmp_path, monkeypatch): |
| monkeypatch.setenv("HERMES_HOME", str(tmp_path)) |
| storage = HermesTokenStorage("server@host:8080/path") |
| path = storage._tokens_path() |
| assert "@" not in path.name |
| assert ":" not in path.name |
| assert "/" not in path.stem |
|
|
|
|
| class TestCallbackHandlerIsolation: |
| """Verify concurrent OAuth flows don't share state.""" |
|
|
| def test_independent_result_dicts(self): |
| from tools.mcp_oauth import _make_callback_handler |
| _, result_a = _make_callback_handler() |
| _, result_b = _make_callback_handler() |
|
|
| result_a["auth_code"] = "code_A" |
| result_b["auth_code"] = "code_B" |
|
|
| assert result_a["auth_code"] == "code_A" |
| assert result_b["auth_code"] == "code_B" |
|
|
| def test_handler_writes_to_own_result(self): |
| from tools.mcp_oauth import _make_callback_handler |
| from io import BytesIO |
| from unittest.mock import MagicMock |
|
|
| HandlerClass, result = _make_callback_handler() |
| assert result["auth_code"] is None |
|
|
| |
| handler = HandlerClass.__new__(HandlerClass) |
| handler.path = "/callback?code=test123&state=mystate" |
| handler.wfile = BytesIO() |
| handler.send_response = MagicMock() |
| handler.send_header = MagicMock() |
| handler.end_headers = MagicMock() |
| handler.do_GET() |
|
|
| assert result["auth_code"] == "test123" |
| assert result["state"] == "mystate" |
|
|
|
|
| class TestOAuthPortSharing: |
| """Verify build_oauth_auth and _wait_for_callback use the same port.""" |
|
|
| def test_port_stored_globally(self): |
| import tools.mcp_oauth as mod |
| |
| mod._oauth_port = None |
|
|
| try: |
| from mcp.client.auth import OAuthClientProvider |
| except ImportError: |
| pytest.skip("MCP SDK auth not available") |
|
|
| build_oauth_auth("test-port", "https://example.com/mcp") |
| assert mod._oauth_port is not None |
| assert isinstance(mod._oauth_port, int) |
| assert 1024 <= mod._oauth_port <= 65535 |
|
|
|
|
| class TestRemoveOAuthTokens: |
| def test_removes_files(self, tmp_path, monkeypatch): |
| monkeypatch.setenv("HERMES_HOME", str(tmp_path)) |
| d = tmp_path / "mcp-tokens" |
| d.mkdir() |
| (d / "myserver.json").write_text("{}") |
| (d / "myserver.client.json").write_text("{}") |
|
|
| remove_oauth_tokens("myserver") |
|
|
| assert not (d / "myserver.json").exists() |
| assert not (d / "myserver.client.json").exists() |
|
|
| def test_no_error_when_files_missing(self, tmp_path, monkeypatch): |
| monkeypatch.setenv("HERMES_HOME", str(tmp_path)) |
| remove_oauth_tokens("nonexistent") |
|
|