Spaces:
Sleeping
Sleeping
| # MERGED: OpenCode (deep KV physics) + CC (surface coverage) | |
| # All tests hermetic: no GPU, no TCP, no downloaded weights required | |
| from __future__ import annotations | |
| import logging | |
| from datetime import datetime, timedelta, timezone | |
| import numpy as np | |
| import pytest | |
| # Optional dep guard — skip entire module if fastapi not installed | |
| fastapi = pytest.importorskip("fastapi", reason="fastapi not installed — install with: pip install fastapi") | |
| from fastapi.testclient import TestClient | |
| from apohara_context_forge.mcp import server as srv | |
| from apohara_context_forge.mcp.server import ( | |
| app, | |
| get_compressor, | |
| get_coordinator, | |
| get_metrics, | |
| get_registry, | |
| ) | |
| from apohara_context_forge.models import ( | |
| CompressionDecision, | |
| ContextEntry, | |
| Degradation, | |
| MetricsSnapshot, | |
| ) | |
| from apohara_context_forge.registry.context_registry import ContextRegistry | |
| # ---- Fakes (module-level so dependency_overrides + lifespan patches both work) ----- | |
| class FakeMetrics: | |
| def __init__(self, *, gpu_label: str = "cuda", raise_on_label: bool = False) -> None: | |
| self._gpu_label = gpu_label | |
| self._raise_on_label = raise_on_label | |
| self.register_calls: list[bool] = [] | |
| self.decision_calls: list[CompressionDecision] = [] | |
| self._snapshot_kwargs: dict | None = None | |
| def _resolve_gpu_label(self) -> str: | |
| if self._raise_on_label: | |
| raise RuntimeError("gpu probe blew up") | |
| return self._gpu_label | |
| def record_register(self, matched: bool) -> None: | |
| self.register_calls.append(matched) | |
| def record_decision(self, decision: CompressionDecision) -> None: | |
| self.decision_calls.append(decision) | |
| async def snapshot( | |
| self, *, current_compressor_model, compressor_degradations | |
| ) -> MetricsSnapshot: | |
| self._snapshot_kwargs = { | |
| "current_compressor_model": current_compressor_model, | |
| "compressor_degradations": compressor_degradations, | |
| } | |
| return MetricsSnapshot( | |
| vram_source="psutil", | |
| compressor_model=current_compressor_model, | |
| vram_used_gb=1.0, | |
| vram_total_gb=8.0, | |
| ttft_ms=0.0, | |
| tokens_processed=0, | |
| tokens_saved=0, | |
| dedup_rate=0.0, | |
| compression_ratio=0.0, | |
| degradations=list(compressor_degradations), | |
| ) | |
| class FakeCompressor: | |
| def __init__( | |
| self, | |
| current_model: str = "xlm-roberta-large", | |
| degradations: list[Degradation] | None = None, | |
| ) -> None: | |
| self.current_model = current_model | |
| self.degradations = degradations or [] | |
| class FakeRegistry: | |
| def __init__(self, entry: ContextEntry | None = None) -> None: | |
| self._entry = entry | |
| self.register_calls: list[tuple[str, str]] = [] | |
| self.cleared = False | |
| async def register(self, agent_id: str, context: str) -> ContextEntry: | |
| self.register_calls.append((agent_id, context)) | |
| if self._entry is not None: | |
| return self._entry | |
| now = datetime.now(timezone.utc) | |
| return ContextEntry( | |
| agent_id=agent_id, | |
| context=context, | |
| token_count=len(context.split()), | |
| created_at=now, | |
| expires_at=now + timedelta(seconds=300), | |
| ) | |
| async def clear(self) -> None: | |
| self.cleared = True | |
| class FakeCoordinator: | |
| def __init__(self, decision: CompressionDecision | Exception) -> None: | |
| self._decision = decision | |
| self.decide_calls: list[tuple[str, str]] = [] | |
| async def decide(self, agent_id: str, context: str) -> CompressionDecision: | |
| self.decide_calls.append((agent_id, context)) | |
| if isinstance(self._decision, Exception): | |
| raise self._decision | |
| return self._decision | |
| # ---- FakeDedupEngine for the full-flow test (re-uses test_registry pattern) --------- | |
| class FakeDedupEngine: | |
| def __init__(self) -> None: | |
| self._key_for_text: dict[str, float] = {} | |
| self._next_key: float = 1.0 | |
| def _key(self, text: str) -> float: | |
| if text not in self._key_for_text: | |
| self._key_for_text[text] = self._next_key | |
| self._next_key += 1.0 | |
| return self._key_for_text[text] | |
| async def embed(self, text: str) -> np.ndarray: | |
| v = np.zeros(8, dtype=np.float32) | |
| v[0] = self._key(text) | |
| return v | |
| async def similarity(self, e1: np.ndarray, e2: np.ndarray) -> float: | |
| return 1.0 if float(e1[0]) == float(e2[0]) else 0.0 | |
| def find_shared_prefix(self, a: str, b: str) -> str: | |
| n = min(len(a), len(b)) | |
| i = 0 | |
| while i < n and a[i] == b[i]: | |
| i += 1 | |
| return a[:i] | |
| def count_prefix_tokens(self, prefix: str) -> int: | |
| return len(prefix.split()) | |
| # ---- Helpers ------------------------------------------------------------------------ | |
| def _client_with_overrides(overrides: dict) -> TestClient: | |
| """Build a TestClient that bypasses the production lifespan by injecting | |
| only the dependency overrides. We do NOT enter the context manager so the | |
| lifespan never fires (which means no real ContextCompressor / VLLMClient | |
| construction). Keys must be the dependency function references themselves | |
| (e.g. ``get_registry``) — FastAPI matches by identity, not by name.""" | |
| for dep, factory in overrides.items(): | |
| app.dependency_overrides[dep] = factory | |
| return TestClient(app) | |
| def _clear_overrides(): | |
| yield | |
| app.dependency_overrides.clear() | |
| # ---- Tests -------------------------------------------------------------------------- | |
| def test_health_returns_ok_with_gpu_label() -> None: | |
| metrics = FakeMetrics(gpu_label="cuda") | |
| client = _client_with_overrides({get_metrics: lambda: metrics}) | |
| resp = client.get("/health") | |
| assert resp.status_code == 200 | |
| assert resp.json() == {"status": "ok", "gpu": "cuda"} | |
| def test_health_returns_degraded_on_internal_error() -> None: | |
| metrics = FakeMetrics(raise_on_label=True) | |
| client = _client_with_overrides({get_metrics: lambda: metrics}) | |
| resp = client.get("/health") | |
| assert resp.status_code == 200 | |
| assert resp.json() == {"status": "degraded", "gpu": "unknown"} | |
| def test_metrics_snapshot_returns_valid_pydantic() -> None: | |
| metrics = FakeMetrics() | |
| compressor = FakeCompressor( | |
| current_model="xlm-roberta-large", | |
| degradations=[Degradation(component="compressor", reason="OOM", fallback="cpu")], | |
| ) | |
| client = _client_with_overrides( | |
| {get_metrics: lambda: metrics, get_compressor: lambda: compressor} | |
| ) | |
| resp = client.get("/metrics/snapshot") | |
| assert resp.status_code == 200 | |
| snap = MetricsSnapshot.model_validate(resp.json()) | |
| assert snap.compressor_model == "xlm-roberta-large" | |
| assert any(d.component == "compressor" for d in snap.degradations) | |
| assert metrics._snapshot_kwargs is not None | |
| assert metrics._snapshot_kwargs["current_compressor_model"] == "xlm-roberta-large" | |
| def test_register_context_happy_path() -> None: | |
| now = datetime.now(timezone.utc) | |
| stub_entry = ContextEntry( | |
| agent_id="alice", | |
| context="hello world", | |
| token_count=2, | |
| created_at=now, | |
| expires_at=now + timedelta(seconds=300), | |
| ) | |
| registry = FakeRegistry(entry=stub_entry) | |
| metrics = FakeMetrics() | |
| client = _client_with_overrides( | |
| {get_registry: lambda: registry, get_metrics: lambda: metrics} | |
| ) | |
| resp = client.post( | |
| "/tools/register_context", | |
| json={"agent_id": "alice", "context": "hello world"}, | |
| ) | |
| assert resp.status_code == 200 | |
| parsed = ContextEntry.model_validate_json(resp.text) | |
| assert parsed.agent_id == "alice" | |
| assert parsed.context == "hello world" | |
| assert metrics.register_calls == [False] | |
| assert registry.register_calls == [("alice", "hello world")] | |
| def test_register_context_422_on_empty_agent_id() -> None: | |
| client = _client_with_overrides( | |
| {get_registry: lambda: FakeRegistry(), get_metrics: lambda: FakeMetrics()} | |
| ) | |
| resp = client.post( | |
| "/tools/register_context", | |
| json={"agent_id": "", "context": "x"}, | |
| ) | |
| assert resp.status_code == 422 | |
| def test_register_context_422_on_extra_field() -> None: | |
| client = _client_with_overrides( | |
| {get_registry: lambda: FakeRegistry(), get_metrics: lambda: FakeMetrics()} | |
| ) | |
| resp = client.post( | |
| "/tools/register_context", | |
| json={"agent_id": "a", "context": "x", "hostile": 1}, | |
| ) | |
| assert resp.status_code == 422 | |
| def test_register_context_422_on_missing_field() -> None: | |
| client = _client_with_overrides( | |
| {get_registry: lambda: FakeRegistry(), get_metrics: lambda: FakeMetrics()} | |
| ) | |
| resp = client.post("/tools/register_context", json={"agent_id": "a"}) | |
| assert resp.status_code == 422 | |
| def test_get_optimized_context_happy_path() -> None: | |
| decision = CompressionDecision( | |
| strategy="compress", | |
| final_context="compressed body", | |
| shared_prefix="", | |
| original_tokens=1000, | |
| final_tokens=500, | |
| tokens_saved=500, | |
| rationale="ctx_tokens > threshold", | |
| ) | |
| coordinator = FakeCoordinator(decision=decision) | |
| metrics = FakeMetrics() | |
| client = _client_with_overrides( | |
| {get_coordinator: lambda: coordinator, get_metrics: lambda: metrics} | |
| ) | |
| resp = client.post( | |
| "/tools/get_optimized_context", | |
| json={"agent_id": "alice", "context": "hello"}, | |
| ) | |
| assert resp.status_code == 200 | |
| parsed = CompressionDecision.model_validate(resp.json()) | |
| assert parsed == decision | |
| assert len(metrics.decision_calls) == 1 | |
| assert coordinator.decide_calls == [("alice", "hello")] | |
| def test_get_optimized_context_503_fallback_on_handler_exception() -> None: | |
| coordinator = FakeCoordinator(decision=RuntimeError("boom")) | |
| metrics = FakeMetrics() | |
| client = _client_with_overrides( | |
| {get_coordinator: lambda: coordinator, get_metrics: lambda: metrics} | |
| ) | |
| resp = client.post( | |
| "/tools/get_optimized_context", | |
| json={"agent_id": "alice", "context": "the original body"}, | |
| ) | |
| assert resp.status_code == 503 | |
| parsed = CompressionDecision.model_validate(resp.json()) | |
| assert parsed.strategy == "passthrough" | |
| assert parsed.final_context == "the original body" | |
| assert parsed.original_tokens == 0 | |
| assert parsed.final_tokens == 0 | |
| assert parsed.tokens_saved == 0 | |
| assert metrics.decision_calls == [] | |
| def test_get_optimized_context_422_on_malformed_body() -> None: | |
| decision = CompressionDecision( | |
| strategy="passthrough", | |
| final_context="", | |
| shared_prefix="", | |
| original_tokens=0, | |
| final_tokens=0, | |
| tokens_saved=0, | |
| rationale="", | |
| ) | |
| client = _client_with_overrides( | |
| { | |
| get_coordinator: lambda: FakeCoordinator(decision=decision), | |
| get_metrics: lambda: FakeMetrics(), | |
| } | |
| ) | |
| resp = client.post("/tools/get_optimized_context", json={"agent_id": "a"}) | |
| assert resp.status_code == 422 | |
| def test_no_log_includes_request_body(caplog: pytest.LogCaptureFixture) -> None: | |
| sentinel = "REDACTION-SENTINEL-XYZZY-9F3A2B7C-do-not-log" | |
| registry = FakeRegistry() | |
| metrics = FakeMetrics() | |
| client = _client_with_overrides( | |
| {get_registry: lambda: registry, get_metrics: lambda: metrics} | |
| ) | |
| with caplog.at_level(logging.DEBUG): | |
| # Trigger both happy-path register AND the 503 warning path so any | |
| # mishandled log surface is exercised. | |
| client.post( | |
| "/tools/register_context", | |
| json={"agent_id": "alice", "context": sentinel}, | |
| ) | |
| # Now exercise the 503 path with the sentinel in the body | |
| bad_coord = FakeCoordinator(decision=RuntimeError("boom")) | |
| app.dependency_overrides[get_coordinator] = lambda: bad_coord | |
| client.post( | |
| "/tools/get_optimized_context", | |
| json={"agent_id": "alice", "context": sentinel}, | |
| ) | |
| for record in caplog.records: | |
| assert sentinel not in record.getMessage() | |
| for value in record.__dict__.values(): | |
| assert sentinel not in str(value) | |
| def test_lifespan_constructs_and_disposes(monkeypatch: pytest.MonkeyPatch) -> None: | |
| # Replace the heavy production classes the lifespan reaches for so | |
| # `with TestClient(app) as client:` does not download model weights or | |
| # touch the network. | |
| class _LifeReg: | |
| instances: list = [] | |
| def __init__(self) -> None: | |
| self.cleared = False | |
| type(self).instances.append(self) | |
| async def clear(self) -> None: | |
| self.cleared = True | |
| class _LifeComp: | |
| def __init__(self) -> None: | |
| pass | |
| class _LifeCoord: | |
| def __init__(self, registry=None, compressor=None) -> None: | |
| self.registry = registry | |
| self.compressor = compressor | |
| class _LifeMetr: | |
| def __init__(self) -> None: | |
| pass | |
| class _LifeVllm: | |
| instances: list = [] | |
| def __init__(self) -> None: | |
| self.closed = False | |
| type(self).instances.append(self) | |
| async def aclose(self) -> None: | |
| self.closed = True | |
| monkeypatch.setattr(srv, "ContextRegistry", _LifeReg) | |
| monkeypatch.setattr(srv, "ContextCompressor", _LifeComp) | |
| monkeypatch.setattr(srv, "CompressionCoordinator", _LifeCoord) | |
| monkeypatch.setattr(srv, "MetricsCollector", _LifeMetr) | |
| monkeypatch.setattr(srv, "VLLMClient", _LifeVllm) | |
| with TestClient(app) as client: | |
| assert isinstance(client.app.state.registry, _LifeReg) | |
| assert isinstance(client.app.state.compressor, _LifeComp) | |
| assert isinstance(client.app.state.coordinator, _LifeCoord) | |
| assert isinstance(client.app.state.metrics, _LifeMetr) | |
| assert isinstance(client.app.state.vllm, _LifeVllm) | |
| # Coordinator must be wired to the SAME registry+compressor instances | |
| assert client.app.state.coordinator.registry is client.app.state.registry | |
| assert client.app.state.coordinator.compressor is client.app.state.compressor | |
| # On context exit the lifespan ran cleanup | |
| assert _LifeReg.instances and _LifeReg.instances[-1].cleared is True | |
| assert _LifeVllm.instances and _LifeVllm.instances[-1].closed is True | |
| def test_full_flow_register_then_optimize_passthrough() -> None: | |
| # Real ContextRegistry with a hermetic FakeDedupEngine (no model download) | |
| # plus a stub coordinator that always returns passthrough. | |
| registry = ContextRegistry(dedup=FakeDedupEngine()) | |
| metrics = FakeMetrics() | |
| compressor = FakeCompressor() | |
| short_ctx = "this is a short context" | |
| passthrough = CompressionDecision( | |
| strategy="passthrough", | |
| final_context=short_ctx, | |
| shared_prefix="", | |
| original_tokens=5, | |
| final_tokens=5, | |
| tokens_saved=0, | |
| rationale="ctx_tokens <= threshold AND no long shared prefix", | |
| ) | |
| coordinator = FakeCoordinator(decision=passthrough) | |
| client = _client_with_overrides( | |
| { | |
| get_registry: lambda: registry, | |
| get_metrics: lambda: metrics, | |
| get_compressor: lambda: compressor, | |
| get_coordinator: lambda: coordinator, | |
| } | |
| ) | |
| reg_resp = client.post( | |
| "/tools/register_context", | |
| json={"agent_id": "alice", "context": short_ctx}, | |
| ) | |
| assert reg_resp.status_code == 200 | |
| reg_entry = ContextEntry.model_validate_json(reg_resp.text) | |
| assert reg_entry.agent_id == "alice" | |
| opt_resp = client.post( | |
| "/tools/get_optimized_context", | |
| json={"agent_id": "alice", "context": short_ctx}, | |
| ) | |
| assert opt_resp.status_code == 200 | |
| decision = CompressionDecision.model_validate(opt_resp.json()) | |
| assert decision.strategy == "passthrough" | |
| snap_resp = client.get("/metrics/snapshot") | |
| assert snap_resp.status_code == 200 | |
| snap = MetricsSnapshot.model_validate(snap_resp.json()) | |
| # passthrough records (0,0) — tokens_processed stays 0; that's fine | |
| assert snap.tokens_processed == 0 | |
| assert metrics.register_calls == [False] | |
| assert len(metrics.decision_calls) == 1 | |