File size: 5,929 Bytes
9fca766
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
"""Inference layer tests: SSE parsing, fixtures, backend selection."""

import json
from pathlib import Path

import httpx
import pytest

from scrypt.inference import build_backend
from scrypt.inference.api import OpenAIChatBackend
from scrypt.inference.backend import (
    RecordingBackend,
    ReplayBackend,
    ScriptedBackend,
    complete,
)
from scrypt.inference.local import LocalSetupError, LlamaServer, preflight


def sse_response(*texts: str) -> bytes:
    lines = []
    for t in texts:
        payload = {"choices": [{"delta": {"content": t}}]}
        lines.append(f"data: {json.dumps(payload)}")
    lines.append("data: [DONE]")
    return ("\n".join(lines) + "\n").encode()


async def test_openai_backend_parses_sse_stream():
    def handler(request: httpx.Request) -> httpx.Response:
        body = json.loads(request.content)
        assert body["stream"] is True
        assert body["chat_template_kwargs"] == {"enable_thinking": False}
        assert request.headers["authorization"] == "Bearer k"
        return httpx.Response(
            200,
            content=sse_response("The ", "scale ", "tips."),
            headers={"content-type": "text/event-stream"},
        )

    client = httpx.AsyncClient(transport=httpx.MockTransport(handler))
    backend = OpenAIChatBackend("http://test/v1", api_key="k", client=client)
    text = await complete(backend, [{"role": "user", "content": "hi"}])
    assert text == "The scale tips."


async def test_record_then_replay_roundtrip(tmp_path: Path):
    fixture = tmp_path / "fixtures.jsonl"
    live = ScriptedBackend(default="recorded line")
    messages = [{"role": "user", "content": "moment"}]

    recorder = RecordingBackend(live, fixture)
    assert await complete(recorder, messages) == "recorded line"

    replay = ReplayBackend(fixture)
    assert await complete(replay, messages) == "recorded line"
    with pytest.raises(KeyError):
        await complete(replay, [{"role": "user", "content": "unseen"}])


def test_preflight_reports_missing_pieces(monkeypatch, tmp_path):
    monkeypatch.setattr("scrypt.inference.local.find_binary", lambda: None)
    monkeypatch.setattr("scrypt.inference.local.llama_cpp_available", lambda: False)
    monkeypatch.setattr("scrypt.inference.local.SCRYPT_HOME", tmp_path)
    problems = preflight()
    assert any("llama-server" in p for p in problems)
    assert any("model" in p for p in problems)


def test_installed_model_finds_any_gguf_name(monkeypatch, tmp_path):
    """A hand-downloaded file with a custom name still counts."""
    import scrypt.inference.local as local

    monkeypatch.setattr(local, "SCRYPT_HOME", tmp_path)
    models = tmp_path / "models"
    models.mkdir()
    (models / "my-cool-quant-q4.gguf").write_bytes(b"x" * 10)
    assert local.installed_model().name == "my-cool-quant-q4.gguf"


def test_server_command_falls_back_to_llama_cpp_python(monkeypatch, tmp_path):
    import sys

    import scrypt.inference.local as local

    monkeypatch.setattr(local, "find_binary", lambda: None)
    monkeypatch.setattr(local, "llama_cpp_available", lambda: True)
    cmd = local.server_command(tmp_path / "m.gguf", 8731, 8192)
    assert cmd[:3] == [sys.executable, "-m", "llama_cpp.server"]

    monkeypatch.setattr(local, "llama_cpp_available", lambda: False)
    assert local.server_command(tmp_path / "m.gguf", 8731, 8192) is None


def test_llama_server_start_refuses_without_setup(monkeypatch):
    monkeypatch.setattr("scrypt.inference.local.find_binary", lambda: None)
    with pytest.raises(LocalSetupError):
        LlamaServer().start()


def test_build_backend_falls_back_to_scripted(monkeypatch):
    monkeypatch.delenv("SCRYPT_API_KEY", raising=False)
    monkeypatch.setenv("SCRYPT_BACKEND", "auto")
    monkeypatch.setattr("scrypt.inference.preflight", lambda: ["no model"])
    backend, server, mode = build_backend()
    assert mode == "scripted" and server is None
    assert isinstance(backend, ScriptedBackend)


def test_build_backend_api_mode(monkeypatch):
    monkeypatch.setenv("SCRYPT_BACKEND", "api")
    monkeypatch.setenv("SCRYPT_API_KEY", "sk-test")
    backend, server, mode = build_backend()
    assert mode == "api"
    assert isinstance(backend, OpenAIChatBackend)


def test_quant_ladder_tiers(monkeypatch):
    from scrypt.inference.local import choose_quant

    monkeypatch.delenv("SCRYPT_QUANT", raising=False)
    assert choose_quant(128) == "Q8_0"
    assert choose_quant(96) == "Q8_0"
    assert choose_quant(64) == "Q5_K_M"
    assert choose_quant(48) == "Q4_K_S"
    assert choose_quant(32) == "Q3_K_S"
    assert choose_quant(16) is None          # booted to API mode
    assert choose_quant(0) == "Q4_K_S"       # unknown RAM -> safe default


def test_quant_env_override(monkeypatch):
    from scrypt.inference.local import choose_quant

    monkeypatch.setenv("SCRYPT_QUANT", "Q6_K")
    assert choose_quant(32) == "Q6_K"


def test_preflight_names_machine_tier(monkeypatch, tmp_path):
    import scrypt.inference.local as local

    monkeypatch.delenv("SCRYPT_QUANT", raising=False)
    monkeypatch.setattr(local, "SCRYPT_HOME", tmp_path)
    monkeypatch.setattr(local, "find_binary", lambda: None)
    monkeypatch.setattr(local, "system_ram_gb", lambda: 33.0)
    problems = local.preflight()
    assert any("Q3_K_S" in p for p in problems)

    monkeypatch.setattr(local, "system_ram_gb", lambda: 16.0)
    problems = local.preflight()
    assert any("API mode" in p for p in problems)
    assert not any("Q3_K_S" in p for p in problems)


def test_installed_model_prefers_heaviest(monkeypatch, tmp_path):
    import scrypt.inference.local as local

    monkeypatch.setattr(local, "SCRYPT_HOME", tmp_path)
    (tmp_path / "models").mkdir()
    (tmp_path / "models" / local.model_file("Q3_K_S")).touch()
    (tmp_path / "models" / local.model_file("Q5_K_M")).touch()
    assert local.installed_model().name == local.model_file("Q5_K_M")