MebinThattil commited on
Commit
5d62acd
·
verified ·
1 Parent(s): ad0341b

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +2 -35
  2. chatapp.py +30 -0
  3. llama_cpp/.DS_Store +0 -0
  4. llama_cpp/__init__.py +4 -0
  5. llama_cpp/__pycache__/__init__.cpython-310.pyc +0 -0
  6. llama_cpp/__pycache__/_ctypes_extensions.cpython-310.pyc +0 -0
  7. llama_cpp/__pycache__/_ggml.cpython-310.pyc +0 -0
  8. llama_cpp/__pycache__/_internals.cpython-310.pyc +0 -0
  9. llama_cpp/__pycache__/_logger.cpython-310.pyc +0 -0
  10. llama_cpp/__pycache__/_utils.cpython-310.pyc +0 -0
  11. llama_cpp/__pycache__/llama.cpython-310.pyc +0 -0
  12. llama_cpp/__pycache__/llama_cache.cpython-310.pyc +0 -0
  13. llama_cpp/__pycache__/llama_chat_format.cpython-310.pyc +0 -0
  14. llama_cpp/__pycache__/llama_cpp.cpython-310.pyc +0 -0
  15. llama_cpp/__pycache__/llama_grammar.cpython-310.pyc +0 -0
  16. llama_cpp/__pycache__/llama_speculative.cpython-310.pyc +0 -0
  17. llama_cpp/__pycache__/llama_tokenizer.cpython-310.pyc +0 -0
  18. llama_cpp/__pycache__/llama_types.cpython-310.pyc +0 -0
  19. llama_cpp/__pycache__/llava_cpp.cpython-310.pyc +0 -0
  20. llama_cpp/_ctypes_extensions.py +131 -0
  21. llama_cpp/_ggml.py +12 -0
  22. llama_cpp/_internals.py +879 -0
  23. llama_cpp/_logger.py +47 -0
  24. llama_cpp/_utils.py +78 -0
  25. llama_cpp/lib/libggml-base.dylib +3 -0
  26. llama_cpp/lib/libggml-blas.dylib +0 -0
  27. llama_cpp/lib/libggml-cpu.dylib +3 -0
  28. llama_cpp/lib/libggml-metal.dylib +3 -0
  29. llama_cpp/lib/libggml.dylib +0 -0
  30. llama_cpp/lib/libllama.dylib +3 -0
  31. llama_cpp/lib/libllava.dylib +3 -0
  32. llama_cpp/llama.py +2418 -0
  33. llama_cpp/llama_cache.py +155 -0
  34. llama_cpp/llama_chat_format.py +0 -0
  35. llama_cpp/llama_cpp.py +0 -0
  36. llama_cpp/llama_grammar.py +953 -0
  37. llama_cpp/llama_speculative.py +64 -0
  38. llama_cpp/llama_tokenizer.py +120 -0
  39. llama_cpp/llama_types.py +316 -0
  40. llama_cpp/llava_cpp.py +158 -0
  41. llama_cpp/py.typed +0 -0
  42. llama_cpp/server/__init__.py +0 -0
  43. llama_cpp/server/__main__.py +100 -0
  44. llama_cpp/server/__pycache__/__init__.cpython-310.pyc +0 -0
  45. llama_cpp/server/__pycache__/__main__.cpython-310.pyc +0 -0
  46. llama_cpp/server/__pycache__/app.cpython-310.pyc +0 -0
  47. llama_cpp/server/__pycache__/cli.cpython-310.pyc +0 -0
  48. llama_cpp/server/__pycache__/errors.cpython-310.pyc +0 -0
  49. llama_cpp/server/__pycache__/model.cpython-310.pyc +0 -0
  50. llama_cpp/server/__pycache__/settings.cpython-310.pyc +0 -0
.gitattributes CHANGED
@@ -1,35 +1,2 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ *.dylib filter=lfs diff=lfs merge=lfs -text
2
+ tinyllama-1.1B-q4.gguf filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
chatapp.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from llama_cpp import Llama
3
+
4
+ if len(sys.argv) < 2:
5
+ print("Model path not provided as argument")
6
+ print("Eg. Usage: $ python chatapp.py path/to/model.gguf")
7
+ sys.exit(1)
8
+
9
+ llm = Llama(
10
+ model_path=sys.argv[1],
11
+ n_ctx=512,
12
+ n_threads=4,
13
+ n_gpu_layers=1,
14
+ verbose=False
15
+ )
16
+
17
+ print("Chat with Llama (type 'exit' to quit)\n")
18
+
19
+ while True:
20
+ user_input = input("You: ")
21
+ if user_input.lower() in ["exit", "quit"]: break
22
+
23
+ prompt = f"### Human: {user_input}\n### Assistant:"
24
+ output = llm(
25
+ prompt,
26
+ max_tokens=100,
27
+ stop=["###", "### Human:", "\n###"]
28
+ )
29
+ response = output["choices"][0]["text"].strip()
30
+ print("Bot:", response)
llama_cpp/.DS_Store ADDED
Binary file (6.15 kB). View file
 
llama_cpp/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .llama_cpp import *
2
+ from .llama import *
3
+
4
+ __version__ = "0.3.9"
llama_cpp/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (301 Bytes). View file
 
llama_cpp/__pycache__/_ctypes_extensions.cpython-310.pyc ADDED
Binary file (3.51 kB). View file
 
llama_cpp/__pycache__/_ggml.cpython-310.pyc ADDED
Binary file (632 Bytes). View file
 
llama_cpp/__pycache__/_internals.cpython-310.pyc ADDED
Binary file (29.6 kB). View file
 
llama_cpp/__pycache__/_logger.cpython-310.pyc ADDED
Binary file (1.11 kB). View file
 
llama_cpp/__pycache__/_utils.cpython-310.pyc ADDED
Binary file (2.53 kB). View file
 
llama_cpp/__pycache__/llama.cpython-310.pyc ADDED
Binary file (56.5 kB). View file
 
llama_cpp/__pycache__/llama_cache.cpython-310.pyc ADDED
Binary file (5.84 kB). View file
 
llama_cpp/__pycache__/llama_chat_format.cpython-310.pyc ADDED
Binary file (76.7 kB). View file
 
llama_cpp/__pycache__/llama_cpp.cpython-310.pyc ADDED
Binary file (70.1 kB). View file
 
llama_cpp/__pycache__/llama_grammar.cpython-310.pyc ADDED
Binary file (25.3 kB). View file
 
llama_cpp/__pycache__/llama_speculative.cpython-310.pyc ADDED
Binary file (2.2 kB). View file
 
llama_cpp/__pycache__/llama_tokenizer.cpython-310.pyc ADDED
Binary file (4.53 kB). View file
 
llama_cpp/__pycache__/llama_types.cpython-310.pyc ADDED
Binary file (10.9 kB). View file
 
llama_cpp/__pycache__/llava_cpp.cpython-310.pyc ADDED
Binary file (2.99 kB). View file
 
llama_cpp/_ctypes_extensions.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import sys
4
+ import os
5
+ import ctypes
6
+ import functools
7
+ import pathlib
8
+
9
+ from typing import (
10
+ Any,
11
+ Callable,
12
+ List,
13
+ Union,
14
+ Optional,
15
+ TYPE_CHECKING,
16
+ TypeVar,
17
+ Generic,
18
+ )
19
+ from typing_extensions import TypeAlias
20
+
21
+
22
+ # Load the library
23
+ def load_shared_library(lib_base_name: str, base_path: pathlib.Path):
24
+ """Platform independent shared library loader"""
25
+ # Searching for the library in the current directory under the name "libllama" (default name
26
+ # for llamacpp) and "llama" (default name for this repo)
27
+ lib_paths: List[pathlib.Path] = []
28
+ # Determine the file extension based on the platform
29
+ if sys.platform.startswith("linux") or sys.platform.startswith("freebsd"):
30
+ lib_paths += [
31
+ base_path / f"lib{lib_base_name}.so",
32
+ ]
33
+ elif sys.platform == "darwin":
34
+ lib_paths += [
35
+ base_path / f"lib{lib_base_name}.so",
36
+ base_path / f"lib{lib_base_name}.dylib",
37
+ ]
38
+ elif sys.platform == "win32":
39
+ lib_paths += [
40
+ base_path / f"{lib_base_name}.dll",
41
+ base_path / f"lib{lib_base_name}.dll",
42
+ ]
43
+ else:
44
+ raise RuntimeError("Unsupported platform")
45
+
46
+ cdll_args = dict() # type: ignore
47
+
48
+ # Add the library directory to the DLL search path on Windows (if needed)
49
+ if sys.platform == "win32":
50
+ os.add_dll_directory(str(base_path))
51
+ os.environ["PATH"] = str(base_path) + os.pathsep + os.environ["PATH"]
52
+
53
+ if sys.platform == "win32" and sys.version_info >= (3, 8):
54
+ os.add_dll_directory(str(base_path))
55
+ if "CUDA_PATH" in os.environ:
56
+ os.add_dll_directory(os.path.join(os.environ["CUDA_PATH"], "bin"))
57
+ os.add_dll_directory(os.path.join(os.environ["CUDA_PATH"], "lib"))
58
+ if "HIP_PATH" in os.environ:
59
+ os.add_dll_directory(os.path.join(os.environ["HIP_PATH"], "bin"))
60
+ os.add_dll_directory(os.path.join(os.environ["HIP_PATH"], "lib"))
61
+ cdll_args["winmode"] = ctypes.RTLD_GLOBAL
62
+
63
+ # Try to load the shared library, handling potential errors
64
+ for lib_path in lib_paths:
65
+ if lib_path.exists():
66
+ try:
67
+ return ctypes.CDLL(str(lib_path), **cdll_args) # type: ignore
68
+ except Exception as e:
69
+ raise RuntimeError(f"Failed to load shared library '{lib_path}': {e}")
70
+
71
+ raise FileNotFoundError(
72
+ f"Shared library with base name '{lib_base_name}' not found"
73
+ )
74
+
75
+
76
+ # ctypes sane type hint helpers
77
+ #
78
+ # - Generic Pointer and Array types
79
+ # - PointerOrRef type with a type hinted byref function
80
+ #
81
+ # NOTE: Only use these for static type checking not for runtime checks
82
+ # no good will come of that
83
+
84
+ if TYPE_CHECKING:
85
+ CtypesCData = TypeVar("CtypesCData", bound=ctypes._CData) # type: ignore
86
+
87
+ CtypesArray: TypeAlias = ctypes.Array[CtypesCData] # type: ignore
88
+
89
+ CtypesPointer: TypeAlias = ctypes._Pointer[CtypesCData] # type: ignore
90
+
91
+ CtypesVoidPointer: TypeAlias = ctypes.c_void_p
92
+
93
+ class CtypesRef(Generic[CtypesCData]):
94
+ pass
95
+
96
+ CtypesPointerOrRef: TypeAlias = Union[
97
+ CtypesPointer[CtypesCData], CtypesRef[CtypesCData]
98
+ ]
99
+
100
+ CtypesFuncPointer: TypeAlias = ctypes._FuncPointer # type: ignore
101
+
102
+ F = TypeVar("F", bound=Callable[..., Any])
103
+
104
+
105
+ def ctypes_function_for_shared_library(lib: ctypes.CDLL):
106
+ """Decorator for defining ctypes functions with type hints"""
107
+
108
+ def ctypes_function(
109
+ name: str, argtypes: List[Any], restype: Any, enabled: bool = True
110
+ ):
111
+ def decorator(f: F) -> F:
112
+ if enabled:
113
+ func = getattr(lib, name)
114
+ func.argtypes = argtypes
115
+ func.restype = restype
116
+ functools.wraps(f)(func)
117
+ return func
118
+ else:
119
+ return f
120
+
121
+ return decorator
122
+
123
+ return ctypes_function
124
+
125
+
126
+ def _byref(obj: CtypesCData, offset: Optional[int] = None) -> CtypesRef[CtypesCData]:
127
+ """Type-annotated version of ctypes.byref"""
128
+ ...
129
+
130
+
131
+ byref = _byref if TYPE_CHECKING else ctypes.byref
llama_cpp/_ggml.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Internal module use at your own risk
2
+
3
+ This module provides a minimal interface for working with ggml tensors from llama-cpp-python
4
+ """
5
+ import os
6
+ import pathlib
7
+
8
+ import llama_cpp._ctypes_extensions as ctypes_ext
9
+
10
+ libggml_base_path = pathlib.Path(os.path.abspath(os.path.dirname(__file__))) / "lib"
11
+ libggml = ctypes_ext.load_shared_library("ggml", libggml_base_path)
12
+
llama_cpp/_internals.py ADDED
@@ -0,0 +1,879 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import ctypes
5
+
6
+ from typing import (
7
+ Dict,
8
+ List,
9
+ Tuple,
10
+ Optional,
11
+ Sequence,
12
+ )
13
+ from dataclasses import dataclass, field
14
+ from contextlib import ExitStack
15
+
16
+ import numpy as np
17
+ import numpy.typing as npt
18
+
19
+ from .llama_types import *
20
+ from .llama_grammar import LlamaGrammar
21
+ from ._utils import suppress_stdout_stderr
22
+
23
+ import llama_cpp.llama_cpp as llama_cpp
24
+
25
+
26
+ # Python wrappers over llama.h structs
27
+
28
+
29
+ class LlamaModel:
30
+ """Intermediate Python wrapper for a llama.cpp llama_model.
31
+ NOTE: For stability it's recommended you use the Llama class instead."""
32
+
33
+ def __init__(
34
+ self,
35
+ *,
36
+ path_model: str,
37
+ params: llama_cpp.llama_model_params,
38
+ verbose: bool = True,
39
+ ):
40
+ self.path_model = path_model
41
+ self.params = params
42
+ self.verbose = verbose
43
+ self._exit_stack = ExitStack()
44
+
45
+ model = None
46
+
47
+ if not os.path.exists(path_model):
48
+ raise ValueError(f"Model path does not exist: {path_model}")
49
+
50
+ with suppress_stdout_stderr(disable=verbose):
51
+ model = llama_cpp.llama_load_model_from_file(
52
+ self.path_model.encode("utf-8"), self.params
53
+ )
54
+
55
+ if model is None:
56
+ raise ValueError(f"Failed to load model from file: {path_model}")
57
+
58
+ vocab = llama_cpp.llama_model_get_vocab(model)
59
+
60
+ if vocab is None:
61
+ raise ValueError(f"Failed to get vocab from model: {path_model}")
62
+
63
+ self.model = model
64
+ self.vocab = vocab
65
+
66
+ def free_model():
67
+ if self.model is None:
68
+ return
69
+ llama_cpp.llama_free_model(self.model)
70
+ self.model = None
71
+
72
+ self._exit_stack.callback(free_model)
73
+
74
+ def close(self):
75
+ self._exit_stack.close()
76
+
77
+ def __del__(self):
78
+ self.close()
79
+
80
+ def vocab_type(self) -> int:
81
+ return llama_cpp.llama_vocab_type(self.model)
82
+
83
+ def n_vocab(self) -> int:
84
+ return llama_cpp.llama_n_vocab(self.vocab)
85
+
86
+ def n_ctx_train(self) -> int:
87
+ return llama_cpp.llama_n_ctx_train(self.model)
88
+
89
+ def n_embd(self) -> int:
90
+ return llama_cpp.llama_n_embd(self.model)
91
+
92
+ def rope_freq_scale_train(self) -> float:
93
+ return llama_cpp.llama_model_rope_freq_scale_train(self.model)
94
+
95
+ def desc(self) -> str:
96
+ buf = ctypes.create_string_buffer(1024)
97
+ llama_cpp.llama_model_desc(self.model, buf, 1024)
98
+ return buf.value.decode("utf-8")
99
+
100
+ def size(self) -> int:
101
+ return llama_cpp.llama_model_size(self.model)
102
+
103
+ def n_params(self) -> int:
104
+ return llama_cpp.llama_model_n_params(self.model)
105
+
106
+ def get_tensor(self, name: str) -> ctypes.c_void_p:
107
+ raise NotImplementedError("get_tensor is not implemented in llama.cpp")
108
+
109
+ # Vocab
110
+
111
+ def token_get_text(self, token: int) -> str:
112
+ return llama_cpp.llama_token_get_text(self.vocab, token).decode("utf-8")
113
+
114
+ def token_get_score(self, token: int) -> float:
115
+ return llama_cpp.llama_token_get_score(self.vocab, token)
116
+
117
+ def token_get_attr(self, token: int) -> int:
118
+ return llama_cpp.llama_token_get_attr(self.vocab, token)
119
+
120
+ # Special tokens
121
+
122
+ def token_bos(self) -> int:
123
+ return llama_cpp.llama_token_bos(self.vocab)
124
+
125
+ def token_eos(self) -> int:
126
+ return llama_cpp.llama_token_eos(self.vocab)
127
+
128
+ def token_cls(self) -> int:
129
+ return llama_cpp.llama_token_cls(self.vocab)
130
+
131
+ def token_sep(self) -> int:
132
+ return llama_cpp.llama_token_sep(self.vocab)
133
+
134
+ def token_nl(self) -> int:
135
+ return llama_cpp.llama_token_nl(self.vocab)
136
+
137
+ def token_prefix(self) -> int:
138
+ raise NotImplementedError("token_prefix is not implemented in llama.cpp")
139
+
140
+ def token_middle(self) -> int:
141
+ raise NotImplementedError("token_middle is not implemented in llama.cpp")
142
+
143
+ def token_suffix(self) -> int:
144
+ raise NotImplementedError("token_suffix is not implemented in llama.cpp")
145
+
146
+ def token_eot(self) -> int:
147
+ return llama_cpp.llama_token_eot(self.vocab)
148
+
149
+ def add_bos_token(self) -> bool:
150
+ return llama_cpp.llama_add_bos_token(self.vocab)
151
+
152
+ def add_eos_token(self) -> bool:
153
+ return llama_cpp.llama_add_eos_token(self.vocab)
154
+
155
+ # Tokenization
156
+
157
+ def tokenize(self, text: bytes, add_bos: bool, special: bool):
158
+ n_ctx = self.n_ctx_train()
159
+ tokens = (llama_cpp.llama_token * n_ctx)()
160
+ n_tokens = llama_cpp.llama_tokenize(
161
+ self.vocab, text, len(text), tokens, n_ctx, add_bos, special
162
+ )
163
+ if n_tokens < 0:
164
+ n_tokens = abs(n_tokens)
165
+ tokens = (llama_cpp.llama_token * n_tokens)()
166
+ n_tokens = llama_cpp.llama_tokenize(
167
+ self.vocab, text, len(text), tokens, n_tokens, add_bos, special
168
+ )
169
+ if n_tokens < 0:
170
+ raise RuntimeError(
171
+ f'Failed to tokenize: text="{text}" n_tokens={n_tokens}'
172
+ )
173
+ return list(tokens[:n_tokens])
174
+
175
+ def token_to_piece(self, token: int, special: bool = False) -> bytes:
176
+ buf = ctypes.create_string_buffer(32)
177
+ llama_cpp.llama_token_to_piece(self.vocab, token, buf, 32, 0, special)
178
+ return bytes(buf)
179
+
180
+ def detokenize(self, tokens: List[int], special: bool = False) -> bytes:
181
+ output = b""
182
+ size = 32
183
+ buffer = (ctypes.c_char * size)()
184
+ for token in tokens:
185
+ n = llama_cpp.llama_token_to_piece(
186
+ self.vocab, llama_cpp.llama_token(token), buffer, size, 0, special
187
+ )
188
+ assert n <= size
189
+ output += bytes(buffer[:n])
190
+ # NOTE: Llama1 models automatically added a space at the start of the prompt
191
+ # this line removes a leading space if the first token is a beginning of sentence token
192
+ return (
193
+ output[1:]
194
+ if len(tokens) > 0 and tokens[0] == self.token_bos() and output[0:1] == b" "
195
+ else output
196
+ )
197
+
198
+ # Extra
199
+ def metadata(self) -> Dict[str, str]:
200
+ metadata: Dict[str, str] = {}
201
+ buffer_size = 1024
202
+ buffer = ctypes.create_string_buffer(buffer_size)
203
+ # zero the buffer
204
+ buffer.value = b"\0" * buffer_size
205
+ # iterate over model keys
206
+ for i in range(llama_cpp.llama_model_meta_count(self.model)):
207
+ nbytes = llama_cpp.llama_model_meta_key_by_index(
208
+ self.model, i, buffer, buffer_size
209
+ )
210
+ if nbytes > buffer_size:
211
+ buffer_size = nbytes + 1
212
+ buffer = ctypes.create_string_buffer(buffer_size)
213
+ nbytes = llama_cpp.llama_model_meta_key_by_index(
214
+ self.model, i, buffer, buffer_size
215
+ )
216
+ key = buffer.value.decode("utf-8")
217
+ nbytes = llama_cpp.llama_model_meta_val_str_by_index(
218
+ self.model, i, buffer, buffer_size
219
+ )
220
+ if nbytes > buffer_size:
221
+ buffer_size = nbytes + 1
222
+ buffer = ctypes.create_string_buffer(buffer_size)
223
+ nbytes = llama_cpp.llama_model_meta_val_str_by_index(
224
+ self.model, i, buffer, buffer_size
225
+ )
226
+ value = buffer.value.decode("utf-8")
227
+ metadata[key] = value
228
+ return metadata
229
+
230
+ @staticmethod
231
+ def default_params():
232
+ """Get the default llama_model_params."""
233
+ return llama_cpp.llama_model_default_params()
234
+
235
+
236
+ class LlamaContext:
237
+ """Intermediate Python wrapper for a llama.cpp llama_context.
238
+ NOTE: For stability it's recommended you use the Llama class instead."""
239
+
240
+ def __init__(
241
+ self,
242
+ *,
243
+ model: LlamaModel,
244
+ params: llama_cpp.llama_context_params,
245
+ verbose: bool = True,
246
+ ):
247
+ self.model = model
248
+ self.params = params
249
+ self.verbose = verbose
250
+ self._exit_stack = ExitStack()
251
+
252
+ ctx = llama_cpp.llama_new_context_with_model(self.model.model, self.params)
253
+
254
+ if ctx is None:
255
+ raise ValueError("Failed to create llama_context")
256
+
257
+ self.ctx = ctx
258
+
259
+ def free_ctx():
260
+ if self.ctx is None:
261
+ return
262
+ llama_cpp.llama_free(self.ctx)
263
+ self.ctx = None
264
+
265
+ self._exit_stack.callback(free_ctx)
266
+
267
+ def close(self):
268
+ self._exit_stack.close()
269
+
270
+ def __del__(self):
271
+ self.close()
272
+
273
+ def n_ctx(self) -> int:
274
+ return llama_cpp.llama_n_ctx(self.ctx)
275
+
276
+ def pooling_type(self) -> int:
277
+ return llama_cpp.llama_pooling_type(self.ctx)
278
+
279
+ def kv_cache_clear(self):
280
+ llama_cpp.llama_kv_cache_clear(self.ctx)
281
+
282
+ def kv_cache_seq_rm(self, seq_id: int, p0: int, p1: int):
283
+ llama_cpp.llama_kv_cache_seq_rm(self.ctx, seq_id, p0, p1)
284
+
285
+ def kv_cache_seq_cp(self, seq_id_src: int, seq_id_dst: int, p0: int, p1: int):
286
+ llama_cpp.llama_kv_cache_seq_cp(self.ctx, seq_id_src, seq_id_dst, p0, p1)
287
+
288
+ def kv_cache_seq_keep(self, seq_id: int):
289
+ llama_cpp.llama_kv_cache_seq_keep(self.ctx, seq_id)
290
+
291
+ def kv_cache_seq_shift(self, seq_id: int, p0: int, p1: int, shift: int):
292
+ llama_cpp.llama_kv_cache_seq_add(self.ctx, seq_id, p0, p1, shift)
293
+
294
+ def get_state_size(self) -> int:
295
+ return llama_cpp.llama_get_state_size(self.ctx)
296
+
297
+ # TODO: copy_state_data
298
+
299
+ # TODO: set_state_data
300
+
301
+ # TODO: llama_load_session_file
302
+
303
+ # TODO: llama_save_session_file
304
+
305
+ def decode(self, batch: LlamaBatch):
306
+ return_code = llama_cpp.llama_decode(
307
+ self.ctx,
308
+ batch.batch,
309
+ )
310
+ if return_code != 0:
311
+ raise RuntimeError(f"llama_decode returned {return_code}")
312
+
313
+ def set_n_threads(self, n_threads: int, n_threads_batch: int):
314
+ llama_cpp.llama_set_n_threads(self.ctx, n_threads, n_threads_batch)
315
+
316
+ def get_logits(self):
317
+ return llama_cpp.llama_get_logits(self.ctx)
318
+
319
+ def get_logits_ith(self, i: int):
320
+ return llama_cpp.llama_get_logits_ith(self.ctx, i)
321
+
322
+ def get_embeddings(self):
323
+ return llama_cpp.llama_get_embeddings(self.ctx)
324
+
325
+ # Sampling functions
326
+
327
+ def set_rng_seed(self, seed: int):
328
+ # TODO: Fix
329
+ # llama_cpp.llama_set_rng_seed(self.ctx, seed)
330
+ raise NotImplementedError("set_rng_seed is not implemented in llama.cpp")
331
+
332
+ def sample_repetition_penalties(
333
+ self,
334
+ candidates: "_LlamaTokenDataArray",
335
+ last_tokens_data: "llama_cpp.Array[llama_cpp.llama_token]",
336
+ penalty_last_n: int,
337
+ penalty_repeat: float,
338
+ penalty_freq: float,
339
+ penalty_present: float,
340
+ ):
341
+ # llama_cpp.llama_sample_repetition_penalties(
342
+ # self.ctx,
343
+ # llama_cpp.byref(candidates.candidates),
344
+ # last_tokens_data,
345
+ # penalty_last_n,
346
+ # penalty_repeat,
347
+ # penalty_freq,
348
+ # penalty_present,
349
+ # )
350
+ raise NotImplementedError("sample_repetition_penalties is not implemented in llama.cpp")
351
+
352
+ def sample_softmax(self, candidates: "_LlamaTokenDataArray"):
353
+ # llama_cpp.llama_sample_softmax(
354
+ # self.ctx,
355
+ # llama_cpp.byref(candidates.candidates),
356
+ # )
357
+ raise NotImplementedError("sample_softmax is not implemented in llama.cpp")
358
+
359
+ def sample_top_k(self, candidates: "_LlamaTokenDataArray", k: int, min_keep: int):
360
+ # llama_cpp.llama_sample_top_k(
361
+ # self.ctx, llama_cpp.byref(candidates.candidates), k, min_keep
362
+ # )
363
+ raise NotImplementedError("sample_top_k is not implemented in llama.cpp")
364
+
365
+ def sample_top_p(self, candidates: "_LlamaTokenDataArray", p: float, min_keep: int):
366
+ # llama_cpp.llama_sample_top_p(
367
+ # self.ctx, llama_cpp.byref(candidates.candidates), p, min_keep
368
+ # )
369
+ raise NotImplementedError("sample_top_p is not implemented in llama.cpp")
370
+
371
+ def sample_min_p(self, candidates: "_LlamaTokenDataArray", p: float, min_keep: int):
372
+ # llama_cpp.llama_sample_min_p(
373
+ # self.ctx, llama_cpp.byref(candidates.candidates), p, min_keep
374
+ # )
375
+ raise NotImplementedError("sample_min_p is not implemented in llama.cpp")
376
+
377
+ def sample_typical(
378
+ self, candidates: "_LlamaTokenDataArray", p: float, min_keep: int
379
+ ):
380
+ # llama_cpp.llama_sample_typical(
381
+ # self.ctx, llama_cpp.byref(candidates.candidates), p, min_keep
382
+ # )
383
+ raise NotImplementedError("sample_typical is not implemented in llama.cpp")
384
+
385
+ def sample_temp(self, candidates: "_LlamaTokenDataArray", temp: float):
386
+ # llama_cpp.llama_sample_temp(
387
+ # self.ctx, llama_cpp.byref(candidates.candidates), temp
388
+ # )
389
+ raise NotImplementedError("sample_temp is not implemented in llama.cpp")
390
+
391
+ def sample_grammar(self, candidates: "_LlamaTokenDataArray", grammar: LlamaGrammar):
392
+ # llama_cpp.llama_sample_grammar(
393
+ # self.ctx,
394
+ # llama_cpp.byref(candidates.candidates),
395
+ # grammar.grammar,
396
+ # )
397
+ raise NotImplementedError("sample_grammar is not implemented in llama.cpp")
398
+
399
+ def sample_token_mirostat(
400
+ self,
401
+ candidates: "_LlamaTokenDataArray",
402
+ tau: float,
403
+ eta: float,
404
+ m: int,
405
+ mu: llama_cpp.CtypesPointerOrRef[ctypes.c_float],
406
+ ) -> int:
407
+ raise NotImplementedError("sample_token_mirostat is not implemented in llama.cpp")
408
+ # return llama_cpp.llama_sample_token_mirostat(
409
+ # self.ctx,
410
+ # llama_cpp.byref(candidates.candidates),
411
+ # tau,
412
+ # eta,
413
+ # m,
414
+ # mu,
415
+ # )
416
+
417
+ def sample_token_mirostat_v2(
418
+ self,
419
+ candidates: "_LlamaTokenDataArray",
420
+ tau: float,
421
+ eta: float,
422
+ mu: llama_cpp.CtypesPointerOrRef[ctypes.c_float],
423
+ ) -> int:
424
+ raise NotImplementedError("sample_token_mirostat_v2 is not implemented in llama.cpp")
425
+ # return llama_cpp.llama_sample_token_mirostat_v2(
426
+ # self.ctx,
427
+ # llama_cpp.byref(candidates.candidates),
428
+ # tau,
429
+ # eta,
430
+ # mu,
431
+ # )
432
+
433
+ def sample_token_greedy(self, candidates: "_LlamaTokenDataArray") -> int:
434
+ raise NotImplementedError("sample_token_greedy is not implemented in llama.cpp")
435
+ # return llama_cpp.llama_sample_token_greedy(
436
+ # self.ctx,
437
+ # llama_cpp.byref(candidates.candidates),
438
+ # )
439
+
440
+ def sample_token(self, candidates: "_LlamaTokenDataArray") -> int:
441
+ raise NotImplementedError("sample_token is not implemented in llama.cpp")
442
+ # return llama_cpp.llama_sample_token(
443
+ # self.ctx,
444
+ # llama_cpp.byref(candidates.candidates),
445
+ # )
446
+
447
+ # Grammar
448
+ def grammar_accept_token(self, grammar: LlamaGrammar, token: int):
449
+ raise NotImplementedError("grammar_accept_token is not implemented in llama.cpp")
450
+ # llama_cpp.llama_grammar_accept_token(grammar.grammar, self.ctx, token)
451
+
452
+ def reset_timings(self):
453
+ llama_cpp.llama_perf_context_reset(self.ctx)
454
+
455
+ def print_timings(self):
456
+ llama_cpp.llama_perf_context_print(self.ctx)
457
+
458
+ # Utility functions
459
+ @staticmethod
460
+ def default_params():
461
+ """Get the default llama_context_params."""
462
+ return llama_cpp.llama_context_default_params()
463
+
464
+
465
+ class LlamaBatch:
466
+ def __init__(
467
+ self, *, n_tokens: int, embd: int, n_seq_max: int, verbose: bool = True
468
+ ):
469
+ self._n_tokens = n_tokens
470
+ self.embd = embd
471
+ self.n_seq_max = n_seq_max
472
+ self.verbose = verbose
473
+ self._exit_stack = ExitStack()
474
+
475
+ batch = llama_cpp.llama_batch_init(self._n_tokens, self.embd, self.n_seq_max)
476
+
477
+ if batch is None:
478
+ raise ValueError("Failed to create llama_batch")
479
+
480
+ self.batch = batch
481
+
482
+ def free_batch():
483
+ if self.batch is None:
484
+ return
485
+ llama_cpp.llama_batch_free(self.batch)
486
+ self.batch = None
487
+
488
+ self._exit_stack.callback(free_batch)
489
+
490
+ def close(self):
491
+ self._exit_stack.close()
492
+
493
+ def __del__(self):
494
+ self.close()
495
+
496
+ def n_tokens(self) -> int:
497
+ return self.batch.n_tokens
498
+
499
+ def reset(self):
500
+ self.batch.n_tokens = 0
501
+
502
+ def set_batch(self, batch: Sequence[int], n_past: int, logits_all: bool):
503
+ n_tokens = len(batch)
504
+ self.batch.n_tokens = n_tokens
505
+ for i in range(n_tokens):
506
+ self.batch.token[i] = batch[i]
507
+ self.batch.pos[i] = n_past + i
508
+ self.batch.seq_id[i][0] = 0
509
+ self.batch.n_seq_id[i] = 1
510
+ self.batch.logits[i] = logits_all
511
+ self.batch.logits[n_tokens - 1] = True
512
+
513
+ def add_sequence(self, batch: Sequence[int], seq_id: int, logits_all: bool):
514
+ n_tokens = len(batch)
515
+ n_tokens0 = self.batch.n_tokens
516
+ self.batch.n_tokens += n_tokens
517
+ for i in range(n_tokens):
518
+ j = n_tokens0 + i
519
+ self.batch.token[j] = batch[i]
520
+ self.batch.pos[j] = i
521
+ self.batch.seq_id[j][0] = seq_id
522
+ self.batch.n_seq_id[j] = 1
523
+ self.batch.logits[j] = logits_all
524
+ self.batch.logits[n_tokens - 1] = True
525
+
526
+
527
+ class LlamaTokenDataArray:
528
+ def __init__(self, *, n_vocab: int):
529
+ self.n_vocab = n_vocab
530
+ self.candidates_data = np.recarray(
531
+ (self.n_vocab,),
532
+ dtype=np.dtype(
533
+ [("id", np.intc), ("logit", np.single), ("p", np.single)], align=True
534
+ ),
535
+ )
536
+ self.candidates = llama_cpp.llama_token_data_array(
537
+ data=self.candidates_data.ctypes.data_as(llama_cpp.llama_token_data_p),
538
+ size=self.n_vocab,
539
+ sorted=False,
540
+ )
541
+ self.default_candidates_data_id = np.arange(self.n_vocab, dtype=np.intc) # type: ignore
542
+ self.default_candidates_data_p = np.zeros(self.n_vocab, dtype=np.single)
543
+
544
+ def copy_logits(self, logits: npt.NDArray[np.single]):
545
+ self.candidates_data.id[:] = self.default_candidates_data_id
546
+ self.candidates_data.logit[:] = logits
547
+ self.candidates_data.p[:] = self.default_candidates_data_p
548
+ self.candidates.sorted = False
549
+ self.candidates.size = self.n_vocab
550
+
551
+
552
+ # Embedding functions
553
+
554
+
555
+ def normalize_embedding(embedding):
556
+ norm = float(np.linalg.norm(embedding))
557
+ if norm == 0.0:
558
+ return embedding
559
+ return [v / norm for v in embedding]
560
+
561
+
562
+ # Python wrappers over common/sampling structs
563
+
564
+
565
+ @dataclass
566
+ class LlamaSamplingParams:
567
+ n_prev: int = 64
568
+ n_probs: int = 0
569
+ top_k: int = 40
570
+ top_p: float = 0.95
571
+ min_p: float = 0.05
572
+ tfs_z: float = 1.00
573
+ typical_p: float = 1.00
574
+ temp: float = 0.80
575
+ penalty_last_n: int = 64
576
+ penalty_repeat: float = 1.0
577
+ penalty_freq: float = 0.00
578
+ penalty_present: float = 0.00
579
+ mirostat: int = 0
580
+ mirostat_tau: float = 5.00
581
+ mirostat_eta: float = 0.10
582
+ penalize_nl: bool = True
583
+
584
+ grammar: str = ""
585
+
586
+ cfg_negative_prompt: str = ""
587
+ cfg_scale: float = 1.00
588
+
589
+ logit_bias: dict[int, float] = field(default_factory=dict)
590
+
591
+
592
+ @dataclass
593
+ class LlamaSamplingContext:
594
+ params: LlamaSamplingParams = field(default_factory=LlamaSamplingParams)
595
+ mirostat_mu: ctypes.c_float = field(default_factory=ctypes.c_float)
596
+ grammar: Optional[LlamaGrammar] = None
597
+ # NOTE: Missing parsed_grammar
598
+ prev: list[int] = field(default_factory=list)
599
+ cur: list[llama_cpp.llama_token_data] = field(default_factory=list)
600
+
601
+ def reset(self):
602
+ self.prev = []
603
+ self.cur = []
604
+ if self.grammar is not None:
605
+ self.grammar.reset()
606
+
607
+ def cp(self):
608
+ return LlamaSamplingContext(
609
+ params=self.params,
610
+ mirostat_mu=self.mirostat_mu,
611
+ grammar=self.grammar,
612
+ prev=self.prev.copy(),
613
+ cur=self.cur.copy(),
614
+ )
615
+
616
+ def last(self) -> Optional[int]:
617
+ if len(self.prev) > 0:
618
+ return self.prev[-1]
619
+ else:
620
+ return None
621
+
622
+ def prev_str(self, ctx_main: LlamaContext, n: int) -> str:
623
+ return ctx_main.model.detokenize(self.prev[-n:]).decode("utf-8")
624
+
625
+ def sample(
626
+ self,
627
+ ctx_main: LlamaContext,
628
+ idx: int = 0,
629
+ logits_array: Optional[npt.NDArray[np.single]] = None,
630
+ ):
631
+ n_vocab = ctx_main.model.n_vocab()
632
+ id: int = 0
633
+
634
+ if logits_array is None:
635
+ logits = ctx_main.get_logits_ith(idx)
636
+ logits_array = np.array(
637
+ ctypes.cast(logits, ctypes.POINTER(ctypes.c_float * n_vocab)).contents,
638
+ dtype=np.single,
639
+ )
640
+
641
+ # apply logit_bias
642
+ for token, logit_bias in self.params.logit_bias.items():
643
+ logits_array[token] += logit_bias
644
+
645
+ token_data_array = LlamaTokenDataArray(
646
+ n_vocab=n_vocab
647
+ ) # TODO: Only create this once
648
+ token_data_array.copy_logits(logits_array)
649
+
650
+ # apply penalties
651
+ if len(self.prev) > 0:
652
+ nl_token = ctx_main.model.token_nl()
653
+ nl_logit = logits_array[nl_token]
654
+ last_tokens = self.prev[-self.params.penalty_last_n :]
655
+ last_tokens_size = min(len(last_tokens), self.params.penalty_last_n)
656
+ if last_tokens_size > 0:
657
+ last_tokens_p = (llama_cpp.llama_token * len(last_tokens))(*last_tokens)
658
+ ctx_main.sample_repetition_penalties(
659
+ token_data_array,
660
+ last_tokens_p,
661
+ last_tokens_size,
662
+ self.params.penalty_repeat,
663
+ self.params.penalty_freq,
664
+ self.params.penalty_present,
665
+ )
666
+ if not self.params.penalize_nl:
667
+ token_data_array.candidates_data.logit[nl_token] = nl_logit
668
+
669
+ if self.grammar is not None:
670
+ ctx_main.sample_grammar(token_data_array, self.grammar)
671
+
672
+ if self.params.temp < 0:
673
+ ctx_main.sample_softmax(token_data_array)
674
+ id = token_data_array.candidates_data.id[0]
675
+ elif self.params.temp == 0:
676
+ id = ctx_main.sample_token_greedy(token_data_array)
677
+ else:
678
+ if self.params.mirostat == 1:
679
+ mirostat_m = 100
680
+ ctx_main.sample_temp(token_data_array, self.params.temp)
681
+ id = ctx_main.sample_token_mirostat(
682
+ token_data_array,
683
+ self.params.mirostat_tau,
684
+ self.params.mirostat_eta,
685
+ mirostat_m,
686
+ ctypes.pointer(self.mirostat_mu),
687
+ )
688
+ elif self.params.mirostat == 2:
689
+ ctx_main.sample_temp(token_data_array, self.params.temp)
690
+ id = ctx_main.sample_token_mirostat_v2(
691
+ token_data_array,
692
+ self.params.mirostat_tau,
693
+ self.params.mirostat_eta,
694
+ ctypes.pointer(self.mirostat_mu),
695
+ )
696
+ else:
697
+ min_keep = max(1, self.params.n_probs)
698
+ ctx_main.sample_top_k(
699
+ token_data_array, self.params.top_k, min_keep=min_keep
700
+ )
701
+ ctx_main.sample_typical(
702
+ token_data_array, self.params.typical_p, min_keep=min_keep
703
+ )
704
+ ctx_main.sample_top_p(
705
+ token_data_array, self.params.top_p, min_keep=min_keep
706
+ )
707
+ ctx_main.sample_min_p(
708
+ token_data_array, self.params.min_p, min_keep=min_keep
709
+ )
710
+ ctx_main.sample_temp(token_data_array, self.params.temp)
711
+ id = ctx_main.sample_token(token_data_array)
712
+ return id
713
+
714
+ def accept(self, ctx_main: LlamaContext, id: int, apply_grammar: bool):
715
+ if apply_grammar and self.grammar is not None:
716
+ ctx_main.grammar_accept_token(self.grammar, id)
717
+ self.prev.append(id)
718
+
719
+
720
+ from typing import List, Callable, Optional, Union
721
+ import ctypes
722
+ import llama_cpp
723
+
724
+
725
+ class CustomSampler:
726
+ def __init__(
727
+ self, apply_func: typing.Callable[[llama_cpp.llama_token_data_array], None]
728
+ ):
729
+ self.apply_func = apply_func
730
+
731
+ def apply_wrapper(
732
+ sampler: llama_cpp.llama_sampler_p,
733
+ cur_p: llama_cpp.llama_token_data_array_p,
734
+ ):
735
+ self.apply_func(cur_p)
736
+
737
+ def free_wrapper(sampler: llama_cpp.llama_sampler_p):
738
+ pass
739
+
740
+ sampler_i = llama_cpp.llama_sampler_i()
741
+ sampler_i.apply = llama_cpp.llama_sampler_i_apply(apply_wrapper)
742
+ self._apply_wrapper_ref = apply_wrapper
743
+
744
+ sampler_i.name = llama_cpp.llama_sampler_i_name(0)
745
+ sampler_i.accept = llama_cpp.llama_sampler_i_accept(0)
746
+ sampler_i.reset = llama_cpp.llama_sampler_i_reset(0)
747
+ sampler_i.clone = llama_cpp.llama_sampler_i_clone(0)
748
+ sampler_i.free = llama_cpp.llama_sampler_i_free(0)
749
+
750
+ self.sampler = llama_cpp.llama_sampler()
751
+ self.sampler.iface = ctypes.pointer(sampler_i)
752
+ self.sampler.ctx = None
753
+
754
+ def get_sampler(self) -> llama_cpp.llama_sampler_p:
755
+ return ctypes.pointer(self.sampler)
756
+
757
+
758
+ class LlamaSampler:
759
+ def __init__(self):
760
+ params = llama_cpp.llama_sampler_chain_params()
761
+ self.sampler = llama_cpp.llama_sampler_chain_init(params)
762
+ self.samplers: List[llama_cpp.llama_sampler_p] = []
763
+ self.custom_samplers: List[Tuple[int, CustomSampler]] = []
764
+
765
+ def add_greedy(self):
766
+ sampler = llama_cpp.llama_sampler_init_greedy()
767
+ self._add_sampler(sampler)
768
+
769
+ def add_dist(self, seed: int):
770
+ sampler = llama_cpp.llama_sampler_init_dist(seed)
771
+ self._add_sampler(sampler)
772
+
773
+ def add_softmax(self):
774
+ sampler = llama_cpp.llama_sampler_init_softmax()
775
+ self._add_sampler(sampler)
776
+
777
+ def add_top_k(self, k: int):
778
+ sampler = llama_cpp.llama_sampler_init_top_k(k)
779
+ self._add_sampler(sampler)
780
+
781
+ def add_top_p(self, p: float, min_keep: int):
782
+ sampler = llama_cpp.llama_sampler_init_top_p(p, min_keep)
783
+ self._add_sampler(sampler)
784
+
785
+ def add_min_p(self, p: float, min_keep: int):
786
+ sampler = llama_cpp.llama_sampler_init_min_p(p, min_keep)
787
+ self._add_sampler(sampler)
788
+
789
+ def add_typical(self, p: float, min_keep: int):
790
+ sampler = llama_cpp.llama_sampler_init_typical(p, min_keep)
791
+ self._add_sampler(sampler)
792
+
793
+ def add_temp(self, temp: float):
794
+ sampler = llama_cpp.llama_sampler_init_temp(temp)
795
+ self._add_sampler(sampler)
796
+
797
+ def add_temp_ext(self, t: float, delta: float, exponent: float):
798
+ sampler = llama_cpp.llama_sampler_init_temp_ext(t, delta, exponent)
799
+ self._add_sampler(sampler)
800
+
801
+ def add_mirostat(self, n_vocab: int, seed: int, tau: float, eta: float, m: int):
802
+ sampler = llama_cpp.llama_sampler_init_mirostat(n_vocab, seed, tau, eta, m)
803
+ self._add_sampler(sampler)
804
+
805
+ def add_mirostat_v2(self, seed: int, tau: float, eta: float):
806
+ sampler = llama_cpp.llama_sampler_init_mirostat_v2(seed, tau, eta)
807
+ self._add_sampler(sampler)
808
+
809
+ def add_grammar(self, model: LlamaModel, grammar: LlamaGrammar):
810
+ sampler = llama_cpp.llama_sampler_init_grammar(
811
+ model.vocab, grammar._grammar.encode("utf-8"), grammar._root.encode("utf-8")
812
+ )
813
+ self._add_sampler(sampler)
814
+
815
+ def add_penalties(
816
+ self,
817
+ n_vocab: int,
818
+ special_eos_id: int,
819
+ linefeed_id: int,
820
+ penalty_last_n: int,
821
+ penalty_repeat: float,
822
+ penalty_freq: float,
823
+ penalty_present: float,
824
+ penalize_nl: bool,
825
+ ignore_eos: bool,
826
+ ):
827
+ sampler = llama_cpp.llama_sampler_init_penalties(
828
+ penalty_last_n,
829
+ penalty_repeat,
830
+ penalty_freq,
831
+ penalty_present,
832
+ )
833
+ self._add_sampler(sampler)
834
+
835
+ def init_logit_bias(
836
+ self, n_vocab: int, n_logit_bias, logit_bias: llama_cpp.llama_logit_bias_p
837
+ ):
838
+ sampler = llama_cpp.llama_sampler_init_logit_bias(
839
+ n_vocab, n_logit_bias, logit_bias
840
+ )
841
+ self._add_sampler(sampler)
842
+
843
+ def add_custom(
844
+ self, apply_func: Callable[[llama_cpp.llama_token_data_array], None]
845
+ ):
846
+ custom_sampler = CustomSampler(apply_func)
847
+ sampler = custom_sampler.get_sampler()
848
+ self._add_sampler(sampler)
849
+ # NOTE: Must remove custom samplers before free or llama.cpp will try to free them
850
+ self.custom_samplers.append(
851
+ (llama_cpp.llama_sampler_chain_n(self.sampler) - 1, custom_sampler)
852
+ )
853
+
854
+ def _add_sampler(self, sampler: llama_cpp.llama_sampler_p):
855
+ assert self.sampler is not None
856
+ llama_cpp.llama_sampler_chain_add(self.sampler, sampler)
857
+ self.samplers.append(sampler)
858
+
859
+ def get_seed(self) -> int:
860
+ assert self.sampler is not None
861
+ return llama_cpp.llama_sampler_get_seed(self.sampler)
862
+
863
+ def sample(self, ctx: LlamaContext, idx: int) -> int:
864
+ assert self.sampler is not None
865
+ assert ctx.ctx is not None
866
+ return llama_cpp.llama_sampler_sample(self.sampler, ctx.ctx, idx)
867
+
868
+ def close(self):
869
+ if self.sampler:
870
+ # NOTE: Must remove custom samplers before free or llama.cpp will try to free them
871
+ for i, _ in reversed(self.custom_samplers):
872
+ llama_cpp.llama_sampler_chain_remove(self.sampler, i)
873
+ llama_cpp.llama_sampler_free(self.sampler)
874
+ self.sampler = None
875
+ self.samplers.clear()
876
+ self.custom_samplers.clear()
877
+
878
+ def __del__(self):
879
+ self.close()
llama_cpp/_logger.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import ctypes
3
+ import logging
4
+
5
+ import llama_cpp
6
+
7
+ # enum ggml_log_level {
8
+ # GGML_LOG_LEVEL_NONE = 0,
9
+ # GGML_LOG_LEVEL_INFO = 1,
10
+ # GGML_LOG_LEVEL_WARN = 2,
11
+ # GGML_LOG_LEVEL_ERROR = 3,
12
+ # GGML_LOG_LEVEL_DEBUG = 4,
13
+ # GGML_LOG_LEVEL_CONT = 5, // continue previous log
14
+ # };
15
+ GGML_LOG_LEVEL_TO_LOGGING_LEVEL = {
16
+ 0: logging.CRITICAL,
17
+ 1: logging.INFO,
18
+ 2: logging.WARNING,
19
+ 3: logging.ERROR,
20
+ 4: logging.DEBUG,
21
+ 5: logging.DEBUG,
22
+ }
23
+
24
+ logger = logging.getLogger("llama-cpp-python")
25
+
26
+ _last_log_level = GGML_LOG_LEVEL_TO_LOGGING_LEVEL[0]
27
+
28
+ # typedef void (*ggml_log_callback)(enum ggml_log_level level, const char * text, void * user_data);
29
+ @llama_cpp.llama_log_callback
30
+ def llama_log_callback(
31
+ level: int,
32
+ text: bytes,
33
+ user_data: ctypes.c_void_p,
34
+ ):
35
+ # TODO: Correctly implement continue previous log
36
+ global _last_log_level
37
+ log_level = GGML_LOG_LEVEL_TO_LOGGING_LEVEL[level] if level != 5 else _last_log_level
38
+ if logger.level <= GGML_LOG_LEVEL_TO_LOGGING_LEVEL[level]:
39
+ print(text.decode("utf-8"), end="", flush=True, file=sys.stderr)
40
+ _last_log_level = log_level
41
+
42
+
43
+ llama_cpp.llama_log_set(llama_log_callback, ctypes.c_void_p(0))
44
+
45
+
46
+ def set_verbose(verbose: bool):
47
+ logger.setLevel(logging.DEBUG if verbose else logging.ERROR)
llama_cpp/_utils.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ from typing import Any, Dict
5
+
6
+ # Avoid "LookupError: unknown encoding: ascii" when open() called in a destructor
7
+ outnull_file = open(os.devnull, "w")
8
+ errnull_file = open(os.devnull, "w")
9
+
10
+ STDOUT_FILENO = 1
11
+ STDERR_FILENO = 2
12
+
13
+
14
+ class suppress_stdout_stderr(object):
15
+ # NOTE: these must be "saved" here to avoid exceptions when using
16
+ # this context manager inside of a __del__ method
17
+ sys = sys
18
+ os = os
19
+
20
+ def __init__(self, disable: bool = True):
21
+ self.disable = disable
22
+
23
+ # Oddly enough this works better than the contextlib version
24
+ def __enter__(self):
25
+ if self.disable:
26
+ return self
27
+
28
+ self.old_stdout_fileno_undup = STDOUT_FILENO
29
+ self.old_stderr_fileno_undup = STDERR_FILENO
30
+
31
+ self.old_stdout_fileno = self.os.dup(self.old_stdout_fileno_undup)
32
+ self.old_stderr_fileno = self.os.dup(self.old_stderr_fileno_undup)
33
+
34
+ self.old_stdout = self.sys.stdout
35
+ self.old_stderr = self.sys.stderr
36
+
37
+ self.os.dup2(outnull_file.fileno(), self.old_stdout_fileno_undup)
38
+ self.os.dup2(errnull_file.fileno(), self.old_stderr_fileno_undup)
39
+
40
+ self.sys.stdout = outnull_file
41
+ self.sys.stderr = errnull_file
42
+ return self
43
+
44
+ def __exit__(self, *_):
45
+ if self.disable:
46
+ return
47
+
48
+ # Check if sys.stdout and sys.stderr have fileno method
49
+ self.sys.stdout = self.old_stdout
50
+ self.sys.stderr = self.old_stderr
51
+
52
+ self.os.dup2(self.old_stdout_fileno, self.old_stdout_fileno_undup)
53
+ self.os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup)
54
+
55
+ self.os.close(self.old_stdout_fileno)
56
+ self.os.close(self.old_stderr_fileno)
57
+
58
+
59
+ class MetaSingleton(type):
60
+ """
61
+ Metaclass for implementing the Singleton pattern.
62
+ """
63
+
64
+ _instances: Dict[type, Any] = {}
65
+
66
+ def __call__(cls, *args: Any, **kwargs: Any) -> Any:
67
+ if cls not in cls._instances:
68
+ cls._instances[cls] = super(MetaSingleton, cls).__call__(*args, **kwargs)
69
+ return cls._instances[cls]
70
+
71
+
72
+ class Singleton(object, metaclass=MetaSingleton):
73
+ """
74
+ Base class for implementing the Singleton pattern.
75
+ """
76
+
77
+ def __init__(self):
78
+ super(Singleton, self).__init__()
llama_cpp/lib/libggml-base.dylib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b10bbd734ef61868c17ddad1e15a08bcb93f711475f02b8df167256483f6b80c
3
+ size 548128
llama_cpp/lib/libggml-blas.dylib ADDED
Binary file (54.8 kB). View file
 
llama_cpp/lib/libggml-cpu.dylib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e842f3d817bdc4366d5b0b85a888bbe75bddd77e9cb320c9fcaacc7916cba08f
3
+ size 520192
llama_cpp/lib/libggml-metal.dylib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:965467ace598c542be5b05efa0edd6f28e8d9b3bef2d10e80f547b92cb05140c
3
+ size 600160
llama_cpp/lib/libggml.dylib ADDED
Binary file (56.4 kB). View file
 
llama_cpp/lib/libllama.dylib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:46ce63cb4168539abcaa517d7023aabf9f111d56ca64b5f8aa2c6fe991ba9910
3
+ size 1140160
llama_cpp/lib/libllava.dylib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:839335da6f28ab29c340a9c61e3dffbdbde54bfb9e68e52d3bb2e7255645506e
3
+ size 337424
llama_cpp/llama.py ADDED
@@ -0,0 +1,2418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import sys
5
+ import uuid
6
+ import time
7
+ import json
8
+ import ctypes
9
+ import typing
10
+ import random
11
+ import fnmatch
12
+ import warnings
13
+ import contextlib
14
+ import multiprocessing
15
+
16
+ from typing import (
17
+ Any,
18
+ List,
19
+ Literal,
20
+ Optional,
21
+ Union,
22
+ Generator,
23
+ Sequence,
24
+ Iterator,
25
+ Deque,
26
+ Callable,
27
+ Dict,
28
+ )
29
+ from collections import deque
30
+ from pathlib import Path
31
+
32
+
33
+ from .llama_types import *
34
+ from .llama_grammar import LlamaGrammar
35
+ from .llama_cache import (
36
+ BaseLlamaCache,
37
+ LlamaCache, # type: ignore
38
+ LlamaDiskCache, # type: ignore
39
+ LlamaRAMCache, # type: ignore
40
+ )
41
+ from .llama_tokenizer import BaseLlamaTokenizer, LlamaTokenizer
42
+ import llama_cpp.llama_cpp as llama_cpp
43
+ import llama_cpp.llama_chat_format as llama_chat_format
44
+
45
+ from llama_cpp.llama_speculative import LlamaDraftModel
46
+
47
+ import numpy as np
48
+ import numpy.typing as npt
49
+
50
+ import llama_cpp._internals as internals
51
+ from ._logger import set_verbose
52
+ from ._utils import suppress_stdout_stderr
53
+
54
+
55
+ class Llama:
56
+ """High-level Python wrapper for a llama.cpp model."""
57
+
58
+ __backend_initialized = False
59
+
60
+ def __init__(
61
+ self,
62
+ model_path: str,
63
+ *,
64
+ # Model Params
65
+ n_gpu_layers: int = 0,
66
+ split_mode: int = llama_cpp.LLAMA_SPLIT_MODE_LAYER,
67
+ main_gpu: int = 0,
68
+ tensor_split: Optional[List[float]] = None,
69
+ rpc_servers: Optional[str] = None,
70
+ vocab_only: bool = False,
71
+ use_mmap: bool = True,
72
+ use_mlock: bool = False,
73
+ kv_overrides: Optional[Dict[str, Union[bool, int, float, str]]] = None,
74
+ # Context Params
75
+ seed: int = llama_cpp.LLAMA_DEFAULT_SEED,
76
+ n_ctx: int = 512,
77
+ n_batch: int = 512,
78
+ n_ubatch: int = 512,
79
+ n_threads: Optional[int] = None,
80
+ n_threads_batch: Optional[int] = None,
81
+ rope_scaling_type: Optional[
82
+ int
83
+ ] = llama_cpp.LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED,
84
+ pooling_type: int = llama_cpp.LLAMA_POOLING_TYPE_UNSPECIFIED,
85
+ rope_freq_base: float = 0.0,
86
+ rope_freq_scale: float = 0.0,
87
+ yarn_ext_factor: float = -1.0,
88
+ yarn_attn_factor: float = 1.0,
89
+ yarn_beta_fast: float = 32.0,
90
+ yarn_beta_slow: float = 1.0,
91
+ yarn_orig_ctx: int = 0,
92
+ logits_all: bool = False,
93
+ embedding: bool = False,
94
+ offload_kqv: bool = True,
95
+ flash_attn: bool = False,
96
+ # Sampling Params
97
+ no_perf: bool = False,
98
+ last_n_tokens_size: int = 64,
99
+ # LoRA Params
100
+ lora_base: Optional[str] = None,
101
+ lora_scale: float = 1.0,
102
+ lora_path: Optional[str] = None,
103
+ # Backend Params
104
+ numa: Union[bool, int] = False,
105
+ # Chat Format Params
106
+ chat_format: Optional[str] = None,
107
+ chat_handler: Optional[llama_chat_format.LlamaChatCompletionHandler] = None,
108
+ # Speculative Decoding
109
+ draft_model: Optional[LlamaDraftModel] = None,
110
+ # Tokenizer Override
111
+ tokenizer: Optional[BaseLlamaTokenizer] = None,
112
+ # KV cache quantization
113
+ type_k: Optional[int] = None,
114
+ type_v: Optional[int] = None,
115
+ # Misc
116
+ spm_infill: bool = False,
117
+ verbose: bool = True,
118
+ # Extra Params
119
+ **kwargs, # type: ignore
120
+ ):
121
+ """Load a llama.cpp model from `model_path`.
122
+
123
+ Examples:
124
+ Basic usage
125
+
126
+ >>> import llama_cpp
127
+ >>> model = llama_cpp.Llama(
128
+ ... model_path="path/to/model",
129
+ ... )
130
+ >>> print(model("The quick brown fox jumps ", stop=["."])["choices"][0]["text"])
131
+ the lazy dog
132
+
133
+ Loading a chat model
134
+
135
+ >>> import llama_cpp
136
+ >>> model = llama_cpp.Llama(
137
+ ... model_path="path/to/model",
138
+ ... chat_format="llama-2",
139
+ ... )
140
+ >>> print(model.create_chat_completion(
141
+ ... messages=[{
142
+ ... "role": "user",
143
+ ... "content": "what is the meaning of life?"
144
+ ... }]
145
+ ... ))
146
+
147
+ Args:
148
+ model_path: Path to the model.
149
+ n_gpu_layers: Number of layers to offload to GPU (-ngl). If -1, all layers are offloaded.
150
+ split_mode: How to split the model across GPUs. See llama_cpp.LLAMA_SPLIT_* for options.
151
+ main_gpu: main_gpu interpretation depends on split_mode: LLAMA_SPLIT_MODE_NONE: the GPU that is used for the entire model. LLAMA_SPLIT_MODE_ROW: the GPU that is used for small tensors and intermediate results. LLAMA_SPLIT_MODE_LAYER: ignored
152
+ tensor_split: How split tensors should be distributed across GPUs. If None, the model is not split.
153
+ rpc_servers: Comma separated list of RPC servers to use for offloading
154
+ vocab_only: Only load the vocabulary no weights.
155
+ use_mmap: Use mmap if possible.
156
+ use_mlock: Force the system to keep the model in RAM.
157
+ kv_overrides: Key-value overrides for the model.
158
+ seed: RNG seed, -1 for random
159
+ n_ctx: Text context, 0 = from model
160
+ n_batch: Prompt processing maximum batch size
161
+ n_ubatch: Physical batch size
162
+ n_threads: Number of threads to use for generation
163
+ n_threads_batch: Number of threads to use for batch processing
164
+ rope_scaling_type: RoPE scaling type, from `enum llama_rope_scaling_type`. ref: https://github.com/ggerganov/llama.cpp/pull/2054
165
+ pooling_type: Pooling type, from `enum llama_pooling_type`.
166
+ rope_freq_base: RoPE base frequency, 0 = from model
167
+ rope_freq_scale: RoPE frequency scaling factor, 0 = from model
168
+ yarn_ext_factor: YaRN extrapolation mix factor, negative = from model
169
+ yarn_attn_factor: YaRN magnitude scaling factor
170
+ yarn_beta_fast: YaRN low correction dim
171
+ yarn_beta_slow: YaRN high correction dim
172
+ yarn_orig_ctx: YaRN original context size
173
+ logits_all: Return logits for all tokens, not just the last token. Must be True for completion to return logprobs.
174
+ embedding: Embedding mode only.
175
+ offload_kqv: Offload K, Q, V to GPU.
176
+ flash_attn: Use flash attention.
177
+ no_perf: Measure performance timings.
178
+ last_n_tokens_size: Maximum number of tokens to keep in the last_n_tokens deque.
179
+ lora_base: Optional path to base model, useful if using a quantized base model and you want to apply LoRA to an f16 model.
180
+ lora_path: Path to a LoRA file to apply to the model.
181
+ numa: numa policy
182
+ chat_format: String specifying the chat format to use when calling create_chat_completion.
183
+ chat_handler: Optional chat handler to use when calling create_chat_completion.
184
+ draft_model: Optional draft model to use for speculative decoding.
185
+ tokenizer: Optional tokenizer to override the default tokenizer from llama.cpp.
186
+ verbose: Print verbose output to stderr.
187
+ type_k: KV cache data type for K (default: f16)
188
+ type_v: KV cache data type for V (default: f16)
189
+ spm_infill: Use Suffix/Prefix/Middle pattern for infill (instead of Prefix/Suffix/Middle) as some models prefer this.
190
+
191
+ Raises:
192
+ ValueError: If the model path does not exist.
193
+
194
+ Returns:
195
+ A Llama instance.
196
+ """
197
+ self.verbose = verbose
198
+ self._stack = contextlib.ExitStack()
199
+
200
+ set_verbose(verbose)
201
+
202
+ if not Llama.__backend_initialized:
203
+ with suppress_stdout_stderr(disable=verbose):
204
+ llama_cpp.llama_backend_init()
205
+ Llama.__backend_initialized = True
206
+
207
+ if isinstance(numa, bool):
208
+ self.numa = (
209
+ llama_cpp.GGML_NUMA_STRATEGY_DISTRIBUTE
210
+ if numa
211
+ else llama_cpp.GGML_NUMA_STRATEGY_DISABLED
212
+ )
213
+ else:
214
+ self.numa = numa
215
+
216
+ if self.numa != llama_cpp.GGML_NUMA_STRATEGY_DISABLED:
217
+ with suppress_stdout_stderr(disable=verbose):
218
+ llama_cpp.llama_numa_init(self.numa)
219
+
220
+ self.model_path = model_path
221
+
222
+ # Model Params
223
+ self.model_params = llama_cpp.llama_model_default_params()
224
+ self.model_params.n_gpu_layers = (
225
+ 0x7FFFFFFF if n_gpu_layers == -1 else n_gpu_layers
226
+ ) # 0x7FFFFFFF is INT32 max, will be auto set to all layers
227
+ self.model_params.split_mode = split_mode
228
+ self.model_params.main_gpu = main_gpu
229
+ if rpc_servers is not None:
230
+ self.model_params.rpc_servers = rpc_servers.encode("utf-8")
231
+ self._rpc_servers = rpc_servers
232
+ else:
233
+ self._rpc_servers = None
234
+ self.tensor_split = tensor_split
235
+ self._c_tensor_split = None
236
+ if self.tensor_split is not None:
237
+ if len(self.tensor_split) > llama_cpp.LLAMA_MAX_DEVICES:
238
+ raise ValueError(
239
+ f"Attempt to split tensors that exceed maximum supported devices. Current LLAMA_MAX_DEVICES={llama_cpp.LLAMA_MAX_DEVICES}"
240
+ )
241
+ # Type conversion and expand the list to the length of LLAMA_MAX_DEVICES
242
+ FloatArray = ctypes.c_float * llama_cpp.LLAMA_MAX_DEVICES
243
+ self._c_tensor_split = FloatArray(
244
+ *tensor_split # type: ignore
245
+ ) # keep a reference to the array so it is not gc'd
246
+ self.model_params.tensor_split = self._c_tensor_split
247
+ self.model_params.vocab_only = vocab_only
248
+ self.model_params.use_mmap = use_mmap if lora_path is None else False
249
+ self.model_params.use_mlock = use_mlock
250
+
251
+ # kv_overrides is the original python dict
252
+ self.kv_overrides = kv_overrides
253
+ if kv_overrides is not None:
254
+ # _kv_overrides_array is a ctypes.Array of llama_model_kv_override Structs
255
+ kvo_array_len = len(kv_overrides) + 1 # for sentinel element
256
+ self._kv_overrides_array = (
257
+ llama_cpp.llama_model_kv_override * kvo_array_len
258
+ )()
259
+
260
+ for i, (k, v) in enumerate(kv_overrides.items()):
261
+ self._kv_overrides_array[i].key = k.encode("utf-8")
262
+ if isinstance(v, bool):
263
+ self._kv_overrides_array[
264
+ i
265
+ ].tag = llama_cpp.LLAMA_KV_OVERRIDE_TYPE_BOOL
266
+ self._kv_overrides_array[i].value.val_bool = v
267
+ elif isinstance(v, int):
268
+ self._kv_overrides_array[
269
+ i
270
+ ].tag = llama_cpp.LLAMA_KV_OVERRIDE_TYPE_INT
271
+ self._kv_overrides_array[i].value.val_i64 = v
272
+ elif isinstance(v, float):
273
+ self._kv_overrides_array[
274
+ i
275
+ ].tag = llama_cpp.LLAMA_KV_OVERRIDE_TYPE_FLOAT
276
+ self._kv_overrides_array[i].value.val_f64 = v
277
+ elif isinstance(v, str): # type: ignore
278
+ v_bytes = v.encode("utf-8")
279
+ if len(v_bytes) > 128: # TODO: Make this a constant
280
+ raise ValueError(f"Value for {k} is too long: {v}")
281
+ v_bytes = v_bytes.ljust(128, b"\0")
282
+ self._kv_overrides_array[
283
+ i
284
+ ].tag = llama_cpp.LLAMA_KV_OVERRIDE_TYPE_STR
285
+ # copy min(v_bytes, 128) to str_value
286
+ address = typing.cast(
287
+ int,
288
+ ctypes.addressof(self._kv_overrides_array[i].value)
289
+ + llama_cpp.llama_model_kv_override_value.val_str.offset,
290
+ )
291
+ buffer_start = ctypes.cast(address, ctypes.POINTER(ctypes.c_char))
292
+ ctypes.memmove(
293
+ buffer_start,
294
+ v_bytes,
295
+ 128,
296
+ )
297
+ else:
298
+ raise ValueError(f"Unknown value type for {k}: {v}")
299
+
300
+ self._kv_overrides_array[
301
+ -1
302
+ ].key = b"\0" # ensure sentinel element is zeroed
303
+ self.model_params.kv_overrides = self._kv_overrides_array
304
+
305
+ self.n_batch = min(n_ctx, n_batch) # ???
306
+ self.n_threads = n_threads or max(multiprocessing.cpu_count() // 2, 1)
307
+ self.n_threads_batch = n_threads_batch or multiprocessing.cpu_count()
308
+
309
+ # Used by the sampler
310
+ self._seed = seed or llama_cpp.LLAMA_DEFAULT_SEED
311
+
312
+ # Context Params
313
+ self.context_params = llama_cpp.llama_context_default_params()
314
+ self.context_params.n_ctx = n_ctx
315
+ self.context_params.n_batch = self.n_batch
316
+ self.context_params.n_ubatch = min(self.n_batch, n_ubatch)
317
+ self.context_params.n_threads = self.n_threads
318
+ self.context_params.n_threads_batch = self.n_threads_batch
319
+ self.context_params.rope_scaling_type = (
320
+ rope_scaling_type
321
+ if rope_scaling_type is not None
322
+ else llama_cpp.LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED
323
+ )
324
+ self.context_params.pooling_type = pooling_type
325
+ self.context_params.rope_freq_base = (
326
+ rope_freq_base if rope_freq_base != 0.0 else 0
327
+ )
328
+ self.context_params.rope_freq_scale = (
329
+ rope_freq_scale if rope_freq_scale != 0.0 else 0
330
+ )
331
+ self.context_params.yarn_ext_factor = (
332
+ yarn_ext_factor if yarn_ext_factor != 0.0 else 0
333
+ )
334
+ self.context_params.yarn_attn_factor = (
335
+ yarn_attn_factor if yarn_attn_factor != 0.0 else 0
336
+ )
337
+ self.context_params.yarn_beta_fast = (
338
+ yarn_beta_fast if yarn_beta_fast != 0.0 else 0
339
+ )
340
+ self.context_params.yarn_beta_slow = (
341
+ yarn_beta_slow if yarn_beta_slow != 0.0 else 0
342
+ )
343
+ self.context_params.yarn_orig_ctx = yarn_orig_ctx if yarn_orig_ctx != 0 else 0
344
+ self.context_params.logits_all = (
345
+ logits_all if draft_model is None else True
346
+ ) # Must be set to True for speculative decoding
347
+ self.context_params.embeddings = embedding # TODO: Rename to embeddings
348
+ self.context_params.offload_kqv = offload_kqv
349
+ self.context_params.flash_attn = flash_attn
350
+ # KV cache quantization
351
+ if type_k is not None:
352
+ self.context_params.type_k = type_k
353
+ if type_v is not None:
354
+ self.context_params.type_v = type_v
355
+ # Sampling Params
356
+ self.context_params.no_perf = no_perf
357
+ self.last_n_tokens_size = last_n_tokens_size
358
+
359
+ self.cache: Optional[BaseLlamaCache] = None
360
+
361
+ self.lora_base = lora_base
362
+ self.lora_scale = lora_scale
363
+ self.lora_path = lora_path
364
+
365
+ self.spm_infill = spm_infill
366
+
367
+ if not os.path.exists(model_path):
368
+ raise ValueError(f"Model path does not exist: {model_path}")
369
+
370
+ self._model = self._stack.enter_context(
371
+ contextlib.closing(
372
+ internals.LlamaModel(
373
+ path_model=self.model_path,
374
+ params=self.model_params,
375
+ verbose=self.verbose,
376
+ )
377
+ )
378
+ )
379
+
380
+ # Override tokenizer
381
+ self.tokenizer_ = tokenizer or LlamaTokenizer(self)
382
+
383
+ # Set the default value for the context and correct the batch
384
+ if n_ctx == 0:
385
+ n_ctx = self._model.n_ctx_train()
386
+ self.n_batch = min(n_ctx, n_batch)
387
+ self.context_params.n_ctx = self._model.n_ctx_train()
388
+ self.context_params.n_batch = self.n_batch
389
+ self.context_params.n_ubatch = min(self.n_batch, n_ubatch)
390
+
391
+ self._ctx = self._stack.enter_context(
392
+ contextlib.closing(
393
+ internals.LlamaContext(
394
+ model=self._model,
395
+ params=self.context_params,
396
+ verbose=self.verbose,
397
+ )
398
+ )
399
+ )
400
+
401
+ self._batch = self._stack.enter_context(
402
+ contextlib.closing(
403
+ internals.LlamaBatch(
404
+ n_tokens=self.n_batch,
405
+ embd=0,
406
+ n_seq_max=self.context_params.n_ctx,
407
+ verbose=self.verbose,
408
+ )
409
+ )
410
+ )
411
+
412
+ self._lora_adapter: Optional[llama_cpp.llama_adapter_lora_p] = None
413
+
414
+ if self.lora_path:
415
+ self._lora_adapter = llama_cpp.llama_adapter_lora_init(
416
+ self._model.model,
417
+ self.lora_path.encode("utf-8"),
418
+ )
419
+ if self._lora_adapter is None:
420
+ raise RuntimeError(
421
+ f"Failed to initialize LoRA adapter from lora path: {self.lora_path}"
422
+ )
423
+
424
+ def free_lora_adapter():
425
+ if self._lora_adapter is None:
426
+ return
427
+ llama_cpp.llama_adapter_lora_free(self._lora_adapter)
428
+ self._lora_adapter = None
429
+
430
+ self._stack.callback(free_lora_adapter)
431
+
432
+ if llama_cpp.llama_set_adapter_lora(
433
+ self._ctx.ctx, self._lora_adapter, self.lora_scale
434
+ ):
435
+ raise RuntimeError(
436
+ f"Failed to set LoRA adapter from lora path: {self.lora_path}"
437
+ )
438
+
439
+ if self.verbose:
440
+ print(llama_cpp.llama_print_system_info().decode("utf-8"), file=sys.stderr)
441
+
442
+ self.chat_format = chat_format
443
+ self.chat_handler = chat_handler
444
+ self._chat_handlers: Dict[
445
+ str, llama_chat_format.LlamaChatCompletionHandler
446
+ ] = {}
447
+
448
+ self.draft_model = draft_model
449
+
450
+ self._n_vocab = self.n_vocab()
451
+ self._n_ctx = self.n_ctx()
452
+
453
+ self._token_nl = self.token_nl()
454
+ self._token_eos = self.token_eos()
455
+
456
+ self._candidates = internals.LlamaTokenDataArray(n_vocab=self._n_vocab)
457
+
458
+ self.n_tokens = 0
459
+ self.input_ids: npt.NDArray[np.intc] = np.ndarray((n_ctx,), dtype=np.intc)
460
+ self.scores: npt.NDArray[np.single] = np.ndarray(
461
+ (n_ctx if logits_all == True else n_batch, self._n_vocab), dtype=np.single
462
+ )
463
+
464
+ self._mirostat_mu = ctypes.c_float(
465
+ 2.0 * 5.0
466
+ ) # TODO: Move this to sampling context
467
+
468
+ try:
469
+ self.metadata = self._model.metadata()
470
+ except Exception as e:
471
+ self.metadata = {}
472
+ if self.verbose:
473
+ print(f"Failed to load metadata: {e}", file=sys.stderr)
474
+
475
+ if self.verbose:
476
+ print(f"Model metadata: {self.metadata}", file=sys.stderr)
477
+
478
+ eos_token_id = self.token_eos()
479
+ bos_token_id = self.token_bos()
480
+
481
+ eos_token = (
482
+ self._model.token_get_text(eos_token_id) if eos_token_id != -1 else ""
483
+ )
484
+ bos_token = (
485
+ self._model.token_get_text(bos_token_id) if bos_token_id != -1 else ""
486
+ )
487
+
488
+ # Unfortunately the llama.cpp API does not return metadata arrays, so we can't get template names from tokenizer.chat_templates
489
+ template_choices = dict(
490
+ (name[10:], template)
491
+ for name, template in self.metadata.items()
492
+ if name.startswith("tokenizer.chat_template.")
493
+ )
494
+
495
+ if "tokenizer.chat_template" in self.metadata:
496
+ template_choices["chat_template.default"] = self.metadata[
497
+ "tokenizer.chat_template"
498
+ ]
499
+
500
+ if self.verbose and template_choices:
501
+ print(
502
+ f"Available chat formats from metadata: {', '.join(template_choices.keys())}",
503
+ file=sys.stderr,
504
+ )
505
+
506
+ for name, template in template_choices.items():
507
+ self._chat_handlers[name] = llama_chat_format.Jinja2ChatFormatter(
508
+ template=template,
509
+ eos_token=eos_token,
510
+ bos_token=bos_token,
511
+ stop_token_ids=[eos_token_id],
512
+ ).to_chat_handler()
513
+
514
+ if (
515
+ self.chat_format is None
516
+ and self.chat_handler is None
517
+ and "chat_template.default" in template_choices
518
+ ):
519
+ chat_format = llama_chat_format.guess_chat_format_from_gguf_metadata(
520
+ self.metadata
521
+ )
522
+
523
+ if chat_format is not None:
524
+ self.chat_format = chat_format
525
+ if self.verbose:
526
+ print(f"Guessed chat format: {chat_format}", file=sys.stderr)
527
+ else:
528
+ if self.verbose:
529
+ print(
530
+ f"Using gguf chat template: {template_choices['chat_template.default']}",
531
+ file=sys.stderr,
532
+ )
533
+ print(f"Using chat eos_token: {eos_token}", file=sys.stderr)
534
+ print(f"Using chat bos_token: {bos_token}", file=sys.stderr)
535
+
536
+ self.chat_format = "chat_template.default"
537
+
538
+ if self.chat_format is None and self.chat_handler is None:
539
+ self.chat_format = "llama-2"
540
+ if self.verbose:
541
+ print(
542
+ f"Using fallback chat format: {self.chat_format}", file=sys.stderr
543
+ )
544
+
545
+ self._sampler = None
546
+
547
+ @property
548
+ def ctx(self) -> llama_cpp.llama_context_p:
549
+ return self._ctx.ctx
550
+
551
+ @property
552
+ def model(self) -> llama_cpp.llama_model_p:
553
+ return self._model.model
554
+
555
+ @property
556
+ def _input_ids(self) -> npt.NDArray[np.intc]:
557
+ return self.input_ids[: self.n_tokens]
558
+
559
+ @property
560
+ def _scores(self) -> npt.NDArray[np.single]:
561
+ return self.scores[: self.n_tokens, :]
562
+
563
+ @property
564
+ def eval_tokens(self) -> Deque[int]:
565
+ return deque(self.input_ids[: self.n_tokens].tolist(), maxlen=self._n_ctx)
566
+
567
+ @property
568
+ def eval_logits(self) -> Deque[List[float]]:
569
+ return deque(
570
+ self.scores[: self.n_tokens, :].tolist(),
571
+ maxlen=self._n_ctx if self.context_params.logits_all else 1,
572
+ )
573
+
574
+ def tokenize(
575
+ self, text: bytes, add_bos: bool = True, special: bool = False
576
+ ) -> List[int]:
577
+ """Tokenize a string.
578
+
579
+ Args:
580
+ text: The utf-8 encoded string to tokenize.
581
+ add_bos: Whether to add a beginning of sequence token.
582
+ special: Whether to tokenize special tokens.
583
+
584
+ Raises:
585
+ RuntimeError: If the tokenization failed.
586
+
587
+ Returns:
588
+ A list of tokens.
589
+ """
590
+ return self.tokenizer_.tokenize(text, add_bos, special)
591
+
592
+ def detokenize(
593
+ self,
594
+ tokens: List[int],
595
+ prev_tokens: Optional[List[int]] = None,
596
+ special: bool = False,
597
+ ) -> bytes:
598
+ """Detokenize a list of tokens.
599
+
600
+ Args:
601
+ tokens: The list of tokens to detokenize.
602
+ prev_tokens: The list of previous tokens. Offset mapping will be performed if provided.
603
+ special: Whether to detokenize special tokens.
604
+
605
+ Returns:
606
+ The detokenized string.
607
+ """
608
+ return self.tokenizer_.detokenize(
609
+ tokens, prev_tokens=prev_tokens, special=special
610
+ )
611
+
612
+ def set_cache(self, cache: Optional[BaseLlamaCache]):
613
+ """Set the cache.
614
+
615
+ Args:
616
+ cache: The cache to set.
617
+ """
618
+ self.cache = cache
619
+
620
+ def set_seed(self, seed: int):
621
+ """Set the random seed.
622
+
623
+ Args:
624
+ seed: The random seed.
625
+ """
626
+ self._seed = seed
627
+
628
+ def reset(self):
629
+ """Reset the model state."""
630
+ self.n_tokens = 0
631
+
632
+ def eval(self, tokens: Sequence[int]):
633
+ """Evaluate a list of tokens.
634
+
635
+ Args:
636
+ tokens: The list of tokens to evaluate.
637
+ """
638
+ self._ctx.kv_cache_seq_rm(-1, self.n_tokens, -1)
639
+ for i in range(0, len(tokens), self.n_batch):
640
+ batch = tokens[i : min(len(tokens), i + self.n_batch)]
641
+ n_past = self.n_tokens
642
+ n_tokens = len(batch)
643
+ self._batch.set_batch(
644
+ batch=batch, n_past=n_past, logits_all=self.context_params.logits_all
645
+ )
646
+ self._ctx.decode(self._batch)
647
+ # Save tokens
648
+ self.input_ids[n_past : n_past + n_tokens] = batch
649
+ # Save logits
650
+ if self.context_params.logits_all:
651
+ rows = n_tokens
652
+ cols = self._n_vocab
653
+ logits = np.ctypeslib.as_array(
654
+ self._ctx.get_logits(), shape=(rows * cols,)
655
+ )
656
+ self.scores[n_past : n_past + n_tokens, :].reshape(-1)[::] = logits
657
+ else:
658
+ # rows = 1
659
+ # cols = self._n_vocab
660
+ # logits = np.ctypeslib.as_array(
661
+ # self._ctx.get_logits(), shape=(rows * cols,)
662
+ # )
663
+ # self.scores[n_past + n_tokens - 1, :].reshape(-1)[::] = logits
664
+ # NOTE: Now that sampling is done inside the sampler, logits are only needed for logprobs which requires logits_all
665
+ pass
666
+ # Update n_tokens
667
+ self.n_tokens += n_tokens
668
+
669
+ def _init_sampler(
670
+ self,
671
+ top_k: int = 40,
672
+ top_p: float = 0.95,
673
+ min_p: float = 0.05,
674
+ typical_p: float = 1.0,
675
+ temp: float = 0.80,
676
+ repeat_penalty: float = 1.0,
677
+ frequency_penalty: float = 0.0,
678
+ presence_penalty: float = 0.0,
679
+ tfs_z: float = 1.0,
680
+ mirostat_mode: int = 0,
681
+ mirostat_eta: float = 0.1,
682
+ mirostat_tau: float = 5.0,
683
+ penalize_nl: bool = True,
684
+ logits_processor: Optional[LogitsProcessorList] = None,
685
+ grammar: Optional[LlamaGrammar] = None,
686
+ ):
687
+ sampler = internals.LlamaSampler()
688
+
689
+ if logits_processor is not None:
690
+ # Create and add a custom sampler
691
+ def apply_func(token_data_array: llama_cpp.llama_token_data_array_p):
692
+ size = token_data_array.contents.size
693
+ data_soa = token_data_array.contents.data
694
+ data_soa_address = ctypes.addressof(data_soa.contents)
695
+ # NOTE: This is probably broken
696
+ recarray = np.recarray(
697
+ shape=(size,),
698
+ dtype=np.dtype(
699
+ [("id", np.intc), ("logit", np.single), ("p", np.single)],
700
+ align=True,
701
+ ),
702
+ buf=(llama_cpp.llama_token_data * size).from_address(
703
+ data_soa_address
704
+ ),
705
+ )
706
+ for logit_processor in logits_processor:
707
+ recarray.logit[:] = logit_processor(self._input_ids, recarray.logit)
708
+
709
+ sampler.add_custom(apply_func)
710
+
711
+ sampler.add_penalties(
712
+ n_vocab=self._n_vocab,
713
+ special_eos_id=self._token_eos,
714
+ linefeed_id=self._token_nl,
715
+ penalty_last_n=self.last_n_tokens_size,
716
+ penalty_repeat=repeat_penalty,
717
+ penalty_freq=frequency_penalty,
718
+ penalty_present=presence_penalty,
719
+ penalize_nl=penalize_nl,
720
+ ignore_eos=False,
721
+ )
722
+
723
+ if grammar is not None:
724
+ sampler.add_grammar(self._model, grammar)
725
+
726
+ if temp < 0.0:
727
+ sampler.add_softmax()
728
+ sampler.add_dist(self._seed)
729
+ elif temp == 0.0:
730
+ sampler.add_greedy()
731
+ else:
732
+ if mirostat_mode == 1:
733
+ mirostat_m = 100
734
+ sampler.add_mirostat(
735
+ self._n_vocab,
736
+ self._seed,
737
+ mirostat_tau,
738
+ mirostat_eta,
739
+ mirostat_m,
740
+ )
741
+ elif mirostat_mode == 2:
742
+ sampler.add_mirostat_v2(
743
+ self._seed,
744
+ mirostat_tau,
745
+ mirostat_eta,
746
+ )
747
+ else:
748
+ n_probs = 0
749
+ min_keep = max(1, n_probs)
750
+ sampler.add_top_k(top_k)
751
+ sampler.add_typical(typical_p, min_keep)
752
+ sampler.add_top_p(top_p, min_keep)
753
+ sampler.add_min_p(min_p, min_keep)
754
+ sampler.add_temp(temp)
755
+ sampler.add_dist(self._seed)
756
+ return sampler
757
+
758
+ def sample(
759
+ self,
760
+ top_k: int = 40,
761
+ top_p: float = 0.95,
762
+ min_p: float = 0.05,
763
+ typical_p: float = 1.0,
764
+ temp: float = 0.80,
765
+ repeat_penalty: float = 1.0,
766
+ frequency_penalty: float = 0.0,
767
+ presence_penalty: float = 0.0,
768
+ tfs_z: float = 1.0,
769
+ mirostat_mode: int = 0,
770
+ mirostat_eta: float = 0.1,
771
+ mirostat_tau: float = 5.0,
772
+ penalize_nl: bool = True,
773
+ logits_processor: Optional[LogitsProcessorList] = None,
774
+ grammar: Optional[LlamaGrammar] = None,
775
+ idx: Optional[int] = None,
776
+ ):
777
+ """Sample a token from the model.
778
+
779
+ Args:
780
+ top_k: The top-k sampling parameter.
781
+ top_p: The top-p sampling parameter.
782
+ temp: The temperature parameter.
783
+ repeat_penalty: The repeat penalty parameter.
784
+
785
+ Returns:
786
+ The sampled token.
787
+ """
788
+ assert self.n_tokens > 0
789
+
790
+ tmp_sampler = False
791
+
792
+ if self._sampler is None:
793
+ tmp_sampler = True
794
+ self._sampler = self._init_sampler(
795
+ top_k=top_k,
796
+ top_p=top_p,
797
+ min_p=min_p,
798
+ typical_p=typical_p,
799
+ temp=temp,
800
+ repeat_penalty=repeat_penalty,
801
+ frequency_penalty=frequency_penalty,
802
+ presence_penalty=presence_penalty,
803
+ tfs_z=tfs_z,
804
+ mirostat_mode=mirostat_mode,
805
+ mirostat_tau=mirostat_tau,
806
+ mirostat_eta=mirostat_eta,
807
+ penalize_nl=penalize_nl,
808
+ logits_processor=logits_processor,
809
+ grammar=grammar,
810
+ )
811
+
812
+ ridx = idx - self.n_tokens if idx is not None else -1
813
+
814
+ assert self.ctx is not None
815
+ token = self._sampler.sample(self._ctx, ridx)
816
+ if tmp_sampler:
817
+ self._sampler = None
818
+ return token
819
+
820
+ def generate(
821
+ self,
822
+ tokens: Sequence[int],
823
+ top_k: int = 40,
824
+ top_p: float = 0.95,
825
+ min_p: float = 0.05,
826
+ typical_p: float = 1.0,
827
+ temp: float = 0.80,
828
+ repeat_penalty: float = 1.0,
829
+ reset: bool = True,
830
+ frequency_penalty: float = 0.0,
831
+ presence_penalty: float = 0.0,
832
+ tfs_z: float = 1.0,
833
+ mirostat_mode: int = 0,
834
+ mirostat_tau: float = 5.0,
835
+ mirostat_eta: float = 0.1,
836
+ penalize_nl: bool = True,
837
+ logits_processor: Optional[LogitsProcessorList] = None,
838
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
839
+ grammar: Optional[LlamaGrammar] = None,
840
+ ) -> Generator[int, Optional[Sequence[int]], None]:
841
+ """Create a generator of tokens from a prompt.
842
+
843
+ Examples:
844
+ >>> llama = Llama("models/ggml-7b.bin")
845
+ >>> tokens = llama.tokenize(b"Hello, world!")
846
+ >>> for token in llama.generate(tokens, top_k=40, top_p=0.95, temp=1.0, repeat_penalty=1.0):
847
+ ... print(llama.detokenize([token]))
848
+
849
+ Args:
850
+ tokens: The prompt tokens.
851
+ top_k: The top-k sampling parameter.
852
+ top_p: The top-p sampling parameter.
853
+ temp: The temperature parameter.
854
+ repeat_penalty: The repeat penalty parameter.
855
+ reset: Whether to reset the model state.
856
+
857
+ Yields:
858
+ The generated tokens.
859
+ """
860
+ # Reset mirostat sampling
861
+ self._mirostat_mu = ctypes.c_float(2.0 * mirostat_tau)
862
+ self._sampler = self._init_sampler(
863
+ top_k=top_k,
864
+ top_p=top_p,
865
+ min_p=min_p,
866
+ typical_p=typical_p,
867
+ temp=temp,
868
+ repeat_penalty=repeat_penalty,
869
+ frequency_penalty=frequency_penalty,
870
+ presence_penalty=presence_penalty,
871
+ tfs_z=tfs_z,
872
+ mirostat_mode=mirostat_mode,
873
+ mirostat_tau=mirostat_tau,
874
+ mirostat_eta=mirostat_eta,
875
+ penalize_nl=penalize_nl,
876
+ logits_processor=logits_processor,
877
+ grammar=grammar,
878
+ )
879
+
880
+ # Check for kv cache prefix match
881
+ if reset and self.n_tokens > 0:
882
+ longest_prefix = 0
883
+ for a, b in zip(self._input_ids, tokens[:-1]):
884
+ if a == b:
885
+ longest_prefix += 1
886
+ else:
887
+ break
888
+ if longest_prefix > 0:
889
+ reset = False
890
+ tokens = tokens[longest_prefix:]
891
+ self.n_tokens = longest_prefix
892
+ if self.verbose:
893
+ print(
894
+ f"Llama.generate: {longest_prefix} prefix-match hit, "
895
+ f"remaining {len(tokens)} prompt tokens to eval",
896
+ file=sys.stderr,
897
+ )
898
+
899
+ # Reset the model state
900
+ if reset:
901
+ self.reset()
902
+
903
+ # # Reset the grammar
904
+ # if grammar is not None:
905
+ # grammar.reset()
906
+
907
+ sample_idx = self.n_tokens + len(tokens) - 1
908
+ tokens = list(tokens)
909
+
910
+ # Eval and sample
911
+ while True:
912
+ self.eval(tokens)
913
+ while sample_idx < self.n_tokens:
914
+ token = self.sample(
915
+ top_k=top_k,
916
+ top_p=top_p,
917
+ min_p=min_p,
918
+ typical_p=typical_p,
919
+ temp=temp,
920
+ repeat_penalty=repeat_penalty,
921
+ frequency_penalty=frequency_penalty,
922
+ presence_penalty=presence_penalty,
923
+ tfs_z=tfs_z,
924
+ mirostat_mode=mirostat_mode,
925
+ mirostat_tau=mirostat_tau,
926
+ mirostat_eta=mirostat_eta,
927
+ logits_processor=logits_processor,
928
+ grammar=grammar,
929
+ penalize_nl=penalize_nl,
930
+ idx=sample_idx,
931
+ )
932
+
933
+ sample_idx += 1
934
+ if stopping_criteria is not None and stopping_criteria(
935
+ self._input_ids[: sample_idx], self._scores[sample_idx - self.n_tokens, :]
936
+ ):
937
+ return
938
+ tokens_or_none = yield token
939
+ tokens.clear()
940
+ tokens.append(token)
941
+ if tokens_or_none is not None:
942
+ tokens.extend(tokens_or_none)
943
+
944
+ if sample_idx < self.n_tokens and token != self._input_ids[sample_idx]:
945
+ self.n_tokens = sample_idx
946
+ self._ctx.kv_cache_seq_rm(-1, self.n_tokens, -1)
947
+ break
948
+
949
+ if self.draft_model is not None:
950
+ self.input_ids[self.n_tokens : self.n_tokens + len(tokens)] = tokens
951
+ draft_tokens = self.draft_model(
952
+ self.input_ids[: self.n_tokens + len(tokens)]
953
+ )
954
+ tokens.extend(
955
+ draft_tokens.astype(int)[
956
+ : self._n_ctx - self.n_tokens - len(tokens)
957
+ ]
958
+ )
959
+
960
+ def create_embedding(
961
+ self, input: Union[str, List[str]], model: Optional[str] = None
962
+ ) -> CreateEmbeddingResponse:
963
+ """Embed a string.
964
+
965
+ Args:
966
+ input: The utf-8 encoded string to embed.
967
+
968
+ Returns:
969
+ An embedding object.
970
+ """
971
+ model_name: str = model if model is not None else self.model_path
972
+
973
+ input = input if isinstance(input, list) else [input]
974
+
975
+ # get numeric embeddings
976
+ embeds: Union[List[List[float]], List[List[List[float]]]]
977
+ total_tokens: int
978
+ embeds, total_tokens = self.embed(input, return_count=True) # type: ignore
979
+
980
+ # convert to CreateEmbeddingResponse
981
+ data: List[Embedding] = [
982
+ {
983
+ "object": "embedding",
984
+ "embedding": emb,
985
+ "index": idx,
986
+ }
987
+ for idx, emb in enumerate(embeds)
988
+ ]
989
+
990
+ return {
991
+ "object": "list",
992
+ "data": data,
993
+ "model": model_name,
994
+ "usage": {
995
+ "prompt_tokens": total_tokens,
996
+ "total_tokens": total_tokens,
997
+ },
998
+ }
999
+
1000
+ def embed(
1001
+ self,
1002
+ input: Union[str, List[str]],
1003
+ normalize: bool = False,
1004
+ truncate: bool = True,
1005
+ return_count: bool = False,
1006
+ ):
1007
+ """Embed a string.
1008
+
1009
+ Args:
1010
+ input: The utf-8 encoded string to embed.
1011
+
1012
+ Returns:
1013
+ A list of embeddings
1014
+ """
1015
+ n_embd = self.n_embd()
1016
+ n_batch = self.n_batch
1017
+
1018
+ # get pooling information
1019
+ pooling_type = self.pooling_type()
1020
+ logits_all = pooling_type == llama_cpp.LLAMA_POOLING_TYPE_NONE
1021
+
1022
+ if self.context_params.embeddings is False:
1023
+ raise RuntimeError(
1024
+ "Llama model must be created with embedding=True to call this method"
1025
+ )
1026
+
1027
+ if self.verbose:
1028
+ llama_cpp.llama_perf_context_reset(self._ctx.ctx)
1029
+
1030
+ if isinstance(input, str):
1031
+ inputs = [input]
1032
+ else:
1033
+ inputs = input
1034
+
1035
+ # reset batch
1036
+ self._batch.reset()
1037
+
1038
+ # decode and fetch embeddings
1039
+ data: Union[List[List[float]], List[List[List[float]]]] = []
1040
+
1041
+ def decode_batch(seq_sizes: List[int]):
1042
+ llama_cpp.llama_kv_cache_clear(self._ctx.ctx)
1043
+ self._ctx.decode(self._batch)
1044
+ self._batch.reset()
1045
+
1046
+ # store embeddings
1047
+ if pooling_type == llama_cpp.LLAMA_POOLING_TYPE_NONE:
1048
+ pos: int = 0
1049
+ for i, size in enumerate(seq_sizes):
1050
+ ptr = llama_cpp.llama_get_embeddings(self._ctx.ctx)
1051
+ embedding: List[List[float]] = [
1052
+ ptr[pos + j * n_embd : pos + (j + 1) * n_embd]
1053
+ for j in range(size)
1054
+ ]
1055
+ if normalize:
1056
+ embedding = [
1057
+ internals.normalize_embedding(e) for e in embedding
1058
+ ]
1059
+ data.append(embedding)
1060
+ pos += size
1061
+ else:
1062
+ for i in range(len(seq_sizes)):
1063
+ ptr = llama_cpp.llama_get_embeddings_seq(self._ctx.ctx, i)
1064
+ embedding: List[float] = ptr[:n_embd]
1065
+ if normalize:
1066
+ embedding = internals.normalize_embedding(embedding)
1067
+ data.append(embedding)
1068
+
1069
+ # init state
1070
+ total_tokens = 0
1071
+ s_batch = []
1072
+ t_batch = 0
1073
+ p_batch = 0
1074
+
1075
+ # accumulate batches and encode
1076
+ for text in inputs:
1077
+ tokens = self.tokenize(text.encode("utf-8"))
1078
+ if truncate:
1079
+ tokens = tokens[:n_batch]
1080
+
1081
+ n_tokens = len(tokens)
1082
+ total_tokens += n_tokens
1083
+
1084
+ # check for overrun
1085
+ if n_tokens > n_batch:
1086
+ raise ValueError(
1087
+ f"Requested tokens ({n_tokens}) exceed batch size of {n_batch}"
1088
+ )
1089
+
1090
+ # time to eval batch
1091
+ if t_batch + n_tokens > n_batch:
1092
+ decode_batch(s_batch)
1093
+ s_batch = []
1094
+ t_batch = 0
1095
+ p_batch = 0
1096
+
1097
+ # add to batch
1098
+ self._batch.add_sequence(tokens, p_batch, logits_all)
1099
+
1100
+ # update batch stats
1101
+ s_batch.append(n_tokens)
1102
+ t_batch += n_tokens
1103
+ p_batch += 1
1104
+
1105
+ # hanlde last batch
1106
+ decode_batch(s_batch)
1107
+
1108
+ if self.verbose:
1109
+ llama_cpp.llama_perf_context_print(self._ctx.ctx)
1110
+
1111
+ output = data[0] if isinstance(input, str) else data
1112
+
1113
+ llama_cpp.llama_kv_cache_clear(self._ctx.ctx)
1114
+ self.reset()
1115
+
1116
+ if return_count:
1117
+ return output, total_tokens
1118
+ else:
1119
+ return output
1120
+
1121
+ def _create_completion(
1122
+ self,
1123
+ prompt: Union[str, List[int]],
1124
+ suffix: Optional[str] = None,
1125
+ max_tokens: Optional[int] = 16,
1126
+ temperature: float = 0.8,
1127
+ top_p: float = 0.95,
1128
+ min_p: float = 0.05,
1129
+ typical_p: float = 1.0,
1130
+ logprobs: Optional[int] = None,
1131
+ echo: bool = False,
1132
+ stop: Optional[Union[str, List[str]]] = [],
1133
+ frequency_penalty: float = 0.0,
1134
+ presence_penalty: float = 0.0,
1135
+ repeat_penalty: float = 1.0,
1136
+ top_k: int = 40,
1137
+ stream: bool = False,
1138
+ seed: Optional[int] = None,
1139
+ tfs_z: float = 1.0,
1140
+ mirostat_mode: int = 0,
1141
+ mirostat_tau: float = 5.0,
1142
+ mirostat_eta: float = 0.1,
1143
+ model: Optional[str] = None,
1144
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
1145
+ logits_processor: Optional[LogitsProcessorList] = None,
1146
+ grammar: Optional[LlamaGrammar] = None,
1147
+ logit_bias: Optional[Dict[int, float]] = None,
1148
+ ) -> Union[
1149
+ Iterator[CreateCompletionResponse], Iterator[CreateCompletionStreamResponse]
1150
+ ]:
1151
+ assert suffix is None or suffix.__class__ is str
1152
+
1153
+ completion_id: str = f"cmpl-{str(uuid.uuid4())}"
1154
+ created: int = int(time.time())
1155
+ bos_token_id: int = self.token_bos()
1156
+ cls_token_id: int = self._model.token_cls()
1157
+ sep_token_id: int = self._model.token_sep()
1158
+ prefix_token_id: int = 0 # self._model.token_prefix() # TODO: Fix
1159
+ middle_token_id: int = 0 # self._model.token_middle() # TODO: Fix
1160
+ suffix_token_id: int = 0 # self._model.token_suffix() # TODO: Fix
1161
+ add_space_prefix: bool = (
1162
+ self.metadata.get("tokenizer.ggml.add_space_prefix", "true") == "true"
1163
+ )
1164
+ bos_tokens: List[int] = [cls_token_id if cls_token_id != -1 else bos_token_id]
1165
+ eos_tokens: List[int] = [
1166
+ sep_token_id if sep_token_id != -1 else self.token_eos()
1167
+ ]
1168
+
1169
+ if (
1170
+ (isinstance(prompt, list) and suffix is None)
1171
+ or not self._model.add_bos_token()
1172
+ or bos_tokens[:1] == [-1]
1173
+ ):
1174
+ bos_tokens = []
1175
+
1176
+ if (isinstance(prompt, list) and suffix is None) or (
1177
+ not self._model.add_eos_token() and sep_token_id == -1
1178
+ ):
1179
+ eos_tokens = []
1180
+
1181
+ suffix_space_prefix: int = 0
1182
+ # Tokenizer hack to remove leading space
1183
+ if add_space_prefix and suffix_token_id >= 0 and suffix:
1184
+ suffix = "☺" + suffix
1185
+ suffix_space_prefix = 2
1186
+
1187
+ # If prompt is empty, initialize completion with BOS token to avoid
1188
+ # detokenization including a space at the beginning of the completion
1189
+ completion_tokens: List[int] = [] if len(prompt) > 0 else [bos_token_id]
1190
+ # Add blank space to start of prompt to match OG llama tokenizer
1191
+ prefix_tokens: List[int] = (
1192
+ [prefix_token_id] if prefix_token_id >= 0 and suffix is not None else []
1193
+ ) + (
1194
+ (
1195
+ self.tokenize(
1196
+ prompt.encode("utf-8"),
1197
+ add_bos=False,
1198
+ special=(prefix_token_id < 0 or suffix is None),
1199
+ )
1200
+ if prompt != ""
1201
+ else []
1202
+ )
1203
+ if isinstance(prompt, str)
1204
+ else prompt
1205
+ )
1206
+ suffix_tokens: List[int] = (
1207
+ (
1208
+ [suffix_token_id]
1209
+ + (
1210
+ self.tokenize(suffix.encode("utf-8"), add_bos=False, special=False)[
1211
+ suffix_space_prefix:
1212
+ ]
1213
+ if suffix
1214
+ else []
1215
+ )
1216
+ )
1217
+ if suffix_token_id >= 0 and suffix is not None
1218
+ else []
1219
+ )
1220
+ middle_tokens: List[int] = (
1221
+ [middle_token_id] if middle_token_id >= 0 and suffix is not None else []
1222
+ )
1223
+ prompt_tokens: List[int] = (
1224
+ bos_tokens
1225
+ + (
1226
+ (suffix_tokens + prefix_tokens + middle_tokens)
1227
+ if self.spm_infill
1228
+ else (prefix_tokens + suffix_tokens + middle_tokens)
1229
+ )
1230
+ + eos_tokens
1231
+ )
1232
+ text: bytes = b""
1233
+ returned_tokens: int = 0
1234
+ stop = (
1235
+ stop if isinstance(stop, list) else [stop] if isinstance(stop, str) else []
1236
+ )
1237
+ model_name: str = model if model is not None else self.model_path
1238
+
1239
+ if prompt_tokens[:2] == [self.token_bos()] * 2:
1240
+ warnings.warn(
1241
+ f'Detected duplicate leading "{self._model.token_get_text(self.token_bos())}" in prompt, this will likely reduce response quality, consider removing it...',
1242
+ RuntimeWarning,
1243
+ )
1244
+
1245
+ # NOTE: This likely doesn't work correctly for the first token in the prompt
1246
+ # because of the extra space added to the start of the prompt_tokens
1247
+ if logit_bias is not None:
1248
+ logit_bias_map = {int(k): float(v) for k, v in logit_bias.items()}
1249
+
1250
+ def logit_bias_processor(
1251
+ input_ids: npt.NDArray[np.intc],
1252
+ scores: npt.NDArray[np.single],
1253
+ ) -> npt.NDArray[np.single]:
1254
+ new_scores = np.copy(
1255
+ scores
1256
+ ) # Does it make sense to copy the whole array or can we just overwrite the original one?
1257
+ for input_id, score in logit_bias_map.items():
1258
+ new_scores[input_id] = score + scores[input_id]
1259
+ return new_scores
1260
+
1261
+ _logit_bias_processor = LogitsProcessorList([logit_bias_processor])
1262
+ if logits_processor is None:
1263
+ logits_processor = _logit_bias_processor
1264
+ else:
1265
+ logits_processor = logits_processor.extend(_logit_bias_processor)
1266
+
1267
+ if self.verbose:
1268
+ self._ctx.reset_timings()
1269
+
1270
+ if len(prompt_tokens) >= self._n_ctx:
1271
+ raise ValueError(
1272
+ f"Requested tokens ({len(prompt_tokens)}) exceed context window of {llama_cpp.llama_n_ctx(self.ctx)}"
1273
+ )
1274
+
1275
+ if max_tokens is None or max_tokens <= 0:
1276
+ # Unlimited, depending on n_ctx.
1277
+ max_tokens = self._n_ctx - len(prompt_tokens)
1278
+
1279
+ # Truncate max_tokens if requested tokens would exceed the context window
1280
+ max_tokens = (
1281
+ max_tokens
1282
+ if max_tokens + len(prompt_tokens) < self._n_ctx
1283
+ else (self._n_ctx - len(prompt_tokens))
1284
+ )
1285
+
1286
+ if stop != []:
1287
+ stop_sequences = [s.encode("utf-8") for s in stop]
1288
+ else:
1289
+ stop_sequences = []
1290
+
1291
+ if logprobs is not None and self.context_params.logits_all is False:
1292
+ raise ValueError(
1293
+ "logprobs is not supported for models created with logits_all=False"
1294
+ )
1295
+
1296
+ if self.cache:
1297
+ try:
1298
+ cache_item = self.cache[prompt_tokens]
1299
+ cache_prefix_len = Llama.longest_token_prefix(
1300
+ cache_item.input_ids.tolist(), prompt_tokens
1301
+ )
1302
+ eval_prefix_len = Llama.longest_token_prefix(
1303
+ self._input_ids.tolist(), prompt_tokens
1304
+ )
1305
+ if cache_prefix_len > eval_prefix_len:
1306
+ self.load_state(cache_item)
1307
+ if self.verbose:
1308
+ print("Llama._create_completion: cache hit", file=sys.stderr)
1309
+ except KeyError:
1310
+ if self.verbose:
1311
+ print("Llama._create_completion: cache miss", file=sys.stderr)
1312
+
1313
+ if seed is not None:
1314
+ self.set_seed(seed)
1315
+ else:
1316
+ self.set_seed(random.Random(self._seed).randint(0, 2 ** 32))
1317
+
1318
+ finish_reason = "length"
1319
+ multibyte_fix = 0
1320
+ for token in self.generate(
1321
+ prompt_tokens,
1322
+ top_k=top_k,
1323
+ top_p=top_p,
1324
+ min_p=min_p,
1325
+ typical_p=typical_p,
1326
+ temp=temperature,
1327
+ tfs_z=tfs_z,
1328
+ mirostat_mode=mirostat_mode,
1329
+ mirostat_tau=mirostat_tau,
1330
+ mirostat_eta=mirostat_eta,
1331
+ frequency_penalty=frequency_penalty,
1332
+ presence_penalty=presence_penalty,
1333
+ repeat_penalty=repeat_penalty,
1334
+ stopping_criteria=stopping_criteria,
1335
+ logits_processor=logits_processor,
1336
+ grammar=grammar,
1337
+ ):
1338
+ if llama_cpp.llama_token_is_eog(self._model.vocab, token):
1339
+ text = self.detokenize(completion_tokens, prev_tokens=prompt_tokens)
1340
+ finish_reason = "stop"
1341
+ break
1342
+
1343
+ completion_tokens.append(token)
1344
+
1345
+ all_text = self.detokenize(completion_tokens, prev_tokens=prompt_tokens)
1346
+
1347
+ # Contains multi-byte UTF8
1348
+ for k, char in enumerate(all_text[-3:]):
1349
+ k = 3 - k
1350
+ for num, pattern in [(2, 192), (3, 224), (4, 240)]:
1351
+ # Bitwise AND check
1352
+ if num > k and pattern & char == pattern:
1353
+ multibyte_fix = num - k
1354
+
1355
+ # Stop incomplete bytes from passing
1356
+ if multibyte_fix > 0:
1357
+ multibyte_fix -= 1
1358
+ continue
1359
+
1360
+ any_stop = [s for s in stop_sequences if s in all_text]
1361
+ if len(any_stop) > 0:
1362
+ first_stop = any_stop[0]
1363
+ text = all_text[: all_text.index(first_stop)]
1364
+ finish_reason = "stop"
1365
+ break
1366
+
1367
+ if stream:
1368
+ remaining_tokens = completion_tokens[returned_tokens:]
1369
+ remaining_text = self.detokenize(
1370
+ remaining_tokens,
1371
+ prev_tokens=prompt_tokens + completion_tokens[:returned_tokens],
1372
+ )
1373
+ remaining_length = len(remaining_text)
1374
+
1375
+ # We want to avoid yielding any characters from
1376
+ # the generated text if they are part of a stop
1377
+ # sequence.
1378
+ first_stop_position = 0
1379
+ for s in stop_sequences:
1380
+ for i in range(min(len(s), remaining_length), 0, -1):
1381
+ if remaining_text.endswith(s[:i]):
1382
+ if i > first_stop_position:
1383
+ first_stop_position = i
1384
+ break
1385
+
1386
+ token_end_position = 0
1387
+
1388
+ if logprobs is not None:
1389
+ # not sure how to handle this branch when dealing
1390
+ # with CJK output, so keep it unchanged
1391
+ for token in remaining_tokens:
1392
+ if token == bos_token_id:
1393
+ continue
1394
+ token_end_position += len(
1395
+ self.detokenize(
1396
+ [token],
1397
+ prev_tokens=prompt_tokens
1398
+ + completion_tokens[:returned_tokens],
1399
+ )
1400
+ )
1401
+ # Check if stop sequence is in the token
1402
+ if token_end_position > (
1403
+ remaining_length - first_stop_position
1404
+ ):
1405
+ break
1406
+ token_str = self.detokenize(
1407
+ [token],
1408
+ prev_tokens=prompt_tokens
1409
+ + completion_tokens[:returned_tokens],
1410
+ ).decode("utf-8", errors="ignore")
1411
+ text_offset = len(prompt) + len(
1412
+ self.detokenize(
1413
+ completion_tokens[:returned_tokens],
1414
+ prev_tokens=prompt_tokens
1415
+ + completion_tokens[:returned_tokens],
1416
+ ).decode("utf-8", errors="ignore")
1417
+ )
1418
+ token_offset = len(prompt_tokens) + returned_tokens
1419
+ logits = self._scores[token_offset - 1, :]
1420
+ current_logprobs = Llama.logits_to_logprobs(logits).tolist()
1421
+ sorted_logprobs = list(
1422
+ sorted(
1423
+ zip(current_logprobs, range(len(current_logprobs))),
1424
+ reverse=True,
1425
+ )
1426
+ )
1427
+ top_logprob = {
1428
+ self.detokenize([i]).decode(
1429
+ "utf-8", errors="ignore"
1430
+ ): logprob
1431
+ for logprob, i in sorted_logprobs[:logprobs]
1432
+ }
1433
+ top_logprob.update({token_str: current_logprobs[int(token)]})
1434
+ logprobs_or_none = {
1435
+ "tokens": [
1436
+ self.detokenize(
1437
+ [token],
1438
+ prev_tokens=prompt_tokens
1439
+ + completion_tokens[:returned_tokens],
1440
+ ).decode("utf-8", errors="ignore")
1441
+ ],
1442
+ "text_offset": [text_offset],
1443
+ "token_logprobs": [current_logprobs[int(token)]],
1444
+ "top_logprobs": [top_logprob],
1445
+ }
1446
+ returned_tokens += 1
1447
+ yield {
1448
+ "id": completion_id,
1449
+ "object": "text_completion",
1450
+ "created": created,
1451
+ "model": model_name,
1452
+ "choices": [
1453
+ {
1454
+ "text": self.detokenize(
1455
+ [token],
1456
+ prev_tokens=prompt_tokens
1457
+ + completion_tokens[:returned_tokens],
1458
+ ).decode("utf-8", errors="ignore"),
1459
+ "index": 0,
1460
+ "logprobs": logprobs_or_none,
1461
+ "finish_reason": None,
1462
+ }
1463
+ ],
1464
+ }
1465
+ else:
1466
+ while len(remaining_tokens) > 0:
1467
+ decode_success = False
1468
+ for i in range(1, len(remaining_tokens) + 1):
1469
+ try:
1470
+ bs = self.detokenize(
1471
+ remaining_tokens[:i],
1472
+ prev_tokens=prompt_tokens
1473
+ + completion_tokens[:returned_tokens],
1474
+ )
1475
+ ts = bs.decode("utf-8")
1476
+ decode_success = True
1477
+ break
1478
+ except UnicodeError:
1479
+ pass
1480
+ else:
1481
+ break
1482
+ if not decode_success:
1483
+ # all remaining tokens cannot be decoded to a UTF-8 character
1484
+ break
1485
+ token_end_position += len(bs)
1486
+ if token_end_position > (
1487
+ remaining_length - first_stop_position
1488
+ ):
1489
+ break
1490
+ remaining_tokens = remaining_tokens[i:]
1491
+ returned_tokens += i
1492
+
1493
+ yield {
1494
+ "id": completion_id,
1495
+ "object": "text_completion",
1496
+ "created": created,
1497
+ "model": model_name,
1498
+ "choices": [
1499
+ {
1500
+ "text": ts,
1501
+ "index": 0,
1502
+ "logprobs": None,
1503
+ "finish_reason": None,
1504
+ }
1505
+ ],
1506
+ }
1507
+
1508
+ if len(completion_tokens) >= max_tokens:
1509
+ text = self.detokenize(completion_tokens, prev_tokens=prompt_tokens)
1510
+ finish_reason = "length"
1511
+ break
1512
+
1513
+ if stopping_criteria is not None and stopping_criteria(
1514
+ self._input_ids, self._scores[-1, :]
1515
+ ):
1516
+ text = self.detokenize(completion_tokens, prev_tokens=prompt_tokens)
1517
+ finish_reason = "stop"
1518
+
1519
+ if self.verbose:
1520
+ self._ctx.print_timings()
1521
+
1522
+ if stream:
1523
+ remaining_tokens = completion_tokens[returned_tokens:]
1524
+ remaining_text = self.detokenize(
1525
+ remaining_tokens,
1526
+ prev_tokens=prompt_tokens + completion_tokens[:returned_tokens],
1527
+ )
1528
+ any_stop = [s for s in stop_sequences if s in remaining_text]
1529
+ if len(any_stop) > 0:
1530
+ end = min(remaining_text.index(stop) for stop in any_stop)
1531
+ else:
1532
+ end = len(remaining_text)
1533
+
1534
+ token_end_position = 0
1535
+ for token in remaining_tokens:
1536
+ token_end_position += len(
1537
+ self.detokenize(
1538
+ [token],
1539
+ prev_tokens=prompt_tokens + completion_tokens[:returned_tokens],
1540
+ )
1541
+ )
1542
+
1543
+ logprobs_or_none: Optional[CompletionLogprobs] = None
1544
+ if logprobs is not None:
1545
+ if token == bos_token_id:
1546
+ continue
1547
+ token_str = self.detokenize([token]).decode(
1548
+ "utf-8", errors="ignore"
1549
+ )
1550
+ text_offset = len(prompt) + len(
1551
+ self.detokenize(
1552
+ completion_tokens[:returned_tokens],
1553
+ prev_tokens=prompt_tokens
1554
+ + completion_tokens[:returned_tokens],
1555
+ )
1556
+ )
1557
+ token_offset = len(prompt_tokens) + returned_tokens - 1
1558
+ logits = self._scores[token_offset, :]
1559
+ current_logprobs = Llama.logits_to_logprobs(logits).tolist()
1560
+ sorted_logprobs = list(
1561
+ sorted(
1562
+ zip(current_logprobs, range(len(current_logprobs))),
1563
+ reverse=True,
1564
+ )
1565
+ )
1566
+ top_logprob = {
1567
+ self.detokenize([i]).decode("utf-8", errors="ignore"): logprob
1568
+ for logprob, i in sorted_logprobs[:logprobs]
1569
+ }
1570
+ top_logprob.update({token_str: current_logprobs[int(token)]})
1571
+ logprobs_or_none = {
1572
+ "tokens": [
1573
+ self.detokenize([token]).decode("utf-8", errors="ignore")
1574
+ ],
1575
+ "text_offset": [text_offset],
1576
+ "token_logprobs": [current_logprobs[int(token)]],
1577
+ "top_logprobs": [top_logprob],
1578
+ }
1579
+
1580
+ if token_end_position >= end:
1581
+ last_text = self.detokenize([token])
1582
+ if token_end_position == end - 1:
1583
+ break
1584
+ returned_tokens += 1
1585
+ yield {
1586
+ "id": completion_id,
1587
+ "object": "text_completion",
1588
+ "created": created,
1589
+ "model": model_name,
1590
+ "choices": [
1591
+ {
1592
+ "text": last_text[
1593
+ : len(last_text) - (token_end_position - end)
1594
+ ].decode("utf-8", errors="ignore"),
1595
+ "index": 0,
1596
+ "logprobs": logprobs_or_none,
1597
+ "finish_reason": None,
1598
+ }
1599
+ ],
1600
+ }
1601
+ break
1602
+ returned_tokens += 1
1603
+ yield {
1604
+ "id": completion_id,
1605
+ "object": "text_completion",
1606
+ "created": created,
1607
+ "model": model_name,
1608
+ "choices": [
1609
+ {
1610
+ "text": self.detokenize([token]).decode(
1611
+ "utf-8", errors="ignore"
1612
+ ),
1613
+ "index": 0,
1614
+ "logprobs": logprobs_or_none,
1615
+ "finish_reason": None,
1616
+ }
1617
+ ],
1618
+ }
1619
+ yield {
1620
+ "id": completion_id,
1621
+ "object": "text_completion",
1622
+ "created": created,
1623
+ "model": model_name,
1624
+ "choices": [
1625
+ {
1626
+ "text": "",
1627
+ "index": 0,
1628
+ "logprobs": None,
1629
+ "finish_reason": finish_reason,
1630
+ }
1631
+ ],
1632
+ }
1633
+ if self.cache:
1634
+ if self.verbose:
1635
+ print("Llama._create_completion: cache save", file=sys.stderr)
1636
+ self.cache[prompt_tokens + completion_tokens] = self.save_state()
1637
+ if self.verbose:
1638
+ print("Llama._create_completion: cache saved", file=sys.stderr)
1639
+ return
1640
+
1641
+ if self.cache:
1642
+ if self.verbose:
1643
+ print("Llama._create_completion: cache save", file=sys.stderr)
1644
+ self.cache[prompt_tokens + completion_tokens] = self.save_state()
1645
+
1646
+ text_str = text.decode("utf-8", errors="ignore")
1647
+
1648
+ if echo:
1649
+ text_str = prompt + text_str
1650
+
1651
+ if suffix_token_id < 0 and suffix is not None:
1652
+ text_str = text_str + suffix
1653
+
1654
+ logprobs_or_none: Optional[CompletionLogprobs] = None
1655
+ if logprobs is not None:
1656
+ text_offset = 0 if echo else len(prompt)
1657
+ token_offset = 0 if echo else len(prompt_tokens[1:])
1658
+ text_offsets: List[int] = []
1659
+ token_logprobs: List[Optional[float]] = []
1660
+ tokens: List[str] = []
1661
+ top_logprobs: List[Optional[Dict[str, float]]] = []
1662
+
1663
+ if echo:
1664
+ # Remove leading BOS token if exists
1665
+ all_tokens = (
1666
+ prompt_tokens[1 if prompt_tokens[0] == self.token_bos() else 0 :]
1667
+ + completion_tokens
1668
+ )
1669
+ else:
1670
+ all_tokens = completion_tokens
1671
+
1672
+ all_token_strs = [
1673
+ self.detokenize([token], prev_tokens=all_tokens[:i]).decode(
1674
+ "utf-8", errors="ignore"
1675
+ )
1676
+ for i, token in enumerate(all_tokens)
1677
+ ]
1678
+ all_logprobs = Llama.logits_to_logprobs(self._scores)[token_offset:]
1679
+ # TODO: may be able to change this loop to use np.take_along_dim
1680
+ for idx, (token, token_str, logprobs_token) in enumerate(
1681
+ zip(all_tokens, all_token_strs, all_logprobs)
1682
+ ):
1683
+ if token == bos_token_id:
1684
+ continue
1685
+ text_offsets.append(
1686
+ text_offset
1687
+ + len(
1688
+ self.detokenize(all_tokens[:idx]).decode(
1689
+ "utf-8", errors="ignore"
1690
+ )
1691
+ )
1692
+ )
1693
+ tokens.append(token_str)
1694
+ sorted_logprobs = list(
1695
+ sorted(
1696
+ zip(logprobs_token, range(len(logprobs_token))), reverse=True
1697
+ )
1698
+ )
1699
+ token_logprobs.append(logprobs_token[int(token)])
1700
+ top_logprob: Optional[Dict[str, float]] = {
1701
+ self.detokenize([i], prev_tokens=all_tokens[:idx]).decode(
1702
+ "utf-8", errors="ignore"
1703
+ ): logprob
1704
+ for logprob, i in sorted_logprobs[:logprobs]
1705
+ }
1706
+ top_logprob.update({token_str: logprobs_token[int(token)]})
1707
+ top_logprobs.append(top_logprob)
1708
+ # Weird idosincracy of the OpenAI API where
1709
+ # token_logprobs and top_logprobs are null for
1710
+ # the first token.
1711
+ if echo and len(all_tokens) > 0:
1712
+ token_logprobs[0] = None
1713
+ top_logprobs[0] = None
1714
+ logprobs_or_none = {
1715
+ "tokens": tokens,
1716
+ "text_offset": text_offsets,
1717
+ "token_logprobs": token_logprobs,
1718
+ "top_logprobs": top_logprobs,
1719
+ }
1720
+
1721
+ yield {
1722
+ "id": completion_id,
1723
+ "object": "text_completion",
1724
+ "created": created,
1725
+ "model": model_name,
1726
+ "choices": [
1727
+ {
1728
+ "text": text_str,
1729
+ "index": 0,
1730
+ "logprobs": logprobs_or_none,
1731
+ "finish_reason": finish_reason,
1732
+ }
1733
+ ],
1734
+ "usage": {
1735
+ "prompt_tokens": len(prompt_tokens),
1736
+ "completion_tokens": len(completion_tokens),
1737
+ "total_tokens": len(prompt_tokens) + len(completion_tokens),
1738
+ },
1739
+ }
1740
+
1741
+ def create_completion(
1742
+ self,
1743
+ prompt: Union[str, List[int]],
1744
+ suffix: Optional[str] = None,
1745
+ max_tokens: Optional[int] = 16,
1746
+ temperature: float = 0.8,
1747
+ top_p: float = 0.95,
1748
+ min_p: float = 0.05,
1749
+ typical_p: float = 1.0,
1750
+ logprobs: Optional[int] = None,
1751
+ echo: bool = False,
1752
+ stop: Optional[Union[str, List[str]]] = [],
1753
+ frequency_penalty: float = 0.0,
1754
+ presence_penalty: float = 0.0,
1755
+ repeat_penalty: float = 1.0,
1756
+ top_k: int = 40,
1757
+ stream: bool = False,
1758
+ seed: Optional[int] = None,
1759
+ tfs_z: float = 1.0,
1760
+ mirostat_mode: int = 0,
1761
+ mirostat_tau: float = 5.0,
1762
+ mirostat_eta: float = 0.1,
1763
+ model: Optional[str] = None,
1764
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
1765
+ logits_processor: Optional[LogitsProcessorList] = None,
1766
+ grammar: Optional[LlamaGrammar] = None,
1767
+ logit_bias: Optional[Dict[int, float]] = None,
1768
+ ) -> Union[CreateCompletionResponse, Iterator[CreateCompletionStreamResponse]]:
1769
+ """Generate text from a prompt.
1770
+
1771
+ Args:
1772
+ prompt: The prompt to generate text from.
1773
+ suffix: A suffix to append to the generated text. If None, no suffix is appended.
1774
+ max_tokens: The maximum number of tokens to generate. If max_tokens <= 0 or None, the maximum number of tokens to generate is unlimited and depends on n_ctx.
1775
+ temperature: The temperature to use for sampling.
1776
+ top_p: The top-p value to use for nucleus sampling. Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
1777
+ min_p: The min-p value to use for minimum p sampling. Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841
1778
+ typical_p: The typical-p value to use for sampling. Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
1779
+ logprobs: The number of logprobs to return. If None, no logprobs are returned.
1780
+ echo: Whether to echo the prompt.
1781
+ stop: A list of strings to stop generation when encountered.
1782
+ frequency_penalty: The penalty to apply to tokens based on their frequency in the prompt.
1783
+ presence_penalty: The penalty to apply to tokens based on their presence in the prompt.
1784
+ repeat_penalty: The penalty to apply to repeated tokens.
1785
+ top_k: The top-k value to use for sampling. Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
1786
+ stream: Whether to stream the results.
1787
+ seed: The seed to use for sampling.
1788
+ tfs_z: The tail-free sampling parameter. Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
1789
+ mirostat_mode: The mirostat sampling mode.
1790
+ mirostat_tau: The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
1791
+ mirostat_eta: The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
1792
+ model: The name to use for the model in the completion object.
1793
+ stopping_criteria: A list of stopping criteria to use.
1794
+ logits_processor: A list of logits processors to use.
1795
+ grammar: A grammar to use for constrained sampling.
1796
+ logit_bias: A logit bias to use.
1797
+
1798
+ Raises:
1799
+ ValueError: If the requested tokens exceed the context window.
1800
+ RuntimeError: If the prompt fails to tokenize or the model fails to evaluate the prompt.
1801
+
1802
+ Returns:
1803
+ Response object containing the generated text.
1804
+ """
1805
+ completion_or_chunks = self._create_completion(
1806
+ prompt=prompt,
1807
+ suffix=suffix,
1808
+ max_tokens=-1 if max_tokens is None else max_tokens,
1809
+ temperature=temperature,
1810
+ top_p=top_p,
1811
+ min_p=min_p,
1812
+ typical_p=typical_p,
1813
+ logprobs=logprobs,
1814
+ echo=echo,
1815
+ stop=stop,
1816
+ frequency_penalty=frequency_penalty,
1817
+ presence_penalty=presence_penalty,
1818
+ repeat_penalty=repeat_penalty,
1819
+ top_k=top_k,
1820
+ stream=stream,
1821
+ seed=seed,
1822
+ tfs_z=tfs_z,
1823
+ mirostat_mode=mirostat_mode,
1824
+ mirostat_tau=mirostat_tau,
1825
+ mirostat_eta=mirostat_eta,
1826
+ model=model,
1827
+ stopping_criteria=stopping_criteria,
1828
+ logits_processor=logits_processor,
1829
+ grammar=grammar,
1830
+ logit_bias=logit_bias,
1831
+ )
1832
+ if stream:
1833
+ chunks: Iterator[CreateCompletionStreamResponse] = completion_or_chunks
1834
+ return chunks
1835
+ completion: Completion = next(completion_or_chunks) # type: ignore
1836
+ return completion
1837
+
1838
+ def __call__(
1839
+ self,
1840
+ prompt: str,
1841
+ suffix: Optional[str] = None,
1842
+ max_tokens: Optional[int] = 16,
1843
+ temperature: float = 0.8,
1844
+ top_p: float = 0.95,
1845
+ min_p: float = 0.05,
1846
+ typical_p: float = 1.0,
1847
+ logprobs: Optional[int] = None,
1848
+ echo: bool = False,
1849
+ stop: Optional[Union[str, List[str]]] = [],
1850
+ frequency_penalty: float = 0.0,
1851
+ presence_penalty: float = 0.0,
1852
+ repeat_penalty: float = 1.0,
1853
+ top_k: int = 40,
1854
+ stream: bool = False,
1855
+ seed: Optional[int] = None,
1856
+ tfs_z: float = 1.0,
1857
+ mirostat_mode: int = 0,
1858
+ mirostat_tau: float = 5.0,
1859
+ mirostat_eta: float = 0.1,
1860
+ model: Optional[str] = None,
1861
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
1862
+ logits_processor: Optional[LogitsProcessorList] = None,
1863
+ grammar: Optional[LlamaGrammar] = None,
1864
+ logit_bias: Optional[Dict[int, float]] = None,
1865
+ ) -> Union[CreateCompletionResponse, Iterator[CreateCompletionStreamResponse]]:
1866
+ """Generate text from a prompt.
1867
+
1868
+ Args:
1869
+ prompt: The prompt to generate text from.
1870
+ suffix: A suffix to append to the generated text. If None, no suffix is appended.
1871
+ max_tokens: The maximum number of tokens to generate. If max_tokens <= 0 or None, the maximum number of tokens to generate is unlimited and depends on n_ctx.
1872
+ temperature: The temperature to use for sampling.
1873
+ top_p: The top-p value to use for nucleus sampling. Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
1874
+ min_p: The min-p value to use for minimum p sampling. Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841
1875
+ typical_p: The typical-p value to use for sampling. Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
1876
+ logprobs: The number of logprobs to return. If None, no logprobs are returned.
1877
+ echo: Whether to echo the prompt.
1878
+ stop: A list of strings to stop generation when encountered.
1879
+ frequency_penalty: The penalty to apply to tokens based on their frequency in the prompt.
1880
+ presence_penalty: The penalty to apply to tokens based on their presence in the prompt.
1881
+ repeat_penalty: The penalty to apply to repeated tokens.
1882
+ top_k: The top-k value to use for sampling. Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
1883
+ stream: Whether to stream the results.
1884
+ seed: The seed to use for sampling.
1885
+ tfs_z: The tail-free sampling parameter. Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
1886
+ mirostat_mode: The mirostat sampling mode.
1887
+ mirostat_tau: The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
1888
+ mirostat_eta: The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
1889
+ model: The name to use for the model in the completion object.
1890
+ stopping_criteria: A list of stopping criteria to use.
1891
+ logits_processor: A list of logits processors to use.
1892
+ grammar: A grammar to use for constrained sampling.
1893
+ logit_bias: A logit bias to use.
1894
+
1895
+ Raises:
1896
+ ValueError: If the requested tokens exceed the context window.
1897
+ RuntimeError: If the prompt fails to tokenize or the model fails to evaluate the prompt.
1898
+
1899
+ Returns:
1900
+ Response object containing the generated text.
1901
+ """
1902
+ return self.create_completion(
1903
+ prompt=prompt,
1904
+ suffix=suffix,
1905
+ max_tokens=max_tokens,
1906
+ temperature=temperature,
1907
+ top_p=top_p,
1908
+ min_p=min_p,
1909
+ typical_p=typical_p,
1910
+ logprobs=logprobs,
1911
+ echo=echo,
1912
+ stop=stop,
1913
+ frequency_penalty=frequency_penalty,
1914
+ presence_penalty=presence_penalty,
1915
+ repeat_penalty=repeat_penalty,
1916
+ top_k=top_k,
1917
+ stream=stream,
1918
+ seed=seed,
1919
+ tfs_z=tfs_z,
1920
+ mirostat_mode=mirostat_mode,
1921
+ mirostat_tau=mirostat_tau,
1922
+ mirostat_eta=mirostat_eta,
1923
+ model=model,
1924
+ stopping_criteria=stopping_criteria,
1925
+ logits_processor=logits_processor,
1926
+ grammar=grammar,
1927
+ logit_bias=logit_bias,
1928
+ )
1929
+
1930
+ def create_chat_completion(
1931
+ self,
1932
+ messages: List[ChatCompletionRequestMessage],
1933
+ functions: Optional[List[ChatCompletionFunction]] = None,
1934
+ function_call: Optional[ChatCompletionRequestFunctionCall] = None,
1935
+ tools: Optional[List[ChatCompletionTool]] = None,
1936
+ tool_choice: Optional[ChatCompletionToolChoiceOption] = None,
1937
+ temperature: float = 0.2,
1938
+ top_p: float = 0.95,
1939
+ top_k: int = 40,
1940
+ min_p: float = 0.05,
1941
+ typical_p: float = 1.0,
1942
+ stream: bool = False,
1943
+ stop: Optional[Union[str, List[str]]] = [],
1944
+ seed: Optional[int] = None,
1945
+ response_format: Optional[ChatCompletionRequestResponseFormat] = None,
1946
+ max_tokens: Optional[int] = None,
1947
+ presence_penalty: float = 0.0,
1948
+ frequency_penalty: float = 0.0,
1949
+ repeat_penalty: float = 1.0,
1950
+ tfs_z: float = 1.0,
1951
+ mirostat_mode: int = 0,
1952
+ mirostat_tau: float = 5.0,
1953
+ mirostat_eta: float = 0.1,
1954
+ model: Optional[str] = None,
1955
+ logits_processor: Optional[LogitsProcessorList] = None,
1956
+ grammar: Optional[LlamaGrammar] = None,
1957
+ logit_bias: Optional[Dict[int, float]] = None,
1958
+ logprobs: Optional[bool] = None,
1959
+ top_logprobs: Optional[int] = None,
1960
+ ) -> Union[
1961
+ CreateChatCompletionResponse, Iterator[CreateChatCompletionStreamResponse]
1962
+ ]:
1963
+ """Generate a chat completion from a list of messages.
1964
+
1965
+ Args:
1966
+ messages: A list of messages to generate a response for.
1967
+ functions: A list of functions to use for the chat completion.
1968
+ function_call: A function call to use for the chat completion.
1969
+ tools: A list of tools to use for the chat completion.
1970
+ tool_choice: A tool choice to use for the chat completion.
1971
+ temperature: The temperature to use for sampling.
1972
+ top_p: The top-p value to use for nucleus sampling. Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
1973
+ top_k: The top-k value to use for sampling. Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
1974
+ min_p: The min-p value to use for minimum p sampling. Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841
1975
+ typical_p: The typical-p value to use for sampling. Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
1976
+ stream: Whether to stream the results.
1977
+ stop: A list of strings to stop generation when encountered.
1978
+ seed: The seed to use for sampling.
1979
+ response_format: The response format to use for the chat completion. Use { "type": "json_object" } to contstrain output to only valid json.
1980
+ max_tokens: The maximum number of tokens to generate. If max_tokens <= 0 or None, the maximum number of tokens to generate is unlimited and depends on n_ctx.
1981
+ presence_penalty: The penalty to apply to tokens based on their presence in the prompt.
1982
+ frequency_penalty: The penalty to apply to tokens based on their frequency in the prompt.
1983
+ repeat_penalty: The penalty to apply to repeated tokens.
1984
+ tfs_z: The tail-free sampling parameter.
1985
+ mirostat_mode: The mirostat sampling mode.
1986
+ mirostat_tau: The mirostat sampling tau parameter.
1987
+ mirostat_eta: The mirostat sampling eta parameter.
1988
+ model: The name to use for the model in the completion object.
1989
+ logits_processor: A list of logits processors to use.
1990
+ grammar: A grammar to use.
1991
+ logit_bias: A logit bias to use.
1992
+
1993
+ Returns:
1994
+ Generated chat completion or a stream of chat completion chunks.
1995
+ """
1996
+ handler = (
1997
+ self.chat_handler
1998
+ or self._chat_handlers.get(self.chat_format)
1999
+ or llama_chat_format.get_chat_completion_handler(self.chat_format)
2000
+ )
2001
+ return handler(
2002
+ llama=self,
2003
+ messages=messages,
2004
+ functions=functions,
2005
+ function_call=function_call,
2006
+ tools=tools,
2007
+ tool_choice=tool_choice,
2008
+ temperature=temperature,
2009
+ top_p=top_p,
2010
+ top_k=top_k,
2011
+ min_p=min_p,
2012
+ typical_p=typical_p,
2013
+ logprobs=logprobs,
2014
+ top_logprobs=top_logprobs,
2015
+ stream=stream,
2016
+ stop=stop,
2017
+ seed=seed,
2018
+ response_format=response_format,
2019
+ max_tokens=max_tokens,
2020
+ presence_penalty=presence_penalty,
2021
+ frequency_penalty=frequency_penalty,
2022
+ repeat_penalty=repeat_penalty,
2023
+ tfs_z=tfs_z,
2024
+ mirostat_mode=mirostat_mode,
2025
+ mirostat_tau=mirostat_tau,
2026
+ mirostat_eta=mirostat_eta,
2027
+ model=model,
2028
+ logits_processor=logits_processor,
2029
+ grammar=grammar,
2030
+ logit_bias=logit_bias,
2031
+ )
2032
+
2033
+ def create_chat_completion_openai_v1(
2034
+ self,
2035
+ *args: Any,
2036
+ **kwargs: Any,
2037
+ ):
2038
+ """Generate a chat completion with return type based on the the OpenAI v1 API.
2039
+
2040
+ OpenAI python package is required to use this method.
2041
+
2042
+ You can install it with `pip install openai`.
2043
+
2044
+ Args:
2045
+ *args: Positional arguments to pass to create_chat_completion.
2046
+ **kwargs: Keyword arguments to pass to create_chat_completion.
2047
+
2048
+ Returns:
2049
+ Generated chat completion or a stream of chat completion chunks.
2050
+ """
2051
+ try:
2052
+ from openai.types.chat import ChatCompletion, ChatCompletionChunk
2053
+
2054
+ stream = kwargs.get("stream", False) # type: ignore
2055
+ assert isinstance(stream, bool)
2056
+ if stream:
2057
+ return (ChatCompletionChunk(**chunk) for chunk in self.create_chat_completion(*args, **kwargs)) # type: ignore
2058
+ else:
2059
+ return ChatCompletion(**self.create_chat_completion(*args, **kwargs)) # type: ignore
2060
+ except ImportError:
2061
+ raise ImportError(
2062
+ "To use create_chat_completion_openai_v1, you must install the openai package."
2063
+ "You can install it with `pip install openai`."
2064
+ )
2065
+
2066
+ def __getstate__(self):
2067
+ return dict(
2068
+ model_path=self.model_path,
2069
+ # Model Params
2070
+ n_gpu_layers=self.model_params.n_gpu_layers,
2071
+ split_mode=self.model_params.split_mode,
2072
+ main_gpu=self.model_params.main_gpu,
2073
+ tensor_split=self.tensor_split,
2074
+ vocab_only=self.model_params.vocab_only,
2075
+ use_mmap=self.model_params.use_mmap,
2076
+ use_mlock=self.model_params.use_mlock,
2077
+ kv_overrides=self.kv_overrides,
2078
+ # Context Params
2079
+ seed=self._seed,
2080
+ n_ctx=self.context_params.n_ctx,
2081
+ n_batch=self.n_batch,
2082
+ n_ubatch=self.context_params.n_ubatch,
2083
+ n_threads=self.context_params.n_threads,
2084
+ n_threads_batch=self.context_params.n_threads_batch,
2085
+ rope_scaling_type=self.context_params.rope_scaling_type,
2086
+ pooling_type=self.context_params.pooling_type,
2087
+ rope_freq_base=self.context_params.rope_freq_base,
2088
+ rope_freq_scale=self.context_params.rope_freq_scale,
2089
+ yarn_ext_factor=self.context_params.yarn_ext_factor,
2090
+ yarn_attn_factor=self.context_params.yarn_attn_factor,
2091
+ yarn_beta_fast=self.context_params.yarn_beta_fast,
2092
+ yarn_beta_slow=self.context_params.yarn_beta_slow,
2093
+ yarn_orig_ctx=self.context_params.yarn_orig_ctx,
2094
+ logits_all=self.context_params.logits_all,
2095
+ embedding=self.context_params.embeddings,
2096
+ offload_kqv=self.context_params.offload_kqv,
2097
+ flash_attn=self.context_params.flash_attn,
2098
+ # Sampling Params
2099
+ no_perf=self.context_params.no_perf,
2100
+ last_n_tokens_size=self.last_n_tokens_size,
2101
+ # LoRA Params
2102
+ lora_base=self.lora_base,
2103
+ lora_scale=self.lora_scale,
2104
+ lora_path=self.lora_path,
2105
+ # Backend Params
2106
+ numa=self.numa,
2107
+ # Chat Format Params
2108
+ chat_format=self.chat_format,
2109
+ chat_handler=self.chat_handler,
2110
+ # Speculative Decidng
2111
+ draft_model=self.draft_model,
2112
+ # KV cache quantization
2113
+ type_k=self.context_params.type_k,
2114
+ type_v=self.context_params.type_v,
2115
+ # Misc
2116
+ spm_infill=self.spm_infill,
2117
+ verbose=self.verbose,
2118
+ )
2119
+
2120
+ def __setstate__(self, state):
2121
+ self.__init__(**state)
2122
+
2123
+ def save_state(self) -> LlamaState:
2124
+ if self.verbose:
2125
+ print("Llama.save_state: saving llama state", file=sys.stderr)
2126
+ state_size = llama_cpp.llama_get_state_size(self._ctx.ctx)
2127
+ if self.verbose:
2128
+ print(f"Llama.save_state: got state size: {state_size}", file=sys.stderr)
2129
+ llama_state = (ctypes.c_uint8 * int(state_size))()
2130
+ if self.verbose:
2131
+ print("Llama.save_state: allocated state", file=sys.stderr)
2132
+ n_bytes = llama_cpp.llama_copy_state_data(self._ctx.ctx, llama_state)
2133
+ if self.verbose:
2134
+ print(f"Llama.save_state: copied llama state: {n_bytes}", file=sys.stderr)
2135
+ if int(n_bytes) > int(state_size):
2136
+ raise RuntimeError("Failed to copy llama state data")
2137
+ llama_state_compact = (ctypes.c_uint8 * int(n_bytes))()
2138
+ llama_cpp.ctypes.memmove(llama_state_compact, llama_state, int(n_bytes))
2139
+ if self.verbose:
2140
+ print(
2141
+ f"Llama.save_state: saving {n_bytes} bytes of llama state",
2142
+ file=sys.stderr,
2143
+ )
2144
+ return LlamaState(
2145
+ scores=self._scores.copy(),
2146
+ input_ids=self.input_ids.copy(),
2147
+ n_tokens=self.n_tokens,
2148
+ llama_state=bytes(llama_state_compact),
2149
+ llama_state_size=n_bytes,
2150
+ seed=self._seed,
2151
+ )
2152
+
2153
+ def load_state(self, state: LlamaState) -> None:
2154
+ # Only filling in up to `n_tokens` and then zero-ing out the rest
2155
+ self.scores[: state.n_tokens, :] = state.scores.copy()
2156
+ rest = self.scores[state.n_tokens :, :]
2157
+ rest[rest > 0] = 0.0
2158
+ self.input_ids = state.input_ids.copy()
2159
+ self.n_tokens = state.n_tokens
2160
+ self._seed = state.seed
2161
+ state_size = state.llama_state_size
2162
+ LLamaStateArrayType = ctypes.c_uint8 * state_size
2163
+ llama_state = LLamaStateArrayType.from_buffer_copy(state.llama_state)
2164
+
2165
+ if llama_cpp.llama_set_state_data(self._ctx.ctx, llama_state) != state_size:
2166
+ raise RuntimeError("Failed to set llama state data")
2167
+
2168
+ def n_ctx(self) -> int:
2169
+ """Return the context window size."""
2170
+ return self._ctx.n_ctx()
2171
+
2172
+ def n_embd(self) -> int:
2173
+ """Return the embedding size."""
2174
+ return self._model.n_embd()
2175
+
2176
+ def n_vocab(self) -> int:
2177
+ """Return the vocabulary size."""
2178
+ return self._model.n_vocab()
2179
+
2180
+ def tokenizer(self) -> LlamaTokenizer:
2181
+ """Return the llama tokenizer for this model."""
2182
+ return LlamaTokenizer(self)
2183
+
2184
+ def token_eos(self) -> int:
2185
+ """Return the end-of-sequence token."""
2186
+ return self._model.token_eos()
2187
+
2188
+ def token_bos(self) -> int:
2189
+ """Return the beginning-of-sequence token."""
2190
+ return self._model.token_bos()
2191
+
2192
+ def token_nl(self) -> int:
2193
+ """Return the newline token."""
2194
+ return self._model.token_nl()
2195
+
2196
+ def pooling_type(self) -> str:
2197
+ """Return the pooling type."""
2198
+ return self._ctx.pooling_type()
2199
+
2200
+ def close(self) -> None:
2201
+ """Explicitly free the model from memory."""
2202
+ self._stack.close()
2203
+
2204
+ def __del__(self) -> None:
2205
+ self.close()
2206
+
2207
+ @staticmethod
2208
+ def logits_to_logprobs(
2209
+ logits: Union[npt.NDArray[np.single], List], axis: int = -1
2210
+ ) -> npt.NDArray[np.single]:
2211
+ # https://docs.scipy.org/doc/scipy/reference/generated/scipy.special.log_softmax.html
2212
+ logits_maxs: np.ndarray = np.amax(logits, axis=axis, keepdims=True)
2213
+ if logits_maxs.ndim > 0:
2214
+ logits_maxs[~np.isfinite(logits_maxs)] = 0
2215
+ elif not np.isfinite(logits_maxs):
2216
+ logits_maxs = 0
2217
+ subtract_maxs = np.subtract(logits, logits_maxs, dtype=np.single)
2218
+ exp = np.exp(subtract_maxs)
2219
+ # Suppress warnings about log of zero
2220
+ with np.errstate(divide="ignore"):
2221
+ summed = np.sum(exp, axis=axis, keepdims=True)
2222
+ out = np.log(summed)
2223
+ return subtract_maxs - out
2224
+
2225
+ @staticmethod
2226
+ def longest_token_prefix(a: Sequence[int], b: Sequence[int]):
2227
+ longest_prefix = 0
2228
+ for _a, _b in zip(a, b):
2229
+ if _a == _b:
2230
+ longest_prefix += 1
2231
+ else:
2232
+ break
2233
+ return longest_prefix
2234
+
2235
+ @classmethod
2236
+ def from_pretrained(
2237
+ cls,
2238
+ repo_id: str,
2239
+ filename: Optional[str],
2240
+ additional_files: Optional[List] = None,
2241
+ local_dir: Optional[Union[str, os.PathLike[str]]] = None,
2242
+ local_dir_use_symlinks: Union[bool, Literal["auto"]] = "auto",
2243
+ cache_dir: Optional[Union[str, os.PathLike[str]]] = None,
2244
+ **kwargs: Any,
2245
+ ) -> "Llama":
2246
+ """Create a Llama model from a pretrained model name or path.
2247
+ This method requires the huggingface-hub package.
2248
+ You can install it with `pip install huggingface-hub`.
2249
+
2250
+ Args:
2251
+ repo_id: The model repo id.
2252
+ filename: A filename or glob pattern to match the model file in the repo.
2253
+ additional_files: A list of filenames or glob patterns to match additional model files in the repo.
2254
+ local_dir: The local directory to save the model to.
2255
+ local_dir_use_symlinks: Whether to use symlinks when downloading the model.
2256
+ **kwargs: Additional keyword arguments to pass to the Llama constructor.
2257
+
2258
+ Returns:
2259
+ A Llama model."""
2260
+ try:
2261
+ from huggingface_hub import hf_hub_download, HfFileSystem
2262
+ from huggingface_hub.utils import validate_repo_id
2263
+ except ImportError:
2264
+ raise ImportError(
2265
+ "Llama.from_pretrained requires the huggingface-hub package. "
2266
+ "You can install it with `pip install huggingface-hub`."
2267
+ )
2268
+
2269
+ validate_repo_id(repo_id)
2270
+
2271
+ hffs = HfFileSystem()
2272
+
2273
+ files = [
2274
+ file["name"] if isinstance(file, dict) else file
2275
+ for file in hffs.ls(repo_id, recursive=True)
2276
+ ]
2277
+
2278
+ # split each file into repo_id, subfolder, filename
2279
+ file_list: List[str] = []
2280
+ for file in files:
2281
+ rel_path = Path(file).relative_to(repo_id)
2282
+ file_list.append(str(rel_path))
2283
+
2284
+ # find the only/first shard file:
2285
+ matching_files = [file for file in file_list if fnmatch.fnmatch(file, filename)] # type: ignore
2286
+
2287
+ if len(matching_files) == 0:
2288
+ raise ValueError(
2289
+ f"No file found in {repo_id} that match {filename}\n\n"
2290
+ f"Available Files:\n{json.dumps(file_list)}"
2291
+ )
2292
+
2293
+ if len(matching_files) > 1:
2294
+ raise ValueError(
2295
+ f"Multiple files found in {repo_id} matching {filename}\n\n"
2296
+ f"Available Files:\n{json.dumps(files)}"
2297
+ )
2298
+
2299
+ (matching_file,) = matching_files
2300
+
2301
+ subfolder = str(Path(matching_file).parent)
2302
+ filename = Path(matching_file).name
2303
+
2304
+ # download the file
2305
+ hf_hub_download(
2306
+ repo_id=repo_id,
2307
+ filename=filename,
2308
+ subfolder=subfolder,
2309
+ local_dir=local_dir,
2310
+ local_dir_use_symlinks=local_dir_use_symlinks,
2311
+ cache_dir=cache_dir,
2312
+ )
2313
+
2314
+ if additional_files:
2315
+ for additonal_file_name in additional_files:
2316
+ # find the additional shard file:
2317
+ matching_additional_files = [file for file in file_list if fnmatch.fnmatch(file, additonal_file_name)]
2318
+
2319
+ if len(matching_additional_files) == 0:
2320
+ raise ValueError(
2321
+ f"No file found in {repo_id} that match {additonal_file_name}\n\n"
2322
+ f"Available Files:\n{json.dumps(file_list)}"
2323
+ )
2324
+
2325
+ if len(matching_additional_files) > 1:
2326
+ raise ValueError(
2327
+ f"Multiple files found in {repo_id} matching {additonal_file_name}\n\n"
2328
+ f"Available Files:\n{json.dumps(files)}"
2329
+ )
2330
+
2331
+ (matching_additional_file,) = matching_additional_files
2332
+
2333
+ # download the additional file
2334
+ hf_hub_download(
2335
+ repo_id=repo_id,
2336
+ filename=matching_additional_file,
2337
+ subfolder=subfolder,
2338
+ local_dir=local_dir,
2339
+ local_dir_use_symlinks=local_dir_use_symlinks,
2340
+ cache_dir=cache_dir,
2341
+ )
2342
+
2343
+ if local_dir is None:
2344
+ model_path = hf_hub_download(
2345
+ repo_id=repo_id,
2346
+ filename=filename,
2347
+ subfolder=subfolder,
2348
+ local_dir=local_dir,
2349
+ local_dir_use_symlinks=local_dir_use_symlinks,
2350
+ cache_dir=cache_dir,
2351
+ local_files_only=True,
2352
+ )
2353
+ else:
2354
+ model_path = os.path.join(local_dir, filename)
2355
+
2356
+ # loading the first file of a sharded GGUF loads all remaining shard files in the subfolder
2357
+ return cls(
2358
+ model_path=model_path,
2359
+ **kwargs,
2360
+ )
2361
+
2362
+
2363
+ class LlamaState:
2364
+ def __init__(
2365
+ self,
2366
+ input_ids: npt.NDArray[np.intc],
2367
+ scores: npt.NDArray[np.single],
2368
+ n_tokens: int,
2369
+ llama_state: bytes,
2370
+ llama_state_size: int,
2371
+ seed: int,
2372
+ ):
2373
+ self.input_ids = input_ids
2374
+ self.scores = scores
2375
+ self.n_tokens = n_tokens
2376
+ self.llama_state = llama_state
2377
+ self.llama_state_size = llama_state_size
2378
+ self.seed = seed
2379
+
2380
+
2381
+ LogitsProcessor = Callable[
2382
+ [npt.NDArray[np.intc], npt.NDArray[np.single]], npt.NDArray[np.single]
2383
+ ]
2384
+
2385
+
2386
+ class LogitsProcessorList(List[LogitsProcessor]):
2387
+ def __call__(
2388
+ self, input_ids: npt.NDArray[np.intc], scores: npt.NDArray[np.single]
2389
+ ) -> npt.NDArray[np.single]:
2390
+ for processor in self:
2391
+ scores = processor(input_ids, scores)
2392
+ return scores
2393
+
2394
+
2395
+ StoppingCriteria = Callable[[npt.NDArray[np.intc], npt.NDArray[np.single]], bool]
2396
+
2397
+
2398
+ class StoppingCriteriaList(List[StoppingCriteria]):
2399
+ def __call__(
2400
+ self, input_ids: npt.NDArray[np.intc], logits: npt.NDArray[np.single]
2401
+ ) -> bool:
2402
+ return any([stopping_criteria(input_ids, logits) for stopping_criteria in self])
2403
+
2404
+
2405
+ class MinTokensLogitsProcessor(LogitsProcessor):
2406
+ def __init__(self, min_tokens: int, token_eos: int):
2407
+ self.min_tokens = min_tokens
2408
+ self.token_eos = token_eos
2409
+ self.prompt_tokens = None
2410
+
2411
+ def __call__(
2412
+ self, input_ids: npt.NDArray[np.intc], scores: npt.NDArray[np.single]
2413
+ ) -> npt.NDArray[np.single]:
2414
+ if self.prompt_tokens is None:
2415
+ self.prompt_tokens = len(input_ids)
2416
+ if len(input_ids) - self.prompt_tokens < self.min_tokens:
2417
+ scores[self.token_eos] = -np.inf
2418
+ return scores
llama_cpp/llama_cache.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from abc import ABC, abstractmethod
3
+ from typing import (
4
+ Optional,
5
+ Sequence,
6
+ Tuple,
7
+ )
8
+ from collections import OrderedDict
9
+
10
+ import diskcache
11
+
12
+ import llama_cpp.llama
13
+
14
+ from .llama_types import *
15
+
16
+
17
+ class BaseLlamaCache(ABC):
18
+ """Base cache class for a llama.cpp model."""
19
+
20
+ def __init__(self, capacity_bytes: int = (2 << 30)):
21
+ self.capacity_bytes = capacity_bytes
22
+
23
+ @property
24
+ @abstractmethod
25
+ def cache_size(self) -> int:
26
+ raise NotImplementedError
27
+
28
+ def _find_longest_prefix_key(
29
+ self,
30
+ key: Tuple[int, ...],
31
+ ) -> Optional[Tuple[int, ...]]:
32
+ pass
33
+
34
+ @abstractmethod
35
+ def __getitem__(self, key: Sequence[int]) -> "llama_cpp.llama.LlamaState":
36
+ raise NotImplementedError
37
+
38
+ @abstractmethod
39
+ def __contains__(self, key: Sequence[int]) -> bool:
40
+ raise NotImplementedError
41
+
42
+ @abstractmethod
43
+ def __setitem__(
44
+ self, key: Sequence[int], value: "llama_cpp.llama.LlamaState"
45
+ ) -> None:
46
+ raise NotImplementedError
47
+
48
+
49
+ class LlamaRAMCache(BaseLlamaCache):
50
+ """Cache for a llama.cpp model using RAM."""
51
+
52
+ def __init__(self, capacity_bytes: int = (2 << 30)):
53
+ super().__init__(capacity_bytes)
54
+ self.capacity_bytes = capacity_bytes
55
+ self.cache_state: OrderedDict[
56
+ Tuple[int, ...], "llama_cpp.llama.LlamaState"
57
+ ] = OrderedDict()
58
+
59
+ @property
60
+ def cache_size(self):
61
+ return sum([state.llama_state_size for state in self.cache_state.values()])
62
+
63
+ def _find_longest_prefix_key(
64
+ self,
65
+ key: Tuple[int, ...],
66
+ ) -> Optional[Tuple[int, ...]]:
67
+ min_len = 0
68
+ min_key = None
69
+ keys = (
70
+ (k, llama_cpp.llama.Llama.longest_token_prefix(k, key))
71
+ for k in self.cache_state.keys()
72
+ )
73
+ for k, prefix_len in keys:
74
+ if prefix_len > min_len:
75
+ min_len = prefix_len
76
+ min_key = k
77
+ return min_key
78
+
79
+ def __getitem__(self, key: Sequence[int]) -> "llama_cpp.llama.LlamaState":
80
+ key = tuple(key)
81
+ _key = self._find_longest_prefix_key(key)
82
+ if _key is None:
83
+ raise KeyError("Key not found")
84
+ value = self.cache_state[_key]
85
+ self.cache_state.move_to_end(_key)
86
+ return value
87
+
88
+ def __contains__(self, key: Sequence[int]) -> bool:
89
+ return self._find_longest_prefix_key(tuple(key)) is not None
90
+
91
+ def __setitem__(self, key: Sequence[int], value: "llama_cpp.llama.LlamaState"):
92
+ key = tuple(key)
93
+ if key in self.cache_state:
94
+ del self.cache_state[key]
95
+ self.cache_state[key] = value
96
+ while self.cache_size > self.capacity_bytes and len(self.cache_state) > 0:
97
+ self.cache_state.popitem(last=False)
98
+
99
+
100
+ # Alias for backwards compatibility
101
+ LlamaCache = LlamaRAMCache
102
+
103
+
104
+ class LlamaDiskCache(BaseLlamaCache):
105
+ """Cache for a llama.cpp model using disk."""
106
+
107
+ def __init__(
108
+ self, cache_dir: str = ".cache/llama_cache", capacity_bytes: int = (2 << 30)
109
+ ):
110
+ super().__init__(capacity_bytes)
111
+ self.cache = diskcache.Cache(cache_dir)
112
+
113
+ @property
114
+ def cache_size(self):
115
+ return int(self.cache.volume()) # type: ignore
116
+
117
+ def _find_longest_prefix_key(
118
+ self,
119
+ key: Tuple[int, ...],
120
+ ) -> Optional[Tuple[int, ...]]:
121
+ min_len = 0
122
+ min_key: Optional[Tuple[int, ...]] = None
123
+ for k in self.cache.iterkeys(): # type: ignore
124
+ prefix_len = llama_cpp.llama.Llama.longest_token_prefix(k, key)
125
+ if prefix_len > min_len:
126
+ min_len = prefix_len
127
+ min_key = k # type: ignore
128
+ return min_key
129
+
130
+ def __getitem__(self, key: Sequence[int]) -> "llama_cpp.llama.LlamaState":
131
+ key = tuple(key)
132
+ _key = self._find_longest_prefix_key(key)
133
+ if _key is None:
134
+ raise KeyError("Key not found")
135
+ value: "llama_cpp.llama.LlamaState" = self.cache.pop(_key) # type: ignore
136
+ # NOTE: This puts an integer as key in cache, which breaks,
137
+ # Llama.longest_token_prefix(k, key) above since k is not a tuple of ints/tokens
138
+ # self.cache.push(_key, side="front") # type: ignore
139
+ return value
140
+
141
+ def __contains__(self, key: Sequence[int]) -> bool:
142
+ return self._find_longest_prefix_key(tuple(key)) is not None
143
+
144
+ def __setitem__(self, key: Sequence[int], value: "llama_cpp.llama.LlamaState"):
145
+ print("LlamaDiskCache.__setitem__: called", file=sys.stderr)
146
+ key = tuple(key)
147
+ if key in self.cache:
148
+ print("LlamaDiskCache.__setitem__: delete", file=sys.stderr)
149
+ del self.cache[key]
150
+ self.cache[key] = value
151
+ print("LlamaDiskCache.__setitem__: set", file=sys.stderr)
152
+ while self.cache_size > self.capacity_bytes and len(self.cache) > 0:
153
+ key_to_remove = next(iter(self.cache))
154
+ del self.cache[key_to_remove]
155
+ print("LlamaDiskCache.__setitem__: trim", file=sys.stderr)
llama_cpp/llama_chat_format.py ADDED
The diff for this file is too large to render. See raw diff
 
llama_cpp/llama_cpp.py ADDED
The diff for this file is too large to render. See raw diff
 
llama_cpp/llama_grammar.py ADDED
@@ -0,0 +1,953 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Python implementation of llama grammar parser directly translated from C++ source file in vendor/llama.cpp/common/grammar-parser.cpp."""
2
+
3
+ # flake8: noqa
4
+ from pathlib import Path
5
+
6
+ from itertools import groupby
7
+ from typing import (
8
+ Any,
9
+ Set,
10
+ List,
11
+ Optional,
12
+ Tuple,
13
+ Union,
14
+ )
15
+
16
+ LLAMA_GRAMMAR_DEFAULT_ROOT = "root"
17
+
18
+
19
+ class LlamaGrammar:
20
+ def __init__(self, *args, _grammar: str, **kwargs):
21
+ self._grammar = _grammar
22
+ self._root = LLAMA_GRAMMAR_DEFAULT_ROOT
23
+
24
+ @classmethod
25
+ def from_string(cls, grammar: str, verbose: bool = True) -> "LlamaGrammar":
26
+ return cls(_grammar=grammar)
27
+
28
+ @classmethod
29
+ def from_file(cls, file: Union[str, Path], verbose: bool = True) -> "LlamaGrammar":
30
+ try:
31
+ with open(file) as f:
32
+ grammar = f.read()
33
+ except Exception as err:
34
+ raise Exception(
35
+ f"{cls.from_file.__name__}: error reading grammar file: {err}"
36
+ )
37
+
38
+ if grammar:
39
+ return cls.from_string(grammar, verbose=verbose)
40
+
41
+ raise ValueError(
42
+ f"{cls.from_file.__name__}: error parsing grammar file: params_grammer is empty"
43
+ )
44
+
45
+ @classmethod
46
+ def from_json_schema(cls, json_schema: str, verbose: bool = True) -> "LlamaGrammar":
47
+ return cls.from_string(json_schema_to_gbnf(json_schema), verbose=verbose)
48
+
49
+
50
+ """llama.cpp gbnf rules from vendor/llama.cpp/grammars"""
51
+
52
+ ARITHMETIC_GBNF = r"""
53
+ root ::= (expr "=" ws term "\n")+
54
+ expr ::= term ([-+*/] term)*
55
+ term ::= ident | num | "(" ws expr ")" ws
56
+ ident ::= [a-z] [a-z0-9_]* ws
57
+ num ::= [0-9]+ ws
58
+ ws ::= [ \t\n]*
59
+ """
60
+
61
+ C_GBNF = r"""
62
+ root ::= (declaration)*
63
+
64
+ declaration ::= dataType identifier "(" parameter? ")" "{" statement* "}"
65
+
66
+ dataType ::= "int" ws | "float" ws | "char" ws
67
+ identifier ::= [a-zA-Z_] [a-zA-Z_0-9]*
68
+
69
+ parameter ::= dataType identifier
70
+
71
+ statement ::=
72
+ ( dataType identifier ws "=" ws expression ";" ) |
73
+ ( identifier ws "=" ws expression ";" ) |
74
+ ( identifier ws "(" argList? ")" ";" ) |
75
+ ( "return" ws expression ";" ) |
76
+ ( "while" "(" condition ")" "{" statement* "}" ) |
77
+ ( "for" "(" forInit ";" ws condition ";" ws forUpdate ")" "{" statement* "}" ) |
78
+ ( "if" "(" condition ")" "{" statement* "}" ("else" "{" statement* "}")? ) |
79
+ ( singleLineComment ) |
80
+ ( multiLineComment )
81
+
82
+ forInit ::= dataType identifier ws "=" ws expression | identifier ws "=" ws expression
83
+ forUpdate ::= identifier ws "=" ws expression
84
+
85
+ condition ::= expression relationOperator expression
86
+ relationOperator ::= ("<=" | "<" | "==" | "!=" | ">=" | ">")
87
+
88
+ expression ::= term (("+" | "-") term)*
89
+ term ::= factor(("*" | "/") factor)*
90
+
91
+ factor ::= identifier | number | unaryTerm | funcCall | parenExpression
92
+ unaryTerm ::= "-" factor
93
+ funcCall ::= identifier "(" argList? ")"
94
+ parenExpression ::= "(" ws expression ws ")"
95
+
96
+ argList ::= expression ("," ws expression)*
97
+
98
+ number ::= [0-9]+
99
+
100
+ singleLineComment ::= "//" [^\n]* "\n"
101
+ multiLineComment ::= "/*" ( [^*] | ("*" [^/]) )* "*/"
102
+
103
+ ws ::= ([ \t\n]+)
104
+ """
105
+
106
+ CHESS_GBNF = r"""
107
+ root ::= object
108
+ value ::= object | array | string | number | ("true" | "false" | "null") ws
109
+
110
+ object ::=
111
+ "{" ws (
112
+ string ":" ws value
113
+ ("," ws string ":" ws value)*
114
+ )? "}" ws
115
+
116
+ array ::=
117
+ "[" ws (
118
+ value
119
+ ("," ws value)*
120
+ )? "]" ws
121
+
122
+ string ::=
123
+ "\"" (
124
+ [^"\\] |
125
+ "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
126
+ )* "\"" ws
127
+
128
+ number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws
129
+
130
+ # Optional space: by convention, applied in this grammar after literal chars when allowed
131
+ ws ::= ([ \t\n] ws)?
132
+ """
133
+
134
+ JAPANESE_GBNF = r"""
135
+ root ::= object
136
+ value ::= object | array | string | number | ("true" | "false" | "null") ws
137
+
138
+ object ::=
139
+ "{" ws (
140
+ string ":" ws value
141
+ ("," ws string ":" ws value)*
142
+ )? "}" ws
143
+
144
+ array ::=
145
+ "[" ws (
146
+ value
147
+ ("," ws value)*
148
+ )? "]" ws
149
+
150
+ string ::=
151
+ "\"" (
152
+ [^"\\] |
153
+ "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
154
+ )* "\"" ws
155
+
156
+ number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws
157
+
158
+ # Optional space: by convention, applied in this grammar after literal chars when allowed
159
+ ws ::= ([ \t\n] ws)?
160
+ """
161
+
162
+ JSON_ARR_GBNF = r"""
163
+ # This is the same as json.gbnf but we restrict whitespaces at the end of the root array
164
+ # Useful for generating JSON arrays
165
+
166
+ root ::= arr
167
+ value ::= object | array | string | number | ("true" | "false" | "null") ws
168
+
169
+ arr ::=
170
+ "[\n" ws (
171
+ value
172
+ (",\n" ws value)*
173
+ )? "]"
174
+
175
+ object ::=
176
+ "{" ws (
177
+ string ":" ws value
178
+ ("," ws string ":" ws value)*
179
+ )? "}" ws
180
+
181
+ array ::=
182
+ "[" ws (
183
+ value
184
+ ("," ws value)*
185
+ )? "]" ws
186
+
187
+ string ::=
188
+ "\"" (
189
+ [^"\\\x7F\x00-\x1F] |
190
+ "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
191
+ )* "\"" ws
192
+
193
+ number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws
194
+
195
+ # Optional space: by convention, applied in this grammar after literal chars when allowed
196
+ ws ::= ([ \t\n] ws)?
197
+ """
198
+
199
+
200
+ JSON_GBNF = r"""
201
+ root ::= object
202
+ value ::= object | array | string | number | ("true" | "false" | "null") ws
203
+
204
+ object ::=
205
+ "{" ws (
206
+ string ":" ws value
207
+ ("," ws string ":" ws value)*
208
+ )? "}" ws
209
+
210
+ array ::=
211
+ "[" ws (
212
+ value
213
+ ("," ws value)*
214
+ )? "]" ws
215
+
216
+ string ::=
217
+ "\"" (
218
+ [^"\\\x7F\x00-\x1F] |
219
+ "\\" (["\\bfnrt] | "u" [0-9a-fA-F]{4}) # escapes
220
+ )* "\"" ws
221
+
222
+ number ::= ("-"? ([0-9] | [1-9] [0-9]{0,15})) ("." [0-9]+)? ([eE] [-+]? [0-9] [1-9]{0,15})? ws
223
+
224
+ # Optional space: by convention, applied in this grammar after literal chars when allowed
225
+ ws ::= | " " | "\n" [ \t]{0,20}
226
+ """
227
+
228
+ LIST_GBNF = r"""
229
+ root ::= item+
230
+
231
+ # Excludes various line break characters
232
+ item ::= "- " [^\r\n\x0b\x0c\x85\u2028\u2029]+ "\n"
233
+ """
234
+
235
+ """llama.cpp json-schema to grammar converter from vendor/llama.cpp/examples/json-schema-to-grammar.py"""
236
+ import json
237
+ import re
238
+ from typing import List, Optional
239
+
240
+ # whitespace is constrained to a single space char to prevent model "running away" in
241
+ # whitespace. Also maybe improves generation quality?
242
+ SPACE_RULE = '" "?'
243
+
244
+
245
+ INVALID_RULE_CHARS_RE = re.compile(r"[^a-zA-Z0-9-]+")
246
+ GRAMMAR_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"]')
247
+ GRAMMAR_LITERAL_ESCAPES = {"\r": "\\r", "\n": "\\n", '"': '\\"'}
248
+
249
+ # whitespace is constrained to a single space char to prevent model "running away" in
250
+ # whitespace. Also maybe improves generation quality?
251
+ SPACE_RULE = '" "?'
252
+
253
+
254
+ def _build_repetition(
255
+ item_rule, min_items, max_items, separator_rule=None, item_rule_is_literal=False
256
+ ):
257
+ if not separator_rule:
258
+ if min_items == 0 and max_items == 1:
259
+ return f"{item_rule}?"
260
+ elif min_items == 1 and max_items is None:
261
+ return f"{item_rule}+"
262
+
263
+ result = ""
264
+
265
+ if min_items > 0:
266
+ if item_rule_is_literal and separator_rule is None:
267
+ result = '"' + (item_rule[1:-1] * min_items) + '"'
268
+ else:
269
+ result = (f" {separator_rule} " if separator_rule else " ").join(
270
+ [item_rule] * min_items
271
+ )
272
+
273
+ def opt_repetitions(up_to_n, prefix_with_sep=False):
274
+ """
275
+ - n=4, no sep: '(a (a (a (a)?)?)?)?'
276
+ - n=4, sep=',', prefix: '("," a ("," a ("," a ("," a)?)?)?)?'
277
+ - n=4, sep=',', no prefix: '(a ("," a ("," a ("," a)?)?)?)?'
278
+ """
279
+
280
+ content = (
281
+ f"{separator_rule} {item_rule}"
282
+ if prefix_with_sep and separator_rule
283
+ else item_rule
284
+ )
285
+ if up_to_n == 0:
286
+ return ""
287
+ elif up_to_n == 1:
288
+ return f"({content})?"
289
+ elif separator_rule and not prefix_with_sep:
290
+ return f"({content} {opt_repetitions(up_to_n - 1, prefix_with_sep=True)})?"
291
+ else:
292
+ return (f"({content} " * up_to_n).rstrip() + (")?" * up_to_n)
293
+
294
+ if min_items > 0 and max_items != min_items:
295
+ result += " "
296
+
297
+ if max_items is not None:
298
+ result += opt_repetitions(max_items - min_items, prefix_with_sep=min_items > 0)
299
+ else:
300
+ item_operator = f'({separator_rule + " " if separator_rule else ""}{item_rule})'
301
+
302
+ if min_items == 0 and separator_rule:
303
+ result = f"({item_rule} {item_operator}*)?"
304
+ else:
305
+ result += f"{item_operator}*"
306
+
307
+ return result
308
+
309
+
310
+ class BuiltinRule:
311
+ def __init__(self, content: str, deps: list = None):
312
+ self.content = content
313
+ self.deps = deps or []
314
+
315
+
316
+ _up_to_15_digits = _build_repetition("[0-9]", 0, 15)
317
+
318
+ PRIMITIVE_RULES = {
319
+ "boolean": BuiltinRule('("true" | "false") space', []),
320
+ "decimal-part": BuiltinRule("[0-9] " + _up_to_15_digits, []),
321
+ "integral-part": BuiltinRule("[0-9] | [1-9] " + _up_to_15_digits, []),
322
+ "number": BuiltinRule(
323
+ '("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space',
324
+ ["integral-part", "decimal-part"],
325
+ ),
326
+ "integer": BuiltinRule('("-"? integral-part) space', ["integral-part"]),
327
+ "value": BuiltinRule(
328
+ "object | array | string | number | boolean | null",
329
+ ["object", "array", "string", "number", "boolean", "null"],
330
+ ),
331
+ "object": BuiltinRule(
332
+ '"{" space ( string ":" space value ("," space string ":" space value)* )? "}" space',
333
+ ["string", "value"],
334
+ ),
335
+ "array": BuiltinRule(
336
+ '"[" space ( value ("," space value)* )? "]" space', ["value"]
337
+ ),
338
+ "uuid": BuiltinRule(
339
+ r'"\"" '
340
+ + ' "-" '.join("[0-9a-fA-F]" * n for n in [8, 4, 4, 4, 12])
341
+ + r' "\"" space',
342
+ [],
343
+ ),
344
+ "char": BuiltinRule(
345
+ r'[^"\\] | "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])',
346
+ [],
347
+ ),
348
+ "string": BuiltinRule(r'"\"" char* "\"" space', ["char"]),
349
+ "null": BuiltinRule('"null" space', []),
350
+ }
351
+
352
+ # TODO: support "uri", "email" string formats
353
+ STRING_FORMAT_RULES = {
354
+ "date": BuiltinRule(
355
+ '[0-9] [0-9] [0-9] [0-9] "-" ( "0" [1-9] | "1" [0-2] ) "-" ( "0" [1-9] | [1-2] [0-9] | "3" [0-1] )',
356
+ [],
357
+ ),
358
+ "time": BuiltinRule(
359
+ '([01] [0-9] | "2" [0-3]) ":" [0-5] [0-9] ":" [0-5] [0-9] ( "." [0-9] [0-9] [0-9] )? ( "Z" | ( "+" | "-" ) ( [01] [0-9] | "2" [0-3] ) ":" [0-5] [0-9] )',
360
+ [],
361
+ ),
362
+ "date-time": BuiltinRule('date "T" time', ["date", "time"]),
363
+ "date-string": BuiltinRule('"\\"" date "\\"" space', ["date"]),
364
+ "time-string": BuiltinRule('"\\"" time "\\"" space', ["time"]),
365
+ "date-time-string": BuiltinRule('"\\"" date-time "\\"" space', ["date-time"]),
366
+ }
367
+
368
+ DOTALL = "[\\U00000000-\\U0010FFFF]"
369
+ DOT = "[^\\x0A\\x0D]"
370
+
371
+ RESERVED_NAMES = set(
372
+ ["root", "dot", *PRIMITIVE_RULES.keys(), *STRING_FORMAT_RULES.keys()]
373
+ )
374
+
375
+
376
+ NON_LITERAL_SET = set("|.()[]{}*+?")
377
+ ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = set("[]()|{}*+?")
378
+
379
+
380
+ class SchemaConverter:
381
+ def __init__(self, *, prop_order, allow_fetch, dotall, raw_pattern):
382
+ self._prop_order = prop_order
383
+ self._allow_fetch = allow_fetch
384
+ self._dotall = dotall
385
+ self._raw_pattern = raw_pattern
386
+ self._rules = {
387
+ "space": SPACE_RULE,
388
+ }
389
+ self._refs = {}
390
+ self._refs_being_resolved = set()
391
+
392
+ def _format_literal(self, literal):
393
+ escaped = GRAMMAR_LITERAL_ESCAPE_RE.sub(
394
+ lambda m: GRAMMAR_LITERAL_ESCAPES.get(m.group(0)), literal
395
+ )
396
+ return f'"{escaped}"'
397
+
398
+ def not_literal(
399
+ self, literal: str, dotall: bool = True, maybe_escaped_underscores=False
400
+ ) -> str:
401
+ """
402
+ not_literal('a') -> '[^a]'
403
+ not_literal('abc') -> '([^a] | "a" ([^b] | "b" ([^c])?)?)?'
404
+ """
405
+ assert len(literal) > 0, "Empty literal not supported"
406
+
407
+ def recurse(i: int):
408
+ c = literal[i]
409
+ if maybe_escaped_underscores and c == "_":
410
+ yield f"[^{c}\\\\]"
411
+ yield " | "
412
+ yield f'"\\\\"? "{c}"'
413
+ else:
414
+ yield f"[^{c}]"
415
+ if i < len(literal) - 1:
416
+ yield " | "
417
+ yield self._format_literal(c)
418
+ yield " ("
419
+ yield from recurse(i + 1)
420
+ yield ")?"
421
+
422
+ return "".join(("(", *recurse(0), ")"))
423
+
424
+ def _add_rule(self, name, rule):
425
+ esc_name = INVALID_RULE_CHARS_RE.sub("-", name)
426
+ if esc_name not in self._rules or self._rules[esc_name] == rule:
427
+ key = esc_name
428
+ else:
429
+ i = 0
430
+ while (
431
+ f"{esc_name}{i}" in self._rules
432
+ and self._rules[f"{esc_name}{i}"] != rule
433
+ ):
434
+ i += 1
435
+ key = f"{esc_name}{i}"
436
+ self._rules[key] = rule
437
+ return key
438
+
439
+ def resolve_refs(self, schema: dict, url: str):
440
+ """
441
+ Resolves all $ref fields in the given schema, fetching any remote schemas,
442
+ replacing $ref with absolute reference URL and populating self._refs with the
443
+ respective referenced (sub)schema dictionaries.
444
+ """
445
+
446
+ def visit(n: dict):
447
+ if isinstance(n, list):
448
+ return [visit(x) for x in n]
449
+ elif isinstance(n, dict):
450
+ ref = n.get("$ref")
451
+ if ref is not None and ref not in self._refs:
452
+ if ref.startswith("https://"):
453
+ assert (
454
+ self._allow_fetch
455
+ ), "Fetching remote schemas is not allowed (use --allow-fetch for force)"
456
+ import requests
457
+
458
+ frag_split = ref.split("#")
459
+ base_url = frag_split[0]
460
+
461
+ target = self._refs.get(base_url)
462
+ if target is None:
463
+ target = self.resolve_refs(
464
+ requests.get(ref).json(), base_url
465
+ )
466
+ self._refs[base_url] = target
467
+
468
+ if len(frag_split) == 1 or frag_split[-1] == "":
469
+ return target
470
+ elif ref.startswith("#/"):
471
+ target = schema
472
+ ref = f"{url}{ref}"
473
+ n["$ref"] = ref
474
+ else:
475
+ raise ValueError(f"Unsupported ref {ref}")
476
+
477
+ for sel in ref.split("#")[-1].split("/")[1:]:
478
+ assert (
479
+ target is not None and sel in target
480
+ ), f"Error resolving ref {ref}: {sel} not in {target}"
481
+ target = target[sel]
482
+
483
+ self._refs[ref] = target
484
+ else:
485
+ for v in n.values():
486
+ visit(v)
487
+
488
+ return n
489
+
490
+ return visit(schema)
491
+
492
+ def _generate_union_rule(self, name, alt_schemas):
493
+ return " | ".join(
494
+ (
495
+ self.visit(alt_schema, f'{name}{"-" if name else "alternative-"}{i}')
496
+ for i, alt_schema in enumerate(alt_schemas)
497
+ )
498
+ )
499
+
500
+ def _visit_pattern(self, pattern, name):
501
+ """
502
+ Transforms a regular expression pattern into a GBNF rule.
503
+
504
+ Input: https://json-schema.org/understanding-json-schema/reference/regular_expressions
505
+ Output: https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md
506
+
507
+ Unsupported features: negative/positive lookaheads, greedy/non-greedy modifiers.
508
+
509
+ Mostly a 1:1 translation, except for {x} / {x,} / {x,y} quantifiers for which
510
+ we define sub-rules to keep the output lean.
511
+ """
512
+
513
+ assert pattern.startswith("^") and pattern.endswith(
514
+ "$"
515
+ ), 'Pattern must start with "^" and end with "$"'
516
+ pattern = pattern[1:-1]
517
+ sub_rule_ids = {}
518
+
519
+ i = 0
520
+ length = len(pattern)
521
+
522
+ def to_rule(s: Tuple[str, bool]) -> str:
523
+ (txt, is_literal) = s
524
+ return '"' + txt + '"' if is_literal else txt
525
+
526
+ def transform() -> Tuple[str, bool]:
527
+ """
528
+ Parse a unit at index i (advancing it), and return its string representation + whether it's a literal.
529
+ """
530
+ nonlocal i
531
+ nonlocal pattern
532
+ nonlocal sub_rule_ids
533
+
534
+ start = i
535
+ # For each component of this sequence, store its string representation and whether it's a literal.
536
+ # We only need a flat structure here to apply repetition operators to the last item, and
537
+ # to merge literals at the and (we're parsing grouped ( sequences ) recursively and don't treat '|' specially
538
+ # (GBNF's syntax is luckily very close to regular expressions!)
539
+ seq: list[Tuple[str, bool]] = []
540
+
541
+ def get_dot():
542
+ if self._dotall:
543
+ rule = DOTALL
544
+ else:
545
+ # Accept any character... except \n and \r line break chars (\x0A and \xOD)
546
+ rule = DOT
547
+ return self._add_rule(f"dot", rule)
548
+
549
+ def join_seq():
550
+ nonlocal seq
551
+ ret = []
552
+ for is_literal, g in groupby(seq, lambda x: x[1]):
553
+ if is_literal:
554
+ ret.append(("".join(x[0] for x in g), True))
555
+ else:
556
+ ret.extend(g)
557
+ if len(ret) == 1:
558
+ return ret[0]
559
+ return (" ".join(to_rule(x) for x in seq), False)
560
+
561
+ while i < length:
562
+ c = pattern[i]
563
+ if c == ".":
564
+ seq.append((get_dot(), False))
565
+ i += 1
566
+ elif c == "(":
567
+ i += 1
568
+ if i < length:
569
+ assert (
570
+ pattern[i] != "?"
571
+ ), f'Unsupported pattern syntax "{pattern[i]}" at index {i} of /{pattern}/'
572
+ seq.append((f"({to_rule(transform())})", False))
573
+ elif c == ")":
574
+ i += 1
575
+ assert (
576
+ start > 0 and pattern[start - 1] == "("
577
+ ), f"Unbalanced parentheses; start = {start}, i = {i}, pattern = {pattern}"
578
+ return join_seq()
579
+ elif c == "[":
580
+ square_brackets = c
581
+ i += 1
582
+ while i < length and pattern[i] != "]":
583
+ if pattern[i] == "\\":
584
+ square_brackets += pattern[i : i + 2]
585
+ i += 2
586
+ else:
587
+ square_brackets += pattern[i]
588
+ i += 1
589
+ assert (
590
+ i < length
591
+ ), f"Unbalanced square brackets; start = {start}, i = {i}, pattern = {pattern}"
592
+ square_brackets += "]"
593
+ i += 1
594
+ seq.append((square_brackets, False))
595
+ elif c == "|":
596
+ seq.append(("|", False))
597
+ i += 1
598
+ elif c in ("*", "+", "?"):
599
+ seq[-1] = (to_rule(seq[-1]) + c, False)
600
+ i += 1
601
+ elif c == "{":
602
+ curly_brackets = c
603
+ i += 1
604
+ while i < length and pattern[i] != "}":
605
+ curly_brackets += pattern[i]
606
+ i += 1
607
+ assert (
608
+ i < length
609
+ ), f"Unbalanced curly brackets; start = {start}, i = {i}, pattern = {pattern}"
610
+ curly_brackets += "}"
611
+ i += 1
612
+ nums = [s.strip() for s in curly_brackets[1:-1].split(",")]
613
+ min_times = 0
614
+ max_times = None
615
+ try:
616
+ if len(nums) == 1:
617
+ min_times = int(nums[0])
618
+ max_times = min_times
619
+ else:
620
+ assert len(nums) == 2
621
+ min_times = int(nums[0]) if nums[0] else 0
622
+ max_times = int(nums[1]) if nums[1] else None
623
+ except ValueError:
624
+ raise ValueError(
625
+ f"Invalid quantifier {curly_brackets} in /{pattern}/"
626
+ )
627
+
628
+ (sub, sub_is_literal) = seq[-1]
629
+
630
+ if not sub_is_literal:
631
+ id = sub_rule_ids.get(sub)
632
+ if id is None:
633
+ id = self._add_rule(f"{name}-{len(sub_rule_ids) + 1}", sub)
634
+ sub_rule_ids[sub] = id
635
+ sub = id
636
+
637
+ seq[-1] = (
638
+ _build_repetition(
639
+ f'"{sub}"' if sub_is_literal else sub,
640
+ min_times,
641
+ max_times,
642
+ item_rule_is_literal=sub_is_literal,
643
+ ),
644
+ False,
645
+ )
646
+ else:
647
+ literal = ""
648
+ while i < length:
649
+ if pattern[i] == "\\" and i < length - 1:
650
+ next = pattern[i + 1]
651
+ if next in ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS:
652
+ i += 1
653
+ literal += pattern[i]
654
+ i += 1
655
+ else:
656
+ literal += pattern[i : i + 2]
657
+ i += 2
658
+ elif pattern[i] == '"' and not self._raw_pattern:
659
+ literal += '\\"'
660
+ i += 1
661
+ elif pattern[i] not in NON_LITERAL_SET and (
662
+ i == length - 1
663
+ or literal == ""
664
+ or pattern[i + 1] == "."
665
+ or pattern[i + 1] not in NON_LITERAL_SET
666
+ ):
667
+ literal += pattern[i]
668
+ i += 1
669
+ else:
670
+ break
671
+ if literal:
672
+ seq.append((literal, True))
673
+
674
+ return join_seq()
675
+
676
+ return self._add_rule(
677
+ name,
678
+ (
679
+ to_rule(transform())
680
+ if self._raw_pattern
681
+ else '"\\"" ' + to_rule(transform()) + ' "\\"" space'
682
+ ),
683
+ )
684
+
685
+ def _resolve_ref(self, ref):
686
+ ref_name = ref.split("/")[-1]
687
+ if ref_name not in self._rules and ref not in self._refs_being_resolved:
688
+ self._refs_being_resolved.add(ref)
689
+ resolved = self._refs[ref]
690
+ ref_name = self.visit(resolved, ref_name)
691
+ self._refs_being_resolved.remove(ref)
692
+ return ref_name
693
+
694
+ def _generate_constant_rule(self, value):
695
+ return self._format_literal(json.dumps(value))
696
+
697
+ def visit(self, schema, name):
698
+ schema_type = schema.get("type")
699
+ schema_format = schema.get("format")
700
+ rule_name = name + "-" if name in RESERVED_NAMES else name or "root"
701
+
702
+ if (ref := schema.get("$ref")) is not None:
703
+ return self._add_rule(rule_name, self._resolve_ref(ref))
704
+
705
+ elif "oneOf" in schema or "anyOf" in schema:
706
+ return self._add_rule(
707
+ rule_name,
708
+ self._generate_union_rule(name, schema.get("oneOf") or schema["anyOf"]),
709
+ )
710
+
711
+ elif isinstance(schema_type, list):
712
+ return self._add_rule(
713
+ rule_name,
714
+ self._generate_union_rule(name, [{"type": t} for t in schema_type]),
715
+ )
716
+
717
+ elif "const" in schema:
718
+ return self._add_rule(
719
+ rule_name, self._generate_constant_rule(schema["const"])
720
+ )
721
+
722
+ elif "enum" in schema:
723
+ rule = " | ".join((self._generate_constant_rule(v) for v in schema["enum"]))
724
+ return self._add_rule(rule_name, rule)
725
+
726
+ elif schema_type in (None, "object") and (
727
+ "properties" in schema
728
+ or (
729
+ "additionalProperties" in schema
730
+ and schema["additionalProperties"] is not True
731
+ )
732
+ ):
733
+ required = set(schema.get("required", []))
734
+ properties = list(schema.get("properties", {}).items())
735
+ return self._add_rule(
736
+ rule_name,
737
+ self._build_object_rule(
738
+ properties, required, name, schema.get("additionalProperties")
739
+ ),
740
+ )
741
+
742
+ elif schema_type in (None, "object") and "allOf" in schema:
743
+ required = set()
744
+ properties = []
745
+ hybrid_name = name
746
+
747
+ def add_component(comp_schema, is_required):
748
+ if (ref := comp_schema.get("$ref")) is not None:
749
+ comp_schema = self._refs[ref]
750
+
751
+ if "properties" in comp_schema:
752
+ for prop_name, prop_schema in comp_schema["properties"].items():
753
+ properties.append((prop_name, prop_schema))
754
+ if is_required:
755
+ required.add(prop_name)
756
+
757
+ for t in schema["allOf"]:
758
+ if "anyOf" in t:
759
+ for tt in t["anyOf"]:
760
+ add_component(tt, is_required=False)
761
+ else:
762
+ add_component(t, is_required=True)
763
+
764
+ return self._add_rule(
765
+ rule_name,
766
+ self._build_object_rule(
767
+ properties, required, hybrid_name, additional_properties=[]
768
+ ),
769
+ )
770
+
771
+ elif schema_type in (None, "array") and (
772
+ "items" in schema or "prefixItems" in schema
773
+ ):
774
+ items = schema.get("items") or schema["prefixItems"]
775
+ if isinstance(items, list):
776
+ return self._add_rule(
777
+ rule_name,
778
+ '"[" space '
779
+ + ' "," space '.join(
780
+ self.visit(item, f'{name}{"-" if name else ""}tuple-{i}')
781
+ for i, item in enumerate(items)
782
+ )
783
+ + ' "]" space',
784
+ )
785
+ else:
786
+ item_rule_name = self.visit(items, f'{name}{"-" if name else ""}item')
787
+ min_items = schema.get("minItems", 0)
788
+ max_items = schema.get("maxItems")
789
+ return self._add_rule(
790
+ rule_name,
791
+ '"[" space '
792
+ + _build_repetition(
793
+ item_rule_name, min_items, max_items, separator_rule='"," space'
794
+ )
795
+ + ' "]" space',
796
+ )
797
+
798
+ elif schema_type in (None, "string") and "pattern" in schema:
799
+ return self._visit_pattern(schema["pattern"], rule_name)
800
+
801
+ elif schema_type in (None, "string") and re.match(
802
+ r"^uuid[1-5]?$", schema_format or ""
803
+ ):
804
+ return self._add_primitive(
805
+ "root" if rule_name == "root" else schema_format,
806
+ PRIMITIVE_RULES["uuid"],
807
+ )
808
+
809
+ elif (
810
+ schema_type in (None, "string")
811
+ and f"{schema_format}-string" in STRING_FORMAT_RULES
812
+ ):
813
+ prim_name = f"{schema_format}-string"
814
+ return self._add_rule(
815
+ rule_name,
816
+ self._add_primitive(prim_name, STRING_FORMAT_RULES[prim_name]),
817
+ )
818
+
819
+ elif schema_type == "string" and (
820
+ "minLength" in schema or "maxLength" in schema
821
+ ):
822
+ char_rule = self._add_primitive("char", PRIMITIVE_RULES["char"])
823
+ min_len = schema.get("minLength", 0)
824
+ max_len = schema.get("maxLength")
825
+
826
+ return self._add_rule(
827
+ rule_name,
828
+ r'"\"" '
829
+ + _build_repetition(char_rule, min_len, max_len)
830
+ + r' "\"" space',
831
+ )
832
+
833
+ elif (schema_type == "object") or (len(schema) == 0):
834
+ return self._add_rule(
835
+ rule_name, self._add_primitive("object", PRIMITIVE_RULES["object"])
836
+ )
837
+
838
+ else:
839
+ assert schema_type in PRIMITIVE_RULES, f"Unrecognized schema: {schema}"
840
+ # TODO: support minimum, maximum, exclusiveMinimum, exclusiveMaximum at least for zero
841
+ return self._add_primitive(
842
+ "root" if rule_name == "root" else schema_type,
843
+ PRIMITIVE_RULES[schema_type],
844
+ )
845
+
846
+ def _add_primitive(self, name: str, rule: BuiltinRule):
847
+ n = self._add_rule(name, rule.content)
848
+
849
+ for dep in rule.deps:
850
+ dep_rule = PRIMITIVE_RULES.get(dep) or STRING_FORMAT_RULES.get(dep)
851
+ assert dep_rule, f"Rule {dep} not known"
852
+ if dep not in self._rules:
853
+ self._add_primitive(dep, dep_rule)
854
+ return n
855
+
856
+ def _build_object_rule(
857
+ self,
858
+ properties: List[Tuple[str, Any]],
859
+ required: Set[str],
860
+ name: str,
861
+ additional_properties: Union[bool, Any],
862
+ ):
863
+ prop_order = self._prop_order
864
+ # sort by position in prop_order (if specified) then by original order
865
+ sorted_props = [
866
+ kv[0]
867
+ for _, kv in sorted(
868
+ enumerate(properties),
869
+ key=lambda ikv: (prop_order.get(ikv[1][0], len(prop_order)), ikv[0]),
870
+ )
871
+ ]
872
+
873
+ prop_kv_rule_names = {}
874
+ for prop_name, prop_schema in properties:
875
+ prop_rule_name = self.visit(
876
+ prop_schema, f'{name}{"-" if name else ""}{prop_name}'
877
+ )
878
+ prop_kv_rule_names[prop_name] = self._add_rule(
879
+ f'{name}{"-" if name else ""}{prop_name}-kv',
880
+ rf'{self._format_literal(json.dumps(prop_name))} space ":" space {prop_rule_name}',
881
+ )
882
+ required_props = [k for k in sorted_props if k in required]
883
+ optional_props = [k for k in sorted_props if k not in required]
884
+
885
+ if additional_properties == True or isinstance(additional_properties, dict):
886
+ sub_name = f'{name}{"-" if name else ""}additional'
887
+ value_rule = self.visit(
888
+ {} if additional_properties == True else additional_properties,
889
+ f"{sub_name}-value",
890
+ )
891
+ prop_kv_rule_names["*"] = self._add_rule(
892
+ f"{sub_name}-kv",
893
+ self._add_primitive("string", PRIMITIVE_RULES["string"])
894
+ + f' ":" space {value_rule}',
895
+ )
896
+ optional_props.append("*")
897
+
898
+ rule = '"{" space '
899
+ rule += ' "," space '.join(prop_kv_rule_names[k] for k in required_props)
900
+
901
+ if optional_props:
902
+ rule += " ("
903
+ if required_props:
904
+ rule += ' "," space ( '
905
+
906
+ def get_recursive_refs(ks, first_is_optional):
907
+ [k, *rest] = ks
908
+ kv_rule_name = prop_kv_rule_names[k]
909
+ if k == "*":
910
+ res = self._add_rule(
911
+ f'{name}{"-" if name else ""}additional-kvs',
912
+ f'{kv_rule_name} ( "," space ' + kv_rule_name + " )*",
913
+ )
914
+ elif first_is_optional:
915
+ res = f'( "," space {kv_rule_name} )?'
916
+ else:
917
+ res = kv_rule_name
918
+ if len(rest) > 0:
919
+ res += " " + self._add_rule(
920
+ f'{name}{"-" if name else ""}{k}-rest',
921
+ get_recursive_refs(rest, first_is_optional=True),
922
+ )
923
+ return res
924
+
925
+ rule += " | ".join(
926
+ get_recursive_refs(optional_props[i:], first_is_optional=False)
927
+ for i in range(len(optional_props))
928
+ )
929
+ if required_props:
930
+ rule += " )"
931
+ rule += " )?"
932
+
933
+ rule += ' "}" space'
934
+
935
+ return rule
936
+
937
+ def format_grammar(self):
938
+ return "\n".join(
939
+ f"{name} ::= {rule}"
940
+ for name, rule in sorted(self._rules.items(), key=lambda kv: kv[0])
941
+ )
942
+
943
+
944
+ def json_schema_to_gbnf(schema: str, prop_order: Optional[List[str]] = None):
945
+ prop_order = prop_order or []
946
+ schema = json.loads(schema)
947
+ prop_order = {name: idx for idx, name in enumerate(prop_order)}
948
+ converter = SchemaConverter(
949
+ prop_order=prop_order, allow_fetch=False, dotall=False, raw_pattern=False
950
+ )
951
+ schema = converter.resolve_refs(schema, "stdin")
952
+ converter.visit(schema, "")
953
+ return converter.format_grammar()
llama_cpp/llama_speculative.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+
3
+ from typing import Any
4
+
5
+ import numpy as np
6
+ import numpy.typing as npt
7
+
8
+
9
+ class LlamaDraftModel(abc.ABC):
10
+ @abc.abstractmethod
11
+ def __call__(
12
+ self, input_ids: npt.NDArray[np.intc], /, **kwargs: Any
13
+ ) -> npt.NDArray[np.intc]:
14
+ raise NotImplementedError()
15
+
16
+
17
+ class LlamaPromptLookupDecoding(LlamaDraftModel):
18
+ """Based on https://github.com/apoorvumang/prompt-lookup-decoding"""
19
+
20
+ def __init__(self, max_ngram_size: int = 2, num_pred_tokens: int = 10):
21
+ self.max_ngram_size = max_ngram_size
22
+ self.num_pred_tokens = num_pred_tokens
23
+
24
+ @staticmethod
25
+ def find_candidate_pred_tokens(
26
+ input_ids: npt.NDArray[np.intc],
27
+ max_ngram_size: int,
28
+ num_pred_tokens: int,
29
+ ):
30
+ input_length = input_ids.shape[0]
31
+
32
+ for ngram_size in range(min(max_ngram_size, input_length - 1), 0, -1):
33
+ # Create sliding windows of size ngram_size
34
+ windows = np.lib.stride_tricks.sliding_window_view(input_ids, (ngram_size,))
35
+
36
+ # Convert ngram to an array for comparison
37
+ ngram_array = input_ids[-ngram_size:]
38
+
39
+ # Find where the windows match the ngram
40
+ matches = np.all(windows == ngram_array, axis=1)
41
+
42
+ # Get the indices of matches
43
+ match_indices = np.nonzero(matches)[0]
44
+
45
+ # Iterate through match indices to find a valid continuation
46
+ for idx in match_indices:
47
+ start_idx = idx + ngram_size
48
+ end_idx = start_idx + num_pred_tokens
49
+ end_idx = min(end_idx, input_length)
50
+
51
+ if start_idx < end_idx:
52
+ return input_ids[start_idx:end_idx]
53
+
54
+ # If no match is found, return an empty array
55
+ return np.array([], dtype=np.intc)
56
+
57
+ def __call__(
58
+ self, input_ids: npt.NDArray[np.intc], /, **kwargs: Any
59
+ ) -> npt.NDArray[np.intc]:
60
+ return self.find_candidate_pred_tokens(
61
+ input_ids=input_ids,
62
+ max_ngram_size=self.max_ngram_size,
63
+ num_pred_tokens=self.num_pred_tokens,
64
+ )
llama_cpp/llama_tokenizer.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import abc
4
+ from typing import (
5
+ List,
6
+ Optional,
7
+ Any,
8
+ )
9
+
10
+ import llama_cpp
11
+ from llama_cpp.llama_types import List
12
+
13
+
14
+ class BaseLlamaTokenizer(abc.ABC):
15
+ @abc.abstractmethod
16
+ def tokenize(
17
+ self, text: bytes, add_bos: bool = True, special: bool = True
18
+ ) -> List[int]:
19
+ """Tokenize the text into tokens.
20
+
21
+ Args:
22
+ text: The utf-8 encoded string to tokenize.
23
+ add_bos: Whether to add a beginning of sequence token.
24
+ special: Whether to tokenize special tokens.
25
+ """
26
+ raise NotImplementedError
27
+
28
+ @abc.abstractmethod
29
+ def detokenize(
30
+ self,
31
+ tokens: List[int],
32
+ prev_tokens: Optional[List[int]] = None,
33
+ special: bool = False,
34
+ ) -> bytes:
35
+ """Detokenize the tokens into text.
36
+
37
+ Args:
38
+ tokens: The list of tokens to detokenize.
39
+ prev_tokens: The list of previous tokens. Offset mapping will be performed if provided.
40
+ special: Whether to detokenize special tokens.
41
+ """
42
+ raise NotImplementedError
43
+
44
+
45
+ class LlamaTokenizer(BaseLlamaTokenizer):
46
+ def __init__(self, llama: llama_cpp.Llama):
47
+ self._model = llama._model # type: ignore
48
+
49
+ def tokenize(
50
+ self, text: bytes, add_bos: bool = True, special: bool = True
51
+ ) -> List[int]:
52
+ return self._model.tokenize(text, add_bos=add_bos, special=special)
53
+
54
+ def detokenize(
55
+ self,
56
+ tokens: List[int],
57
+ prev_tokens: Optional[List[int]] = None,
58
+ special: bool = False,
59
+ ) -> bytes:
60
+ return self._model.detokenize(tokens, special=special)
61
+
62
+ def encode(
63
+ self, text: str, add_bos: bool = True, special: bool = True
64
+ ) -> List[int]:
65
+ return self.tokenize(
66
+ text.encode("utf-8", errors="ignore"), add_bos=add_bos, special=special
67
+ )
68
+
69
+ def decode(self, tokens: List[int]) -> str:
70
+ return self.detokenize(tokens).decode("utf-8", errors="ignore")
71
+
72
+ @classmethod
73
+ def from_ggml_file(cls, path: str) -> "LlamaTokenizer":
74
+ return cls(llama_cpp.Llama(model_path=path, vocab_only=True))
75
+
76
+
77
+ class LlamaHFTokenizer(BaseLlamaTokenizer):
78
+ def __init__(self, hf_tokenizer: Any):
79
+ self.hf_tokenizer = hf_tokenizer
80
+
81
+ def tokenize(
82
+ self, text: bytes, add_bos: bool = True, special: bool = True
83
+ ) -> List[int]:
84
+ return self.hf_tokenizer.encode(
85
+ text.decode("utf-8", errors="ignore"), add_special_tokens=special
86
+ )
87
+
88
+ def detokenize(
89
+ self,
90
+ tokens: List[int],
91
+ prev_tokens: Optional[List[int]] = None,
92
+ special: bool = False,
93
+ ) -> bytes:
94
+ skip_special_tokens = not special
95
+ if prev_tokens is not None:
96
+ text = self.hf_tokenizer.decode(
97
+ prev_tokens + tokens, skip_special_tokens=skip_special_tokens
98
+ ).encode("utf-8", errors="ignore")
99
+ prev_text = self.hf_tokenizer.decode(
100
+ prev_tokens, skip_special_tokens=skip_special_tokens
101
+ ).encode("utf-8", errors="ignore")
102
+ return text[len(prev_text) :]
103
+ else:
104
+ return self.hf_tokenizer.decode(
105
+ tokens, skip_special_tokens=skip_special_tokens
106
+ ).encode("utf-8", errors="ignore")
107
+
108
+ @classmethod
109
+ def from_pretrained(cls, pretrained_model_name_or_path: str) -> "LlamaHFTokenizer":
110
+ try:
111
+ from transformers import AutoTokenizer
112
+ except ImportError:
113
+ raise ImportError(
114
+ "The `transformers` library is required to use the `HFTokenizer`."
115
+ "You can install it with `pip install transformers`."
116
+ )
117
+ hf_tokenizer = AutoTokenizer.from_pretrained(
118
+ pretrained_model_name_or_path=pretrained_model_name_or_path
119
+ )
120
+ return cls(hf_tokenizer)
llama_cpp/llama_types.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Types and request signatures for OpenAI compatibility
2
+
3
+ NOTE: These types may change to match the OpenAI OpenAPI specification.
4
+
5
+ Based on the OpenAI OpenAPI specification:
6
+ https://github.com/openai/openai-openapi/blob/master/openapi.yaml
7
+
8
+ """
9
+
10
+ from typing import Any, List, Optional, Dict, Union
11
+ from typing_extensions import TypedDict, NotRequired, Literal
12
+
13
+
14
+ # NOTE: Defining this correctly using annotations seems to break pydantic validation.
15
+ # This is a workaround until we can figure out how to do this correctly
16
+ # JsonType = Union[None, int, str, bool, List["JsonType"], Dict[str, "JsonType"]]
17
+ JsonType = Union[None, int, str, bool, List[Any], Dict[str, Any]]
18
+
19
+
20
+ class EmbeddingUsage(TypedDict):
21
+ prompt_tokens: int
22
+ total_tokens: int
23
+
24
+
25
+ class Embedding(TypedDict):
26
+ index: int
27
+ object: str
28
+ embedding: Union[List[float], List[List[float]]]
29
+
30
+
31
+ class CreateEmbeddingResponse(TypedDict):
32
+ object: Literal["list"]
33
+ model: str
34
+ data: List[Embedding]
35
+ usage: EmbeddingUsage
36
+
37
+
38
+ class CompletionLogprobs(TypedDict):
39
+ text_offset: List[int]
40
+ token_logprobs: List[Optional[float]]
41
+ tokens: List[str]
42
+ top_logprobs: List[Optional[Dict[str, float]]]
43
+
44
+
45
+ class CompletionChoice(TypedDict):
46
+ text: str
47
+ index: int
48
+ logprobs: Optional[CompletionLogprobs]
49
+ finish_reason: Optional[Literal["stop", "length"]]
50
+
51
+
52
+ class CompletionUsage(TypedDict):
53
+ prompt_tokens: int
54
+ completion_tokens: int
55
+ total_tokens: int
56
+
57
+
58
+ class CreateCompletionResponse(TypedDict):
59
+ id: str
60
+ object: Literal["text_completion"]
61
+ created: int
62
+ model: str
63
+ choices: List[CompletionChoice]
64
+ usage: NotRequired[CompletionUsage]
65
+
66
+
67
+ class ChatCompletionResponseFunctionCall(TypedDict):
68
+ name: str
69
+ arguments: str
70
+
71
+
72
+ class ChatCompletionResponseMessage(TypedDict):
73
+ content: Optional[str]
74
+ tool_calls: NotRequired["ChatCompletionMessageToolCalls"]
75
+ role: Literal["assistant", "function"] # NOTE: "function" may be incorrect here
76
+ function_call: NotRequired[ChatCompletionResponseFunctionCall] # DEPRECATED
77
+
78
+
79
+ class ChatCompletionFunction(TypedDict):
80
+ name: str
81
+ description: NotRequired[str]
82
+ parameters: Dict[str, JsonType] # TODO: make this more specific
83
+
84
+
85
+ class ChatCompletionTopLogprobToken(TypedDict):
86
+ token: str
87
+ logprob: float
88
+ bytes: Optional[List[int]]
89
+
90
+
91
+ class ChatCompletionLogprobToken(ChatCompletionTopLogprobToken):
92
+ token: str
93
+ logprob: float
94
+ bytes: Optional[List[int]]
95
+ top_logprobs: List[ChatCompletionTopLogprobToken]
96
+
97
+
98
+ class ChatCompletionLogprobs(TypedDict):
99
+ content: Optional[List[ChatCompletionLogprobToken]]
100
+ refusal: Optional[List[ChatCompletionLogprobToken]]
101
+
102
+
103
+ class ChatCompletionResponseChoice(TypedDict):
104
+ index: int
105
+ message: "ChatCompletionResponseMessage"
106
+ logprobs: Optional[ChatCompletionLogprobs]
107
+ finish_reason: Optional[str]
108
+
109
+
110
+ class CreateChatCompletionResponse(TypedDict):
111
+ id: str
112
+ object: Literal["chat.completion"]
113
+ created: int
114
+ model: str
115
+ choices: List["ChatCompletionResponseChoice"]
116
+ usage: CompletionUsage
117
+
118
+
119
+ class ChatCompletionMessageToolCallChunkFunction(TypedDict):
120
+ name: Optional[str]
121
+ arguments: str
122
+
123
+
124
+ class ChatCompletionMessageToolCallChunk(TypedDict):
125
+ index: int
126
+ id: NotRequired[str]
127
+ type: Literal["function"]
128
+ function: ChatCompletionMessageToolCallChunkFunction
129
+
130
+
131
+ class ChatCompletionStreamResponseDeltaEmpty(TypedDict):
132
+ pass
133
+
134
+
135
+ class ChatCompletionStreamResponseDeltaFunctionCall(TypedDict):
136
+ name: str
137
+ arguments: str
138
+
139
+
140
+ class ChatCompletionStreamResponseDelta(TypedDict):
141
+ content: NotRequired[Optional[str]]
142
+ function_call: NotRequired[
143
+ Optional[ChatCompletionStreamResponseDeltaFunctionCall]
144
+ ] # DEPRECATED
145
+ tool_calls: NotRequired[Optional[List[ChatCompletionMessageToolCallChunk]]]
146
+ role: NotRequired[Optional[Literal["system", "user", "assistant", "tool"]]]
147
+
148
+
149
+ class ChatCompletionStreamResponseChoice(TypedDict):
150
+ index: int
151
+ delta: Union[
152
+ ChatCompletionStreamResponseDelta, ChatCompletionStreamResponseDeltaEmpty
153
+ ]
154
+ finish_reason: Optional[Literal["stop", "length", "tool_calls", "function_call"]]
155
+ logprobs: NotRequired[Optional[ChatCompletionLogprobs]]
156
+
157
+
158
+ class CreateChatCompletionStreamResponse(TypedDict):
159
+ id: str
160
+ model: str
161
+ object: Literal["chat.completion.chunk"]
162
+ created: int
163
+ choices: List[ChatCompletionStreamResponseChoice]
164
+
165
+
166
+ class ChatCompletionFunctions(TypedDict):
167
+ name: str
168
+ description: NotRequired[str]
169
+ parameters: Dict[str, JsonType] # TODO: make this more specific
170
+
171
+
172
+ class ChatCompletionFunctionCallOption(TypedDict):
173
+ name: str
174
+
175
+
176
+ class ChatCompletionRequestResponseFormat(TypedDict):
177
+ type: Literal["text", "json_object"]
178
+ schema: NotRequired[
179
+ JsonType
180
+ ] # https://docs.endpoints.anyscale.com/guides/json_mode/
181
+
182
+
183
+ class ChatCompletionRequestMessageContentPartText(TypedDict):
184
+ type: Literal["text"]
185
+ text: str
186
+
187
+
188
+ class ChatCompletionRequestMessageContentPartImageImageUrl(TypedDict):
189
+ url: str
190
+ detail: NotRequired[Literal["auto", "low", "high"]]
191
+
192
+
193
+ class ChatCompletionRequestMessageContentPartImage(TypedDict):
194
+ type: Literal["image_url"]
195
+ image_url: Union[str, ChatCompletionRequestMessageContentPartImageImageUrl]
196
+
197
+
198
+ ChatCompletionRequestMessageContentPart = Union[
199
+ ChatCompletionRequestMessageContentPartText,
200
+ ChatCompletionRequestMessageContentPartImage,
201
+ ]
202
+
203
+
204
+ class ChatCompletionRequestSystemMessage(TypedDict):
205
+ role: Literal["system"]
206
+ content: Optional[str]
207
+
208
+
209
+ class ChatCompletionRequestUserMessage(TypedDict):
210
+ role: Literal["user"]
211
+ content: Optional[Union[str, List[ChatCompletionRequestMessageContentPart]]]
212
+
213
+
214
+ class ChatCompletionMessageToolCallFunction(TypedDict):
215
+ name: str
216
+ arguments: str
217
+
218
+
219
+ class ChatCompletionMessageToolCall(TypedDict):
220
+ id: str
221
+ type: Literal["function"]
222
+ function: ChatCompletionMessageToolCallFunction
223
+
224
+
225
+ ChatCompletionMessageToolCalls = List[ChatCompletionMessageToolCall]
226
+
227
+
228
+ class ChatCompletionRequestAssistantMessageFunctionCall(TypedDict):
229
+ name: str
230
+ arguments: str
231
+
232
+
233
+ class ChatCompletionRequestAssistantMessage(TypedDict):
234
+ role: Literal["assistant"]
235
+ content: NotRequired[str]
236
+ tool_calls: NotRequired[ChatCompletionMessageToolCalls]
237
+ function_call: NotRequired[
238
+ ChatCompletionRequestAssistantMessageFunctionCall
239
+ ] # DEPRECATED
240
+
241
+
242
+ class ChatCompletionRequestToolMessage(TypedDict):
243
+ role: Literal["tool"]
244
+ content: Optional[str]
245
+ tool_call_id: str
246
+
247
+
248
+ class ChatCompletionRequestFunctionMessage(TypedDict):
249
+ role: Literal["function"]
250
+ content: Optional[str]
251
+ name: str
252
+
253
+
254
+ ChatCompletionRequestMessage = Union[
255
+ ChatCompletionRequestSystemMessage,
256
+ ChatCompletionRequestUserMessage,
257
+ ChatCompletionRequestAssistantMessage,
258
+ ChatCompletionRequestUserMessage,
259
+ ChatCompletionRequestToolMessage,
260
+ ChatCompletionRequestFunctionMessage,
261
+ ]
262
+
263
+
264
+ class ChatCompletionRequestFunctionCallOption(TypedDict):
265
+ name: str
266
+
267
+
268
+ ChatCompletionRequestFunctionCall = Union[
269
+ Literal["none", "auto"], ChatCompletionRequestFunctionCallOption
270
+ ]
271
+
272
+ ChatCompletionFunctionParameters = Dict[str, JsonType] # TODO: make this more specific
273
+
274
+
275
+ class ChatCompletionToolFunction(TypedDict):
276
+ name: str
277
+ description: NotRequired[str]
278
+ parameters: ChatCompletionFunctionParameters
279
+
280
+
281
+ class ChatCompletionTool(TypedDict):
282
+ type: Literal["function"]
283
+ function: ChatCompletionToolFunction
284
+
285
+
286
+ class ChatCompletionNamedToolChoiceFunction(TypedDict):
287
+ name: str
288
+
289
+
290
+ class ChatCompletionNamedToolChoice(TypedDict):
291
+ type: Literal["function"]
292
+ function: ChatCompletionNamedToolChoiceFunction
293
+
294
+
295
+ ChatCompletionToolChoiceOption = Union[
296
+ Literal["none", "auto", "required"], ChatCompletionNamedToolChoice
297
+ ]
298
+
299
+
300
+ # NOTE: The following type names are not part of the OpenAI OpenAPI specification
301
+ # and will be removed in a future major release.
302
+
303
+ EmbeddingData = Embedding
304
+ CompletionChunk = CreateCompletionResponse
305
+ Completion = CreateCompletionResponse
306
+ CreateCompletionStreamResponse = CreateCompletionResponse
307
+ ChatCompletionMessage = ChatCompletionResponseMessage
308
+ ChatCompletionChoice = ChatCompletionResponseChoice
309
+ ChatCompletion = CreateChatCompletionResponse
310
+ ChatCompletionChunkDeltaEmpty = ChatCompletionStreamResponseDeltaEmpty
311
+ ChatCompletionChunkChoice = ChatCompletionStreamResponseChoice
312
+ ChatCompletionChunkDelta = ChatCompletionStreamResponseDelta
313
+ ChatCompletionChunk = CreateChatCompletionStreamResponse
314
+ ChatCompletionStreamResponse = CreateChatCompletionStreamResponse
315
+ ChatCompletionResponseFunction = ChatCompletionFunction
316
+ ChatCompletionFunctionCall = ChatCompletionResponseFunctionCall
llama_cpp/llava_cpp.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ from ctypes import (
5
+ c_bool,
6
+ c_char_p,
7
+ c_int,
8
+ c_uint8,
9
+ c_float,
10
+ c_void_p,
11
+ POINTER,
12
+ _Pointer, # type: ignore
13
+ Structure,
14
+ )
15
+ import pathlib
16
+ from typing import (
17
+ Union,
18
+ NewType,
19
+ Optional,
20
+ TYPE_CHECKING,
21
+ )
22
+
23
+ import llama_cpp.llama_cpp as llama_cpp
24
+
25
+ from llama_cpp._ctypes_extensions import (
26
+ load_shared_library,
27
+ ctypes_function_for_shared_library,
28
+ )
29
+
30
+ if TYPE_CHECKING:
31
+ from llama_cpp._ctypes_extensions import (
32
+ CtypesArray,
33
+ )
34
+
35
+
36
+ # Specify the base name of the shared library to load
37
+ _libllava_base_name = "llava"
38
+ _libllava_override_path = os.environ.get("LLAVA_CPP_LIB")
39
+ _libllava_base_path = pathlib.Path(os.path.abspath(os.path.dirname(__file__))) / "lib" if _libllava_override_path is None else pathlib.Path()
40
+
41
+ # Load the library
42
+ _libllava = load_shared_library(_libllava_base_name, _libllava_base_path)
43
+
44
+ ctypes_function = ctypes_function_for_shared_library(_libllava)
45
+
46
+
47
+ ################################################
48
+ # llava.h
49
+ ################################################
50
+
51
+ # struct clip_ctx;
52
+ clip_ctx_p = NewType("clip_ctx_p", int)
53
+ clip_ctx_p_ctypes = c_void_p
54
+
55
+
56
+ # struct llava_image_embed {
57
+ # float * embed;
58
+ # int n_image_pos;
59
+ # };
60
+ class llava_image_embed(Structure):
61
+ _fields_ = [
62
+ ("embed", POINTER(c_float)),
63
+ ("n_image_pos", c_int),
64
+ ]
65
+
66
+
67
+ # /** sanity check for clip <-> llava embed size match */
68
+ # LLAVA_API bool llava_validate_embed_size(const llama_context * ctx_llama, const clip_ctx * ctx_clip);
69
+ @ctypes_function(
70
+ "llava_validate_embed_size",
71
+ [llama_cpp.llama_context_p_ctypes, clip_ctx_p_ctypes],
72
+ c_bool,
73
+ )
74
+ def llava_validate_embed_size(
75
+ ctx_llama: llama_cpp.llama_context_p, ctx_clip: clip_ctx_p, /
76
+ ) -> bool:
77
+ ...
78
+
79
+
80
+ # /** build an image embed from image file bytes */
81
+ # LLAVA_API struct llava_image_embed * llava_image_embed_make_with_bytes(struct clip_ctx * ctx_clip, int n_threads, const unsigned char * image_bytes, int image_bytes_length);
82
+ @ctypes_function(
83
+ "llava_image_embed_make_with_bytes",
84
+ [clip_ctx_p_ctypes, c_int, POINTER(c_uint8), c_int],
85
+ POINTER(llava_image_embed),
86
+ )
87
+ def llava_image_embed_make_with_bytes(
88
+ ctx_clip: clip_ctx_p,
89
+ n_threads: Union[c_int, int],
90
+ image_bytes: CtypesArray[c_uint8],
91
+ image_bytes_length: Union[c_int, int],
92
+ /,
93
+ ) -> "_Pointer[llava_image_embed]":
94
+ ...
95
+
96
+
97
+ # /** build an image embed from a path to an image filename */
98
+ # LLAVA_API struct llava_image_embed * llava_image_embed_make_with_filename(struct clip_ctx * ctx_clip, int n_threads, const char * image_path);
99
+ @ctypes_function(
100
+ "llava_image_embed_make_with_filename",
101
+ [clip_ctx_p_ctypes, c_int, c_char_p],
102
+ POINTER(llava_image_embed),
103
+ )
104
+ def llava_image_embed_make_with_filename(
105
+ ctx_clip: clip_ctx_p, n_threads: Union[c_int, int], image_path: bytes, /
106
+ ) -> "_Pointer[llava_image_embed]":
107
+ ...
108
+
109
+
110
+ # LLAVA_API void llava_image_embed_free(struct llava_image_embed * embed);
111
+ # /** free an embedding made with llava_image_embed_make_* */
112
+ @ctypes_function("llava_image_embed_free", [POINTER(llava_image_embed)], None)
113
+ def llava_image_embed_free(embed: "_Pointer[llava_image_embed]", /):
114
+ ...
115
+
116
+
117
+ # /** write the image represented by embed into the llama context with batch size n_batch, starting at context pos n_past. on completion, n_past points to the next position in the context after the image embed. */
118
+ # LLAVA_API bool llava_eval_image_embed(struct llama_context * ctx_llama, const struct llava_image_embed * embed, int n_batch, int * n_past);
119
+ @ctypes_function(
120
+ "llava_eval_image_embed",
121
+ [
122
+ llama_cpp.llama_context_p_ctypes,
123
+ POINTER(llava_image_embed),
124
+ c_int,
125
+ POINTER(c_int),
126
+ ],
127
+ c_bool,
128
+ )
129
+ def llava_eval_image_embed(
130
+ ctx_llama: llama_cpp.llama_context_p,
131
+ embed: "_Pointer[llava_image_embed]",
132
+ n_batch: Union[c_int, int],
133
+ n_past: "_Pointer[c_int]",
134
+ /,
135
+ ) -> bool:
136
+ ...
137
+
138
+
139
+ ################################################
140
+ # clip.h
141
+ ################################################
142
+
143
+
144
+ # /** load mmproj model */
145
+ # CLIP_API struct clip_ctx * clip_model_load (const char * fname, int verbosity);
146
+ @ctypes_function("clip_model_load", [c_char_p, c_int], clip_ctx_p_ctypes)
147
+ def clip_model_load(
148
+ fname: bytes, verbosity: Union[c_int, int], /
149
+ ) -> Optional[clip_ctx_p]:
150
+ ...
151
+
152
+
153
+ # /** free mmproj model */
154
+ # CLIP_API void clip_free(struct clip_ctx * ctx);
155
+ @ctypes_function("clip_free", [clip_ctx_p_ctypes], None)
156
+ def clip_free(ctx: clip_ctx_p, /):
157
+ ...
158
+
llama_cpp/py.typed ADDED
File without changes
llama_cpp/server/__init__.py ADDED
File without changes
llama_cpp/server/__main__.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Example FastAPI server for llama.cpp.
2
+
3
+ To run this example:
4
+
5
+ ```bash
6
+ pip install fastapi uvicorn sse-starlette pydantic-settings
7
+ export MODEL=../models/7B/...
8
+ ```
9
+
10
+ Then run:
11
+ ```
12
+ uvicorn llama_cpp.server.app:create_app --reload
13
+ ```
14
+
15
+ or
16
+
17
+ ```
18
+ python3 -m llama_cpp.server
19
+ ```
20
+
21
+ Then visit http://localhost:8000/docs to see the interactive API docs.
22
+
23
+ """
24
+
25
+ from __future__ import annotations
26
+
27
+ import os
28
+ import sys
29
+ import argparse
30
+
31
+ import uvicorn
32
+
33
+ from llama_cpp.server.app import create_app
34
+ from llama_cpp.server.settings import (
35
+ Settings,
36
+ ServerSettings,
37
+ ModelSettings,
38
+ ConfigFileSettings,
39
+ )
40
+ from llama_cpp.server.cli import add_args_from_model, parse_model_from_args
41
+
42
+
43
+ def main():
44
+ description = "🦙 Llama.cpp python server. Host your own LLMs!🚀"
45
+ parser = argparse.ArgumentParser(description=description)
46
+
47
+ add_args_from_model(parser, Settings)
48
+ parser.add_argument(
49
+ "--config_file",
50
+ type=str,
51
+ help="Path to a config file to load.",
52
+ )
53
+ server_settings: ServerSettings | None = None
54
+ model_settings: list[ModelSettings] = []
55
+ args = parser.parse_args()
56
+ try:
57
+ # Load server settings from config_file if provided
58
+ config_file = os.environ.get("CONFIG_FILE", args.config_file)
59
+ if config_file:
60
+ if not os.path.exists(config_file):
61
+ raise ValueError(f"Config file {config_file} not found!")
62
+ with open(config_file, "rb") as f:
63
+ # Check if yaml file
64
+ if config_file.endswith(".yaml") or config_file.endswith(".yml"):
65
+ import yaml
66
+ import json
67
+
68
+ config_file_settings = ConfigFileSettings.model_validate_json(
69
+ json.dumps(yaml.safe_load(f))
70
+ )
71
+ else:
72
+ config_file_settings = ConfigFileSettings.model_validate_json(
73
+ f.read()
74
+ )
75
+ server_settings = ServerSettings.model_validate(config_file_settings)
76
+ model_settings = config_file_settings.models
77
+ else:
78
+ server_settings = parse_model_from_args(ServerSettings, args)
79
+ model_settings = [parse_model_from_args(ModelSettings, args)]
80
+ except Exception as e:
81
+ print(e, file=sys.stderr)
82
+ parser.print_help()
83
+ sys.exit(1)
84
+ assert server_settings is not None
85
+ assert model_settings is not None
86
+ app = create_app(
87
+ server_settings=server_settings,
88
+ model_settings=model_settings,
89
+ )
90
+ uvicorn.run(
91
+ app,
92
+ host=os.getenv("HOST", server_settings.host),
93
+ port=int(os.getenv("PORT", server_settings.port)),
94
+ ssl_keyfile=server_settings.ssl_keyfile,
95
+ ssl_certfile=server_settings.ssl_certfile,
96
+ )
97
+
98
+
99
+ if __name__ == "__main__":
100
+ main()
llama_cpp/server/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (239 Bytes). View file
 
llama_cpp/server/__pycache__/__main__.cpython-310.pyc ADDED
Binary file (2.37 kB). View file
 
llama_cpp/server/__pycache__/app.cpython-310.pyc ADDED
Binary file (12.9 kB). View file
 
llama_cpp/server/__pycache__/cli.cpython-310.pyc ADDED
Binary file (3.21 kB). View file
 
llama_cpp/server/__pycache__/errors.cpython-310.pyc ADDED
Binary file (5.54 kB). View file
 
llama_cpp/server/__pycache__/model.cpython-310.pyc ADDED
Binary file (6.45 kB). View file
 
llama_cpp/server/__pycache__/settings.cpython-310.pyc ADDED
Binary file (7.32 kB). View file