File size: 4,026 Bytes
414dc55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
"""In-process llama.cpp backend via ``llama-cpp-python``.

The shipped runtime: fully in-process on the CPU - no server, no GPU, no network. The
model is loaded once (lazily, on first use) and reused for every call.
"""

from __future__ import annotations

import json
from collections.abc import Iterator
from pathlib import Path

from ..config import Settings
from .backend import GenParams, LLMError


class LlamaCppBackend:
    """Wraps a single ``llama_cpp.Llama`` instance, loaded lazily on first use."""

    def __init__(self, model_path: Path, *, n_ctx: int, n_threads: int) -> None:
        self.model_path = model_path
        self._n_ctx = n_ctx
        self._n_threads = n_threads
        self._llama: object | None = None

    @classmethod
    def from_settings(cls, settings: Settings) -> LlamaCppBackend:
        try:
            import llama_cpp  # noqa: F401
        except ImportError as exc:  # pragma: no cover
            raise LLMError(
                "llama-cpp-python is not installed. Install it with "
                "'pip install -r requirements.txt'."
            ) from exc

        model_path = settings.llm_model_path
        if not model_path.exists():
            raise LLMError(
                f"model weights not found at {model_path}. Run scripts/fetch_models.py."
            )
        return cls(model_path, n_ctx=settings.llm_n_ctx, n_threads=settings.llm_n_threads)

    def _ensure(self) -> object:
        if self._llama is None:
            from llama_cpp import Llama

            self._llama = Llama(
                model_path=str(self.model_path),
                n_ctx=self._n_ctx,
                n_threads=self._n_threads,
                n_threads_batch=self._n_threads,
                n_gpu_layers=0,
                # RAM is plentiful on the Space (model is ~1GB of 16GB); lock the weights
                # resident so they are never paged out mid-game. Ignored if unsupported.
                use_mlock=True,
                verbose=False,
            )
        return self._llama

    def _grammar(self, params: GenParams) -> object | None:
        from llama_cpp import LlamaGrammar

        if params.grammar:
            return LlamaGrammar.from_string(params.grammar, verbose=False)
        if params.json_schema:
            return LlamaGrammar.from_json_schema(json.dumps(params.json_schema), verbose=False)
        return None

    def _messages(self, prompt: str) -> list[dict[str, str]]:
        return [{"role": "user", "content": prompt}]

    def generate(self, prompt: str, params: GenParams) -> str:
        llama = self._ensure()
        result = llama.create_chat_completion(  # type: ignore[attr-defined]
            messages=self._messages(prompt),
            max_tokens=params.max_tokens,
            temperature=params.temperature,
            top_p=params.top_p,
            stop=list(params.stop) or None,
            grammar=self._grammar(params),
            seed=params.seed,
            repeat_penalty=params.repeat_penalty,
            frequency_penalty=params.frequency_penalty,
            presence_penalty=params.presence_penalty,
        )
        return result["choices"][0]["message"]["content"] or ""

    def stream(self, prompt: str, params: GenParams) -> Iterator[str]:
        llama = self._ensure()
        chunks = llama.create_chat_completion(  # type: ignore[attr-defined]
            messages=self._messages(prompt),
            max_tokens=params.max_tokens,
            temperature=params.temperature,
            top_p=params.top_p,
            stop=list(params.stop) or None,
            grammar=self._grammar(params),
            seed=params.seed,
            repeat_penalty=params.repeat_penalty,
            frequency_penalty=params.frequency_penalty,
            presence_penalty=params.presence_penalty,
            stream=True,
        )
        for chunk in chunks:
            delta = chunk["choices"][0].get("delta", {})
            text = delta.get("content")
            if text:
                yield text