Spaces:
Sleeping
test(hf-space): cover _call_huggingface (full) + _call_zerogpu (refactor)
Browse filesSymmetric coverage with the Premium path. 59 unit tests now pass
(was 42). One small production refactor to make ZeroGPU testable.
app.py refactor:
- Extracted _zerogpu_invoke() with the actual model-invocation logic
(chat template build → device move → generate → prompt-strip →
decode). _call_zerogpu became a one-line @spaces.GPU wrapper.
- This lets tests exercise the invocation path without needing torch
or @spaces.GPU runtime, by monkeypatching the module-level
_zerogpu_tokenizer / _zerogpu_model / _load_zerogpu_model.
- Stub fallback (deps unavailable) is unchanged.
test_diagnose.py: 17 new tests
_call_huggingface (10):
Token resolution:
- No token anywhere → RuntimeError with actionable message
- HF_TOKEN env wins (primary)
- HUGGING_FACE_HUB_TOKEN env as fallback
- get_token() from `hf auth login` cache as last fallback
- HF_TOKEN wins over the other two sources
InferenceClient init shape:
- model = HF_MODEL_ID
- provider="auto" — catches regressions that would re-break the
modern HF Inference Providers routing (the bug we fixed earlier)
- timeout=120
chat_completion shape:
- messages = [system, user] with correct role/content
- max_tokens=2500
- temperature=0.2 — intentionally low for small-model JSON
adherence; catch drift
- response unwrap via choices[0].message.content
Error handling:
- model_not_supported → RuntimeError with billing guidance
- alternate phrasing also triggers the wrap
- Other exceptions (ValueError, etc.) pass through so F14 can
format them in diagnose()
_call_zerogpu / _zerogpu_invoke (7):
Stub path:
- When deps unavailable, _call_zerogpu raises clear RuntimeError
- _zerogpu_available() reflects _ZEROGPU_DEPS_AVAILABLE
Invocation shape (via _zerogpu_invoke with mocked tokenizer/model):
- Builds chat template with system + user roles, return_tensors="pt",
add_generation_prompt=True
- Moves inputs to model.device (the .to() chain)
- generate() called with max_new_tokens=2500, temperature=0.2,
do_sample=True (required for non-zero temp), pad_token_id=eos_token_id
- Prompt tokens are stripped before decode (outputs[0][prompt_len:])
- skip_special_tokens=True on decode
- Returns the decoded string
What this catches in practice:
- Bumping HF_MODEL_ID without re-validating it gets passed correctly
- Accidentally removing provider="auto" (the model_not_supported bug)
- SDK arg name changes (max_new_tokens vs max_tokens for HF chat_completion
vs generate — easy to confuse)
- Forgetting do_sample=True when setting temperature
- Wrong response-unwrap path (HF uses .choices[].message.content,
Anthropic uses .content[0].text — easy to mix up)
- Forgetting to strip prompt tokens (would echo back the system prompt)
All 59 unit tests pass + 1 skipped opt-in integration test.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- app.py +32 -23
- test_diagnose.py +282 -0
|
@@ -377,35 +377,44 @@ def _load_zerogpu_model():
|
|
| 377 |
)
|
| 378 |
|
| 379 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 380 |
if _ZEROGPU_DEPS_AVAILABLE:
|
| 381 |
|
| 382 |
@_spaces.GPU(duration=ZEROGPU_DURATION_SECONDS)
|
| 383 |
def _call_zerogpu(system_block: str, user_prompt: str) -> str:
|
| 384 |
"""ZeroGPU backend. Loads Phi-4-mini-instruct (or whatever
|
| 385 |
ZEROGPU_MODEL_ID points at) into the Space's allocated GPU and
|
| 386 |
-
runs chat-template inference.
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
messages = [
|
| 390 |
-
{"role": "system", "content": system_block},
|
| 391 |
-
{"role": "user", "content": user_prompt},
|
| 392 |
-
]
|
| 393 |
-
inputs = _zerogpu_tokenizer.apply_chat_template(
|
| 394 |
-
messages,
|
| 395 |
-
return_tensors="pt",
|
| 396 |
-
add_generation_prompt=True,
|
| 397 |
-
).to(_zerogpu_model.device)
|
| 398 |
-
outputs = _zerogpu_model.generate(
|
| 399 |
-
inputs,
|
| 400 |
-
max_new_tokens=2500,
|
| 401 |
-
temperature=0.2,
|
| 402 |
-
do_sample=True,
|
| 403 |
-
pad_token_id=_zerogpu_tokenizer.eos_token_id,
|
| 404 |
-
)
|
| 405 |
-
prompt_len = inputs.shape[1]
|
| 406 |
-
return _zerogpu_tokenizer.decode(
|
| 407 |
-
outputs[0][prompt_len:], skip_special_tokens=True
|
| 408 |
-
)
|
| 409 |
|
| 410 |
else:
|
| 411 |
|
|
|
|
| 377 |
)
|
| 378 |
|
| 379 |
|
| 380 |
+
def _zerogpu_invoke(system_block: str, user_prompt: str) -> str:
|
| 381 |
+
"""Model invocation logic for the ZeroGPU backend. Separated from
|
| 382 |
+
the `@spaces.GPU` decoration below so it can be unit-tested without
|
| 383 |
+
actually allocating a GPU. The function reads module-level globals
|
| 384 |
+
(`_zerogpu_tokenizer`, `_zerogpu_model`) which tests can monkeypatch
|
| 385 |
+
to fake the transformers types."""
|
| 386 |
+
_load_zerogpu_model()
|
| 387 |
+
messages = [
|
| 388 |
+
{"role": "system", "content": system_block},
|
| 389 |
+
{"role": "user", "content": user_prompt},
|
| 390 |
+
]
|
| 391 |
+
inputs = _zerogpu_tokenizer.apply_chat_template(
|
| 392 |
+
messages,
|
| 393 |
+
return_tensors="pt",
|
| 394 |
+
add_generation_prompt=True,
|
| 395 |
+
).to(_zerogpu_model.device)
|
| 396 |
+
outputs = _zerogpu_model.generate(
|
| 397 |
+
inputs,
|
| 398 |
+
max_new_tokens=2500,
|
| 399 |
+
temperature=0.2,
|
| 400 |
+
do_sample=True,
|
| 401 |
+
pad_token_id=_zerogpu_tokenizer.eos_token_id,
|
| 402 |
+
)
|
| 403 |
+
prompt_len = inputs.shape[1]
|
| 404 |
+
return _zerogpu_tokenizer.decode(
|
| 405 |
+
outputs[0][prompt_len:], skip_special_tokens=True
|
| 406 |
+
)
|
| 407 |
+
|
| 408 |
+
|
| 409 |
if _ZEROGPU_DEPS_AVAILABLE:
|
| 410 |
|
| 411 |
@_spaces.GPU(duration=ZEROGPU_DURATION_SECONDS)
|
| 412 |
def _call_zerogpu(system_block: str, user_prompt: str) -> str:
|
| 413 |
"""ZeroGPU backend. Loads Phi-4-mini-instruct (or whatever
|
| 414 |
ZEROGPU_MODEL_ID points at) into the Space's allocated GPU and
|
| 415 |
+
runs chat-template inference. Thin wrapper around the testable
|
| 416 |
+
`_zerogpu_invoke` so the decorator stays at module load time."""
|
| 417 |
+
return _zerogpu_invoke(system_block, user_prompt)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 418 |
|
| 419 |
else:
|
| 420 |
|
|
@@ -15,11 +15,16 @@ from unittest.mock import MagicMock
|
|
| 15 |
|
| 16 |
from app import (
|
| 17 |
ANTHROPIC_MODEL_ID,
|
|
|
|
| 18 |
MalformedResponseError,
|
| 19 |
PROVIDERS,
|
| 20 |
_call_anthropic,
|
|
|
|
| 21 |
_call_model,
|
|
|
|
| 22 |
_detect_provider,
|
|
|
|
|
|
|
| 23 |
diagnose,
|
| 24 |
parse_response,
|
| 25 |
)
|
|
@@ -480,6 +485,283 @@ def test_call_anthropic_passes_system_block_with_cache_control(monkeypatch):
|
|
| 480 |
assert captured["messages"] == [{"role": "user", "content": "MY USER PROMPT"}]
|
| 481 |
|
| 482 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 483 |
# --- Integration test (opt-in; hits the real Anthropic API) ----------------
|
| 484 |
#
|
| 485 |
# Skipped unless ANTHROPIC_API_KEY is set AND ANTHROPIC_INTEGRATION=1 is
|
|
|
|
| 15 |
|
| 16 |
from app import (
|
| 17 |
ANTHROPIC_MODEL_ID,
|
| 18 |
+
HF_MODEL_ID,
|
| 19 |
MalformedResponseError,
|
| 20 |
PROVIDERS,
|
| 21 |
_call_anthropic,
|
| 22 |
+
_call_huggingface,
|
| 23 |
_call_model,
|
| 24 |
+
_call_zerogpu,
|
| 25 |
_detect_provider,
|
| 26 |
+
_zerogpu_available,
|
| 27 |
+
_zerogpu_invoke,
|
| 28 |
diagnose,
|
| 29 |
parse_response,
|
| 30 |
)
|
|
|
|
| 485 |
assert captured["messages"] == [{"role": "user", "content": "MY USER PROMPT"}]
|
| 486 |
|
| 487 |
|
| 488 |
+
# --- _call_huggingface: token resolution + call shape ----------------------
|
| 489 |
+
|
| 490 |
+
|
| 491 |
+
def _install_fake_inference_client(monkeypatch, captured: dict, *,
|
| 492 |
+
response_text: str = "hf response",
|
| 493 |
+
raises: Exception | None = None):
|
| 494 |
+
"""Replace huggingface_hub.InferenceClient with a fake that records
|
| 495 |
+
its init kwargs and chat_completion kwargs into `captured`. Optionally
|
| 496 |
+
have chat_completion raise an exception instead of returning."""
|
| 497 |
+
|
| 498 |
+
class _FakeMsg:
|
| 499 |
+
content = response_text
|
| 500 |
+
|
| 501 |
+
class _FakeChoice:
|
| 502 |
+
message = _FakeMsg()
|
| 503 |
+
|
| 504 |
+
class _FakeResponse:
|
| 505 |
+
choices = [_FakeChoice()]
|
| 506 |
+
|
| 507 |
+
class _FakeClient:
|
| 508 |
+
def __init__(self, **kwargs):
|
| 509 |
+
captured["init_kwargs"] = kwargs
|
| 510 |
+
|
| 511 |
+
def chat_completion(self, **kwargs):
|
| 512 |
+
captured["chat_kwargs"] = kwargs
|
| 513 |
+
if raises is not None:
|
| 514 |
+
raise raises
|
| 515 |
+
return _FakeResponse()
|
| 516 |
+
|
| 517 |
+
import huggingface_hub
|
| 518 |
+
monkeypatch.setattr(huggingface_hub, "InferenceClient", _FakeClient)
|
| 519 |
+
|
| 520 |
+
|
| 521 |
+
def test_call_huggingface_no_token_anywhere_raises_actionable_error(monkeypatch):
|
| 522 |
+
monkeypatch.delenv("HF_TOKEN", raising=False)
|
| 523 |
+
monkeypatch.delenv("HUGGING_FACE_HUB_TOKEN", raising=False)
|
| 524 |
+
import huggingface_hub
|
| 525 |
+
monkeypatch.setattr(huggingface_hub, "get_token", lambda: None)
|
| 526 |
+
|
| 527 |
+
with pytest.raises(RuntimeError, match="No HuggingFace token"):
|
| 528 |
+
_call_huggingface("sys", "usr")
|
| 529 |
+
|
| 530 |
+
|
| 531 |
+
def test_call_huggingface_uses_HF_TOKEN_env(monkeypatch):
|
| 532 |
+
monkeypatch.setenv("HF_TOKEN", "hf_from_env")
|
| 533 |
+
captured = {}
|
| 534 |
+
_install_fake_inference_client(monkeypatch, captured)
|
| 535 |
+
_call_huggingface("sys", "usr")
|
| 536 |
+
assert captured["init_kwargs"]["token"] == "hf_from_env"
|
| 537 |
+
|
| 538 |
+
|
| 539 |
+
def test_call_huggingface_uses_HUGGING_FACE_HUB_TOKEN_env_as_fallback(monkeypatch):
|
| 540 |
+
monkeypatch.delenv("HF_TOKEN", raising=False)
|
| 541 |
+
monkeypatch.setenv("HUGGING_FACE_HUB_TOKEN", "hf_legacy_var")
|
| 542 |
+
captured = {}
|
| 543 |
+
_install_fake_inference_client(monkeypatch, captured)
|
| 544 |
+
_call_huggingface("sys", "usr")
|
| 545 |
+
assert captured["init_kwargs"]["token"] == "hf_legacy_var"
|
| 546 |
+
|
| 547 |
+
|
| 548 |
+
def test_call_huggingface_uses_get_token_when_no_env(monkeypatch):
|
| 549 |
+
monkeypatch.delenv("HF_TOKEN", raising=False)
|
| 550 |
+
monkeypatch.delenv("HUGGING_FACE_HUB_TOKEN", raising=False)
|
| 551 |
+
import huggingface_hub
|
| 552 |
+
monkeypatch.setattr(huggingface_hub, "get_token", lambda: "hf_from_cli_login")
|
| 553 |
+
captured = {}
|
| 554 |
+
_install_fake_inference_client(monkeypatch, captured)
|
| 555 |
+
_call_huggingface("sys", "usr")
|
| 556 |
+
assert captured["init_kwargs"]["token"] == "hf_from_cli_login"
|
| 557 |
+
|
| 558 |
+
|
| 559 |
+
def test_call_huggingface_HF_TOKEN_wins_over_other_sources(monkeypatch):
|
| 560 |
+
monkeypatch.setenv("HF_TOKEN", "hf_winner")
|
| 561 |
+
monkeypatch.setenv("HUGGING_FACE_HUB_TOKEN", "hf_loser_1")
|
| 562 |
+
import huggingface_hub
|
| 563 |
+
monkeypatch.setattr(huggingface_hub, "get_token", lambda: "hf_loser_2")
|
| 564 |
+
captured = {}
|
| 565 |
+
_install_fake_inference_client(monkeypatch, captured)
|
| 566 |
+
_call_huggingface("sys", "usr")
|
| 567 |
+
assert captured["init_kwargs"]["token"] == "hf_winner"
|
| 568 |
+
|
| 569 |
+
|
| 570 |
+
def test_call_huggingface_init_shape_model_provider_timeout(monkeypatch):
|
| 571 |
+
monkeypatch.setenv("HF_TOKEN", "hf_test")
|
| 572 |
+
captured = {}
|
| 573 |
+
_install_fake_inference_client(monkeypatch, captured)
|
| 574 |
+
_call_huggingface("sys", "usr")
|
| 575 |
+
init = captured["init_kwargs"]
|
| 576 |
+
assert init["model"] == HF_MODEL_ID
|
| 577 |
+
# provider="auto" is the critical config that enables the modern HF
|
| 578 |
+
# Inference Providers routing layer — without it, the client falls
|
| 579 |
+
# back to the legacy hf-inference-only path. Catch any regression
|
| 580 |
+
# that removes this flag.
|
| 581 |
+
assert init["provider"] == "auto"
|
| 582 |
+
assert init["timeout"] == 120
|
| 583 |
+
|
| 584 |
+
|
| 585 |
+
def test_call_huggingface_chat_completion_call_shape(monkeypatch):
|
| 586 |
+
monkeypatch.setenv("HF_TOKEN", "hf_test")
|
| 587 |
+
captured = {}
|
| 588 |
+
_install_fake_inference_client(monkeypatch, captured)
|
| 589 |
+
result = _call_huggingface("MY SYSTEM BLOCK", "MY USER PROMPT")
|
| 590 |
+
chat = captured["chat_kwargs"]
|
| 591 |
+
assert chat["messages"] == [
|
| 592 |
+
{"role": "system", "content": "MY SYSTEM BLOCK"},
|
| 593 |
+
{"role": "user", "content": "MY USER PROMPT"},
|
| 594 |
+
]
|
| 595 |
+
assert chat["max_tokens"] == 2500
|
| 596 |
+
# Low temperature is intentional — smaller open models can produce
|
| 597 |
+
# looser JSON at higher temperatures. Catch any drift.
|
| 598 |
+
assert chat["temperature"] == 0.2
|
| 599 |
+
# Response unwrap: choices[0].message.content
|
| 600 |
+
assert result == "hf response"
|
| 601 |
+
|
| 602 |
+
|
| 603 |
+
def test_call_huggingface_model_not_supported_error_wrapped(monkeypatch):
|
| 604 |
+
monkeypatch.setenv("HF_TOKEN", "hf_test")
|
| 605 |
+
fake_hf_error = Exception(
|
| 606 |
+
"Bad request: {'message': \"The requested model is not supported "
|
| 607 |
+
"by any provider you have enabled.\", 'code': 'model_not_supported'}"
|
| 608 |
+
)
|
| 609 |
+
captured = {}
|
| 610 |
+
_install_fake_inference_client(monkeypatch, captured, raises=fake_hf_error)
|
| 611 |
+
with pytest.raises(RuntimeError, match="isn't available through any"):
|
| 612 |
+
_call_huggingface("sys", "usr")
|
| 613 |
+
|
| 614 |
+
|
| 615 |
+
def test_call_huggingface_model_not_supported_alternate_phrasing_wrapped(monkeypatch):
|
| 616 |
+
monkeypatch.setenv("HF_TOKEN", "hf_test")
|
| 617 |
+
fake_hf_error = Exception("...'code': 'model_not_supported'...")
|
| 618 |
+
captured = {}
|
| 619 |
+
_install_fake_inference_client(monkeypatch, captured, raises=fake_hf_error)
|
| 620 |
+
with pytest.raises(RuntimeError, match="isn't available through any"):
|
| 621 |
+
_call_huggingface("sys", "usr")
|
| 622 |
+
|
| 623 |
+
|
| 624 |
+
def test_call_huggingface_other_exception_passes_through(monkeypatch):
|
| 625 |
+
"""Errors that aren't the model_not_supported case (auth fail,
|
| 626 |
+
network timeout, malformed response) should propagate up so the
|
| 627 |
+
F14 wrapper in diagnose() can surface them with the original class
|
| 628 |
+
name and detail."""
|
| 629 |
+
monkeypatch.setenv("HF_TOKEN", "hf_test")
|
| 630 |
+
fake_other_error = ValueError("Invalid API key")
|
| 631 |
+
captured = {}
|
| 632 |
+
_install_fake_inference_client(monkeypatch, captured, raises=fake_other_error)
|
| 633 |
+
with pytest.raises(ValueError, match="Invalid API key"):
|
| 634 |
+
_call_huggingface("sys", "usr")
|
| 635 |
+
|
| 636 |
+
|
| 637 |
+
# --- _call_zerogpu: stub path + invocation shape --------------------------
|
| 638 |
+
|
| 639 |
+
|
| 640 |
+
def test_call_zerogpu_stub_raises_clear_error_when_deps_unavailable():
|
| 641 |
+
"""In a local environment without spaces/torch/transformers installed,
|
| 642 |
+
_ZEROGPU_DEPS_AVAILABLE is False and _call_zerogpu is the stub that
|
| 643 |
+
raises a RuntimeError pointing the user to the other two backends."""
|
| 644 |
+
if _zerogpu_available():
|
| 645 |
+
pytest.skip("Test only meaningful when zerogpu deps are NOT installed")
|
| 646 |
+
with pytest.raises(RuntimeError, match="ZeroGPU backend requires"):
|
| 647 |
+
_call_zerogpu("sys", "usr")
|
| 648 |
+
|
| 649 |
+
|
| 650 |
+
def test_zerogpu_available_reflects_dep_state():
|
| 651 |
+
"""_zerogpu_available() is the sole gating function for the zerogpu
|
| 652 |
+
branch in _detect_provider; it must return the cached import-time
|
| 653 |
+
boolean rather than re-trying imports on every call."""
|
| 654 |
+
import app as app_module
|
| 655 |
+
assert _zerogpu_available() is app_module._ZEROGPU_DEPS_AVAILABLE
|
| 656 |
+
|
| 657 |
+
|
| 658 |
+
def _install_fake_zerogpu_model(monkeypatch, captured: dict, *,
|
| 659 |
+
prompt_len: int = 5,
|
| 660 |
+
decoded_text: str = "model output"):
|
| 661 |
+
"""Replace the module-level _zerogpu_tokenizer and _zerogpu_model
|
| 662 |
+
with fakes that record their calls. Simulates transformers types
|
| 663 |
+
just enough for _zerogpu_invoke() to run end-to-end without torch
|
| 664 |
+
actually installed."""
|
| 665 |
+
import app as app_module
|
| 666 |
+
|
| 667 |
+
class _FakeInputs:
|
| 668 |
+
def __init__(self):
|
| 669 |
+
self.shape = (1, prompt_len)
|
| 670 |
+
|
| 671 |
+
def to(self, device):
|
| 672 |
+
captured["inputs_moved_to_device"] = device
|
| 673 |
+
return self # chain .to() back into self for further use
|
| 674 |
+
|
| 675 |
+
fake_inputs = _FakeInputs()
|
| 676 |
+
fake_outputs = [list(range(prompt_len + 10))] # prompt tokens + 10 new tokens
|
| 677 |
+
|
| 678 |
+
class _FakeTokenizer:
|
| 679 |
+
eos_token_id = 99
|
| 680 |
+
|
| 681 |
+
def apply_chat_template(self, messages, **kwargs):
|
| 682 |
+
captured["apply_chat_template"] = {
|
| 683 |
+
"messages": messages,
|
| 684 |
+
"kwargs": kwargs,
|
| 685 |
+
}
|
| 686 |
+
return fake_inputs
|
| 687 |
+
|
| 688 |
+
def decode(self, token_ids, **kwargs):
|
| 689 |
+
captured["decode"] = {"token_ids": list(token_ids), "kwargs": kwargs}
|
| 690 |
+
return decoded_text
|
| 691 |
+
|
| 692 |
+
class _FakeModel:
|
| 693 |
+
device = "cuda:0"
|
| 694 |
+
|
| 695 |
+
def generate(self, inputs, **kwargs):
|
| 696 |
+
captured["generate_inputs"] = inputs
|
| 697 |
+
captured["generate_kwargs"] = kwargs
|
| 698 |
+
return fake_outputs
|
| 699 |
+
|
| 700 |
+
monkeypatch.setattr(app_module, "_zerogpu_tokenizer", _FakeTokenizer())
|
| 701 |
+
monkeypatch.setattr(app_module, "_zerogpu_model", _FakeModel())
|
| 702 |
+
# Skip the real model-load path; we've already populated the globals.
|
| 703 |
+
monkeypatch.setattr(app_module, "_load_zerogpu_model", lambda: None)
|
| 704 |
+
|
| 705 |
+
|
| 706 |
+
def test_zerogpu_invoke_builds_chat_template_with_system_and_user(monkeypatch):
|
| 707 |
+
captured = {}
|
| 708 |
+
_install_fake_zerogpu_model(monkeypatch, captured)
|
| 709 |
+
_zerogpu_invoke("MY SYSTEM BLOCK", "MY USER PROMPT")
|
| 710 |
+
chat = captured["apply_chat_template"]
|
| 711 |
+
assert chat["messages"] == [
|
| 712 |
+
{"role": "system", "content": "MY SYSTEM BLOCK"},
|
| 713 |
+
{"role": "user", "content": "MY USER PROMPT"},
|
| 714 |
+
]
|
| 715 |
+
assert chat["kwargs"]["return_tensors"] == "pt"
|
| 716 |
+
assert chat["kwargs"]["add_generation_prompt"] is True
|
| 717 |
+
|
| 718 |
+
|
| 719 |
+
def test_zerogpu_invoke_moves_inputs_to_model_device(monkeypatch):
|
| 720 |
+
captured = {}
|
| 721 |
+
_install_fake_zerogpu_model(monkeypatch, captured)
|
| 722 |
+
_zerogpu_invoke("sys", "usr")
|
| 723 |
+
assert captured["inputs_moved_to_device"] == "cuda:0"
|
| 724 |
+
|
| 725 |
+
|
| 726 |
+
def test_zerogpu_invoke_generate_call_shape(monkeypatch):
|
| 727 |
+
"""The .generate() kwargs are easy to typo and carry real semantics:
|
| 728 |
+
max_new_tokens=2500 caps output length
|
| 729 |
+
temperature=0.2 keeps JSON output stable for small models
|
| 730 |
+
do_sample=True is needed for non-zero temperature to have effect
|
| 731 |
+
pad_token_id=eos_token_id avoids warning spam on short prompts
|
| 732 |
+
Catch regressions in any of these."""
|
| 733 |
+
captured = {}
|
| 734 |
+
_install_fake_zerogpu_model(monkeypatch, captured)
|
| 735 |
+
_zerogpu_invoke("sys", "usr")
|
| 736 |
+
gen = captured["generate_kwargs"]
|
| 737 |
+
assert gen["max_new_tokens"] == 2500
|
| 738 |
+
assert gen["temperature"] == 0.2
|
| 739 |
+
assert gen["do_sample"] is True
|
| 740 |
+
assert gen["pad_token_id"] == 99 # _FakeTokenizer.eos_token_id
|
| 741 |
+
|
| 742 |
+
|
| 743 |
+
def test_zerogpu_invoke_strips_prompt_tokens_before_decode(monkeypatch):
|
| 744 |
+
"""The decoded output must be the GENERATED text only, not echo back
|
| 745 |
+
the prompt. The function does this by slicing outputs[0][prompt_len:]
|
| 746 |
+
before calling decode. Verify the slice happens correctly."""
|
| 747 |
+
captured = {}
|
| 748 |
+
# prompt_len=5 → fake_outputs returns range(15) (5 prompt + 10 generated)
|
| 749 |
+
# so decode should be called with tokens [5..15)
|
| 750 |
+
_install_fake_zerogpu_model(monkeypatch, captured, prompt_len=5)
|
| 751 |
+
_zerogpu_invoke("sys", "usr")
|
| 752 |
+
decoded_tokens = captured["decode"]["token_ids"]
|
| 753 |
+
assert decoded_tokens == list(range(5, 15))
|
| 754 |
+
# And skip_special_tokens is on so we don't include things like </s>
|
| 755 |
+
assert captured["decode"]["kwargs"]["skip_special_tokens"] is True
|
| 756 |
+
|
| 757 |
+
|
| 758 |
+
def test_zerogpu_invoke_returns_decoded_text(monkeypatch):
|
| 759 |
+
captured = {}
|
| 760 |
+
_install_fake_zerogpu_model(monkeypatch, captured, decoded_text="my generated answer")
|
| 761 |
+
result = _zerogpu_invoke("sys", "usr")
|
| 762 |
+
assert result == "my generated answer"
|
| 763 |
+
|
| 764 |
+
|
| 765 |
# --- Integration test (opt-in; hits the real Anthropic API) ----------------
|
| 766 |
#
|
| 767 |
# Skipped unless ANTHROPIC_API_KEY is set AND ANTHROPIC_INTEGRATION=1 is
|