compounding-test / test_diagnose.py
apingali
perf(hf-space): pre-load model at module init (Option 3 refactor)
c673b37
"""Sanity tests for parse_response() — the deterministic JSON-extraction
contract between the Claude response and the Gradio UI. The LLM call
itself is exempt from unit tests (Principle VII); these tests cover
only the parser surface.
Run: pytest gradio-apps/compounding-test/test_diagnose.py -v
"""
from __future__ import annotations
import pytest
import json
import os
from unittest.mock import MagicMock
from app import (
ANTHROPIC_MODEL_ID,
HF_MODEL_ID,
MalformedResponseError,
PROVIDERS,
_call_anthropic,
_call_huggingface,
_call_model,
_call_zerogpu,
_detect_provider,
_zerogpu_available,
_zerogpu_invoke,
diagnose,
parse_response,
)
import app as app_module
# --- Fixtures ---------------------------------------------------------------
VALID_JSON_BLOCK = """```json
{
"constraint": "Underwriting at the quote screen.",
"scores": {
"proprietary_data": { "score": 4, "rationale": "First-party policy data.", "quoted_span": "claim outcomes Progressive observes directly" },
"self_labeling": { "score": 4, "rationale": "Every policy term self-labels.", "quoted_span": "every policy term produces a claim" },
"decreasing_marginal_cost": { "score": 3, "rationale": "Pipeline amortized.", "quoted_span": "17 years of amortized pipeline" },
"defensible_asymmetry": { "score": 3, "rationale": "Integration depth capped.", "quoted_span": "behavior data integration depth" }
},
"quadrant": "compounder",
"closest_portrait": "progressive",
"closest_portrait_paragraph": "Your case tracks Progressive most closely because the labeling loop runs through claim outcomes you directly observe.",
"warnings": []
}
```
# The Verdict
Your initiative is a compounder. Here's why this lands cleanly on all
four conditions and what to do Monday morning to keep it compounding.
"""
# --- Happy path -------------------------------------------------------------
def test_valid_response_parses_top_level_fields():
r = parse_response(VALID_JSON_BLOCK)
assert r.constraint == "Underwriting at the quote screen."
assert r.quadrant == "compounder"
assert r.closest_portrait == "progressive"
assert "compounder" in r.closest_portrait_paragraph or "Progressive" in r.closest_portrait_paragraph
assert r.warnings == []
def test_valid_response_captures_writeup_after_json_block():
r = parse_response(VALID_JSON_BLOCK)
assert "# The Verdict" in r.writeup
assert "Monday morning" in r.writeup
def test_valid_response_extracts_all_four_scores():
r = parse_response(VALID_JSON_BLOCK)
assert set(r.scores.keys()) == {
"proprietary_data",
"self_labeling",
"decreasing_marginal_cost",
"defensible_asymmetry",
}
assert r.scores["proprietary_data"].score == 4
assert r.scores["decreasing_marginal_cost"].score == 3
assert r.scores["defensible_asymmetry"].quoted_span == "behavior data integration depth"
# --- Sad path: missing or malformed JSON ------------------------------------
def test_no_json_block_raises():
raw = "Hi there, no JSON here at all, just prose explaining the verdict."
with pytest.raises(MalformedResponseError, match="json"):
parse_response(raw)
def test_invalid_json_inside_block_raises():
raw = "```json\n{ this is not valid json }\n```\n# Writeup"
with pytest.raises(MalformedResponseError):
parse_response(raw)
def test_missing_required_top_level_field_raises():
raw = """```json
{
"constraint": "...",
"scores": {},
"quadrant": "compounder"
}
```"""
# missing closest_portrait, closest_portrait_paragraph, warnings
with pytest.raises(MalformedResponseError, match="closest_portrait"):
parse_response(raw)
# --- Sad path: enum validation ----------------------------------------------
def test_invalid_quadrant_raises():
raw = VALID_JSON_BLOCK.replace('"quadrant": "compounder"', '"quadrant": "bogus-quadrant"')
with pytest.raises(MalformedResponseError, match="quadrant"):
parse_response(raw)
def test_invalid_closest_portrait_raises():
raw = VALID_JSON_BLOCK.replace(
'"closest_portrait": "progressive"', '"closest_portrait": "wells-fargo"'
)
with pytest.raises(MalformedResponseError, match="closest_portrait"):
parse_response(raw)
# --- Sad path: score-range validation --------------------------------------
def test_score_below_zero_raises():
raw = VALID_JSON_BLOCK.replace('"score": 4, "rationale": "First-party policy data."', '"score": -1, "rationale": "First-party policy data."')
with pytest.raises(MalformedResponseError, match="score"):
parse_response(raw)
def test_score_above_four_raises():
raw = VALID_JSON_BLOCK.replace('"score": 4, "rationale": "First-party policy data."', '"score": 7, "rationale": "First-party policy data."')
with pytest.raises(MalformedResponseError, match="score"):
parse_response(raw)
def test_score_not_integer_raises():
raw = VALID_JSON_BLOCK.replace('"score": 4, "rationale": "First-party policy data."', '"score": "high", "rationale": "First-party policy data."')
with pytest.raises(MalformedResponseError, match="score"):
parse_response(raw)
# --- Sad path: quoted_span validation --------------------------------------
def test_empty_quoted_span_raises():
raw = VALID_JSON_BLOCK.replace(
'"quoted_span": "claim outcomes Progressive observes directly"',
'"quoted_span": ""',
)
with pytest.raises(MalformedResponseError, match="quoted_span"):
parse_response(raw)
def test_quoted_span_over_400_chars_raises():
"""The 400-char limit is a generous ceiling — Phi-4-mini consistently
generates ~200-220 char quoted_spans when asked for 5-15 words, so
we bumped from 200 to 400 to accommodate normal model output without
losing the runaway-output guard."""
over_limit = "x" * 401
raw = VALID_JSON_BLOCK.replace(
'"quoted_span": "claim outcomes Progressive observes directly"',
f'"quoted_span": "{over_limit}"',
)
with pytest.raises(MalformedResponseError, match="quoted_span"):
parse_response(raw)
def test_quoted_span_up_to_400_chars_accepted():
"""Confirms the new ceiling lets typical Phi-4-mini output through."""
at_limit = "x" * 250 # well above the prior 200-char cap
raw = VALID_JSON_BLOCK.replace(
'"quoted_span": "claim outcomes Progressive observes directly"',
f'"quoted_span": "{at_limit}"',
)
r = parse_response(raw)
assert len(r.scores["proprietary_data"].quoted_span) == 250
# --- Tolerance: forward-compat and whitespace ------------------------------
def test_extra_unknown_fields_tolerated():
raw = VALID_JSON_BLOCK.replace(
'"warnings": []',
'"warnings": [], "future_field": "ignored", "another": 42',
)
r = parse_response(raw) # should not raise
assert r.quadrant == "compounder"
# --- Provider auto-detection (multi-backend support) ----------------------
def test_detect_provider_explicit_anthropic_wins():
env = {"MODEL_PROVIDER": "anthropic", "HF_TOKEN": "hf-xxx"}
assert _detect_provider(env) == "anthropic"
def test_detect_provider_explicit_huggingface_wins():
env = {"MODEL_PROVIDER": "huggingface", "ANTHROPIC_API_KEY": "sk-xxx"}
assert _detect_provider(env) == "huggingface"
def test_detect_provider_case_insensitive():
assert _detect_provider({"MODEL_PROVIDER": "HuggingFace"}) == "huggingface"
def test_detect_provider_invalid_explicit_falls_through():
# bogus MODEL_PROVIDER is ignored; auto-detect kicks in
env = {"MODEL_PROVIDER": "bogus", "ANTHROPIC_API_KEY": "sk-xxx"}
assert _detect_provider(env) == "anthropic"
def test_detect_provider_anthropic_when_only_anthropic_key_set():
assert _detect_provider({"ANTHROPIC_API_KEY": "sk-xxx"}) == "anthropic"
def test_detect_provider_huggingface_when_only_hf_token_set():
assert _detect_provider({"HF_TOKEN": "hf-xxx"}) == "huggingface"
def test_detect_provider_huggingface_when_running_on_hf_space_without_zerogpu(monkeypatch):
# On a Space WITHOUT ZeroGPU deps installed, fall back to the inference API.
monkeypatch.setattr(app_module, "_zerogpu_available", lambda: False)
assert _detect_provider({"SPACE_ID": "mile-hi-ai/compounding-test"}) == "huggingface"
def test_detect_provider_prefers_zerogpu_on_pro_space_with_deps(monkeypatch):
# On a Space WITH ZeroGPU deps installed (transformers + torch + spaces),
# default to the free GPU backend rather than burning inference credits.
monkeypatch.setattr(app_module, "_zerogpu_available", lambda: True)
assert _detect_provider({"SPACE_ID": "mile-hi-ai/compounding-test"}) == "zerogpu"
def test_detect_provider_explicit_anthropic_wins_over_zerogpu(monkeypatch):
# Explicit MODEL_PROVIDER beats the zerogpu auto-detect even on a Pro Space.
monkeypatch.setattr(app_module, "_zerogpu_available", lambda: True)
env = {"MODEL_PROVIDER": "anthropic", "SPACE_ID": "mile-hi-ai/compounding-test"}
assert _detect_provider(env) == "anthropic"
def test_detect_provider_explicit_zerogpu_wins():
assert _detect_provider({"MODEL_PROVIDER": "zerogpu"}) == "zerogpu"
def test_zerogpu_is_in_providers_dict():
# Even when deps aren't installed locally, the provider key exists so the
# UI dropdown can show it (the stub raises a clear error if invoked).
assert "zerogpu" in PROVIDERS
def test_detect_provider_alt_hf_token_var():
# HuggingFace SDKs also recognize HUGGING_FACE_HUB_TOKEN
assert _detect_provider({"HUGGING_FACE_HUB_TOKEN": "hf-xxx"}) == "huggingface"
def test_detect_provider_default_when_nothing_set():
# No creds anywhere → default to anthropic (clearest error at call time)
assert _detect_provider({}) == "anthropic"
# --- Provider dispatch (_call_model routes to the right backend) -----------
def test_call_model_routes_to_anthropic_backend(monkeypatch):
calls = []
monkeypatch.setitem(PROVIDERS, "anthropic", lambda s, u: (calls.append(("anthropic", s, u)) or "anth-out"))
out = _call_model("system-text", "user-text", "anthropic")
assert out == "anth-out"
assert calls == [("anthropic", "system-text", "user-text")]
def test_call_model_routes_to_huggingface_backend(monkeypatch):
calls = []
monkeypatch.setitem(PROVIDERS, "huggingface", lambda s, u: (calls.append(("hf", s, u)) or "hf-out"))
out = _call_model("system-text", "user-text", "huggingface")
assert out == "hf-out"
assert calls == [("hf", "system-text", "user-text")]
def test_call_model_unknown_provider_raises():
with pytest.raises(ValueError, match="provider"):
_call_model("s", "u", "bogus-provider")
# --- diagnose() input validation -------------------------------------------
# Reused across diagnose() tests: a description long enough to pass the
# 200-word minimum. The actual content doesn't matter for these tests
# because we mock the backend.
_LONG_DESCRIPTION = " ".join(["word"] * 250)
def test_diagnose_empty_description_returns_friendly_error():
writeup, json_str = diagnose("", None, None, None, provider="zerogpu")
assert "Please describe" in writeup
assert json_str == ""
def test_diagnose_short_description_returns_word_count_error():
short = " ".join(["word"] * 50)
writeup, json_str = diagnose(short, None, None, None, provider="zerogpu")
assert "at least 200 words" in writeup
assert "50" in writeup # current word count
assert json_str == ""
def test_diagnose_long_description_returns_word_count_error():
long = " ".join(["word"] * 5001)
writeup, json_str = diagnose(long, None, None, None, provider="zerogpu")
assert "under 5000 words" in writeup
assert json_str == ""
def test_diagnose_unknown_provider_returns_friendly_error():
writeup, json_str = diagnose(_LONG_DESCRIPTION, None, None, None, provider="bogus")
assert "Unknown model provider" in writeup
assert "bogus" in writeup
assert json_str == ""
# --- diagnose() Premium (Anthropic) path -----------------------------------
def test_diagnose_premium_without_any_key_returns_friendly_error(monkeypatch):
monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
writeup, json_str = diagnose(
_LONG_DESCRIPTION, None, None, None,
provider="anthropic",
anthropic_api_key=None,
)
assert "Premium" in writeup
assert "API key" in writeup
assert json_str == ""
def test_diagnose_premium_with_empty_string_key_returns_friendly_error(monkeypatch):
# Whitespace-only key should not count as supplied
monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
writeup, _ = diagnose(
_LONG_DESCRIPTION, None, None, None,
provider="anthropic",
anthropic_api_key=" ",
)
assert "Premium" in writeup
assert "API key" in writeup
def test_diagnose_premium_with_env_key_dispatches_to_anthropic(monkeypatch):
monkeypatch.setenv("ANTHROPIC_API_KEY", "sk-env-xxx")
captured = {}
def fake_anthropic(system, user):
captured["system"] = system
captured["user"] = user
captured["env_key_at_call_time"] = os.environ.get("ANTHROPIC_API_KEY")
return _VALID_BACKEND_RESPONSE
monkeypatch.setitem(PROVIDERS, "anthropic", fake_anthropic)
writeup, json_str = diagnose(
_LONG_DESCRIPTION, "insurance", "enterprise", "$1M–$10M",
provider="anthropic",
anthropic_api_key=None,
)
# Backend was called (so dispatch worked) and the env key was visible
assert captured.get("env_key_at_call_time") == "sk-env-xxx"
# And the response made it through the parser → JSON tab populated
assert json_str
parsed = json.loads(json_str)
assert parsed["quadrant"] == "compounder"
def test_diagnose_premium_user_key_passed_directly_not_via_env(monkeypatch):
"""The page's API-key field should take precedence over any
ANTHROPIC_API_KEY env var the Space owner has configured. Critically,
the visitor's key must be passed DIRECTLY to _call_anthropic via
kwarg — never written to os.environ — or concurrent requests from
other visitors could pick up the wrong key from shared process env.
See _call_anthropic docstring."""
monkeypatch.setenv("ANTHROPIC_API_KEY", "sk-space-owner-xxx")
captured = {}
def fake_call_anthropic(system, user, *, api_key=None):
captured["api_key_kwarg"] = api_key
captured["env_at_call_time"] = os.environ.get("ANTHROPIC_API_KEY")
return _VALID_BACKEND_RESPONSE
monkeypatch.setattr(app_module, "_call_anthropic", fake_call_anthropic)
diagnose(
_LONG_DESCRIPTION, None, None, None,
provider="anthropic",
anthropic_api_key="sk-user-yyy",
)
# User key passed directly via kwarg (the override mechanism)
assert captured["api_key_kwarg"] == "sk-user-yyy"
# CRITICAL: env was NOT clobbered with the user's key — Space
# owner's key remained intact for any concurrent request that
# legitimately needs it (or for no request at all if there's no
# owner-set key).
assert captured["env_at_call_time"] == "sk-space-owner-xxx"
def test_diagnose_premium_does_not_mutate_env_with_user_key(monkeypatch):
"""Cross-tenant key-leak regression test. On a public Space, two
concurrent visitors may both submit Premium requests. Each must use
only their own key; neither should ever see the other's key via
os.environ. The fix is to never write user-supplied keys to env."""
monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
captured = {}
def fake_call_anthropic(system, user, *, api_key=None):
captured["api_key_kwarg"] = api_key
captured["env_at_call_time"] = os.environ.get("ANTHROPIC_API_KEY")
return _VALID_BACKEND_RESPONSE
monkeypatch.setattr(app_module, "_call_anthropic", fake_call_anthropic)
diagnose(
_LONG_DESCRIPTION, None, None, None,
provider="anthropic",
anthropic_api_key="sk-visitor-A-secret",
)
# The key went directly to the SDK, not via env
assert captured["api_key_kwarg"] == "sk-visitor-A-secret"
# Env was never set during the call
assert captured["env_at_call_time"] is None
# And env is still unset after the call returns — no residue for
# the next visitor's concurrent request to pick up
assert os.environ.get("ANTHROPIC_API_KEY") is None
def test_diagnose_redacts_user_key_from_error_messages(monkeypatch):
"""Defense-in-depth: if a backend exception ever included the
user-supplied Anthropic key in its string representation, the F14
wrapper must redact it before surfacing the error to the UI.
Symmetric with redactKey() in src/lib/anthropic-direct.ts."""
monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
user_key = "sk-ant-very-secret-12345"
class _LeakyError(Exception):
pass
def leaky_anthropic(system, user, *, api_key=None):
# Simulate the worst case: SDK echoes the key in its error
raise _LeakyError(f"auth fail with key {api_key} rejected")
monkeypatch.setattr(app_module, "_call_anthropic", leaky_anthropic)
writeup, _ = diagnose(
_LONG_DESCRIPTION, None, None, None,
provider="anthropic",
anthropic_api_key=user_key,
)
assert user_key not in writeup
assert "[redacted]" in writeup
# And the rest of the error info should still be visible
assert "LeakyError" in writeup
def test_call_anthropic_passes_api_key_to_sdk_constructor(monkeypatch):
"""When _call_anthropic receives api_key=, it must be passed to the
Anthropic() SDK constructor — not stored in os.environ, not
discarded, not exposed elsewhere."""
captured_init = {}
class FakeContentBlock:
text = "ok"
class FakeMessage:
content = [FakeContentBlock()]
class FakeClient:
class messages: # noqa: N801
@staticmethod
def create(**kwargs):
return FakeMessage()
def fake_anthropic_ctor(**kwargs):
captured_init.update(kwargs)
return FakeClient()
import anthropic as anthropic_module
monkeypatch.setattr(anthropic_module, "Anthropic", fake_anthropic_ctor)
monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
_call_anthropic("sys", "usr", api_key="sk-direct-yyy")
assert captured_init.get("api_key") == "sk-direct-yyy"
# And env was not touched
assert os.environ.get("ANTHROPIC_API_KEY") is None
def test_call_anthropic_without_api_key_uses_env_via_sdk(monkeypatch):
"""When api_key is not supplied, the SDK constructor is called with
no kwargs — letting it read ANTHROPIC_API_KEY from env, as is the
SDK's normal default behavior. We don't explicitly pass api_key=None
because the SDK treats that differently than 'not supplied'."""
captured_init = {}
class FakeContentBlock:
text = "ok"
class FakeMessage:
content = [FakeContentBlock()]
class FakeClient:
class messages: # noqa: N801
@staticmethod
def create(**kwargs):
return FakeMessage()
def fake_anthropic_ctor(**kwargs):
captured_init.update(kwargs)
return FakeClient()
import anthropic as anthropic_module
monkeypatch.setattr(anthropic_module, "Anthropic", fake_anthropic_ctor)
monkeypatch.setenv("ANTHROPIC_API_KEY", "sk-env-default")
_call_anthropic("sys", "usr") # no api_key kwarg
# SDK constructor called with no api_key — it'll use env on its own
assert "api_key" not in captured_init
def test_diagnose_premium_backend_exception_returns_friendly_error(monkeypatch):
"""When the Anthropic SDK raises (auth fail, rate limit, network),
F14 should wrap it in a markdown message that names the provider,
model, exception class, and exception detail — never a raw trace."""
monkeypatch.setenv("ANTHROPIC_API_KEY", "sk-test")
class FakeAuthError(Exception):
pass
def failing_anthropic(system, user):
raise FakeAuthError("invalid x-api-key header")
monkeypatch.setitem(PROVIDERS, "anthropic", failing_anthropic)
writeup, json_str = diagnose(
_LONG_DESCRIPTION, None, None, None,
provider="anthropic",
anthropic_api_key=None,
)
assert "anthropic" in writeup
assert ANTHROPIC_MODEL_ID in writeup
assert "FakeAuthError" in writeup
assert "invalid x-api-key header" in writeup
assert "stack" not in writeup.lower() # no stack trace leaked
assert json_str == ""
def test_diagnose_premium_backend_returns_malformed_response(monkeypatch):
"""When the backend returns something that fails the JSON schema
validator, surface the parser's error message — don't crash."""
monkeypatch.setenv("ANTHROPIC_API_KEY", "sk-test")
def garbage_anthropic(system, user):
return "Sorry, I cannot help with that request." # no JSON block
monkeypatch.setitem(PROVIDERS, "anthropic", garbage_anthropic)
writeup, json_str = diagnose(
_LONG_DESCRIPTION, None, None, None,
provider="anthropic",
anthropic_api_key=None,
)
assert "malformed output" in writeup
assert json_str == ""
# --- _call_anthropic: Anthropic SDK call shape -----------------------------
#
# Per Principle VII the actual API call is exempt from automated tests
# (the SDK and the remote API are not our code). But the SHAPE of the
# call we make IS our code: model id, system block, cache_control flag,
# messages structure, response-unwrap path. These are easy to typo and
# easy to miss in review. Shape tests catch that without hitting the
# network.
def test_call_anthropic_passes_system_block_with_cache_control(monkeypatch):
captured = {}
class FakeContentBlock:
text = "raw response text"
class FakeMessage:
content = [FakeContentBlock()]
class FakeClient:
class messages: # noqa: N801 — mirroring SDK's nested .messages.create
@staticmethod
def create(**kwargs):
captured.update(kwargs)
return FakeMessage()
import anthropic as anthropic_module
monkeypatch.setattr(anthropic_module, "Anthropic", lambda: FakeClient())
result = _call_anthropic("MY SYSTEM BLOCK", "MY USER PROMPT")
# The function unwrapped content[0].text correctly
assert result == "raw response text"
# Model + token budget
assert captured["model"] == ANTHROPIC_MODEL_ID
assert captured["max_tokens"] == 2500
# System block is a list of one dict with cache_control
sys_block = captured["system"]
assert isinstance(sys_block, list)
assert len(sys_block) == 1
assert sys_block[0]["type"] == "text"
assert sys_block[0]["text"] == "MY SYSTEM BLOCK"
assert sys_block[0]["cache_control"] == {"type": "ephemeral"}
# User prompt in the messages array
assert captured["messages"] == [{"role": "user", "content": "MY USER PROMPT"}]
# --- _call_huggingface: token resolution + call shape ----------------------
def _install_fake_inference_client(monkeypatch, captured: dict, *,
response_text: str = "hf response",
raises: Exception | None = None):
"""Replace huggingface_hub.InferenceClient with a fake that records
its init kwargs and chat_completion kwargs into `captured`. Optionally
have chat_completion raise an exception instead of returning."""
class _FakeMsg:
content = response_text
class _FakeChoice:
message = _FakeMsg()
class _FakeResponse:
choices = [_FakeChoice()]
class _FakeClient:
def __init__(self, **kwargs):
captured["init_kwargs"] = kwargs
def chat_completion(self, **kwargs):
captured["chat_kwargs"] = kwargs
if raises is not None:
raise raises
return _FakeResponse()
import huggingface_hub
monkeypatch.setattr(huggingface_hub, "InferenceClient", _FakeClient)
def test_call_huggingface_no_token_anywhere_raises_actionable_error(monkeypatch):
monkeypatch.delenv("HF_TOKEN", raising=False)
monkeypatch.delenv("HUGGING_FACE_HUB_TOKEN", raising=False)
import huggingface_hub
monkeypatch.setattr(huggingface_hub, "get_token", lambda: None)
with pytest.raises(RuntimeError, match="No HuggingFace token"):
_call_huggingface("sys", "usr")
def test_call_huggingface_uses_HF_TOKEN_env(monkeypatch):
monkeypatch.setenv("HF_TOKEN", "hf_from_env")
captured = {}
_install_fake_inference_client(monkeypatch, captured)
_call_huggingface("sys", "usr")
assert captured["init_kwargs"]["token"] == "hf_from_env"
def test_call_huggingface_uses_HUGGING_FACE_HUB_TOKEN_env_as_fallback(monkeypatch):
monkeypatch.delenv("HF_TOKEN", raising=False)
monkeypatch.setenv("HUGGING_FACE_HUB_TOKEN", "hf_legacy_var")
captured = {}
_install_fake_inference_client(monkeypatch, captured)
_call_huggingface("sys", "usr")
assert captured["init_kwargs"]["token"] == "hf_legacy_var"
def test_call_huggingface_uses_get_token_when_no_env(monkeypatch):
monkeypatch.delenv("HF_TOKEN", raising=False)
monkeypatch.delenv("HUGGING_FACE_HUB_TOKEN", raising=False)
import huggingface_hub
monkeypatch.setattr(huggingface_hub, "get_token", lambda: "hf_from_cli_login")
captured = {}
_install_fake_inference_client(monkeypatch, captured)
_call_huggingface("sys", "usr")
assert captured["init_kwargs"]["token"] == "hf_from_cli_login"
def test_call_huggingface_HF_TOKEN_wins_over_other_sources(monkeypatch):
monkeypatch.setenv("HF_TOKEN", "hf_winner")
monkeypatch.setenv("HUGGING_FACE_HUB_TOKEN", "hf_loser_1")
import huggingface_hub
monkeypatch.setattr(huggingface_hub, "get_token", lambda: "hf_loser_2")
captured = {}
_install_fake_inference_client(monkeypatch, captured)
_call_huggingface("sys", "usr")
assert captured["init_kwargs"]["token"] == "hf_winner"
def test_call_huggingface_init_shape_model_provider_timeout(monkeypatch):
monkeypatch.setenv("HF_TOKEN", "hf_test")
captured = {}
_install_fake_inference_client(monkeypatch, captured)
_call_huggingface("sys", "usr")
init = captured["init_kwargs"]
assert init["model"] == HF_MODEL_ID
# provider="auto" is the critical config that enables the modern HF
# Inference Providers routing layer — without it, the client falls
# back to the legacy hf-inference-only path. Catch any regression
# that removes this flag.
assert init["provider"] == "auto"
assert init["timeout"] == 120
def test_call_huggingface_chat_completion_call_shape(monkeypatch):
monkeypatch.setenv("HF_TOKEN", "hf_test")
captured = {}
_install_fake_inference_client(monkeypatch, captured)
result = _call_huggingface("MY SYSTEM BLOCK", "MY USER PROMPT")
chat = captured["chat_kwargs"]
assert chat["messages"] == [
{"role": "system", "content": "MY SYSTEM BLOCK"},
{"role": "user", "content": "MY USER PROMPT"},
]
assert chat["max_tokens"] == 2500
# Low temperature is intentional — smaller open models can produce
# looser JSON at higher temperatures. Catch any drift.
assert chat["temperature"] == 0.2
# Response unwrap: choices[0].message.content
assert result == "hf response"
def test_call_huggingface_model_not_supported_error_wrapped(monkeypatch):
monkeypatch.setenv("HF_TOKEN", "hf_test")
fake_hf_error = Exception(
"Bad request: {'message': \"The requested model is not supported "
"by any provider you have enabled.\", 'code': 'model_not_supported'}"
)
captured = {}
_install_fake_inference_client(monkeypatch, captured, raises=fake_hf_error)
with pytest.raises(RuntimeError, match="isn't available through any"):
_call_huggingface("sys", "usr")
def test_call_huggingface_model_not_supported_alternate_phrasing_wrapped(monkeypatch):
monkeypatch.setenv("HF_TOKEN", "hf_test")
fake_hf_error = Exception("...'code': 'model_not_supported'...")
captured = {}
_install_fake_inference_client(monkeypatch, captured, raises=fake_hf_error)
with pytest.raises(RuntimeError, match="isn't available through any"):
_call_huggingface("sys", "usr")
def test_call_huggingface_other_exception_passes_through(monkeypatch):
"""Errors that aren't the model_not_supported case (auth fail,
network timeout, malformed response) should propagate up so the
F14 wrapper in diagnose() can surface them with the original class
name and detail."""
monkeypatch.setenv("HF_TOKEN", "hf_test")
fake_other_error = ValueError("Invalid API key")
captured = {}
_install_fake_inference_client(monkeypatch, captured, raises=fake_other_error)
with pytest.raises(ValueError, match="Invalid API key"):
_call_huggingface("sys", "usr")
# --- _call_zerogpu: stub path + invocation shape --------------------------
def test_call_zerogpu_stub_raises_clear_error_when_deps_unavailable():
"""In a local environment without spaces/torch/transformers installed,
_ZEROGPU_DEPS_AVAILABLE is False and _call_zerogpu is the stub that
raises a RuntimeError pointing the user to the other two backends."""
if _zerogpu_available():
pytest.skip("Test only meaningful when zerogpu deps are NOT installed")
with pytest.raises(RuntimeError, match="ZeroGPU backend requires"):
_call_zerogpu("sys", "usr")
def test_zerogpu_available_reflects_dep_state():
"""_zerogpu_available() is the sole gating function for the zerogpu
branch in _detect_provider; it must return the cached import-time
boolean rather than re-trying imports on every call."""
import app as app_module
assert _zerogpu_available() is app_module._ZEROGPU_DEPS_AVAILABLE
def _install_fake_zerogpu_model(monkeypatch, captured: dict, *,
prompt_len: int = 5,
decoded_text: str = "model output"):
"""Replace the module-level _zerogpu_tokenizer and _zerogpu_model
with fakes that record their calls. Simulates transformers types
just enough for _zerogpu_invoke() to run end-to-end without torch
actually installed."""
import app as app_module
class _FakeInputs:
def __init__(self):
self.shape = (1, prompt_len)
def to(self, device):
captured["inputs_moved_to_device"] = device
return self # chain .to() back into self for further use
fake_inputs = _FakeInputs()
fake_outputs = [list(range(prompt_len + 10))] # prompt tokens + 10 new tokens
class _FakeTokenizer:
eos_token_id = 99
def apply_chat_template(self, messages, **kwargs):
captured["apply_chat_template"] = {
"messages": messages,
"kwargs": kwargs,
}
return fake_inputs
def decode(self, token_ids, **kwargs):
captured["decode"] = {"token_ids": list(token_ids), "kwargs": kwargs}
return decoded_text
class _FakeModel:
device = "cpu" # starts on CPU; _zerogpu_invoke moves to cuda
def to(self, device):
captured["model_moved_to_device"] = device
self.device = device
return self
def generate(self, inputs, **kwargs):
captured["generate_inputs"] = inputs
captured["generate_kwargs"] = kwargs
return fake_outputs
monkeypatch.setattr(app_module, "_zerogpu_tokenizer", _FakeTokenizer())
monkeypatch.setattr(app_module, "_zerogpu_model", _FakeModel())
# Note: no _load_zerogpu_model to patch — after the pre-load refactor
# (commit ___), model load happens at module init, not lazily.
def test_zerogpu_invoke_builds_chat_template_with_system_and_user(monkeypatch):
captured = {}
_install_fake_zerogpu_model(monkeypatch, captured)
_zerogpu_invoke("MY SYSTEM BLOCK", "MY USER PROMPT")
chat = captured["apply_chat_template"]
assert chat["messages"] == [
{"role": "system", "content": "MY SYSTEM BLOCK"},
{"role": "user", "content": "MY USER PROMPT"},
]
assert chat["kwargs"]["return_tensors"] == "pt"
assert chat["kwargs"]["add_generation_prompt"] is True
def test_zerogpu_invoke_moves_model_and_inputs_to_cuda(monkeypatch):
"""Post-refactor (pre-load pattern): the model lives on CPU at
module init, and _zerogpu_invoke must explicitly move it AND the
input tensors to cuda inside the @spaces.GPU context."""
captured = {}
_install_fake_zerogpu_model(monkeypatch, captured)
_zerogpu_invoke("sys", "usr")
# Model: moved CPU → cuda inside the invoke
assert captured["model_moved_to_device"] == "cuda"
# Inputs: tokenized then moved to cuda for inference
assert captured["inputs_moved_to_device"] == "cuda"
def test_zerogpu_invoke_generate_call_shape(monkeypatch):
"""The .generate() kwargs are easy to typo and carry real semantics:
max_new_tokens=2500 caps output length
temperature=0.2 keeps JSON output stable for small models
do_sample=True is needed for non-zero temperature to have effect
pad_token_id=eos_token_id avoids warning spam on short prompts
Catch regressions in any of these."""
captured = {}
_install_fake_zerogpu_model(monkeypatch, captured)
_zerogpu_invoke("sys", "usr")
gen = captured["generate_kwargs"]
assert gen["max_new_tokens"] == 2500
assert gen["temperature"] == 0.2
assert gen["do_sample"] is True
assert gen["pad_token_id"] == 99 # _FakeTokenizer.eos_token_id
def test_zerogpu_invoke_strips_prompt_tokens_before_decode(monkeypatch):
"""The decoded output must be the GENERATED text only, not echo back
the prompt. The function does this by slicing outputs[0][prompt_len:]
before calling decode. Verify the slice happens correctly."""
captured = {}
# prompt_len=5 → fake_outputs returns range(15) (5 prompt + 10 generated)
# so decode should be called with tokens [5..15)
_install_fake_zerogpu_model(monkeypatch, captured, prompt_len=5)
_zerogpu_invoke("sys", "usr")
decoded_tokens = captured["decode"]["token_ids"]
assert decoded_tokens == list(range(5, 15))
# And skip_special_tokens is on so we don't include things like </s>
assert captured["decode"]["kwargs"]["skip_special_tokens"] is True
def test_zerogpu_invoke_returns_decoded_text(monkeypatch):
captured = {}
_install_fake_zerogpu_model(monkeypatch, captured, decoded_text="my generated answer")
result = _zerogpu_invoke("sys", "usr")
assert result == "my generated answer"
# --- Integration test (opt-in; hits the real Anthropic API) ----------------
#
# Skipped unless ANTHROPIC_API_KEY is set AND ANTHROPIC_INTEGRATION=1 is
# set. Costs money to run (~$0.05 per call to Opus 4.7). Use this when
# you want to verify end-to-end that the key works and the model is
# reachable; routine CI should leave this skipped.
import pytest as _pytest # already imported above as pytest, but kept explicit
@_pytest.mark.skipif(
not (os.environ.get("ANTHROPIC_API_KEY") and os.environ.get("ANTHROPIC_INTEGRATION") == "1"),
reason="needs ANTHROPIC_API_KEY + ANTHROPIC_INTEGRATION=1 to hit the real API",
)
def test_call_anthropic_real_api_returns_text():
result = _call_anthropic(
"You are a one-word echo. Reply with exactly one word.",
"Say hello.",
)
assert isinstance(result, str)
assert len(result) > 0
assert len(result.split()) < 20 # one-word reply, generously bounded
# ---------------------------------------------------------------------------
# Fixture: a backend response that satisfies parse_response() so the
# diagnose-Premium happy-path tests can assert on parser output without
# duplicating the full JSON shape per test.
# ---------------------------------------------------------------------------
_VALID_BACKEND_RESPONSE = """```json
{
"constraint": "Underwriting at the quote screen.",
"scores": {
"proprietary_data": { "score": 4, "rationale": "first-party.", "quoted_span": "12-year claims database" },
"self_labeling": { "score": 4, "rationale": "policies self-label.", "quoted_span": "Every policy we write" },
"decreasing_marginal_cost": { "score": 3, "rationale": "amortized pipeline.", "quoted_span": "feed back into the next quarter" },
"defensible_asymmetry": { "score": 3, "rationale": "no data sharing.", "quoted_span": "we don't share data with" }
},
"quadrant": "compounder",
"closest_portrait": "progressive",
"closest_portrait_paragraph": "Your case tracks Progressive most closely.",
"warnings": []
}
```
# The Verdict
Your initiative is a compounder. Here's why.
"""
def test_warnings_populated_for_failure_quadrant():
raw = VALID_JSON_BLOCK.replace('"quadrant": "compounder"', '"quadrant": "roman-candle"').replace(
'"warnings": []',
'"warnings": [{"text": "Wrong place, conditions weak.", "citation_source": "Buffett 2007", "citation_url": "https://www.berkshirehathaway.com/letters/2007ltr.pdf"}]',
)
r = parse_response(raw)
assert r.quadrant == "roman-candle"
assert len(r.warnings) == 1
assert r.warnings[0].citation_source == "Buffett 2007"
assert "berkshirehathaway" in r.warnings[0].citation_url