hymenjj commited on
Commit
6f8aedf
·
verified ·
1 Parent(s): 7d58434

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +6 -0
  2. llama_cpp/__init__.py +4 -0
  3. llama_cpp/__pycache__/__init__.cpython-311.pyc +0 -0
  4. llama_cpp/__pycache__/_ctypes_extensions.cpython-311.pyc +0 -0
  5. llama_cpp/__pycache__/_ggml.cpython-311.pyc +0 -0
  6. llama_cpp/__pycache__/_internals.cpython-311.pyc +0 -0
  7. llama_cpp/__pycache__/_logger.cpython-311.pyc +0 -0
  8. llama_cpp/__pycache__/_utils.cpython-311.pyc +0 -0
  9. llama_cpp/__pycache__/llama.cpython-311.pyc +0 -0
  10. llama_cpp/__pycache__/llama_cache.cpython-311.pyc +0 -0
  11. llama_cpp/__pycache__/llama_chat_format.cpython-311.pyc +3 -0
  12. llama_cpp/__pycache__/llama_cpp.cpython-311.pyc +3 -0
  13. llama_cpp/__pycache__/llama_grammar.cpython-311.pyc +0 -0
  14. llama_cpp/__pycache__/llama_speculative.cpython-311.pyc +0 -0
  15. llama_cpp/__pycache__/llama_tokenizer.cpython-311.pyc +0 -0
  16. llama_cpp/__pycache__/llama_types.cpython-311.pyc +0 -0
  17. llama_cpp/__pycache__/llava_cpp.cpython-311.pyc +0 -0
  18. llama_cpp/__pycache__/mtmd_cpp.cpython-311.pyc +0 -0
  19. llama_cpp/_ctypes_extensions.py +131 -0
  20. llama_cpp/_ggml.py +12 -0
  21. llama_cpp/_internals.py +856 -0
  22. llama_cpp/_logger.py +47 -0
  23. llama_cpp/_utils.py +78 -0
  24. llama_cpp/lib/libggml-base.so +3 -0
  25. llama_cpp/lib/libggml-cpu.so +3 -0
  26. llama_cpp/lib/libggml.so +0 -0
  27. llama_cpp/lib/libllama.so +3 -0
  28. llama_cpp/lib/libmtmd.so +3 -0
  29. llama_cpp/llama.py +2422 -0
  30. llama_cpp/llama_cache.py +155 -0
  31. llama_cpp/llama_chat_format.py +0 -0
  32. llama_cpp/llama_cpp.py +0 -0
  33. llama_cpp/llama_grammar.py +953 -0
  34. llama_cpp/llama_speculative.py +64 -0
  35. llama_cpp/llama_tokenizer.py +120 -0
  36. llama_cpp/llama_types.py +316 -0
  37. llama_cpp/llava_cpp.py +158 -0
  38. llama_cpp/mtmd_cpp.py +280 -0
  39. llama_cpp/py.typed +0 -0
  40. llama_cpp/server/__init__.py +0 -0
  41. llama_cpp/server/__main__.py +100 -0
  42. llama_cpp/server/__pycache__/__init__.cpython-311.pyc +0 -0
  43. llama_cpp/server/__pycache__/__main__.cpython-311.pyc +0 -0
  44. llama_cpp/server/__pycache__/app.cpython-311.pyc +0 -0
  45. llama_cpp/server/__pycache__/cli.cpython-311.pyc +0 -0
  46. llama_cpp/server/__pycache__/errors.cpython-311.pyc +0 -0
  47. llama_cpp/server/__pycache__/model.cpython-311.pyc +0 -0
  48. llama_cpp/server/__pycache__/settings.cpython-311.pyc +0 -0
  49. llama_cpp/server/__pycache__/types.cpython-311.pyc +0 -0
  50. llama_cpp/server/app.py +597 -0
.gitattributes CHANGED
@@ -33,3 +33,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ llama_cpp/__pycache__/llama_chat_format.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
37
+ llama_cpp/__pycache__/llama_cpp.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
38
+ llama_cpp/lib/libggml-base.so filter=lfs diff=lfs merge=lfs -text
39
+ llama_cpp/lib/libggml-cpu.so filter=lfs diff=lfs merge=lfs -text
40
+ llama_cpp/lib/libllama.so filter=lfs diff=lfs merge=lfs -text
41
+ llama_cpp/lib/libmtmd.so filter=lfs diff=lfs merge=lfs -text
llama_cpp/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .llama_cpp import *
2
+ from .llama import *
3
+
4
+ __version__ = "0.3.16"
llama_cpp/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (259 Bytes). View file
 
llama_cpp/__pycache__/_ctypes_extensions.cpython-311.pyc ADDED
Binary file (6.13 kB). View file
 
llama_cpp/__pycache__/_ggml.cpython-311.pyc ADDED
Binary file (800 Bytes). View file
 
llama_cpp/__pycache__/_internals.cpython-311.pyc ADDED
Binary file (51.3 kB). View file
 
llama_cpp/__pycache__/_logger.cpython-311.pyc ADDED
Binary file (1.74 kB). View file
 
llama_cpp/__pycache__/_utils.cpython-311.pyc ADDED
Binary file (4.31 kB). View file
 
llama_cpp/__pycache__/llama.cpython-311.pyc ADDED
Binary file (93.7 kB). View file
 
llama_cpp/__pycache__/llama_cache.cpython-311.pyc ADDED
Binary file (9.74 kB). View file
 
llama_cpp/__pycache__/llama_chat_format.cpython-311.pyc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f1112a8090af3509b71ef46f87b57b23dbf0d410dda331b3612b506983a312b9
3
+ size 138859
llama_cpp/__pycache__/llama_cpp.cpython-311.pyc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ff68308c4b1951f3e8bdc79275ad24faa6cd18669939a1233be626713764d8a1
3
+ size 100194
llama_cpp/__pycache__/llama_grammar.cpython-311.pyc ADDED
Binary file (42.8 kB). View file
 
llama_cpp/__pycache__/llama_speculative.cpython-311.pyc ADDED
Binary file (3.39 kB). View file
 
llama_cpp/__pycache__/llama_tokenizer.cpython-311.pyc ADDED
Binary file (6.47 kB). View file
 
llama_cpp/__pycache__/llama_types.cpython-311.pyc ADDED
Binary file (16.8 kB). View file
 
llama_cpp/__pycache__/llava_cpp.cpython-311.pyc ADDED
Binary file (4.49 kB). View file
 
llama_cpp/__pycache__/mtmd_cpp.cpython-311.pyc ADDED
Binary file (8.75 kB). View file
 
llama_cpp/_ctypes_extensions.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import sys
4
+ import os
5
+ import ctypes
6
+ import functools
7
+ import pathlib
8
+
9
+ from typing import (
10
+ Any,
11
+ Callable,
12
+ List,
13
+ Union,
14
+ Optional,
15
+ TYPE_CHECKING,
16
+ TypeVar,
17
+ Generic,
18
+ )
19
+ from typing_extensions import TypeAlias
20
+
21
+
22
+ # Load the library
23
+ def load_shared_library(lib_base_name: str, base_path: pathlib.Path):
24
+ """Platform independent shared library loader"""
25
+ # Searching for the library in the current directory under the name "libllama" (default name
26
+ # for llamacpp) and "llama" (default name for this repo)
27
+ lib_paths: List[pathlib.Path] = []
28
+ # Determine the file extension based on the platform
29
+ if sys.platform.startswith("linux") or sys.platform.startswith("freebsd"):
30
+ lib_paths += [
31
+ base_path / f"lib{lib_base_name}.so",
32
+ ]
33
+ elif sys.platform == "darwin":
34
+ lib_paths += [
35
+ base_path / f"lib{lib_base_name}.so",
36
+ base_path / f"lib{lib_base_name}.dylib",
37
+ ]
38
+ elif sys.platform == "win32":
39
+ lib_paths += [
40
+ base_path / f"{lib_base_name}.dll",
41
+ base_path / f"lib{lib_base_name}.dll",
42
+ ]
43
+ else:
44
+ raise RuntimeError("Unsupported platform")
45
+
46
+ cdll_args = dict() # type: ignore
47
+
48
+ # Add the library directory to the DLL search path on Windows (if needed)
49
+ if sys.platform == "win32":
50
+ os.add_dll_directory(str(base_path))
51
+ os.environ["PATH"] = str(base_path) + os.pathsep + os.environ["PATH"]
52
+
53
+ if sys.platform == "win32" and sys.version_info >= (3, 8):
54
+ os.add_dll_directory(str(base_path))
55
+ if "CUDA_PATH" in os.environ:
56
+ os.add_dll_directory(os.path.join(os.environ["CUDA_PATH"], "bin"))
57
+ os.add_dll_directory(os.path.join(os.environ["CUDA_PATH"], "lib"))
58
+ if "HIP_PATH" in os.environ:
59
+ os.add_dll_directory(os.path.join(os.environ["HIP_PATH"], "bin"))
60
+ os.add_dll_directory(os.path.join(os.environ["HIP_PATH"], "lib"))
61
+ cdll_args["winmode"] = ctypes.RTLD_GLOBAL
62
+
63
+ # Try to load the shared library, handling potential errors
64
+ for lib_path in lib_paths:
65
+ if lib_path.exists():
66
+ try:
67
+ return ctypes.CDLL(str(lib_path), **cdll_args) # type: ignore
68
+ except Exception as e:
69
+ raise RuntimeError(f"Failed to load shared library '{lib_path}': {e}")
70
+
71
+ raise FileNotFoundError(
72
+ f"Shared library with base name '{lib_base_name}' not found"
73
+ )
74
+
75
+
76
+ # ctypes sane type hint helpers
77
+ #
78
+ # - Generic Pointer and Array types
79
+ # - PointerOrRef type with a type hinted byref function
80
+ #
81
+ # NOTE: Only use these for static type checking not for runtime checks
82
+ # no good will come of that
83
+
84
+ if TYPE_CHECKING:
85
+ CtypesCData = TypeVar("CtypesCData", bound=ctypes._CData) # type: ignore
86
+
87
+ CtypesArray: TypeAlias = ctypes.Array[CtypesCData] # type: ignore
88
+
89
+ CtypesPointer: TypeAlias = ctypes._Pointer[CtypesCData] # type: ignore
90
+
91
+ CtypesVoidPointer: TypeAlias = ctypes.c_void_p
92
+
93
+ class CtypesRef(Generic[CtypesCData]):
94
+ pass
95
+
96
+ CtypesPointerOrRef: TypeAlias = Union[
97
+ CtypesPointer[CtypesCData], CtypesRef[CtypesCData]
98
+ ]
99
+
100
+ CtypesFuncPointer: TypeAlias = ctypes._FuncPointer # type: ignore
101
+
102
+ F = TypeVar("F", bound=Callable[..., Any])
103
+
104
+
105
+ def ctypes_function_for_shared_library(lib: ctypes.CDLL):
106
+ """Decorator for defining ctypes functions with type hints"""
107
+
108
+ def ctypes_function(
109
+ name: str, argtypes: List[Any], restype: Any, enabled: bool = True
110
+ ):
111
+ def decorator(f: F) -> F:
112
+ if enabled:
113
+ func = getattr(lib, name)
114
+ func.argtypes = argtypes
115
+ func.restype = restype
116
+ functools.wraps(f)(func)
117
+ return func
118
+ else:
119
+ return f
120
+
121
+ return decorator
122
+
123
+ return ctypes_function
124
+
125
+
126
+ def _byref(obj: CtypesCData, offset: Optional[int] = None) -> CtypesRef[CtypesCData]:
127
+ """Type-annotated version of ctypes.byref"""
128
+ ...
129
+
130
+
131
+ byref = _byref if TYPE_CHECKING else ctypes.byref
llama_cpp/_ggml.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Internal module use at your own risk
2
+
3
+ This module provides a minimal interface for working with ggml tensors from llama-cpp-python
4
+ """
5
+ import os
6
+ import pathlib
7
+
8
+ import llama_cpp._ctypes_extensions as ctypes_ext
9
+
10
+ libggml_base_path = pathlib.Path(os.path.abspath(os.path.dirname(__file__))) / "lib"
11
+ libggml = ctypes_ext.load_shared_library("ggml", libggml_base_path)
12
+
llama_cpp/_internals.py ADDED
@@ -0,0 +1,856 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import ctypes
5
+
6
+ from typing import (
7
+ Dict,
8
+ List,
9
+ Tuple,
10
+ Optional,
11
+ Sequence,
12
+ Callable,
13
+ Union,
14
+ )
15
+ from dataclasses import dataclass, field
16
+ from contextlib import ExitStack
17
+
18
+ import numpy as np
19
+ import numpy.typing as npt
20
+
21
+ from .llama_types import *
22
+ from .llama_grammar import LlamaGrammar
23
+ from ._utils import suppress_stdout_stderr
24
+
25
+ import llama_cpp.llama_cpp as llama_cpp
26
+
27
+
28
+ # Python wrappers over llama.h structs
29
+
30
+
31
+ class LlamaModel:
32
+ """Intermediate Python wrapper for a llama.cpp llama_model.
33
+ NOTE: For stability it's recommended you use the Llama class instead."""
34
+
35
+ def __init__(
36
+ self,
37
+ *,
38
+ path_model: str,
39
+ params: llama_cpp.llama_model_params,
40
+ verbose: bool = True,
41
+ ):
42
+ self.path_model = path_model
43
+ self.params = params
44
+ self.verbose = verbose
45
+ self._exit_stack = ExitStack()
46
+
47
+ model = None
48
+
49
+ if not os.path.exists(path_model):
50
+ raise ValueError(f"Model path does not exist: {path_model}")
51
+
52
+ with suppress_stdout_stderr(disable=verbose):
53
+ model = llama_cpp.llama_model_load_from_file(
54
+ self.path_model.encode("utf-8"), self.params
55
+ )
56
+
57
+ if model is None:
58
+ raise ValueError(f"Failed to load model from file: {path_model}")
59
+
60
+ vocab = llama_cpp.llama_model_get_vocab(model)
61
+
62
+ if vocab is None:
63
+ raise ValueError(f"Failed to get vocab from model: {path_model}")
64
+
65
+ self.model = model
66
+ self.vocab = vocab
67
+ self.sampler = None # LlamaModel doesn't use samplers, but some cleanup code expects this attribute
68
+
69
+ def free_model():
70
+ if self.model is None:
71
+ return
72
+ llama_cpp.llama_model_free(self.model)
73
+ self.model = None
74
+
75
+ self._exit_stack.callback(free_model)
76
+
77
+ def close(self):
78
+ if self.sampler is not None:
79
+ # NOTE: Must remove custom samplers before free or llama.cpp will try to free them
80
+ for i, _ in reversed(self.custom_samplers):
81
+ llama_cpp.llama_sampler_chain_remove(self.sampler, i)
82
+ self.custom_samplers.clear()
83
+ self._exit_stack.close()
84
+
85
+ def __del__(self):
86
+ self.close()
87
+
88
+ def vocab_type(self) -> int:
89
+ return llama_cpp.llama_vocab_type(self.vocab)
90
+
91
+ def n_vocab(self) -> int:
92
+ return llama_cpp.llama_vocab_n_tokens(self.vocab)
93
+
94
+ def n_ctx_train(self) -> int:
95
+ return llama_cpp.llama_model_n_ctx_train(self.model)
96
+
97
+ def n_embd(self) -> int:
98
+ return llama_cpp.llama_model_n_embd(self.model)
99
+
100
+ def rope_freq_scale_train(self) -> float:
101
+ return llama_cpp.llama_model_rope_freq_scale_train(self.model)
102
+
103
+ def desc(self) -> str:
104
+ buf = ctypes.create_string_buffer(1024)
105
+ llama_cpp.llama_model_desc(self.model, buf, 1024)
106
+ return buf.value.decode("utf-8")
107
+
108
+ def size(self) -> int:
109
+ return llama_cpp.llama_model_size(self.model)
110
+
111
+ def n_params(self) -> int:
112
+ return llama_cpp.llama_model_n_params(self.model)
113
+
114
+ def get_tensor(self, name: str) -> ctypes.c_void_p:
115
+ raise NotImplementedError("get_tensor is not implemented in llama.cpp")
116
+
117
+ # Vocab
118
+
119
+ def token_get_text(self, token: int) -> str:
120
+ return llama_cpp.llama_vocab_get_text(self.vocab, token).decode("utf-8")
121
+
122
+ def token_get_score(self, token: int) -> float:
123
+ return llama_cpp.llama_vocab_get_score(self.vocab, token)
124
+
125
+ def token_get_attr(self, token: int) -> int:
126
+ return llama_cpp.llama_vocab_get_attr(self.vocab, token)
127
+
128
+ # Special tokens
129
+
130
+ def token_bos(self) -> int:
131
+ return llama_cpp.llama_vocab_bos(self.vocab)
132
+
133
+ def token_eos(self) -> int:
134
+ return llama_cpp.llama_vocab_eos(self.vocab)
135
+
136
+ def token_cls(self) -> int:
137
+ return llama_cpp.llama_vocab_cls(self.vocab)
138
+
139
+ def token_sep(self) -> int:
140
+ return llama_cpp.llama_vocab_sep(self.vocab)
141
+
142
+ def token_nl(self) -> int:
143
+ return llama_cpp.llama_vocab_nl(self.vocab)
144
+
145
+ def token_prefix(self) -> int:
146
+ return llama_cpp.llama_vocab_fim_pre(self.vocab)
147
+
148
+ def token_middle(self) -> int:
149
+ return llama_cpp.llama_vocab_fim_mid(self.vocab)
150
+
151
+ def token_suffix(self) -> int:
152
+ return llama_cpp.llama_vocab_fim_suf(self.vocab)
153
+
154
+ def token_eot(self) -> int:
155
+ return llama_cpp.llama_vocab_eot(self.vocab)
156
+
157
+ def add_bos_token(self) -> bool:
158
+ return llama_cpp.llama_vocab_get_add_bos(self.vocab)
159
+
160
+ def add_eos_token(self) -> bool:
161
+ return llama_cpp.llama_vocab_get_add_eos(self.vocab)
162
+
163
+ # Tokenization
164
+
165
+ def tokenize(self, text: bytes, add_bos: bool, special: bool):
166
+ n_ctx = self.n_ctx_train()
167
+ tokens = (llama_cpp.llama_token * n_ctx)()
168
+ n_tokens = llama_cpp.llama_tokenize(
169
+ self.vocab, text, len(text), tokens, n_ctx, add_bos, special
170
+ )
171
+ if n_tokens < 0:
172
+ n_tokens = abs(n_tokens)
173
+ tokens = (llama_cpp.llama_token * n_tokens)()
174
+ n_tokens = llama_cpp.llama_tokenize(
175
+ self.vocab, text, len(text), tokens, n_tokens, add_bos, special
176
+ )
177
+ if n_tokens < 0:
178
+ raise RuntimeError(
179
+ f'Failed to tokenize: text="{text}" n_tokens={n_tokens}'
180
+ )
181
+ return list(tokens[:n_tokens])
182
+
183
+ def token_to_piece(self, token: int, special: bool = False) -> bytes:
184
+ buf = ctypes.create_string_buffer(32)
185
+ llama_cpp.llama_token_to_piece(self.vocab, token, buf, 32, 0, special)
186
+ return bytes(buf)
187
+
188
+ def detokenize(self, tokens: List[int], special: bool = False) -> bytes:
189
+ output = b""
190
+ size = 32
191
+ buffer = (ctypes.c_char * size)()
192
+ for token in tokens:
193
+ n = llama_cpp.llama_token_to_piece(
194
+ self.vocab, llama_cpp.llama_token(token), buffer, size, 0, special
195
+ )
196
+ assert n <= size
197
+ output += bytes(buffer[:n])
198
+ # NOTE: Llama1 models automatically added a space at the start of the prompt
199
+ # this line removes a leading space if the first token is a beginning of sentence token
200
+ return (
201
+ output[1:]
202
+ if len(tokens) > 0 and tokens[0] == self.token_bos() and output[0:1] == b" "
203
+ else output
204
+ )
205
+
206
+ # Extra
207
+ def metadata(self) -> Dict[str, str]:
208
+ metadata: Dict[str, str] = {}
209
+ buffer_size = 1024
210
+ buffer = ctypes.create_string_buffer(buffer_size)
211
+ # zero the buffer
212
+ buffer.value = b"\0" * buffer_size
213
+ # iterate over model keys
214
+ for i in range(llama_cpp.llama_model_meta_count(self.model)):
215
+ nbytes = llama_cpp.llama_model_meta_key_by_index(
216
+ self.model, i, buffer, buffer_size
217
+ )
218
+ if nbytes > buffer_size:
219
+ buffer_size = nbytes + 1
220
+ buffer = ctypes.create_string_buffer(buffer_size)
221
+ nbytes = llama_cpp.llama_model_meta_key_by_index(
222
+ self.model, i, buffer, buffer_size
223
+ )
224
+ key = buffer.value.decode("utf-8")
225
+ nbytes = llama_cpp.llama_model_meta_val_str_by_index(
226
+ self.model, i, buffer, buffer_size
227
+ )
228
+ if nbytes > buffer_size:
229
+ buffer_size = nbytes + 1
230
+ buffer = ctypes.create_string_buffer(buffer_size)
231
+ nbytes = llama_cpp.llama_model_meta_val_str_by_index(
232
+ self.model, i, buffer, buffer_size
233
+ )
234
+ value = buffer.value.decode("utf-8")
235
+ metadata[key] = value
236
+ return metadata
237
+
238
+ @staticmethod
239
+ def default_params():
240
+ """Get the default llama_model_params."""
241
+ return llama_cpp.llama_model_default_params()
242
+
243
+
244
+ class LlamaContext:
245
+ """Intermediate Python wrapper for a llama.cpp llama_context.
246
+ NOTE: For stability it's recommended you use the Llama class instead."""
247
+
248
+ def __init__(
249
+ self,
250
+ *,
251
+ model: LlamaModel,
252
+ params: llama_cpp.llama_context_params,
253
+ verbose: bool = True,
254
+ ):
255
+ self.model = model
256
+ self.params = params
257
+ self.verbose = verbose
258
+ self._exit_stack = ExitStack()
259
+
260
+ ctx = llama_cpp.llama_init_from_model(self.model.model, self.params)
261
+
262
+ if ctx is None:
263
+ raise ValueError("Failed to create llama_context")
264
+
265
+ self.ctx = ctx
266
+ self.memory = llama_cpp.llama_get_memory(self.ctx)
267
+ self.sampler = None # LlamaContext doesn't manage samplers directly, but some cleanup code expects this attribute
268
+
269
+ def free_ctx():
270
+ if self.ctx is None:
271
+ return
272
+ llama_cpp.llama_free(self.ctx)
273
+ self.ctx = None
274
+
275
+ self._exit_stack.callback(free_ctx)
276
+
277
+ def close(self):
278
+ self._exit_stack.close()
279
+
280
+ def __del__(self):
281
+ self.close()
282
+
283
+ def n_ctx(self) -> int:
284
+ return llama_cpp.llama_n_ctx(self.ctx)
285
+
286
+ def pooling_type(self) -> int:
287
+ return llama_cpp.llama_pooling_type(self.ctx)
288
+
289
+ def kv_cache_clear(self):
290
+ assert self.memory is not None, "Memory is not initialized"
291
+ llama_cpp.llama_memory_clear(self.memory, True)
292
+
293
+ def kv_cache_seq_rm(self, seq_id: int, p0: int, p1: int):
294
+ assert self.memory is not None, "Memory is not initialized"
295
+ seq_id = seq_id if seq_id >= 0 else 0
296
+ llama_cpp.llama_memory_seq_rm(self.memory, seq_id, p0, p1)
297
+
298
+ def kv_cache_seq_cp(self, seq_id_src: int, seq_id_dst: int, p0: int, p1: int):
299
+ assert self.memory is not None, "Memory is not initialized"
300
+ llama_cpp.llama_memory_seq_cp(self.memory, seq_id_src, seq_id_dst, p0, p1)
301
+
302
+ def kv_cache_seq_keep(self, seq_id: int):
303
+ assert self.memory is not None, "Memory is not initialized"
304
+ llama_cpp.llama_memory_seq_keep(self.memory, seq_id)
305
+
306
+ def kv_cache_seq_shift(self, seq_id: int, p0: int, p1: int, shift: int):
307
+ assert self.memory is not None, "Memory is not initialized"
308
+ llama_cpp.llama_memory_seq_add(self.memory, seq_id, p0, p1, shift)
309
+
310
+ def get_state_size(self) -> int:
311
+ return llama_cpp.llama_state_get_size(self.ctx)
312
+
313
+ # TODO: copy_state_data
314
+
315
+ # TODO: set_state_data
316
+
317
+ # TODO: llama_load_session_file
318
+
319
+ # TODO: llama_save_session_file
320
+
321
+ def decode(self, batch: LlamaBatch):
322
+ return_code = llama_cpp.llama_decode(
323
+ self.ctx,
324
+ batch.batch,
325
+ )
326
+ if return_code != 0:
327
+ raise RuntimeError(f"llama_decode returned {return_code}")
328
+
329
+ def encode(self, batch: LlamaBatch):
330
+ return_code = llama_cpp.llama_encode(
331
+ self.ctx,
332
+ batch.batch,
333
+ )
334
+ if return_code != 0:
335
+ raise RuntimeError(f"llama_encode returned {return_code}")
336
+
337
+ def set_n_threads(self, n_threads: int, n_threads_batch: int):
338
+ llama_cpp.llama_set_n_threads(self.ctx, n_threads, n_threads_batch)
339
+
340
+ def get_logits(self):
341
+ return llama_cpp.llama_get_logits(self.ctx)
342
+
343
+ def get_logits_ith(self, i: int):
344
+ return llama_cpp.llama_get_logits_ith(self.ctx, i)
345
+
346
+ def get_embeddings(self):
347
+ return llama_cpp.llama_get_embeddings(self.ctx)
348
+
349
+ def get_embeddings_ith(self, i: int):
350
+ return llama_cpp.llama_get_embeddings_ith(self.ctx, i)
351
+
352
+ def get_embeddings_seq(self, seq_id: int):
353
+ return llama_cpp.llama_get_embeddings_seq(self.ctx, seq_id)
354
+
355
+ # Sampling functions - deprecated, use LlamaSampler instead
356
+
357
+ def set_rng_seed(self, seed: int):
358
+ raise NotImplementedError("set_rng_seed is deprecated, use LlamaSampler instead")
359
+
360
+ def sample_repetition_penalties(
361
+ self,
362
+ candidates: "_LlamaTokenDataArray",
363
+ last_tokens_data: "llama_cpp.Array[llama_cpp.llama_token]",
364
+ penalty_last_n: int,
365
+ penalty_repeat: float,
366
+ penalty_freq: float,
367
+ penalty_present: float,
368
+ ):
369
+ raise NotImplementedError("sample_repetition_penalties is deprecated, use LlamaSampler instead")
370
+
371
+ def sample_softmax(self, candidates: "_LlamaTokenDataArray"):
372
+ raise NotImplementedError("sample_softmax is deprecated, use LlamaSampler instead")
373
+
374
+ def sample_top_k(self, candidates: "_LlamaTokenDataArray", k: int, min_keep: int):
375
+ raise NotImplementedError("sample_top_k is deprecated, use LlamaSampler instead")
376
+
377
+ def sample_top_p(self, candidates: "_LlamaTokenDataArray", p: float, min_keep: int):
378
+ raise NotImplementedError("sample_top_p is deprecated, use LlamaSampler instead")
379
+
380
+ def sample_min_p(self, candidates: "_LlamaTokenDataArray", p: float, min_keep: int):
381
+ raise NotImplementedError("sample_min_p is deprecated, use LlamaSampler instead")
382
+
383
+ def sample_typical(
384
+ self, candidates: "_LlamaTokenDataArray", p: float, min_keep: int
385
+ ):
386
+ raise NotImplementedError("sample_typical is deprecated, use LlamaSampler instead")
387
+
388
+ def sample_temp(self, candidates: "_LlamaTokenDataArray", temp: float):
389
+ raise NotImplementedError("sample_temp is deprecated, use LlamaSampler instead")
390
+
391
+ def sample_grammar(self, candidates: "_LlamaTokenDataArray", grammar: LlamaGrammar):
392
+ raise NotImplementedError("sample_grammar is deprecated, use LlamaSampler instead")
393
+
394
+ def sample_token_mirostat(
395
+ self,
396
+ candidates: "_LlamaTokenDataArray",
397
+ tau: float,
398
+ eta: float,
399
+ m: int,
400
+ mu: llama_cpp.CtypesPointerOrRef[ctypes.c_float],
401
+ ) -> int:
402
+ raise NotImplementedError("sample_token_mirostat is deprecated, use LlamaSampler instead")
403
+
404
+ def sample_token_mirostat_v2(
405
+ self,
406
+ candidates: "_LlamaTokenDataArray",
407
+ tau: float,
408
+ eta: float,
409
+ mu: llama_cpp.CtypesPointerOrRef[ctypes.c_float],
410
+ ) -> int:
411
+ raise NotImplementedError("sample_token_mirostat_v2 is deprecated, use LlamaSampler instead")
412
+
413
+ def sample_token_greedy(self, candidates: "_LlamaTokenDataArray") -> int:
414
+ raise NotImplementedError("sample_token_greedy is deprecated, use LlamaSampler instead")
415
+
416
+ def sample_token(self, candidates: "_LlamaTokenDataArray") -> int:
417
+ raise NotImplementedError("sample_token is deprecated, use LlamaSampler instead")
418
+
419
+ # Grammar
420
+ def grammar_accept_token(self, grammar: LlamaGrammar, token: int):
421
+ raise NotImplementedError("grammar_accept_token is deprecated, use LlamaSampler instead")
422
+
423
+ def reset_timings(self):
424
+ llama_cpp.llama_perf_context_reset(self.ctx)
425
+
426
+ def print_timings(self):
427
+ llama_cpp.llama_perf_context_print(self.ctx)
428
+
429
+ # Utility functions
430
+ @staticmethod
431
+ def default_params():
432
+ """Get the default llama_context_params."""
433
+ return llama_cpp.llama_context_default_params()
434
+
435
+
436
+ class LlamaBatch:
437
+ def __init__(
438
+ self, *, n_tokens: int, embd: int, n_seq_max: int, verbose: bool = True
439
+ ):
440
+ self._n_tokens = n_tokens
441
+ self.embd = embd
442
+ self.n_seq_max = n_seq_max
443
+ self.verbose = verbose
444
+ self._exit_stack = ExitStack()
445
+
446
+ batch = llama_cpp.llama_batch_init(self._n_tokens, self.embd, self.n_seq_max)
447
+
448
+ if batch is None:
449
+ raise ValueError("Failed to create llama_batch")
450
+
451
+ self.batch = batch
452
+ self.sampler = None # LlamaBatch doesn't use samplers, but some cleanup code expects this attribute
453
+
454
+ def free_batch():
455
+ if self.batch is None:
456
+ return
457
+ llama_cpp.llama_batch_free(self.batch)
458
+ self.batch = None
459
+
460
+ self._exit_stack.callback(free_batch)
461
+
462
+ def close(self):
463
+ self._exit_stack.close()
464
+
465
+ def __del__(self):
466
+ self.close()
467
+
468
+ def n_tokens(self) -> int:
469
+ return self.batch.n_tokens
470
+
471
+ def reset(self):
472
+ self.batch.n_tokens = 0
473
+
474
+ def set_batch(self, batch: Sequence[int], n_past: int, logits_all: bool):
475
+ n_tokens = len(batch)
476
+ self.batch.n_tokens = n_tokens
477
+ for i in range(n_tokens):
478
+ self.batch.token[i] = batch[i]
479
+ self.batch.pos[i] = n_past + i
480
+ self.batch.seq_id[i][0] = 0
481
+ self.batch.n_seq_id[i] = 1
482
+ self.batch.logits[i] = logits_all
483
+ self.batch.logits[n_tokens - 1] = True
484
+
485
+ def add_sequence(self, batch: Sequence[int], seq_id: int, logits_all: bool):
486
+ n_tokens = len(batch)
487
+ n_tokens0 = self.batch.n_tokens
488
+ self.batch.n_tokens += n_tokens
489
+ for i in range(n_tokens):
490
+ j = n_tokens0 + i
491
+ self.batch.token[j] = batch[i]
492
+ self.batch.pos[j] = i
493
+ self.batch.seq_id[j][0] = seq_id
494
+ self.batch.n_seq_id[j] = 1
495
+ self.batch.logits[j] = logits_all
496
+ self.batch.logits[n_tokens - 1] = True
497
+
498
+
499
+ class LlamaTokenDataArray:
500
+ def __init__(self, *, n_vocab: int):
501
+ self.n_vocab = n_vocab
502
+ self.candidates_data = np.recarray(
503
+ (self.n_vocab,),
504
+ dtype=np.dtype(
505
+ [("id", np.intc), ("logit", np.single), ("p", np.single)], align=True
506
+ ),
507
+ )
508
+ self.candidates = llama_cpp.llama_token_data_array(
509
+ data=self.candidates_data.ctypes.data_as(llama_cpp.llama_token_data_p),
510
+ size=self.n_vocab,
511
+ sorted=False,
512
+ )
513
+ self.default_candidates_data_id = np.arange(self.n_vocab, dtype=np.intc) # type: ignore
514
+ self.default_candidates_data_p = np.zeros(self.n_vocab, dtype=np.single)
515
+ self.sampler = None # LlamaTokenDataArray doesn't use samplers, but some cleanup code expects this attribute
516
+
517
+ def copy_logits(self, logits: npt.NDArray[np.single]):
518
+ self.candidates_data.id[:] = self.default_candidates_data_id
519
+ self.candidates_data.logit[:] = logits
520
+ self.candidates_data.p[:] = self.default_candidates_data_p
521
+ self.candidates.sorted = False
522
+ self.candidates.size = self.n_vocab
523
+
524
+
525
+ # Embedding functions
526
+
527
+
528
+ def normalize_embedding(embedding):
529
+ norm = float(np.linalg.norm(embedding))
530
+ if norm == 0.0:
531
+ return embedding
532
+ return [v / norm for v in embedding]
533
+
534
+
535
+ # Python wrappers over common/sampling structs
536
+
537
+
538
+ @dataclass
539
+ class LlamaSamplingParams:
540
+ n_prev: int = 64
541
+ n_probs: int = 0
542
+ top_k: int = 40
543
+ top_p: float = 0.95
544
+ min_p: float = 0.05
545
+ tfs_z: float = 1.00
546
+ typical_p: float = 1.00
547
+ temp: float = 0.80
548
+ penalty_last_n: int = 64
549
+ penalty_repeat: float = 1.0
550
+ penalty_freq: float = 0.00
551
+ penalty_present: float = 0.00
552
+ mirostat: int = 0
553
+ mirostat_tau: float = 5.00
554
+ mirostat_eta: float = 0.10
555
+ penalize_nl: bool = True
556
+
557
+ grammar: str = ""
558
+
559
+ cfg_negative_prompt: str = ""
560
+ cfg_scale: float = 1.00
561
+
562
+ logit_bias: dict[int, float] = field(default_factory=dict)
563
+
564
+
565
+ @dataclass
566
+ class LlamaSamplingContext:
567
+ params: LlamaSamplingParams = field(default_factory=LlamaSamplingParams)
568
+ mirostat_mu: ctypes.c_float = field(default_factory=ctypes.c_float)
569
+ grammar: Optional[LlamaGrammar] = None
570
+ # NOTE: Missing parsed_grammar
571
+ prev: list[int] = field(default_factory=list)
572
+ cur: list[llama_cpp.llama_token_data] = field(default_factory=list)
573
+
574
+ def reset(self):
575
+ self.prev = []
576
+ self.cur = []
577
+ if self.grammar is not None:
578
+ self.grammar.reset()
579
+
580
+ def cp(self):
581
+ return LlamaSamplingContext(
582
+ params=self.params,
583
+ mirostat_mu=self.mirostat_mu,
584
+ grammar=self.grammar,
585
+ prev=self.prev.copy(),
586
+ cur=self.cur.copy(),
587
+ )
588
+
589
+ def last(self) -> Optional[int]:
590
+ if len(self.prev) > 0:
591
+ return self.prev[-1]
592
+ else:
593
+ return None
594
+
595
+ def prev_str(self, ctx_main: LlamaContext, n: int) -> str:
596
+ return ctx_main.model.detokenize(self.prev[-n:]).decode("utf-8")
597
+
598
+ def sample(
599
+ self,
600
+ ctx_main: LlamaContext,
601
+ idx: int = 0,
602
+ logits_array: Optional[npt.NDArray[np.single]] = None,
603
+ ):
604
+ # This method is deprecated in favor of using LlamaSampler directly
605
+ raise NotImplementedError("LlamaSamplingContext.sample is deprecated, use LlamaSampler instead")
606
+
607
+ def accept(self, ctx_main: LlamaContext, id: int, apply_grammar: bool):
608
+ self.prev.append(id)
609
+
610
+
611
+ class CustomSampler:
612
+ def __init__(
613
+ self, apply_func: Callable[[llama_cpp.llama_token_data_array], None]
614
+ ):
615
+ self.apply_func = apply_func
616
+
617
+ def apply_wrapper(
618
+ sampler: llama_cpp.llama_sampler_p,
619
+ cur_p: llama_cpp.llama_token_data_array_p,
620
+ ):
621
+ self.apply_func(cur_p)
622
+
623
+ def free_wrapper(sampler: llama_cpp.llama_sampler_p):
624
+ pass
625
+
626
+ sampler_i = llama_cpp.llama_sampler_i()
627
+ sampler_i.apply = llama_cpp.llama_sampler_i_apply(apply_wrapper)
628
+ self._apply_wrapper_ref = apply_wrapper
629
+
630
+ sampler_i.name = llama_cpp.llama_sampler_i_name(0)
631
+ sampler_i.accept = llama_cpp.llama_sampler_i_accept(0)
632
+ sampler_i.reset = llama_cpp.llama_sampler_i_reset(0)
633
+ sampler_i.clone = llama_cpp.llama_sampler_i_clone(0)
634
+ sampler_i.free = llama_cpp.llama_sampler_i_free(0)
635
+
636
+ self.sampler = llama_cpp.llama_sampler()
637
+ self.sampler.iface = ctypes.pointer(sampler_i)
638
+ self.sampler.ctx = None
639
+
640
+ def get_sampler(self) -> llama_cpp.llama_sampler_p:
641
+ return ctypes.pointer(self.sampler)
642
+
643
+
644
+ class LlamaSampler:
645
+ def __init__(self):
646
+ params = llama_cpp.llama_sampler_chain_default_params()
647
+ self.sampler = llama_cpp.llama_sampler_chain_init(params)
648
+ self.custom_samplers: List[Tuple[int, CustomSampler]] = []
649
+ self._exit_stack = ExitStack()
650
+
651
+ def free_sampler():
652
+ if self.sampler is not None:
653
+ # NOTE: Must remove custom samplers before free or llama.cpp will try to free them
654
+ for i, _ in reversed(self.custom_samplers):
655
+ llama_cpp.llama_sampler_chain_remove(self.sampler, i)
656
+ llama_cpp.llama_sampler_free(self.sampler)
657
+ self.sampler = None
658
+
659
+ self._exit_stack.callback(free_sampler)
660
+
661
+ def close(self):
662
+ self._exit_stack.close()
663
+
664
+ def __del__(self):
665
+ self.close()
666
+
667
+ def add_greedy(self):
668
+ sampler = llama_cpp.llama_sampler_init_greedy()
669
+ llama_cpp.llama_sampler_chain_add(self.sampler, sampler)
670
+
671
+ def add_dist(self, seed: int):
672
+ sampler = llama_cpp.llama_sampler_init_dist(seed)
673
+ llama_cpp.llama_sampler_chain_add(self.sampler, sampler)
674
+
675
+ def add_softmax(self):
676
+ sampler = llama_cpp.llama_sampler_init_softmax()
677
+ llama_cpp.llama_sampler_chain_add(self.sampler, sampler)
678
+
679
+ def add_top_k(self, k: int):
680
+ sampler = llama_cpp.llama_sampler_init_top_k(k)
681
+ llama_cpp.llama_sampler_chain_add(self.sampler, sampler)
682
+
683
+ def add_top_p(self, p: float, min_keep: int = 1):
684
+ sampler = llama_cpp.llama_sampler_init_top_p(p, min_keep)
685
+ llama_cpp.llama_sampler_chain_add(self.sampler, sampler)
686
+
687
+ def add_min_p(self, p: float, min_keep: int = 1):
688
+ sampler = llama_cpp.llama_sampler_init_min_p(p, min_keep)
689
+ llama_cpp.llama_sampler_chain_add(self.sampler, sampler)
690
+
691
+ def add_typical(self, p: float, min_keep: int = 1):
692
+ sampler = llama_cpp.llama_sampler_init_typical(p, min_keep)
693
+ llama_cpp.llama_sampler_chain_add(self.sampler, sampler)
694
+
695
+ def add_temp(self, temp: float):
696
+ sampler = llama_cpp.llama_sampler_init_temp(temp)
697
+ llama_cpp.llama_sampler_chain_add(self.sampler, sampler)
698
+
699
+ def add_temp_ext(self, t: float, delta: float, exponent: float):
700
+ sampler = llama_cpp.llama_sampler_init_temp_ext(t, delta, exponent)
701
+ llama_cpp.llama_sampler_chain_add(self.sampler, sampler)
702
+
703
+ def add_xtc(self, p: float, t: float, min_keep: int, seed: int):
704
+ sampler = llama_cpp.llama_sampler_init_xtc(p, t, min_keep, seed)
705
+ llama_cpp.llama_sampler_chain_add(self.sampler, sampler)
706
+
707
+ def add_top_n_sigma(self, n: float):
708
+ sampler = llama_cpp.llama_sampler_init_top_n_sigma(n)
709
+ llama_cpp.llama_sampler_chain_add(self.sampler, sampler)
710
+
711
+ def add_mirostat(self, n_vocab: int, seed: int, tau: float, eta: float, m: int):
712
+ sampler = llama_cpp.llama_sampler_init_mirostat(n_vocab, seed, tau, eta, m)
713
+ llama_cpp.llama_sampler_chain_add(self.sampler, sampler)
714
+
715
+ def add_mirostat_v2(self, seed: int, tau: float, eta: float):
716
+ sampler = llama_cpp.llama_sampler_init_mirostat_v2(seed, tau, eta)
717
+ llama_cpp.llama_sampler_chain_add(self.sampler, sampler)
718
+
719
+ def add_grammar(self, model: LlamaModel, grammar: LlamaGrammar):
720
+ sampler = llama_cpp.llama_sampler_init_grammar(
721
+ model.vocab, grammar._grammar.encode("utf-8"), grammar._root.encode("utf-8")
722
+ )
723
+ llama_cpp.llama_sampler_chain_add(self.sampler, sampler)
724
+
725
+ def add_grammar_lazy_patterns(
726
+ self,
727
+ model: LlamaModel,
728
+ grammar: LlamaGrammar,
729
+ trigger_patterns: List[str],
730
+ trigger_tokens: List[int]
731
+ ):
732
+ # Convert patterns to C array
733
+ pattern_ptrs = (ctypes.c_char_p * len(trigger_patterns))()
734
+ for i, pattern in enumerate(trigger_patterns):
735
+ pattern_ptrs[i] = pattern.encode("utf-8")
736
+
737
+ # Convert tokens to C array
738
+ token_array = (llama_cpp.llama_token * len(trigger_tokens))(*trigger_tokens)
739
+
740
+ sampler = llama_cpp.llama_sampler_init_grammar_lazy_patterns(
741
+ model.vocab,
742
+ grammar._grammar.encode("utf-8"),
743
+ grammar._root.encode("utf-8"),
744
+ pattern_ptrs,
745
+ len(trigger_patterns),
746
+ token_array,
747
+ len(trigger_tokens)
748
+ )
749
+ llama_cpp.llama_sampler_chain_add(self.sampler, sampler)
750
+
751
+ def add_penalties(
752
+ self,
753
+ penalty_last_n: int,
754
+ penalty_repeat: float,
755
+ penalty_freq: float,
756
+ penalty_present: float,
757
+ ):
758
+ sampler = llama_cpp.llama_sampler_init_penalties(
759
+ penalty_last_n,
760
+ penalty_repeat,
761
+ penalty_freq,
762
+ penalty_present,
763
+ )
764
+ llama_cpp.llama_sampler_chain_add(self.sampler, sampler)
765
+
766
+ def add_dry(
767
+ self,
768
+ model: LlamaModel,
769
+ n_ctx_train: int,
770
+ dry_multiplier: float,
771
+ dry_base: float,
772
+ dry_allowed_length: int,
773
+ dry_penalty_last_n: int,
774
+ seq_breakers: List[str]
775
+ ):
776
+ # Convert seq_breakers to C array
777
+ breaker_ptrs = (ctypes.c_char_p * len(seq_breakers))()
778
+ for i, breaker in enumerate(seq_breakers):
779
+ breaker_ptrs[i] = breaker.encode("utf-8")
780
+
781
+ sampler = llama_cpp.llama_sampler_init_dry(
782
+ model.vocab,
783
+ n_ctx_train,
784
+ dry_multiplier,
785
+ dry_base,
786
+ dry_allowed_length,
787
+ dry_penalty_last_n,
788
+ breaker_ptrs,
789
+ len(seq_breakers)
790
+ )
791
+ llama_cpp.llama_sampler_chain_add(self.sampler, sampler)
792
+
793
+ def add_logit_bias(
794
+ self,
795
+ n_vocab: int,
796
+ logit_bias: Dict[int, float]
797
+ ):
798
+ # Convert logit_bias dict to C array
799
+ bias_array = (llama_cpp.llama_logit_bias * len(logit_bias))()
800
+ for i, (token, bias) in enumerate(logit_bias.items()):
801
+ bias_array[i].token = token
802
+ bias_array[i].bias = bias
803
+
804
+ sampler = llama_cpp.llama_sampler_init_logit_bias(
805
+ n_vocab,
806
+ len(logit_bias),
807
+ bias_array
808
+ )
809
+ llama_cpp.llama_sampler_chain_add(self.sampler, sampler)
810
+
811
+ def add_infill(self, model: LlamaModel):
812
+ sampler = llama_cpp.llama_sampler_init_infill(model.vocab)
813
+ llama_cpp.llama_sampler_chain_add(self.sampler, sampler)
814
+
815
+ def add_custom(
816
+ self, apply_func: Callable[[llama_cpp.llama_token_data_array], None]
817
+ ):
818
+ custom_sampler = CustomSampler(apply_func)
819
+ sampler = custom_sampler.get_sampler()
820
+ llama_cpp.llama_sampler_chain_add(self.sampler, sampler)
821
+ # NOTE: Must remove custom samplers before free or llama.cpp will try to free them
822
+ self.custom_samplers.append(
823
+ (llama_cpp.llama_sampler_chain_n(self.sampler) - 1, custom_sampler)
824
+ )
825
+
826
+ def get_seed(self) -> int:
827
+ return llama_cpp.llama_sampler_get_seed(self.sampler)
828
+
829
+ def sample(self, ctx: LlamaContext, idx: int = -1) -> int:
830
+ return llama_cpp.llama_sampler_sample(self.sampler, ctx.ctx, idx)
831
+
832
+ def accept(self, token: int):
833
+ llama_cpp.llama_sampler_accept(self.sampler, token)
834
+
835
+ def reset(self):
836
+ llama_cpp.llama_sampler_reset(self.sampler)
837
+
838
+ def clone(self):
839
+ # NOTE: Custom samplers cannot be cloned due to Python callback limitations
840
+ if self.custom_samplers:
841
+ raise NotImplementedError("Cannot clone LlamaSampler that contains custom samplers")
842
+
843
+ cloned_sampler = llama_cpp.llama_sampler_clone(self.sampler)
844
+ # Create a new wrapper around the cloned sampler
845
+ new_sampler = LlamaSampler.__new__(LlamaSampler)
846
+ new_sampler.sampler = cloned_sampler
847
+ new_sampler.custom_samplers = []
848
+ new_sampler._exit_stack = ExitStack()
849
+
850
+ def free_sampler():
851
+ if new_sampler.sampler is not None:
852
+ llama_cpp.llama_sampler_free(new_sampler.sampler)
853
+ new_sampler.sampler = None
854
+
855
+ new_sampler._exit_stack.callback(free_sampler)
856
+ return new_sampler
llama_cpp/_logger.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import ctypes
3
+ import logging
4
+
5
+ import llama_cpp
6
+
7
+ # enum ggml_log_level {
8
+ # GGML_LOG_LEVEL_NONE = 0,
9
+ # GGML_LOG_LEVEL_INFO = 1,
10
+ # GGML_LOG_LEVEL_WARN = 2,
11
+ # GGML_LOG_LEVEL_ERROR = 3,
12
+ # GGML_LOG_LEVEL_DEBUG = 4,
13
+ # GGML_LOG_LEVEL_CONT = 5, // continue previous log
14
+ # };
15
+ GGML_LOG_LEVEL_TO_LOGGING_LEVEL = {
16
+ 0: logging.CRITICAL,
17
+ 1: logging.INFO,
18
+ 2: logging.WARNING,
19
+ 3: logging.ERROR,
20
+ 4: logging.DEBUG,
21
+ 5: logging.DEBUG,
22
+ }
23
+
24
+ logger = logging.getLogger("llama-cpp-python")
25
+
26
+ _last_log_level = GGML_LOG_LEVEL_TO_LOGGING_LEVEL[0]
27
+
28
+ # typedef void (*ggml_log_callback)(enum ggml_log_level level, const char * text, void * user_data);
29
+ @llama_cpp.llama_log_callback
30
+ def llama_log_callback(
31
+ level: int,
32
+ text: bytes,
33
+ user_data: ctypes.c_void_p,
34
+ ):
35
+ # TODO: Correctly implement continue previous log
36
+ global _last_log_level
37
+ log_level = GGML_LOG_LEVEL_TO_LOGGING_LEVEL[level] if level != 5 else _last_log_level
38
+ if logger.level <= GGML_LOG_LEVEL_TO_LOGGING_LEVEL[level]:
39
+ print(text.decode("utf-8"), end="", flush=True, file=sys.stderr)
40
+ _last_log_level = log_level
41
+
42
+
43
+ llama_cpp.llama_log_set(llama_log_callback, ctypes.c_void_p(0))
44
+
45
+
46
+ def set_verbose(verbose: bool):
47
+ logger.setLevel(logging.DEBUG if verbose else logging.ERROR)
llama_cpp/_utils.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ from typing import Any, Dict
5
+
6
+ # Avoid "LookupError: unknown encoding: ascii" when open() called in a destructor
7
+ outnull_file = open(os.devnull, "w")
8
+ errnull_file = open(os.devnull, "w")
9
+
10
+ STDOUT_FILENO = 1
11
+ STDERR_FILENO = 2
12
+
13
+
14
+ class suppress_stdout_stderr(object):
15
+ # NOTE: these must be "saved" here to avoid exceptions when using
16
+ # this context manager inside of a __del__ method
17
+ sys = sys
18
+ os = os
19
+
20
+ def __init__(self, disable: bool = True):
21
+ self.disable = disable
22
+
23
+ # Oddly enough this works better than the contextlib version
24
+ def __enter__(self):
25
+ if self.disable:
26
+ return self
27
+
28
+ self.old_stdout_fileno_undup = STDOUT_FILENO
29
+ self.old_stderr_fileno_undup = STDERR_FILENO
30
+
31
+ self.old_stdout_fileno = self.os.dup(self.old_stdout_fileno_undup)
32
+ self.old_stderr_fileno = self.os.dup(self.old_stderr_fileno_undup)
33
+
34
+ self.old_stdout = self.sys.stdout
35
+ self.old_stderr = self.sys.stderr
36
+
37
+ self.os.dup2(outnull_file.fileno(), self.old_stdout_fileno_undup)
38
+ self.os.dup2(errnull_file.fileno(), self.old_stderr_fileno_undup)
39
+
40
+ self.sys.stdout = outnull_file
41
+ self.sys.stderr = errnull_file
42
+ return self
43
+
44
+ def __exit__(self, *_):
45
+ if self.disable:
46
+ return
47
+
48
+ # Check if sys.stdout and sys.stderr have fileno method
49
+ self.sys.stdout = self.old_stdout
50
+ self.sys.stderr = self.old_stderr
51
+
52
+ self.os.dup2(self.old_stdout_fileno, self.old_stdout_fileno_undup)
53
+ self.os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup)
54
+
55
+ self.os.close(self.old_stdout_fileno)
56
+ self.os.close(self.old_stderr_fileno)
57
+
58
+
59
+ class MetaSingleton(type):
60
+ """
61
+ Metaclass for implementing the Singleton pattern.
62
+ """
63
+
64
+ _instances: Dict[type, Any] = {}
65
+
66
+ def __call__(cls, *args: Any, **kwargs: Any) -> Any:
67
+ if cls not in cls._instances:
68
+ cls._instances[cls] = super(MetaSingleton, cls).__call__(*args, **kwargs)
69
+ return cls._instances[cls]
70
+
71
+
72
+ class Singleton(object, metaclass=MetaSingleton):
73
+ """
74
+ Base class for implementing the Singleton pattern.
75
+ """
76
+
77
+ def __init__(self):
78
+ super(Singleton, self).__init__()
llama_cpp/lib/libggml-base.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a33aad406e2803734808d1a05e8a45834fa726a4066f03eb6aaff5b7a3c155a7
3
+ size 615864
llama_cpp/lib/libggml-cpu.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:57833bdcf97a60d84e9ae3089678ab06771fb1c7d4affbdd7360283ccd8e8e16
3
+ size 791480
llama_cpp/lib/libggml.so ADDED
Binary file (47.6 kB). View file
 
llama_cpp/lib/libllama.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9369b069f7df66e1bfb7afc0ff31c18117aba769fbf9cbe7b209ab6254de90cf
3
+ size 2150632
llama_cpp/lib/libmtmd.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:58d6765b2f621fc5feb7e2e188810129be5ff8d5151e1bfef35231f22f6b9b08
3
+ size 722296
llama_cpp/llama.py ADDED
@@ -0,0 +1,2422 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import sys
5
+ import uuid
6
+ import time
7
+ import json
8
+ import ctypes
9
+ import typing
10
+ import random
11
+ import fnmatch
12
+ import warnings
13
+ import contextlib
14
+ import multiprocessing
15
+
16
+ from typing import (
17
+ Any,
18
+ List,
19
+ Literal,
20
+ Optional,
21
+ Union,
22
+ Generator,
23
+ Sequence,
24
+ Iterator,
25
+ Deque,
26
+ Callable,
27
+ Dict,
28
+ )
29
+ from collections import deque
30
+ from pathlib import Path
31
+
32
+
33
+ from .llama_types import *
34
+ from .llama_grammar import LlamaGrammar
35
+ from .llama_cache import (
36
+ BaseLlamaCache,
37
+ LlamaCache, # type: ignore
38
+ LlamaDiskCache, # type: ignore
39
+ LlamaRAMCache, # type: ignore
40
+ )
41
+ from .llama_tokenizer import BaseLlamaTokenizer, LlamaTokenizer
42
+ import llama_cpp.llama_cpp as llama_cpp
43
+ import llama_cpp.llama_chat_format as llama_chat_format
44
+
45
+ from llama_cpp.llama_speculative import LlamaDraftModel
46
+
47
+ import numpy as np
48
+ import numpy.typing as npt
49
+
50
+ import llama_cpp._internals as internals
51
+ from ._logger import set_verbose
52
+ from ._utils import suppress_stdout_stderr
53
+
54
+
55
+ class Llama:
56
+ """High-level Python wrapper for a llama.cpp model."""
57
+
58
+ __backend_initialized = False
59
+
60
+ def __init__(
61
+ self,
62
+ model_path: str,
63
+ *,
64
+ # Model Params
65
+ n_gpu_layers: int = 0,
66
+ split_mode: int = llama_cpp.LLAMA_SPLIT_MODE_LAYER,
67
+ main_gpu: int = 0,
68
+ tensor_split: Optional[List[float]] = None,
69
+ vocab_only: bool = False,
70
+ use_mmap: bool = True,
71
+ use_mlock: bool = False,
72
+ kv_overrides: Optional[Dict[str, Union[bool, int, float, str]]] = None,
73
+ # Context Params
74
+ seed: int = llama_cpp.LLAMA_DEFAULT_SEED,
75
+ n_ctx: int = 512,
76
+ n_batch: int = 512,
77
+ n_ubatch: int = 512,
78
+ n_threads: Optional[int] = None,
79
+ n_threads_batch: Optional[int] = None,
80
+ rope_scaling_type: Optional[
81
+ int
82
+ ] = llama_cpp.LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED,
83
+ pooling_type: int = llama_cpp.LLAMA_POOLING_TYPE_UNSPECIFIED,
84
+ rope_freq_base: float = 0.0,
85
+ rope_freq_scale: float = 0.0,
86
+ yarn_ext_factor: float = -1.0,
87
+ yarn_attn_factor: float = 1.0,
88
+ yarn_beta_fast: float = 32.0,
89
+ yarn_beta_slow: float = 1.0,
90
+ yarn_orig_ctx: int = 0,
91
+ logits_all: bool = False,
92
+ embedding: bool = False,
93
+ offload_kqv: bool = True,
94
+ flash_attn: bool = False,
95
+ op_offload: Optional[bool] = None,
96
+ swa_full: Optional[bool] = None,
97
+ # Sampling Params
98
+ no_perf: bool = False,
99
+ last_n_tokens_size: int = 64,
100
+ # LoRA Params
101
+ lora_base: Optional[str] = None,
102
+ lora_scale: float = 1.0,
103
+ lora_path: Optional[str] = None,
104
+ # Backend Params
105
+ numa: Union[bool, int] = False,
106
+ # Chat Format Params
107
+ chat_format: Optional[str] = None,
108
+ chat_handler: Optional[llama_chat_format.LlamaChatCompletionHandler] = None,
109
+ # Speculative Decoding
110
+ draft_model: Optional[LlamaDraftModel] = None,
111
+ # Tokenizer Override
112
+ tokenizer: Optional[BaseLlamaTokenizer] = None,
113
+ # KV cache quantization
114
+ type_k: Optional[int] = None,
115
+ type_v: Optional[int] = None,
116
+ # Misc
117
+ spm_infill: bool = False,
118
+ verbose: bool = True,
119
+ # Extra Params
120
+ **kwargs, # type: ignore
121
+ ):
122
+ """Load a llama.cpp model from `model_path`.
123
+
124
+ Examples:
125
+ Basic usage
126
+
127
+ >>> import llama_cpp
128
+ >>> model = llama_cpp.Llama(
129
+ ... model_path="path/to/model",
130
+ ... )
131
+ >>> print(model("The quick brown fox jumps ", stop=["."])["choices"][0]["text"])
132
+ the lazy dog
133
+
134
+ Loading a chat model
135
+
136
+ >>> import llama_cpp
137
+ >>> model = llama_cpp.Llama(
138
+ ... model_path="path/to/model",
139
+ ... chat_format="llama-2",
140
+ ... )
141
+ >>> print(model.create_chat_completion(
142
+ ... messages=[{
143
+ ... "role": "user",
144
+ ... "content": "what is the meaning of life?"
145
+ ... }]
146
+ ... ))
147
+
148
+ Args:
149
+ model_path: Path to the model.
150
+ n_gpu_layers: Number of layers to offload to GPU (-ngl). If -1, all layers are offloaded.
151
+ split_mode: How to split the model across GPUs. See llama_cpp.LLAMA_SPLIT_* for options.
152
+ main_gpu: main_gpu interpretation depends on split_mode: LLAMA_SPLIT_MODE_NONE: the GPU that is used for the entire model. LLAMA_SPLIT_MODE_ROW: the GPU that is used for small tensors and intermediate results. LLAMA_SPLIT_MODE_LAYER: ignored
153
+ tensor_split: How split tensors should be distributed across GPUs. If None, the model is not split.
154
+ vocab_only: Only load the vocabulary no weights.
155
+ use_mmap: Use mmap if possible.
156
+ use_mlock: Force the system to keep the model in RAM.
157
+ kv_overrides: Key-value overrides for the model.
158
+ seed: RNG seed, -1 for random
159
+ n_ctx: Text context, 0 = from model
160
+ n_batch: Prompt processing maximum batch size
161
+ n_ubatch: Physical batch size
162
+ n_threads: Number of threads to use for generation
163
+ n_threads_batch: Number of threads to use for batch processing
164
+ rope_scaling_type: RoPE scaling type, from `enum llama_rope_scaling_type`. ref: https://github.com/ggerganov/llama.cpp/pull/2054
165
+ pooling_type: Pooling type, from `enum llama_pooling_type`.
166
+ rope_freq_base: RoPE base frequency, 0 = from model
167
+ rope_freq_scale: RoPE frequency scaling factor, 0 = from model
168
+ yarn_ext_factor: YaRN extrapolation mix factor, negative = from model
169
+ yarn_attn_factor: YaRN magnitude scaling factor
170
+ yarn_beta_fast: YaRN low correction dim
171
+ yarn_beta_slow: YaRN high correction dim
172
+ yarn_orig_ctx: YaRN original context size
173
+ logits_all: Return logits for all tokens, not just the last token. Must be True for completion to return logprobs.
174
+ embedding: Embedding mode only.
175
+ offload_kqv: Offload K, Q, V to GPU.
176
+ flash_attn: Use flash attention.
177
+ op_offload: offload host tensor operations to device
178
+ swa_full: use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
179
+ no_perf: Measure performance timings.
180
+ last_n_tokens_size: Maximum number of tokens to keep in the last_n_tokens deque.
181
+ lora_base: Optional path to base model, useful if using a quantized base model and you want to apply LoRA to an f16 model.
182
+ lora_path: Path to a LoRA file to apply to the model.
183
+ numa: numa policy
184
+ chat_format: String specifying the chat format to use when calling create_chat_completion.
185
+ chat_handler: Optional chat handler to use when calling create_chat_completion.
186
+ draft_model: Optional draft model to use for speculative decoding.
187
+ tokenizer: Optional tokenizer to override the default tokenizer from llama.cpp.
188
+ verbose: Print verbose output to stderr.
189
+ type_k: KV cache data type for K (default: f16)
190
+ type_v: KV cache data type for V (default: f16)
191
+ spm_infill: Use Suffix/Prefix/Middle pattern for infill (instead of Prefix/Suffix/Middle) as some models prefer this.
192
+
193
+ Raises:
194
+ ValueError: If the model path does not exist.
195
+
196
+ Returns:
197
+ A Llama instance.
198
+ """
199
+ self.verbose = verbose
200
+ self._stack = contextlib.ExitStack()
201
+
202
+ set_verbose(verbose)
203
+
204
+ if not Llama.__backend_initialized:
205
+ with suppress_stdout_stderr(disable=verbose):
206
+ llama_cpp.llama_backend_init()
207
+ Llama.__backend_initialized = True
208
+
209
+ if isinstance(numa, bool):
210
+ self.numa = (
211
+ llama_cpp.GGML_NUMA_STRATEGY_DISTRIBUTE
212
+ if numa
213
+ else llama_cpp.GGML_NUMA_STRATEGY_DISABLED
214
+ )
215
+ else:
216
+ self.numa = numa
217
+
218
+ if self.numa != llama_cpp.GGML_NUMA_STRATEGY_DISABLED:
219
+ with suppress_stdout_stderr(disable=verbose):
220
+ llama_cpp.llama_numa_init(self.numa)
221
+
222
+ self.model_path = model_path
223
+
224
+ # Model Params
225
+ self.model_params = llama_cpp.llama_model_default_params()
226
+ self.model_params.n_gpu_layers = (
227
+ 0x7FFFFFFF if n_gpu_layers == -1 else n_gpu_layers
228
+ ) # 0x7FFFFFFF is INT32 max, will be auto set to all layers
229
+ self.model_params.split_mode = split_mode
230
+ self.model_params.main_gpu = main_gpu
231
+ self.tensor_split = tensor_split
232
+ self._c_tensor_split = None
233
+ if self.tensor_split is not None:
234
+ if len(self.tensor_split) > llama_cpp.LLAMA_MAX_DEVICES:
235
+ raise ValueError(
236
+ f"Attempt to split tensors that exceed maximum supported devices. Current LLAMA_MAX_DEVICES={llama_cpp.LLAMA_MAX_DEVICES}"
237
+ )
238
+ # Type conversion and expand the list to the length of LLAMA_MAX_DEVICES
239
+ FloatArray = ctypes.c_float * llama_cpp.LLAMA_MAX_DEVICES
240
+ self._c_tensor_split = FloatArray(
241
+ *tensor_split # type: ignore
242
+ ) # keep a reference to the array so it is not gc'd
243
+ self.model_params.tensor_split = self._c_tensor_split
244
+ self.model_params.vocab_only = vocab_only
245
+ self.model_params.use_mmap = use_mmap if lora_path is None else False
246
+ self.model_params.use_mlock = use_mlock
247
+
248
+ # kv_overrides is the original python dict
249
+ self.kv_overrides = kv_overrides
250
+ if kv_overrides is not None:
251
+ # _kv_overrides_array is a ctypes.Array of llama_model_kv_override Structs
252
+ kvo_array_len = len(kv_overrides) + 1 # for sentinel element
253
+ self._kv_overrides_array = (
254
+ llama_cpp.llama_model_kv_override * kvo_array_len
255
+ )()
256
+
257
+ for i, (k, v) in enumerate(kv_overrides.items()):
258
+ self._kv_overrides_array[i].key = k.encode("utf-8")
259
+ if isinstance(v, bool):
260
+ self._kv_overrides_array[
261
+ i
262
+ ].tag = llama_cpp.LLAMA_KV_OVERRIDE_TYPE_BOOL
263
+ self._kv_overrides_array[i].value.val_bool = v
264
+ elif isinstance(v, int):
265
+ self._kv_overrides_array[
266
+ i
267
+ ].tag = llama_cpp.LLAMA_KV_OVERRIDE_TYPE_INT
268
+ self._kv_overrides_array[i].value.val_i64 = v
269
+ elif isinstance(v, float):
270
+ self._kv_overrides_array[
271
+ i
272
+ ].tag = llama_cpp.LLAMA_KV_OVERRIDE_TYPE_FLOAT
273
+ self._kv_overrides_array[i].value.val_f64 = v
274
+ elif isinstance(v, str): # type: ignore
275
+ v_bytes = v.encode("utf-8")
276
+ if len(v_bytes) > 128: # TODO: Make this a constant
277
+ raise ValueError(f"Value for {k} is too long: {v}")
278
+ v_bytes = v_bytes.ljust(128, b"\0")
279
+ self._kv_overrides_array[
280
+ i
281
+ ].tag = llama_cpp.LLAMA_KV_OVERRIDE_TYPE_STR
282
+ # copy min(v_bytes, 128) to str_value
283
+ address = typing.cast(
284
+ int,
285
+ ctypes.addressof(self._kv_overrides_array[i].value)
286
+ + llama_cpp.llama_model_kv_override_value.val_str.offset,
287
+ )
288
+ buffer_start = ctypes.cast(address, ctypes.POINTER(ctypes.c_char))
289
+ ctypes.memmove(
290
+ buffer_start,
291
+ v_bytes,
292
+ 128,
293
+ )
294
+ else:
295
+ raise ValueError(f"Unknown value type for {k}: {v}")
296
+
297
+ self._kv_overrides_array[
298
+ -1
299
+ ].key = b"\0" # ensure sentinel element is zeroed
300
+ self.model_params.kv_overrides = self._kv_overrides_array
301
+
302
+ self.n_batch = min(n_ctx, n_batch) # ???
303
+ self.n_threads = n_threads or max(multiprocessing.cpu_count() // 2, 1)
304
+ self.n_threads_batch = n_threads_batch or multiprocessing.cpu_count()
305
+
306
+ # Used by the sampler
307
+ self._seed = seed or llama_cpp.LLAMA_DEFAULT_SEED
308
+
309
+ # Context Params
310
+ self.context_params = llama_cpp.llama_context_default_params()
311
+ self.context_params.n_ctx = n_ctx
312
+ self.context_params.n_batch = self.n_batch
313
+ self.context_params.n_ubatch = min(self.n_batch, n_ubatch)
314
+ self.context_params.n_threads = self.n_threads
315
+ self.context_params.n_threads_batch = self.n_threads_batch
316
+ self.context_params.rope_scaling_type = (
317
+ rope_scaling_type
318
+ if rope_scaling_type is not None
319
+ else llama_cpp.LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED
320
+ )
321
+ self.context_params.pooling_type = pooling_type
322
+ self.context_params.rope_freq_base = (
323
+ rope_freq_base if rope_freq_base != 0.0 else 0
324
+ )
325
+ self.context_params.rope_freq_scale = (
326
+ rope_freq_scale if rope_freq_scale != 0.0 else 0
327
+ )
328
+ self.context_params.yarn_ext_factor = (
329
+ yarn_ext_factor if yarn_ext_factor != 0.0 else 0
330
+ )
331
+ self.context_params.yarn_attn_factor = (
332
+ yarn_attn_factor if yarn_attn_factor != 0.0 else 0
333
+ )
334
+ self.context_params.yarn_beta_fast = (
335
+ yarn_beta_fast if yarn_beta_fast != 0.0 else 0
336
+ )
337
+ self.context_params.yarn_beta_slow = (
338
+ yarn_beta_slow if yarn_beta_slow != 0.0 else 0
339
+ )
340
+ self.context_params.yarn_orig_ctx = yarn_orig_ctx if yarn_orig_ctx != 0 else 0
341
+ self._logits_all = logits_all if draft_model is None else True
342
+ self.context_params.embeddings = embedding # TODO: Rename to embeddings
343
+ self.context_params.offload_kqv = offload_kqv
344
+ self.context_params.flash_attn = flash_attn
345
+
346
+ if op_offload is not None:
347
+ self.context_params.op_offload = op_offload
348
+
349
+ if swa_full is not None:
350
+ self.context_params.swa_full = swa_full
351
+
352
+ # KV cache quantization
353
+ if type_k is not None:
354
+ self.context_params.type_k = type_k
355
+ if type_v is not None:
356
+ self.context_params.type_v = type_v
357
+ # Sampling Params
358
+ self.context_params.no_perf = no_perf
359
+ self.last_n_tokens_size = last_n_tokens_size
360
+
361
+ self.cache: Optional[BaseLlamaCache] = None
362
+
363
+ self.lora_base = lora_base
364
+ self.lora_scale = lora_scale
365
+ self.lora_path = lora_path
366
+
367
+ self.spm_infill = spm_infill
368
+
369
+ if not os.path.exists(model_path):
370
+ raise ValueError(f"Model path does not exist: {model_path}")
371
+
372
+ self._model = self._stack.enter_context(
373
+ contextlib.closing(
374
+ internals.LlamaModel(
375
+ path_model=self.model_path,
376
+ params=self.model_params,
377
+ verbose=self.verbose,
378
+ )
379
+ )
380
+ )
381
+
382
+ # Override tokenizer
383
+ self.tokenizer_ = tokenizer or LlamaTokenizer(self)
384
+
385
+ # Set the default value for the context and correct the batch
386
+ if n_ctx == 0:
387
+ n_ctx = self._model.n_ctx_train()
388
+ self.n_batch = min(n_ctx, n_batch)
389
+ self.context_params.n_ctx = self._model.n_ctx_train()
390
+ self.context_params.n_batch = self.n_batch
391
+ self.context_params.n_ubatch = min(self.n_batch, n_ubatch)
392
+
393
+ self._ctx = self._stack.enter_context(
394
+ contextlib.closing(
395
+ internals.LlamaContext(
396
+ model=self._model,
397
+ params=self.context_params,
398
+ verbose=self.verbose,
399
+ )
400
+ )
401
+ )
402
+
403
+ self._batch = self._stack.enter_context(
404
+ contextlib.closing(
405
+ internals.LlamaBatch(
406
+ n_tokens=self.n_batch,
407
+ embd=0,
408
+ n_seq_max=self.context_params.n_ctx,
409
+ verbose=self.verbose,
410
+ )
411
+ )
412
+ )
413
+
414
+ self._lora_adapter: Optional[llama_cpp.llama_adapter_lora_p] = None
415
+
416
+ if self.lora_path:
417
+ self._lora_adapter = llama_cpp.llama_adapter_lora_init(
418
+ self._model.model,
419
+ self.lora_path.encode("utf-8"),
420
+ )
421
+ if self._lora_adapter is None:
422
+ raise RuntimeError(
423
+ f"Failed to initialize LoRA adapter from lora path: {self.lora_path}"
424
+ )
425
+
426
+ def free_lora_adapter():
427
+ if self._lora_adapter is None:
428
+ return
429
+ llama_cpp.llama_adapter_lora_free(self._lora_adapter)
430
+ self._lora_adapter = None
431
+
432
+ self._stack.callback(free_lora_adapter)
433
+
434
+ if llama_cpp.llama_set_adapter_lora(
435
+ self._ctx.ctx, self._lora_adapter, self.lora_scale
436
+ ):
437
+ raise RuntimeError(
438
+ f"Failed to set LoRA adapter from lora path: {self.lora_path}"
439
+ )
440
+
441
+ if self.verbose:
442
+ print(llama_cpp.llama_print_system_info().decode("utf-8"), file=sys.stderr)
443
+
444
+ self.chat_format = chat_format
445
+ self.chat_handler = chat_handler
446
+ self._chat_handlers: Dict[
447
+ str, llama_chat_format.LlamaChatCompletionHandler
448
+ ] = {}
449
+
450
+ self.draft_model = draft_model
451
+
452
+ self._n_vocab = self.n_vocab()
453
+ self._n_ctx = self.n_ctx()
454
+
455
+ self._token_nl = self.token_nl()
456
+ self._token_eos = self.token_eos()
457
+
458
+ self._candidates = internals.LlamaTokenDataArray(n_vocab=self._n_vocab)
459
+
460
+ self.n_tokens = 0
461
+ self.input_ids: npt.NDArray[np.intc] = np.ndarray((n_ctx,), dtype=np.intc)
462
+ self.scores: npt.NDArray[np.single] = np.ndarray(
463
+ (n_ctx if logits_all == True else n_batch, self._n_vocab), dtype=np.single
464
+ )
465
+
466
+ self._mirostat_mu = ctypes.c_float(
467
+ 2.0 * 5.0
468
+ ) # TODO: Move this to sampling context
469
+
470
+ try:
471
+ self.metadata = self._model.metadata()
472
+ except Exception as e:
473
+ self.metadata = {}
474
+ if self.verbose:
475
+ print(f"Failed to load metadata: {e}", file=sys.stderr)
476
+
477
+ if self.verbose:
478
+ print(f"Model metadata: {self.metadata}", file=sys.stderr)
479
+
480
+ eos_token_id = self.token_eos()
481
+ bos_token_id = self.token_bos()
482
+
483
+ eos_token = (
484
+ self._model.token_get_text(eos_token_id) if eos_token_id != -1 else ""
485
+ )
486
+ bos_token = (
487
+ self._model.token_get_text(bos_token_id) if bos_token_id != -1 else ""
488
+ )
489
+
490
+ # Unfortunately the llama.cpp API does not return metadata arrays, so we can't get template names from tokenizer.chat_templates
491
+ template_choices = dict(
492
+ (name[10:], template)
493
+ for name, template in self.metadata.items()
494
+ if name.startswith("tokenizer.chat_template.")
495
+ )
496
+
497
+ if "tokenizer.chat_template" in self.metadata:
498
+ template_choices["chat_template.default"] = self.metadata[
499
+ "tokenizer.chat_template"
500
+ ]
501
+
502
+ if self.verbose and template_choices:
503
+ print(
504
+ f"Available chat formats from metadata: {', '.join(template_choices.keys())}",
505
+ file=sys.stderr,
506
+ )
507
+
508
+ for name, template in template_choices.items():
509
+ self._chat_handlers[name] = llama_chat_format.Jinja2ChatFormatter(
510
+ template=template,
511
+ eos_token=eos_token,
512
+ bos_token=bos_token,
513
+ stop_token_ids=[eos_token_id],
514
+ ).to_chat_handler()
515
+
516
+ if (
517
+ self.chat_format is None
518
+ and self.chat_handler is None
519
+ and "chat_template.default" in template_choices
520
+ ):
521
+ chat_format = llama_chat_format.guess_chat_format_from_gguf_metadata(
522
+ self.metadata
523
+ )
524
+
525
+ if chat_format is not None:
526
+ self.chat_format = chat_format
527
+ if self.verbose:
528
+ print(f"Guessed chat format: {chat_format}", file=sys.stderr)
529
+ else:
530
+ if self.verbose:
531
+ print(
532
+ f"Using gguf chat template: {template_choices['chat_template.default']}",
533
+ file=sys.stderr,
534
+ )
535
+ print(f"Using chat eos_token: {eos_token}", file=sys.stderr)
536
+ print(f"Using chat bos_token: {bos_token}", file=sys.stderr)
537
+
538
+ self.chat_format = "chat_template.default"
539
+
540
+ if self.chat_format is None and self.chat_handler is None:
541
+ self.chat_format = "llama-2"
542
+ if self.verbose:
543
+ print(
544
+ f"Using fallback chat format: {self.chat_format}", file=sys.stderr
545
+ )
546
+
547
+ self._sampler = None
548
+
549
+ @property
550
+ def ctx(self) -> llama_cpp.llama_context_p:
551
+ return self._ctx.ctx
552
+
553
+ @property
554
+ def model(self) -> llama_cpp.llama_model_p:
555
+ return self._model.model
556
+
557
+ @property
558
+ def _input_ids(self) -> npt.NDArray[np.intc]:
559
+ return self.input_ids[: self.n_tokens]
560
+
561
+ @property
562
+ def _scores(self) -> npt.NDArray[np.single]:
563
+ return self.scores[: self.n_tokens, :]
564
+
565
+ @property
566
+ def eval_tokens(self) -> Deque[int]:
567
+ return deque(self.input_ids[: self.n_tokens].tolist(), maxlen=self._n_ctx)
568
+
569
+ @property
570
+ def eval_logits(self) -> Deque[List[float]]:
571
+ return deque(
572
+ self.scores[: self.n_tokens, :].tolist(),
573
+ maxlen=self._n_ctx if self._logits_all else 1,
574
+ )
575
+
576
+ def tokenize(
577
+ self, text: bytes, add_bos: bool = True, special: bool = False
578
+ ) -> List[int]:
579
+ """Tokenize a string.
580
+
581
+ Args:
582
+ text: The utf-8 encoded string to tokenize.
583
+ add_bos: Whether to add a beginning of sequence token.
584
+ special: Whether to tokenize special tokens.
585
+
586
+ Raises:
587
+ RuntimeError: If the tokenization failed.
588
+
589
+ Returns:
590
+ A list of tokens.
591
+ """
592
+ return self.tokenizer_.tokenize(text, add_bos, special)
593
+
594
+ def detokenize(
595
+ self,
596
+ tokens: List[int],
597
+ prev_tokens: Optional[List[int]] = None,
598
+ special: bool = False,
599
+ ) -> bytes:
600
+ """Detokenize a list of tokens.
601
+
602
+ Args:
603
+ tokens: The list of tokens to detokenize.
604
+ prev_tokens: The list of previous tokens. Offset mapping will be performed if provided.
605
+ special: Whether to detokenize special tokens.
606
+
607
+ Returns:
608
+ The detokenized string.
609
+ """
610
+ return self.tokenizer_.detokenize(
611
+ tokens, prev_tokens=prev_tokens, special=special
612
+ )
613
+
614
+ def set_cache(self, cache: Optional[BaseLlamaCache]):
615
+ """Set the cache.
616
+
617
+ Args:
618
+ cache: The cache to set.
619
+ """
620
+ self.cache = cache
621
+
622
+ def set_seed(self, seed: int):
623
+ """Set the random seed.
624
+
625
+ Args:
626
+ seed: The random seed.
627
+ """
628
+ self._seed = seed
629
+
630
+ def reset(self):
631
+ """Reset the model state."""
632
+ self.n_tokens = 0
633
+
634
+ def eval(self, tokens: Sequence[int]):
635
+ """Evaluate a list of tokens.
636
+
637
+ Args:
638
+ tokens: The list of tokens to evaluate.
639
+ """
640
+ self._ctx.kv_cache_seq_rm(-1, self.n_tokens, -1)
641
+ for i in range(0, len(tokens), self.n_batch):
642
+ batch = tokens[i : min(len(tokens), i + self.n_batch)]
643
+ n_past = self.n_tokens
644
+ n_tokens = len(batch)
645
+ self._batch.set_batch(
646
+ batch=batch, n_past=n_past, logits_all=self._logits_all
647
+ )
648
+ self._ctx.decode(self._batch)
649
+ # Save tokens
650
+ self.input_ids[n_past : n_past + n_tokens] = batch
651
+ # Save logits
652
+ if self._logits_all:
653
+ rows = n_tokens
654
+ cols = self._n_vocab
655
+ logits = np.ctypeslib.as_array(
656
+ self._ctx.get_logits(), shape=(rows * cols,)
657
+ )
658
+ self.scores[n_past : n_past + n_tokens, :].reshape(-1)[::] = logits
659
+ else:
660
+ # rows = 1
661
+ # cols = self._n_vocab
662
+ # logits = np.ctypeslib.as_array(
663
+ # self._ctx.get_logits(), shape=(rows * cols,)
664
+ # )
665
+ # self.scores[n_past + n_tokens - 1, :].reshape(-1)[::] = logits
666
+ # NOTE: Now that sampling is done inside the sampler, logits are only needed for logprobs which requires logits_all
667
+ pass
668
+ # Update n_tokens
669
+ self.n_tokens += n_tokens
670
+
671
+ def _init_sampler(
672
+ self,
673
+ top_k: int = 40,
674
+ top_p: float = 0.95,
675
+ min_p: float = 0.05,
676
+ typical_p: float = 1.0,
677
+ temp: float = 0.80,
678
+ repeat_penalty: float = 1.0,
679
+ frequency_penalty: float = 0.0,
680
+ presence_penalty: float = 0.0,
681
+ tfs_z: float = 1.0,
682
+ mirostat_mode: int = 0,
683
+ mirostat_eta: float = 0.1,
684
+ mirostat_tau: float = 5.0,
685
+ penalize_nl: bool = True,
686
+ logits_processor: Optional[LogitsProcessorList] = None,
687
+ grammar: Optional[LlamaGrammar] = None,
688
+ ):
689
+ sampler = internals.LlamaSampler()
690
+
691
+ if logits_processor is not None:
692
+ # Create and add a custom sampler
693
+ def apply_func(token_data_array: llama_cpp.llama_token_data_array_p):
694
+ size = token_data_array.contents.size
695
+ data_soa = token_data_array.contents.data
696
+ data_soa_address = ctypes.addressof(data_soa.contents)
697
+ # NOTE: This is probably broken
698
+ recarray = np.recarray(
699
+ shape=(size,),
700
+ dtype=np.dtype(
701
+ [("id", np.intc), ("logit", np.single), ("p", np.single)],
702
+ align=True,
703
+ ),
704
+ buf=(llama_cpp.llama_token_data * size).from_address(
705
+ data_soa_address
706
+ ),
707
+ )
708
+ for logit_processor in logits_processor:
709
+ recarray.logit[:] = logit_processor(self._input_ids, recarray.logit)
710
+
711
+ sampler.add_custom(apply_func)
712
+
713
+ sampler.add_penalties(
714
+ # n_vocab=self._n_vocab,
715
+ # special_eos_id=self._token_eos,
716
+ # linefeed_id=self._token_nl,
717
+ penalty_last_n=self.last_n_tokens_size,
718
+ penalty_repeat=repeat_penalty,
719
+ penalty_freq=frequency_penalty,
720
+ penalty_present=presence_penalty,
721
+ # penalize_nl=penalize_nl,
722
+ # ignore_eos=False,
723
+ )
724
+
725
+ if grammar is not None:
726
+ sampler.add_grammar(self._model, grammar)
727
+
728
+ if temp < 0.0:
729
+ sampler.add_softmax()
730
+ sampler.add_dist(self._seed)
731
+ elif temp == 0.0:
732
+ sampler.add_greedy()
733
+ else:
734
+ if mirostat_mode == 1:
735
+ mirostat_m = 100
736
+ sampler.add_mirostat(
737
+ self._n_vocab,
738
+ self._seed,
739
+ mirostat_tau,
740
+ mirostat_eta,
741
+ mirostat_m,
742
+ )
743
+ elif mirostat_mode == 2:
744
+ sampler.add_mirostat_v2(
745
+ self._seed,
746
+ mirostat_tau,
747
+ mirostat_eta,
748
+ )
749
+ else:
750
+ n_probs = 0
751
+ min_keep = max(1, n_probs)
752
+ sampler.add_top_k(top_k)
753
+ sampler.add_typical(typical_p, min_keep)
754
+ sampler.add_top_p(top_p, min_keep)
755
+ sampler.add_min_p(min_p, min_keep)
756
+ sampler.add_temp(temp)
757
+ sampler.add_dist(self._seed)
758
+ return sampler
759
+
760
+ def sample(
761
+ self,
762
+ top_k: int = 40,
763
+ top_p: float = 0.95,
764
+ min_p: float = 0.05,
765
+ typical_p: float = 1.0,
766
+ temp: float = 0.80,
767
+ repeat_penalty: float = 1.0,
768
+ frequency_penalty: float = 0.0,
769
+ presence_penalty: float = 0.0,
770
+ tfs_z: float = 1.0,
771
+ mirostat_mode: int = 0,
772
+ mirostat_eta: float = 0.1,
773
+ mirostat_tau: float = 5.0,
774
+ penalize_nl: bool = True,
775
+ logits_processor: Optional[LogitsProcessorList] = None,
776
+ grammar: Optional[LlamaGrammar] = None,
777
+ idx: Optional[int] = None,
778
+ ):
779
+ """Sample a token from the model.
780
+
781
+ Args:
782
+ top_k: The top-k sampling parameter.
783
+ top_p: The top-p sampling parameter.
784
+ temp: The temperature parameter.
785
+ repeat_penalty: The repeat penalty parameter.
786
+
787
+ Returns:
788
+ The sampled token.
789
+ """
790
+ assert self.n_tokens > 0
791
+
792
+ tmp_sampler = False
793
+
794
+ if self._sampler is None:
795
+ tmp_sampler = True
796
+ self._sampler = self._init_sampler(
797
+ top_k=top_k,
798
+ top_p=top_p,
799
+ min_p=min_p,
800
+ typical_p=typical_p,
801
+ temp=temp,
802
+ repeat_penalty=repeat_penalty,
803
+ frequency_penalty=frequency_penalty,
804
+ presence_penalty=presence_penalty,
805
+ tfs_z=tfs_z,
806
+ mirostat_mode=mirostat_mode,
807
+ mirostat_tau=mirostat_tau,
808
+ mirostat_eta=mirostat_eta,
809
+ penalize_nl=penalize_nl,
810
+ logits_processor=logits_processor,
811
+ grammar=grammar,
812
+ )
813
+
814
+ ridx = idx - self.n_tokens if idx is not None else -1
815
+
816
+ assert self.ctx is not None
817
+ token = self._sampler.sample(self._ctx, ridx)
818
+ if tmp_sampler:
819
+ self._sampler = None
820
+ return token
821
+
822
+ def generate(
823
+ self,
824
+ tokens: Sequence[int],
825
+ top_k: int = 40,
826
+ top_p: float = 0.95,
827
+ min_p: float = 0.05,
828
+ typical_p: float = 1.0,
829
+ temp: float = 0.80,
830
+ repeat_penalty: float = 1.0,
831
+ reset: bool = True,
832
+ frequency_penalty: float = 0.0,
833
+ presence_penalty: float = 0.0,
834
+ tfs_z: float = 1.0,
835
+ mirostat_mode: int = 0,
836
+ mirostat_tau: float = 5.0,
837
+ mirostat_eta: float = 0.1,
838
+ penalize_nl: bool = True,
839
+ logits_processor: Optional[LogitsProcessorList] = None,
840
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
841
+ grammar: Optional[LlamaGrammar] = None,
842
+ ) -> Generator[int, Optional[Sequence[int]], None]:
843
+ """Create a generator of tokens from a prompt.
844
+
845
+ Examples:
846
+ >>> llama = Llama("models/ggml-7b.bin")
847
+ >>> tokens = llama.tokenize(b"Hello, world!")
848
+ >>> for token in llama.generate(tokens, top_k=40, top_p=0.95, temp=1.0, repeat_penalty=1.0):
849
+ ... print(llama.detokenize([token]))
850
+
851
+ Args:
852
+ tokens: The prompt tokens.
853
+ top_k: The top-k sampling parameter.
854
+ top_p: The top-p sampling parameter.
855
+ temp: The temperature parameter.
856
+ repeat_penalty: The repeat penalty parameter.
857
+ reset: Whether to reset the model state.
858
+
859
+ Yields:
860
+ The generated tokens.
861
+ """
862
+ # Reset mirostat sampling
863
+ self._mirostat_mu = ctypes.c_float(2.0 * mirostat_tau)
864
+ self._sampler = self._init_sampler(
865
+ top_k=top_k,
866
+ top_p=top_p,
867
+ min_p=min_p,
868
+ typical_p=typical_p,
869
+ temp=temp,
870
+ repeat_penalty=repeat_penalty,
871
+ frequency_penalty=frequency_penalty,
872
+ presence_penalty=presence_penalty,
873
+ tfs_z=tfs_z,
874
+ mirostat_mode=mirostat_mode,
875
+ mirostat_tau=mirostat_tau,
876
+ mirostat_eta=mirostat_eta,
877
+ penalize_nl=penalize_nl,
878
+ logits_processor=logits_processor,
879
+ grammar=grammar,
880
+ )
881
+
882
+ # Check for kv cache prefix match
883
+ if reset and self.n_tokens > 0:
884
+ longest_prefix = 0
885
+ for a, b in zip(self._input_ids, tokens[:-1]):
886
+ if a == b:
887
+ longest_prefix += 1
888
+ else:
889
+ break
890
+ if longest_prefix > 0:
891
+ reset = False
892
+ tokens = tokens[longest_prefix:]
893
+ self.n_tokens = longest_prefix
894
+ if self.verbose:
895
+ print(
896
+ f"Llama.generate: {longest_prefix} prefix-match hit, "
897
+ f"remaining {len(tokens)} prompt tokens to eval",
898
+ file=sys.stderr,
899
+ )
900
+
901
+ # Reset the model state
902
+ if reset:
903
+ self.reset()
904
+
905
+ # # Reset the grammar
906
+ # if grammar is not None:
907
+ # grammar.reset()
908
+
909
+ sample_idx = self.n_tokens + len(tokens) - 1
910
+ tokens = list(tokens)
911
+
912
+ # Eval and sample
913
+ while True:
914
+ self.eval(tokens)
915
+ while sample_idx < self.n_tokens:
916
+ token = self.sample(
917
+ top_k=top_k,
918
+ top_p=top_p,
919
+ min_p=min_p,
920
+ typical_p=typical_p,
921
+ temp=temp,
922
+ repeat_penalty=repeat_penalty,
923
+ frequency_penalty=frequency_penalty,
924
+ presence_penalty=presence_penalty,
925
+ tfs_z=tfs_z,
926
+ mirostat_mode=mirostat_mode,
927
+ mirostat_tau=mirostat_tau,
928
+ mirostat_eta=mirostat_eta,
929
+ logits_processor=logits_processor,
930
+ grammar=grammar,
931
+ penalize_nl=penalize_nl,
932
+ idx=sample_idx,
933
+ )
934
+
935
+ sample_idx += 1
936
+ if stopping_criteria is not None and stopping_criteria(
937
+ self._input_ids[: sample_idx], self._scores[sample_idx - self.n_tokens, :]
938
+ ):
939
+ return
940
+ tokens_or_none = yield token
941
+ tokens.clear()
942
+ tokens.append(token)
943
+ if tokens_or_none is not None:
944
+ tokens.extend(tokens_or_none)
945
+
946
+ if sample_idx < self.n_tokens and token != self._input_ids[sample_idx]:
947
+ self.n_tokens = sample_idx
948
+ self._ctx.kv_cache_seq_rm(-1, self.n_tokens, -1)
949
+ break
950
+
951
+ if self.draft_model is not None:
952
+ self.input_ids[self.n_tokens : self.n_tokens + len(tokens)] = tokens
953
+ draft_tokens = self.draft_model(
954
+ self.input_ids[: self.n_tokens + len(tokens)]
955
+ )
956
+ tokens.extend(
957
+ draft_tokens.astype(int)[
958
+ : self._n_ctx - self.n_tokens - len(tokens)
959
+ ]
960
+ )
961
+
962
+ def create_embedding(
963
+ self, input: Union[str, List[str]], model: Optional[str] = None
964
+ ) -> CreateEmbeddingResponse:
965
+ """Embed a string.
966
+
967
+ Args:
968
+ input: The utf-8 encoded string to embed.
969
+
970
+ Returns:
971
+ An embedding object.
972
+ """
973
+ model_name: str = model if model is not None else self.model_path
974
+
975
+ input = input if isinstance(input, list) else [input]
976
+
977
+ # get numeric embeddings
978
+ embeds: Union[List[List[float]], List[List[List[float]]]]
979
+ total_tokens: int
980
+ embeds, total_tokens = self.embed(input, return_count=True) # type: ignore
981
+
982
+ # convert to CreateEmbeddingResponse
983
+ data: List[Embedding] = [
984
+ {
985
+ "object": "embedding",
986
+ "embedding": emb,
987
+ "index": idx,
988
+ }
989
+ for idx, emb in enumerate(embeds)
990
+ ]
991
+
992
+ return {
993
+ "object": "list",
994
+ "data": data,
995
+ "model": model_name,
996
+ "usage": {
997
+ "prompt_tokens": total_tokens,
998
+ "total_tokens": total_tokens,
999
+ },
1000
+ }
1001
+
1002
+ def embed(
1003
+ self,
1004
+ input: Union[str, List[str]],
1005
+ normalize: bool = False,
1006
+ truncate: bool = True,
1007
+ return_count: bool = False,
1008
+ ):
1009
+ """Embed a string.
1010
+
1011
+ Args:
1012
+ input: The utf-8 encoded string to embed.
1013
+
1014
+ Returns:
1015
+ A list of embeddings
1016
+ """
1017
+ n_embd = self.n_embd()
1018
+ n_batch = self.n_batch
1019
+
1020
+ # get pooling information
1021
+ pooling_type = self.pooling_type()
1022
+ logits_all = pooling_type == llama_cpp.LLAMA_POOLING_TYPE_NONE
1023
+
1024
+ if self.context_params.embeddings is False:
1025
+ raise RuntimeError(
1026
+ "Llama model must be created with embedding=True to call this method"
1027
+ )
1028
+
1029
+ if self.verbose:
1030
+ llama_cpp.llama_perf_context_reset(self._ctx.ctx)
1031
+
1032
+ if isinstance(input, str):
1033
+ inputs = [input]
1034
+ else:
1035
+ inputs = input
1036
+
1037
+ # reset batch
1038
+ self._batch.reset()
1039
+
1040
+ # decode and fetch embeddings
1041
+ data: Union[List[List[float]], List[List[List[float]]]] = []
1042
+
1043
+ def decode_batch(seq_sizes: List[int]):
1044
+ llama_cpp.llama_kv_self_clear(self._ctx.ctx)
1045
+ self._ctx.decode(self._batch)
1046
+ self._batch.reset()
1047
+
1048
+ # store embeddings
1049
+ if pooling_type == llama_cpp.LLAMA_POOLING_TYPE_NONE:
1050
+ pos: int = 0
1051
+ for i, size in enumerate(seq_sizes):
1052
+ ptr = llama_cpp.llama_get_embeddings(self._ctx.ctx)
1053
+ embedding: List[List[float]] = [
1054
+ ptr[pos + j * n_embd : pos + (j + 1) * n_embd]
1055
+ for j in range(size)
1056
+ ]
1057
+ if normalize:
1058
+ embedding = [
1059
+ internals.normalize_embedding(e) for e in embedding
1060
+ ]
1061
+ data.append(embedding)
1062
+ pos += size
1063
+ else:
1064
+ for i in range(len(seq_sizes)):
1065
+ ptr = llama_cpp.llama_get_embeddings_seq(self._ctx.ctx, i)
1066
+ embedding: List[float] = ptr[:n_embd]
1067
+ if normalize:
1068
+ embedding = internals.normalize_embedding(embedding)
1069
+ data.append(embedding)
1070
+
1071
+ # init state
1072
+ total_tokens = 0
1073
+ s_batch = []
1074
+ t_batch = 0
1075
+ p_batch = 0
1076
+
1077
+ # accumulate batches and encode
1078
+ for text in inputs:
1079
+ tokens = self.tokenize(text.encode("utf-8"))
1080
+ if truncate:
1081
+ tokens = tokens[:n_batch]
1082
+
1083
+ n_tokens = len(tokens)
1084
+ total_tokens += n_tokens
1085
+
1086
+ # check for overrun
1087
+ if n_tokens > n_batch:
1088
+ raise ValueError(
1089
+ f"Requested tokens ({n_tokens}) exceed batch size of {n_batch}"
1090
+ )
1091
+
1092
+ # time to eval batch
1093
+ if t_batch + n_tokens > n_batch:
1094
+ decode_batch(s_batch)
1095
+ s_batch = []
1096
+ t_batch = 0
1097
+ p_batch = 0
1098
+
1099
+ # add to batch
1100
+ self._batch.add_sequence(tokens, p_batch, logits_all)
1101
+
1102
+ # update batch stats
1103
+ s_batch.append(n_tokens)
1104
+ t_batch += n_tokens
1105
+ p_batch += 1
1106
+
1107
+ # hanlde last batch
1108
+ decode_batch(s_batch)
1109
+
1110
+ if self.verbose:
1111
+ llama_cpp.llama_perf_context_print(self._ctx.ctx)
1112
+
1113
+ output = data[0] if isinstance(input, str) else data
1114
+
1115
+ llama_cpp.llama_kv_self_clear(self._ctx.ctx)
1116
+ self.reset()
1117
+
1118
+ if return_count:
1119
+ return output, total_tokens
1120
+ else:
1121
+ return output
1122
+
1123
+ def _create_completion(
1124
+ self,
1125
+ prompt: Union[str, List[int]],
1126
+ suffix: Optional[str] = None,
1127
+ max_tokens: Optional[int] = 16,
1128
+ temperature: float = 0.8,
1129
+ top_p: float = 0.95,
1130
+ min_p: float = 0.05,
1131
+ typical_p: float = 1.0,
1132
+ logprobs: Optional[int] = None,
1133
+ echo: bool = False,
1134
+ stop: Optional[Union[str, List[str]]] = [],
1135
+ frequency_penalty: float = 0.0,
1136
+ presence_penalty: float = 0.0,
1137
+ repeat_penalty: float = 1.0,
1138
+ top_k: int = 40,
1139
+ stream: bool = False,
1140
+ seed: Optional[int] = None,
1141
+ tfs_z: float = 1.0,
1142
+ mirostat_mode: int = 0,
1143
+ mirostat_tau: float = 5.0,
1144
+ mirostat_eta: float = 0.1,
1145
+ model: Optional[str] = None,
1146
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
1147
+ logits_processor: Optional[LogitsProcessorList] = None,
1148
+ grammar: Optional[LlamaGrammar] = None,
1149
+ logit_bias: Optional[Dict[int, float]] = None,
1150
+ ) -> Union[
1151
+ Iterator[CreateCompletionResponse], Iterator[CreateCompletionStreamResponse]
1152
+ ]:
1153
+ assert suffix is None or suffix.__class__ is str
1154
+
1155
+ completion_id: str = f"cmpl-{str(uuid.uuid4())}"
1156
+ created: int = int(time.time())
1157
+ bos_token_id: int = self.token_bos()
1158
+ cls_token_id: int = self._model.token_cls()
1159
+ sep_token_id: int = self._model.token_sep()
1160
+ prefix_token_id: int = 0 # self._model.token_prefix() # TODO: Fix
1161
+ middle_token_id: int = 0 # self._model.token_middle() # TODO: Fix
1162
+ suffix_token_id: int = 0 # self._model.token_suffix() # TODO: Fix
1163
+ add_space_prefix: bool = (
1164
+ self.metadata.get("tokenizer.ggml.add_space_prefix", "true") == "true"
1165
+ )
1166
+ bos_tokens: List[int] = [cls_token_id if cls_token_id != -1 else bos_token_id]
1167
+ eos_tokens: List[int] = [
1168
+ sep_token_id if sep_token_id != -1 else self.token_eos()
1169
+ ]
1170
+
1171
+ if (
1172
+ (isinstance(prompt, list) and suffix is None)
1173
+ or not self._model.add_bos_token()
1174
+ or bos_tokens[:1] == [-1]
1175
+ ):
1176
+ bos_tokens = []
1177
+
1178
+ if (isinstance(prompt, list) and suffix is None) or (
1179
+ not self._model.add_eos_token() and sep_token_id == -1
1180
+ ):
1181
+ eos_tokens = []
1182
+
1183
+ suffix_space_prefix: int = 0
1184
+ # Tokenizer hack to remove leading space
1185
+ if add_space_prefix and suffix_token_id >= 0 and suffix:
1186
+ suffix = "☺" + suffix
1187
+ suffix_space_prefix = 2
1188
+
1189
+ # If prompt is empty, initialize completion with BOS token to avoid
1190
+ # detokenization including a space at the beginning of the completion
1191
+ completion_tokens: List[int] = [] if len(prompt) > 0 else [bos_token_id]
1192
+ # Add blank space to start of prompt to match OG llama tokenizer
1193
+ prefix_tokens: List[int] = (
1194
+ [prefix_token_id] if prefix_token_id >= 0 and suffix is not None else []
1195
+ ) + (
1196
+ (
1197
+ self.tokenize(
1198
+ prompt.encode("utf-8"),
1199
+ add_bos=False,
1200
+ special=(prefix_token_id < 0 or suffix is None),
1201
+ )
1202
+ if prompt != ""
1203
+ else []
1204
+ )
1205
+ if isinstance(prompt, str)
1206
+ else prompt
1207
+ )
1208
+ suffix_tokens: List[int] = (
1209
+ (
1210
+ [suffix_token_id]
1211
+ + (
1212
+ self.tokenize(suffix.encode("utf-8"), add_bos=False, special=False)[
1213
+ suffix_space_prefix:
1214
+ ]
1215
+ if suffix
1216
+ else []
1217
+ )
1218
+ )
1219
+ if suffix_token_id >= 0 and suffix is not None
1220
+ else []
1221
+ )
1222
+ middle_tokens: List[int] = (
1223
+ [middle_token_id] if middle_token_id >= 0 and suffix is not None else []
1224
+ )
1225
+ prompt_tokens: List[int] = (
1226
+ bos_tokens
1227
+ + (
1228
+ (suffix_tokens + prefix_tokens + middle_tokens)
1229
+ if self.spm_infill
1230
+ else (prefix_tokens + suffix_tokens + middle_tokens)
1231
+ )
1232
+ + eos_tokens
1233
+ )
1234
+ text: bytes = b""
1235
+ returned_tokens: int = 0
1236
+ stop = (
1237
+ stop if isinstance(stop, list) else [stop] if isinstance(stop, str) else []
1238
+ )
1239
+ model_name: str = model if model is not None else self.model_path
1240
+
1241
+ if prompt_tokens[:2] == [self.token_bos()] * 2:
1242
+ warnings.warn(
1243
+ f'Detected duplicate leading "{self._model.token_get_text(self.token_bos())}" in prompt, this will likely reduce response quality, consider removing it...',
1244
+ RuntimeWarning,
1245
+ )
1246
+
1247
+ # NOTE: This likely doesn't work correctly for the first token in the prompt
1248
+ # because of the extra space added to the start of the prompt_tokens
1249
+ if logit_bias is not None:
1250
+ logit_bias_map = {int(k): float(v) for k, v in logit_bias.items()}
1251
+
1252
+ def logit_bias_processor(
1253
+ input_ids: npt.NDArray[np.intc],
1254
+ scores: npt.NDArray[np.single],
1255
+ ) -> npt.NDArray[np.single]:
1256
+ new_scores = np.copy(
1257
+ scores
1258
+ ) # Does it make sense to copy the whole array or can we just overwrite the original one?
1259
+ for input_id, score in logit_bias_map.items():
1260
+ new_scores[input_id] = score + scores[input_id]
1261
+ return new_scores
1262
+
1263
+ _logit_bias_processor = LogitsProcessorList([logit_bias_processor])
1264
+ if logits_processor is None:
1265
+ logits_processor = _logit_bias_processor
1266
+ else:
1267
+ logits_processor = logits_processor.extend(_logit_bias_processor)
1268
+
1269
+ if self.verbose:
1270
+ self._ctx.reset_timings()
1271
+
1272
+ if len(prompt_tokens) >= self._n_ctx:
1273
+ raise ValueError(
1274
+ f"Requested tokens ({len(prompt_tokens)}) exceed context window of {llama_cpp.llama_n_ctx(self.ctx)}"
1275
+ )
1276
+
1277
+ if max_tokens is None or max_tokens <= 0:
1278
+ # Unlimited, depending on n_ctx.
1279
+ max_tokens = self._n_ctx - len(prompt_tokens)
1280
+
1281
+ # Truncate max_tokens if requested tokens would exceed the context window
1282
+ max_tokens = (
1283
+ max_tokens
1284
+ if max_tokens + len(prompt_tokens) < self._n_ctx
1285
+ else (self._n_ctx - len(prompt_tokens))
1286
+ )
1287
+
1288
+ if stop != []:
1289
+ stop_sequences = [s.encode("utf-8") for s in stop]
1290
+ else:
1291
+ stop_sequences = []
1292
+
1293
+ if logprobs is not None and self._logits_all is False:
1294
+ raise ValueError(
1295
+ "logprobs is not supported for models created with logits_all=False"
1296
+ )
1297
+
1298
+ if self.cache:
1299
+ try:
1300
+ cache_item = self.cache[prompt_tokens]
1301
+ cache_prefix_len = Llama.longest_token_prefix(
1302
+ cache_item.input_ids.tolist(), prompt_tokens
1303
+ )
1304
+ eval_prefix_len = Llama.longest_token_prefix(
1305
+ self._input_ids.tolist(), prompt_tokens
1306
+ )
1307
+ if cache_prefix_len > eval_prefix_len:
1308
+ self.load_state(cache_item)
1309
+ if self.verbose:
1310
+ print("Llama._create_completion: cache hit", file=sys.stderr)
1311
+ except KeyError:
1312
+ if self.verbose:
1313
+ print("Llama._create_completion: cache miss", file=sys.stderr)
1314
+
1315
+ if seed is not None:
1316
+ self.set_seed(seed)
1317
+ else:
1318
+ self.set_seed(random.Random(self._seed).randint(0, 2 ** 32))
1319
+
1320
+ finish_reason = "length"
1321
+ multibyte_fix = 0
1322
+ for token in self.generate(
1323
+ prompt_tokens,
1324
+ top_k=top_k,
1325
+ top_p=top_p,
1326
+ min_p=min_p,
1327
+ typical_p=typical_p,
1328
+ temp=temperature,
1329
+ tfs_z=tfs_z,
1330
+ mirostat_mode=mirostat_mode,
1331
+ mirostat_tau=mirostat_tau,
1332
+ mirostat_eta=mirostat_eta,
1333
+ frequency_penalty=frequency_penalty,
1334
+ presence_penalty=presence_penalty,
1335
+ repeat_penalty=repeat_penalty,
1336
+ stopping_criteria=stopping_criteria,
1337
+ logits_processor=logits_processor,
1338
+ grammar=grammar,
1339
+ ):
1340
+ if llama_cpp.llama_token_is_eog(self._model.vocab, token):
1341
+ text = self.detokenize(completion_tokens, prev_tokens=prompt_tokens)
1342
+ finish_reason = "stop"
1343
+ break
1344
+
1345
+ completion_tokens.append(token)
1346
+
1347
+ all_text = self.detokenize(completion_tokens, prev_tokens=prompt_tokens)
1348
+
1349
+ # Contains multi-byte UTF8
1350
+ for k, char in enumerate(all_text[-3:]):
1351
+ k = 3 - k
1352
+ for num, pattern in [(2, 192), (3, 224), (4, 240)]:
1353
+ # Bitwise AND check
1354
+ if num > k and pattern & char == pattern:
1355
+ multibyte_fix = num - k
1356
+
1357
+ # Stop incomplete bytes from passing
1358
+ if multibyte_fix > 0:
1359
+ multibyte_fix -= 1
1360
+ continue
1361
+
1362
+ any_stop = [s for s in stop_sequences if s in all_text]
1363
+ if len(any_stop) > 0:
1364
+ first_stop = any_stop[0]
1365
+ text = all_text[: all_text.index(first_stop)]
1366
+ finish_reason = "stop"
1367
+ break
1368
+
1369
+ if stream:
1370
+ remaining_tokens = completion_tokens[returned_tokens:]
1371
+ remaining_text = self.detokenize(
1372
+ remaining_tokens,
1373
+ prev_tokens=prompt_tokens + completion_tokens[:returned_tokens],
1374
+ )
1375
+ remaining_length = len(remaining_text)
1376
+
1377
+ # We want to avoid yielding any characters from
1378
+ # the generated text if they are part of a stop
1379
+ # sequence.
1380
+ first_stop_position = 0
1381
+ for s in stop_sequences:
1382
+ for i in range(min(len(s), remaining_length), 0, -1):
1383
+ if remaining_text.endswith(s[:i]):
1384
+ if i > first_stop_position:
1385
+ first_stop_position = i
1386
+ break
1387
+
1388
+ token_end_position = 0
1389
+
1390
+ if logprobs is not None:
1391
+ # not sure how to handle this branch when dealing
1392
+ # with CJK output, so keep it unchanged
1393
+ for token in remaining_tokens:
1394
+ if token == bos_token_id:
1395
+ continue
1396
+ token_end_position += len(
1397
+ self.detokenize(
1398
+ [token],
1399
+ prev_tokens=prompt_tokens
1400
+ + completion_tokens[:returned_tokens],
1401
+ )
1402
+ )
1403
+ # Check if stop sequence is in the token
1404
+ if token_end_position > (
1405
+ remaining_length - first_stop_position
1406
+ ):
1407
+ break
1408
+ token_str = self.detokenize(
1409
+ [token],
1410
+ prev_tokens=prompt_tokens
1411
+ + completion_tokens[:returned_tokens],
1412
+ ).decode("utf-8", errors="ignore")
1413
+ text_offset = len(prompt) + len(
1414
+ self.detokenize(
1415
+ completion_tokens[:returned_tokens],
1416
+ prev_tokens=prompt_tokens
1417
+ + completion_tokens[:returned_tokens],
1418
+ ).decode("utf-8", errors="ignore")
1419
+ )
1420
+ token_offset = len(prompt_tokens) + returned_tokens
1421
+ logits = self._scores[token_offset - 1, :]
1422
+ current_logprobs = Llama.logits_to_logprobs(logits).tolist()
1423
+ sorted_logprobs = list(
1424
+ sorted(
1425
+ zip(current_logprobs, range(len(current_logprobs))),
1426
+ reverse=True,
1427
+ )
1428
+ )
1429
+ top_logprob = {
1430
+ self.detokenize([i]).decode(
1431
+ "utf-8", errors="ignore"
1432
+ ): logprob
1433
+ for logprob, i in sorted_logprobs[:logprobs]
1434
+ }
1435
+ top_logprob.update({token_str: current_logprobs[int(token)]})
1436
+ logprobs_or_none = {
1437
+ "tokens": [
1438
+ self.detokenize(
1439
+ [token],
1440
+ prev_tokens=prompt_tokens
1441
+ + completion_tokens[:returned_tokens],
1442
+ ).decode("utf-8", errors="ignore")
1443
+ ],
1444
+ "text_offset": [text_offset],
1445
+ "token_logprobs": [current_logprobs[int(token)]],
1446
+ "top_logprobs": [top_logprob],
1447
+ }
1448
+ returned_tokens += 1
1449
+ yield {
1450
+ "id": completion_id,
1451
+ "object": "text_completion",
1452
+ "created": created,
1453
+ "model": model_name,
1454
+ "choices": [
1455
+ {
1456
+ "text": self.detokenize(
1457
+ [token],
1458
+ prev_tokens=prompt_tokens
1459
+ + completion_tokens[:returned_tokens],
1460
+ ).decode("utf-8", errors="ignore"),
1461
+ "index": 0,
1462
+ "logprobs": logprobs_or_none,
1463
+ "finish_reason": None,
1464
+ }
1465
+ ],
1466
+ }
1467
+ else:
1468
+ while len(remaining_tokens) > 0:
1469
+ decode_success = False
1470
+ for i in range(1, len(remaining_tokens) + 1):
1471
+ try:
1472
+ bs = self.detokenize(
1473
+ remaining_tokens[:i],
1474
+ prev_tokens=prompt_tokens
1475
+ + completion_tokens[:returned_tokens],
1476
+ )
1477
+ ts = bs.decode("utf-8")
1478
+ decode_success = True
1479
+ break
1480
+ except UnicodeError:
1481
+ pass
1482
+ else:
1483
+ break
1484
+ if not decode_success:
1485
+ # all remaining tokens cannot be decoded to a UTF-8 character
1486
+ break
1487
+ token_end_position += len(bs)
1488
+ if token_end_position > (
1489
+ remaining_length - first_stop_position
1490
+ ):
1491
+ break
1492
+ remaining_tokens = remaining_tokens[i:]
1493
+ returned_tokens += i
1494
+
1495
+ yield {
1496
+ "id": completion_id,
1497
+ "object": "text_completion",
1498
+ "created": created,
1499
+ "model": model_name,
1500
+ "choices": [
1501
+ {
1502
+ "text": ts,
1503
+ "index": 0,
1504
+ "logprobs": None,
1505
+ "finish_reason": None,
1506
+ }
1507
+ ],
1508
+ }
1509
+
1510
+ if len(completion_tokens) >= max_tokens:
1511
+ text = self.detokenize(completion_tokens, prev_tokens=prompt_tokens)
1512
+ finish_reason = "length"
1513
+ break
1514
+
1515
+ if stopping_criteria is not None and stopping_criteria(
1516
+ self._input_ids, self._scores[-1, :]
1517
+ ):
1518
+ text = self.detokenize(completion_tokens, prev_tokens=prompt_tokens)
1519
+ finish_reason = "stop"
1520
+
1521
+ if self.verbose:
1522
+ self._ctx.print_timings()
1523
+
1524
+ if stream:
1525
+ remaining_tokens = completion_tokens[returned_tokens:]
1526
+ remaining_text = self.detokenize(
1527
+ remaining_tokens,
1528
+ prev_tokens=prompt_tokens + completion_tokens[:returned_tokens],
1529
+ )
1530
+ any_stop = [s for s in stop_sequences if s in remaining_text]
1531
+ if len(any_stop) > 0:
1532
+ end = min(remaining_text.index(stop) for stop in any_stop)
1533
+ else:
1534
+ end = len(remaining_text)
1535
+
1536
+ token_end_position = 0
1537
+ for token in remaining_tokens:
1538
+ token_end_position += len(
1539
+ self.detokenize(
1540
+ [token],
1541
+ prev_tokens=prompt_tokens + completion_tokens[:returned_tokens],
1542
+ )
1543
+ )
1544
+
1545
+ logprobs_or_none: Optional[CompletionLogprobs] = None
1546
+ if logprobs is not None:
1547
+ if token == bos_token_id:
1548
+ continue
1549
+ token_str = self.detokenize([token]).decode(
1550
+ "utf-8", errors="ignore"
1551
+ )
1552
+ text_offset = len(prompt) + len(
1553
+ self.detokenize(
1554
+ completion_tokens[:returned_tokens],
1555
+ prev_tokens=prompt_tokens
1556
+ + completion_tokens[:returned_tokens],
1557
+ )
1558
+ )
1559
+ token_offset = len(prompt_tokens) + returned_tokens - 1
1560
+ logits = self._scores[token_offset, :]
1561
+ current_logprobs = Llama.logits_to_logprobs(logits).tolist()
1562
+ sorted_logprobs = list(
1563
+ sorted(
1564
+ zip(current_logprobs, range(len(current_logprobs))),
1565
+ reverse=True,
1566
+ )
1567
+ )
1568
+ top_logprob = {
1569
+ self.detokenize([i]).decode("utf-8", errors="ignore"): logprob
1570
+ for logprob, i in sorted_logprobs[:logprobs]
1571
+ }
1572
+ top_logprob.update({token_str: current_logprobs[int(token)]})
1573
+ logprobs_or_none = {
1574
+ "tokens": [
1575
+ self.detokenize([token]).decode("utf-8", errors="ignore")
1576
+ ],
1577
+ "text_offset": [text_offset],
1578
+ "token_logprobs": [current_logprobs[int(token)]],
1579
+ "top_logprobs": [top_logprob],
1580
+ }
1581
+
1582
+ if token_end_position >= end:
1583
+ last_text = self.detokenize([token])
1584
+ if token_end_position == end - 1:
1585
+ break
1586
+ returned_tokens += 1
1587
+ yield {
1588
+ "id": completion_id,
1589
+ "object": "text_completion",
1590
+ "created": created,
1591
+ "model": model_name,
1592
+ "choices": [
1593
+ {
1594
+ "text": last_text[
1595
+ : len(last_text) - (token_end_position - end)
1596
+ ].decode("utf-8", errors="ignore"),
1597
+ "index": 0,
1598
+ "logprobs": logprobs_or_none,
1599
+ "finish_reason": None,
1600
+ }
1601
+ ],
1602
+ }
1603
+ break
1604
+ returned_tokens += 1
1605
+ yield {
1606
+ "id": completion_id,
1607
+ "object": "text_completion",
1608
+ "created": created,
1609
+ "model": model_name,
1610
+ "choices": [
1611
+ {
1612
+ "text": self.detokenize([token]).decode(
1613
+ "utf-8", errors="ignore"
1614
+ ),
1615
+ "index": 0,
1616
+ "logprobs": logprobs_or_none,
1617
+ "finish_reason": None,
1618
+ }
1619
+ ],
1620
+ }
1621
+ yield {
1622
+ "id": completion_id,
1623
+ "object": "text_completion",
1624
+ "created": created,
1625
+ "model": model_name,
1626
+ "choices": [
1627
+ {
1628
+ "text": "",
1629
+ "index": 0,
1630
+ "logprobs": None,
1631
+ "finish_reason": finish_reason,
1632
+ }
1633
+ ],
1634
+ }
1635
+ if self.cache:
1636
+ if self.verbose:
1637
+ print("Llama._create_completion: cache save", file=sys.stderr)
1638
+ self.cache[prompt_tokens + completion_tokens] = self.save_state()
1639
+ if self.verbose:
1640
+ print("Llama._create_completion: cache saved", file=sys.stderr)
1641
+ return
1642
+
1643
+ if self.cache:
1644
+ if self.verbose:
1645
+ print("Llama._create_completion: cache save", file=sys.stderr)
1646
+ self.cache[prompt_tokens + completion_tokens] = self.save_state()
1647
+
1648
+ text_str = text.decode("utf-8", errors="ignore")
1649
+
1650
+ if echo:
1651
+ text_str = prompt + text_str
1652
+
1653
+ if suffix_token_id < 0 and suffix is not None:
1654
+ text_str = text_str + suffix
1655
+
1656
+ logprobs_or_none: Optional[CompletionLogprobs] = None
1657
+ if logprobs is not None:
1658
+ text_offset = 0 if echo else len(prompt)
1659
+ token_offset = 0 if echo else len(prompt_tokens[1:])
1660
+ text_offsets: List[int] = []
1661
+ token_logprobs: List[Optional[float]] = []
1662
+ tokens: List[str] = []
1663
+ top_logprobs: List[Optional[Dict[str, float]]] = []
1664
+
1665
+ if echo:
1666
+ # Remove leading BOS token if exists
1667
+ all_tokens = (
1668
+ prompt_tokens[1 if prompt_tokens[0] == self.token_bos() else 0 :]
1669
+ + completion_tokens
1670
+ )
1671
+ else:
1672
+ all_tokens = completion_tokens
1673
+
1674
+ all_token_strs = [
1675
+ self.detokenize([token], prev_tokens=all_tokens[:i]).decode(
1676
+ "utf-8", errors="ignore"
1677
+ )
1678
+ for i, token in enumerate(all_tokens)
1679
+ ]
1680
+ all_logprobs = Llama.logits_to_logprobs(self._scores)[token_offset:]
1681
+ # TODO: may be able to change this loop to use np.take_along_dim
1682
+ for idx, (token, token_str, logprobs_token) in enumerate(
1683
+ zip(all_tokens, all_token_strs, all_logprobs)
1684
+ ):
1685
+ if token == bos_token_id:
1686
+ continue
1687
+ text_offsets.append(
1688
+ text_offset
1689
+ + len(
1690
+ self.detokenize(all_tokens[:idx]).decode(
1691
+ "utf-8", errors="ignore"
1692
+ )
1693
+ )
1694
+ )
1695
+ tokens.append(token_str)
1696
+ sorted_logprobs = list(
1697
+ sorted(
1698
+ zip(logprobs_token, range(len(logprobs_token))), reverse=True
1699
+ )
1700
+ )
1701
+ token_logprobs.append(logprobs_token[int(token)])
1702
+ top_logprob: Optional[Dict[str, float]] = {
1703
+ self.detokenize([i], prev_tokens=all_tokens[:idx]).decode(
1704
+ "utf-8", errors="ignore"
1705
+ ): logprob
1706
+ for logprob, i in sorted_logprobs[:logprobs]
1707
+ }
1708
+ top_logprob.update({token_str: logprobs_token[int(token)]})
1709
+ top_logprobs.append(top_logprob)
1710
+ # Weird idosincracy of the OpenAI API where
1711
+ # token_logprobs and top_logprobs are null for
1712
+ # the first token.
1713
+ if echo and len(all_tokens) > 0:
1714
+ token_logprobs[0] = None
1715
+ top_logprobs[0] = None
1716
+ logprobs_or_none = {
1717
+ "tokens": tokens,
1718
+ "text_offset": text_offsets,
1719
+ "token_logprobs": token_logprobs,
1720
+ "top_logprobs": top_logprobs,
1721
+ }
1722
+
1723
+ yield {
1724
+ "id": completion_id,
1725
+ "object": "text_completion",
1726
+ "created": created,
1727
+ "model": model_name,
1728
+ "choices": [
1729
+ {
1730
+ "text": text_str,
1731
+ "index": 0,
1732
+ "logprobs": logprobs_or_none,
1733
+ "finish_reason": finish_reason,
1734
+ }
1735
+ ],
1736
+ "usage": {
1737
+ "prompt_tokens": len(prompt_tokens),
1738
+ "completion_tokens": len(completion_tokens),
1739
+ "total_tokens": len(prompt_tokens) + len(completion_tokens),
1740
+ },
1741
+ }
1742
+
1743
+ def create_completion(
1744
+ self,
1745
+ prompt: Union[str, List[int]],
1746
+ suffix: Optional[str] = None,
1747
+ max_tokens: Optional[int] = 16,
1748
+ temperature: float = 0.8,
1749
+ top_p: float = 0.95,
1750
+ min_p: float = 0.05,
1751
+ typical_p: float = 1.0,
1752
+ logprobs: Optional[int] = None,
1753
+ echo: bool = False,
1754
+ stop: Optional[Union[str, List[str]]] = [],
1755
+ frequency_penalty: float = 0.0,
1756
+ presence_penalty: float = 0.0,
1757
+ repeat_penalty: float = 1.0,
1758
+ top_k: int = 40,
1759
+ stream: bool = False,
1760
+ seed: Optional[int] = None,
1761
+ tfs_z: float = 1.0,
1762
+ mirostat_mode: int = 0,
1763
+ mirostat_tau: float = 5.0,
1764
+ mirostat_eta: float = 0.1,
1765
+ model: Optional[str] = None,
1766
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
1767
+ logits_processor: Optional[LogitsProcessorList] = None,
1768
+ grammar: Optional[LlamaGrammar] = None,
1769
+ logit_bias: Optional[Dict[int, float]] = None,
1770
+ ) -> Union[CreateCompletionResponse, Iterator[CreateCompletionStreamResponse]]:
1771
+ """Generate text from a prompt.
1772
+
1773
+ Args:
1774
+ prompt: The prompt to generate text from.
1775
+ suffix: A suffix to append to the generated text. If None, no suffix is appended.
1776
+ max_tokens: The maximum number of tokens to generate. If max_tokens <= 0 or None, the maximum number of tokens to generate is unlimited and depends on n_ctx.
1777
+ temperature: The temperature to use for sampling.
1778
+ top_p: The top-p value to use for nucleus sampling. Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
1779
+ min_p: The min-p value to use for minimum p sampling. Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841
1780
+ typical_p: The typical-p value to use for sampling. Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
1781
+ logprobs: The number of logprobs to return. If None, no logprobs are returned.
1782
+ echo: Whether to echo the prompt.
1783
+ stop: A list of strings to stop generation when encountered.
1784
+ frequency_penalty: The penalty to apply to tokens based on their frequency in the prompt.
1785
+ presence_penalty: The penalty to apply to tokens based on their presence in the prompt.
1786
+ repeat_penalty: The penalty to apply to repeated tokens.
1787
+ top_k: The top-k value to use for sampling. Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
1788
+ stream: Whether to stream the results.
1789
+ seed: The seed to use for sampling.
1790
+ tfs_z: The tail-free sampling parameter. Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
1791
+ mirostat_mode: The mirostat sampling mode.
1792
+ mirostat_tau: The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
1793
+ mirostat_eta: The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
1794
+ model: The name to use for the model in the completion object.
1795
+ stopping_criteria: A list of stopping criteria to use.
1796
+ logits_processor: A list of logits processors to use.
1797
+ grammar: A grammar to use for constrained sampling.
1798
+ logit_bias: A logit bias to use.
1799
+
1800
+ Raises:
1801
+ ValueError: If the requested tokens exceed the context window.
1802
+ RuntimeError: If the prompt fails to tokenize or the model fails to evaluate the prompt.
1803
+
1804
+ Returns:
1805
+ Response object containing the generated text.
1806
+ """
1807
+ completion_or_chunks = self._create_completion(
1808
+ prompt=prompt,
1809
+ suffix=suffix,
1810
+ max_tokens=-1 if max_tokens is None else max_tokens,
1811
+ temperature=temperature,
1812
+ top_p=top_p,
1813
+ min_p=min_p,
1814
+ typical_p=typical_p,
1815
+ logprobs=logprobs,
1816
+ echo=echo,
1817
+ stop=stop,
1818
+ frequency_penalty=frequency_penalty,
1819
+ presence_penalty=presence_penalty,
1820
+ repeat_penalty=repeat_penalty,
1821
+ top_k=top_k,
1822
+ stream=stream,
1823
+ seed=seed,
1824
+ tfs_z=tfs_z,
1825
+ mirostat_mode=mirostat_mode,
1826
+ mirostat_tau=mirostat_tau,
1827
+ mirostat_eta=mirostat_eta,
1828
+ model=model,
1829
+ stopping_criteria=stopping_criteria,
1830
+ logits_processor=logits_processor,
1831
+ grammar=grammar,
1832
+ logit_bias=logit_bias,
1833
+ )
1834
+ if stream:
1835
+ chunks: Iterator[CreateCompletionStreamResponse] = completion_or_chunks
1836
+ return chunks
1837
+ completion: Completion = next(completion_or_chunks) # type: ignore
1838
+ return completion
1839
+
1840
+ def __call__(
1841
+ self,
1842
+ prompt: str,
1843
+ suffix: Optional[str] = None,
1844
+ max_tokens: Optional[int] = 16,
1845
+ temperature: float = 0.8,
1846
+ top_p: float = 0.95,
1847
+ min_p: float = 0.05,
1848
+ typical_p: float = 1.0,
1849
+ logprobs: Optional[int] = None,
1850
+ echo: bool = False,
1851
+ stop: Optional[Union[str, List[str]]] = [],
1852
+ frequency_penalty: float = 0.0,
1853
+ presence_penalty: float = 0.0,
1854
+ repeat_penalty: float = 1.0,
1855
+ top_k: int = 40,
1856
+ stream: bool = False,
1857
+ seed: Optional[int] = None,
1858
+ tfs_z: float = 1.0,
1859
+ mirostat_mode: int = 0,
1860
+ mirostat_tau: float = 5.0,
1861
+ mirostat_eta: float = 0.1,
1862
+ model: Optional[str] = None,
1863
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
1864
+ logits_processor: Optional[LogitsProcessorList] = None,
1865
+ grammar: Optional[LlamaGrammar] = None,
1866
+ logit_bias: Optional[Dict[int, float]] = None,
1867
+ ) -> Union[CreateCompletionResponse, Iterator[CreateCompletionStreamResponse]]:
1868
+ """Generate text from a prompt.
1869
+
1870
+ Args:
1871
+ prompt: The prompt to generate text from.
1872
+ suffix: A suffix to append to the generated text. If None, no suffix is appended.
1873
+ max_tokens: The maximum number of tokens to generate. If max_tokens <= 0 or None, the maximum number of tokens to generate is unlimited and depends on n_ctx.
1874
+ temperature: The temperature to use for sampling.
1875
+ top_p: The top-p value to use for nucleus sampling. Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
1876
+ min_p: The min-p value to use for minimum p sampling. Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841
1877
+ typical_p: The typical-p value to use for sampling. Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
1878
+ logprobs: The number of logprobs to return. If None, no logprobs are returned.
1879
+ echo: Whether to echo the prompt.
1880
+ stop: A list of strings to stop generation when encountered.
1881
+ frequency_penalty: The penalty to apply to tokens based on their frequency in the prompt.
1882
+ presence_penalty: The penalty to apply to tokens based on their presence in the prompt.
1883
+ repeat_penalty: The penalty to apply to repeated tokens.
1884
+ top_k: The top-k value to use for sampling. Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
1885
+ stream: Whether to stream the results.
1886
+ seed: The seed to use for sampling.
1887
+ tfs_z: The tail-free sampling parameter. Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
1888
+ mirostat_mode: The mirostat sampling mode.
1889
+ mirostat_tau: The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
1890
+ mirostat_eta: The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
1891
+ model: The name to use for the model in the completion object.
1892
+ stopping_criteria: A list of stopping criteria to use.
1893
+ logits_processor: A list of logits processors to use.
1894
+ grammar: A grammar to use for constrained sampling.
1895
+ logit_bias: A logit bias to use.
1896
+
1897
+ Raises:
1898
+ ValueError: If the requested tokens exceed the context window.
1899
+ RuntimeError: If the prompt fails to tokenize or the model fails to evaluate the prompt.
1900
+
1901
+ Returns:
1902
+ Response object containing the generated text.
1903
+ """
1904
+ return self.create_completion(
1905
+ prompt=prompt,
1906
+ suffix=suffix,
1907
+ max_tokens=max_tokens,
1908
+ temperature=temperature,
1909
+ top_p=top_p,
1910
+ min_p=min_p,
1911
+ typical_p=typical_p,
1912
+ logprobs=logprobs,
1913
+ echo=echo,
1914
+ stop=stop,
1915
+ frequency_penalty=frequency_penalty,
1916
+ presence_penalty=presence_penalty,
1917
+ repeat_penalty=repeat_penalty,
1918
+ top_k=top_k,
1919
+ stream=stream,
1920
+ seed=seed,
1921
+ tfs_z=tfs_z,
1922
+ mirostat_mode=mirostat_mode,
1923
+ mirostat_tau=mirostat_tau,
1924
+ mirostat_eta=mirostat_eta,
1925
+ model=model,
1926
+ stopping_criteria=stopping_criteria,
1927
+ logits_processor=logits_processor,
1928
+ grammar=grammar,
1929
+ logit_bias=logit_bias,
1930
+ )
1931
+
1932
+ def create_chat_completion(
1933
+ self,
1934
+ messages: List[ChatCompletionRequestMessage],
1935
+ functions: Optional[List[ChatCompletionFunction]] = None,
1936
+ function_call: Optional[ChatCompletionRequestFunctionCall] = None,
1937
+ tools: Optional[List[ChatCompletionTool]] = None,
1938
+ tool_choice: Optional[ChatCompletionToolChoiceOption] = None,
1939
+ temperature: float = 0.2,
1940
+ top_p: float = 0.95,
1941
+ top_k: int = 40,
1942
+ min_p: float = 0.05,
1943
+ typical_p: float = 1.0,
1944
+ stream: bool = False,
1945
+ stop: Optional[Union[str, List[str]]] = [],
1946
+ seed: Optional[int] = None,
1947
+ response_format: Optional[ChatCompletionRequestResponseFormat] = None,
1948
+ max_tokens: Optional[int] = None,
1949
+ presence_penalty: float = 0.0,
1950
+ frequency_penalty: float = 0.0,
1951
+ repeat_penalty: float = 1.0,
1952
+ tfs_z: float = 1.0,
1953
+ mirostat_mode: int = 0,
1954
+ mirostat_tau: float = 5.0,
1955
+ mirostat_eta: float = 0.1,
1956
+ model: Optional[str] = None,
1957
+ logits_processor: Optional[LogitsProcessorList] = None,
1958
+ grammar: Optional[LlamaGrammar] = None,
1959
+ logit_bias: Optional[Dict[int, float]] = None,
1960
+ logprobs: Optional[bool] = None,
1961
+ top_logprobs: Optional[int] = None,
1962
+ ) -> Union[
1963
+ CreateChatCompletionResponse, Iterator[CreateChatCompletionStreamResponse]
1964
+ ]:
1965
+ """Generate a chat completion from a list of messages.
1966
+
1967
+ Args:
1968
+ messages: A list of messages to generate a response for.
1969
+ functions: A list of functions to use for the chat completion.
1970
+ function_call: A function call to use for the chat completion.
1971
+ tools: A list of tools to use for the chat completion.
1972
+ tool_choice: A tool choice to use for the chat completion.
1973
+ temperature: The temperature to use for sampling.
1974
+ top_p: The top-p value to use for nucleus sampling. Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
1975
+ top_k: The top-k value to use for sampling. Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
1976
+ min_p: The min-p value to use for minimum p sampling. Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841
1977
+ typical_p: The typical-p value to use for sampling. Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
1978
+ stream: Whether to stream the results.
1979
+ stop: A list of strings to stop generation when encountered.
1980
+ seed: The seed to use for sampling.
1981
+ response_format: The response format to use for the chat completion. Use { "type": "json_object" } to contstrain output to only valid json.
1982
+ max_tokens: The maximum number of tokens to generate. If max_tokens <= 0 or None, the maximum number of tokens to generate is unlimited and depends on n_ctx.
1983
+ presence_penalty: The penalty to apply to tokens based on their presence in the prompt.
1984
+ frequency_penalty: The penalty to apply to tokens based on their frequency in the prompt.
1985
+ repeat_penalty: The penalty to apply to repeated tokens.
1986
+ tfs_z: The tail-free sampling parameter.
1987
+ mirostat_mode: The mirostat sampling mode.
1988
+ mirostat_tau: The mirostat sampling tau parameter.
1989
+ mirostat_eta: The mirostat sampling eta parameter.
1990
+ model: The name to use for the model in the completion object.
1991
+ logits_processor: A list of logits processors to use.
1992
+ grammar: A grammar to use.
1993
+ logit_bias: A logit bias to use.
1994
+
1995
+ Returns:
1996
+ Generated chat completion or a stream of chat completion chunks.
1997
+ """
1998
+ handler = (
1999
+ self.chat_handler
2000
+ or self._chat_handlers.get(self.chat_format)
2001
+ or llama_chat_format.get_chat_completion_handler(self.chat_format)
2002
+ )
2003
+ return handler(
2004
+ llama=self,
2005
+ messages=messages,
2006
+ functions=functions,
2007
+ function_call=function_call,
2008
+ tools=tools,
2009
+ tool_choice=tool_choice,
2010
+ temperature=temperature,
2011
+ top_p=top_p,
2012
+ top_k=top_k,
2013
+ min_p=min_p,
2014
+ typical_p=typical_p,
2015
+ logprobs=logprobs,
2016
+ top_logprobs=top_logprobs,
2017
+ stream=stream,
2018
+ stop=stop,
2019
+ seed=seed,
2020
+ response_format=response_format,
2021
+ max_tokens=max_tokens,
2022
+ presence_penalty=presence_penalty,
2023
+ frequency_penalty=frequency_penalty,
2024
+ repeat_penalty=repeat_penalty,
2025
+ tfs_z=tfs_z,
2026
+ mirostat_mode=mirostat_mode,
2027
+ mirostat_tau=mirostat_tau,
2028
+ mirostat_eta=mirostat_eta,
2029
+ model=model,
2030
+ logits_processor=logits_processor,
2031
+ grammar=grammar,
2032
+ logit_bias=logit_bias,
2033
+ )
2034
+
2035
+ def create_chat_completion_openai_v1(
2036
+ self,
2037
+ *args: Any,
2038
+ **kwargs: Any,
2039
+ ):
2040
+ """Generate a chat completion with return type based on the the OpenAI v1 API.
2041
+
2042
+ OpenAI python package is required to use this method.
2043
+
2044
+ You can install it with `pip install openai`.
2045
+
2046
+ Args:
2047
+ *args: Positional arguments to pass to create_chat_completion.
2048
+ **kwargs: Keyword arguments to pass to create_chat_completion.
2049
+
2050
+ Returns:
2051
+ Generated chat completion or a stream of chat completion chunks.
2052
+ """
2053
+ try:
2054
+ from openai.types.chat import ChatCompletion, ChatCompletionChunk
2055
+
2056
+ stream = kwargs.get("stream", False) # type: ignore
2057
+ assert isinstance(stream, bool)
2058
+ if stream:
2059
+ return (ChatCompletionChunk(**chunk) for chunk in self.create_chat_completion(*args, **kwargs)) # type: ignore
2060
+ else:
2061
+ return ChatCompletion(**self.create_chat_completion(*args, **kwargs)) # type: ignore
2062
+ except ImportError:
2063
+ raise ImportError(
2064
+ "To use create_chat_completion_openai_v1, you must install the openai package."
2065
+ "You can install it with `pip install openai`."
2066
+ )
2067
+
2068
+ def __getstate__(self):
2069
+ return dict(
2070
+ model_path=self.model_path,
2071
+ # Model Params
2072
+ n_gpu_layers=self.model_params.n_gpu_layers,
2073
+ split_mode=self.model_params.split_mode,
2074
+ main_gpu=self.model_params.main_gpu,
2075
+ tensor_split=self.tensor_split,
2076
+ vocab_only=self.model_params.vocab_only,
2077
+ use_mmap=self.model_params.use_mmap,
2078
+ use_mlock=self.model_params.use_mlock,
2079
+ kv_overrides=self.kv_overrides,
2080
+ # Context Params
2081
+ seed=self._seed,
2082
+ n_ctx=self.context_params.n_ctx,
2083
+ n_batch=self.n_batch,
2084
+ n_ubatch=self.context_params.n_ubatch,
2085
+ n_threads=self.context_params.n_threads,
2086
+ n_threads_batch=self.context_params.n_threads_batch,
2087
+ rope_scaling_type=self.context_params.rope_scaling_type,
2088
+ pooling_type=self.context_params.pooling_type,
2089
+ rope_freq_base=self.context_params.rope_freq_base,
2090
+ rope_freq_scale=self.context_params.rope_freq_scale,
2091
+ yarn_ext_factor=self.context_params.yarn_ext_factor,
2092
+ yarn_attn_factor=self.context_params.yarn_attn_factor,
2093
+ yarn_beta_fast=self.context_params.yarn_beta_fast,
2094
+ yarn_beta_slow=self.context_params.yarn_beta_slow,
2095
+ yarn_orig_ctx=self.context_params.yarn_orig_ctx,
2096
+ logits_all=self._logits_all,
2097
+ embedding=self.context_params.embeddings,
2098
+ offload_kqv=self.context_params.offload_kqv,
2099
+ flash_attn=self.context_params.flash_attn,
2100
+ op_offload=self.context_params.op_offload,
2101
+ swa_full=self.context_params.swa_full,
2102
+ # Sampling Params
2103
+ no_perf=self.context_params.no_perf,
2104
+ last_n_tokens_size=self.last_n_tokens_size,
2105
+ # LoRA Params
2106
+ lora_base=self.lora_base,
2107
+ lora_scale=self.lora_scale,
2108
+ lora_path=self.lora_path,
2109
+ # Backend Params
2110
+ numa=self.numa,
2111
+ # Chat Format Params
2112
+ chat_format=self.chat_format,
2113
+ chat_handler=self.chat_handler,
2114
+ # Speculative Decidng
2115
+ draft_model=self.draft_model,
2116
+ # KV cache quantization
2117
+ type_k=self.context_params.type_k,
2118
+ type_v=self.context_params.type_v,
2119
+ # Misc
2120
+ spm_infill=self.spm_infill,
2121
+ verbose=self.verbose,
2122
+ )
2123
+
2124
+ def __setstate__(self, state):
2125
+ self.__init__(**state)
2126
+
2127
+ def save_state(self) -> LlamaState:
2128
+ if self.verbose:
2129
+ print("Llama.save_state: saving llama state", file=sys.stderr)
2130
+ state_size = llama_cpp.llama_get_state_size(self._ctx.ctx)
2131
+ if self.verbose:
2132
+ print(f"Llama.save_state: got state size: {state_size}", file=sys.stderr)
2133
+ llama_state = (ctypes.c_uint8 * int(state_size))()
2134
+ if self.verbose:
2135
+ print("Llama.save_state: allocated state", file=sys.stderr)
2136
+ n_bytes = llama_cpp.llama_copy_state_data(self._ctx.ctx, llama_state)
2137
+ if self.verbose:
2138
+ print(f"Llama.save_state: copied llama state: {n_bytes}", file=sys.stderr)
2139
+ if int(n_bytes) > int(state_size):
2140
+ raise RuntimeError("Failed to copy llama state data")
2141
+ llama_state_compact = (ctypes.c_uint8 * int(n_bytes))()
2142
+ llama_cpp.ctypes.memmove(llama_state_compact, llama_state, int(n_bytes))
2143
+ if self.verbose:
2144
+ print(
2145
+ f"Llama.save_state: saving {n_bytes} bytes of llama state",
2146
+ file=sys.stderr,
2147
+ )
2148
+ return LlamaState(
2149
+ scores=self._scores.copy(),
2150
+ input_ids=self.input_ids.copy(),
2151
+ n_tokens=self.n_tokens,
2152
+ llama_state=bytes(llama_state_compact),
2153
+ llama_state_size=n_bytes,
2154
+ seed=self._seed,
2155
+ )
2156
+
2157
+ def load_state(self, state: LlamaState) -> None:
2158
+ # Only filling in up to `n_tokens` and then zero-ing out the rest
2159
+ self.scores[: state.n_tokens, :] = state.scores.copy()
2160
+ rest = self.scores[state.n_tokens :, :]
2161
+ rest[rest > 0] = 0.0
2162
+ self.input_ids = state.input_ids.copy()
2163
+ self.n_tokens = state.n_tokens
2164
+ self._seed = state.seed
2165
+ state_size = state.llama_state_size
2166
+ LLamaStateArrayType = ctypes.c_uint8 * state_size
2167
+ llama_state = LLamaStateArrayType.from_buffer_copy(state.llama_state)
2168
+
2169
+ if llama_cpp.llama_set_state_data(self._ctx.ctx, llama_state) != state_size:
2170
+ raise RuntimeError("Failed to set llama state data")
2171
+
2172
+ def n_ctx(self) -> int:
2173
+ """Return the context window size."""
2174
+ return self._ctx.n_ctx()
2175
+
2176
+ def n_embd(self) -> int:
2177
+ """Return the embedding size."""
2178
+ return self._model.n_embd()
2179
+
2180
+ def n_vocab(self) -> int:
2181
+ """Return the vocabulary size."""
2182
+ return self._model.n_vocab()
2183
+
2184
+ def tokenizer(self) -> LlamaTokenizer:
2185
+ """Return the llama tokenizer for this model."""
2186
+ return LlamaTokenizer(self)
2187
+
2188
+ def token_eos(self) -> int:
2189
+ """Return the end-of-sequence token."""
2190
+ return self._model.token_eos()
2191
+
2192
+ def token_bos(self) -> int:
2193
+ """Return the beginning-of-sequence token."""
2194
+ return self._model.token_bos()
2195
+
2196
+ def token_nl(self) -> int:
2197
+ """Return the newline token."""
2198
+ return self._model.token_nl()
2199
+
2200
+ def pooling_type(self) -> str:
2201
+ """Return the pooling type."""
2202
+ return self._ctx.pooling_type()
2203
+
2204
+ def close(self) -> None:
2205
+ """Explicitly free the model from memory."""
2206
+ self._stack.close()
2207
+
2208
+ def __del__(self) -> None:
2209
+ self.close()
2210
+
2211
+ @staticmethod
2212
+ def logits_to_logprobs(
2213
+ logits: Union[npt.NDArray[np.single], List], axis: int = -1
2214
+ ) -> npt.NDArray[np.single]:
2215
+ # https://docs.scipy.org/doc/scipy/reference/generated/scipy.special.log_softmax.html
2216
+ logits_maxs: np.ndarray = np.amax(logits, axis=axis, keepdims=True)
2217
+ if logits_maxs.ndim > 0:
2218
+ logits_maxs[~np.isfinite(logits_maxs)] = 0
2219
+ elif not np.isfinite(logits_maxs):
2220
+ logits_maxs = 0
2221
+ subtract_maxs = np.subtract(logits, logits_maxs, dtype=np.single)
2222
+ exp = np.exp(subtract_maxs)
2223
+ # Suppress warnings about log of zero
2224
+ with np.errstate(divide="ignore"):
2225
+ summed = np.sum(exp, axis=axis, keepdims=True)
2226
+ out = np.log(summed)
2227
+ return subtract_maxs - out
2228
+
2229
+ @staticmethod
2230
+ def longest_token_prefix(a: Sequence[int], b: Sequence[int]):
2231
+ longest_prefix = 0
2232
+ for _a, _b in zip(a, b):
2233
+ if _a == _b:
2234
+ longest_prefix += 1
2235
+ else:
2236
+ break
2237
+ return longest_prefix
2238
+
2239
+ @classmethod
2240
+ def from_pretrained(
2241
+ cls,
2242
+ repo_id: str,
2243
+ filename: Optional[str],
2244
+ additional_files: Optional[List] = None,
2245
+ local_dir: Optional[Union[str, os.PathLike[str]]] = None,
2246
+ local_dir_use_symlinks: Union[bool, Literal["auto"]] = "auto",
2247
+ cache_dir: Optional[Union[str, os.PathLike[str]]] = None,
2248
+ **kwargs: Any,
2249
+ ) -> "Llama":
2250
+ """Create a Llama model from a pretrained model name or path.
2251
+ This method requires the huggingface-hub package.
2252
+ You can install it with `pip install huggingface-hub`.
2253
+
2254
+ Args:
2255
+ repo_id: The model repo id.
2256
+ filename: A filename or glob pattern to match the model file in the repo.
2257
+ additional_files: A list of filenames or glob patterns to match additional model files in the repo.
2258
+ local_dir: The local directory to save the model to.
2259
+ local_dir_use_symlinks: Whether to use symlinks when downloading the model.
2260
+ **kwargs: Additional keyword arguments to pass to the Llama constructor.
2261
+
2262
+ Returns:
2263
+ A Llama model."""
2264
+ try:
2265
+ from huggingface_hub import hf_hub_download, HfFileSystem
2266
+ from huggingface_hub.utils import validate_repo_id
2267
+ except ImportError:
2268
+ raise ImportError(
2269
+ "Llama.from_pretrained requires the huggingface-hub package. "
2270
+ "You can install it with `pip install huggingface-hub`."
2271
+ )
2272
+
2273
+ validate_repo_id(repo_id)
2274
+
2275
+ hffs = HfFileSystem()
2276
+
2277
+ files = [
2278
+ file["name"] if isinstance(file, dict) else file
2279
+ for file in hffs.ls(repo_id, recursive=True)
2280
+ ]
2281
+
2282
+ # split each file into repo_id, subfolder, filename
2283
+ file_list: List[str] = []
2284
+ for file in files:
2285
+ rel_path = Path(file).relative_to(repo_id)
2286
+ file_list.append(str(rel_path))
2287
+
2288
+ # find the only/first shard file:
2289
+ matching_files = [file for file in file_list if fnmatch.fnmatch(file, filename)] # type: ignore
2290
+
2291
+ if len(matching_files) == 0:
2292
+ raise ValueError(
2293
+ f"No file found in {repo_id} that match {filename}\n\n"
2294
+ f"Available Files:\n{json.dumps(file_list)}"
2295
+ )
2296
+
2297
+ if len(matching_files) > 1:
2298
+ raise ValueError(
2299
+ f"Multiple files found in {repo_id} matching {filename}\n\n"
2300
+ f"Available Files:\n{json.dumps(files)}"
2301
+ )
2302
+
2303
+ (matching_file,) = matching_files
2304
+
2305
+ subfolder = str(Path(matching_file).parent)
2306
+ filename = Path(matching_file).name
2307
+
2308
+ # download the file
2309
+ hf_hub_download(
2310
+ repo_id=repo_id,
2311
+ filename=filename,
2312
+ subfolder=subfolder,
2313
+ local_dir=local_dir,
2314
+ local_dir_use_symlinks=local_dir_use_symlinks,
2315
+ cache_dir=cache_dir,
2316
+ )
2317
+
2318
+ if additional_files:
2319
+ for additonal_file_name in additional_files:
2320
+ # find the additional shard file:
2321
+ matching_additional_files = [file for file in file_list if fnmatch.fnmatch(file, additonal_file_name)]
2322
+
2323
+ if len(matching_additional_files) == 0:
2324
+ raise ValueError(
2325
+ f"No file found in {repo_id} that match {additonal_file_name}\n\n"
2326
+ f"Available Files:\n{json.dumps(file_list)}"
2327
+ )
2328
+
2329
+ if len(matching_additional_files) > 1:
2330
+ raise ValueError(
2331
+ f"Multiple files found in {repo_id} matching {additonal_file_name}\n\n"
2332
+ f"Available Files:\n{json.dumps(files)}"
2333
+ )
2334
+
2335
+ (matching_additional_file,) = matching_additional_files
2336
+
2337
+ # download the additional file
2338
+ hf_hub_download(
2339
+ repo_id=repo_id,
2340
+ filename=matching_additional_file,
2341
+ subfolder=subfolder,
2342
+ local_dir=local_dir,
2343
+ local_dir_use_symlinks=local_dir_use_symlinks,
2344
+ cache_dir=cache_dir,
2345
+ )
2346
+
2347
+ if local_dir is None:
2348
+ model_path = hf_hub_download(
2349
+ repo_id=repo_id,
2350
+ filename=filename,
2351
+ subfolder=subfolder,
2352
+ local_dir=local_dir,
2353
+ local_dir_use_symlinks=local_dir_use_symlinks,
2354
+ cache_dir=cache_dir,
2355
+ local_files_only=True,
2356
+ )
2357
+ else:
2358
+ model_path = os.path.join(local_dir, filename)
2359
+
2360
+ # loading the first file of a sharded GGUF loads all remaining shard files in the subfolder
2361
+ return cls(
2362
+ model_path=model_path,
2363
+ **kwargs,
2364
+ )
2365
+
2366
+
2367
+ class LlamaState:
2368
+ def __init__(
2369
+ self,
2370
+ input_ids: npt.NDArray[np.intc],
2371
+ scores: npt.NDArray[np.single],
2372
+ n_tokens: int,
2373
+ llama_state: bytes,
2374
+ llama_state_size: int,
2375
+ seed: int,
2376
+ ):
2377
+ self.input_ids = input_ids
2378
+ self.scores = scores
2379
+ self.n_tokens = n_tokens
2380
+ self.llama_state = llama_state
2381
+ self.llama_state_size = llama_state_size
2382
+ self.seed = seed
2383
+
2384
+
2385
+ LogitsProcessor = Callable[
2386
+ [npt.NDArray[np.intc], npt.NDArray[np.single]], npt.NDArray[np.single]
2387
+ ]
2388
+
2389
+
2390
+ class LogitsProcessorList(List[LogitsProcessor]):
2391
+ def __call__(
2392
+ self, input_ids: npt.NDArray[np.intc], scores: npt.NDArray[np.single]
2393
+ ) -> npt.NDArray[np.single]:
2394
+ for processor in self:
2395
+ scores = processor(input_ids, scores)
2396
+ return scores
2397
+
2398
+
2399
+ StoppingCriteria = Callable[[npt.NDArray[np.intc], npt.NDArray[np.single]], bool]
2400
+
2401
+
2402
+ class StoppingCriteriaList(List[StoppingCriteria]):
2403
+ def __call__(
2404
+ self, input_ids: npt.NDArray[np.intc], logits: npt.NDArray[np.single]
2405
+ ) -> bool:
2406
+ return any([stopping_criteria(input_ids, logits) for stopping_criteria in self])
2407
+
2408
+
2409
+ class MinTokensLogitsProcessor(LogitsProcessor):
2410
+ def __init__(self, min_tokens: int, token_eos: int):
2411
+ self.min_tokens = min_tokens
2412
+ self.token_eos = token_eos
2413
+ self.prompt_tokens = None
2414
+
2415
+ def __call__(
2416
+ self, input_ids: npt.NDArray[np.intc], scores: npt.NDArray[np.single]
2417
+ ) -> npt.NDArray[np.single]:
2418
+ if self.prompt_tokens is None:
2419
+ self.prompt_tokens = len(input_ids)
2420
+ if len(input_ids) - self.prompt_tokens < self.min_tokens:
2421
+ scores[self.token_eos] = -np.inf
2422
+ return scores
llama_cpp/llama_cache.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from abc import ABC, abstractmethod
3
+ from typing import (
4
+ Optional,
5
+ Sequence,
6
+ Tuple,
7
+ )
8
+ from collections import OrderedDict
9
+
10
+ import diskcache
11
+
12
+ import llama_cpp.llama
13
+
14
+ from .llama_types import *
15
+
16
+
17
+ class BaseLlamaCache(ABC):
18
+ """Base cache class for a llama.cpp model."""
19
+
20
+ def __init__(self, capacity_bytes: int = (2 << 30)):
21
+ self.capacity_bytes = capacity_bytes
22
+
23
+ @property
24
+ @abstractmethod
25
+ def cache_size(self) -> int:
26
+ raise NotImplementedError
27
+
28
+ def _find_longest_prefix_key(
29
+ self,
30
+ key: Tuple[int, ...],
31
+ ) -> Optional[Tuple[int, ...]]:
32
+ pass
33
+
34
+ @abstractmethod
35
+ def __getitem__(self, key: Sequence[int]) -> "llama_cpp.llama.LlamaState":
36
+ raise NotImplementedError
37
+
38
+ @abstractmethod
39
+ def __contains__(self, key: Sequence[int]) -> bool:
40
+ raise NotImplementedError
41
+
42
+ @abstractmethod
43
+ def __setitem__(
44
+ self, key: Sequence[int], value: "llama_cpp.llama.LlamaState"
45
+ ) -> None:
46
+ raise NotImplementedError
47
+
48
+
49
+ class LlamaRAMCache(BaseLlamaCache):
50
+ """Cache for a llama.cpp model using RAM."""
51
+
52
+ def __init__(self, capacity_bytes: int = (2 << 30)):
53
+ super().__init__(capacity_bytes)
54
+ self.capacity_bytes = capacity_bytes
55
+ self.cache_state: OrderedDict[
56
+ Tuple[int, ...], "llama_cpp.llama.LlamaState"
57
+ ] = OrderedDict()
58
+
59
+ @property
60
+ def cache_size(self):
61
+ return sum([state.llama_state_size for state in self.cache_state.values()])
62
+
63
+ def _find_longest_prefix_key(
64
+ self,
65
+ key: Tuple[int, ...],
66
+ ) -> Optional[Tuple[int, ...]]:
67
+ min_len = 0
68
+ min_key = None
69
+ keys = (
70
+ (k, llama_cpp.llama.Llama.longest_token_prefix(k, key))
71
+ for k in self.cache_state.keys()
72
+ )
73
+ for k, prefix_len in keys:
74
+ if prefix_len > min_len:
75
+ min_len = prefix_len
76
+ min_key = k
77
+ return min_key
78
+
79
+ def __getitem__(self, key: Sequence[int]) -> "llama_cpp.llama.LlamaState":
80
+ key = tuple(key)
81
+ _key = self._find_longest_prefix_key(key)
82
+ if _key is None:
83
+ raise KeyError("Key not found")
84
+ value = self.cache_state[_key]
85
+ self.cache_state.move_to_end(_key)
86
+ return value
87
+
88
+ def __contains__(self, key: Sequence[int]) -> bool:
89
+ return self._find_longest_prefix_key(tuple(key)) is not None
90
+
91
+ def __setitem__(self, key: Sequence[int], value: "llama_cpp.llama.LlamaState"):
92
+ key = tuple(key)
93
+ if key in self.cache_state:
94
+ del self.cache_state[key]
95
+ self.cache_state[key] = value
96
+ while self.cache_size > self.capacity_bytes and len(self.cache_state) > 0:
97
+ self.cache_state.popitem(last=False)
98
+
99
+
100
+ # Alias for backwards compatibility
101
+ LlamaCache = LlamaRAMCache
102
+
103
+
104
+ class LlamaDiskCache(BaseLlamaCache):
105
+ """Cache for a llama.cpp model using disk."""
106
+
107
+ def __init__(
108
+ self, cache_dir: str = ".cache/llama_cache", capacity_bytes: int = (2 << 30)
109
+ ):
110
+ super().__init__(capacity_bytes)
111
+ self.cache = diskcache.Cache(cache_dir)
112
+
113
+ @property
114
+ def cache_size(self):
115
+ return int(self.cache.volume()) # type: ignore
116
+
117
+ def _find_longest_prefix_key(
118
+ self,
119
+ key: Tuple[int, ...],
120
+ ) -> Optional[Tuple[int, ...]]:
121
+ min_len = 0
122
+ min_key: Optional[Tuple[int, ...]] = None
123
+ for k in self.cache.iterkeys(): # type: ignore
124
+ prefix_len = llama_cpp.llama.Llama.longest_token_prefix(k, key)
125
+ if prefix_len > min_len:
126
+ min_len = prefix_len
127
+ min_key = k # type: ignore
128
+ return min_key
129
+
130
+ def __getitem__(self, key: Sequence[int]) -> "llama_cpp.llama.LlamaState":
131
+ key = tuple(key)
132
+ _key = self._find_longest_prefix_key(key)
133
+ if _key is None:
134
+ raise KeyError("Key not found")
135
+ value: "llama_cpp.llama.LlamaState" = self.cache.pop(_key) # type: ignore
136
+ # NOTE: This puts an integer as key in cache, which breaks,
137
+ # Llama.longest_token_prefix(k, key) above since k is not a tuple of ints/tokens
138
+ # self.cache.push(_key, side="front") # type: ignore
139
+ return value
140
+
141
+ def __contains__(self, key: Sequence[int]) -> bool:
142
+ return self._find_longest_prefix_key(tuple(key)) is not None
143
+
144
+ def __setitem__(self, key: Sequence[int], value: "llama_cpp.llama.LlamaState"):
145
+ print("LlamaDiskCache.__setitem__: called", file=sys.stderr)
146
+ key = tuple(key)
147
+ if key in self.cache:
148
+ print("LlamaDiskCache.__setitem__: delete", file=sys.stderr)
149
+ del self.cache[key]
150
+ self.cache[key] = value
151
+ print("LlamaDiskCache.__setitem__: set", file=sys.stderr)
152
+ while self.cache_size > self.capacity_bytes and len(self.cache) > 0:
153
+ key_to_remove = next(iter(self.cache))
154
+ del self.cache[key_to_remove]
155
+ print("LlamaDiskCache.__setitem__: trim", file=sys.stderr)
llama_cpp/llama_chat_format.py ADDED
The diff for this file is too large to render. See raw diff
 
llama_cpp/llama_cpp.py ADDED
The diff for this file is too large to render. See raw diff
 
llama_cpp/llama_grammar.py ADDED
@@ -0,0 +1,953 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Python implementation of llama grammar parser directly translated from C++ source file in vendor/llama.cpp/common/grammar-parser.cpp."""
2
+
3
+ # flake8: noqa
4
+ from pathlib import Path
5
+
6
+ from itertools import groupby
7
+ from typing import (
8
+ Any,
9
+ Set,
10
+ List,
11
+ Optional,
12
+ Tuple,
13
+ Union,
14
+ )
15
+
16
+ LLAMA_GRAMMAR_DEFAULT_ROOT = "root"
17
+
18
+
19
+ class LlamaGrammar:
20
+ def __init__(self, *args, _grammar: str, **kwargs):
21
+ self._grammar = _grammar
22
+ self._root = LLAMA_GRAMMAR_DEFAULT_ROOT
23
+
24
+ @classmethod
25
+ def from_string(cls, grammar: str, verbose: bool = True) -> "LlamaGrammar":
26
+ return cls(_grammar=grammar)
27
+
28
+ @classmethod
29
+ def from_file(cls, file: Union[str, Path], verbose: bool = True) -> "LlamaGrammar":
30
+ try:
31
+ with open(file) as f:
32
+ grammar = f.read()
33
+ except Exception as err:
34
+ raise Exception(
35
+ f"{cls.from_file.__name__}: error reading grammar file: {err}"
36
+ )
37
+
38
+ if grammar:
39
+ return cls.from_string(grammar, verbose=verbose)
40
+
41
+ raise ValueError(
42
+ f"{cls.from_file.__name__}: error parsing grammar file: params_grammer is empty"
43
+ )
44
+
45
+ @classmethod
46
+ def from_json_schema(cls, json_schema: str, verbose: bool = True) -> "LlamaGrammar":
47
+ return cls.from_string(json_schema_to_gbnf(json_schema), verbose=verbose)
48
+
49
+
50
+ """llama.cpp gbnf rules from vendor/llama.cpp/grammars"""
51
+
52
+ ARITHMETIC_GBNF = r"""
53
+ root ::= (expr "=" ws term "\n")+
54
+ expr ::= term ([-+*/] term)*
55
+ term ::= ident | num | "(" ws expr ")" ws
56
+ ident ::= [a-z] [a-z0-9_]* ws
57
+ num ::= [0-9]+ ws
58
+ ws ::= [ \t\n]*
59
+ """
60
+
61
+ C_GBNF = r"""
62
+ root ::= (declaration)*
63
+
64
+ declaration ::= dataType identifier "(" parameter? ")" "{" statement* "}"
65
+
66
+ dataType ::= "int" ws | "float" ws | "char" ws
67
+ identifier ::= [a-zA-Z_] [a-zA-Z_0-9]*
68
+
69
+ parameter ::= dataType identifier
70
+
71
+ statement ::=
72
+ ( dataType identifier ws "=" ws expression ";" ) |
73
+ ( identifier ws "=" ws expression ";" ) |
74
+ ( identifier ws "(" argList? ")" ";" ) |
75
+ ( "return" ws expression ";" ) |
76
+ ( "while" "(" condition ")" "{" statement* "}" ) |
77
+ ( "for" "(" forInit ";" ws condition ";" ws forUpdate ")" "{" statement* "}" ) |
78
+ ( "if" "(" condition ")" "{" statement* "}" ("else" "{" statement* "}")? ) |
79
+ ( singleLineComment ) |
80
+ ( multiLineComment )
81
+
82
+ forInit ::= dataType identifier ws "=" ws expression | identifier ws "=" ws expression
83
+ forUpdate ::= identifier ws "=" ws expression
84
+
85
+ condition ::= expression relationOperator expression
86
+ relationOperator ::= ("<=" | "<" | "==" | "!=" | ">=" | ">")
87
+
88
+ expression ::= term (("+" | "-") term)*
89
+ term ::= factor(("*" | "/") factor)*
90
+
91
+ factor ::= identifier | number | unaryTerm | funcCall | parenExpression
92
+ unaryTerm ::= "-" factor
93
+ funcCall ::= identifier "(" argList? ")"
94
+ parenExpression ::= "(" ws expression ws ")"
95
+
96
+ argList ::= expression ("," ws expression)*
97
+
98
+ number ::= [0-9]+
99
+
100
+ singleLineComment ::= "//" [^\n]* "\n"
101
+ multiLineComment ::= "/*" ( [^*] | ("*" [^/]) )* "*/"
102
+
103
+ ws ::= ([ \t\n]+)
104
+ """
105
+
106
+ CHESS_GBNF = r"""
107
+ root ::= object
108
+ value ::= object | array | string | number | ("true" | "false" | "null") ws
109
+
110
+ object ::=
111
+ "{" ws (
112
+ string ":" ws value
113
+ ("," ws string ":" ws value)*
114
+ )? "}" ws
115
+
116
+ array ::=
117
+ "[" ws (
118
+ value
119
+ ("," ws value)*
120
+ )? "]" ws
121
+
122
+ string ::=
123
+ "\"" (
124
+ [^"\\] |
125
+ "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
126
+ )* "\"" ws
127
+
128
+ number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws
129
+
130
+ # Optional space: by convention, applied in this grammar after literal chars when allowed
131
+ ws ::= ([ \t\n] ws)?
132
+ """
133
+
134
+ JAPANESE_GBNF = r"""
135
+ root ::= object
136
+ value ::= object | array | string | number | ("true" | "false" | "null") ws
137
+
138
+ object ::=
139
+ "{" ws (
140
+ string ":" ws value
141
+ ("," ws string ":" ws value)*
142
+ )? "}" ws
143
+
144
+ array ::=
145
+ "[" ws (
146
+ value
147
+ ("," ws value)*
148
+ )? "]" ws
149
+
150
+ string ::=
151
+ "\"" (
152
+ [^"\\] |
153
+ "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
154
+ )* "\"" ws
155
+
156
+ number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws
157
+
158
+ # Optional space: by convention, applied in this grammar after literal chars when allowed
159
+ ws ::= ([ \t\n] ws)?
160
+ """
161
+
162
+ JSON_ARR_GBNF = r"""
163
+ # This is the same as json.gbnf but we restrict whitespaces at the end of the root array
164
+ # Useful for generating JSON arrays
165
+
166
+ root ::= arr
167
+ value ::= object | array | string | number | ("true" | "false" | "null") ws
168
+
169
+ arr ::=
170
+ "[\n" ws (
171
+ value
172
+ (",\n" ws value)*
173
+ )? "]"
174
+
175
+ object ::=
176
+ "{" ws (
177
+ string ":" ws value
178
+ ("," ws string ":" ws value)*
179
+ )? "}" ws
180
+
181
+ array ::=
182
+ "[" ws (
183
+ value
184
+ ("," ws value)*
185
+ )? "]" ws
186
+
187
+ string ::=
188
+ "\"" (
189
+ [^"\\\x7F\x00-\x1F] |
190
+ "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
191
+ )* "\"" ws
192
+
193
+ number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws
194
+
195
+ # Optional space: by convention, applied in this grammar after literal chars when allowed
196
+ ws ::= ([ \t\n] ws)?
197
+ """
198
+
199
+
200
+ JSON_GBNF = r"""
201
+ root ::= object
202
+ value ::= object | array | string | number | ("true" | "false" | "null") ws
203
+
204
+ object ::=
205
+ "{" ws (
206
+ string ":" ws value
207
+ ("," ws string ":" ws value)*
208
+ )? "}" ws
209
+
210
+ array ::=
211
+ "[" ws (
212
+ value
213
+ ("," ws value)*
214
+ )? "]" ws
215
+
216
+ string ::=
217
+ "\"" (
218
+ [^"\\\x7F\x00-\x1F] |
219
+ "\\" (["\\bfnrt] | "u" [0-9a-fA-F]{4}) # escapes
220
+ )* "\"" ws
221
+
222
+ number ::= ("-"? ([0-9] | [1-9] [0-9]{0,15})) ("." [0-9]+)? ([eE] [-+]? [0-9] [1-9]{0,15})? ws
223
+
224
+ # Optional space: by convention, applied in this grammar after literal chars when allowed
225
+ ws ::= | " " | "\n" [ \t]{0,20}
226
+ """
227
+
228
+ LIST_GBNF = r"""
229
+ root ::= item+
230
+
231
+ # Excludes various line break characters
232
+ item ::= "- " [^\r\n\x0b\x0c\x85\u2028\u2029]+ "\n"
233
+ """
234
+
235
+ """llama.cpp json-schema to grammar converter from vendor/llama.cpp/examples/json-schema-to-grammar.py"""
236
+ import json
237
+ import re
238
+ from typing import List, Optional
239
+
240
+ # whitespace is constrained to a single space char to prevent model "running away" in
241
+ # whitespace. Also maybe improves generation quality?
242
+ SPACE_RULE = '" "?'
243
+
244
+
245
+ INVALID_RULE_CHARS_RE = re.compile(r"[^a-zA-Z0-9-]+")
246
+ GRAMMAR_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"]')
247
+ GRAMMAR_LITERAL_ESCAPES = {"\r": "\\r", "\n": "\\n", '"': '\\"'}
248
+
249
+ # whitespace is constrained to a single space char to prevent model "running away" in
250
+ # whitespace. Also maybe improves generation quality?
251
+ SPACE_RULE = '" "?'
252
+
253
+
254
+ def _build_repetition(
255
+ item_rule, min_items, max_items, separator_rule=None, item_rule_is_literal=False
256
+ ):
257
+ if not separator_rule:
258
+ if min_items == 0 and max_items == 1:
259
+ return f"{item_rule}?"
260
+ elif min_items == 1 and max_items is None:
261
+ return f"{item_rule}+"
262
+
263
+ result = ""
264
+
265
+ if min_items > 0:
266
+ if item_rule_is_literal and separator_rule is None:
267
+ result = '"' + (item_rule[1:-1] * min_items) + '"'
268
+ else:
269
+ result = (f" {separator_rule} " if separator_rule else " ").join(
270
+ [item_rule] * min_items
271
+ )
272
+
273
+ def opt_repetitions(up_to_n, prefix_with_sep=False):
274
+ """
275
+ - n=4, no sep: '(a (a (a (a)?)?)?)?'
276
+ - n=4, sep=',', prefix: '("," a ("," a ("," a ("," a)?)?)?)?'
277
+ - n=4, sep=',', no prefix: '(a ("," a ("," a ("," a)?)?)?)?'
278
+ """
279
+
280
+ content = (
281
+ f"{separator_rule} {item_rule}"
282
+ if prefix_with_sep and separator_rule
283
+ else item_rule
284
+ )
285
+ if up_to_n == 0:
286
+ return ""
287
+ elif up_to_n == 1:
288
+ return f"({content})?"
289
+ elif separator_rule and not prefix_with_sep:
290
+ return f"({content} {opt_repetitions(up_to_n - 1, prefix_with_sep=True)})?"
291
+ else:
292
+ return (f"({content} " * up_to_n).rstrip() + (")?" * up_to_n)
293
+
294
+ if min_items > 0 and max_items != min_items:
295
+ result += " "
296
+
297
+ if max_items is not None:
298
+ result += opt_repetitions(max_items - min_items, prefix_with_sep=min_items > 0)
299
+ else:
300
+ item_operator = f'({separator_rule + " " if separator_rule else ""}{item_rule})'
301
+
302
+ if min_items == 0 and separator_rule:
303
+ result = f"({item_rule} {item_operator}*)?"
304
+ else:
305
+ result += f"{item_operator}*"
306
+
307
+ return result
308
+
309
+
310
+ class BuiltinRule:
311
+ def __init__(self, content: str, deps: list = None):
312
+ self.content = content
313
+ self.deps = deps or []
314
+
315
+
316
+ _up_to_15_digits = _build_repetition("[0-9]", 0, 15)
317
+
318
+ PRIMITIVE_RULES = {
319
+ "boolean": BuiltinRule('("true" | "false") space', []),
320
+ "decimal-part": BuiltinRule("[0-9] " + _up_to_15_digits, []),
321
+ "integral-part": BuiltinRule("[0-9] | [1-9] " + _up_to_15_digits, []),
322
+ "number": BuiltinRule(
323
+ '("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space',
324
+ ["integral-part", "decimal-part"],
325
+ ),
326
+ "integer": BuiltinRule('("-"? integral-part) space', ["integral-part"]),
327
+ "value": BuiltinRule(
328
+ "object | array | string | number | boolean | null",
329
+ ["object", "array", "string", "number", "boolean", "null"],
330
+ ),
331
+ "object": BuiltinRule(
332
+ '"{" space ( string ":" space value ("," space string ":" space value)* )? "}" space',
333
+ ["string", "value"],
334
+ ),
335
+ "array": BuiltinRule(
336
+ '"[" space ( value ("," space value)* )? "]" space', ["value"]
337
+ ),
338
+ "uuid": BuiltinRule(
339
+ r'"\"" '
340
+ + ' "-" '.join("[0-9a-fA-F]" * n for n in [8, 4, 4, 4, 12])
341
+ + r' "\"" space',
342
+ [],
343
+ ),
344
+ "char": BuiltinRule(
345
+ r'[^"\\] | "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])',
346
+ [],
347
+ ),
348
+ "string": BuiltinRule(r'"\"" char* "\"" space', ["char"]),
349
+ "null": BuiltinRule('"null" space', []),
350
+ }
351
+
352
+ # TODO: support "uri", "email" string formats
353
+ STRING_FORMAT_RULES = {
354
+ "date": BuiltinRule(
355
+ '[0-9] [0-9] [0-9] [0-9] "-" ( "0" [1-9] | "1" [0-2] ) "-" ( "0" [1-9] | [1-2] [0-9] | "3" [0-1] )',
356
+ [],
357
+ ),
358
+ "time": BuiltinRule(
359
+ '([01] [0-9] | "2" [0-3]) ":" [0-5] [0-9] ":" [0-5] [0-9] ( "." [0-9] [0-9] [0-9] )? ( "Z" | ( "+" | "-" ) ( [01] [0-9] | "2" [0-3] ) ":" [0-5] [0-9] )',
360
+ [],
361
+ ),
362
+ "date-time": BuiltinRule('date "T" time', ["date", "time"]),
363
+ "date-string": BuiltinRule('"\\"" date "\\"" space', ["date"]),
364
+ "time-string": BuiltinRule('"\\"" time "\\"" space', ["time"]),
365
+ "date-time-string": BuiltinRule('"\\"" date-time "\\"" space', ["date-time"]),
366
+ }
367
+
368
+ DOTALL = "[\\U00000000-\\U0010FFFF]"
369
+ DOT = "[^\\x0A\\x0D]"
370
+
371
+ RESERVED_NAMES = set(
372
+ ["root", "dot", *PRIMITIVE_RULES.keys(), *STRING_FORMAT_RULES.keys()]
373
+ )
374
+
375
+
376
+ NON_LITERAL_SET = set("|.()[]{}*+?")
377
+ ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = set("[]()|{}*+?")
378
+
379
+
380
+ class SchemaConverter:
381
+ def __init__(self, *, prop_order, allow_fetch, dotall, raw_pattern):
382
+ self._prop_order = prop_order
383
+ self._allow_fetch = allow_fetch
384
+ self._dotall = dotall
385
+ self._raw_pattern = raw_pattern
386
+ self._rules = {
387
+ "space": SPACE_RULE,
388
+ }
389
+ self._refs = {}
390
+ self._refs_being_resolved = set()
391
+
392
+ def _format_literal(self, literal):
393
+ escaped = GRAMMAR_LITERAL_ESCAPE_RE.sub(
394
+ lambda m: GRAMMAR_LITERAL_ESCAPES.get(m.group(0)), literal
395
+ )
396
+ return f'"{escaped}"'
397
+
398
+ def not_literal(
399
+ self, literal: str, dotall: bool = True, maybe_escaped_underscores=False
400
+ ) -> str:
401
+ """
402
+ not_literal('a') -> '[^a]'
403
+ not_literal('abc') -> '([^a] | "a" ([^b] | "b" ([^c])?)?)?'
404
+ """
405
+ assert len(literal) > 0, "Empty literal not supported"
406
+
407
+ def recurse(i: int):
408
+ c = literal[i]
409
+ if maybe_escaped_underscores and c == "_":
410
+ yield f"[^{c}\\\\]"
411
+ yield " | "
412
+ yield f'"\\\\"? "{c}"'
413
+ else:
414
+ yield f"[^{c}]"
415
+ if i < len(literal) - 1:
416
+ yield " | "
417
+ yield self._format_literal(c)
418
+ yield " ("
419
+ yield from recurse(i + 1)
420
+ yield ")?"
421
+
422
+ return "".join(("(", *recurse(0), ")"))
423
+
424
+ def _add_rule(self, name, rule):
425
+ esc_name = INVALID_RULE_CHARS_RE.sub("-", name)
426
+ if esc_name not in self._rules or self._rules[esc_name] == rule:
427
+ key = esc_name
428
+ else:
429
+ i = 0
430
+ while (
431
+ f"{esc_name}{i}" in self._rules
432
+ and self._rules[f"{esc_name}{i}"] != rule
433
+ ):
434
+ i += 1
435
+ key = f"{esc_name}{i}"
436
+ self._rules[key] = rule
437
+ return key
438
+
439
+ def resolve_refs(self, schema: dict, url: str):
440
+ """
441
+ Resolves all $ref fields in the given schema, fetching any remote schemas,
442
+ replacing $ref with absolute reference URL and populating self._refs with the
443
+ respective referenced (sub)schema dictionaries.
444
+ """
445
+
446
+ def visit(n: dict):
447
+ if isinstance(n, list):
448
+ return [visit(x) for x in n]
449
+ elif isinstance(n, dict):
450
+ ref = n.get("$ref")
451
+ if ref is not None and ref not in self._refs:
452
+ if ref.startswith("https://"):
453
+ assert (
454
+ self._allow_fetch
455
+ ), "Fetching remote schemas is not allowed (use --allow-fetch for force)"
456
+ import requests
457
+
458
+ frag_split = ref.split("#")
459
+ base_url = frag_split[0]
460
+
461
+ target = self._refs.get(base_url)
462
+ if target is None:
463
+ target = self.resolve_refs(
464
+ requests.get(ref).json(), base_url
465
+ )
466
+ self._refs[base_url] = target
467
+
468
+ if len(frag_split) == 1 or frag_split[-1] == "":
469
+ return target
470
+ elif ref.startswith("#/"):
471
+ target = schema
472
+ ref = f"{url}{ref}"
473
+ n["$ref"] = ref
474
+ else:
475
+ raise ValueError(f"Unsupported ref {ref}")
476
+
477
+ for sel in ref.split("#")[-1].split("/")[1:]:
478
+ assert (
479
+ target is not None and sel in target
480
+ ), f"Error resolving ref {ref}: {sel} not in {target}"
481
+ target = target[sel]
482
+
483
+ self._refs[ref] = target
484
+ else:
485
+ for v in n.values():
486
+ visit(v)
487
+
488
+ return n
489
+
490
+ return visit(schema)
491
+
492
+ def _generate_union_rule(self, name, alt_schemas):
493
+ return " | ".join(
494
+ (
495
+ self.visit(alt_schema, f'{name}{"-" if name else "alternative-"}{i}')
496
+ for i, alt_schema in enumerate(alt_schemas)
497
+ )
498
+ )
499
+
500
+ def _visit_pattern(self, pattern, name):
501
+ """
502
+ Transforms a regular expression pattern into a GBNF rule.
503
+
504
+ Input: https://json-schema.org/understanding-json-schema/reference/regular_expressions
505
+ Output: https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md
506
+
507
+ Unsupported features: negative/positive lookaheads, greedy/non-greedy modifiers.
508
+
509
+ Mostly a 1:1 translation, except for {x} / {x,} / {x,y} quantifiers for which
510
+ we define sub-rules to keep the output lean.
511
+ """
512
+
513
+ assert pattern.startswith("^") and pattern.endswith(
514
+ "$"
515
+ ), 'Pattern must start with "^" and end with "$"'
516
+ pattern = pattern[1:-1]
517
+ sub_rule_ids = {}
518
+
519
+ i = 0
520
+ length = len(pattern)
521
+
522
+ def to_rule(s: Tuple[str, bool]) -> str:
523
+ (txt, is_literal) = s
524
+ return '"' + txt + '"' if is_literal else txt
525
+
526
+ def transform() -> Tuple[str, bool]:
527
+ """
528
+ Parse a unit at index i (advancing it), and return its string representation + whether it's a literal.
529
+ """
530
+ nonlocal i
531
+ nonlocal pattern
532
+ nonlocal sub_rule_ids
533
+
534
+ start = i
535
+ # For each component of this sequence, store its string representation and whether it's a literal.
536
+ # We only need a flat structure here to apply repetition operators to the last item, and
537
+ # to merge literals at the and (we're parsing grouped ( sequences ) recursively and don't treat '|' specially
538
+ # (GBNF's syntax is luckily very close to regular expressions!)
539
+ seq: list[Tuple[str, bool]] = []
540
+
541
+ def get_dot():
542
+ if self._dotall:
543
+ rule = DOTALL
544
+ else:
545
+ # Accept any character... except \n and \r line break chars (\x0A and \xOD)
546
+ rule = DOT
547
+ return self._add_rule(f"dot", rule)
548
+
549
+ def join_seq():
550
+ nonlocal seq
551
+ ret = []
552
+ for is_literal, g in groupby(seq, lambda x: x[1]):
553
+ if is_literal:
554
+ ret.append(("".join(x[0] for x in g), True))
555
+ else:
556
+ ret.extend(g)
557
+ if len(ret) == 1:
558
+ return ret[0]
559
+ return (" ".join(to_rule(x) for x in seq), False)
560
+
561
+ while i < length:
562
+ c = pattern[i]
563
+ if c == ".":
564
+ seq.append((get_dot(), False))
565
+ i += 1
566
+ elif c == "(":
567
+ i += 1
568
+ if i < length:
569
+ assert (
570
+ pattern[i] != "?"
571
+ ), f'Unsupported pattern syntax "{pattern[i]}" at index {i} of /{pattern}/'
572
+ seq.append((f"({to_rule(transform())})", False))
573
+ elif c == ")":
574
+ i += 1
575
+ assert (
576
+ start > 0 and pattern[start - 1] == "("
577
+ ), f"Unbalanced parentheses; start = {start}, i = {i}, pattern = {pattern}"
578
+ return join_seq()
579
+ elif c == "[":
580
+ square_brackets = c
581
+ i += 1
582
+ while i < length and pattern[i] != "]":
583
+ if pattern[i] == "\\":
584
+ square_brackets += pattern[i : i + 2]
585
+ i += 2
586
+ else:
587
+ square_brackets += pattern[i]
588
+ i += 1
589
+ assert (
590
+ i < length
591
+ ), f"Unbalanced square brackets; start = {start}, i = {i}, pattern = {pattern}"
592
+ square_brackets += "]"
593
+ i += 1
594
+ seq.append((square_brackets, False))
595
+ elif c == "|":
596
+ seq.append(("|", False))
597
+ i += 1
598
+ elif c in ("*", "+", "?"):
599
+ seq[-1] = (to_rule(seq[-1]) + c, False)
600
+ i += 1
601
+ elif c == "{":
602
+ curly_brackets = c
603
+ i += 1
604
+ while i < length and pattern[i] != "}":
605
+ curly_brackets += pattern[i]
606
+ i += 1
607
+ assert (
608
+ i < length
609
+ ), f"Unbalanced curly brackets; start = {start}, i = {i}, pattern = {pattern}"
610
+ curly_brackets += "}"
611
+ i += 1
612
+ nums = [s.strip() for s in curly_brackets[1:-1].split(",")]
613
+ min_times = 0
614
+ max_times = None
615
+ try:
616
+ if len(nums) == 1:
617
+ min_times = int(nums[0])
618
+ max_times = min_times
619
+ else:
620
+ assert len(nums) == 2
621
+ min_times = int(nums[0]) if nums[0] else 0
622
+ max_times = int(nums[1]) if nums[1] else None
623
+ except ValueError:
624
+ raise ValueError(
625
+ f"Invalid quantifier {curly_brackets} in /{pattern}/"
626
+ )
627
+
628
+ (sub, sub_is_literal) = seq[-1]
629
+
630
+ if not sub_is_literal:
631
+ id = sub_rule_ids.get(sub)
632
+ if id is None:
633
+ id = self._add_rule(f"{name}-{len(sub_rule_ids) + 1}", sub)
634
+ sub_rule_ids[sub] = id
635
+ sub = id
636
+
637
+ seq[-1] = (
638
+ _build_repetition(
639
+ f'"{sub}"' if sub_is_literal else sub,
640
+ min_times,
641
+ max_times,
642
+ item_rule_is_literal=sub_is_literal,
643
+ ),
644
+ False,
645
+ )
646
+ else:
647
+ literal = ""
648
+ while i < length:
649
+ if pattern[i] == "\\" and i < length - 1:
650
+ next = pattern[i + 1]
651
+ if next in ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS:
652
+ i += 1
653
+ literal += pattern[i]
654
+ i += 1
655
+ else:
656
+ literal += pattern[i : i + 2]
657
+ i += 2
658
+ elif pattern[i] == '"' and not self._raw_pattern:
659
+ literal += '\\"'
660
+ i += 1
661
+ elif pattern[i] not in NON_LITERAL_SET and (
662
+ i == length - 1
663
+ or literal == ""
664
+ or pattern[i + 1] == "."
665
+ or pattern[i + 1] not in NON_LITERAL_SET
666
+ ):
667
+ literal += pattern[i]
668
+ i += 1
669
+ else:
670
+ break
671
+ if literal:
672
+ seq.append((literal, True))
673
+
674
+ return join_seq()
675
+
676
+ return self._add_rule(
677
+ name,
678
+ (
679
+ to_rule(transform())
680
+ if self._raw_pattern
681
+ else '"\\"" ' + to_rule(transform()) + ' "\\"" space'
682
+ ),
683
+ )
684
+
685
+ def _resolve_ref(self, ref):
686
+ ref_name = ref.split("/")[-1]
687
+ if ref_name not in self._rules and ref not in self._refs_being_resolved:
688
+ self._refs_being_resolved.add(ref)
689
+ resolved = self._refs[ref]
690
+ ref_name = self.visit(resolved, ref_name)
691
+ self._refs_being_resolved.remove(ref)
692
+ return ref_name
693
+
694
+ def _generate_constant_rule(self, value):
695
+ return self._format_literal(json.dumps(value))
696
+
697
+ def visit(self, schema, name):
698
+ schema_type = schema.get("type")
699
+ schema_format = schema.get("format")
700
+ rule_name = name + "-" if name in RESERVED_NAMES else name or "root"
701
+
702
+ if (ref := schema.get("$ref")) is not None:
703
+ return self._add_rule(rule_name, self._resolve_ref(ref))
704
+
705
+ elif "oneOf" in schema or "anyOf" in schema:
706
+ return self._add_rule(
707
+ rule_name,
708
+ self._generate_union_rule(name, schema.get("oneOf") or schema["anyOf"]),
709
+ )
710
+
711
+ elif isinstance(schema_type, list):
712
+ return self._add_rule(
713
+ rule_name,
714
+ self._generate_union_rule(name, [{"type": t} for t in schema_type]),
715
+ )
716
+
717
+ elif "const" in schema:
718
+ return self._add_rule(
719
+ rule_name, self._generate_constant_rule(schema["const"])
720
+ )
721
+
722
+ elif "enum" in schema:
723
+ rule = " | ".join((self._generate_constant_rule(v) for v in schema["enum"]))
724
+ return self._add_rule(rule_name, rule)
725
+
726
+ elif schema_type in (None, "object") and (
727
+ "properties" in schema
728
+ or (
729
+ "additionalProperties" in schema
730
+ and schema["additionalProperties"] is not True
731
+ )
732
+ ):
733
+ required = set(schema.get("required", []))
734
+ properties = list(schema.get("properties", {}).items())
735
+ return self._add_rule(
736
+ rule_name,
737
+ self._build_object_rule(
738
+ properties, required, name, schema.get("additionalProperties")
739
+ ),
740
+ )
741
+
742
+ elif schema_type in (None, "object") and "allOf" in schema:
743
+ required = set()
744
+ properties = []
745
+ hybrid_name = name
746
+
747
+ def add_component(comp_schema, is_required):
748
+ if (ref := comp_schema.get("$ref")) is not None:
749
+ comp_schema = self._refs[ref]
750
+
751
+ if "properties" in comp_schema:
752
+ for prop_name, prop_schema in comp_schema["properties"].items():
753
+ properties.append((prop_name, prop_schema))
754
+ if is_required:
755
+ required.add(prop_name)
756
+
757
+ for t in schema["allOf"]:
758
+ if "anyOf" in t:
759
+ for tt in t["anyOf"]:
760
+ add_component(tt, is_required=False)
761
+ else:
762
+ add_component(t, is_required=True)
763
+
764
+ return self._add_rule(
765
+ rule_name,
766
+ self._build_object_rule(
767
+ properties, required, hybrid_name, additional_properties=[]
768
+ ),
769
+ )
770
+
771
+ elif schema_type in (None, "array") and (
772
+ "items" in schema or "prefixItems" in schema
773
+ ):
774
+ items = schema.get("items") or schema["prefixItems"]
775
+ if isinstance(items, list):
776
+ return self._add_rule(
777
+ rule_name,
778
+ '"[" space '
779
+ + ' "," space '.join(
780
+ self.visit(item, f'{name}{"-" if name else ""}tuple-{i}')
781
+ for i, item in enumerate(items)
782
+ )
783
+ + ' "]" space',
784
+ )
785
+ else:
786
+ item_rule_name = self.visit(items, f'{name}{"-" if name else ""}item')
787
+ min_items = schema.get("minItems", 0)
788
+ max_items = schema.get("maxItems")
789
+ return self._add_rule(
790
+ rule_name,
791
+ '"[" space '
792
+ + _build_repetition(
793
+ item_rule_name, min_items, max_items, separator_rule='"," space'
794
+ )
795
+ + ' "]" space',
796
+ )
797
+
798
+ elif schema_type in (None, "string") and "pattern" in schema:
799
+ return self._visit_pattern(schema["pattern"], rule_name)
800
+
801
+ elif schema_type in (None, "string") and re.match(
802
+ r"^uuid[1-5]?$", schema_format or ""
803
+ ):
804
+ return self._add_primitive(
805
+ "root" if rule_name == "root" else schema_format,
806
+ PRIMITIVE_RULES["uuid"],
807
+ )
808
+
809
+ elif (
810
+ schema_type in (None, "string")
811
+ and f"{schema_format}-string" in STRING_FORMAT_RULES
812
+ ):
813
+ prim_name = f"{schema_format}-string"
814
+ return self._add_rule(
815
+ rule_name,
816
+ self._add_primitive(prim_name, STRING_FORMAT_RULES[prim_name]),
817
+ )
818
+
819
+ elif schema_type == "string" and (
820
+ "minLength" in schema or "maxLength" in schema
821
+ ):
822
+ char_rule = self._add_primitive("char", PRIMITIVE_RULES["char"])
823
+ min_len = schema.get("minLength", 0)
824
+ max_len = schema.get("maxLength")
825
+
826
+ return self._add_rule(
827
+ rule_name,
828
+ r'"\"" '
829
+ + _build_repetition(char_rule, min_len, max_len)
830
+ + r' "\"" space',
831
+ )
832
+
833
+ elif (schema_type == "object") or (len(schema) == 0):
834
+ return self._add_rule(
835
+ rule_name, self._add_primitive("object", PRIMITIVE_RULES["object"])
836
+ )
837
+
838
+ else:
839
+ assert schema_type in PRIMITIVE_RULES, f"Unrecognized schema: {schema}"
840
+ # TODO: support minimum, maximum, exclusiveMinimum, exclusiveMaximum at least for zero
841
+ return self._add_primitive(
842
+ "root" if rule_name == "root" else schema_type,
843
+ PRIMITIVE_RULES[schema_type],
844
+ )
845
+
846
+ def _add_primitive(self, name: str, rule: BuiltinRule):
847
+ n = self._add_rule(name, rule.content)
848
+
849
+ for dep in rule.deps:
850
+ dep_rule = PRIMITIVE_RULES.get(dep) or STRING_FORMAT_RULES.get(dep)
851
+ assert dep_rule, f"Rule {dep} not known"
852
+ if dep not in self._rules:
853
+ self._add_primitive(dep, dep_rule)
854
+ return n
855
+
856
+ def _build_object_rule(
857
+ self,
858
+ properties: List[Tuple[str, Any]],
859
+ required: Set[str],
860
+ name: str,
861
+ additional_properties: Union[bool, Any],
862
+ ):
863
+ prop_order = self._prop_order
864
+ # sort by position in prop_order (if specified) then by original order
865
+ sorted_props = [
866
+ kv[0]
867
+ for _, kv in sorted(
868
+ enumerate(properties),
869
+ key=lambda ikv: (prop_order.get(ikv[1][0], len(prop_order)), ikv[0]),
870
+ )
871
+ ]
872
+
873
+ prop_kv_rule_names = {}
874
+ for prop_name, prop_schema in properties:
875
+ prop_rule_name = self.visit(
876
+ prop_schema, f'{name}{"-" if name else ""}{prop_name}'
877
+ )
878
+ prop_kv_rule_names[prop_name] = self._add_rule(
879
+ f'{name}{"-" if name else ""}{prop_name}-kv',
880
+ rf'{self._format_literal(json.dumps(prop_name))} space ":" space {prop_rule_name}',
881
+ )
882
+ required_props = [k for k in sorted_props if k in required]
883
+ optional_props = [k for k in sorted_props if k not in required]
884
+
885
+ if additional_properties == True or isinstance(additional_properties, dict):
886
+ sub_name = f'{name}{"-" if name else ""}additional'
887
+ value_rule = self.visit(
888
+ {} if additional_properties == True else additional_properties,
889
+ f"{sub_name}-value",
890
+ )
891
+ prop_kv_rule_names["*"] = self._add_rule(
892
+ f"{sub_name}-kv",
893
+ self._add_primitive("string", PRIMITIVE_RULES["string"])
894
+ + f' ":" space {value_rule}',
895
+ )
896
+ optional_props.append("*")
897
+
898
+ rule = '"{" space '
899
+ rule += ' "," space '.join(prop_kv_rule_names[k] for k in required_props)
900
+
901
+ if optional_props:
902
+ rule += " ("
903
+ if required_props:
904
+ rule += ' "," space ( '
905
+
906
+ def get_recursive_refs(ks, first_is_optional):
907
+ [k, *rest] = ks
908
+ kv_rule_name = prop_kv_rule_names[k]
909
+ if k == "*":
910
+ res = self._add_rule(
911
+ f'{name}{"-" if name else ""}additional-kvs',
912
+ f'{kv_rule_name} ( "," space ' + kv_rule_name + " )*",
913
+ )
914
+ elif first_is_optional:
915
+ res = f'( "," space {kv_rule_name} )?'
916
+ else:
917
+ res = kv_rule_name
918
+ if len(rest) > 0:
919
+ res += " " + self._add_rule(
920
+ f'{name}{"-" if name else ""}{k}-rest',
921
+ get_recursive_refs(rest, first_is_optional=True),
922
+ )
923
+ return res
924
+
925
+ rule += " | ".join(
926
+ get_recursive_refs(optional_props[i:], first_is_optional=False)
927
+ for i in range(len(optional_props))
928
+ )
929
+ if required_props:
930
+ rule += " )"
931
+ rule += " )?"
932
+
933
+ rule += ' "}" space'
934
+
935
+ return rule
936
+
937
+ def format_grammar(self):
938
+ return "\n".join(
939
+ f"{name} ::= {rule}"
940
+ for name, rule in sorted(self._rules.items(), key=lambda kv: kv[0])
941
+ )
942
+
943
+
944
+ def json_schema_to_gbnf(schema: str, prop_order: Optional[List[str]] = None):
945
+ prop_order = prop_order or []
946
+ schema = json.loads(schema)
947
+ prop_order = {name: idx for idx, name in enumerate(prop_order)}
948
+ converter = SchemaConverter(
949
+ prop_order=prop_order, allow_fetch=False, dotall=False, raw_pattern=False
950
+ )
951
+ schema = converter.resolve_refs(schema, "stdin")
952
+ converter.visit(schema, "")
953
+ return converter.format_grammar()
llama_cpp/llama_speculative.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+
3
+ from typing import Any
4
+
5
+ import numpy as np
6
+ import numpy.typing as npt
7
+
8
+
9
+ class LlamaDraftModel(abc.ABC):
10
+ @abc.abstractmethod
11
+ def __call__(
12
+ self, input_ids: npt.NDArray[np.intc], /, **kwargs: Any
13
+ ) -> npt.NDArray[np.intc]:
14
+ raise NotImplementedError()
15
+
16
+
17
+ class LlamaPromptLookupDecoding(LlamaDraftModel):
18
+ """Based on https://github.com/apoorvumang/prompt-lookup-decoding"""
19
+
20
+ def __init__(self, max_ngram_size: int = 2, num_pred_tokens: int = 10):
21
+ self.max_ngram_size = max_ngram_size
22
+ self.num_pred_tokens = num_pred_tokens
23
+
24
+ @staticmethod
25
+ def find_candidate_pred_tokens(
26
+ input_ids: npt.NDArray[np.intc],
27
+ max_ngram_size: int,
28
+ num_pred_tokens: int,
29
+ ):
30
+ input_length = input_ids.shape[0]
31
+
32
+ for ngram_size in range(min(max_ngram_size, input_length - 1), 0, -1):
33
+ # Create sliding windows of size ngram_size
34
+ windows = np.lib.stride_tricks.sliding_window_view(input_ids, (ngram_size,))
35
+
36
+ # Convert ngram to an array for comparison
37
+ ngram_array = input_ids[-ngram_size:]
38
+
39
+ # Find where the windows match the ngram
40
+ matches = np.all(windows == ngram_array, axis=1)
41
+
42
+ # Get the indices of matches
43
+ match_indices = np.nonzero(matches)[0]
44
+
45
+ # Iterate through match indices to find a valid continuation
46
+ for idx in match_indices:
47
+ start_idx = idx + ngram_size
48
+ end_idx = start_idx + num_pred_tokens
49
+ end_idx = min(end_idx, input_length)
50
+
51
+ if start_idx < end_idx:
52
+ return input_ids[start_idx:end_idx]
53
+
54
+ # If no match is found, return an empty array
55
+ return np.array([], dtype=np.intc)
56
+
57
+ def __call__(
58
+ self, input_ids: npt.NDArray[np.intc], /, **kwargs: Any
59
+ ) -> npt.NDArray[np.intc]:
60
+ return self.find_candidate_pred_tokens(
61
+ input_ids=input_ids,
62
+ max_ngram_size=self.max_ngram_size,
63
+ num_pred_tokens=self.num_pred_tokens,
64
+ )
llama_cpp/llama_tokenizer.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import abc
4
+ from typing import (
5
+ List,
6
+ Optional,
7
+ Any,
8
+ )
9
+
10
+ import llama_cpp
11
+ from llama_cpp.llama_types import List
12
+
13
+
14
+ class BaseLlamaTokenizer(abc.ABC):
15
+ @abc.abstractmethod
16
+ def tokenize(
17
+ self, text: bytes, add_bos: bool = True, special: bool = True
18
+ ) -> List[int]:
19
+ """Tokenize the text into tokens.
20
+
21
+ Args:
22
+ text: The utf-8 encoded string to tokenize.
23
+ add_bos: Whether to add a beginning of sequence token.
24
+ special: Whether to tokenize special tokens.
25
+ """
26
+ raise NotImplementedError
27
+
28
+ @abc.abstractmethod
29
+ def detokenize(
30
+ self,
31
+ tokens: List[int],
32
+ prev_tokens: Optional[List[int]] = None,
33
+ special: bool = False,
34
+ ) -> bytes:
35
+ """Detokenize the tokens into text.
36
+
37
+ Args:
38
+ tokens: The list of tokens to detokenize.
39
+ prev_tokens: The list of previous tokens. Offset mapping will be performed if provided.
40
+ special: Whether to detokenize special tokens.
41
+ """
42
+ raise NotImplementedError
43
+
44
+
45
+ class LlamaTokenizer(BaseLlamaTokenizer):
46
+ def __init__(self, llama: llama_cpp.Llama):
47
+ self._model = llama._model # type: ignore
48
+
49
+ def tokenize(
50
+ self, text: bytes, add_bos: bool = True, special: bool = True
51
+ ) -> List[int]:
52
+ return self._model.tokenize(text, add_bos=add_bos, special=special)
53
+
54
+ def detokenize(
55
+ self,
56
+ tokens: List[int],
57
+ prev_tokens: Optional[List[int]] = None,
58
+ special: bool = False,
59
+ ) -> bytes:
60
+ return self._model.detokenize(tokens, special=special)
61
+
62
+ def encode(
63
+ self, text: str, add_bos: bool = True, special: bool = True
64
+ ) -> List[int]:
65
+ return self.tokenize(
66
+ text.encode("utf-8", errors="ignore"), add_bos=add_bos, special=special
67
+ )
68
+
69
+ def decode(self, tokens: List[int]) -> str:
70
+ return self.detokenize(tokens).decode("utf-8", errors="ignore")
71
+
72
+ @classmethod
73
+ def from_ggml_file(cls, path: str) -> "LlamaTokenizer":
74
+ return cls(llama_cpp.Llama(model_path=path, vocab_only=True))
75
+
76
+
77
+ class LlamaHFTokenizer(BaseLlamaTokenizer):
78
+ def __init__(self, hf_tokenizer: Any):
79
+ self.hf_tokenizer = hf_tokenizer
80
+
81
+ def tokenize(
82
+ self, text: bytes, add_bos: bool = True, special: bool = True
83
+ ) -> List[int]:
84
+ return self.hf_tokenizer.encode(
85
+ text.decode("utf-8", errors="ignore"), add_special_tokens=special
86
+ )
87
+
88
+ def detokenize(
89
+ self,
90
+ tokens: List[int],
91
+ prev_tokens: Optional[List[int]] = None,
92
+ special: bool = False,
93
+ ) -> bytes:
94
+ skip_special_tokens = not special
95
+ if prev_tokens is not None:
96
+ text = self.hf_tokenizer.decode(
97
+ prev_tokens + tokens, skip_special_tokens=skip_special_tokens
98
+ ).encode("utf-8", errors="ignore")
99
+ prev_text = self.hf_tokenizer.decode(
100
+ prev_tokens, skip_special_tokens=skip_special_tokens
101
+ ).encode("utf-8", errors="ignore")
102
+ return text[len(prev_text) :]
103
+ else:
104
+ return self.hf_tokenizer.decode(
105
+ tokens, skip_special_tokens=skip_special_tokens
106
+ ).encode("utf-8", errors="ignore")
107
+
108
+ @classmethod
109
+ def from_pretrained(cls, pretrained_model_name_or_path: str) -> "LlamaHFTokenizer":
110
+ try:
111
+ from transformers import AutoTokenizer
112
+ except ImportError:
113
+ raise ImportError(
114
+ "The `transformers` library is required to use the `HFTokenizer`."
115
+ "You can install it with `pip install transformers`."
116
+ )
117
+ hf_tokenizer = AutoTokenizer.from_pretrained(
118
+ pretrained_model_name_or_path=pretrained_model_name_or_path
119
+ )
120
+ return cls(hf_tokenizer)
llama_cpp/llama_types.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Types and request signatures for OpenAI compatibility
2
+
3
+ NOTE: These types may change to match the OpenAI OpenAPI specification.
4
+
5
+ Based on the OpenAI OpenAPI specification:
6
+ https://github.com/openai/openai-openapi/blob/master/openapi.yaml
7
+
8
+ """
9
+
10
+ from typing import Any, List, Optional, Dict, Union
11
+ from typing_extensions import TypedDict, NotRequired, Literal
12
+
13
+
14
+ # NOTE: Defining this correctly using annotations seems to break pydantic validation.
15
+ # This is a workaround until we can figure out how to do this correctly
16
+ # JsonType = Union[None, int, str, bool, List["JsonType"], Dict[str, "JsonType"]]
17
+ JsonType = Union[None, int, str, bool, List[Any], Dict[str, Any]]
18
+
19
+
20
+ class EmbeddingUsage(TypedDict):
21
+ prompt_tokens: int
22
+ total_tokens: int
23
+
24
+
25
+ class Embedding(TypedDict):
26
+ index: int
27
+ object: str
28
+ embedding: Union[List[float], List[List[float]]]
29
+
30
+
31
+ class CreateEmbeddingResponse(TypedDict):
32
+ object: Literal["list"]
33
+ model: str
34
+ data: List[Embedding]
35
+ usage: EmbeddingUsage
36
+
37
+
38
+ class CompletionLogprobs(TypedDict):
39
+ text_offset: List[int]
40
+ token_logprobs: List[Optional[float]]
41
+ tokens: List[str]
42
+ top_logprobs: List[Optional[Dict[str, float]]]
43
+
44
+
45
+ class CompletionChoice(TypedDict):
46
+ text: str
47
+ index: int
48
+ logprobs: Optional[CompletionLogprobs]
49
+ finish_reason: Optional[Literal["stop", "length"]]
50
+
51
+
52
+ class CompletionUsage(TypedDict):
53
+ prompt_tokens: int
54
+ completion_tokens: int
55
+ total_tokens: int
56
+
57
+
58
+ class CreateCompletionResponse(TypedDict):
59
+ id: str
60
+ object: Literal["text_completion"]
61
+ created: int
62
+ model: str
63
+ choices: List[CompletionChoice]
64
+ usage: NotRequired[CompletionUsage]
65
+
66
+
67
+ class ChatCompletionResponseFunctionCall(TypedDict):
68
+ name: str
69
+ arguments: str
70
+
71
+
72
+ class ChatCompletionResponseMessage(TypedDict):
73
+ content: Optional[str]
74
+ tool_calls: NotRequired["ChatCompletionMessageToolCalls"]
75
+ role: Literal["assistant", "function"] # NOTE: "function" may be incorrect here
76
+ function_call: NotRequired[ChatCompletionResponseFunctionCall] # DEPRECATED
77
+
78
+
79
+ class ChatCompletionFunction(TypedDict):
80
+ name: str
81
+ description: NotRequired[str]
82
+ parameters: Dict[str, JsonType] # TODO: make this more specific
83
+
84
+
85
+ class ChatCompletionTopLogprobToken(TypedDict):
86
+ token: str
87
+ logprob: float
88
+ bytes: Optional[List[int]]
89
+
90
+
91
+ class ChatCompletionLogprobToken(ChatCompletionTopLogprobToken):
92
+ token: str
93
+ logprob: float
94
+ bytes: Optional[List[int]]
95
+ top_logprobs: List[ChatCompletionTopLogprobToken]
96
+
97
+
98
+ class ChatCompletionLogprobs(TypedDict):
99
+ content: Optional[List[ChatCompletionLogprobToken]]
100
+ refusal: Optional[List[ChatCompletionLogprobToken]]
101
+
102
+
103
+ class ChatCompletionResponseChoice(TypedDict):
104
+ index: int
105
+ message: "ChatCompletionResponseMessage"
106
+ logprobs: Optional[ChatCompletionLogprobs]
107
+ finish_reason: Optional[str]
108
+
109
+
110
+ class CreateChatCompletionResponse(TypedDict):
111
+ id: str
112
+ object: Literal["chat.completion"]
113
+ created: int
114
+ model: str
115
+ choices: List["ChatCompletionResponseChoice"]
116
+ usage: CompletionUsage
117
+
118
+
119
+ class ChatCompletionMessageToolCallChunkFunction(TypedDict):
120
+ name: Optional[str]
121
+ arguments: str
122
+
123
+
124
+ class ChatCompletionMessageToolCallChunk(TypedDict):
125
+ index: int
126
+ id: NotRequired[str]
127
+ type: Literal["function"]
128
+ function: ChatCompletionMessageToolCallChunkFunction
129
+
130
+
131
+ class ChatCompletionStreamResponseDeltaEmpty(TypedDict):
132
+ pass
133
+
134
+
135
+ class ChatCompletionStreamResponseDeltaFunctionCall(TypedDict):
136
+ name: str
137
+ arguments: str
138
+
139
+
140
+ class ChatCompletionStreamResponseDelta(TypedDict):
141
+ content: NotRequired[Optional[str]]
142
+ function_call: NotRequired[
143
+ Optional[ChatCompletionStreamResponseDeltaFunctionCall]
144
+ ] # DEPRECATED
145
+ tool_calls: NotRequired[Optional[List[ChatCompletionMessageToolCallChunk]]]
146
+ role: NotRequired[Optional[Literal["system", "user", "assistant", "tool"]]]
147
+
148
+
149
+ class ChatCompletionStreamResponseChoice(TypedDict):
150
+ index: int
151
+ delta: Union[
152
+ ChatCompletionStreamResponseDelta, ChatCompletionStreamResponseDeltaEmpty
153
+ ]
154
+ finish_reason: Optional[Literal["stop", "length", "tool_calls", "function_call"]]
155
+ logprobs: NotRequired[Optional[ChatCompletionLogprobs]]
156
+
157
+
158
+ class CreateChatCompletionStreamResponse(TypedDict):
159
+ id: str
160
+ model: str
161
+ object: Literal["chat.completion.chunk"]
162
+ created: int
163
+ choices: List[ChatCompletionStreamResponseChoice]
164
+
165
+
166
+ class ChatCompletionFunctions(TypedDict):
167
+ name: str
168
+ description: NotRequired[str]
169
+ parameters: Dict[str, JsonType] # TODO: make this more specific
170
+
171
+
172
+ class ChatCompletionFunctionCallOption(TypedDict):
173
+ name: str
174
+
175
+
176
+ class ChatCompletionRequestResponseFormat(TypedDict):
177
+ type: Literal["text", "json_object"]
178
+ schema: NotRequired[
179
+ JsonType
180
+ ] # https://docs.endpoints.anyscale.com/guides/json_mode/
181
+
182
+
183
+ class ChatCompletionRequestMessageContentPartText(TypedDict):
184
+ type: Literal["text"]
185
+ text: str
186
+
187
+
188
+ class ChatCompletionRequestMessageContentPartImageImageUrl(TypedDict):
189
+ url: str
190
+ detail: NotRequired[Literal["auto", "low", "high"]]
191
+
192
+
193
+ class ChatCompletionRequestMessageContentPartImage(TypedDict):
194
+ type: Literal["image_url"]
195
+ image_url: Union[str, ChatCompletionRequestMessageContentPartImageImageUrl]
196
+
197
+
198
+ ChatCompletionRequestMessageContentPart = Union[
199
+ ChatCompletionRequestMessageContentPartText,
200
+ ChatCompletionRequestMessageContentPartImage,
201
+ ]
202
+
203
+
204
+ class ChatCompletionRequestSystemMessage(TypedDict):
205
+ role: Literal["system"]
206
+ content: Optional[str]
207
+
208
+
209
+ class ChatCompletionRequestUserMessage(TypedDict):
210
+ role: Literal["user"]
211
+ content: Optional[Union[str, List[ChatCompletionRequestMessageContentPart]]]
212
+
213
+
214
+ class ChatCompletionMessageToolCallFunction(TypedDict):
215
+ name: str
216
+ arguments: str
217
+
218
+
219
+ class ChatCompletionMessageToolCall(TypedDict):
220
+ id: str
221
+ type: Literal["function"]
222
+ function: ChatCompletionMessageToolCallFunction
223
+
224
+
225
+ ChatCompletionMessageToolCalls = List[ChatCompletionMessageToolCall]
226
+
227
+
228
+ class ChatCompletionRequestAssistantMessageFunctionCall(TypedDict):
229
+ name: str
230
+ arguments: str
231
+
232
+
233
+ class ChatCompletionRequestAssistantMessage(TypedDict):
234
+ role: Literal["assistant"]
235
+ content: NotRequired[str]
236
+ tool_calls: NotRequired[ChatCompletionMessageToolCalls]
237
+ function_call: NotRequired[
238
+ ChatCompletionRequestAssistantMessageFunctionCall
239
+ ] # DEPRECATED
240
+
241
+
242
+ class ChatCompletionRequestToolMessage(TypedDict):
243
+ role: Literal["tool"]
244
+ content: Optional[str]
245
+ tool_call_id: str
246
+
247
+
248
+ class ChatCompletionRequestFunctionMessage(TypedDict):
249
+ role: Literal["function"]
250
+ content: Optional[str]
251
+ name: str
252
+
253
+
254
+ ChatCompletionRequestMessage = Union[
255
+ ChatCompletionRequestSystemMessage,
256
+ ChatCompletionRequestUserMessage,
257
+ ChatCompletionRequestAssistantMessage,
258
+ ChatCompletionRequestUserMessage,
259
+ ChatCompletionRequestToolMessage,
260
+ ChatCompletionRequestFunctionMessage,
261
+ ]
262
+
263
+
264
+ class ChatCompletionRequestFunctionCallOption(TypedDict):
265
+ name: str
266
+
267
+
268
+ ChatCompletionRequestFunctionCall = Union[
269
+ Literal["none", "auto"], ChatCompletionRequestFunctionCallOption
270
+ ]
271
+
272
+ ChatCompletionFunctionParameters = Dict[str, JsonType] # TODO: make this more specific
273
+
274
+
275
+ class ChatCompletionToolFunction(TypedDict):
276
+ name: str
277
+ description: NotRequired[str]
278
+ parameters: ChatCompletionFunctionParameters
279
+
280
+
281
+ class ChatCompletionTool(TypedDict):
282
+ type: Literal["function"]
283
+ function: ChatCompletionToolFunction
284
+
285
+
286
+ class ChatCompletionNamedToolChoiceFunction(TypedDict):
287
+ name: str
288
+
289
+
290
+ class ChatCompletionNamedToolChoice(TypedDict):
291
+ type: Literal["function"]
292
+ function: ChatCompletionNamedToolChoiceFunction
293
+
294
+
295
+ ChatCompletionToolChoiceOption = Union[
296
+ Literal["none", "auto", "required"], ChatCompletionNamedToolChoice
297
+ ]
298
+
299
+
300
+ # NOTE: The following type names are not part of the OpenAI OpenAPI specification
301
+ # and will be removed in a future major release.
302
+
303
+ EmbeddingData = Embedding
304
+ CompletionChunk = CreateCompletionResponse
305
+ Completion = CreateCompletionResponse
306
+ CreateCompletionStreamResponse = CreateCompletionResponse
307
+ ChatCompletionMessage = ChatCompletionResponseMessage
308
+ ChatCompletionChoice = ChatCompletionResponseChoice
309
+ ChatCompletion = CreateChatCompletionResponse
310
+ ChatCompletionChunkDeltaEmpty = ChatCompletionStreamResponseDeltaEmpty
311
+ ChatCompletionChunkChoice = ChatCompletionStreamResponseChoice
312
+ ChatCompletionChunkDelta = ChatCompletionStreamResponseDelta
313
+ ChatCompletionChunk = CreateChatCompletionStreamResponse
314
+ ChatCompletionStreamResponse = CreateChatCompletionStreamResponse
315
+ ChatCompletionResponseFunction = ChatCompletionFunction
316
+ ChatCompletionFunctionCall = ChatCompletionResponseFunctionCall
llama_cpp/llava_cpp.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ from ctypes import (
5
+ c_bool,
6
+ c_char_p,
7
+ c_int,
8
+ c_uint8,
9
+ c_float,
10
+ c_void_p,
11
+ POINTER,
12
+ _Pointer, # type: ignore
13
+ Structure,
14
+ )
15
+ import pathlib
16
+ from typing import (
17
+ Union,
18
+ NewType,
19
+ Optional,
20
+ TYPE_CHECKING,
21
+ )
22
+
23
+ import llama_cpp.llama_cpp as llama_cpp
24
+
25
+ from llama_cpp._ctypes_extensions import (
26
+ load_shared_library,
27
+ ctypes_function_for_shared_library,
28
+ )
29
+
30
+ if TYPE_CHECKING:
31
+ from llama_cpp._ctypes_extensions import (
32
+ CtypesArray,
33
+ )
34
+
35
+
36
+ # Specify the base name of the shared library to load
37
+ _libllava_base_name = "llava"
38
+ _libllava_override_path = os.environ.get("LLAVA_CPP_LIB")
39
+ _libllava_base_path = pathlib.Path(os.path.abspath(os.path.dirname(__file__))) / "lib" if _libllava_override_path is None else pathlib.Path()
40
+
41
+ # Load the library
42
+ _libllava = load_shared_library(_libllava_base_name, _libllava_base_path)
43
+
44
+ ctypes_function = ctypes_function_for_shared_library(_libllava)
45
+
46
+
47
+ ################################################
48
+ # llava.h
49
+ ################################################
50
+
51
+ # struct clip_ctx;
52
+ clip_ctx_p = NewType("clip_ctx_p", int)
53
+ clip_ctx_p_ctypes = c_void_p
54
+
55
+
56
+ # struct llava_image_embed {
57
+ # float * embed;
58
+ # int n_image_pos;
59
+ # };
60
+ class llava_image_embed(Structure):
61
+ _fields_ = [
62
+ ("embed", POINTER(c_float)),
63
+ ("n_image_pos", c_int),
64
+ ]
65
+
66
+
67
+ # /** sanity check for clip <-> llava embed size match */
68
+ # LLAVA_API bool llava_validate_embed_size(const llama_context * ctx_llama, const clip_ctx * ctx_clip);
69
+ @ctypes_function(
70
+ "llava_validate_embed_size",
71
+ [llama_cpp.llama_context_p_ctypes, clip_ctx_p_ctypes],
72
+ c_bool,
73
+ )
74
+ def llava_validate_embed_size(
75
+ ctx_llama: llama_cpp.llama_context_p, ctx_clip: clip_ctx_p, /
76
+ ) -> bool:
77
+ ...
78
+
79
+
80
+ # /** build an image embed from image file bytes */
81
+ # LLAVA_API struct llava_image_embed * llava_image_embed_make_with_bytes(struct clip_ctx * ctx_clip, int n_threads, const unsigned char * image_bytes, int image_bytes_length);
82
+ @ctypes_function(
83
+ "llava_image_embed_make_with_bytes",
84
+ [clip_ctx_p_ctypes, c_int, POINTER(c_uint8), c_int],
85
+ POINTER(llava_image_embed),
86
+ )
87
+ def llava_image_embed_make_with_bytes(
88
+ ctx_clip: clip_ctx_p,
89
+ n_threads: Union[c_int, int],
90
+ image_bytes: CtypesArray[c_uint8],
91
+ image_bytes_length: Union[c_int, int],
92
+ /,
93
+ ) -> "_Pointer[llava_image_embed]":
94
+ ...
95
+
96
+
97
+ # /** build an image embed from a path to an image filename */
98
+ # LLAVA_API struct llava_image_embed * llava_image_embed_make_with_filename(struct clip_ctx * ctx_clip, int n_threads, const char * image_path);
99
+ @ctypes_function(
100
+ "llava_image_embed_make_with_filename",
101
+ [clip_ctx_p_ctypes, c_int, c_char_p],
102
+ POINTER(llava_image_embed),
103
+ )
104
+ def llava_image_embed_make_with_filename(
105
+ ctx_clip: clip_ctx_p, n_threads: Union[c_int, int], image_path: bytes, /
106
+ ) -> "_Pointer[llava_image_embed]":
107
+ ...
108
+
109
+
110
+ # LLAVA_API void llava_image_embed_free(struct llava_image_embed * embed);
111
+ # /** free an embedding made with llava_image_embed_make_* */
112
+ @ctypes_function("llava_image_embed_free", [POINTER(llava_image_embed)], None)
113
+ def llava_image_embed_free(embed: "_Pointer[llava_image_embed]", /):
114
+ ...
115
+
116
+
117
+ # /** write the image represented by embed into the llama context with batch size n_batch, starting at context pos n_past. on completion, n_past points to the next position in the context after the image embed. */
118
+ # LLAVA_API bool llava_eval_image_embed(struct llama_context * ctx_llama, const struct llava_image_embed * embed, int n_batch, int * n_past);
119
+ @ctypes_function(
120
+ "llava_eval_image_embed",
121
+ [
122
+ llama_cpp.llama_context_p_ctypes,
123
+ POINTER(llava_image_embed),
124
+ c_int,
125
+ POINTER(c_int),
126
+ ],
127
+ c_bool,
128
+ )
129
+ def llava_eval_image_embed(
130
+ ctx_llama: llama_cpp.llama_context_p,
131
+ embed: "_Pointer[llava_image_embed]",
132
+ n_batch: Union[c_int, int],
133
+ n_past: "_Pointer[c_int]",
134
+ /,
135
+ ) -> bool:
136
+ ...
137
+
138
+
139
+ ################################################
140
+ # clip.h
141
+ ################################################
142
+
143
+
144
+ # /** load mmproj model */
145
+ # CLIP_API struct clip_ctx * clip_model_load (const char * fname, int verbosity);
146
+ @ctypes_function("clip_model_load", [c_char_p, c_int], clip_ctx_p_ctypes)
147
+ def clip_model_load(
148
+ fname: bytes, verbosity: Union[c_int, int], /
149
+ ) -> Optional[clip_ctx_p]:
150
+ ...
151
+
152
+
153
+ # /** free mmproj model */
154
+ # CLIP_API void clip_free(struct clip_ctx * ctx);
155
+ @ctypes_function("clip_free", [clip_ctx_p_ctypes], None)
156
+ def clip_free(ctx: clip_ctx_p, /):
157
+ ...
158
+
llama_cpp/mtmd_cpp.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ from ctypes import (
5
+ c_bool,
6
+ c_char_p,
7
+ c_int,
8
+ c_uint8,
9
+ c_uint32,
10
+ c_float,
11
+ c_void_p,
12
+ c_size_t,
13
+ POINTER,
14
+ _Pointer, # type: ignore
15
+ Structure,
16
+ byref,
17
+ )
18
+ import pathlib
19
+ from typing import (
20
+ Union,
21
+ NewType,
22
+ Optional,
23
+ TYPE_CHECKING,
24
+ )
25
+
26
+ import llama_cpp.llama_cpp as llama_cpp
27
+
28
+ from llama_cpp._ctypes_extensions import (
29
+ load_shared_library,
30
+ ctypes_function_for_shared_library,
31
+ )
32
+
33
+ if TYPE_CHECKING:
34
+ from llama_cpp._ctypes_extensions import (
35
+ CtypesArray,
36
+ )
37
+
38
+
39
+ # Specify the base name of the shared library to load
40
+ _libmtmd_base_name = "mtmd"
41
+ _libmtmd_override_path = os.environ.get("MTMD_CPP_LIB")
42
+ _libmtmd_base_path = pathlib.Path(os.path.abspath(os.path.dirname(__file__))) / "lib" if _libmtmd_override_path is None else pathlib.Path()
43
+
44
+ # Load the library
45
+ _libmtmd = load_shared_library(_libmtmd_base_name, _libmtmd_base_path)
46
+
47
+ ctypes_function = ctypes_function_for_shared_library(_libmtmd)
48
+
49
+ ################################################
50
+ # mtmd.h types
51
+ ################################################
52
+
53
+ # Opaque types
54
+ mtmd_context_p = NewType("mtmd_context_p", int)
55
+ mtmd_context_p_ctypes = c_void_p
56
+
57
+ mtmd_bitmap_p = NewType("mtmd_bitmap_p", int)
58
+ mtmd_bitmap_p_ctypes = c_void_p
59
+
60
+ mtmd_image_tokens_p = NewType("mtmd_image_tokens_p", int)
61
+ mtmd_image_tokens_p_ctypes = c_void_p
62
+
63
+ mtmd_input_chunk_p = NewType("mtmd_input_chunk_p", int)
64
+ mtmd_input_chunk_p_ctypes = c_void_p
65
+
66
+ mtmd_input_chunks_p = NewType("mtmd_input_chunks_p", int)
67
+ mtmd_input_chunks_p_ctypes = c_void_p
68
+
69
+ # Enums
70
+ MTMD_INPUT_CHUNK_TYPE_TEXT = 0
71
+ MTMD_INPUT_CHUNK_TYPE_IMAGE = 1
72
+ MTMD_INPUT_CHUNK_TYPE_AUDIO = 2
73
+
74
+ # Structures
75
+ class mtmd_context_params(Structure):
76
+ _fields_ = [
77
+ ("use_gpu", c_bool),
78
+ ("print_timings", c_bool),
79
+ ("n_threads", c_int),
80
+ ("verbosity", c_int), # ggml_log_level
81
+ ("image_marker", c_char_p),
82
+ ("media_marker", c_char_p),
83
+ ]
84
+
85
+ class mtmd_input_text(Structure):
86
+ _fields_ = [
87
+ ("text", c_char_p),
88
+ ("add_special", c_bool),
89
+ ("parse_special", c_bool),
90
+ ]
91
+
92
+ ################################################
93
+ # mtmd.h functions
94
+ ################################################
95
+
96
+ # MTMD_API const char * mtmd_default_marker(void);
97
+ @ctypes_function("mtmd_default_marker", [], c_char_p)
98
+ def mtmd_default_marker() -> bytes:
99
+ ...
100
+
101
+ # MTMD_API struct mtmd_context_params mtmd_context_params_default(void);
102
+ @ctypes_function("mtmd_context_params_default", [], mtmd_context_params)
103
+ def mtmd_context_params_default() -> mtmd_context_params:
104
+ ...
105
+
106
+ # MTMD_API mtmd_context * mtmd_init_from_file(const char * mmproj_fname,
107
+ # const struct llama_model * text_model,
108
+ # const struct mtmd_context_params ctx_params);
109
+ @ctypes_function(
110
+ "mtmd_init_from_file",
111
+ [c_char_p, llama_cpp.llama_model_p_ctypes, mtmd_context_params],
112
+ mtmd_context_p_ctypes
113
+ )
114
+ def mtmd_init_from_file(
115
+ mmproj_fname: bytes,
116
+ text_model: llama_cpp.llama_model_p,
117
+ ctx_params: mtmd_context_params,
118
+ /,
119
+ ) -> Optional[mtmd_context_p]:
120
+ ...
121
+
122
+ # MTMD_API void mtmd_free(mtmd_context * ctx);
123
+ @ctypes_function("mtmd_free", [mtmd_context_p_ctypes], None)
124
+ def mtmd_free(ctx: mtmd_context_p, /):
125
+ ...
126
+
127
+ # MTMD_API bool mtmd_support_vision(mtmd_context * ctx);
128
+ @ctypes_function("mtmd_support_vision", [mtmd_context_p_ctypes], c_bool)
129
+ def mtmd_support_vision(ctx: mtmd_context_p, /) -> bool:
130
+ ...
131
+
132
+ # MTMD_API mtmd_bitmap * mtmd_bitmap_init(uint32_t nx, uint32_t ny, const unsigned char * data);
133
+ @ctypes_function(
134
+ "mtmd_bitmap_init",
135
+ [c_uint32, c_uint32, POINTER(c_uint8)],
136
+ mtmd_bitmap_p_ctypes
137
+ )
138
+ def mtmd_bitmap_init(
139
+ nx: Union[c_uint32, int],
140
+ ny: Union[c_uint32, int],
141
+ data: CtypesArray[c_uint8],
142
+ /,
143
+ ) -> Optional[mtmd_bitmap_p]:
144
+ ...
145
+
146
+ # MTMD_API void mtmd_bitmap_free(mtmd_bitmap * bitmap);
147
+ @ctypes_function("mtmd_bitmap_free", [mtmd_bitmap_p_ctypes], None)
148
+ def mtmd_bitmap_free(bitmap: mtmd_bitmap_p, /):
149
+ ...
150
+
151
+ # MTMD_API mtmd_input_chunks * mtmd_input_chunks_init(void);
152
+ @ctypes_function("mtmd_input_chunks_init", [], mtmd_input_chunks_p_ctypes)
153
+ def mtmd_input_chunks_init() -> Optional[mtmd_input_chunks_p]:
154
+ ...
155
+
156
+ # MTMD_API void mtmd_input_chunks_free(mtmd_input_chunks * chunks);
157
+ @ctypes_function("mtmd_input_chunks_free", [mtmd_input_chunks_p_ctypes], None)
158
+ def mtmd_input_chunks_free(chunks: mtmd_input_chunks_p, /):
159
+ ...
160
+
161
+ # MTMD_API size_t mtmd_input_chunks_size(const mtmd_input_chunks * chunks);
162
+ @ctypes_function("mtmd_input_chunks_size", [mtmd_input_chunks_p_ctypes], c_size_t)
163
+ def mtmd_input_chunks_size(chunks: mtmd_input_chunks_p, /) -> int:
164
+ ...
165
+
166
+ # MTMD_API const mtmd_input_chunk * mtmd_input_chunks_get(const mtmd_input_chunks * chunks, size_t idx);
167
+ @ctypes_function(
168
+ "mtmd_input_chunks_get",
169
+ [mtmd_input_chunks_p_ctypes, c_size_t],
170
+ mtmd_input_chunk_p_ctypes
171
+ )
172
+ def mtmd_input_chunks_get(
173
+ chunks: mtmd_input_chunks_p, idx: Union[c_size_t, int], /
174
+ ) -> Optional[mtmd_input_chunk_p]:
175
+ ...
176
+
177
+ # MTMD_API int32_t mtmd_tokenize(mtmd_context * ctx,
178
+ # mtmd_input_chunks * output,
179
+ # const mtmd_input_text * text,
180
+ # const mtmd_bitmap ** bitmaps,
181
+ # size_t n_bitmaps);
182
+ @ctypes_function(
183
+ "mtmd_tokenize",
184
+ [
185
+ mtmd_context_p_ctypes,
186
+ mtmd_input_chunks_p_ctypes,
187
+ POINTER(mtmd_input_text),
188
+ POINTER(mtmd_bitmap_p_ctypes),
189
+ c_size_t,
190
+ ],
191
+ c_int,
192
+ )
193
+ def mtmd_tokenize(
194
+ ctx: mtmd_context_p,
195
+ output: mtmd_input_chunks_p,
196
+ text: "_Pointer[mtmd_input_text]",
197
+ bitmaps: CtypesArray[mtmd_bitmap_p_ctypes],
198
+ n_bitmaps: Union[c_size_t, int],
199
+ /,
200
+ ) -> int:
201
+ ...
202
+
203
+ # MTMD_API size_t mtmd_input_chunk_get_n_tokens(const mtmd_input_chunk * chunk);
204
+ @ctypes_function("mtmd_input_chunk_get_n_tokens", [mtmd_input_chunk_p_ctypes], c_size_t)
205
+ def mtmd_input_chunk_get_n_tokens(chunk: mtmd_input_chunk_p, /) -> int:
206
+ ...
207
+
208
+ # MTMD_API enum mtmd_input_chunk_type mtmd_input_chunk_get_type(const mtmd_input_chunk * chunk);
209
+ @ctypes_function("mtmd_input_chunk_get_type", [mtmd_input_chunk_p_ctypes], c_int)
210
+ def mtmd_input_chunk_get_type(chunk: mtmd_input_chunk_p, /) -> int:
211
+ ...
212
+
213
+ # MTMD_API const llama_token * mtmd_input_chunk_get_tokens_text(const mtmd_input_chunk * chunk, size_t * n_tokens_output);
214
+ @ctypes_function(
215
+ "mtmd_input_chunk_get_tokens_text",
216
+ [mtmd_input_chunk_p_ctypes, POINTER(c_size_t)],
217
+ POINTER(llama_cpp.llama_token)
218
+ )
219
+ def mtmd_input_chunk_get_tokens_text(
220
+ chunk: mtmd_input_chunk_p, n_tokens_output: "_Pointer[c_size_t]", /
221
+ ) -> Optional["_Pointer[llama_cpp.llama_token]"]:
222
+ ...
223
+
224
+ ################################################
225
+ # mtmd-helper.h functions
226
+ ################################################
227
+
228
+ # MTMD_API mtmd_bitmap * mtmd_helper_bitmap_init_from_buf(mtmd_context * ctx, const unsigned char * buf, size_t len);
229
+ @ctypes_function(
230
+ "mtmd_helper_bitmap_init_from_buf",
231
+ [mtmd_context_p_ctypes, POINTER(c_uint8), c_size_t],
232
+ mtmd_bitmap_p_ctypes
233
+ )
234
+ def mtmd_helper_bitmap_init_from_buf(
235
+ ctx: mtmd_context_p,
236
+ buf: CtypesArray[c_uint8],
237
+ length: Union[c_size_t, int],
238
+ /,
239
+ ) -> Optional[mtmd_bitmap_p]:
240
+ ...
241
+
242
+ # MTMD_API size_t mtmd_helper_get_n_tokens(const mtmd_input_chunks * chunks);
243
+ @ctypes_function("mtmd_helper_get_n_tokens", [mtmd_input_chunks_p_ctypes], c_size_t)
244
+ def mtmd_helper_get_n_tokens(chunks: mtmd_input_chunks_p, /) -> int:
245
+ ...
246
+
247
+ # MTMD_API int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx,
248
+ # struct llama_context * lctx,
249
+ # const mtmd_input_chunk * chunk,
250
+ # llama_pos n_past,
251
+ # llama_seq_id seq_id,
252
+ # int32_t n_batch,
253
+ # bool logits_last,
254
+ # llama_pos * new_n_past);
255
+ @ctypes_function(
256
+ "mtmd_helper_eval_chunk_single",
257
+ [
258
+ mtmd_context_p_ctypes,
259
+ llama_cpp.llama_context_p_ctypes,
260
+ mtmd_input_chunk_p_ctypes,
261
+ llama_cpp.llama_pos,
262
+ llama_cpp.llama_seq_id,
263
+ c_int,
264
+ c_bool,
265
+ POINTER(llama_cpp.llama_pos),
266
+ ],
267
+ c_int,
268
+ )
269
+ def mtmd_helper_eval_chunk_single(
270
+ ctx: mtmd_context_p,
271
+ lctx: llama_cpp.llama_context_p,
272
+ chunk: mtmd_input_chunk_p,
273
+ n_past: llama_cpp.llama_pos,
274
+ seq_id: llama_cpp.llama_seq_id,
275
+ n_batch: Union[c_int, int],
276
+ logits_last: Union[c_bool, bool],
277
+ new_n_past: "_Pointer[llama_cpp.llama_pos]",
278
+ /,
279
+ ) -> int:
280
+ ...
llama_cpp/py.typed ADDED
File without changes
llama_cpp/server/__init__.py ADDED
File without changes
llama_cpp/server/__main__.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Example FastAPI server for llama.cpp.
2
+
3
+ To run this example:
4
+
5
+ ```bash
6
+ pip install fastapi uvicorn sse-starlette pydantic-settings
7
+ export MODEL=../models/7B/...
8
+ ```
9
+
10
+ Then run:
11
+ ```
12
+ uvicorn llama_cpp.server.app:create_app --reload
13
+ ```
14
+
15
+ or
16
+
17
+ ```
18
+ python3 -m llama_cpp.server
19
+ ```
20
+
21
+ Then visit http://localhost:8000/docs to see the interactive API docs.
22
+
23
+ """
24
+
25
+ from __future__ import annotations
26
+
27
+ import os
28
+ import sys
29
+ import argparse
30
+
31
+ import uvicorn
32
+
33
+ from llama_cpp.server.app import create_app
34
+ from llama_cpp.server.settings import (
35
+ Settings,
36
+ ServerSettings,
37
+ ModelSettings,
38
+ ConfigFileSettings,
39
+ )
40
+ from llama_cpp.server.cli import add_args_from_model, parse_model_from_args
41
+
42
+
43
+ def main():
44
+ description = "🦙 Llama.cpp python server. Host your own LLMs!🚀"
45
+ parser = argparse.ArgumentParser(description=description)
46
+
47
+ add_args_from_model(parser, Settings)
48
+ parser.add_argument(
49
+ "--config_file",
50
+ type=str,
51
+ help="Path to a config file to load.",
52
+ )
53
+ server_settings: ServerSettings | None = None
54
+ model_settings: list[ModelSettings] = []
55
+ args = parser.parse_args()
56
+ try:
57
+ # Load server settings from config_file if provided
58
+ config_file = os.environ.get("CONFIG_FILE", args.config_file)
59
+ if config_file:
60
+ if not os.path.exists(config_file):
61
+ raise ValueError(f"Config file {config_file} not found!")
62
+ with open(config_file, "rb") as f:
63
+ # Check if yaml file
64
+ if config_file.endswith(".yaml") or config_file.endswith(".yml"):
65
+ import yaml
66
+ import json
67
+
68
+ config_file_settings = ConfigFileSettings.model_validate_json(
69
+ json.dumps(yaml.safe_load(f))
70
+ )
71
+ else:
72
+ config_file_settings = ConfigFileSettings.model_validate_json(
73
+ f.read()
74
+ )
75
+ server_settings = ServerSettings.model_validate(config_file_settings)
76
+ model_settings = config_file_settings.models
77
+ else:
78
+ server_settings = parse_model_from_args(ServerSettings, args)
79
+ model_settings = [parse_model_from_args(ModelSettings, args)]
80
+ except Exception as e:
81
+ print(e, file=sys.stderr)
82
+ parser.print_help()
83
+ sys.exit(1)
84
+ assert server_settings is not None
85
+ assert model_settings is not None
86
+ app = create_app(
87
+ server_settings=server_settings,
88
+ model_settings=model_settings,
89
+ )
90
+ uvicorn.run(
91
+ app,
92
+ host=os.getenv("HOST", server_settings.host),
93
+ port=int(os.getenv("PORT", server_settings.port)),
94
+ ssl_keyfile=server_settings.ssl_keyfile,
95
+ ssl_certfile=server_settings.ssl_certfile,
96
+ )
97
+
98
+
99
+ if __name__ == "__main__":
100
+ main()
llama_cpp/server/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (177 Bytes). View file
 
llama_cpp/server/__pycache__/__main__.cpython-311.pyc ADDED
Binary file (4.12 kB). View file
 
llama_cpp/server/__pycache__/app.cpython-311.pyc ADDED
Binary file (23.4 kB). View file
 
llama_cpp/server/__pycache__/cli.cpython-311.pyc ADDED
Binary file (5.44 kB). View file
 
llama_cpp/server/__pycache__/errors.cpython-311.pyc ADDED
Binary file (8.14 kB). View file
 
llama_cpp/server/__pycache__/model.cpython-311.pyc ADDED
Binary file (12.7 kB). View file
 
llama_cpp/server/__pycache__/settings.cpython-311.pyc ADDED
Binary file (11.5 kB). View file
 
llama_cpp/server/__pycache__/types.cpython-311.pyc ADDED
Binary file (15.8 kB). View file
 
llama_cpp/server/app.py ADDED
@@ -0,0 +1,597 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import json
5
+ import typing
6
+ import contextlib
7
+
8
+ from anyio import Lock
9
+ from functools import partial
10
+ from typing import List, Optional, Union, Dict
11
+
12
+ import llama_cpp
13
+
14
+ import anyio
15
+ from anyio.streams.memory import MemoryObjectSendStream
16
+ from starlette.concurrency import run_in_threadpool, iterate_in_threadpool
17
+ from fastapi import Depends, FastAPI, APIRouter, Request, HTTPException, status, Body
18
+ from fastapi.middleware import Middleware
19
+ from fastapi.middleware.cors import CORSMiddleware
20
+ from fastapi.security import HTTPBearer
21
+ from sse_starlette.sse import EventSourceResponse
22
+ from starlette_context.plugins import RequestIdPlugin # type: ignore
23
+ from starlette_context.middleware import RawContextMiddleware
24
+
25
+ from llama_cpp.server.model import (
26
+ LlamaProxy,
27
+ )
28
+ from llama_cpp.server.settings import (
29
+ ConfigFileSettings,
30
+ Settings,
31
+ ModelSettings,
32
+ ServerSettings,
33
+ )
34
+ from llama_cpp.server.types import (
35
+ CreateCompletionRequest,
36
+ CreateEmbeddingRequest,
37
+ CreateChatCompletionRequest,
38
+ ModelList,
39
+ TokenizeInputRequest,
40
+ TokenizeInputResponse,
41
+ TokenizeInputCountResponse,
42
+ DetokenizeInputRequest,
43
+ DetokenizeInputResponse,
44
+ )
45
+ from llama_cpp.server.errors import RouteErrorHandler
46
+
47
+
48
+ router = APIRouter(route_class=RouteErrorHandler)
49
+
50
+ _server_settings: Optional[ServerSettings] = None
51
+
52
+
53
+ def set_server_settings(server_settings: ServerSettings):
54
+ global _server_settings
55
+ _server_settings = server_settings
56
+
57
+
58
+ def get_server_settings():
59
+ yield _server_settings
60
+
61
+
62
+ _llama_proxy: Optional[LlamaProxy] = None
63
+
64
+ llama_outer_lock = Lock()
65
+ llama_inner_lock = Lock()
66
+
67
+
68
+ def set_llama_proxy(model_settings: List[ModelSettings]):
69
+ global _llama_proxy
70
+ _llama_proxy = LlamaProxy(models=model_settings)
71
+
72
+
73
+ async def get_llama_proxy():
74
+ # NOTE: This double lock allows the currently streaming llama model to
75
+ # check if any other requests are pending in the same thread and cancel
76
+ # the stream if so.
77
+ await llama_outer_lock.acquire()
78
+ release_outer_lock = True
79
+ try:
80
+ await llama_inner_lock.acquire()
81
+ try:
82
+ llama_outer_lock.release()
83
+ release_outer_lock = False
84
+ yield _llama_proxy
85
+ finally:
86
+ llama_inner_lock.release()
87
+ finally:
88
+ if release_outer_lock:
89
+ llama_outer_lock.release()
90
+
91
+
92
+ _ping_message_factory: typing.Optional[typing.Callable[[], bytes]] = None
93
+
94
+
95
+ def set_ping_message_factory(factory: typing.Callable[[], bytes]):
96
+ global _ping_message_factory
97
+ _ping_message_factory = factory
98
+
99
+
100
+ def create_app(
101
+ settings: Settings | None = None,
102
+ server_settings: ServerSettings | None = None,
103
+ model_settings: List[ModelSettings] | None = None,
104
+ ):
105
+ config_file = os.environ.get("CONFIG_FILE", None)
106
+ if config_file is not None:
107
+ if not os.path.exists(config_file):
108
+ raise ValueError(f"Config file {config_file} not found!")
109
+ with open(config_file, "rb") as f:
110
+ # Check if yaml file
111
+ if config_file.endswith(".yaml") or config_file.endswith(".yml"):
112
+ import yaml
113
+
114
+ config_file_settings = ConfigFileSettings.model_validate_json(
115
+ json.dumps(yaml.safe_load(f))
116
+ )
117
+ else:
118
+ config_file_settings = ConfigFileSettings.model_validate_json(f.read())
119
+ server_settings = ServerSettings.model_validate(config_file_settings)
120
+ model_settings = config_file_settings.models
121
+
122
+ if server_settings is None and model_settings is None:
123
+ if settings is None:
124
+ settings = Settings()
125
+ server_settings = ServerSettings.model_validate(settings)
126
+ model_settings = [ModelSettings.model_validate(settings)]
127
+
128
+ assert (
129
+ server_settings is not None and model_settings is not None
130
+ ), "server_settings and model_settings must be provided together"
131
+
132
+ set_server_settings(server_settings)
133
+ middleware = [Middleware(RawContextMiddleware, plugins=(RequestIdPlugin(),))]
134
+ app = FastAPI(
135
+ middleware=middleware,
136
+ title="🦙 llama.cpp Python API",
137
+ version=llama_cpp.__version__,
138
+ root_path=server_settings.root_path,
139
+ )
140
+ app.add_middleware(
141
+ CORSMiddleware,
142
+ allow_origins=["*"],
143
+ allow_credentials=True,
144
+ allow_methods=["*"],
145
+ allow_headers=["*"],
146
+ )
147
+ app.include_router(router)
148
+
149
+ assert model_settings is not None
150
+ set_llama_proxy(model_settings=model_settings)
151
+
152
+ if server_settings.disable_ping_events:
153
+ set_ping_message_factory(lambda: bytes())
154
+
155
+ return app
156
+
157
+
158
+ def prepare_request_resources(
159
+ body: CreateCompletionRequest | CreateChatCompletionRequest,
160
+ llama_proxy: LlamaProxy,
161
+ body_model: str | None,
162
+ kwargs,
163
+ ) -> llama_cpp.Llama:
164
+ if llama_proxy is None:
165
+ raise HTTPException(
166
+ status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
167
+ detail="Service is not available",
168
+ )
169
+ llama = llama_proxy(body_model)
170
+ if body.logit_bias is not None:
171
+ kwargs["logit_bias"] = (
172
+ _logit_bias_tokens_to_input_ids(llama, body.logit_bias)
173
+ if body.logit_bias_type == "tokens"
174
+ else body.logit_bias
175
+ )
176
+
177
+ if body.grammar is not None:
178
+ kwargs["grammar"] = llama_cpp.LlamaGrammar.from_string(body.grammar)
179
+
180
+ if body.min_tokens > 0:
181
+ _min_tokens_logits_processor = llama_cpp.LogitsProcessorList(
182
+ [llama_cpp.MinTokensLogitsProcessor(body.min_tokens, llama.token_eos())]
183
+ )
184
+ if "logits_processor" not in kwargs:
185
+ kwargs["logits_processor"] = _min_tokens_logits_processor
186
+ else:
187
+ kwargs["logits_processor"].extend(_min_tokens_logits_processor)
188
+ return llama
189
+
190
+
191
+ async def get_event_publisher(
192
+ request: Request,
193
+ inner_send_chan: MemoryObjectSendStream[typing.Any],
194
+ body: CreateCompletionRequest | CreateChatCompletionRequest,
195
+ body_model: str | None,
196
+ llama_call,
197
+ kwargs,
198
+ ):
199
+ server_settings = next(get_server_settings())
200
+ interrupt_requests = (
201
+ server_settings.interrupt_requests if server_settings else False
202
+ )
203
+ async with contextlib.asynccontextmanager(get_llama_proxy)() as llama_proxy:
204
+ llama = prepare_request_resources(body, llama_proxy, body_model, kwargs)
205
+ async with inner_send_chan:
206
+ try:
207
+ iterator = await run_in_threadpool(llama_call, llama, **kwargs)
208
+ async for chunk in iterate_in_threadpool(iterator):
209
+ await inner_send_chan.send(dict(data=json.dumps(chunk)))
210
+ if await request.is_disconnected():
211
+ raise anyio.get_cancelled_exc_class()()
212
+ if interrupt_requests and llama_outer_lock.locked():
213
+ await inner_send_chan.send(dict(data="[DONE]"))
214
+ raise anyio.get_cancelled_exc_class()()
215
+ await inner_send_chan.send(dict(data="[DONE]"))
216
+ except anyio.get_cancelled_exc_class() as e:
217
+ print("disconnected")
218
+ with anyio.move_on_after(1, shield=True):
219
+ print(
220
+ f"Disconnected from client (via refresh/close) {request.client}"
221
+ )
222
+ raise e
223
+
224
+
225
+ def _logit_bias_tokens_to_input_ids(
226
+ llama: llama_cpp.Llama,
227
+ logit_bias: Dict[str, float],
228
+ ) -> Dict[str, float]:
229
+ to_bias: Dict[str, float] = {}
230
+ for token, score in logit_bias.items():
231
+ token = token.encode("utf-8")
232
+ for input_id in llama.tokenize(token, add_bos=False, special=True):
233
+ to_bias[str(input_id)] = score
234
+ return to_bias
235
+
236
+
237
+ # Setup Bearer authentication scheme
238
+ bearer_scheme = HTTPBearer(auto_error=False)
239
+
240
+
241
+ async def authenticate(
242
+ settings: Settings = Depends(get_server_settings),
243
+ authorization: Optional[str] = Depends(bearer_scheme),
244
+ ):
245
+ # Skip API key check if it's not set in settings
246
+ if settings.api_key is None:
247
+ return True
248
+
249
+ # check bearer credentials against the api_key
250
+ if authorization and authorization.credentials == settings.api_key:
251
+ # api key is valid
252
+ return authorization.credentials
253
+
254
+ # raise http error 401
255
+ raise HTTPException(
256
+ status_code=status.HTTP_401_UNAUTHORIZED,
257
+ detail="Invalid API key",
258
+ )
259
+
260
+
261
+ openai_v1_tag = "OpenAI V1"
262
+
263
+
264
+ @router.post(
265
+ "/v1/completions",
266
+ summary="Completion",
267
+ dependencies=[Depends(authenticate)],
268
+ response_model=Union[
269
+ llama_cpp.CreateCompletionResponse,
270
+ str,
271
+ ],
272
+ responses={
273
+ "200": {
274
+ "description": "Successful Response",
275
+ "content": {
276
+ "application/json": {
277
+ "schema": {
278
+ "anyOf": [
279
+ {"$ref": "#/components/schemas/CreateCompletionResponse"}
280
+ ],
281
+ "title": "Completion response, when stream=False",
282
+ }
283
+ },
284
+ "text/event-stream": {
285
+ "schema": {
286
+ "type": "string",
287
+ "title": "Server Side Streaming response, when stream=True. "
288
+ + "See SSE format: https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format", # noqa: E501
289
+ "example": """data: {... see CreateCompletionResponse ...} \\n\\n data: ... \\n\\n ... data: [DONE]""",
290
+ }
291
+ },
292
+ },
293
+ }
294
+ },
295
+ tags=[openai_v1_tag],
296
+ )
297
+ @router.post(
298
+ "/v1/engines/copilot-codex/completions",
299
+ include_in_schema=False,
300
+ dependencies=[Depends(authenticate)],
301
+ tags=[openai_v1_tag],
302
+ )
303
+ async def create_completion(
304
+ request: Request,
305
+ body: CreateCompletionRequest,
306
+ ) -> llama_cpp.Completion:
307
+ if isinstance(body.prompt, list):
308
+ assert len(body.prompt) <= 1
309
+ body.prompt = body.prompt[0] if len(body.prompt) > 0 else ""
310
+
311
+ body_model = (
312
+ body.model
313
+ if request.url.path != "/v1/engines/copilot-codex/completions"
314
+ else "copilot-codex"
315
+ )
316
+
317
+ exclude = {
318
+ "n",
319
+ "best_of",
320
+ "logit_bias_type",
321
+ "user",
322
+ "min_tokens",
323
+ }
324
+ kwargs = body.model_dump(exclude=exclude)
325
+
326
+ # handle streaming request
327
+ if kwargs.get("stream", False):
328
+ send_chan, recv_chan = anyio.create_memory_object_stream(10)
329
+ return EventSourceResponse(
330
+ recv_chan,
331
+ data_sender_callable=partial( # type: ignore
332
+ get_event_publisher,
333
+ request=request,
334
+ inner_send_chan=send_chan,
335
+ body=body,
336
+ body_model=body_model,
337
+ llama_call=llama_cpp.Llama.__call__,
338
+ kwargs=kwargs,
339
+ ),
340
+ sep="\n",
341
+ ping_message_factory=_ping_message_factory,
342
+ )
343
+
344
+ # handle regular request
345
+ async with contextlib.asynccontextmanager(get_llama_proxy)() as llama_proxy:
346
+ llama = prepare_request_resources(body, llama_proxy, body_model, kwargs)
347
+
348
+ if await request.is_disconnected():
349
+ print(
350
+ f"Disconnected from client (via refresh/close) before llm invoked {request.client}"
351
+ )
352
+ raise HTTPException(
353
+ status_code=status.HTTP_400_BAD_REQUEST,
354
+ detail="Client closed request",
355
+ )
356
+
357
+ return await run_in_threadpool(llama, **kwargs)
358
+
359
+
360
+ @router.post(
361
+ "/v1/embeddings",
362
+ summary="Embedding",
363
+ dependencies=[Depends(authenticate)],
364
+ tags=[openai_v1_tag],
365
+ )
366
+ async def create_embedding(
367
+ request: CreateEmbeddingRequest,
368
+ llama_proxy: LlamaProxy = Depends(get_llama_proxy),
369
+ ):
370
+ return await run_in_threadpool(
371
+ llama_proxy(request.model).create_embedding,
372
+ **request.model_dump(exclude={"user"}),
373
+ )
374
+
375
+
376
+ @router.post(
377
+ "/v1/chat/completions",
378
+ summary="Chat",
379
+ dependencies=[Depends(authenticate)],
380
+ response_model=Union[llama_cpp.ChatCompletion, str],
381
+ responses={
382
+ "200": {
383
+ "description": "Successful Response",
384
+ "content": {
385
+ "application/json": {
386
+ "schema": {
387
+ "anyOf": [
388
+ {
389
+ "$ref": "#/components/schemas/CreateChatCompletionResponse"
390
+ }
391
+ ],
392
+ "title": "Completion response, when stream=False",
393
+ }
394
+ },
395
+ "text/event-stream": {
396
+ "schema": {
397
+ "type": "string",
398
+ "title": "Server Side Streaming response, when stream=True"
399
+ + "See SSE format: https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format", # noqa: E501
400
+ "example": """data: {... see CreateChatCompletionResponse ...} \\n\\n data: ... \\n\\n ... data: [DONE]""",
401
+ }
402
+ },
403
+ },
404
+ }
405
+ },
406
+ tags=[openai_v1_tag],
407
+ )
408
+ async def create_chat_completion(
409
+ request: Request,
410
+ body: CreateChatCompletionRequest = Body(
411
+ openapi_examples={
412
+ "normal": {
413
+ "summary": "Chat Completion",
414
+ "value": {
415
+ "model": "gpt-3.5-turbo",
416
+ "messages": [
417
+ {"role": "system", "content": "You are a helpful assistant."},
418
+ {"role": "user", "content": "What is the capital of France?"},
419
+ ],
420
+ },
421
+ },
422
+ "json_mode": {
423
+ "summary": "JSON Mode",
424
+ "value": {
425
+ "model": "gpt-3.5-turbo",
426
+ "messages": [
427
+ {"role": "system", "content": "You are a helpful assistant."},
428
+ {"role": "user", "content": "Who won the world series in 2020"},
429
+ ],
430
+ "response_format": {"type": "json_object"},
431
+ },
432
+ },
433
+ "tool_calling": {
434
+ "summary": "Tool Calling",
435
+ "value": {
436
+ "model": "gpt-3.5-turbo",
437
+ "messages": [
438
+ {"role": "system", "content": "You are a helpful assistant."},
439
+ {"role": "user", "content": "Extract Jason is 30 years old."},
440
+ ],
441
+ "tools": [
442
+ {
443
+ "type": "function",
444
+ "function": {
445
+ "name": "User",
446
+ "description": "User record",
447
+ "parameters": {
448
+ "type": "object",
449
+ "properties": {
450
+ "name": {"type": "string"},
451
+ "age": {"type": "number"},
452
+ },
453
+ "required": ["name", "age"],
454
+ },
455
+ },
456
+ }
457
+ ],
458
+ "tool_choice": {
459
+ "type": "function",
460
+ "function": {
461
+ "name": "User",
462
+ },
463
+ },
464
+ },
465
+ },
466
+ "logprobs": {
467
+ "summary": "Logprobs",
468
+ "value": {
469
+ "model": "gpt-3.5-turbo",
470
+ "messages": [
471
+ {"role": "system", "content": "You are a helpful assistant."},
472
+ {"role": "user", "content": "What is the capital of France?"},
473
+ ],
474
+ "logprobs": True,
475
+ "top_logprobs": 10,
476
+ },
477
+ },
478
+ }
479
+ ),
480
+ ) -> llama_cpp.ChatCompletion:
481
+ # This is a workaround for an issue in FastAPI dependencies
482
+ # where the dependency is cleaned up before a StreamingResponse
483
+ # is complete.
484
+ # https://github.com/tiangolo/fastapi/issues/11143
485
+
486
+ body_model = body.model
487
+ exclude = {
488
+ "n",
489
+ "logit_bias_type",
490
+ "user",
491
+ "min_tokens",
492
+ }
493
+ kwargs = body.model_dump(exclude=exclude)
494
+
495
+ # handle streaming request
496
+ if kwargs.get("stream", False):
497
+ send_chan, recv_chan = anyio.create_memory_object_stream(10)
498
+ return EventSourceResponse(
499
+ recv_chan,
500
+ data_sender_callable=partial( # type: ignore
501
+ get_event_publisher,
502
+ request=request,
503
+ inner_send_chan=send_chan,
504
+ body=body,
505
+ body_model=body_model,
506
+ llama_call=llama_cpp.Llama.create_chat_completion,
507
+ kwargs=kwargs,
508
+ ),
509
+ sep="\n",
510
+ ping_message_factory=_ping_message_factory,
511
+ )
512
+
513
+ # handle regular request
514
+ async with contextlib.asynccontextmanager(get_llama_proxy)() as llama_proxy:
515
+ llama = prepare_request_resources(body, llama_proxy, body_model, kwargs)
516
+
517
+ if await request.is_disconnected():
518
+ print(
519
+ f"Disconnected from client (via refresh/close) before llm invoked {request.client}"
520
+ )
521
+ raise HTTPException(
522
+ status_code=status.HTTP_400_BAD_REQUEST,
523
+ detail="Client closed request",
524
+ )
525
+
526
+ return await run_in_threadpool(llama.create_chat_completion, **kwargs)
527
+
528
+
529
+ @router.get(
530
+ "/v1/models",
531
+ summary="Models",
532
+ dependencies=[Depends(authenticate)],
533
+ tags=[openai_v1_tag],
534
+ )
535
+ async def get_models(
536
+ llama_proxy: LlamaProxy = Depends(get_llama_proxy),
537
+ ) -> ModelList:
538
+ return {
539
+ "object": "list",
540
+ "data": [
541
+ {
542
+ "id": model_alias,
543
+ "object": "model",
544
+ "owned_by": "me",
545
+ "permissions": [],
546
+ }
547
+ for model_alias in llama_proxy
548
+ ],
549
+ }
550
+
551
+
552
+ extras_tag = "Extras"
553
+
554
+
555
+ @router.post(
556
+ "/extras/tokenize",
557
+ summary="Tokenize",
558
+ dependencies=[Depends(authenticate)],
559
+ tags=[extras_tag],
560
+ )
561
+ async def tokenize(
562
+ body: TokenizeInputRequest,
563
+ llama_proxy: LlamaProxy = Depends(get_llama_proxy),
564
+ ) -> TokenizeInputResponse:
565
+ tokens = llama_proxy(body.model).tokenize(body.input.encode("utf-8"), special=True)
566
+
567
+ return TokenizeInputResponse(tokens=tokens)
568
+
569
+
570
+ @router.post(
571
+ "/extras/tokenize/count",
572
+ summary="Tokenize Count",
573
+ dependencies=[Depends(authenticate)],
574
+ tags=[extras_tag],
575
+ )
576
+ async def count_query_tokens(
577
+ body: TokenizeInputRequest,
578
+ llama_proxy: LlamaProxy = Depends(get_llama_proxy),
579
+ ) -> TokenizeInputCountResponse:
580
+ tokens = llama_proxy(body.model).tokenize(body.input.encode("utf-8"), special=True)
581
+
582
+ return TokenizeInputCountResponse(count=len(tokens))
583
+
584
+
585
+ @router.post(
586
+ "/extras/detokenize",
587
+ summary="Detokenize",
588
+ dependencies=[Depends(authenticate)],
589
+ tags=[extras_tag],
590
+ )
591
+ async def detokenize(
592
+ body: DetokenizeInputRequest,
593
+ llama_proxy: LlamaProxy = Depends(get_llama_proxy),
594
+ ) -> DetokenizeInputResponse:
595
+ text = llama_proxy(body.model).detokenize(body.tokens).decode("utf-8")
596
+
597
+ return DetokenizeInputResponse(text=text)