Spaces:
Running
Running
| """ | |
| pytest suite for MCP Governance Receipts Server. | |
| Tests cover all 4 MCP tools: | |
| - emit_receipt | |
| - verify_receipt | |
| - list_receipts | |
| - get_attestation_chain | |
| Spec: https://modelcontextprotocol.io | |
| DSSE: https://slsa.dev/spec/v1.0/about | |
| Thesis DOI: https://doi.org/10.5281/zenodo.20434276 | |
| SPDX-License-Identifier: Apache-2.0 | |
| Copyright 2026 SZL Holdings / Lutar, Stephen P. ORCID 0009-0001-0110-4173 | |
| """ | |
| from __future__ import annotations | |
| import base64 | |
| import hashlib | |
| import importlib | |
| import json | |
| import sys | |
| import types | |
| import os | |
| import pytest | |
| # --------------------------------------------------------------------------- | |
| # Isolate the server module so tests don't start the HTTP listener | |
| # --------------------------------------------------------------------------- | |
| # Patch FastMCP before importing server so no network socket is opened | |
| class _FakeFastMCP: | |
| def __init__(self, name="", description=""): | |
| self.name = name | |
| def tool(self): | |
| def decorator(fn): | |
| return fn | |
| return decorator | |
| def run(self, *args, **kwargs): | |
| pass | |
| # Inject fake 'mcp' package | |
| _mcp_pkg = types.ModuleType("mcp") | |
| _mcp_server = types.ModuleType("mcp.server") | |
| _mcp_fastmcp = types.ModuleType("mcp.server.fastmcp") | |
| _mcp_fastmcp.FastMCP = _FakeFastMCP | |
| _mcp_pkg.server = _mcp_server | |
| _mcp_server.fastmcp = _mcp_fastmcp | |
| sys.modules.setdefault("mcp", _mcp_pkg) | |
| sys.modules.setdefault("mcp.server", _mcp_server) | |
| sys.modules.setdefault("mcp.server.fastmcp", _mcp_fastmcp) | |
| # Now safe to import | |
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) | |
| import server as S | |
| # Reset store between tests | |
| def clear_store(): | |
| S._RECEIPT_STORE.clear() | |
| yield | |
| S._RECEIPT_STORE.clear() | |
| # --------------------------------------------------------------------------- | |
| # DSSE helpers | |
| # --------------------------------------------------------------------------- | |
| class TestDSSEHelpers: | |
| def test_pae_structure(self): | |
| pae = S._pae("a", b"b") | |
| assert pae == b"DSSEv1 1 a 1 b" | |
| def test_pae_empty_payload(self): | |
| pae = S._pae("application/vnd.in-toto+json", b"") | |
| assert pae.startswith(b"DSSEv1 ") | |
| assert b" 0 " in pae | |
| def test_sign_verify_roundtrip(self): | |
| msg = b"hello governance" | |
| sig = S._sign_dev(msg) | |
| sig_b64 = base64.standard_b64encode(sig).decode() | |
| assert S._verify_dev(msg, sig_b64) | |
| def test_verify_wrong_message_fails(self): | |
| msg = b"correct" | |
| sig = S._sign_dev(msg) | |
| sig_b64 = base64.standard_b64encode(sig).decode() | |
| assert not S._verify_dev(b"tampered", sig_b64) | |
| def test_dsse_wrap_structure(self): | |
| env = S.dsse_wrap({"action": "test"}) | |
| assert "payload" in env | |
| assert "payloadType" in env | |
| assert env["payloadType"] == S.SLSA_V1_PAYLOAD_TYPE | |
| assert len(env["signatures"]) == 1 | |
| assert "keyid" in env["signatures"][0] | |
| assert "sig" in env["signatures"][0] | |
| def test_dsse_wrap_deterministic(self): | |
| env1 = S.dsse_wrap({"a": 1, "b": 2}) | |
| env2 = S.dsse_wrap({"b": 2, "a": 1}) # key-order invariant via sort_keys | |
| assert env1 == env2 | |
| def test_dsse_verify_valid(self): | |
| env = S.dsse_wrap({"test": True}) | |
| assert S.dsse_verify_envelope(env) | |
| def test_dsse_verify_tampered_payload(self): | |
| env = S.dsse_wrap({"test": True}) | |
| bad = dict(env) | |
| bad["payload"] = base64.standard_b64encode(b'{"test":false}').decode() | |
| assert not S.dsse_verify_envelope(bad) | |
| def test_dsse_verify_tampered_type(self): | |
| env = S.dsse_wrap({"test": True}) | |
| bad = dict(env) | |
| bad["payloadType"] = "application/other" | |
| assert not S.dsse_verify_envelope(bad) | |
| # --------------------------------------------------------------------------- | |
| # Tool 1: emit_receipt | |
| # --------------------------------------------------------------------------- | |
| class TestEmitReceipt: | |
| def test_returns_required_fields(self): | |
| result = S.emit_receipt( | |
| action_type="tool_call", | |
| agent_id="test-agent-001", | |
| payload={"tool": "bash", "command": "ls"}, | |
| ) | |
| assert "receipt_id" in result | |
| assert "envelope" in result | |
| assert "timestamp" in result | |
| assert result["parent_id"] is None | |
| def test_receipt_id_is_sha256_hex(self): | |
| result = S.emit_receipt("file_write", "agent-x", {"file": "foo.py"}) | |
| assert len(result["receipt_id"]) == 64 | |
| # Verify it's actually sha256 of decoded payload | |
| payload_bytes = base64.standard_b64decode(result["envelope"]["payload"]) | |
| expected = hashlib.sha256(payload_bytes).hexdigest() | |
| assert result["receipt_id"] == expected | |
| def test_envelope_is_valid_dsse(self): | |
| result = S.emit_receipt("api_call", "agent-y", {"url": "https://example.com"}) | |
| assert S.dsse_verify_envelope(result["envelope"]) | |
| def test_receipt_stored_in_store(self): | |
| result = S.emit_receipt("code_edit", "agent-z", {"file": "main.py"}) | |
| rid = result["receipt_id"] | |
| assert rid in S._RECEIPT_STORE | |
| def test_parent_id_threaded(self): | |
| r1 = S.emit_receipt("tool_call", "agent", {"step": 1}) | |
| r2 = S.emit_receipt("tool_call", "agent", {"step": 2}, parent_id=r1["receipt_id"]) | |
| assert r2["parent_id"] == r1["receipt_id"] | |
| def test_different_payloads_different_ids(self): | |
| r1 = S.emit_receipt("tool_call", "agent", {"x": 1}) | |
| r2 = S.emit_receipt("tool_call", "agent", {"x": 2}) | |
| assert r1["receipt_id"] != r2["receipt_id"] | |
| def test_payload_too_large_raises(self): | |
| giant = {"data": "x" * (512 * 1024 + 1)} | |
| with pytest.raises(ValueError, match="524288|512"): | |
| S.emit_receipt("tool_call", "agent", giant) | |
| def test_timestamp_iso8601(self): | |
| result = S.emit_receipt("tool_call", "agent", {}) | |
| ts = result["timestamp"] | |
| # Must parse without error | |
| import datetime | |
| dt = datetime.datetime.fromisoformat(ts) | |
| assert dt.tzinfo is not None # UTC-aware | |
| # --------------------------------------------------------------------------- | |
| # Tool 2: verify_receipt | |
| # --------------------------------------------------------------------------- | |
| class TestVerifyReceipt: | |
| def test_verify_freshly_emitted(self): | |
| r = S.emit_receipt("tool_call", "agent", {"k": "v"}) | |
| v = S.verify_receipt(r["receipt_id"]) | |
| assert v["valid"] is True | |
| assert v["error"] is None | |
| assert v["keyid"] is not None | |
| def test_verify_unknown_id(self): | |
| v = S.verify_receipt("a" * 64) | |
| assert v["valid"] is False | |
| assert "not found" in v["error"] | |
| def test_verify_after_store_tamper(self): | |
| r = S.emit_receipt("tool_call", "agent", {"k": "v"}) | |
| rid = r["receipt_id"] | |
| # Tamper with stored envelope payload | |
| S._RECEIPT_STORE[rid]["envelope"]["payload"] = base64.standard_b64encode( | |
| b'{"tampered":true}' | |
| ).decode() | |
| v = S.verify_receipt(rid) | |
| assert v["valid"] is False | |
| # --------------------------------------------------------------------------- | |
| # Tool 3: list_receipts | |
| # --------------------------------------------------------------------------- | |
| class TestListReceipts: | |
| def _emit_batch(self): | |
| S.emit_receipt("tool_call", "alice", {"n": 1}) | |
| S.emit_receipt("tool_call", "alice", {"n": 2}) | |
| S.emit_receipt("file_write", "alice", {"n": 3}) | |
| S.emit_receipt("api_call", "bob", {"n": 4}) | |
| S.emit_receipt("tool_call", "bob", {"n": 5}) | |
| def test_list_all(self): | |
| self._emit_batch() | |
| result = S.list_receipts() | |
| assert result["total"] == 5 | |
| assert len(result["receipts"]) == 5 | |
| def test_filter_by_agent(self): | |
| self._emit_batch() | |
| result = S.list_receipts(agent_id="alice") | |
| assert result["total"] == 3 | |
| assert all(r["agent_id"] == "alice" for r in result["receipts"]) | |
| def test_filter_by_action_type(self): | |
| self._emit_batch() | |
| result = S.list_receipts(action_type="tool_call") | |
| assert result["total"] == 3 | |
| def test_filter_combined(self): | |
| self._emit_batch() | |
| result = S.list_receipts(agent_id="alice", action_type="tool_call") | |
| assert result["total"] == 2 | |
| def test_pagination(self): | |
| self._emit_batch() | |
| p1 = S.list_receipts(limit=2, offset=0) | |
| p2 = S.list_receipts(limit=2, offset=2) | |
| assert len(p1["receipts"]) == 2 | |
| assert len(p2["receipts"]) == 2 | |
| ids1 = {r["receipt_id"] for r in p1["receipts"]} | |
| ids2 = {r["receipt_id"] for r in p2["receipts"]} | |
| assert ids1.isdisjoint(ids2) | |
| def test_limit_capped_at_500(self): | |
| # Should not error even with outrageous limit | |
| result = S.list_receipts(limit=99999) | |
| assert isinstance(result["receipts"], list) | |
| def test_empty_store(self): | |
| result = S.list_receipts() | |
| assert result["total"] == 0 | |
| assert result["receipts"] == [] | |
| # --------------------------------------------------------------------------- | |
| # Tool 4: get_attestation_chain | |
| # --------------------------------------------------------------------------- | |
| class TestGetAttestationChain: | |
| def test_single_node_chain(self): | |
| r = S.emit_receipt("tool_call", "agent", {}) | |
| result = S.get_attestation_chain(r["receipt_id"]) | |
| assert len(result["chain"]) == 1 | |
| assert result["chain"][0]["depth"] == 0 | |
| assert result["root_id"] == r["receipt_id"] | |
| def test_three_node_chain(self): | |
| r1 = S.emit_receipt("tool_call", "agent", {"step": 1}) | |
| r2 = S.emit_receipt("tool_call", "agent", {"step": 2}, parent_id=r1["receipt_id"]) | |
| r3 = S.emit_receipt("tool_call", "agent", {"step": 3}, parent_id=r2["receipt_id"]) | |
| result = S.get_attestation_chain(r3["receipt_id"]) | |
| chain = result["chain"] | |
| assert len(chain) == 3 | |
| assert chain[0]["receipt_id"] == r3["receipt_id"] | |
| assert chain[0]["depth"] == 0 | |
| assert chain[2]["receipt_id"] == r1["receipt_id"] | |
| assert chain[2]["depth"] == 2 | |
| assert result["root_id"] == r1["receipt_id"] | |
| def test_unknown_receipt_returns_empty(self): | |
| result = S.get_attestation_chain("b" * 64) | |
| assert result["chain"] == [] | |
| def test_max_depth_respected(self): | |
| # Build a chain of 20 | |
| prev_id = None | |
| for i in range(20): | |
| r = S.emit_receipt("tool_call", "agent", {"i": i}, parent_id=prev_id) | |
| prev_id = r["receipt_id"] | |
| result = S.get_attestation_chain(prev_id, max_depth=5) | |
| assert len(result["chain"]) <= 6 # depth 0..5 inclusive | |