SystemTruth / tests /test_ui_layer.py
Madhav189's picture
finalization: blog + README + execution rewrite, drop 3B + openclaw shim
0058c94
"""Tests for the BYOK provider routing + Gradio UI helpers.
We don't run the full Gradio server here (that's an integration test); these
tests exercise the building blocks app.py wires together:
- ``sre_gym/ui/router.py`` — model lists, find_entry, build_provider
- ``sre_gym/ui/providers.py`` — Provider auth checks + exception mapping
- ``sre_gym/ui/policies.py`` — JSON action extraction + fallback to escalate
"""
from __future__ import annotations
from typing import Any
import pytest
from sre_gym.exceptions import ActionParseError, ProviderAuthError, ProviderModelError
from sre_gym.tier import Tier
from sre_gym.ui.policies import _extract_json_object, make_policy
from sre_gym.ui.providers import (
AnthropicProvider,
HFInferenceProvider,
OpenAICompatibleProvider,
)
from sre_gym.ui.router import (
ADVANCED_MODELS,
BASIC_MODELS,
MAX_MODELS,
ProviderKind,
build_provider,
find_entry,
models_for_tier,
)
# ---------- Router / model catalogue ----------
def test_models_for_tier_returns_curated_list_for_each_tier() -> None:
for tier in Tier:
models = models_for_tier(tier)
assert len(models) >= 3, f"{tier.value} must have at least 3 curated models"
# Each entry must have a label, model_id, and provider kind.
for entry in models:
assert entry.label
assert entry.model_id
assert isinstance(entry.kind, ProviderKind)
def test_triage_tier_default_is_qwen25_7b() -> None:
"""The Triage tier's first entry should be the open-weight Qwen2.5-7B base."""
assert BASIC_MODELS[0].model_id == "Qwen/Qwen2.5-7B-Instruct"
def test_advanced_tier_default_is_long_horizon_model() -> None:
"""The Advanced tier's first entry must be a long-horizon-class model."""
assert "70B" in ADVANCED_MODELS[0].model_id or "Llama-3.3-70B" in ADVANCED_MODELS[0].label
def test_max_tier_default_is_claude_sonnet() -> None:
"""The Max tier's first entry must be Claude Sonnet (BYOK)."""
assert "Sonnet" in MAX_MODELS[0].label
assert MAX_MODELS[0].kind is ProviderKind.ANTHROPIC
def test_find_entry_resolves_label_or_model_id() -> None:
entry = find_entry(BASIC_MODELS[1].label, Tier.BASIC)
assert entry is BASIC_MODELS[1]
entry = find_entry(BASIC_MODELS[1].model_id, Tier.BASIC)
assert entry is BASIC_MODELS[1]
assert find_entry("nonexistent", Tier.BASIC) is None
# ---------- Providers — auth + exception mapping ----------
def test_hf_provider_rejects_empty_token() -> None:
with pytest.raises(ProviderAuthError):
HFInferenceProvider(hf_token="", model="Qwen/Qwen2.5-7B-Instruct")
def test_anthropic_provider_rejects_empty_key() -> None:
with pytest.raises(ProviderAuthError):
AnthropicProvider(api_key="", model="claude-sonnet-4-6")
def test_openai_compat_provider_rejects_empty_key() -> None:
with pytest.raises(ProviderAuthError):
OpenAICompatibleProvider(base_url="https://api.openai.com/v1", api_key="", model="gpt-5")
def test_openai_compat_provider_rejects_empty_base_url() -> None:
with pytest.raises(ProviderModelError):
OpenAICompatibleProvider(base_url="", api_key="sk-test", model="gpt-5")
# ---------- build_provider — credential dispatch ----------
def test_build_provider_for_basic_default_uses_hf_token() -> None:
"""Building the trained-3B provider with an HF token works."""
entry = BASIC_MODELS[0]
provider = build_provider(entry, hf_token="hf_test")
assert isinstance(provider, HFInferenceProvider)
assert provider.model == entry.model_id
def test_build_provider_anthropic_default() -> None:
sonnet = next(e for e in MAX_MODELS if "Sonnet" in e.label)
provider = build_provider(sonnet, anthropic_key="sk-ant-test")
assert isinstance(provider, AnthropicProvider)
def test_build_provider_openai_compat_default() -> None:
gpt5 = next(e for e in MAX_MODELS if "GPT-5" in e.label)
provider = build_provider(gpt5, openai_key="sk-test")
assert isinstance(provider, OpenAICompatibleProvider)
assert provider.model == "gpt-5"
def test_build_provider_missing_credential_raises_provider_auth_error() -> None:
"""No HF token → should raise ProviderAuthError, not silently proceed."""
entry = BASIC_MODELS[0]
with pytest.raises(ProviderAuthError):
build_provider(entry, hf_token="")
def test_build_provider_custom_model_override() -> None:
"""``custom_model_id`` overrides the entry's model_id."""
entry = BASIC_MODELS[0]
provider = build_provider(entry, hf_token="hf_test", custom_model_id="mistralai/Mistral-Small-Instruct")
assert provider.model == "mistralai/Mistral-Small-Instruct"
# ---------- Policy adapter — JSON extraction + fallbacks ----------
def test_extract_json_object_strips_markdown_fences() -> None:
text = '```json\n{"action_type": "query_logs", "service": "worker"}\n```'
obj = _extract_json_object(text)
assert obj == {"action_type": "query_logs", "service": "worker"}
def test_extract_json_object_handles_prose_around_json() -> None:
text = 'Sure thing! Here is the action:\n{"action_type": "escalate"}\nLet me know if you need more.'
obj = _extract_json_object(text)
assert obj == {"action_type": "escalate"}
def test_extract_json_object_normalizes_action_alias() -> None:
text = '{"action": "query_logs", "service": "worker"}'
obj = _extract_json_object(text)
assert obj == {"action_type": "query_logs", "service": "worker"}
def test_extract_json_object_raises_on_unterminated() -> None:
with pytest.raises(ActionParseError):
_extract_json_object('{"action_type": "query_logs"')
def test_extract_json_object_raises_on_no_json() -> None:
with pytest.raises(ActionParseError):
_extract_json_object('I am sorry but I cannot respond as JSON.')
class _FakeProvider:
"""Stand-in for tests — captures the messages and returns a configured response."""
name = "fake"
model = "fake-model"
def __init__(self, response: str) -> None:
self._response = response
self.last_messages: list[dict[str, str]] | None = None
def chat_sync(self, messages: list[dict[str, str]], **kwargs: Any) -> str:
self.last_messages = messages
return self._response
async def chat(self, messages: list[dict[str, str]], **kwargs: Any) -> str: # pragma: no cover
return self.chat_sync(messages, **kwargs)
def test_make_policy_returns_action_dict_for_well_formed_response() -> None:
fake = _FakeProvider('{"action_type":"query_logs","service":"worker"}')
policy = make_policy(fake, tier="basic")
class FakeObs:
prompt_text = "incident summary text"
action = policy(FakeObs())
assert action == {"action_type": "query_logs", "service": "worker"}
# System prompt + user observation must have been sent.
assert fake.last_messages[0]["role"] == "system"
assert fake.last_messages[1]["role"] == "user"
def test_make_policy_falls_back_to_escalate_on_garbage() -> None:
fake = _FakeProvider("not json at all, sorry")
policy = make_policy(fake, tier="basic")
class FakeObs:
prompt_text = "..."
action = policy(FakeObs())
assert action == {"action_type": "escalate"}
def test_make_policy_falls_back_to_escalate_on_provider_auth_error() -> None:
class FakeAuthBrokenProvider:
name = "anthropic"
model = "claude-sonnet-4-6"
def chat_sync(self, messages: list[dict[str, str]], **kwargs: Any) -> str:
raise ProviderAuthError("anthropic")
async def chat(self, messages: list[dict[str, str]], **kwargs: Any) -> str: # pragma: no cover
return self.chat_sync(messages, **kwargs)
policy = make_policy(FakeAuthBrokenProvider(), tier="basic")
class FakeObs:
prompt_text = "..."
action = policy(FakeObs())
assert action == {"action_type": "escalate"}
def test_make_policy_max_tier_uses_max_observation_renderer() -> None:
fake = _FakeProvider('{"action_type":"escalate"}')
policy = make_policy(fake, tier="max")
from dataclasses import dataclass
@dataclass
class FakeMaxObs:
family_id: str = "ecommerce_vibecoded_saas"
chaos: str = "deploy_regression"
tick_count: int = 1
max_ticks: int = 25
incident_summary: str = "test"
services: dict[str, dict[str, Any]] = None
cause_removed: bool = False
blast_radius: int = 0
last_log: str = "..."
obs = FakeMaxObs(services={"api-gateway": {"status": "healthy", "cpu_pct": 30.0, "error_rate_pct": 0.0, "latency_ms": 30.0}})
policy(obs)
user_text = fake.last_messages[1]["content"]
assert "FAMILY:" in user_text or "CHAOS:" in user_text