Spaces:
Running on Zero
Running on Zero
| """Inference layer tests: SSE parsing, fixtures, backend selection.""" | |
| import json | |
| from pathlib import Path | |
| import httpx | |
| import pytest | |
| from scrypt.inference import build_backend | |
| from scrypt.inference.api import OpenAIChatBackend | |
| from scrypt.inference.backend import ( | |
| RecordingBackend, | |
| ReplayBackend, | |
| ScriptedBackend, | |
| complete, | |
| ) | |
| from scrypt.inference.local import LocalSetupError, LlamaServer, preflight | |
| def sse_response(*texts: str) -> bytes: | |
| lines = [] | |
| for t in texts: | |
| payload = {"choices": [{"delta": {"content": t}}]} | |
| lines.append(f"data: {json.dumps(payload)}") | |
| lines.append("data: [DONE]") | |
| return ("\n".join(lines) + "\n").encode() | |
| async def test_openai_backend_parses_sse_stream(): | |
| def handler(request: httpx.Request) -> httpx.Response: | |
| body = json.loads(request.content) | |
| assert body["stream"] is True | |
| assert body["chat_template_kwargs"] == {"enable_thinking": False} | |
| assert request.headers["authorization"] == "Bearer k" | |
| return httpx.Response( | |
| 200, | |
| content=sse_response("The ", "scale ", "tips."), | |
| headers={"content-type": "text/event-stream"}, | |
| ) | |
| client = httpx.AsyncClient(transport=httpx.MockTransport(handler)) | |
| backend = OpenAIChatBackend("http://test/v1", api_key="k", client=client) | |
| text = await complete(backend, [{"role": "user", "content": "hi"}]) | |
| assert text == "The scale tips." | |
| async def test_record_then_replay_roundtrip(tmp_path: Path): | |
| fixture = tmp_path / "fixtures.jsonl" | |
| live = ScriptedBackend(default="recorded line") | |
| messages = [{"role": "user", "content": "moment"}] | |
| recorder = RecordingBackend(live, fixture) | |
| assert await complete(recorder, messages) == "recorded line" | |
| replay = ReplayBackend(fixture) | |
| assert await complete(replay, messages) == "recorded line" | |
| with pytest.raises(KeyError): | |
| await complete(replay, [{"role": "user", "content": "unseen"}]) | |
| def test_preflight_reports_missing_pieces(monkeypatch, tmp_path): | |
| monkeypatch.setattr("scrypt.inference.local.find_binary", lambda: None) | |
| monkeypatch.setattr("scrypt.inference.local.llama_cpp_available", lambda: False) | |
| monkeypatch.setattr("scrypt.inference.local.SCRYPT_HOME", tmp_path) | |
| problems = preflight() | |
| assert any("llama-server" in p for p in problems) | |
| assert any("model" in p for p in problems) | |
| def test_installed_model_finds_any_gguf_name(monkeypatch, tmp_path): | |
| """A hand-downloaded file with a custom name still counts.""" | |
| import scrypt.inference.local as local | |
| monkeypatch.setattr(local, "SCRYPT_HOME", tmp_path) | |
| models = tmp_path / "models" | |
| models.mkdir() | |
| (models / "my-cool-quant-q4.gguf").write_bytes(b"x" * 10) | |
| assert local.installed_model().name == "my-cool-quant-q4.gguf" | |
| def test_server_command_falls_back_to_llama_cpp_python(monkeypatch, tmp_path): | |
| import sys | |
| import scrypt.inference.local as local | |
| monkeypatch.setattr(local, "find_binary", lambda: None) | |
| monkeypatch.setattr(local, "llama_cpp_available", lambda: True) | |
| cmd = local.server_command(tmp_path / "m.gguf", 8731, 8192) | |
| assert cmd[:3] == [sys.executable, "-m", "llama_cpp.server"] | |
| monkeypatch.setattr(local, "llama_cpp_available", lambda: False) | |
| assert local.server_command(tmp_path / "m.gguf", 8731, 8192) is None | |
| def test_llama_server_start_refuses_without_setup(monkeypatch): | |
| monkeypatch.setattr("scrypt.inference.local.find_binary", lambda: None) | |
| with pytest.raises(LocalSetupError): | |
| LlamaServer().start() | |
| def test_build_backend_falls_back_to_scripted(monkeypatch): | |
| monkeypatch.delenv("SCRYPT_API_KEY", raising=False) | |
| monkeypatch.setenv("SCRYPT_BACKEND", "auto") | |
| monkeypatch.setattr("scrypt.inference.preflight", lambda: ["no model"]) | |
| backend, server, mode = build_backend() | |
| assert mode == "scripted" and server is None | |
| assert isinstance(backend, ScriptedBackend) | |
| def test_build_backend_api_mode(monkeypatch): | |
| monkeypatch.setenv("SCRYPT_BACKEND", "api") | |
| monkeypatch.setenv("SCRYPT_API_KEY", "sk-test") | |
| backend, server, mode = build_backend() | |
| assert mode == "api" | |
| assert isinstance(backend, OpenAIChatBackend) | |
| def test_quant_ladder_tiers(monkeypatch): | |
| from scrypt.inference.local import choose_quant | |
| monkeypatch.delenv("SCRYPT_QUANT", raising=False) | |
| assert choose_quant(128) == "Q8_0" | |
| assert choose_quant(96) == "Q8_0" | |
| assert choose_quant(64) == "Q5_K_M" | |
| assert choose_quant(48) == "Q4_K_S" | |
| assert choose_quant(32) == "Q3_K_S" | |
| assert choose_quant(16) is None # booted to API mode | |
| assert choose_quant(0) == "Q4_K_S" # unknown RAM -> safe default | |
| def test_quant_env_override(monkeypatch): | |
| from scrypt.inference.local import choose_quant | |
| monkeypatch.setenv("SCRYPT_QUANT", "Q6_K") | |
| assert choose_quant(32) == "Q6_K" | |
| def test_preflight_names_machine_tier(monkeypatch, tmp_path): | |
| import scrypt.inference.local as local | |
| monkeypatch.delenv("SCRYPT_QUANT", raising=False) | |
| monkeypatch.setattr(local, "SCRYPT_HOME", tmp_path) | |
| monkeypatch.setattr(local, "find_binary", lambda: None) | |
| monkeypatch.setattr(local, "system_ram_gb", lambda: 33.0) | |
| problems = local.preflight() | |
| assert any("Q3_K_S" in p for p in problems) | |
| monkeypatch.setattr(local, "system_ram_gb", lambda: 16.0) | |
| problems = local.preflight() | |
| assert any("API mode" in p for p in problems) | |
| assert not any("Q3_K_S" in p for p in problems) | |
| def test_installed_model_prefers_heaviest(monkeypatch, tmp_path): | |
| import scrypt.inference.local as local | |
| monkeypatch.setattr(local, "SCRYPT_HOME", tmp_path) | |
| (tmp_path / "models").mkdir() | |
| (tmp_path / "models" / local.model_file("Q3_K_S")).touch() | |
| (tmp_path / "models" / local.model_file("Q5_K_M")).touch() | |
| assert local.installed_model().name == local.model_file("Q5_K_M") | |