""" pytest suite for MCP Governance Receipts Server. Tests cover all 4 MCP tools: - emit_receipt - verify_receipt - list_receipts - get_attestation_chain Spec: https://modelcontextprotocol.io DSSE: https://slsa.dev/spec/v1.0/about Thesis DOI: https://doi.org/10.5281/zenodo.20434276 SPDX-License-Identifier: Apache-2.0 Copyright 2026 SZL Holdings / Lutar, Stephen P. ORCID 0009-0001-0110-4173 """ from __future__ import annotations import base64 import hashlib import importlib import json import sys import types import os import pytest # --------------------------------------------------------------------------- # Isolate the server module so tests don't start the HTTP listener # --------------------------------------------------------------------------- # Patch FastMCP before importing server so no network socket is opened class _FakeFastMCP: def __init__(self, name="", description=""): self.name = name def tool(self): def decorator(fn): return fn return decorator def run(self, *args, **kwargs): pass # Inject fake 'mcp' package _mcp_pkg = types.ModuleType("mcp") _mcp_server = types.ModuleType("mcp.server") _mcp_fastmcp = types.ModuleType("mcp.server.fastmcp") _mcp_fastmcp.FastMCP = _FakeFastMCP _mcp_pkg.server = _mcp_server _mcp_server.fastmcp = _mcp_fastmcp sys.modules.setdefault("mcp", _mcp_pkg) sys.modules.setdefault("mcp.server", _mcp_server) sys.modules.setdefault("mcp.server.fastmcp", _mcp_fastmcp) # Now safe to import sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) import server as S # Reset store between tests @pytest.fixture(autouse=True) def clear_store(): S._RECEIPT_STORE.clear() yield S._RECEIPT_STORE.clear() # --------------------------------------------------------------------------- # DSSE helpers # --------------------------------------------------------------------------- class TestDSSEHelpers: def test_pae_structure(self): pae = S._pae("a", b"b") assert pae == b"DSSEv1 1 a 1 b" def test_pae_empty_payload(self): pae = S._pae("application/vnd.in-toto+json", b"") assert pae.startswith(b"DSSEv1 ") assert b" 0 " in pae def test_sign_verify_roundtrip(self): msg = b"hello governance" sig = S._sign_dev(msg) sig_b64 = base64.standard_b64encode(sig).decode() assert S._verify_dev(msg, sig_b64) def test_verify_wrong_message_fails(self): msg = b"correct" sig = S._sign_dev(msg) sig_b64 = base64.standard_b64encode(sig).decode() assert not S._verify_dev(b"tampered", sig_b64) def test_dsse_wrap_structure(self): env = S.dsse_wrap({"action": "test"}) assert "payload" in env assert "payloadType" in env assert env["payloadType"] == S.SLSA_V1_PAYLOAD_TYPE assert len(env["signatures"]) == 1 assert "keyid" in env["signatures"][0] assert "sig" in env["signatures"][0] def test_dsse_wrap_deterministic(self): env1 = S.dsse_wrap({"a": 1, "b": 2}) env2 = S.dsse_wrap({"b": 2, "a": 1}) # key-order invariant via sort_keys assert env1 == env2 def test_dsse_verify_valid(self): env = S.dsse_wrap({"test": True}) assert S.dsse_verify_envelope(env) def test_dsse_verify_tampered_payload(self): env = S.dsse_wrap({"test": True}) bad = dict(env) bad["payload"] = base64.standard_b64encode(b'{"test":false}').decode() assert not S.dsse_verify_envelope(bad) def test_dsse_verify_tampered_type(self): env = S.dsse_wrap({"test": True}) bad = dict(env) bad["payloadType"] = "application/other" assert not S.dsse_verify_envelope(bad) # --------------------------------------------------------------------------- # Tool 1: emit_receipt # --------------------------------------------------------------------------- class TestEmitReceipt: def test_returns_required_fields(self): result = S.emit_receipt( action_type="tool_call", agent_id="test-agent-001", payload={"tool": "bash", "command": "ls"}, ) assert "receipt_id" in result assert "envelope" in result assert "timestamp" in result assert result["parent_id"] is None def test_receipt_id_is_sha256_hex(self): result = S.emit_receipt("file_write", "agent-x", {"file": "foo.py"}) assert len(result["receipt_id"]) == 64 # Verify it's actually sha256 of decoded payload payload_bytes = base64.standard_b64decode(result["envelope"]["payload"]) expected = hashlib.sha256(payload_bytes).hexdigest() assert result["receipt_id"] == expected def test_envelope_is_valid_dsse(self): result = S.emit_receipt("api_call", "agent-y", {"url": "https://example.com"}) assert S.dsse_verify_envelope(result["envelope"]) def test_receipt_stored_in_store(self): result = S.emit_receipt("code_edit", "agent-z", {"file": "main.py"}) rid = result["receipt_id"] assert rid in S._RECEIPT_STORE def test_parent_id_threaded(self): r1 = S.emit_receipt("tool_call", "agent", {"step": 1}) r2 = S.emit_receipt("tool_call", "agent", {"step": 2}, parent_id=r1["receipt_id"]) assert r2["parent_id"] == r1["receipt_id"] def test_different_payloads_different_ids(self): r1 = S.emit_receipt("tool_call", "agent", {"x": 1}) r2 = S.emit_receipt("tool_call", "agent", {"x": 2}) assert r1["receipt_id"] != r2["receipt_id"] def test_payload_too_large_raises(self): giant = {"data": "x" * (512 * 1024 + 1)} with pytest.raises(ValueError, match="524288|512"): S.emit_receipt("tool_call", "agent", giant) def test_timestamp_iso8601(self): result = S.emit_receipt("tool_call", "agent", {}) ts = result["timestamp"] # Must parse without error import datetime dt = datetime.datetime.fromisoformat(ts) assert dt.tzinfo is not None # UTC-aware # --------------------------------------------------------------------------- # Tool 2: verify_receipt # --------------------------------------------------------------------------- class TestVerifyReceipt: def test_verify_freshly_emitted(self): r = S.emit_receipt("tool_call", "agent", {"k": "v"}) v = S.verify_receipt(r["receipt_id"]) assert v["valid"] is True assert v["error"] is None assert v["keyid"] is not None def test_verify_unknown_id(self): v = S.verify_receipt("a" * 64) assert v["valid"] is False assert "not found" in v["error"] def test_verify_after_store_tamper(self): r = S.emit_receipt("tool_call", "agent", {"k": "v"}) rid = r["receipt_id"] # Tamper with stored envelope payload S._RECEIPT_STORE[rid]["envelope"]["payload"] = base64.standard_b64encode( b'{"tampered":true}' ).decode() v = S.verify_receipt(rid) assert v["valid"] is False # --------------------------------------------------------------------------- # Tool 3: list_receipts # --------------------------------------------------------------------------- class TestListReceipts: def _emit_batch(self): S.emit_receipt("tool_call", "alice", {"n": 1}) S.emit_receipt("tool_call", "alice", {"n": 2}) S.emit_receipt("file_write", "alice", {"n": 3}) S.emit_receipt("api_call", "bob", {"n": 4}) S.emit_receipt("tool_call", "bob", {"n": 5}) def test_list_all(self): self._emit_batch() result = S.list_receipts() assert result["total"] == 5 assert len(result["receipts"]) == 5 def test_filter_by_agent(self): self._emit_batch() result = S.list_receipts(agent_id="alice") assert result["total"] == 3 assert all(r["agent_id"] == "alice" for r in result["receipts"]) def test_filter_by_action_type(self): self._emit_batch() result = S.list_receipts(action_type="tool_call") assert result["total"] == 3 def test_filter_combined(self): self._emit_batch() result = S.list_receipts(agent_id="alice", action_type="tool_call") assert result["total"] == 2 def test_pagination(self): self._emit_batch() p1 = S.list_receipts(limit=2, offset=0) p2 = S.list_receipts(limit=2, offset=2) assert len(p1["receipts"]) == 2 assert len(p2["receipts"]) == 2 ids1 = {r["receipt_id"] for r in p1["receipts"]} ids2 = {r["receipt_id"] for r in p2["receipts"]} assert ids1.isdisjoint(ids2) def test_limit_capped_at_500(self): # Should not error even with outrageous limit result = S.list_receipts(limit=99999) assert isinstance(result["receipts"], list) def test_empty_store(self): result = S.list_receipts() assert result["total"] == 0 assert result["receipts"] == [] # --------------------------------------------------------------------------- # Tool 4: get_attestation_chain # --------------------------------------------------------------------------- class TestGetAttestationChain: def test_single_node_chain(self): r = S.emit_receipt("tool_call", "agent", {}) result = S.get_attestation_chain(r["receipt_id"]) assert len(result["chain"]) == 1 assert result["chain"][0]["depth"] == 0 assert result["root_id"] == r["receipt_id"] def test_three_node_chain(self): r1 = S.emit_receipt("tool_call", "agent", {"step": 1}) r2 = S.emit_receipt("tool_call", "agent", {"step": 2}, parent_id=r1["receipt_id"]) r3 = S.emit_receipt("tool_call", "agent", {"step": 3}, parent_id=r2["receipt_id"]) result = S.get_attestation_chain(r3["receipt_id"]) chain = result["chain"] assert len(chain) == 3 assert chain[0]["receipt_id"] == r3["receipt_id"] assert chain[0]["depth"] == 0 assert chain[2]["receipt_id"] == r1["receipt_id"] assert chain[2]["depth"] == 2 assert result["root_id"] == r1["receipt_id"] def test_unknown_receipt_returns_empty(self): result = S.get_attestation_chain("b" * 64) assert result["chain"] == [] def test_max_depth_respected(self): # Build a chain of 20 prev_id = None for i in range(20): r = S.emit_receipt("tool_call", "agent", {"i": i}, parent_id=prev_id) prev_id = r["receipt_id"] result = S.get_attestation_chain(prev_id, max_depth=5) assert len(result["chain"]) <= 6 # depth 0..5 inclusive