mcp-receipts-server / tests /test_server.py
betterwithage's picture
fix: shorten short_description to ≤60 chars
8ad3ee3 verified
"""
pytest suite for MCP Governance Receipts Server.
Tests cover all 4 MCP tools:
- emit_receipt
- verify_receipt
- list_receipts
- get_attestation_chain
Spec: https://modelcontextprotocol.io
DSSE: https://slsa.dev/spec/v1.0/about
Thesis DOI: https://doi.org/10.5281/zenodo.20434276
SPDX-License-Identifier: Apache-2.0
Copyright 2026 SZL Holdings / Lutar, Stephen P. ORCID 0009-0001-0110-4173
"""
from __future__ import annotations
import base64
import hashlib
import importlib
import json
import sys
import types
import os
import pytest
# ---------------------------------------------------------------------------
# Isolate the server module so tests don't start the HTTP listener
# ---------------------------------------------------------------------------
# Patch FastMCP before importing server so no network socket is opened
class _FakeFastMCP:
def __init__(self, name="", description=""):
self.name = name
def tool(self):
def decorator(fn):
return fn
return decorator
def run(self, *args, **kwargs):
pass
# Inject fake 'mcp' package
_mcp_pkg = types.ModuleType("mcp")
_mcp_server = types.ModuleType("mcp.server")
_mcp_fastmcp = types.ModuleType("mcp.server.fastmcp")
_mcp_fastmcp.FastMCP = _FakeFastMCP
_mcp_pkg.server = _mcp_server
_mcp_server.fastmcp = _mcp_fastmcp
sys.modules.setdefault("mcp", _mcp_pkg)
sys.modules.setdefault("mcp.server", _mcp_server)
sys.modules.setdefault("mcp.server.fastmcp", _mcp_fastmcp)
# Now safe to import
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import server as S
# Reset store between tests
@pytest.fixture(autouse=True)
def clear_store():
S._RECEIPT_STORE.clear()
yield
S._RECEIPT_STORE.clear()
# ---------------------------------------------------------------------------
# DSSE helpers
# ---------------------------------------------------------------------------
class TestDSSEHelpers:
def test_pae_structure(self):
pae = S._pae("a", b"b")
assert pae == b"DSSEv1 1 a 1 b"
def test_pae_empty_payload(self):
pae = S._pae("application/vnd.in-toto+json", b"")
assert pae.startswith(b"DSSEv1 ")
assert b" 0 " in pae
def test_sign_verify_roundtrip(self):
msg = b"hello governance"
sig = S._sign_dev(msg)
sig_b64 = base64.standard_b64encode(sig).decode()
assert S._verify_dev(msg, sig_b64)
def test_verify_wrong_message_fails(self):
msg = b"correct"
sig = S._sign_dev(msg)
sig_b64 = base64.standard_b64encode(sig).decode()
assert not S._verify_dev(b"tampered", sig_b64)
def test_dsse_wrap_structure(self):
env = S.dsse_wrap({"action": "test"})
assert "payload" in env
assert "payloadType" in env
assert env["payloadType"] == S.SLSA_V1_PAYLOAD_TYPE
assert len(env["signatures"]) == 1
assert "keyid" in env["signatures"][0]
assert "sig" in env["signatures"][0]
def test_dsse_wrap_deterministic(self):
env1 = S.dsse_wrap({"a": 1, "b": 2})
env2 = S.dsse_wrap({"b": 2, "a": 1}) # key-order invariant via sort_keys
assert env1 == env2
def test_dsse_verify_valid(self):
env = S.dsse_wrap({"test": True})
assert S.dsse_verify_envelope(env)
def test_dsse_verify_tampered_payload(self):
env = S.dsse_wrap({"test": True})
bad = dict(env)
bad["payload"] = base64.standard_b64encode(b'{"test":false}').decode()
assert not S.dsse_verify_envelope(bad)
def test_dsse_verify_tampered_type(self):
env = S.dsse_wrap({"test": True})
bad = dict(env)
bad["payloadType"] = "application/other"
assert not S.dsse_verify_envelope(bad)
# ---------------------------------------------------------------------------
# Tool 1: emit_receipt
# ---------------------------------------------------------------------------
class TestEmitReceipt:
def test_returns_required_fields(self):
result = S.emit_receipt(
action_type="tool_call",
agent_id="test-agent-001",
payload={"tool": "bash", "command": "ls"},
)
assert "receipt_id" in result
assert "envelope" in result
assert "timestamp" in result
assert result["parent_id"] is None
def test_receipt_id_is_sha256_hex(self):
result = S.emit_receipt("file_write", "agent-x", {"file": "foo.py"})
assert len(result["receipt_id"]) == 64
# Verify it's actually sha256 of decoded payload
payload_bytes = base64.standard_b64decode(result["envelope"]["payload"])
expected = hashlib.sha256(payload_bytes).hexdigest()
assert result["receipt_id"] == expected
def test_envelope_is_valid_dsse(self):
result = S.emit_receipt("api_call", "agent-y", {"url": "https://example.com"})
assert S.dsse_verify_envelope(result["envelope"])
def test_receipt_stored_in_store(self):
result = S.emit_receipt("code_edit", "agent-z", {"file": "main.py"})
rid = result["receipt_id"]
assert rid in S._RECEIPT_STORE
def test_parent_id_threaded(self):
r1 = S.emit_receipt("tool_call", "agent", {"step": 1})
r2 = S.emit_receipt("tool_call", "agent", {"step": 2}, parent_id=r1["receipt_id"])
assert r2["parent_id"] == r1["receipt_id"]
def test_different_payloads_different_ids(self):
r1 = S.emit_receipt("tool_call", "agent", {"x": 1})
r2 = S.emit_receipt("tool_call", "agent", {"x": 2})
assert r1["receipt_id"] != r2["receipt_id"]
def test_payload_too_large_raises(self):
giant = {"data": "x" * (512 * 1024 + 1)}
with pytest.raises(ValueError, match="524288|512"):
S.emit_receipt("tool_call", "agent", giant)
def test_timestamp_iso8601(self):
result = S.emit_receipt("tool_call", "agent", {})
ts = result["timestamp"]
# Must parse without error
import datetime
dt = datetime.datetime.fromisoformat(ts)
assert dt.tzinfo is not None # UTC-aware
# ---------------------------------------------------------------------------
# Tool 2: verify_receipt
# ---------------------------------------------------------------------------
class TestVerifyReceipt:
def test_verify_freshly_emitted(self):
r = S.emit_receipt("tool_call", "agent", {"k": "v"})
v = S.verify_receipt(r["receipt_id"])
assert v["valid"] is True
assert v["error"] is None
assert v["keyid"] is not None
def test_verify_unknown_id(self):
v = S.verify_receipt("a" * 64)
assert v["valid"] is False
assert "not found" in v["error"]
def test_verify_after_store_tamper(self):
r = S.emit_receipt("tool_call", "agent", {"k": "v"})
rid = r["receipt_id"]
# Tamper with stored envelope payload
S._RECEIPT_STORE[rid]["envelope"]["payload"] = base64.standard_b64encode(
b'{"tampered":true}'
).decode()
v = S.verify_receipt(rid)
assert v["valid"] is False
# ---------------------------------------------------------------------------
# Tool 3: list_receipts
# ---------------------------------------------------------------------------
class TestListReceipts:
def _emit_batch(self):
S.emit_receipt("tool_call", "alice", {"n": 1})
S.emit_receipt("tool_call", "alice", {"n": 2})
S.emit_receipt("file_write", "alice", {"n": 3})
S.emit_receipt("api_call", "bob", {"n": 4})
S.emit_receipt("tool_call", "bob", {"n": 5})
def test_list_all(self):
self._emit_batch()
result = S.list_receipts()
assert result["total"] == 5
assert len(result["receipts"]) == 5
def test_filter_by_agent(self):
self._emit_batch()
result = S.list_receipts(agent_id="alice")
assert result["total"] == 3
assert all(r["agent_id"] == "alice" for r in result["receipts"])
def test_filter_by_action_type(self):
self._emit_batch()
result = S.list_receipts(action_type="tool_call")
assert result["total"] == 3
def test_filter_combined(self):
self._emit_batch()
result = S.list_receipts(agent_id="alice", action_type="tool_call")
assert result["total"] == 2
def test_pagination(self):
self._emit_batch()
p1 = S.list_receipts(limit=2, offset=0)
p2 = S.list_receipts(limit=2, offset=2)
assert len(p1["receipts"]) == 2
assert len(p2["receipts"]) == 2
ids1 = {r["receipt_id"] for r in p1["receipts"]}
ids2 = {r["receipt_id"] for r in p2["receipts"]}
assert ids1.isdisjoint(ids2)
def test_limit_capped_at_500(self):
# Should not error even with outrageous limit
result = S.list_receipts(limit=99999)
assert isinstance(result["receipts"], list)
def test_empty_store(self):
result = S.list_receipts()
assert result["total"] == 0
assert result["receipts"] == []
# ---------------------------------------------------------------------------
# Tool 4: get_attestation_chain
# ---------------------------------------------------------------------------
class TestGetAttestationChain:
def test_single_node_chain(self):
r = S.emit_receipt("tool_call", "agent", {})
result = S.get_attestation_chain(r["receipt_id"])
assert len(result["chain"]) == 1
assert result["chain"][0]["depth"] == 0
assert result["root_id"] == r["receipt_id"]
def test_three_node_chain(self):
r1 = S.emit_receipt("tool_call", "agent", {"step": 1})
r2 = S.emit_receipt("tool_call", "agent", {"step": 2}, parent_id=r1["receipt_id"])
r3 = S.emit_receipt("tool_call", "agent", {"step": 3}, parent_id=r2["receipt_id"])
result = S.get_attestation_chain(r3["receipt_id"])
chain = result["chain"]
assert len(chain) == 3
assert chain[0]["receipt_id"] == r3["receipt_id"]
assert chain[0]["depth"] == 0
assert chain[2]["receipt_id"] == r1["receipt_id"]
assert chain[2]["depth"] == 2
assert result["root_id"] == r1["receipt_id"]
def test_unknown_receipt_returns_empty(self):
result = S.get_attestation_chain("b" * 64)
assert result["chain"] == []
def test_max_depth_respected(self):
# Build a chain of 20
prev_id = None
for i in range(20):
r = S.emit_receipt("tool_call", "agent", {"i": i}, parent_id=prev_id)
prev_id = r["receipt_id"]
result = S.get_attestation_chain(prev_id, max_depth=5)
assert len(result["chain"]) <= 6 # depth 0..5 inclusive