Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +2 -35
- chatapp.py +30 -0
- llama_cpp/.DS_Store +0 -0
- llama_cpp/__init__.py +4 -0
- llama_cpp/__pycache__/__init__.cpython-310.pyc +0 -0
- llama_cpp/__pycache__/_ctypes_extensions.cpython-310.pyc +0 -0
- llama_cpp/__pycache__/_ggml.cpython-310.pyc +0 -0
- llama_cpp/__pycache__/_internals.cpython-310.pyc +0 -0
- llama_cpp/__pycache__/_logger.cpython-310.pyc +0 -0
- llama_cpp/__pycache__/_utils.cpython-310.pyc +0 -0
- llama_cpp/__pycache__/llama.cpython-310.pyc +0 -0
- llama_cpp/__pycache__/llama_cache.cpython-310.pyc +0 -0
- llama_cpp/__pycache__/llama_chat_format.cpython-310.pyc +0 -0
- llama_cpp/__pycache__/llama_cpp.cpython-310.pyc +0 -0
- llama_cpp/__pycache__/llama_grammar.cpython-310.pyc +0 -0
- llama_cpp/__pycache__/llama_speculative.cpython-310.pyc +0 -0
- llama_cpp/__pycache__/llama_tokenizer.cpython-310.pyc +0 -0
- llama_cpp/__pycache__/llama_types.cpython-310.pyc +0 -0
- llama_cpp/__pycache__/llava_cpp.cpython-310.pyc +0 -0
- llama_cpp/_ctypes_extensions.py +131 -0
- llama_cpp/_ggml.py +12 -0
- llama_cpp/_internals.py +879 -0
- llama_cpp/_logger.py +47 -0
- llama_cpp/_utils.py +78 -0
- llama_cpp/lib/libggml-base.dylib +3 -0
- llama_cpp/lib/libggml-blas.dylib +0 -0
- llama_cpp/lib/libggml-cpu.dylib +3 -0
- llama_cpp/lib/libggml-metal.dylib +3 -0
- llama_cpp/lib/libggml.dylib +0 -0
- llama_cpp/lib/libllama.dylib +3 -0
- llama_cpp/lib/libllava.dylib +3 -0
- llama_cpp/llama.py +2418 -0
- llama_cpp/llama_cache.py +155 -0
- llama_cpp/llama_chat_format.py +0 -0
- llama_cpp/llama_cpp.py +0 -0
- llama_cpp/llama_grammar.py +953 -0
- llama_cpp/llama_speculative.py +64 -0
- llama_cpp/llama_tokenizer.py +120 -0
- llama_cpp/llama_types.py +316 -0
- llama_cpp/llava_cpp.py +158 -0
- llama_cpp/py.typed +0 -0
- llama_cpp/server/__init__.py +0 -0
- llama_cpp/server/__main__.py +100 -0
- llama_cpp/server/__pycache__/__init__.cpython-310.pyc +0 -0
- llama_cpp/server/__pycache__/__main__.cpython-310.pyc +0 -0
- llama_cpp/server/__pycache__/app.cpython-310.pyc +0 -0
- llama_cpp/server/__pycache__/cli.cpython-310.pyc +0 -0
- llama_cpp/server/__pycache__/errors.cpython-310.pyc +0 -0
- llama_cpp/server/__pycache__/model.cpython-310.pyc +0 -0
- llama_cpp/server/__pycache__/settings.cpython-310.pyc +0 -0
.gitattributes
CHANGED
|
@@ -1,35 +1,2 @@
|
|
| 1 |
-
*.
|
| 2 |
-
|
| 3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 1 |
+
*.dylib filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
tinyllama-1.1B-q4.gguf filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
chatapp.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
from llama_cpp import Llama
|
| 3 |
+
|
| 4 |
+
if len(sys.argv) < 2:
|
| 5 |
+
print("Model path not provided as argument")
|
| 6 |
+
print("Eg. Usage: $ python chatapp.py path/to/model.gguf")
|
| 7 |
+
sys.exit(1)
|
| 8 |
+
|
| 9 |
+
llm = Llama(
|
| 10 |
+
model_path=sys.argv[1],
|
| 11 |
+
n_ctx=512,
|
| 12 |
+
n_threads=4,
|
| 13 |
+
n_gpu_layers=1,
|
| 14 |
+
verbose=False
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
print("Chat with Llama (type 'exit' to quit)\n")
|
| 18 |
+
|
| 19 |
+
while True:
|
| 20 |
+
user_input = input("You: ")
|
| 21 |
+
if user_input.lower() in ["exit", "quit"]: break
|
| 22 |
+
|
| 23 |
+
prompt = f"### Human: {user_input}\n### Assistant:"
|
| 24 |
+
output = llm(
|
| 25 |
+
prompt,
|
| 26 |
+
max_tokens=100,
|
| 27 |
+
stop=["###", "### Human:", "\n###"]
|
| 28 |
+
)
|
| 29 |
+
response = output["choices"][0]["text"].strip()
|
| 30 |
+
print("Bot:", response)
|
llama_cpp/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
llama_cpp/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .llama_cpp import *
|
| 2 |
+
from .llama import *
|
| 3 |
+
|
| 4 |
+
__version__ = "0.3.9"
|
llama_cpp/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (301 Bytes). View file
|
|
|
llama_cpp/__pycache__/_ctypes_extensions.cpython-310.pyc
ADDED
|
Binary file (3.51 kB). View file
|
|
|
llama_cpp/__pycache__/_ggml.cpython-310.pyc
ADDED
|
Binary file (632 Bytes). View file
|
|
|
llama_cpp/__pycache__/_internals.cpython-310.pyc
ADDED
|
Binary file (29.6 kB). View file
|
|
|
llama_cpp/__pycache__/_logger.cpython-310.pyc
ADDED
|
Binary file (1.11 kB). View file
|
|
|
llama_cpp/__pycache__/_utils.cpython-310.pyc
ADDED
|
Binary file (2.53 kB). View file
|
|
|
llama_cpp/__pycache__/llama.cpython-310.pyc
ADDED
|
Binary file (56.5 kB). View file
|
|
|
llama_cpp/__pycache__/llama_cache.cpython-310.pyc
ADDED
|
Binary file (5.84 kB). View file
|
|
|
llama_cpp/__pycache__/llama_chat_format.cpython-310.pyc
ADDED
|
Binary file (76.7 kB). View file
|
|
|
llama_cpp/__pycache__/llama_cpp.cpython-310.pyc
ADDED
|
Binary file (70.1 kB). View file
|
|
|
llama_cpp/__pycache__/llama_grammar.cpython-310.pyc
ADDED
|
Binary file (25.3 kB). View file
|
|
|
llama_cpp/__pycache__/llama_speculative.cpython-310.pyc
ADDED
|
Binary file (2.2 kB). View file
|
|
|
llama_cpp/__pycache__/llama_tokenizer.cpython-310.pyc
ADDED
|
Binary file (4.53 kB). View file
|
|
|
llama_cpp/__pycache__/llama_types.cpython-310.pyc
ADDED
|
Binary file (10.9 kB). View file
|
|
|
llama_cpp/__pycache__/llava_cpp.cpython-310.pyc
ADDED
|
Binary file (2.99 kB). View file
|
|
|
llama_cpp/_ctypes_extensions.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import sys
|
| 4 |
+
import os
|
| 5 |
+
import ctypes
|
| 6 |
+
import functools
|
| 7 |
+
import pathlib
|
| 8 |
+
|
| 9 |
+
from typing import (
|
| 10 |
+
Any,
|
| 11 |
+
Callable,
|
| 12 |
+
List,
|
| 13 |
+
Union,
|
| 14 |
+
Optional,
|
| 15 |
+
TYPE_CHECKING,
|
| 16 |
+
TypeVar,
|
| 17 |
+
Generic,
|
| 18 |
+
)
|
| 19 |
+
from typing_extensions import TypeAlias
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
# Load the library
|
| 23 |
+
def load_shared_library(lib_base_name: str, base_path: pathlib.Path):
|
| 24 |
+
"""Platform independent shared library loader"""
|
| 25 |
+
# Searching for the library in the current directory under the name "libllama" (default name
|
| 26 |
+
# for llamacpp) and "llama" (default name for this repo)
|
| 27 |
+
lib_paths: List[pathlib.Path] = []
|
| 28 |
+
# Determine the file extension based on the platform
|
| 29 |
+
if sys.platform.startswith("linux") or sys.platform.startswith("freebsd"):
|
| 30 |
+
lib_paths += [
|
| 31 |
+
base_path / f"lib{lib_base_name}.so",
|
| 32 |
+
]
|
| 33 |
+
elif sys.platform == "darwin":
|
| 34 |
+
lib_paths += [
|
| 35 |
+
base_path / f"lib{lib_base_name}.so",
|
| 36 |
+
base_path / f"lib{lib_base_name}.dylib",
|
| 37 |
+
]
|
| 38 |
+
elif sys.platform == "win32":
|
| 39 |
+
lib_paths += [
|
| 40 |
+
base_path / f"{lib_base_name}.dll",
|
| 41 |
+
base_path / f"lib{lib_base_name}.dll",
|
| 42 |
+
]
|
| 43 |
+
else:
|
| 44 |
+
raise RuntimeError("Unsupported platform")
|
| 45 |
+
|
| 46 |
+
cdll_args = dict() # type: ignore
|
| 47 |
+
|
| 48 |
+
# Add the library directory to the DLL search path on Windows (if needed)
|
| 49 |
+
if sys.platform == "win32":
|
| 50 |
+
os.add_dll_directory(str(base_path))
|
| 51 |
+
os.environ["PATH"] = str(base_path) + os.pathsep + os.environ["PATH"]
|
| 52 |
+
|
| 53 |
+
if sys.platform == "win32" and sys.version_info >= (3, 8):
|
| 54 |
+
os.add_dll_directory(str(base_path))
|
| 55 |
+
if "CUDA_PATH" in os.environ:
|
| 56 |
+
os.add_dll_directory(os.path.join(os.environ["CUDA_PATH"], "bin"))
|
| 57 |
+
os.add_dll_directory(os.path.join(os.environ["CUDA_PATH"], "lib"))
|
| 58 |
+
if "HIP_PATH" in os.environ:
|
| 59 |
+
os.add_dll_directory(os.path.join(os.environ["HIP_PATH"], "bin"))
|
| 60 |
+
os.add_dll_directory(os.path.join(os.environ["HIP_PATH"], "lib"))
|
| 61 |
+
cdll_args["winmode"] = ctypes.RTLD_GLOBAL
|
| 62 |
+
|
| 63 |
+
# Try to load the shared library, handling potential errors
|
| 64 |
+
for lib_path in lib_paths:
|
| 65 |
+
if lib_path.exists():
|
| 66 |
+
try:
|
| 67 |
+
return ctypes.CDLL(str(lib_path), **cdll_args) # type: ignore
|
| 68 |
+
except Exception as e:
|
| 69 |
+
raise RuntimeError(f"Failed to load shared library '{lib_path}': {e}")
|
| 70 |
+
|
| 71 |
+
raise FileNotFoundError(
|
| 72 |
+
f"Shared library with base name '{lib_base_name}' not found"
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
# ctypes sane type hint helpers
|
| 77 |
+
#
|
| 78 |
+
# - Generic Pointer and Array types
|
| 79 |
+
# - PointerOrRef type with a type hinted byref function
|
| 80 |
+
#
|
| 81 |
+
# NOTE: Only use these for static type checking not for runtime checks
|
| 82 |
+
# no good will come of that
|
| 83 |
+
|
| 84 |
+
if TYPE_CHECKING:
|
| 85 |
+
CtypesCData = TypeVar("CtypesCData", bound=ctypes._CData) # type: ignore
|
| 86 |
+
|
| 87 |
+
CtypesArray: TypeAlias = ctypes.Array[CtypesCData] # type: ignore
|
| 88 |
+
|
| 89 |
+
CtypesPointer: TypeAlias = ctypes._Pointer[CtypesCData] # type: ignore
|
| 90 |
+
|
| 91 |
+
CtypesVoidPointer: TypeAlias = ctypes.c_void_p
|
| 92 |
+
|
| 93 |
+
class CtypesRef(Generic[CtypesCData]):
|
| 94 |
+
pass
|
| 95 |
+
|
| 96 |
+
CtypesPointerOrRef: TypeAlias = Union[
|
| 97 |
+
CtypesPointer[CtypesCData], CtypesRef[CtypesCData]
|
| 98 |
+
]
|
| 99 |
+
|
| 100 |
+
CtypesFuncPointer: TypeAlias = ctypes._FuncPointer # type: ignore
|
| 101 |
+
|
| 102 |
+
F = TypeVar("F", bound=Callable[..., Any])
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def ctypes_function_for_shared_library(lib: ctypes.CDLL):
|
| 106 |
+
"""Decorator for defining ctypes functions with type hints"""
|
| 107 |
+
|
| 108 |
+
def ctypes_function(
|
| 109 |
+
name: str, argtypes: List[Any], restype: Any, enabled: bool = True
|
| 110 |
+
):
|
| 111 |
+
def decorator(f: F) -> F:
|
| 112 |
+
if enabled:
|
| 113 |
+
func = getattr(lib, name)
|
| 114 |
+
func.argtypes = argtypes
|
| 115 |
+
func.restype = restype
|
| 116 |
+
functools.wraps(f)(func)
|
| 117 |
+
return func
|
| 118 |
+
else:
|
| 119 |
+
return f
|
| 120 |
+
|
| 121 |
+
return decorator
|
| 122 |
+
|
| 123 |
+
return ctypes_function
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def _byref(obj: CtypesCData, offset: Optional[int] = None) -> CtypesRef[CtypesCData]:
|
| 127 |
+
"""Type-annotated version of ctypes.byref"""
|
| 128 |
+
...
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
byref = _byref if TYPE_CHECKING else ctypes.byref
|
llama_cpp/_ggml.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Internal module use at your own risk
|
| 2 |
+
|
| 3 |
+
This module provides a minimal interface for working with ggml tensors from llama-cpp-python
|
| 4 |
+
"""
|
| 5 |
+
import os
|
| 6 |
+
import pathlib
|
| 7 |
+
|
| 8 |
+
import llama_cpp._ctypes_extensions as ctypes_ext
|
| 9 |
+
|
| 10 |
+
libggml_base_path = pathlib.Path(os.path.abspath(os.path.dirname(__file__))) / "lib"
|
| 11 |
+
libggml = ctypes_ext.load_shared_library("ggml", libggml_base_path)
|
| 12 |
+
|
llama_cpp/_internals.py
ADDED
|
@@ -0,0 +1,879 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import ctypes
|
| 5 |
+
|
| 6 |
+
from typing import (
|
| 7 |
+
Dict,
|
| 8 |
+
List,
|
| 9 |
+
Tuple,
|
| 10 |
+
Optional,
|
| 11 |
+
Sequence,
|
| 12 |
+
)
|
| 13 |
+
from dataclasses import dataclass, field
|
| 14 |
+
from contextlib import ExitStack
|
| 15 |
+
|
| 16 |
+
import numpy as np
|
| 17 |
+
import numpy.typing as npt
|
| 18 |
+
|
| 19 |
+
from .llama_types import *
|
| 20 |
+
from .llama_grammar import LlamaGrammar
|
| 21 |
+
from ._utils import suppress_stdout_stderr
|
| 22 |
+
|
| 23 |
+
import llama_cpp.llama_cpp as llama_cpp
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
# Python wrappers over llama.h structs
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class LlamaModel:
|
| 30 |
+
"""Intermediate Python wrapper for a llama.cpp llama_model.
|
| 31 |
+
NOTE: For stability it's recommended you use the Llama class instead."""
|
| 32 |
+
|
| 33 |
+
def __init__(
|
| 34 |
+
self,
|
| 35 |
+
*,
|
| 36 |
+
path_model: str,
|
| 37 |
+
params: llama_cpp.llama_model_params,
|
| 38 |
+
verbose: bool = True,
|
| 39 |
+
):
|
| 40 |
+
self.path_model = path_model
|
| 41 |
+
self.params = params
|
| 42 |
+
self.verbose = verbose
|
| 43 |
+
self._exit_stack = ExitStack()
|
| 44 |
+
|
| 45 |
+
model = None
|
| 46 |
+
|
| 47 |
+
if not os.path.exists(path_model):
|
| 48 |
+
raise ValueError(f"Model path does not exist: {path_model}")
|
| 49 |
+
|
| 50 |
+
with suppress_stdout_stderr(disable=verbose):
|
| 51 |
+
model = llama_cpp.llama_load_model_from_file(
|
| 52 |
+
self.path_model.encode("utf-8"), self.params
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
if model is None:
|
| 56 |
+
raise ValueError(f"Failed to load model from file: {path_model}")
|
| 57 |
+
|
| 58 |
+
vocab = llama_cpp.llama_model_get_vocab(model)
|
| 59 |
+
|
| 60 |
+
if vocab is None:
|
| 61 |
+
raise ValueError(f"Failed to get vocab from model: {path_model}")
|
| 62 |
+
|
| 63 |
+
self.model = model
|
| 64 |
+
self.vocab = vocab
|
| 65 |
+
|
| 66 |
+
def free_model():
|
| 67 |
+
if self.model is None:
|
| 68 |
+
return
|
| 69 |
+
llama_cpp.llama_free_model(self.model)
|
| 70 |
+
self.model = None
|
| 71 |
+
|
| 72 |
+
self._exit_stack.callback(free_model)
|
| 73 |
+
|
| 74 |
+
def close(self):
|
| 75 |
+
self._exit_stack.close()
|
| 76 |
+
|
| 77 |
+
def __del__(self):
|
| 78 |
+
self.close()
|
| 79 |
+
|
| 80 |
+
def vocab_type(self) -> int:
|
| 81 |
+
return llama_cpp.llama_vocab_type(self.model)
|
| 82 |
+
|
| 83 |
+
def n_vocab(self) -> int:
|
| 84 |
+
return llama_cpp.llama_n_vocab(self.vocab)
|
| 85 |
+
|
| 86 |
+
def n_ctx_train(self) -> int:
|
| 87 |
+
return llama_cpp.llama_n_ctx_train(self.model)
|
| 88 |
+
|
| 89 |
+
def n_embd(self) -> int:
|
| 90 |
+
return llama_cpp.llama_n_embd(self.model)
|
| 91 |
+
|
| 92 |
+
def rope_freq_scale_train(self) -> float:
|
| 93 |
+
return llama_cpp.llama_model_rope_freq_scale_train(self.model)
|
| 94 |
+
|
| 95 |
+
def desc(self) -> str:
|
| 96 |
+
buf = ctypes.create_string_buffer(1024)
|
| 97 |
+
llama_cpp.llama_model_desc(self.model, buf, 1024)
|
| 98 |
+
return buf.value.decode("utf-8")
|
| 99 |
+
|
| 100 |
+
def size(self) -> int:
|
| 101 |
+
return llama_cpp.llama_model_size(self.model)
|
| 102 |
+
|
| 103 |
+
def n_params(self) -> int:
|
| 104 |
+
return llama_cpp.llama_model_n_params(self.model)
|
| 105 |
+
|
| 106 |
+
def get_tensor(self, name: str) -> ctypes.c_void_p:
|
| 107 |
+
raise NotImplementedError("get_tensor is not implemented in llama.cpp")
|
| 108 |
+
|
| 109 |
+
# Vocab
|
| 110 |
+
|
| 111 |
+
def token_get_text(self, token: int) -> str:
|
| 112 |
+
return llama_cpp.llama_token_get_text(self.vocab, token).decode("utf-8")
|
| 113 |
+
|
| 114 |
+
def token_get_score(self, token: int) -> float:
|
| 115 |
+
return llama_cpp.llama_token_get_score(self.vocab, token)
|
| 116 |
+
|
| 117 |
+
def token_get_attr(self, token: int) -> int:
|
| 118 |
+
return llama_cpp.llama_token_get_attr(self.vocab, token)
|
| 119 |
+
|
| 120 |
+
# Special tokens
|
| 121 |
+
|
| 122 |
+
def token_bos(self) -> int:
|
| 123 |
+
return llama_cpp.llama_token_bos(self.vocab)
|
| 124 |
+
|
| 125 |
+
def token_eos(self) -> int:
|
| 126 |
+
return llama_cpp.llama_token_eos(self.vocab)
|
| 127 |
+
|
| 128 |
+
def token_cls(self) -> int:
|
| 129 |
+
return llama_cpp.llama_token_cls(self.vocab)
|
| 130 |
+
|
| 131 |
+
def token_sep(self) -> int:
|
| 132 |
+
return llama_cpp.llama_token_sep(self.vocab)
|
| 133 |
+
|
| 134 |
+
def token_nl(self) -> int:
|
| 135 |
+
return llama_cpp.llama_token_nl(self.vocab)
|
| 136 |
+
|
| 137 |
+
def token_prefix(self) -> int:
|
| 138 |
+
raise NotImplementedError("token_prefix is not implemented in llama.cpp")
|
| 139 |
+
|
| 140 |
+
def token_middle(self) -> int:
|
| 141 |
+
raise NotImplementedError("token_middle is not implemented in llama.cpp")
|
| 142 |
+
|
| 143 |
+
def token_suffix(self) -> int:
|
| 144 |
+
raise NotImplementedError("token_suffix is not implemented in llama.cpp")
|
| 145 |
+
|
| 146 |
+
def token_eot(self) -> int:
|
| 147 |
+
return llama_cpp.llama_token_eot(self.vocab)
|
| 148 |
+
|
| 149 |
+
def add_bos_token(self) -> bool:
|
| 150 |
+
return llama_cpp.llama_add_bos_token(self.vocab)
|
| 151 |
+
|
| 152 |
+
def add_eos_token(self) -> bool:
|
| 153 |
+
return llama_cpp.llama_add_eos_token(self.vocab)
|
| 154 |
+
|
| 155 |
+
# Tokenization
|
| 156 |
+
|
| 157 |
+
def tokenize(self, text: bytes, add_bos: bool, special: bool):
|
| 158 |
+
n_ctx = self.n_ctx_train()
|
| 159 |
+
tokens = (llama_cpp.llama_token * n_ctx)()
|
| 160 |
+
n_tokens = llama_cpp.llama_tokenize(
|
| 161 |
+
self.vocab, text, len(text), tokens, n_ctx, add_bos, special
|
| 162 |
+
)
|
| 163 |
+
if n_tokens < 0:
|
| 164 |
+
n_tokens = abs(n_tokens)
|
| 165 |
+
tokens = (llama_cpp.llama_token * n_tokens)()
|
| 166 |
+
n_tokens = llama_cpp.llama_tokenize(
|
| 167 |
+
self.vocab, text, len(text), tokens, n_tokens, add_bos, special
|
| 168 |
+
)
|
| 169 |
+
if n_tokens < 0:
|
| 170 |
+
raise RuntimeError(
|
| 171 |
+
f'Failed to tokenize: text="{text}" n_tokens={n_tokens}'
|
| 172 |
+
)
|
| 173 |
+
return list(tokens[:n_tokens])
|
| 174 |
+
|
| 175 |
+
def token_to_piece(self, token: int, special: bool = False) -> bytes:
|
| 176 |
+
buf = ctypes.create_string_buffer(32)
|
| 177 |
+
llama_cpp.llama_token_to_piece(self.vocab, token, buf, 32, 0, special)
|
| 178 |
+
return bytes(buf)
|
| 179 |
+
|
| 180 |
+
def detokenize(self, tokens: List[int], special: bool = False) -> bytes:
|
| 181 |
+
output = b""
|
| 182 |
+
size = 32
|
| 183 |
+
buffer = (ctypes.c_char * size)()
|
| 184 |
+
for token in tokens:
|
| 185 |
+
n = llama_cpp.llama_token_to_piece(
|
| 186 |
+
self.vocab, llama_cpp.llama_token(token), buffer, size, 0, special
|
| 187 |
+
)
|
| 188 |
+
assert n <= size
|
| 189 |
+
output += bytes(buffer[:n])
|
| 190 |
+
# NOTE: Llama1 models automatically added a space at the start of the prompt
|
| 191 |
+
# this line removes a leading space if the first token is a beginning of sentence token
|
| 192 |
+
return (
|
| 193 |
+
output[1:]
|
| 194 |
+
if len(tokens) > 0 and tokens[0] == self.token_bos() and output[0:1] == b" "
|
| 195 |
+
else output
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
# Extra
|
| 199 |
+
def metadata(self) -> Dict[str, str]:
|
| 200 |
+
metadata: Dict[str, str] = {}
|
| 201 |
+
buffer_size = 1024
|
| 202 |
+
buffer = ctypes.create_string_buffer(buffer_size)
|
| 203 |
+
# zero the buffer
|
| 204 |
+
buffer.value = b"\0" * buffer_size
|
| 205 |
+
# iterate over model keys
|
| 206 |
+
for i in range(llama_cpp.llama_model_meta_count(self.model)):
|
| 207 |
+
nbytes = llama_cpp.llama_model_meta_key_by_index(
|
| 208 |
+
self.model, i, buffer, buffer_size
|
| 209 |
+
)
|
| 210 |
+
if nbytes > buffer_size:
|
| 211 |
+
buffer_size = nbytes + 1
|
| 212 |
+
buffer = ctypes.create_string_buffer(buffer_size)
|
| 213 |
+
nbytes = llama_cpp.llama_model_meta_key_by_index(
|
| 214 |
+
self.model, i, buffer, buffer_size
|
| 215 |
+
)
|
| 216 |
+
key = buffer.value.decode("utf-8")
|
| 217 |
+
nbytes = llama_cpp.llama_model_meta_val_str_by_index(
|
| 218 |
+
self.model, i, buffer, buffer_size
|
| 219 |
+
)
|
| 220 |
+
if nbytes > buffer_size:
|
| 221 |
+
buffer_size = nbytes + 1
|
| 222 |
+
buffer = ctypes.create_string_buffer(buffer_size)
|
| 223 |
+
nbytes = llama_cpp.llama_model_meta_val_str_by_index(
|
| 224 |
+
self.model, i, buffer, buffer_size
|
| 225 |
+
)
|
| 226 |
+
value = buffer.value.decode("utf-8")
|
| 227 |
+
metadata[key] = value
|
| 228 |
+
return metadata
|
| 229 |
+
|
| 230 |
+
@staticmethod
|
| 231 |
+
def default_params():
|
| 232 |
+
"""Get the default llama_model_params."""
|
| 233 |
+
return llama_cpp.llama_model_default_params()
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
class LlamaContext:
|
| 237 |
+
"""Intermediate Python wrapper for a llama.cpp llama_context.
|
| 238 |
+
NOTE: For stability it's recommended you use the Llama class instead."""
|
| 239 |
+
|
| 240 |
+
def __init__(
|
| 241 |
+
self,
|
| 242 |
+
*,
|
| 243 |
+
model: LlamaModel,
|
| 244 |
+
params: llama_cpp.llama_context_params,
|
| 245 |
+
verbose: bool = True,
|
| 246 |
+
):
|
| 247 |
+
self.model = model
|
| 248 |
+
self.params = params
|
| 249 |
+
self.verbose = verbose
|
| 250 |
+
self._exit_stack = ExitStack()
|
| 251 |
+
|
| 252 |
+
ctx = llama_cpp.llama_new_context_with_model(self.model.model, self.params)
|
| 253 |
+
|
| 254 |
+
if ctx is None:
|
| 255 |
+
raise ValueError("Failed to create llama_context")
|
| 256 |
+
|
| 257 |
+
self.ctx = ctx
|
| 258 |
+
|
| 259 |
+
def free_ctx():
|
| 260 |
+
if self.ctx is None:
|
| 261 |
+
return
|
| 262 |
+
llama_cpp.llama_free(self.ctx)
|
| 263 |
+
self.ctx = None
|
| 264 |
+
|
| 265 |
+
self._exit_stack.callback(free_ctx)
|
| 266 |
+
|
| 267 |
+
def close(self):
|
| 268 |
+
self._exit_stack.close()
|
| 269 |
+
|
| 270 |
+
def __del__(self):
|
| 271 |
+
self.close()
|
| 272 |
+
|
| 273 |
+
def n_ctx(self) -> int:
|
| 274 |
+
return llama_cpp.llama_n_ctx(self.ctx)
|
| 275 |
+
|
| 276 |
+
def pooling_type(self) -> int:
|
| 277 |
+
return llama_cpp.llama_pooling_type(self.ctx)
|
| 278 |
+
|
| 279 |
+
def kv_cache_clear(self):
|
| 280 |
+
llama_cpp.llama_kv_cache_clear(self.ctx)
|
| 281 |
+
|
| 282 |
+
def kv_cache_seq_rm(self, seq_id: int, p0: int, p1: int):
|
| 283 |
+
llama_cpp.llama_kv_cache_seq_rm(self.ctx, seq_id, p0, p1)
|
| 284 |
+
|
| 285 |
+
def kv_cache_seq_cp(self, seq_id_src: int, seq_id_dst: int, p0: int, p1: int):
|
| 286 |
+
llama_cpp.llama_kv_cache_seq_cp(self.ctx, seq_id_src, seq_id_dst, p0, p1)
|
| 287 |
+
|
| 288 |
+
def kv_cache_seq_keep(self, seq_id: int):
|
| 289 |
+
llama_cpp.llama_kv_cache_seq_keep(self.ctx, seq_id)
|
| 290 |
+
|
| 291 |
+
def kv_cache_seq_shift(self, seq_id: int, p0: int, p1: int, shift: int):
|
| 292 |
+
llama_cpp.llama_kv_cache_seq_add(self.ctx, seq_id, p0, p1, shift)
|
| 293 |
+
|
| 294 |
+
def get_state_size(self) -> int:
|
| 295 |
+
return llama_cpp.llama_get_state_size(self.ctx)
|
| 296 |
+
|
| 297 |
+
# TODO: copy_state_data
|
| 298 |
+
|
| 299 |
+
# TODO: set_state_data
|
| 300 |
+
|
| 301 |
+
# TODO: llama_load_session_file
|
| 302 |
+
|
| 303 |
+
# TODO: llama_save_session_file
|
| 304 |
+
|
| 305 |
+
def decode(self, batch: LlamaBatch):
|
| 306 |
+
return_code = llama_cpp.llama_decode(
|
| 307 |
+
self.ctx,
|
| 308 |
+
batch.batch,
|
| 309 |
+
)
|
| 310 |
+
if return_code != 0:
|
| 311 |
+
raise RuntimeError(f"llama_decode returned {return_code}")
|
| 312 |
+
|
| 313 |
+
def set_n_threads(self, n_threads: int, n_threads_batch: int):
|
| 314 |
+
llama_cpp.llama_set_n_threads(self.ctx, n_threads, n_threads_batch)
|
| 315 |
+
|
| 316 |
+
def get_logits(self):
|
| 317 |
+
return llama_cpp.llama_get_logits(self.ctx)
|
| 318 |
+
|
| 319 |
+
def get_logits_ith(self, i: int):
|
| 320 |
+
return llama_cpp.llama_get_logits_ith(self.ctx, i)
|
| 321 |
+
|
| 322 |
+
def get_embeddings(self):
|
| 323 |
+
return llama_cpp.llama_get_embeddings(self.ctx)
|
| 324 |
+
|
| 325 |
+
# Sampling functions
|
| 326 |
+
|
| 327 |
+
def set_rng_seed(self, seed: int):
|
| 328 |
+
# TODO: Fix
|
| 329 |
+
# llama_cpp.llama_set_rng_seed(self.ctx, seed)
|
| 330 |
+
raise NotImplementedError("set_rng_seed is not implemented in llama.cpp")
|
| 331 |
+
|
| 332 |
+
def sample_repetition_penalties(
|
| 333 |
+
self,
|
| 334 |
+
candidates: "_LlamaTokenDataArray",
|
| 335 |
+
last_tokens_data: "llama_cpp.Array[llama_cpp.llama_token]",
|
| 336 |
+
penalty_last_n: int,
|
| 337 |
+
penalty_repeat: float,
|
| 338 |
+
penalty_freq: float,
|
| 339 |
+
penalty_present: float,
|
| 340 |
+
):
|
| 341 |
+
# llama_cpp.llama_sample_repetition_penalties(
|
| 342 |
+
# self.ctx,
|
| 343 |
+
# llama_cpp.byref(candidates.candidates),
|
| 344 |
+
# last_tokens_data,
|
| 345 |
+
# penalty_last_n,
|
| 346 |
+
# penalty_repeat,
|
| 347 |
+
# penalty_freq,
|
| 348 |
+
# penalty_present,
|
| 349 |
+
# )
|
| 350 |
+
raise NotImplementedError("sample_repetition_penalties is not implemented in llama.cpp")
|
| 351 |
+
|
| 352 |
+
def sample_softmax(self, candidates: "_LlamaTokenDataArray"):
|
| 353 |
+
# llama_cpp.llama_sample_softmax(
|
| 354 |
+
# self.ctx,
|
| 355 |
+
# llama_cpp.byref(candidates.candidates),
|
| 356 |
+
# )
|
| 357 |
+
raise NotImplementedError("sample_softmax is not implemented in llama.cpp")
|
| 358 |
+
|
| 359 |
+
def sample_top_k(self, candidates: "_LlamaTokenDataArray", k: int, min_keep: int):
|
| 360 |
+
# llama_cpp.llama_sample_top_k(
|
| 361 |
+
# self.ctx, llama_cpp.byref(candidates.candidates), k, min_keep
|
| 362 |
+
# )
|
| 363 |
+
raise NotImplementedError("sample_top_k is not implemented in llama.cpp")
|
| 364 |
+
|
| 365 |
+
def sample_top_p(self, candidates: "_LlamaTokenDataArray", p: float, min_keep: int):
|
| 366 |
+
# llama_cpp.llama_sample_top_p(
|
| 367 |
+
# self.ctx, llama_cpp.byref(candidates.candidates), p, min_keep
|
| 368 |
+
# )
|
| 369 |
+
raise NotImplementedError("sample_top_p is not implemented in llama.cpp")
|
| 370 |
+
|
| 371 |
+
def sample_min_p(self, candidates: "_LlamaTokenDataArray", p: float, min_keep: int):
|
| 372 |
+
# llama_cpp.llama_sample_min_p(
|
| 373 |
+
# self.ctx, llama_cpp.byref(candidates.candidates), p, min_keep
|
| 374 |
+
# )
|
| 375 |
+
raise NotImplementedError("sample_min_p is not implemented in llama.cpp")
|
| 376 |
+
|
| 377 |
+
def sample_typical(
|
| 378 |
+
self, candidates: "_LlamaTokenDataArray", p: float, min_keep: int
|
| 379 |
+
):
|
| 380 |
+
# llama_cpp.llama_sample_typical(
|
| 381 |
+
# self.ctx, llama_cpp.byref(candidates.candidates), p, min_keep
|
| 382 |
+
# )
|
| 383 |
+
raise NotImplementedError("sample_typical is not implemented in llama.cpp")
|
| 384 |
+
|
| 385 |
+
def sample_temp(self, candidates: "_LlamaTokenDataArray", temp: float):
|
| 386 |
+
# llama_cpp.llama_sample_temp(
|
| 387 |
+
# self.ctx, llama_cpp.byref(candidates.candidates), temp
|
| 388 |
+
# )
|
| 389 |
+
raise NotImplementedError("sample_temp is not implemented in llama.cpp")
|
| 390 |
+
|
| 391 |
+
def sample_grammar(self, candidates: "_LlamaTokenDataArray", grammar: LlamaGrammar):
|
| 392 |
+
# llama_cpp.llama_sample_grammar(
|
| 393 |
+
# self.ctx,
|
| 394 |
+
# llama_cpp.byref(candidates.candidates),
|
| 395 |
+
# grammar.grammar,
|
| 396 |
+
# )
|
| 397 |
+
raise NotImplementedError("sample_grammar is not implemented in llama.cpp")
|
| 398 |
+
|
| 399 |
+
def sample_token_mirostat(
|
| 400 |
+
self,
|
| 401 |
+
candidates: "_LlamaTokenDataArray",
|
| 402 |
+
tau: float,
|
| 403 |
+
eta: float,
|
| 404 |
+
m: int,
|
| 405 |
+
mu: llama_cpp.CtypesPointerOrRef[ctypes.c_float],
|
| 406 |
+
) -> int:
|
| 407 |
+
raise NotImplementedError("sample_token_mirostat is not implemented in llama.cpp")
|
| 408 |
+
# return llama_cpp.llama_sample_token_mirostat(
|
| 409 |
+
# self.ctx,
|
| 410 |
+
# llama_cpp.byref(candidates.candidates),
|
| 411 |
+
# tau,
|
| 412 |
+
# eta,
|
| 413 |
+
# m,
|
| 414 |
+
# mu,
|
| 415 |
+
# )
|
| 416 |
+
|
| 417 |
+
def sample_token_mirostat_v2(
|
| 418 |
+
self,
|
| 419 |
+
candidates: "_LlamaTokenDataArray",
|
| 420 |
+
tau: float,
|
| 421 |
+
eta: float,
|
| 422 |
+
mu: llama_cpp.CtypesPointerOrRef[ctypes.c_float],
|
| 423 |
+
) -> int:
|
| 424 |
+
raise NotImplementedError("sample_token_mirostat_v2 is not implemented in llama.cpp")
|
| 425 |
+
# return llama_cpp.llama_sample_token_mirostat_v2(
|
| 426 |
+
# self.ctx,
|
| 427 |
+
# llama_cpp.byref(candidates.candidates),
|
| 428 |
+
# tau,
|
| 429 |
+
# eta,
|
| 430 |
+
# mu,
|
| 431 |
+
# )
|
| 432 |
+
|
| 433 |
+
def sample_token_greedy(self, candidates: "_LlamaTokenDataArray") -> int:
|
| 434 |
+
raise NotImplementedError("sample_token_greedy is not implemented in llama.cpp")
|
| 435 |
+
# return llama_cpp.llama_sample_token_greedy(
|
| 436 |
+
# self.ctx,
|
| 437 |
+
# llama_cpp.byref(candidates.candidates),
|
| 438 |
+
# )
|
| 439 |
+
|
| 440 |
+
def sample_token(self, candidates: "_LlamaTokenDataArray") -> int:
|
| 441 |
+
raise NotImplementedError("sample_token is not implemented in llama.cpp")
|
| 442 |
+
# return llama_cpp.llama_sample_token(
|
| 443 |
+
# self.ctx,
|
| 444 |
+
# llama_cpp.byref(candidates.candidates),
|
| 445 |
+
# )
|
| 446 |
+
|
| 447 |
+
# Grammar
|
| 448 |
+
def grammar_accept_token(self, grammar: LlamaGrammar, token: int):
|
| 449 |
+
raise NotImplementedError("grammar_accept_token is not implemented in llama.cpp")
|
| 450 |
+
# llama_cpp.llama_grammar_accept_token(grammar.grammar, self.ctx, token)
|
| 451 |
+
|
| 452 |
+
def reset_timings(self):
|
| 453 |
+
llama_cpp.llama_perf_context_reset(self.ctx)
|
| 454 |
+
|
| 455 |
+
def print_timings(self):
|
| 456 |
+
llama_cpp.llama_perf_context_print(self.ctx)
|
| 457 |
+
|
| 458 |
+
# Utility functions
|
| 459 |
+
@staticmethod
|
| 460 |
+
def default_params():
|
| 461 |
+
"""Get the default llama_context_params."""
|
| 462 |
+
return llama_cpp.llama_context_default_params()
|
| 463 |
+
|
| 464 |
+
|
| 465 |
+
class LlamaBatch:
|
| 466 |
+
def __init__(
|
| 467 |
+
self, *, n_tokens: int, embd: int, n_seq_max: int, verbose: bool = True
|
| 468 |
+
):
|
| 469 |
+
self._n_tokens = n_tokens
|
| 470 |
+
self.embd = embd
|
| 471 |
+
self.n_seq_max = n_seq_max
|
| 472 |
+
self.verbose = verbose
|
| 473 |
+
self._exit_stack = ExitStack()
|
| 474 |
+
|
| 475 |
+
batch = llama_cpp.llama_batch_init(self._n_tokens, self.embd, self.n_seq_max)
|
| 476 |
+
|
| 477 |
+
if batch is None:
|
| 478 |
+
raise ValueError("Failed to create llama_batch")
|
| 479 |
+
|
| 480 |
+
self.batch = batch
|
| 481 |
+
|
| 482 |
+
def free_batch():
|
| 483 |
+
if self.batch is None:
|
| 484 |
+
return
|
| 485 |
+
llama_cpp.llama_batch_free(self.batch)
|
| 486 |
+
self.batch = None
|
| 487 |
+
|
| 488 |
+
self._exit_stack.callback(free_batch)
|
| 489 |
+
|
| 490 |
+
def close(self):
|
| 491 |
+
self._exit_stack.close()
|
| 492 |
+
|
| 493 |
+
def __del__(self):
|
| 494 |
+
self.close()
|
| 495 |
+
|
| 496 |
+
def n_tokens(self) -> int:
|
| 497 |
+
return self.batch.n_tokens
|
| 498 |
+
|
| 499 |
+
def reset(self):
|
| 500 |
+
self.batch.n_tokens = 0
|
| 501 |
+
|
| 502 |
+
def set_batch(self, batch: Sequence[int], n_past: int, logits_all: bool):
|
| 503 |
+
n_tokens = len(batch)
|
| 504 |
+
self.batch.n_tokens = n_tokens
|
| 505 |
+
for i in range(n_tokens):
|
| 506 |
+
self.batch.token[i] = batch[i]
|
| 507 |
+
self.batch.pos[i] = n_past + i
|
| 508 |
+
self.batch.seq_id[i][0] = 0
|
| 509 |
+
self.batch.n_seq_id[i] = 1
|
| 510 |
+
self.batch.logits[i] = logits_all
|
| 511 |
+
self.batch.logits[n_tokens - 1] = True
|
| 512 |
+
|
| 513 |
+
def add_sequence(self, batch: Sequence[int], seq_id: int, logits_all: bool):
|
| 514 |
+
n_tokens = len(batch)
|
| 515 |
+
n_tokens0 = self.batch.n_tokens
|
| 516 |
+
self.batch.n_tokens += n_tokens
|
| 517 |
+
for i in range(n_tokens):
|
| 518 |
+
j = n_tokens0 + i
|
| 519 |
+
self.batch.token[j] = batch[i]
|
| 520 |
+
self.batch.pos[j] = i
|
| 521 |
+
self.batch.seq_id[j][0] = seq_id
|
| 522 |
+
self.batch.n_seq_id[j] = 1
|
| 523 |
+
self.batch.logits[j] = logits_all
|
| 524 |
+
self.batch.logits[n_tokens - 1] = True
|
| 525 |
+
|
| 526 |
+
|
| 527 |
+
class LlamaTokenDataArray:
|
| 528 |
+
def __init__(self, *, n_vocab: int):
|
| 529 |
+
self.n_vocab = n_vocab
|
| 530 |
+
self.candidates_data = np.recarray(
|
| 531 |
+
(self.n_vocab,),
|
| 532 |
+
dtype=np.dtype(
|
| 533 |
+
[("id", np.intc), ("logit", np.single), ("p", np.single)], align=True
|
| 534 |
+
),
|
| 535 |
+
)
|
| 536 |
+
self.candidates = llama_cpp.llama_token_data_array(
|
| 537 |
+
data=self.candidates_data.ctypes.data_as(llama_cpp.llama_token_data_p),
|
| 538 |
+
size=self.n_vocab,
|
| 539 |
+
sorted=False,
|
| 540 |
+
)
|
| 541 |
+
self.default_candidates_data_id = np.arange(self.n_vocab, dtype=np.intc) # type: ignore
|
| 542 |
+
self.default_candidates_data_p = np.zeros(self.n_vocab, dtype=np.single)
|
| 543 |
+
|
| 544 |
+
def copy_logits(self, logits: npt.NDArray[np.single]):
|
| 545 |
+
self.candidates_data.id[:] = self.default_candidates_data_id
|
| 546 |
+
self.candidates_data.logit[:] = logits
|
| 547 |
+
self.candidates_data.p[:] = self.default_candidates_data_p
|
| 548 |
+
self.candidates.sorted = False
|
| 549 |
+
self.candidates.size = self.n_vocab
|
| 550 |
+
|
| 551 |
+
|
| 552 |
+
# Embedding functions
|
| 553 |
+
|
| 554 |
+
|
| 555 |
+
def normalize_embedding(embedding):
|
| 556 |
+
norm = float(np.linalg.norm(embedding))
|
| 557 |
+
if norm == 0.0:
|
| 558 |
+
return embedding
|
| 559 |
+
return [v / norm for v in embedding]
|
| 560 |
+
|
| 561 |
+
|
| 562 |
+
# Python wrappers over common/sampling structs
|
| 563 |
+
|
| 564 |
+
|
| 565 |
+
@dataclass
|
| 566 |
+
class LlamaSamplingParams:
|
| 567 |
+
n_prev: int = 64
|
| 568 |
+
n_probs: int = 0
|
| 569 |
+
top_k: int = 40
|
| 570 |
+
top_p: float = 0.95
|
| 571 |
+
min_p: float = 0.05
|
| 572 |
+
tfs_z: float = 1.00
|
| 573 |
+
typical_p: float = 1.00
|
| 574 |
+
temp: float = 0.80
|
| 575 |
+
penalty_last_n: int = 64
|
| 576 |
+
penalty_repeat: float = 1.0
|
| 577 |
+
penalty_freq: float = 0.00
|
| 578 |
+
penalty_present: float = 0.00
|
| 579 |
+
mirostat: int = 0
|
| 580 |
+
mirostat_tau: float = 5.00
|
| 581 |
+
mirostat_eta: float = 0.10
|
| 582 |
+
penalize_nl: bool = True
|
| 583 |
+
|
| 584 |
+
grammar: str = ""
|
| 585 |
+
|
| 586 |
+
cfg_negative_prompt: str = ""
|
| 587 |
+
cfg_scale: float = 1.00
|
| 588 |
+
|
| 589 |
+
logit_bias: dict[int, float] = field(default_factory=dict)
|
| 590 |
+
|
| 591 |
+
|
| 592 |
+
@dataclass
|
| 593 |
+
class LlamaSamplingContext:
|
| 594 |
+
params: LlamaSamplingParams = field(default_factory=LlamaSamplingParams)
|
| 595 |
+
mirostat_mu: ctypes.c_float = field(default_factory=ctypes.c_float)
|
| 596 |
+
grammar: Optional[LlamaGrammar] = None
|
| 597 |
+
# NOTE: Missing parsed_grammar
|
| 598 |
+
prev: list[int] = field(default_factory=list)
|
| 599 |
+
cur: list[llama_cpp.llama_token_data] = field(default_factory=list)
|
| 600 |
+
|
| 601 |
+
def reset(self):
|
| 602 |
+
self.prev = []
|
| 603 |
+
self.cur = []
|
| 604 |
+
if self.grammar is not None:
|
| 605 |
+
self.grammar.reset()
|
| 606 |
+
|
| 607 |
+
def cp(self):
|
| 608 |
+
return LlamaSamplingContext(
|
| 609 |
+
params=self.params,
|
| 610 |
+
mirostat_mu=self.mirostat_mu,
|
| 611 |
+
grammar=self.grammar,
|
| 612 |
+
prev=self.prev.copy(),
|
| 613 |
+
cur=self.cur.copy(),
|
| 614 |
+
)
|
| 615 |
+
|
| 616 |
+
def last(self) -> Optional[int]:
|
| 617 |
+
if len(self.prev) > 0:
|
| 618 |
+
return self.prev[-1]
|
| 619 |
+
else:
|
| 620 |
+
return None
|
| 621 |
+
|
| 622 |
+
def prev_str(self, ctx_main: LlamaContext, n: int) -> str:
|
| 623 |
+
return ctx_main.model.detokenize(self.prev[-n:]).decode("utf-8")
|
| 624 |
+
|
| 625 |
+
def sample(
|
| 626 |
+
self,
|
| 627 |
+
ctx_main: LlamaContext,
|
| 628 |
+
idx: int = 0,
|
| 629 |
+
logits_array: Optional[npt.NDArray[np.single]] = None,
|
| 630 |
+
):
|
| 631 |
+
n_vocab = ctx_main.model.n_vocab()
|
| 632 |
+
id: int = 0
|
| 633 |
+
|
| 634 |
+
if logits_array is None:
|
| 635 |
+
logits = ctx_main.get_logits_ith(idx)
|
| 636 |
+
logits_array = np.array(
|
| 637 |
+
ctypes.cast(logits, ctypes.POINTER(ctypes.c_float * n_vocab)).contents,
|
| 638 |
+
dtype=np.single,
|
| 639 |
+
)
|
| 640 |
+
|
| 641 |
+
# apply logit_bias
|
| 642 |
+
for token, logit_bias in self.params.logit_bias.items():
|
| 643 |
+
logits_array[token] += logit_bias
|
| 644 |
+
|
| 645 |
+
token_data_array = LlamaTokenDataArray(
|
| 646 |
+
n_vocab=n_vocab
|
| 647 |
+
) # TODO: Only create this once
|
| 648 |
+
token_data_array.copy_logits(logits_array)
|
| 649 |
+
|
| 650 |
+
# apply penalties
|
| 651 |
+
if len(self.prev) > 0:
|
| 652 |
+
nl_token = ctx_main.model.token_nl()
|
| 653 |
+
nl_logit = logits_array[nl_token]
|
| 654 |
+
last_tokens = self.prev[-self.params.penalty_last_n :]
|
| 655 |
+
last_tokens_size = min(len(last_tokens), self.params.penalty_last_n)
|
| 656 |
+
if last_tokens_size > 0:
|
| 657 |
+
last_tokens_p = (llama_cpp.llama_token * len(last_tokens))(*last_tokens)
|
| 658 |
+
ctx_main.sample_repetition_penalties(
|
| 659 |
+
token_data_array,
|
| 660 |
+
last_tokens_p,
|
| 661 |
+
last_tokens_size,
|
| 662 |
+
self.params.penalty_repeat,
|
| 663 |
+
self.params.penalty_freq,
|
| 664 |
+
self.params.penalty_present,
|
| 665 |
+
)
|
| 666 |
+
if not self.params.penalize_nl:
|
| 667 |
+
token_data_array.candidates_data.logit[nl_token] = nl_logit
|
| 668 |
+
|
| 669 |
+
if self.grammar is not None:
|
| 670 |
+
ctx_main.sample_grammar(token_data_array, self.grammar)
|
| 671 |
+
|
| 672 |
+
if self.params.temp < 0:
|
| 673 |
+
ctx_main.sample_softmax(token_data_array)
|
| 674 |
+
id = token_data_array.candidates_data.id[0]
|
| 675 |
+
elif self.params.temp == 0:
|
| 676 |
+
id = ctx_main.sample_token_greedy(token_data_array)
|
| 677 |
+
else:
|
| 678 |
+
if self.params.mirostat == 1:
|
| 679 |
+
mirostat_m = 100
|
| 680 |
+
ctx_main.sample_temp(token_data_array, self.params.temp)
|
| 681 |
+
id = ctx_main.sample_token_mirostat(
|
| 682 |
+
token_data_array,
|
| 683 |
+
self.params.mirostat_tau,
|
| 684 |
+
self.params.mirostat_eta,
|
| 685 |
+
mirostat_m,
|
| 686 |
+
ctypes.pointer(self.mirostat_mu),
|
| 687 |
+
)
|
| 688 |
+
elif self.params.mirostat == 2:
|
| 689 |
+
ctx_main.sample_temp(token_data_array, self.params.temp)
|
| 690 |
+
id = ctx_main.sample_token_mirostat_v2(
|
| 691 |
+
token_data_array,
|
| 692 |
+
self.params.mirostat_tau,
|
| 693 |
+
self.params.mirostat_eta,
|
| 694 |
+
ctypes.pointer(self.mirostat_mu),
|
| 695 |
+
)
|
| 696 |
+
else:
|
| 697 |
+
min_keep = max(1, self.params.n_probs)
|
| 698 |
+
ctx_main.sample_top_k(
|
| 699 |
+
token_data_array, self.params.top_k, min_keep=min_keep
|
| 700 |
+
)
|
| 701 |
+
ctx_main.sample_typical(
|
| 702 |
+
token_data_array, self.params.typical_p, min_keep=min_keep
|
| 703 |
+
)
|
| 704 |
+
ctx_main.sample_top_p(
|
| 705 |
+
token_data_array, self.params.top_p, min_keep=min_keep
|
| 706 |
+
)
|
| 707 |
+
ctx_main.sample_min_p(
|
| 708 |
+
token_data_array, self.params.min_p, min_keep=min_keep
|
| 709 |
+
)
|
| 710 |
+
ctx_main.sample_temp(token_data_array, self.params.temp)
|
| 711 |
+
id = ctx_main.sample_token(token_data_array)
|
| 712 |
+
return id
|
| 713 |
+
|
| 714 |
+
def accept(self, ctx_main: LlamaContext, id: int, apply_grammar: bool):
|
| 715 |
+
if apply_grammar and self.grammar is not None:
|
| 716 |
+
ctx_main.grammar_accept_token(self.grammar, id)
|
| 717 |
+
self.prev.append(id)
|
| 718 |
+
|
| 719 |
+
|
| 720 |
+
from typing import List, Callable, Optional, Union
|
| 721 |
+
import ctypes
|
| 722 |
+
import llama_cpp
|
| 723 |
+
|
| 724 |
+
|
| 725 |
+
class CustomSampler:
|
| 726 |
+
def __init__(
|
| 727 |
+
self, apply_func: typing.Callable[[llama_cpp.llama_token_data_array], None]
|
| 728 |
+
):
|
| 729 |
+
self.apply_func = apply_func
|
| 730 |
+
|
| 731 |
+
def apply_wrapper(
|
| 732 |
+
sampler: llama_cpp.llama_sampler_p,
|
| 733 |
+
cur_p: llama_cpp.llama_token_data_array_p,
|
| 734 |
+
):
|
| 735 |
+
self.apply_func(cur_p)
|
| 736 |
+
|
| 737 |
+
def free_wrapper(sampler: llama_cpp.llama_sampler_p):
|
| 738 |
+
pass
|
| 739 |
+
|
| 740 |
+
sampler_i = llama_cpp.llama_sampler_i()
|
| 741 |
+
sampler_i.apply = llama_cpp.llama_sampler_i_apply(apply_wrapper)
|
| 742 |
+
self._apply_wrapper_ref = apply_wrapper
|
| 743 |
+
|
| 744 |
+
sampler_i.name = llama_cpp.llama_sampler_i_name(0)
|
| 745 |
+
sampler_i.accept = llama_cpp.llama_sampler_i_accept(0)
|
| 746 |
+
sampler_i.reset = llama_cpp.llama_sampler_i_reset(0)
|
| 747 |
+
sampler_i.clone = llama_cpp.llama_sampler_i_clone(0)
|
| 748 |
+
sampler_i.free = llama_cpp.llama_sampler_i_free(0)
|
| 749 |
+
|
| 750 |
+
self.sampler = llama_cpp.llama_sampler()
|
| 751 |
+
self.sampler.iface = ctypes.pointer(sampler_i)
|
| 752 |
+
self.sampler.ctx = None
|
| 753 |
+
|
| 754 |
+
def get_sampler(self) -> llama_cpp.llama_sampler_p:
|
| 755 |
+
return ctypes.pointer(self.sampler)
|
| 756 |
+
|
| 757 |
+
|
| 758 |
+
class LlamaSampler:
|
| 759 |
+
def __init__(self):
|
| 760 |
+
params = llama_cpp.llama_sampler_chain_params()
|
| 761 |
+
self.sampler = llama_cpp.llama_sampler_chain_init(params)
|
| 762 |
+
self.samplers: List[llama_cpp.llama_sampler_p] = []
|
| 763 |
+
self.custom_samplers: List[Tuple[int, CustomSampler]] = []
|
| 764 |
+
|
| 765 |
+
def add_greedy(self):
|
| 766 |
+
sampler = llama_cpp.llama_sampler_init_greedy()
|
| 767 |
+
self._add_sampler(sampler)
|
| 768 |
+
|
| 769 |
+
def add_dist(self, seed: int):
|
| 770 |
+
sampler = llama_cpp.llama_sampler_init_dist(seed)
|
| 771 |
+
self._add_sampler(sampler)
|
| 772 |
+
|
| 773 |
+
def add_softmax(self):
|
| 774 |
+
sampler = llama_cpp.llama_sampler_init_softmax()
|
| 775 |
+
self._add_sampler(sampler)
|
| 776 |
+
|
| 777 |
+
def add_top_k(self, k: int):
|
| 778 |
+
sampler = llama_cpp.llama_sampler_init_top_k(k)
|
| 779 |
+
self._add_sampler(sampler)
|
| 780 |
+
|
| 781 |
+
def add_top_p(self, p: float, min_keep: int):
|
| 782 |
+
sampler = llama_cpp.llama_sampler_init_top_p(p, min_keep)
|
| 783 |
+
self._add_sampler(sampler)
|
| 784 |
+
|
| 785 |
+
def add_min_p(self, p: float, min_keep: int):
|
| 786 |
+
sampler = llama_cpp.llama_sampler_init_min_p(p, min_keep)
|
| 787 |
+
self._add_sampler(sampler)
|
| 788 |
+
|
| 789 |
+
def add_typical(self, p: float, min_keep: int):
|
| 790 |
+
sampler = llama_cpp.llama_sampler_init_typical(p, min_keep)
|
| 791 |
+
self._add_sampler(sampler)
|
| 792 |
+
|
| 793 |
+
def add_temp(self, temp: float):
|
| 794 |
+
sampler = llama_cpp.llama_sampler_init_temp(temp)
|
| 795 |
+
self._add_sampler(sampler)
|
| 796 |
+
|
| 797 |
+
def add_temp_ext(self, t: float, delta: float, exponent: float):
|
| 798 |
+
sampler = llama_cpp.llama_sampler_init_temp_ext(t, delta, exponent)
|
| 799 |
+
self._add_sampler(sampler)
|
| 800 |
+
|
| 801 |
+
def add_mirostat(self, n_vocab: int, seed: int, tau: float, eta: float, m: int):
|
| 802 |
+
sampler = llama_cpp.llama_sampler_init_mirostat(n_vocab, seed, tau, eta, m)
|
| 803 |
+
self._add_sampler(sampler)
|
| 804 |
+
|
| 805 |
+
def add_mirostat_v2(self, seed: int, tau: float, eta: float):
|
| 806 |
+
sampler = llama_cpp.llama_sampler_init_mirostat_v2(seed, tau, eta)
|
| 807 |
+
self._add_sampler(sampler)
|
| 808 |
+
|
| 809 |
+
def add_grammar(self, model: LlamaModel, grammar: LlamaGrammar):
|
| 810 |
+
sampler = llama_cpp.llama_sampler_init_grammar(
|
| 811 |
+
model.vocab, grammar._grammar.encode("utf-8"), grammar._root.encode("utf-8")
|
| 812 |
+
)
|
| 813 |
+
self._add_sampler(sampler)
|
| 814 |
+
|
| 815 |
+
def add_penalties(
|
| 816 |
+
self,
|
| 817 |
+
n_vocab: int,
|
| 818 |
+
special_eos_id: int,
|
| 819 |
+
linefeed_id: int,
|
| 820 |
+
penalty_last_n: int,
|
| 821 |
+
penalty_repeat: float,
|
| 822 |
+
penalty_freq: float,
|
| 823 |
+
penalty_present: float,
|
| 824 |
+
penalize_nl: bool,
|
| 825 |
+
ignore_eos: bool,
|
| 826 |
+
):
|
| 827 |
+
sampler = llama_cpp.llama_sampler_init_penalties(
|
| 828 |
+
penalty_last_n,
|
| 829 |
+
penalty_repeat,
|
| 830 |
+
penalty_freq,
|
| 831 |
+
penalty_present,
|
| 832 |
+
)
|
| 833 |
+
self._add_sampler(sampler)
|
| 834 |
+
|
| 835 |
+
def init_logit_bias(
|
| 836 |
+
self, n_vocab: int, n_logit_bias, logit_bias: llama_cpp.llama_logit_bias_p
|
| 837 |
+
):
|
| 838 |
+
sampler = llama_cpp.llama_sampler_init_logit_bias(
|
| 839 |
+
n_vocab, n_logit_bias, logit_bias
|
| 840 |
+
)
|
| 841 |
+
self._add_sampler(sampler)
|
| 842 |
+
|
| 843 |
+
def add_custom(
|
| 844 |
+
self, apply_func: Callable[[llama_cpp.llama_token_data_array], None]
|
| 845 |
+
):
|
| 846 |
+
custom_sampler = CustomSampler(apply_func)
|
| 847 |
+
sampler = custom_sampler.get_sampler()
|
| 848 |
+
self._add_sampler(sampler)
|
| 849 |
+
# NOTE: Must remove custom samplers before free or llama.cpp will try to free them
|
| 850 |
+
self.custom_samplers.append(
|
| 851 |
+
(llama_cpp.llama_sampler_chain_n(self.sampler) - 1, custom_sampler)
|
| 852 |
+
)
|
| 853 |
+
|
| 854 |
+
def _add_sampler(self, sampler: llama_cpp.llama_sampler_p):
|
| 855 |
+
assert self.sampler is not None
|
| 856 |
+
llama_cpp.llama_sampler_chain_add(self.sampler, sampler)
|
| 857 |
+
self.samplers.append(sampler)
|
| 858 |
+
|
| 859 |
+
def get_seed(self) -> int:
|
| 860 |
+
assert self.sampler is not None
|
| 861 |
+
return llama_cpp.llama_sampler_get_seed(self.sampler)
|
| 862 |
+
|
| 863 |
+
def sample(self, ctx: LlamaContext, idx: int) -> int:
|
| 864 |
+
assert self.sampler is not None
|
| 865 |
+
assert ctx.ctx is not None
|
| 866 |
+
return llama_cpp.llama_sampler_sample(self.sampler, ctx.ctx, idx)
|
| 867 |
+
|
| 868 |
+
def close(self):
|
| 869 |
+
if self.sampler:
|
| 870 |
+
# NOTE: Must remove custom samplers before free or llama.cpp will try to free them
|
| 871 |
+
for i, _ in reversed(self.custom_samplers):
|
| 872 |
+
llama_cpp.llama_sampler_chain_remove(self.sampler, i)
|
| 873 |
+
llama_cpp.llama_sampler_free(self.sampler)
|
| 874 |
+
self.sampler = None
|
| 875 |
+
self.samplers.clear()
|
| 876 |
+
self.custom_samplers.clear()
|
| 877 |
+
|
| 878 |
+
def __del__(self):
|
| 879 |
+
self.close()
|
llama_cpp/_logger.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import ctypes
|
| 3 |
+
import logging
|
| 4 |
+
|
| 5 |
+
import llama_cpp
|
| 6 |
+
|
| 7 |
+
# enum ggml_log_level {
|
| 8 |
+
# GGML_LOG_LEVEL_NONE = 0,
|
| 9 |
+
# GGML_LOG_LEVEL_INFO = 1,
|
| 10 |
+
# GGML_LOG_LEVEL_WARN = 2,
|
| 11 |
+
# GGML_LOG_LEVEL_ERROR = 3,
|
| 12 |
+
# GGML_LOG_LEVEL_DEBUG = 4,
|
| 13 |
+
# GGML_LOG_LEVEL_CONT = 5, // continue previous log
|
| 14 |
+
# };
|
| 15 |
+
GGML_LOG_LEVEL_TO_LOGGING_LEVEL = {
|
| 16 |
+
0: logging.CRITICAL,
|
| 17 |
+
1: logging.INFO,
|
| 18 |
+
2: logging.WARNING,
|
| 19 |
+
3: logging.ERROR,
|
| 20 |
+
4: logging.DEBUG,
|
| 21 |
+
5: logging.DEBUG,
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
logger = logging.getLogger("llama-cpp-python")
|
| 25 |
+
|
| 26 |
+
_last_log_level = GGML_LOG_LEVEL_TO_LOGGING_LEVEL[0]
|
| 27 |
+
|
| 28 |
+
# typedef void (*ggml_log_callback)(enum ggml_log_level level, const char * text, void * user_data);
|
| 29 |
+
@llama_cpp.llama_log_callback
|
| 30 |
+
def llama_log_callback(
|
| 31 |
+
level: int,
|
| 32 |
+
text: bytes,
|
| 33 |
+
user_data: ctypes.c_void_p,
|
| 34 |
+
):
|
| 35 |
+
# TODO: Correctly implement continue previous log
|
| 36 |
+
global _last_log_level
|
| 37 |
+
log_level = GGML_LOG_LEVEL_TO_LOGGING_LEVEL[level] if level != 5 else _last_log_level
|
| 38 |
+
if logger.level <= GGML_LOG_LEVEL_TO_LOGGING_LEVEL[level]:
|
| 39 |
+
print(text.decode("utf-8"), end="", flush=True, file=sys.stderr)
|
| 40 |
+
_last_log_level = log_level
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
llama_cpp.llama_log_set(llama_log_callback, ctypes.c_void_p(0))
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def set_verbose(verbose: bool):
|
| 47 |
+
logger.setLevel(logging.DEBUG if verbose else logging.ERROR)
|
llama_cpp/_utils.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
|
| 4 |
+
from typing import Any, Dict
|
| 5 |
+
|
| 6 |
+
# Avoid "LookupError: unknown encoding: ascii" when open() called in a destructor
|
| 7 |
+
outnull_file = open(os.devnull, "w")
|
| 8 |
+
errnull_file = open(os.devnull, "w")
|
| 9 |
+
|
| 10 |
+
STDOUT_FILENO = 1
|
| 11 |
+
STDERR_FILENO = 2
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class suppress_stdout_stderr(object):
|
| 15 |
+
# NOTE: these must be "saved" here to avoid exceptions when using
|
| 16 |
+
# this context manager inside of a __del__ method
|
| 17 |
+
sys = sys
|
| 18 |
+
os = os
|
| 19 |
+
|
| 20 |
+
def __init__(self, disable: bool = True):
|
| 21 |
+
self.disable = disable
|
| 22 |
+
|
| 23 |
+
# Oddly enough this works better than the contextlib version
|
| 24 |
+
def __enter__(self):
|
| 25 |
+
if self.disable:
|
| 26 |
+
return self
|
| 27 |
+
|
| 28 |
+
self.old_stdout_fileno_undup = STDOUT_FILENO
|
| 29 |
+
self.old_stderr_fileno_undup = STDERR_FILENO
|
| 30 |
+
|
| 31 |
+
self.old_stdout_fileno = self.os.dup(self.old_stdout_fileno_undup)
|
| 32 |
+
self.old_stderr_fileno = self.os.dup(self.old_stderr_fileno_undup)
|
| 33 |
+
|
| 34 |
+
self.old_stdout = self.sys.stdout
|
| 35 |
+
self.old_stderr = self.sys.stderr
|
| 36 |
+
|
| 37 |
+
self.os.dup2(outnull_file.fileno(), self.old_stdout_fileno_undup)
|
| 38 |
+
self.os.dup2(errnull_file.fileno(), self.old_stderr_fileno_undup)
|
| 39 |
+
|
| 40 |
+
self.sys.stdout = outnull_file
|
| 41 |
+
self.sys.stderr = errnull_file
|
| 42 |
+
return self
|
| 43 |
+
|
| 44 |
+
def __exit__(self, *_):
|
| 45 |
+
if self.disable:
|
| 46 |
+
return
|
| 47 |
+
|
| 48 |
+
# Check if sys.stdout and sys.stderr have fileno method
|
| 49 |
+
self.sys.stdout = self.old_stdout
|
| 50 |
+
self.sys.stderr = self.old_stderr
|
| 51 |
+
|
| 52 |
+
self.os.dup2(self.old_stdout_fileno, self.old_stdout_fileno_undup)
|
| 53 |
+
self.os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup)
|
| 54 |
+
|
| 55 |
+
self.os.close(self.old_stdout_fileno)
|
| 56 |
+
self.os.close(self.old_stderr_fileno)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class MetaSingleton(type):
|
| 60 |
+
"""
|
| 61 |
+
Metaclass for implementing the Singleton pattern.
|
| 62 |
+
"""
|
| 63 |
+
|
| 64 |
+
_instances: Dict[type, Any] = {}
|
| 65 |
+
|
| 66 |
+
def __call__(cls, *args: Any, **kwargs: Any) -> Any:
|
| 67 |
+
if cls not in cls._instances:
|
| 68 |
+
cls._instances[cls] = super(MetaSingleton, cls).__call__(*args, **kwargs)
|
| 69 |
+
return cls._instances[cls]
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class Singleton(object, metaclass=MetaSingleton):
|
| 73 |
+
"""
|
| 74 |
+
Base class for implementing the Singleton pattern.
|
| 75 |
+
"""
|
| 76 |
+
|
| 77 |
+
def __init__(self):
|
| 78 |
+
super(Singleton, self).__init__()
|
llama_cpp/lib/libggml-base.dylib
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b10bbd734ef61868c17ddad1e15a08bcb93f711475f02b8df167256483f6b80c
|
| 3 |
+
size 548128
|
llama_cpp/lib/libggml-blas.dylib
ADDED
|
Binary file (54.8 kB). View file
|
|
|
llama_cpp/lib/libggml-cpu.dylib
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e842f3d817bdc4366d5b0b85a888bbe75bddd77e9cb320c9fcaacc7916cba08f
|
| 3 |
+
size 520192
|
llama_cpp/lib/libggml-metal.dylib
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:965467ace598c542be5b05efa0edd6f28e8d9b3bef2d10e80f547b92cb05140c
|
| 3 |
+
size 600160
|
llama_cpp/lib/libggml.dylib
ADDED
|
Binary file (56.4 kB). View file
|
|
|
llama_cpp/lib/libllama.dylib
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:46ce63cb4168539abcaa517d7023aabf9f111d56ca64b5f8aa2c6fe991ba9910
|
| 3 |
+
size 1140160
|
llama_cpp/lib/libllava.dylib
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:839335da6f28ab29c340a9c61e3dffbdbde54bfb9e68e52d3bb2e7255645506e
|
| 3 |
+
size 337424
|
llama_cpp/llama.py
ADDED
|
@@ -0,0 +1,2418 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
import uuid
|
| 6 |
+
import time
|
| 7 |
+
import json
|
| 8 |
+
import ctypes
|
| 9 |
+
import typing
|
| 10 |
+
import random
|
| 11 |
+
import fnmatch
|
| 12 |
+
import warnings
|
| 13 |
+
import contextlib
|
| 14 |
+
import multiprocessing
|
| 15 |
+
|
| 16 |
+
from typing import (
|
| 17 |
+
Any,
|
| 18 |
+
List,
|
| 19 |
+
Literal,
|
| 20 |
+
Optional,
|
| 21 |
+
Union,
|
| 22 |
+
Generator,
|
| 23 |
+
Sequence,
|
| 24 |
+
Iterator,
|
| 25 |
+
Deque,
|
| 26 |
+
Callable,
|
| 27 |
+
Dict,
|
| 28 |
+
)
|
| 29 |
+
from collections import deque
|
| 30 |
+
from pathlib import Path
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
from .llama_types import *
|
| 34 |
+
from .llama_grammar import LlamaGrammar
|
| 35 |
+
from .llama_cache import (
|
| 36 |
+
BaseLlamaCache,
|
| 37 |
+
LlamaCache, # type: ignore
|
| 38 |
+
LlamaDiskCache, # type: ignore
|
| 39 |
+
LlamaRAMCache, # type: ignore
|
| 40 |
+
)
|
| 41 |
+
from .llama_tokenizer import BaseLlamaTokenizer, LlamaTokenizer
|
| 42 |
+
import llama_cpp.llama_cpp as llama_cpp
|
| 43 |
+
import llama_cpp.llama_chat_format as llama_chat_format
|
| 44 |
+
|
| 45 |
+
from llama_cpp.llama_speculative import LlamaDraftModel
|
| 46 |
+
|
| 47 |
+
import numpy as np
|
| 48 |
+
import numpy.typing as npt
|
| 49 |
+
|
| 50 |
+
import llama_cpp._internals as internals
|
| 51 |
+
from ._logger import set_verbose
|
| 52 |
+
from ._utils import suppress_stdout_stderr
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class Llama:
|
| 56 |
+
"""High-level Python wrapper for a llama.cpp model."""
|
| 57 |
+
|
| 58 |
+
__backend_initialized = False
|
| 59 |
+
|
| 60 |
+
def __init__(
|
| 61 |
+
self,
|
| 62 |
+
model_path: str,
|
| 63 |
+
*,
|
| 64 |
+
# Model Params
|
| 65 |
+
n_gpu_layers: int = 0,
|
| 66 |
+
split_mode: int = llama_cpp.LLAMA_SPLIT_MODE_LAYER,
|
| 67 |
+
main_gpu: int = 0,
|
| 68 |
+
tensor_split: Optional[List[float]] = None,
|
| 69 |
+
rpc_servers: Optional[str] = None,
|
| 70 |
+
vocab_only: bool = False,
|
| 71 |
+
use_mmap: bool = True,
|
| 72 |
+
use_mlock: bool = False,
|
| 73 |
+
kv_overrides: Optional[Dict[str, Union[bool, int, float, str]]] = None,
|
| 74 |
+
# Context Params
|
| 75 |
+
seed: int = llama_cpp.LLAMA_DEFAULT_SEED,
|
| 76 |
+
n_ctx: int = 512,
|
| 77 |
+
n_batch: int = 512,
|
| 78 |
+
n_ubatch: int = 512,
|
| 79 |
+
n_threads: Optional[int] = None,
|
| 80 |
+
n_threads_batch: Optional[int] = None,
|
| 81 |
+
rope_scaling_type: Optional[
|
| 82 |
+
int
|
| 83 |
+
] = llama_cpp.LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED,
|
| 84 |
+
pooling_type: int = llama_cpp.LLAMA_POOLING_TYPE_UNSPECIFIED,
|
| 85 |
+
rope_freq_base: float = 0.0,
|
| 86 |
+
rope_freq_scale: float = 0.0,
|
| 87 |
+
yarn_ext_factor: float = -1.0,
|
| 88 |
+
yarn_attn_factor: float = 1.0,
|
| 89 |
+
yarn_beta_fast: float = 32.0,
|
| 90 |
+
yarn_beta_slow: float = 1.0,
|
| 91 |
+
yarn_orig_ctx: int = 0,
|
| 92 |
+
logits_all: bool = False,
|
| 93 |
+
embedding: bool = False,
|
| 94 |
+
offload_kqv: bool = True,
|
| 95 |
+
flash_attn: bool = False,
|
| 96 |
+
# Sampling Params
|
| 97 |
+
no_perf: bool = False,
|
| 98 |
+
last_n_tokens_size: int = 64,
|
| 99 |
+
# LoRA Params
|
| 100 |
+
lora_base: Optional[str] = None,
|
| 101 |
+
lora_scale: float = 1.0,
|
| 102 |
+
lora_path: Optional[str] = None,
|
| 103 |
+
# Backend Params
|
| 104 |
+
numa: Union[bool, int] = False,
|
| 105 |
+
# Chat Format Params
|
| 106 |
+
chat_format: Optional[str] = None,
|
| 107 |
+
chat_handler: Optional[llama_chat_format.LlamaChatCompletionHandler] = None,
|
| 108 |
+
# Speculative Decoding
|
| 109 |
+
draft_model: Optional[LlamaDraftModel] = None,
|
| 110 |
+
# Tokenizer Override
|
| 111 |
+
tokenizer: Optional[BaseLlamaTokenizer] = None,
|
| 112 |
+
# KV cache quantization
|
| 113 |
+
type_k: Optional[int] = None,
|
| 114 |
+
type_v: Optional[int] = None,
|
| 115 |
+
# Misc
|
| 116 |
+
spm_infill: bool = False,
|
| 117 |
+
verbose: bool = True,
|
| 118 |
+
# Extra Params
|
| 119 |
+
**kwargs, # type: ignore
|
| 120 |
+
):
|
| 121 |
+
"""Load a llama.cpp model from `model_path`.
|
| 122 |
+
|
| 123 |
+
Examples:
|
| 124 |
+
Basic usage
|
| 125 |
+
|
| 126 |
+
>>> import llama_cpp
|
| 127 |
+
>>> model = llama_cpp.Llama(
|
| 128 |
+
... model_path="path/to/model",
|
| 129 |
+
... )
|
| 130 |
+
>>> print(model("The quick brown fox jumps ", stop=["."])["choices"][0]["text"])
|
| 131 |
+
the lazy dog
|
| 132 |
+
|
| 133 |
+
Loading a chat model
|
| 134 |
+
|
| 135 |
+
>>> import llama_cpp
|
| 136 |
+
>>> model = llama_cpp.Llama(
|
| 137 |
+
... model_path="path/to/model",
|
| 138 |
+
... chat_format="llama-2",
|
| 139 |
+
... )
|
| 140 |
+
>>> print(model.create_chat_completion(
|
| 141 |
+
... messages=[{
|
| 142 |
+
... "role": "user",
|
| 143 |
+
... "content": "what is the meaning of life?"
|
| 144 |
+
... }]
|
| 145 |
+
... ))
|
| 146 |
+
|
| 147 |
+
Args:
|
| 148 |
+
model_path: Path to the model.
|
| 149 |
+
n_gpu_layers: Number of layers to offload to GPU (-ngl). If -1, all layers are offloaded.
|
| 150 |
+
split_mode: How to split the model across GPUs. See llama_cpp.LLAMA_SPLIT_* for options.
|
| 151 |
+
main_gpu: main_gpu interpretation depends on split_mode: LLAMA_SPLIT_MODE_NONE: the GPU that is used for the entire model. LLAMA_SPLIT_MODE_ROW: the GPU that is used for small tensors and intermediate results. LLAMA_SPLIT_MODE_LAYER: ignored
|
| 152 |
+
tensor_split: How split tensors should be distributed across GPUs. If None, the model is not split.
|
| 153 |
+
rpc_servers: Comma separated list of RPC servers to use for offloading
|
| 154 |
+
vocab_only: Only load the vocabulary no weights.
|
| 155 |
+
use_mmap: Use mmap if possible.
|
| 156 |
+
use_mlock: Force the system to keep the model in RAM.
|
| 157 |
+
kv_overrides: Key-value overrides for the model.
|
| 158 |
+
seed: RNG seed, -1 for random
|
| 159 |
+
n_ctx: Text context, 0 = from model
|
| 160 |
+
n_batch: Prompt processing maximum batch size
|
| 161 |
+
n_ubatch: Physical batch size
|
| 162 |
+
n_threads: Number of threads to use for generation
|
| 163 |
+
n_threads_batch: Number of threads to use for batch processing
|
| 164 |
+
rope_scaling_type: RoPE scaling type, from `enum llama_rope_scaling_type`. ref: https://github.com/ggerganov/llama.cpp/pull/2054
|
| 165 |
+
pooling_type: Pooling type, from `enum llama_pooling_type`.
|
| 166 |
+
rope_freq_base: RoPE base frequency, 0 = from model
|
| 167 |
+
rope_freq_scale: RoPE frequency scaling factor, 0 = from model
|
| 168 |
+
yarn_ext_factor: YaRN extrapolation mix factor, negative = from model
|
| 169 |
+
yarn_attn_factor: YaRN magnitude scaling factor
|
| 170 |
+
yarn_beta_fast: YaRN low correction dim
|
| 171 |
+
yarn_beta_slow: YaRN high correction dim
|
| 172 |
+
yarn_orig_ctx: YaRN original context size
|
| 173 |
+
logits_all: Return logits for all tokens, not just the last token. Must be True for completion to return logprobs.
|
| 174 |
+
embedding: Embedding mode only.
|
| 175 |
+
offload_kqv: Offload K, Q, V to GPU.
|
| 176 |
+
flash_attn: Use flash attention.
|
| 177 |
+
no_perf: Measure performance timings.
|
| 178 |
+
last_n_tokens_size: Maximum number of tokens to keep in the last_n_tokens deque.
|
| 179 |
+
lora_base: Optional path to base model, useful if using a quantized base model and you want to apply LoRA to an f16 model.
|
| 180 |
+
lora_path: Path to a LoRA file to apply to the model.
|
| 181 |
+
numa: numa policy
|
| 182 |
+
chat_format: String specifying the chat format to use when calling create_chat_completion.
|
| 183 |
+
chat_handler: Optional chat handler to use when calling create_chat_completion.
|
| 184 |
+
draft_model: Optional draft model to use for speculative decoding.
|
| 185 |
+
tokenizer: Optional tokenizer to override the default tokenizer from llama.cpp.
|
| 186 |
+
verbose: Print verbose output to stderr.
|
| 187 |
+
type_k: KV cache data type for K (default: f16)
|
| 188 |
+
type_v: KV cache data type for V (default: f16)
|
| 189 |
+
spm_infill: Use Suffix/Prefix/Middle pattern for infill (instead of Prefix/Suffix/Middle) as some models prefer this.
|
| 190 |
+
|
| 191 |
+
Raises:
|
| 192 |
+
ValueError: If the model path does not exist.
|
| 193 |
+
|
| 194 |
+
Returns:
|
| 195 |
+
A Llama instance.
|
| 196 |
+
"""
|
| 197 |
+
self.verbose = verbose
|
| 198 |
+
self._stack = contextlib.ExitStack()
|
| 199 |
+
|
| 200 |
+
set_verbose(verbose)
|
| 201 |
+
|
| 202 |
+
if not Llama.__backend_initialized:
|
| 203 |
+
with suppress_stdout_stderr(disable=verbose):
|
| 204 |
+
llama_cpp.llama_backend_init()
|
| 205 |
+
Llama.__backend_initialized = True
|
| 206 |
+
|
| 207 |
+
if isinstance(numa, bool):
|
| 208 |
+
self.numa = (
|
| 209 |
+
llama_cpp.GGML_NUMA_STRATEGY_DISTRIBUTE
|
| 210 |
+
if numa
|
| 211 |
+
else llama_cpp.GGML_NUMA_STRATEGY_DISABLED
|
| 212 |
+
)
|
| 213 |
+
else:
|
| 214 |
+
self.numa = numa
|
| 215 |
+
|
| 216 |
+
if self.numa != llama_cpp.GGML_NUMA_STRATEGY_DISABLED:
|
| 217 |
+
with suppress_stdout_stderr(disable=verbose):
|
| 218 |
+
llama_cpp.llama_numa_init(self.numa)
|
| 219 |
+
|
| 220 |
+
self.model_path = model_path
|
| 221 |
+
|
| 222 |
+
# Model Params
|
| 223 |
+
self.model_params = llama_cpp.llama_model_default_params()
|
| 224 |
+
self.model_params.n_gpu_layers = (
|
| 225 |
+
0x7FFFFFFF if n_gpu_layers == -1 else n_gpu_layers
|
| 226 |
+
) # 0x7FFFFFFF is INT32 max, will be auto set to all layers
|
| 227 |
+
self.model_params.split_mode = split_mode
|
| 228 |
+
self.model_params.main_gpu = main_gpu
|
| 229 |
+
if rpc_servers is not None:
|
| 230 |
+
self.model_params.rpc_servers = rpc_servers.encode("utf-8")
|
| 231 |
+
self._rpc_servers = rpc_servers
|
| 232 |
+
else:
|
| 233 |
+
self._rpc_servers = None
|
| 234 |
+
self.tensor_split = tensor_split
|
| 235 |
+
self._c_tensor_split = None
|
| 236 |
+
if self.tensor_split is not None:
|
| 237 |
+
if len(self.tensor_split) > llama_cpp.LLAMA_MAX_DEVICES:
|
| 238 |
+
raise ValueError(
|
| 239 |
+
f"Attempt to split tensors that exceed maximum supported devices. Current LLAMA_MAX_DEVICES={llama_cpp.LLAMA_MAX_DEVICES}"
|
| 240 |
+
)
|
| 241 |
+
# Type conversion and expand the list to the length of LLAMA_MAX_DEVICES
|
| 242 |
+
FloatArray = ctypes.c_float * llama_cpp.LLAMA_MAX_DEVICES
|
| 243 |
+
self._c_tensor_split = FloatArray(
|
| 244 |
+
*tensor_split # type: ignore
|
| 245 |
+
) # keep a reference to the array so it is not gc'd
|
| 246 |
+
self.model_params.tensor_split = self._c_tensor_split
|
| 247 |
+
self.model_params.vocab_only = vocab_only
|
| 248 |
+
self.model_params.use_mmap = use_mmap if lora_path is None else False
|
| 249 |
+
self.model_params.use_mlock = use_mlock
|
| 250 |
+
|
| 251 |
+
# kv_overrides is the original python dict
|
| 252 |
+
self.kv_overrides = kv_overrides
|
| 253 |
+
if kv_overrides is not None:
|
| 254 |
+
# _kv_overrides_array is a ctypes.Array of llama_model_kv_override Structs
|
| 255 |
+
kvo_array_len = len(kv_overrides) + 1 # for sentinel element
|
| 256 |
+
self._kv_overrides_array = (
|
| 257 |
+
llama_cpp.llama_model_kv_override * kvo_array_len
|
| 258 |
+
)()
|
| 259 |
+
|
| 260 |
+
for i, (k, v) in enumerate(kv_overrides.items()):
|
| 261 |
+
self._kv_overrides_array[i].key = k.encode("utf-8")
|
| 262 |
+
if isinstance(v, bool):
|
| 263 |
+
self._kv_overrides_array[
|
| 264 |
+
i
|
| 265 |
+
].tag = llama_cpp.LLAMA_KV_OVERRIDE_TYPE_BOOL
|
| 266 |
+
self._kv_overrides_array[i].value.val_bool = v
|
| 267 |
+
elif isinstance(v, int):
|
| 268 |
+
self._kv_overrides_array[
|
| 269 |
+
i
|
| 270 |
+
].tag = llama_cpp.LLAMA_KV_OVERRIDE_TYPE_INT
|
| 271 |
+
self._kv_overrides_array[i].value.val_i64 = v
|
| 272 |
+
elif isinstance(v, float):
|
| 273 |
+
self._kv_overrides_array[
|
| 274 |
+
i
|
| 275 |
+
].tag = llama_cpp.LLAMA_KV_OVERRIDE_TYPE_FLOAT
|
| 276 |
+
self._kv_overrides_array[i].value.val_f64 = v
|
| 277 |
+
elif isinstance(v, str): # type: ignore
|
| 278 |
+
v_bytes = v.encode("utf-8")
|
| 279 |
+
if len(v_bytes) > 128: # TODO: Make this a constant
|
| 280 |
+
raise ValueError(f"Value for {k} is too long: {v}")
|
| 281 |
+
v_bytes = v_bytes.ljust(128, b"\0")
|
| 282 |
+
self._kv_overrides_array[
|
| 283 |
+
i
|
| 284 |
+
].tag = llama_cpp.LLAMA_KV_OVERRIDE_TYPE_STR
|
| 285 |
+
# copy min(v_bytes, 128) to str_value
|
| 286 |
+
address = typing.cast(
|
| 287 |
+
int,
|
| 288 |
+
ctypes.addressof(self._kv_overrides_array[i].value)
|
| 289 |
+
+ llama_cpp.llama_model_kv_override_value.val_str.offset,
|
| 290 |
+
)
|
| 291 |
+
buffer_start = ctypes.cast(address, ctypes.POINTER(ctypes.c_char))
|
| 292 |
+
ctypes.memmove(
|
| 293 |
+
buffer_start,
|
| 294 |
+
v_bytes,
|
| 295 |
+
128,
|
| 296 |
+
)
|
| 297 |
+
else:
|
| 298 |
+
raise ValueError(f"Unknown value type for {k}: {v}")
|
| 299 |
+
|
| 300 |
+
self._kv_overrides_array[
|
| 301 |
+
-1
|
| 302 |
+
].key = b"\0" # ensure sentinel element is zeroed
|
| 303 |
+
self.model_params.kv_overrides = self._kv_overrides_array
|
| 304 |
+
|
| 305 |
+
self.n_batch = min(n_ctx, n_batch) # ???
|
| 306 |
+
self.n_threads = n_threads or max(multiprocessing.cpu_count() // 2, 1)
|
| 307 |
+
self.n_threads_batch = n_threads_batch or multiprocessing.cpu_count()
|
| 308 |
+
|
| 309 |
+
# Used by the sampler
|
| 310 |
+
self._seed = seed or llama_cpp.LLAMA_DEFAULT_SEED
|
| 311 |
+
|
| 312 |
+
# Context Params
|
| 313 |
+
self.context_params = llama_cpp.llama_context_default_params()
|
| 314 |
+
self.context_params.n_ctx = n_ctx
|
| 315 |
+
self.context_params.n_batch = self.n_batch
|
| 316 |
+
self.context_params.n_ubatch = min(self.n_batch, n_ubatch)
|
| 317 |
+
self.context_params.n_threads = self.n_threads
|
| 318 |
+
self.context_params.n_threads_batch = self.n_threads_batch
|
| 319 |
+
self.context_params.rope_scaling_type = (
|
| 320 |
+
rope_scaling_type
|
| 321 |
+
if rope_scaling_type is not None
|
| 322 |
+
else llama_cpp.LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED
|
| 323 |
+
)
|
| 324 |
+
self.context_params.pooling_type = pooling_type
|
| 325 |
+
self.context_params.rope_freq_base = (
|
| 326 |
+
rope_freq_base if rope_freq_base != 0.0 else 0
|
| 327 |
+
)
|
| 328 |
+
self.context_params.rope_freq_scale = (
|
| 329 |
+
rope_freq_scale if rope_freq_scale != 0.0 else 0
|
| 330 |
+
)
|
| 331 |
+
self.context_params.yarn_ext_factor = (
|
| 332 |
+
yarn_ext_factor if yarn_ext_factor != 0.0 else 0
|
| 333 |
+
)
|
| 334 |
+
self.context_params.yarn_attn_factor = (
|
| 335 |
+
yarn_attn_factor if yarn_attn_factor != 0.0 else 0
|
| 336 |
+
)
|
| 337 |
+
self.context_params.yarn_beta_fast = (
|
| 338 |
+
yarn_beta_fast if yarn_beta_fast != 0.0 else 0
|
| 339 |
+
)
|
| 340 |
+
self.context_params.yarn_beta_slow = (
|
| 341 |
+
yarn_beta_slow if yarn_beta_slow != 0.0 else 0
|
| 342 |
+
)
|
| 343 |
+
self.context_params.yarn_orig_ctx = yarn_orig_ctx if yarn_orig_ctx != 0 else 0
|
| 344 |
+
self.context_params.logits_all = (
|
| 345 |
+
logits_all if draft_model is None else True
|
| 346 |
+
) # Must be set to True for speculative decoding
|
| 347 |
+
self.context_params.embeddings = embedding # TODO: Rename to embeddings
|
| 348 |
+
self.context_params.offload_kqv = offload_kqv
|
| 349 |
+
self.context_params.flash_attn = flash_attn
|
| 350 |
+
# KV cache quantization
|
| 351 |
+
if type_k is not None:
|
| 352 |
+
self.context_params.type_k = type_k
|
| 353 |
+
if type_v is not None:
|
| 354 |
+
self.context_params.type_v = type_v
|
| 355 |
+
# Sampling Params
|
| 356 |
+
self.context_params.no_perf = no_perf
|
| 357 |
+
self.last_n_tokens_size = last_n_tokens_size
|
| 358 |
+
|
| 359 |
+
self.cache: Optional[BaseLlamaCache] = None
|
| 360 |
+
|
| 361 |
+
self.lora_base = lora_base
|
| 362 |
+
self.lora_scale = lora_scale
|
| 363 |
+
self.lora_path = lora_path
|
| 364 |
+
|
| 365 |
+
self.spm_infill = spm_infill
|
| 366 |
+
|
| 367 |
+
if not os.path.exists(model_path):
|
| 368 |
+
raise ValueError(f"Model path does not exist: {model_path}")
|
| 369 |
+
|
| 370 |
+
self._model = self._stack.enter_context(
|
| 371 |
+
contextlib.closing(
|
| 372 |
+
internals.LlamaModel(
|
| 373 |
+
path_model=self.model_path,
|
| 374 |
+
params=self.model_params,
|
| 375 |
+
verbose=self.verbose,
|
| 376 |
+
)
|
| 377 |
+
)
|
| 378 |
+
)
|
| 379 |
+
|
| 380 |
+
# Override tokenizer
|
| 381 |
+
self.tokenizer_ = tokenizer or LlamaTokenizer(self)
|
| 382 |
+
|
| 383 |
+
# Set the default value for the context and correct the batch
|
| 384 |
+
if n_ctx == 0:
|
| 385 |
+
n_ctx = self._model.n_ctx_train()
|
| 386 |
+
self.n_batch = min(n_ctx, n_batch)
|
| 387 |
+
self.context_params.n_ctx = self._model.n_ctx_train()
|
| 388 |
+
self.context_params.n_batch = self.n_batch
|
| 389 |
+
self.context_params.n_ubatch = min(self.n_batch, n_ubatch)
|
| 390 |
+
|
| 391 |
+
self._ctx = self._stack.enter_context(
|
| 392 |
+
contextlib.closing(
|
| 393 |
+
internals.LlamaContext(
|
| 394 |
+
model=self._model,
|
| 395 |
+
params=self.context_params,
|
| 396 |
+
verbose=self.verbose,
|
| 397 |
+
)
|
| 398 |
+
)
|
| 399 |
+
)
|
| 400 |
+
|
| 401 |
+
self._batch = self._stack.enter_context(
|
| 402 |
+
contextlib.closing(
|
| 403 |
+
internals.LlamaBatch(
|
| 404 |
+
n_tokens=self.n_batch,
|
| 405 |
+
embd=0,
|
| 406 |
+
n_seq_max=self.context_params.n_ctx,
|
| 407 |
+
verbose=self.verbose,
|
| 408 |
+
)
|
| 409 |
+
)
|
| 410 |
+
)
|
| 411 |
+
|
| 412 |
+
self._lora_adapter: Optional[llama_cpp.llama_adapter_lora_p] = None
|
| 413 |
+
|
| 414 |
+
if self.lora_path:
|
| 415 |
+
self._lora_adapter = llama_cpp.llama_adapter_lora_init(
|
| 416 |
+
self._model.model,
|
| 417 |
+
self.lora_path.encode("utf-8"),
|
| 418 |
+
)
|
| 419 |
+
if self._lora_adapter is None:
|
| 420 |
+
raise RuntimeError(
|
| 421 |
+
f"Failed to initialize LoRA adapter from lora path: {self.lora_path}"
|
| 422 |
+
)
|
| 423 |
+
|
| 424 |
+
def free_lora_adapter():
|
| 425 |
+
if self._lora_adapter is None:
|
| 426 |
+
return
|
| 427 |
+
llama_cpp.llama_adapter_lora_free(self._lora_adapter)
|
| 428 |
+
self._lora_adapter = None
|
| 429 |
+
|
| 430 |
+
self._stack.callback(free_lora_adapter)
|
| 431 |
+
|
| 432 |
+
if llama_cpp.llama_set_adapter_lora(
|
| 433 |
+
self._ctx.ctx, self._lora_adapter, self.lora_scale
|
| 434 |
+
):
|
| 435 |
+
raise RuntimeError(
|
| 436 |
+
f"Failed to set LoRA adapter from lora path: {self.lora_path}"
|
| 437 |
+
)
|
| 438 |
+
|
| 439 |
+
if self.verbose:
|
| 440 |
+
print(llama_cpp.llama_print_system_info().decode("utf-8"), file=sys.stderr)
|
| 441 |
+
|
| 442 |
+
self.chat_format = chat_format
|
| 443 |
+
self.chat_handler = chat_handler
|
| 444 |
+
self._chat_handlers: Dict[
|
| 445 |
+
str, llama_chat_format.LlamaChatCompletionHandler
|
| 446 |
+
] = {}
|
| 447 |
+
|
| 448 |
+
self.draft_model = draft_model
|
| 449 |
+
|
| 450 |
+
self._n_vocab = self.n_vocab()
|
| 451 |
+
self._n_ctx = self.n_ctx()
|
| 452 |
+
|
| 453 |
+
self._token_nl = self.token_nl()
|
| 454 |
+
self._token_eos = self.token_eos()
|
| 455 |
+
|
| 456 |
+
self._candidates = internals.LlamaTokenDataArray(n_vocab=self._n_vocab)
|
| 457 |
+
|
| 458 |
+
self.n_tokens = 0
|
| 459 |
+
self.input_ids: npt.NDArray[np.intc] = np.ndarray((n_ctx,), dtype=np.intc)
|
| 460 |
+
self.scores: npt.NDArray[np.single] = np.ndarray(
|
| 461 |
+
(n_ctx if logits_all == True else n_batch, self._n_vocab), dtype=np.single
|
| 462 |
+
)
|
| 463 |
+
|
| 464 |
+
self._mirostat_mu = ctypes.c_float(
|
| 465 |
+
2.0 * 5.0
|
| 466 |
+
) # TODO: Move this to sampling context
|
| 467 |
+
|
| 468 |
+
try:
|
| 469 |
+
self.metadata = self._model.metadata()
|
| 470 |
+
except Exception as e:
|
| 471 |
+
self.metadata = {}
|
| 472 |
+
if self.verbose:
|
| 473 |
+
print(f"Failed to load metadata: {e}", file=sys.stderr)
|
| 474 |
+
|
| 475 |
+
if self.verbose:
|
| 476 |
+
print(f"Model metadata: {self.metadata}", file=sys.stderr)
|
| 477 |
+
|
| 478 |
+
eos_token_id = self.token_eos()
|
| 479 |
+
bos_token_id = self.token_bos()
|
| 480 |
+
|
| 481 |
+
eos_token = (
|
| 482 |
+
self._model.token_get_text(eos_token_id) if eos_token_id != -1 else ""
|
| 483 |
+
)
|
| 484 |
+
bos_token = (
|
| 485 |
+
self._model.token_get_text(bos_token_id) if bos_token_id != -1 else ""
|
| 486 |
+
)
|
| 487 |
+
|
| 488 |
+
# Unfortunately the llama.cpp API does not return metadata arrays, so we can't get template names from tokenizer.chat_templates
|
| 489 |
+
template_choices = dict(
|
| 490 |
+
(name[10:], template)
|
| 491 |
+
for name, template in self.metadata.items()
|
| 492 |
+
if name.startswith("tokenizer.chat_template.")
|
| 493 |
+
)
|
| 494 |
+
|
| 495 |
+
if "tokenizer.chat_template" in self.metadata:
|
| 496 |
+
template_choices["chat_template.default"] = self.metadata[
|
| 497 |
+
"tokenizer.chat_template"
|
| 498 |
+
]
|
| 499 |
+
|
| 500 |
+
if self.verbose and template_choices:
|
| 501 |
+
print(
|
| 502 |
+
f"Available chat formats from metadata: {', '.join(template_choices.keys())}",
|
| 503 |
+
file=sys.stderr,
|
| 504 |
+
)
|
| 505 |
+
|
| 506 |
+
for name, template in template_choices.items():
|
| 507 |
+
self._chat_handlers[name] = llama_chat_format.Jinja2ChatFormatter(
|
| 508 |
+
template=template,
|
| 509 |
+
eos_token=eos_token,
|
| 510 |
+
bos_token=bos_token,
|
| 511 |
+
stop_token_ids=[eos_token_id],
|
| 512 |
+
).to_chat_handler()
|
| 513 |
+
|
| 514 |
+
if (
|
| 515 |
+
self.chat_format is None
|
| 516 |
+
and self.chat_handler is None
|
| 517 |
+
and "chat_template.default" in template_choices
|
| 518 |
+
):
|
| 519 |
+
chat_format = llama_chat_format.guess_chat_format_from_gguf_metadata(
|
| 520 |
+
self.metadata
|
| 521 |
+
)
|
| 522 |
+
|
| 523 |
+
if chat_format is not None:
|
| 524 |
+
self.chat_format = chat_format
|
| 525 |
+
if self.verbose:
|
| 526 |
+
print(f"Guessed chat format: {chat_format}", file=sys.stderr)
|
| 527 |
+
else:
|
| 528 |
+
if self.verbose:
|
| 529 |
+
print(
|
| 530 |
+
f"Using gguf chat template: {template_choices['chat_template.default']}",
|
| 531 |
+
file=sys.stderr,
|
| 532 |
+
)
|
| 533 |
+
print(f"Using chat eos_token: {eos_token}", file=sys.stderr)
|
| 534 |
+
print(f"Using chat bos_token: {bos_token}", file=sys.stderr)
|
| 535 |
+
|
| 536 |
+
self.chat_format = "chat_template.default"
|
| 537 |
+
|
| 538 |
+
if self.chat_format is None and self.chat_handler is None:
|
| 539 |
+
self.chat_format = "llama-2"
|
| 540 |
+
if self.verbose:
|
| 541 |
+
print(
|
| 542 |
+
f"Using fallback chat format: {self.chat_format}", file=sys.stderr
|
| 543 |
+
)
|
| 544 |
+
|
| 545 |
+
self._sampler = None
|
| 546 |
+
|
| 547 |
+
@property
|
| 548 |
+
def ctx(self) -> llama_cpp.llama_context_p:
|
| 549 |
+
return self._ctx.ctx
|
| 550 |
+
|
| 551 |
+
@property
|
| 552 |
+
def model(self) -> llama_cpp.llama_model_p:
|
| 553 |
+
return self._model.model
|
| 554 |
+
|
| 555 |
+
@property
|
| 556 |
+
def _input_ids(self) -> npt.NDArray[np.intc]:
|
| 557 |
+
return self.input_ids[: self.n_tokens]
|
| 558 |
+
|
| 559 |
+
@property
|
| 560 |
+
def _scores(self) -> npt.NDArray[np.single]:
|
| 561 |
+
return self.scores[: self.n_tokens, :]
|
| 562 |
+
|
| 563 |
+
@property
|
| 564 |
+
def eval_tokens(self) -> Deque[int]:
|
| 565 |
+
return deque(self.input_ids[: self.n_tokens].tolist(), maxlen=self._n_ctx)
|
| 566 |
+
|
| 567 |
+
@property
|
| 568 |
+
def eval_logits(self) -> Deque[List[float]]:
|
| 569 |
+
return deque(
|
| 570 |
+
self.scores[: self.n_tokens, :].tolist(),
|
| 571 |
+
maxlen=self._n_ctx if self.context_params.logits_all else 1,
|
| 572 |
+
)
|
| 573 |
+
|
| 574 |
+
def tokenize(
|
| 575 |
+
self, text: bytes, add_bos: bool = True, special: bool = False
|
| 576 |
+
) -> List[int]:
|
| 577 |
+
"""Tokenize a string.
|
| 578 |
+
|
| 579 |
+
Args:
|
| 580 |
+
text: The utf-8 encoded string to tokenize.
|
| 581 |
+
add_bos: Whether to add a beginning of sequence token.
|
| 582 |
+
special: Whether to tokenize special tokens.
|
| 583 |
+
|
| 584 |
+
Raises:
|
| 585 |
+
RuntimeError: If the tokenization failed.
|
| 586 |
+
|
| 587 |
+
Returns:
|
| 588 |
+
A list of tokens.
|
| 589 |
+
"""
|
| 590 |
+
return self.tokenizer_.tokenize(text, add_bos, special)
|
| 591 |
+
|
| 592 |
+
def detokenize(
|
| 593 |
+
self,
|
| 594 |
+
tokens: List[int],
|
| 595 |
+
prev_tokens: Optional[List[int]] = None,
|
| 596 |
+
special: bool = False,
|
| 597 |
+
) -> bytes:
|
| 598 |
+
"""Detokenize a list of tokens.
|
| 599 |
+
|
| 600 |
+
Args:
|
| 601 |
+
tokens: The list of tokens to detokenize.
|
| 602 |
+
prev_tokens: The list of previous tokens. Offset mapping will be performed if provided.
|
| 603 |
+
special: Whether to detokenize special tokens.
|
| 604 |
+
|
| 605 |
+
Returns:
|
| 606 |
+
The detokenized string.
|
| 607 |
+
"""
|
| 608 |
+
return self.tokenizer_.detokenize(
|
| 609 |
+
tokens, prev_tokens=prev_tokens, special=special
|
| 610 |
+
)
|
| 611 |
+
|
| 612 |
+
def set_cache(self, cache: Optional[BaseLlamaCache]):
|
| 613 |
+
"""Set the cache.
|
| 614 |
+
|
| 615 |
+
Args:
|
| 616 |
+
cache: The cache to set.
|
| 617 |
+
"""
|
| 618 |
+
self.cache = cache
|
| 619 |
+
|
| 620 |
+
def set_seed(self, seed: int):
|
| 621 |
+
"""Set the random seed.
|
| 622 |
+
|
| 623 |
+
Args:
|
| 624 |
+
seed: The random seed.
|
| 625 |
+
"""
|
| 626 |
+
self._seed = seed
|
| 627 |
+
|
| 628 |
+
def reset(self):
|
| 629 |
+
"""Reset the model state."""
|
| 630 |
+
self.n_tokens = 0
|
| 631 |
+
|
| 632 |
+
def eval(self, tokens: Sequence[int]):
|
| 633 |
+
"""Evaluate a list of tokens.
|
| 634 |
+
|
| 635 |
+
Args:
|
| 636 |
+
tokens: The list of tokens to evaluate.
|
| 637 |
+
"""
|
| 638 |
+
self._ctx.kv_cache_seq_rm(-1, self.n_tokens, -1)
|
| 639 |
+
for i in range(0, len(tokens), self.n_batch):
|
| 640 |
+
batch = tokens[i : min(len(tokens), i + self.n_batch)]
|
| 641 |
+
n_past = self.n_tokens
|
| 642 |
+
n_tokens = len(batch)
|
| 643 |
+
self._batch.set_batch(
|
| 644 |
+
batch=batch, n_past=n_past, logits_all=self.context_params.logits_all
|
| 645 |
+
)
|
| 646 |
+
self._ctx.decode(self._batch)
|
| 647 |
+
# Save tokens
|
| 648 |
+
self.input_ids[n_past : n_past + n_tokens] = batch
|
| 649 |
+
# Save logits
|
| 650 |
+
if self.context_params.logits_all:
|
| 651 |
+
rows = n_tokens
|
| 652 |
+
cols = self._n_vocab
|
| 653 |
+
logits = np.ctypeslib.as_array(
|
| 654 |
+
self._ctx.get_logits(), shape=(rows * cols,)
|
| 655 |
+
)
|
| 656 |
+
self.scores[n_past : n_past + n_tokens, :].reshape(-1)[::] = logits
|
| 657 |
+
else:
|
| 658 |
+
# rows = 1
|
| 659 |
+
# cols = self._n_vocab
|
| 660 |
+
# logits = np.ctypeslib.as_array(
|
| 661 |
+
# self._ctx.get_logits(), shape=(rows * cols,)
|
| 662 |
+
# )
|
| 663 |
+
# self.scores[n_past + n_tokens - 1, :].reshape(-1)[::] = logits
|
| 664 |
+
# NOTE: Now that sampling is done inside the sampler, logits are only needed for logprobs which requires logits_all
|
| 665 |
+
pass
|
| 666 |
+
# Update n_tokens
|
| 667 |
+
self.n_tokens += n_tokens
|
| 668 |
+
|
| 669 |
+
def _init_sampler(
|
| 670 |
+
self,
|
| 671 |
+
top_k: int = 40,
|
| 672 |
+
top_p: float = 0.95,
|
| 673 |
+
min_p: float = 0.05,
|
| 674 |
+
typical_p: float = 1.0,
|
| 675 |
+
temp: float = 0.80,
|
| 676 |
+
repeat_penalty: float = 1.0,
|
| 677 |
+
frequency_penalty: float = 0.0,
|
| 678 |
+
presence_penalty: float = 0.0,
|
| 679 |
+
tfs_z: float = 1.0,
|
| 680 |
+
mirostat_mode: int = 0,
|
| 681 |
+
mirostat_eta: float = 0.1,
|
| 682 |
+
mirostat_tau: float = 5.0,
|
| 683 |
+
penalize_nl: bool = True,
|
| 684 |
+
logits_processor: Optional[LogitsProcessorList] = None,
|
| 685 |
+
grammar: Optional[LlamaGrammar] = None,
|
| 686 |
+
):
|
| 687 |
+
sampler = internals.LlamaSampler()
|
| 688 |
+
|
| 689 |
+
if logits_processor is not None:
|
| 690 |
+
# Create and add a custom sampler
|
| 691 |
+
def apply_func(token_data_array: llama_cpp.llama_token_data_array_p):
|
| 692 |
+
size = token_data_array.contents.size
|
| 693 |
+
data_soa = token_data_array.contents.data
|
| 694 |
+
data_soa_address = ctypes.addressof(data_soa.contents)
|
| 695 |
+
# NOTE: This is probably broken
|
| 696 |
+
recarray = np.recarray(
|
| 697 |
+
shape=(size,),
|
| 698 |
+
dtype=np.dtype(
|
| 699 |
+
[("id", np.intc), ("logit", np.single), ("p", np.single)],
|
| 700 |
+
align=True,
|
| 701 |
+
),
|
| 702 |
+
buf=(llama_cpp.llama_token_data * size).from_address(
|
| 703 |
+
data_soa_address
|
| 704 |
+
),
|
| 705 |
+
)
|
| 706 |
+
for logit_processor in logits_processor:
|
| 707 |
+
recarray.logit[:] = logit_processor(self._input_ids, recarray.logit)
|
| 708 |
+
|
| 709 |
+
sampler.add_custom(apply_func)
|
| 710 |
+
|
| 711 |
+
sampler.add_penalties(
|
| 712 |
+
n_vocab=self._n_vocab,
|
| 713 |
+
special_eos_id=self._token_eos,
|
| 714 |
+
linefeed_id=self._token_nl,
|
| 715 |
+
penalty_last_n=self.last_n_tokens_size,
|
| 716 |
+
penalty_repeat=repeat_penalty,
|
| 717 |
+
penalty_freq=frequency_penalty,
|
| 718 |
+
penalty_present=presence_penalty,
|
| 719 |
+
penalize_nl=penalize_nl,
|
| 720 |
+
ignore_eos=False,
|
| 721 |
+
)
|
| 722 |
+
|
| 723 |
+
if grammar is not None:
|
| 724 |
+
sampler.add_grammar(self._model, grammar)
|
| 725 |
+
|
| 726 |
+
if temp < 0.0:
|
| 727 |
+
sampler.add_softmax()
|
| 728 |
+
sampler.add_dist(self._seed)
|
| 729 |
+
elif temp == 0.0:
|
| 730 |
+
sampler.add_greedy()
|
| 731 |
+
else:
|
| 732 |
+
if mirostat_mode == 1:
|
| 733 |
+
mirostat_m = 100
|
| 734 |
+
sampler.add_mirostat(
|
| 735 |
+
self._n_vocab,
|
| 736 |
+
self._seed,
|
| 737 |
+
mirostat_tau,
|
| 738 |
+
mirostat_eta,
|
| 739 |
+
mirostat_m,
|
| 740 |
+
)
|
| 741 |
+
elif mirostat_mode == 2:
|
| 742 |
+
sampler.add_mirostat_v2(
|
| 743 |
+
self._seed,
|
| 744 |
+
mirostat_tau,
|
| 745 |
+
mirostat_eta,
|
| 746 |
+
)
|
| 747 |
+
else:
|
| 748 |
+
n_probs = 0
|
| 749 |
+
min_keep = max(1, n_probs)
|
| 750 |
+
sampler.add_top_k(top_k)
|
| 751 |
+
sampler.add_typical(typical_p, min_keep)
|
| 752 |
+
sampler.add_top_p(top_p, min_keep)
|
| 753 |
+
sampler.add_min_p(min_p, min_keep)
|
| 754 |
+
sampler.add_temp(temp)
|
| 755 |
+
sampler.add_dist(self._seed)
|
| 756 |
+
return sampler
|
| 757 |
+
|
| 758 |
+
def sample(
|
| 759 |
+
self,
|
| 760 |
+
top_k: int = 40,
|
| 761 |
+
top_p: float = 0.95,
|
| 762 |
+
min_p: float = 0.05,
|
| 763 |
+
typical_p: float = 1.0,
|
| 764 |
+
temp: float = 0.80,
|
| 765 |
+
repeat_penalty: float = 1.0,
|
| 766 |
+
frequency_penalty: float = 0.0,
|
| 767 |
+
presence_penalty: float = 0.0,
|
| 768 |
+
tfs_z: float = 1.0,
|
| 769 |
+
mirostat_mode: int = 0,
|
| 770 |
+
mirostat_eta: float = 0.1,
|
| 771 |
+
mirostat_tau: float = 5.0,
|
| 772 |
+
penalize_nl: bool = True,
|
| 773 |
+
logits_processor: Optional[LogitsProcessorList] = None,
|
| 774 |
+
grammar: Optional[LlamaGrammar] = None,
|
| 775 |
+
idx: Optional[int] = None,
|
| 776 |
+
):
|
| 777 |
+
"""Sample a token from the model.
|
| 778 |
+
|
| 779 |
+
Args:
|
| 780 |
+
top_k: The top-k sampling parameter.
|
| 781 |
+
top_p: The top-p sampling parameter.
|
| 782 |
+
temp: The temperature parameter.
|
| 783 |
+
repeat_penalty: The repeat penalty parameter.
|
| 784 |
+
|
| 785 |
+
Returns:
|
| 786 |
+
The sampled token.
|
| 787 |
+
"""
|
| 788 |
+
assert self.n_tokens > 0
|
| 789 |
+
|
| 790 |
+
tmp_sampler = False
|
| 791 |
+
|
| 792 |
+
if self._sampler is None:
|
| 793 |
+
tmp_sampler = True
|
| 794 |
+
self._sampler = self._init_sampler(
|
| 795 |
+
top_k=top_k,
|
| 796 |
+
top_p=top_p,
|
| 797 |
+
min_p=min_p,
|
| 798 |
+
typical_p=typical_p,
|
| 799 |
+
temp=temp,
|
| 800 |
+
repeat_penalty=repeat_penalty,
|
| 801 |
+
frequency_penalty=frequency_penalty,
|
| 802 |
+
presence_penalty=presence_penalty,
|
| 803 |
+
tfs_z=tfs_z,
|
| 804 |
+
mirostat_mode=mirostat_mode,
|
| 805 |
+
mirostat_tau=mirostat_tau,
|
| 806 |
+
mirostat_eta=mirostat_eta,
|
| 807 |
+
penalize_nl=penalize_nl,
|
| 808 |
+
logits_processor=logits_processor,
|
| 809 |
+
grammar=grammar,
|
| 810 |
+
)
|
| 811 |
+
|
| 812 |
+
ridx = idx - self.n_tokens if idx is not None else -1
|
| 813 |
+
|
| 814 |
+
assert self.ctx is not None
|
| 815 |
+
token = self._sampler.sample(self._ctx, ridx)
|
| 816 |
+
if tmp_sampler:
|
| 817 |
+
self._sampler = None
|
| 818 |
+
return token
|
| 819 |
+
|
| 820 |
+
def generate(
|
| 821 |
+
self,
|
| 822 |
+
tokens: Sequence[int],
|
| 823 |
+
top_k: int = 40,
|
| 824 |
+
top_p: float = 0.95,
|
| 825 |
+
min_p: float = 0.05,
|
| 826 |
+
typical_p: float = 1.0,
|
| 827 |
+
temp: float = 0.80,
|
| 828 |
+
repeat_penalty: float = 1.0,
|
| 829 |
+
reset: bool = True,
|
| 830 |
+
frequency_penalty: float = 0.0,
|
| 831 |
+
presence_penalty: float = 0.0,
|
| 832 |
+
tfs_z: float = 1.0,
|
| 833 |
+
mirostat_mode: int = 0,
|
| 834 |
+
mirostat_tau: float = 5.0,
|
| 835 |
+
mirostat_eta: float = 0.1,
|
| 836 |
+
penalize_nl: bool = True,
|
| 837 |
+
logits_processor: Optional[LogitsProcessorList] = None,
|
| 838 |
+
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
| 839 |
+
grammar: Optional[LlamaGrammar] = None,
|
| 840 |
+
) -> Generator[int, Optional[Sequence[int]], None]:
|
| 841 |
+
"""Create a generator of tokens from a prompt.
|
| 842 |
+
|
| 843 |
+
Examples:
|
| 844 |
+
>>> llama = Llama("models/ggml-7b.bin")
|
| 845 |
+
>>> tokens = llama.tokenize(b"Hello, world!")
|
| 846 |
+
>>> for token in llama.generate(tokens, top_k=40, top_p=0.95, temp=1.0, repeat_penalty=1.0):
|
| 847 |
+
... print(llama.detokenize([token]))
|
| 848 |
+
|
| 849 |
+
Args:
|
| 850 |
+
tokens: The prompt tokens.
|
| 851 |
+
top_k: The top-k sampling parameter.
|
| 852 |
+
top_p: The top-p sampling parameter.
|
| 853 |
+
temp: The temperature parameter.
|
| 854 |
+
repeat_penalty: The repeat penalty parameter.
|
| 855 |
+
reset: Whether to reset the model state.
|
| 856 |
+
|
| 857 |
+
Yields:
|
| 858 |
+
The generated tokens.
|
| 859 |
+
"""
|
| 860 |
+
# Reset mirostat sampling
|
| 861 |
+
self._mirostat_mu = ctypes.c_float(2.0 * mirostat_tau)
|
| 862 |
+
self._sampler = self._init_sampler(
|
| 863 |
+
top_k=top_k,
|
| 864 |
+
top_p=top_p,
|
| 865 |
+
min_p=min_p,
|
| 866 |
+
typical_p=typical_p,
|
| 867 |
+
temp=temp,
|
| 868 |
+
repeat_penalty=repeat_penalty,
|
| 869 |
+
frequency_penalty=frequency_penalty,
|
| 870 |
+
presence_penalty=presence_penalty,
|
| 871 |
+
tfs_z=tfs_z,
|
| 872 |
+
mirostat_mode=mirostat_mode,
|
| 873 |
+
mirostat_tau=mirostat_tau,
|
| 874 |
+
mirostat_eta=mirostat_eta,
|
| 875 |
+
penalize_nl=penalize_nl,
|
| 876 |
+
logits_processor=logits_processor,
|
| 877 |
+
grammar=grammar,
|
| 878 |
+
)
|
| 879 |
+
|
| 880 |
+
# Check for kv cache prefix match
|
| 881 |
+
if reset and self.n_tokens > 0:
|
| 882 |
+
longest_prefix = 0
|
| 883 |
+
for a, b in zip(self._input_ids, tokens[:-1]):
|
| 884 |
+
if a == b:
|
| 885 |
+
longest_prefix += 1
|
| 886 |
+
else:
|
| 887 |
+
break
|
| 888 |
+
if longest_prefix > 0:
|
| 889 |
+
reset = False
|
| 890 |
+
tokens = tokens[longest_prefix:]
|
| 891 |
+
self.n_tokens = longest_prefix
|
| 892 |
+
if self.verbose:
|
| 893 |
+
print(
|
| 894 |
+
f"Llama.generate: {longest_prefix} prefix-match hit, "
|
| 895 |
+
f"remaining {len(tokens)} prompt tokens to eval",
|
| 896 |
+
file=sys.stderr,
|
| 897 |
+
)
|
| 898 |
+
|
| 899 |
+
# Reset the model state
|
| 900 |
+
if reset:
|
| 901 |
+
self.reset()
|
| 902 |
+
|
| 903 |
+
# # Reset the grammar
|
| 904 |
+
# if grammar is not None:
|
| 905 |
+
# grammar.reset()
|
| 906 |
+
|
| 907 |
+
sample_idx = self.n_tokens + len(tokens) - 1
|
| 908 |
+
tokens = list(tokens)
|
| 909 |
+
|
| 910 |
+
# Eval and sample
|
| 911 |
+
while True:
|
| 912 |
+
self.eval(tokens)
|
| 913 |
+
while sample_idx < self.n_tokens:
|
| 914 |
+
token = self.sample(
|
| 915 |
+
top_k=top_k,
|
| 916 |
+
top_p=top_p,
|
| 917 |
+
min_p=min_p,
|
| 918 |
+
typical_p=typical_p,
|
| 919 |
+
temp=temp,
|
| 920 |
+
repeat_penalty=repeat_penalty,
|
| 921 |
+
frequency_penalty=frequency_penalty,
|
| 922 |
+
presence_penalty=presence_penalty,
|
| 923 |
+
tfs_z=tfs_z,
|
| 924 |
+
mirostat_mode=mirostat_mode,
|
| 925 |
+
mirostat_tau=mirostat_tau,
|
| 926 |
+
mirostat_eta=mirostat_eta,
|
| 927 |
+
logits_processor=logits_processor,
|
| 928 |
+
grammar=grammar,
|
| 929 |
+
penalize_nl=penalize_nl,
|
| 930 |
+
idx=sample_idx,
|
| 931 |
+
)
|
| 932 |
+
|
| 933 |
+
sample_idx += 1
|
| 934 |
+
if stopping_criteria is not None and stopping_criteria(
|
| 935 |
+
self._input_ids[: sample_idx], self._scores[sample_idx - self.n_tokens, :]
|
| 936 |
+
):
|
| 937 |
+
return
|
| 938 |
+
tokens_or_none = yield token
|
| 939 |
+
tokens.clear()
|
| 940 |
+
tokens.append(token)
|
| 941 |
+
if tokens_or_none is not None:
|
| 942 |
+
tokens.extend(tokens_or_none)
|
| 943 |
+
|
| 944 |
+
if sample_idx < self.n_tokens and token != self._input_ids[sample_idx]:
|
| 945 |
+
self.n_tokens = sample_idx
|
| 946 |
+
self._ctx.kv_cache_seq_rm(-1, self.n_tokens, -1)
|
| 947 |
+
break
|
| 948 |
+
|
| 949 |
+
if self.draft_model is not None:
|
| 950 |
+
self.input_ids[self.n_tokens : self.n_tokens + len(tokens)] = tokens
|
| 951 |
+
draft_tokens = self.draft_model(
|
| 952 |
+
self.input_ids[: self.n_tokens + len(tokens)]
|
| 953 |
+
)
|
| 954 |
+
tokens.extend(
|
| 955 |
+
draft_tokens.astype(int)[
|
| 956 |
+
: self._n_ctx - self.n_tokens - len(tokens)
|
| 957 |
+
]
|
| 958 |
+
)
|
| 959 |
+
|
| 960 |
+
def create_embedding(
|
| 961 |
+
self, input: Union[str, List[str]], model: Optional[str] = None
|
| 962 |
+
) -> CreateEmbeddingResponse:
|
| 963 |
+
"""Embed a string.
|
| 964 |
+
|
| 965 |
+
Args:
|
| 966 |
+
input: The utf-8 encoded string to embed.
|
| 967 |
+
|
| 968 |
+
Returns:
|
| 969 |
+
An embedding object.
|
| 970 |
+
"""
|
| 971 |
+
model_name: str = model if model is not None else self.model_path
|
| 972 |
+
|
| 973 |
+
input = input if isinstance(input, list) else [input]
|
| 974 |
+
|
| 975 |
+
# get numeric embeddings
|
| 976 |
+
embeds: Union[List[List[float]], List[List[List[float]]]]
|
| 977 |
+
total_tokens: int
|
| 978 |
+
embeds, total_tokens = self.embed(input, return_count=True) # type: ignore
|
| 979 |
+
|
| 980 |
+
# convert to CreateEmbeddingResponse
|
| 981 |
+
data: List[Embedding] = [
|
| 982 |
+
{
|
| 983 |
+
"object": "embedding",
|
| 984 |
+
"embedding": emb,
|
| 985 |
+
"index": idx,
|
| 986 |
+
}
|
| 987 |
+
for idx, emb in enumerate(embeds)
|
| 988 |
+
]
|
| 989 |
+
|
| 990 |
+
return {
|
| 991 |
+
"object": "list",
|
| 992 |
+
"data": data,
|
| 993 |
+
"model": model_name,
|
| 994 |
+
"usage": {
|
| 995 |
+
"prompt_tokens": total_tokens,
|
| 996 |
+
"total_tokens": total_tokens,
|
| 997 |
+
},
|
| 998 |
+
}
|
| 999 |
+
|
| 1000 |
+
def embed(
|
| 1001 |
+
self,
|
| 1002 |
+
input: Union[str, List[str]],
|
| 1003 |
+
normalize: bool = False,
|
| 1004 |
+
truncate: bool = True,
|
| 1005 |
+
return_count: bool = False,
|
| 1006 |
+
):
|
| 1007 |
+
"""Embed a string.
|
| 1008 |
+
|
| 1009 |
+
Args:
|
| 1010 |
+
input: The utf-8 encoded string to embed.
|
| 1011 |
+
|
| 1012 |
+
Returns:
|
| 1013 |
+
A list of embeddings
|
| 1014 |
+
"""
|
| 1015 |
+
n_embd = self.n_embd()
|
| 1016 |
+
n_batch = self.n_batch
|
| 1017 |
+
|
| 1018 |
+
# get pooling information
|
| 1019 |
+
pooling_type = self.pooling_type()
|
| 1020 |
+
logits_all = pooling_type == llama_cpp.LLAMA_POOLING_TYPE_NONE
|
| 1021 |
+
|
| 1022 |
+
if self.context_params.embeddings is False:
|
| 1023 |
+
raise RuntimeError(
|
| 1024 |
+
"Llama model must be created with embedding=True to call this method"
|
| 1025 |
+
)
|
| 1026 |
+
|
| 1027 |
+
if self.verbose:
|
| 1028 |
+
llama_cpp.llama_perf_context_reset(self._ctx.ctx)
|
| 1029 |
+
|
| 1030 |
+
if isinstance(input, str):
|
| 1031 |
+
inputs = [input]
|
| 1032 |
+
else:
|
| 1033 |
+
inputs = input
|
| 1034 |
+
|
| 1035 |
+
# reset batch
|
| 1036 |
+
self._batch.reset()
|
| 1037 |
+
|
| 1038 |
+
# decode and fetch embeddings
|
| 1039 |
+
data: Union[List[List[float]], List[List[List[float]]]] = []
|
| 1040 |
+
|
| 1041 |
+
def decode_batch(seq_sizes: List[int]):
|
| 1042 |
+
llama_cpp.llama_kv_cache_clear(self._ctx.ctx)
|
| 1043 |
+
self._ctx.decode(self._batch)
|
| 1044 |
+
self._batch.reset()
|
| 1045 |
+
|
| 1046 |
+
# store embeddings
|
| 1047 |
+
if pooling_type == llama_cpp.LLAMA_POOLING_TYPE_NONE:
|
| 1048 |
+
pos: int = 0
|
| 1049 |
+
for i, size in enumerate(seq_sizes):
|
| 1050 |
+
ptr = llama_cpp.llama_get_embeddings(self._ctx.ctx)
|
| 1051 |
+
embedding: List[List[float]] = [
|
| 1052 |
+
ptr[pos + j * n_embd : pos + (j + 1) * n_embd]
|
| 1053 |
+
for j in range(size)
|
| 1054 |
+
]
|
| 1055 |
+
if normalize:
|
| 1056 |
+
embedding = [
|
| 1057 |
+
internals.normalize_embedding(e) for e in embedding
|
| 1058 |
+
]
|
| 1059 |
+
data.append(embedding)
|
| 1060 |
+
pos += size
|
| 1061 |
+
else:
|
| 1062 |
+
for i in range(len(seq_sizes)):
|
| 1063 |
+
ptr = llama_cpp.llama_get_embeddings_seq(self._ctx.ctx, i)
|
| 1064 |
+
embedding: List[float] = ptr[:n_embd]
|
| 1065 |
+
if normalize:
|
| 1066 |
+
embedding = internals.normalize_embedding(embedding)
|
| 1067 |
+
data.append(embedding)
|
| 1068 |
+
|
| 1069 |
+
# init state
|
| 1070 |
+
total_tokens = 0
|
| 1071 |
+
s_batch = []
|
| 1072 |
+
t_batch = 0
|
| 1073 |
+
p_batch = 0
|
| 1074 |
+
|
| 1075 |
+
# accumulate batches and encode
|
| 1076 |
+
for text in inputs:
|
| 1077 |
+
tokens = self.tokenize(text.encode("utf-8"))
|
| 1078 |
+
if truncate:
|
| 1079 |
+
tokens = tokens[:n_batch]
|
| 1080 |
+
|
| 1081 |
+
n_tokens = len(tokens)
|
| 1082 |
+
total_tokens += n_tokens
|
| 1083 |
+
|
| 1084 |
+
# check for overrun
|
| 1085 |
+
if n_tokens > n_batch:
|
| 1086 |
+
raise ValueError(
|
| 1087 |
+
f"Requested tokens ({n_tokens}) exceed batch size of {n_batch}"
|
| 1088 |
+
)
|
| 1089 |
+
|
| 1090 |
+
# time to eval batch
|
| 1091 |
+
if t_batch + n_tokens > n_batch:
|
| 1092 |
+
decode_batch(s_batch)
|
| 1093 |
+
s_batch = []
|
| 1094 |
+
t_batch = 0
|
| 1095 |
+
p_batch = 0
|
| 1096 |
+
|
| 1097 |
+
# add to batch
|
| 1098 |
+
self._batch.add_sequence(tokens, p_batch, logits_all)
|
| 1099 |
+
|
| 1100 |
+
# update batch stats
|
| 1101 |
+
s_batch.append(n_tokens)
|
| 1102 |
+
t_batch += n_tokens
|
| 1103 |
+
p_batch += 1
|
| 1104 |
+
|
| 1105 |
+
# hanlde last batch
|
| 1106 |
+
decode_batch(s_batch)
|
| 1107 |
+
|
| 1108 |
+
if self.verbose:
|
| 1109 |
+
llama_cpp.llama_perf_context_print(self._ctx.ctx)
|
| 1110 |
+
|
| 1111 |
+
output = data[0] if isinstance(input, str) else data
|
| 1112 |
+
|
| 1113 |
+
llama_cpp.llama_kv_cache_clear(self._ctx.ctx)
|
| 1114 |
+
self.reset()
|
| 1115 |
+
|
| 1116 |
+
if return_count:
|
| 1117 |
+
return output, total_tokens
|
| 1118 |
+
else:
|
| 1119 |
+
return output
|
| 1120 |
+
|
| 1121 |
+
def _create_completion(
|
| 1122 |
+
self,
|
| 1123 |
+
prompt: Union[str, List[int]],
|
| 1124 |
+
suffix: Optional[str] = None,
|
| 1125 |
+
max_tokens: Optional[int] = 16,
|
| 1126 |
+
temperature: float = 0.8,
|
| 1127 |
+
top_p: float = 0.95,
|
| 1128 |
+
min_p: float = 0.05,
|
| 1129 |
+
typical_p: float = 1.0,
|
| 1130 |
+
logprobs: Optional[int] = None,
|
| 1131 |
+
echo: bool = False,
|
| 1132 |
+
stop: Optional[Union[str, List[str]]] = [],
|
| 1133 |
+
frequency_penalty: float = 0.0,
|
| 1134 |
+
presence_penalty: float = 0.0,
|
| 1135 |
+
repeat_penalty: float = 1.0,
|
| 1136 |
+
top_k: int = 40,
|
| 1137 |
+
stream: bool = False,
|
| 1138 |
+
seed: Optional[int] = None,
|
| 1139 |
+
tfs_z: float = 1.0,
|
| 1140 |
+
mirostat_mode: int = 0,
|
| 1141 |
+
mirostat_tau: float = 5.0,
|
| 1142 |
+
mirostat_eta: float = 0.1,
|
| 1143 |
+
model: Optional[str] = None,
|
| 1144 |
+
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
| 1145 |
+
logits_processor: Optional[LogitsProcessorList] = None,
|
| 1146 |
+
grammar: Optional[LlamaGrammar] = None,
|
| 1147 |
+
logit_bias: Optional[Dict[int, float]] = None,
|
| 1148 |
+
) -> Union[
|
| 1149 |
+
Iterator[CreateCompletionResponse], Iterator[CreateCompletionStreamResponse]
|
| 1150 |
+
]:
|
| 1151 |
+
assert suffix is None or suffix.__class__ is str
|
| 1152 |
+
|
| 1153 |
+
completion_id: str = f"cmpl-{str(uuid.uuid4())}"
|
| 1154 |
+
created: int = int(time.time())
|
| 1155 |
+
bos_token_id: int = self.token_bos()
|
| 1156 |
+
cls_token_id: int = self._model.token_cls()
|
| 1157 |
+
sep_token_id: int = self._model.token_sep()
|
| 1158 |
+
prefix_token_id: int = 0 # self._model.token_prefix() # TODO: Fix
|
| 1159 |
+
middle_token_id: int = 0 # self._model.token_middle() # TODO: Fix
|
| 1160 |
+
suffix_token_id: int = 0 # self._model.token_suffix() # TODO: Fix
|
| 1161 |
+
add_space_prefix: bool = (
|
| 1162 |
+
self.metadata.get("tokenizer.ggml.add_space_prefix", "true") == "true"
|
| 1163 |
+
)
|
| 1164 |
+
bos_tokens: List[int] = [cls_token_id if cls_token_id != -1 else bos_token_id]
|
| 1165 |
+
eos_tokens: List[int] = [
|
| 1166 |
+
sep_token_id if sep_token_id != -1 else self.token_eos()
|
| 1167 |
+
]
|
| 1168 |
+
|
| 1169 |
+
if (
|
| 1170 |
+
(isinstance(prompt, list) and suffix is None)
|
| 1171 |
+
or not self._model.add_bos_token()
|
| 1172 |
+
or bos_tokens[:1] == [-1]
|
| 1173 |
+
):
|
| 1174 |
+
bos_tokens = []
|
| 1175 |
+
|
| 1176 |
+
if (isinstance(prompt, list) and suffix is None) or (
|
| 1177 |
+
not self._model.add_eos_token() and sep_token_id == -1
|
| 1178 |
+
):
|
| 1179 |
+
eos_tokens = []
|
| 1180 |
+
|
| 1181 |
+
suffix_space_prefix: int = 0
|
| 1182 |
+
# Tokenizer hack to remove leading space
|
| 1183 |
+
if add_space_prefix and suffix_token_id >= 0 and suffix:
|
| 1184 |
+
suffix = "☺" + suffix
|
| 1185 |
+
suffix_space_prefix = 2
|
| 1186 |
+
|
| 1187 |
+
# If prompt is empty, initialize completion with BOS token to avoid
|
| 1188 |
+
# detokenization including a space at the beginning of the completion
|
| 1189 |
+
completion_tokens: List[int] = [] if len(prompt) > 0 else [bos_token_id]
|
| 1190 |
+
# Add blank space to start of prompt to match OG llama tokenizer
|
| 1191 |
+
prefix_tokens: List[int] = (
|
| 1192 |
+
[prefix_token_id] if prefix_token_id >= 0 and suffix is not None else []
|
| 1193 |
+
) + (
|
| 1194 |
+
(
|
| 1195 |
+
self.tokenize(
|
| 1196 |
+
prompt.encode("utf-8"),
|
| 1197 |
+
add_bos=False,
|
| 1198 |
+
special=(prefix_token_id < 0 or suffix is None),
|
| 1199 |
+
)
|
| 1200 |
+
if prompt != ""
|
| 1201 |
+
else []
|
| 1202 |
+
)
|
| 1203 |
+
if isinstance(prompt, str)
|
| 1204 |
+
else prompt
|
| 1205 |
+
)
|
| 1206 |
+
suffix_tokens: List[int] = (
|
| 1207 |
+
(
|
| 1208 |
+
[suffix_token_id]
|
| 1209 |
+
+ (
|
| 1210 |
+
self.tokenize(suffix.encode("utf-8"), add_bos=False, special=False)[
|
| 1211 |
+
suffix_space_prefix:
|
| 1212 |
+
]
|
| 1213 |
+
if suffix
|
| 1214 |
+
else []
|
| 1215 |
+
)
|
| 1216 |
+
)
|
| 1217 |
+
if suffix_token_id >= 0 and suffix is not None
|
| 1218 |
+
else []
|
| 1219 |
+
)
|
| 1220 |
+
middle_tokens: List[int] = (
|
| 1221 |
+
[middle_token_id] if middle_token_id >= 0 and suffix is not None else []
|
| 1222 |
+
)
|
| 1223 |
+
prompt_tokens: List[int] = (
|
| 1224 |
+
bos_tokens
|
| 1225 |
+
+ (
|
| 1226 |
+
(suffix_tokens + prefix_tokens + middle_tokens)
|
| 1227 |
+
if self.spm_infill
|
| 1228 |
+
else (prefix_tokens + suffix_tokens + middle_tokens)
|
| 1229 |
+
)
|
| 1230 |
+
+ eos_tokens
|
| 1231 |
+
)
|
| 1232 |
+
text: bytes = b""
|
| 1233 |
+
returned_tokens: int = 0
|
| 1234 |
+
stop = (
|
| 1235 |
+
stop if isinstance(stop, list) else [stop] if isinstance(stop, str) else []
|
| 1236 |
+
)
|
| 1237 |
+
model_name: str = model if model is not None else self.model_path
|
| 1238 |
+
|
| 1239 |
+
if prompt_tokens[:2] == [self.token_bos()] * 2:
|
| 1240 |
+
warnings.warn(
|
| 1241 |
+
f'Detected duplicate leading "{self._model.token_get_text(self.token_bos())}" in prompt, this will likely reduce response quality, consider removing it...',
|
| 1242 |
+
RuntimeWarning,
|
| 1243 |
+
)
|
| 1244 |
+
|
| 1245 |
+
# NOTE: This likely doesn't work correctly for the first token in the prompt
|
| 1246 |
+
# because of the extra space added to the start of the prompt_tokens
|
| 1247 |
+
if logit_bias is not None:
|
| 1248 |
+
logit_bias_map = {int(k): float(v) for k, v in logit_bias.items()}
|
| 1249 |
+
|
| 1250 |
+
def logit_bias_processor(
|
| 1251 |
+
input_ids: npt.NDArray[np.intc],
|
| 1252 |
+
scores: npt.NDArray[np.single],
|
| 1253 |
+
) -> npt.NDArray[np.single]:
|
| 1254 |
+
new_scores = np.copy(
|
| 1255 |
+
scores
|
| 1256 |
+
) # Does it make sense to copy the whole array or can we just overwrite the original one?
|
| 1257 |
+
for input_id, score in logit_bias_map.items():
|
| 1258 |
+
new_scores[input_id] = score + scores[input_id]
|
| 1259 |
+
return new_scores
|
| 1260 |
+
|
| 1261 |
+
_logit_bias_processor = LogitsProcessorList([logit_bias_processor])
|
| 1262 |
+
if logits_processor is None:
|
| 1263 |
+
logits_processor = _logit_bias_processor
|
| 1264 |
+
else:
|
| 1265 |
+
logits_processor = logits_processor.extend(_logit_bias_processor)
|
| 1266 |
+
|
| 1267 |
+
if self.verbose:
|
| 1268 |
+
self._ctx.reset_timings()
|
| 1269 |
+
|
| 1270 |
+
if len(prompt_tokens) >= self._n_ctx:
|
| 1271 |
+
raise ValueError(
|
| 1272 |
+
f"Requested tokens ({len(prompt_tokens)}) exceed context window of {llama_cpp.llama_n_ctx(self.ctx)}"
|
| 1273 |
+
)
|
| 1274 |
+
|
| 1275 |
+
if max_tokens is None or max_tokens <= 0:
|
| 1276 |
+
# Unlimited, depending on n_ctx.
|
| 1277 |
+
max_tokens = self._n_ctx - len(prompt_tokens)
|
| 1278 |
+
|
| 1279 |
+
# Truncate max_tokens if requested tokens would exceed the context window
|
| 1280 |
+
max_tokens = (
|
| 1281 |
+
max_tokens
|
| 1282 |
+
if max_tokens + len(prompt_tokens) < self._n_ctx
|
| 1283 |
+
else (self._n_ctx - len(prompt_tokens))
|
| 1284 |
+
)
|
| 1285 |
+
|
| 1286 |
+
if stop != []:
|
| 1287 |
+
stop_sequences = [s.encode("utf-8") for s in stop]
|
| 1288 |
+
else:
|
| 1289 |
+
stop_sequences = []
|
| 1290 |
+
|
| 1291 |
+
if logprobs is not None and self.context_params.logits_all is False:
|
| 1292 |
+
raise ValueError(
|
| 1293 |
+
"logprobs is not supported for models created with logits_all=False"
|
| 1294 |
+
)
|
| 1295 |
+
|
| 1296 |
+
if self.cache:
|
| 1297 |
+
try:
|
| 1298 |
+
cache_item = self.cache[prompt_tokens]
|
| 1299 |
+
cache_prefix_len = Llama.longest_token_prefix(
|
| 1300 |
+
cache_item.input_ids.tolist(), prompt_tokens
|
| 1301 |
+
)
|
| 1302 |
+
eval_prefix_len = Llama.longest_token_prefix(
|
| 1303 |
+
self._input_ids.tolist(), prompt_tokens
|
| 1304 |
+
)
|
| 1305 |
+
if cache_prefix_len > eval_prefix_len:
|
| 1306 |
+
self.load_state(cache_item)
|
| 1307 |
+
if self.verbose:
|
| 1308 |
+
print("Llama._create_completion: cache hit", file=sys.stderr)
|
| 1309 |
+
except KeyError:
|
| 1310 |
+
if self.verbose:
|
| 1311 |
+
print("Llama._create_completion: cache miss", file=sys.stderr)
|
| 1312 |
+
|
| 1313 |
+
if seed is not None:
|
| 1314 |
+
self.set_seed(seed)
|
| 1315 |
+
else:
|
| 1316 |
+
self.set_seed(random.Random(self._seed).randint(0, 2 ** 32))
|
| 1317 |
+
|
| 1318 |
+
finish_reason = "length"
|
| 1319 |
+
multibyte_fix = 0
|
| 1320 |
+
for token in self.generate(
|
| 1321 |
+
prompt_tokens,
|
| 1322 |
+
top_k=top_k,
|
| 1323 |
+
top_p=top_p,
|
| 1324 |
+
min_p=min_p,
|
| 1325 |
+
typical_p=typical_p,
|
| 1326 |
+
temp=temperature,
|
| 1327 |
+
tfs_z=tfs_z,
|
| 1328 |
+
mirostat_mode=mirostat_mode,
|
| 1329 |
+
mirostat_tau=mirostat_tau,
|
| 1330 |
+
mirostat_eta=mirostat_eta,
|
| 1331 |
+
frequency_penalty=frequency_penalty,
|
| 1332 |
+
presence_penalty=presence_penalty,
|
| 1333 |
+
repeat_penalty=repeat_penalty,
|
| 1334 |
+
stopping_criteria=stopping_criteria,
|
| 1335 |
+
logits_processor=logits_processor,
|
| 1336 |
+
grammar=grammar,
|
| 1337 |
+
):
|
| 1338 |
+
if llama_cpp.llama_token_is_eog(self._model.vocab, token):
|
| 1339 |
+
text = self.detokenize(completion_tokens, prev_tokens=prompt_tokens)
|
| 1340 |
+
finish_reason = "stop"
|
| 1341 |
+
break
|
| 1342 |
+
|
| 1343 |
+
completion_tokens.append(token)
|
| 1344 |
+
|
| 1345 |
+
all_text = self.detokenize(completion_tokens, prev_tokens=prompt_tokens)
|
| 1346 |
+
|
| 1347 |
+
# Contains multi-byte UTF8
|
| 1348 |
+
for k, char in enumerate(all_text[-3:]):
|
| 1349 |
+
k = 3 - k
|
| 1350 |
+
for num, pattern in [(2, 192), (3, 224), (4, 240)]:
|
| 1351 |
+
# Bitwise AND check
|
| 1352 |
+
if num > k and pattern & char == pattern:
|
| 1353 |
+
multibyte_fix = num - k
|
| 1354 |
+
|
| 1355 |
+
# Stop incomplete bytes from passing
|
| 1356 |
+
if multibyte_fix > 0:
|
| 1357 |
+
multibyte_fix -= 1
|
| 1358 |
+
continue
|
| 1359 |
+
|
| 1360 |
+
any_stop = [s for s in stop_sequences if s in all_text]
|
| 1361 |
+
if len(any_stop) > 0:
|
| 1362 |
+
first_stop = any_stop[0]
|
| 1363 |
+
text = all_text[: all_text.index(first_stop)]
|
| 1364 |
+
finish_reason = "stop"
|
| 1365 |
+
break
|
| 1366 |
+
|
| 1367 |
+
if stream:
|
| 1368 |
+
remaining_tokens = completion_tokens[returned_tokens:]
|
| 1369 |
+
remaining_text = self.detokenize(
|
| 1370 |
+
remaining_tokens,
|
| 1371 |
+
prev_tokens=prompt_tokens + completion_tokens[:returned_tokens],
|
| 1372 |
+
)
|
| 1373 |
+
remaining_length = len(remaining_text)
|
| 1374 |
+
|
| 1375 |
+
# We want to avoid yielding any characters from
|
| 1376 |
+
# the generated text if they are part of a stop
|
| 1377 |
+
# sequence.
|
| 1378 |
+
first_stop_position = 0
|
| 1379 |
+
for s in stop_sequences:
|
| 1380 |
+
for i in range(min(len(s), remaining_length), 0, -1):
|
| 1381 |
+
if remaining_text.endswith(s[:i]):
|
| 1382 |
+
if i > first_stop_position:
|
| 1383 |
+
first_stop_position = i
|
| 1384 |
+
break
|
| 1385 |
+
|
| 1386 |
+
token_end_position = 0
|
| 1387 |
+
|
| 1388 |
+
if logprobs is not None:
|
| 1389 |
+
# not sure how to handle this branch when dealing
|
| 1390 |
+
# with CJK output, so keep it unchanged
|
| 1391 |
+
for token in remaining_tokens:
|
| 1392 |
+
if token == bos_token_id:
|
| 1393 |
+
continue
|
| 1394 |
+
token_end_position += len(
|
| 1395 |
+
self.detokenize(
|
| 1396 |
+
[token],
|
| 1397 |
+
prev_tokens=prompt_tokens
|
| 1398 |
+
+ completion_tokens[:returned_tokens],
|
| 1399 |
+
)
|
| 1400 |
+
)
|
| 1401 |
+
# Check if stop sequence is in the token
|
| 1402 |
+
if token_end_position > (
|
| 1403 |
+
remaining_length - first_stop_position
|
| 1404 |
+
):
|
| 1405 |
+
break
|
| 1406 |
+
token_str = self.detokenize(
|
| 1407 |
+
[token],
|
| 1408 |
+
prev_tokens=prompt_tokens
|
| 1409 |
+
+ completion_tokens[:returned_tokens],
|
| 1410 |
+
).decode("utf-8", errors="ignore")
|
| 1411 |
+
text_offset = len(prompt) + len(
|
| 1412 |
+
self.detokenize(
|
| 1413 |
+
completion_tokens[:returned_tokens],
|
| 1414 |
+
prev_tokens=prompt_tokens
|
| 1415 |
+
+ completion_tokens[:returned_tokens],
|
| 1416 |
+
).decode("utf-8", errors="ignore")
|
| 1417 |
+
)
|
| 1418 |
+
token_offset = len(prompt_tokens) + returned_tokens
|
| 1419 |
+
logits = self._scores[token_offset - 1, :]
|
| 1420 |
+
current_logprobs = Llama.logits_to_logprobs(logits).tolist()
|
| 1421 |
+
sorted_logprobs = list(
|
| 1422 |
+
sorted(
|
| 1423 |
+
zip(current_logprobs, range(len(current_logprobs))),
|
| 1424 |
+
reverse=True,
|
| 1425 |
+
)
|
| 1426 |
+
)
|
| 1427 |
+
top_logprob = {
|
| 1428 |
+
self.detokenize([i]).decode(
|
| 1429 |
+
"utf-8", errors="ignore"
|
| 1430 |
+
): logprob
|
| 1431 |
+
for logprob, i in sorted_logprobs[:logprobs]
|
| 1432 |
+
}
|
| 1433 |
+
top_logprob.update({token_str: current_logprobs[int(token)]})
|
| 1434 |
+
logprobs_or_none = {
|
| 1435 |
+
"tokens": [
|
| 1436 |
+
self.detokenize(
|
| 1437 |
+
[token],
|
| 1438 |
+
prev_tokens=prompt_tokens
|
| 1439 |
+
+ completion_tokens[:returned_tokens],
|
| 1440 |
+
).decode("utf-8", errors="ignore")
|
| 1441 |
+
],
|
| 1442 |
+
"text_offset": [text_offset],
|
| 1443 |
+
"token_logprobs": [current_logprobs[int(token)]],
|
| 1444 |
+
"top_logprobs": [top_logprob],
|
| 1445 |
+
}
|
| 1446 |
+
returned_tokens += 1
|
| 1447 |
+
yield {
|
| 1448 |
+
"id": completion_id,
|
| 1449 |
+
"object": "text_completion",
|
| 1450 |
+
"created": created,
|
| 1451 |
+
"model": model_name,
|
| 1452 |
+
"choices": [
|
| 1453 |
+
{
|
| 1454 |
+
"text": self.detokenize(
|
| 1455 |
+
[token],
|
| 1456 |
+
prev_tokens=prompt_tokens
|
| 1457 |
+
+ completion_tokens[:returned_tokens],
|
| 1458 |
+
).decode("utf-8", errors="ignore"),
|
| 1459 |
+
"index": 0,
|
| 1460 |
+
"logprobs": logprobs_or_none,
|
| 1461 |
+
"finish_reason": None,
|
| 1462 |
+
}
|
| 1463 |
+
],
|
| 1464 |
+
}
|
| 1465 |
+
else:
|
| 1466 |
+
while len(remaining_tokens) > 0:
|
| 1467 |
+
decode_success = False
|
| 1468 |
+
for i in range(1, len(remaining_tokens) + 1):
|
| 1469 |
+
try:
|
| 1470 |
+
bs = self.detokenize(
|
| 1471 |
+
remaining_tokens[:i],
|
| 1472 |
+
prev_tokens=prompt_tokens
|
| 1473 |
+
+ completion_tokens[:returned_tokens],
|
| 1474 |
+
)
|
| 1475 |
+
ts = bs.decode("utf-8")
|
| 1476 |
+
decode_success = True
|
| 1477 |
+
break
|
| 1478 |
+
except UnicodeError:
|
| 1479 |
+
pass
|
| 1480 |
+
else:
|
| 1481 |
+
break
|
| 1482 |
+
if not decode_success:
|
| 1483 |
+
# all remaining tokens cannot be decoded to a UTF-8 character
|
| 1484 |
+
break
|
| 1485 |
+
token_end_position += len(bs)
|
| 1486 |
+
if token_end_position > (
|
| 1487 |
+
remaining_length - first_stop_position
|
| 1488 |
+
):
|
| 1489 |
+
break
|
| 1490 |
+
remaining_tokens = remaining_tokens[i:]
|
| 1491 |
+
returned_tokens += i
|
| 1492 |
+
|
| 1493 |
+
yield {
|
| 1494 |
+
"id": completion_id,
|
| 1495 |
+
"object": "text_completion",
|
| 1496 |
+
"created": created,
|
| 1497 |
+
"model": model_name,
|
| 1498 |
+
"choices": [
|
| 1499 |
+
{
|
| 1500 |
+
"text": ts,
|
| 1501 |
+
"index": 0,
|
| 1502 |
+
"logprobs": None,
|
| 1503 |
+
"finish_reason": None,
|
| 1504 |
+
}
|
| 1505 |
+
],
|
| 1506 |
+
}
|
| 1507 |
+
|
| 1508 |
+
if len(completion_tokens) >= max_tokens:
|
| 1509 |
+
text = self.detokenize(completion_tokens, prev_tokens=prompt_tokens)
|
| 1510 |
+
finish_reason = "length"
|
| 1511 |
+
break
|
| 1512 |
+
|
| 1513 |
+
if stopping_criteria is not None and stopping_criteria(
|
| 1514 |
+
self._input_ids, self._scores[-1, :]
|
| 1515 |
+
):
|
| 1516 |
+
text = self.detokenize(completion_tokens, prev_tokens=prompt_tokens)
|
| 1517 |
+
finish_reason = "stop"
|
| 1518 |
+
|
| 1519 |
+
if self.verbose:
|
| 1520 |
+
self._ctx.print_timings()
|
| 1521 |
+
|
| 1522 |
+
if stream:
|
| 1523 |
+
remaining_tokens = completion_tokens[returned_tokens:]
|
| 1524 |
+
remaining_text = self.detokenize(
|
| 1525 |
+
remaining_tokens,
|
| 1526 |
+
prev_tokens=prompt_tokens + completion_tokens[:returned_tokens],
|
| 1527 |
+
)
|
| 1528 |
+
any_stop = [s for s in stop_sequences if s in remaining_text]
|
| 1529 |
+
if len(any_stop) > 0:
|
| 1530 |
+
end = min(remaining_text.index(stop) for stop in any_stop)
|
| 1531 |
+
else:
|
| 1532 |
+
end = len(remaining_text)
|
| 1533 |
+
|
| 1534 |
+
token_end_position = 0
|
| 1535 |
+
for token in remaining_tokens:
|
| 1536 |
+
token_end_position += len(
|
| 1537 |
+
self.detokenize(
|
| 1538 |
+
[token],
|
| 1539 |
+
prev_tokens=prompt_tokens + completion_tokens[:returned_tokens],
|
| 1540 |
+
)
|
| 1541 |
+
)
|
| 1542 |
+
|
| 1543 |
+
logprobs_or_none: Optional[CompletionLogprobs] = None
|
| 1544 |
+
if logprobs is not None:
|
| 1545 |
+
if token == bos_token_id:
|
| 1546 |
+
continue
|
| 1547 |
+
token_str = self.detokenize([token]).decode(
|
| 1548 |
+
"utf-8", errors="ignore"
|
| 1549 |
+
)
|
| 1550 |
+
text_offset = len(prompt) + len(
|
| 1551 |
+
self.detokenize(
|
| 1552 |
+
completion_tokens[:returned_tokens],
|
| 1553 |
+
prev_tokens=prompt_tokens
|
| 1554 |
+
+ completion_tokens[:returned_tokens],
|
| 1555 |
+
)
|
| 1556 |
+
)
|
| 1557 |
+
token_offset = len(prompt_tokens) + returned_tokens - 1
|
| 1558 |
+
logits = self._scores[token_offset, :]
|
| 1559 |
+
current_logprobs = Llama.logits_to_logprobs(logits).tolist()
|
| 1560 |
+
sorted_logprobs = list(
|
| 1561 |
+
sorted(
|
| 1562 |
+
zip(current_logprobs, range(len(current_logprobs))),
|
| 1563 |
+
reverse=True,
|
| 1564 |
+
)
|
| 1565 |
+
)
|
| 1566 |
+
top_logprob = {
|
| 1567 |
+
self.detokenize([i]).decode("utf-8", errors="ignore"): logprob
|
| 1568 |
+
for logprob, i in sorted_logprobs[:logprobs]
|
| 1569 |
+
}
|
| 1570 |
+
top_logprob.update({token_str: current_logprobs[int(token)]})
|
| 1571 |
+
logprobs_or_none = {
|
| 1572 |
+
"tokens": [
|
| 1573 |
+
self.detokenize([token]).decode("utf-8", errors="ignore")
|
| 1574 |
+
],
|
| 1575 |
+
"text_offset": [text_offset],
|
| 1576 |
+
"token_logprobs": [current_logprobs[int(token)]],
|
| 1577 |
+
"top_logprobs": [top_logprob],
|
| 1578 |
+
}
|
| 1579 |
+
|
| 1580 |
+
if token_end_position >= end:
|
| 1581 |
+
last_text = self.detokenize([token])
|
| 1582 |
+
if token_end_position == end - 1:
|
| 1583 |
+
break
|
| 1584 |
+
returned_tokens += 1
|
| 1585 |
+
yield {
|
| 1586 |
+
"id": completion_id,
|
| 1587 |
+
"object": "text_completion",
|
| 1588 |
+
"created": created,
|
| 1589 |
+
"model": model_name,
|
| 1590 |
+
"choices": [
|
| 1591 |
+
{
|
| 1592 |
+
"text": last_text[
|
| 1593 |
+
: len(last_text) - (token_end_position - end)
|
| 1594 |
+
].decode("utf-8", errors="ignore"),
|
| 1595 |
+
"index": 0,
|
| 1596 |
+
"logprobs": logprobs_or_none,
|
| 1597 |
+
"finish_reason": None,
|
| 1598 |
+
}
|
| 1599 |
+
],
|
| 1600 |
+
}
|
| 1601 |
+
break
|
| 1602 |
+
returned_tokens += 1
|
| 1603 |
+
yield {
|
| 1604 |
+
"id": completion_id,
|
| 1605 |
+
"object": "text_completion",
|
| 1606 |
+
"created": created,
|
| 1607 |
+
"model": model_name,
|
| 1608 |
+
"choices": [
|
| 1609 |
+
{
|
| 1610 |
+
"text": self.detokenize([token]).decode(
|
| 1611 |
+
"utf-8", errors="ignore"
|
| 1612 |
+
),
|
| 1613 |
+
"index": 0,
|
| 1614 |
+
"logprobs": logprobs_or_none,
|
| 1615 |
+
"finish_reason": None,
|
| 1616 |
+
}
|
| 1617 |
+
],
|
| 1618 |
+
}
|
| 1619 |
+
yield {
|
| 1620 |
+
"id": completion_id,
|
| 1621 |
+
"object": "text_completion",
|
| 1622 |
+
"created": created,
|
| 1623 |
+
"model": model_name,
|
| 1624 |
+
"choices": [
|
| 1625 |
+
{
|
| 1626 |
+
"text": "",
|
| 1627 |
+
"index": 0,
|
| 1628 |
+
"logprobs": None,
|
| 1629 |
+
"finish_reason": finish_reason,
|
| 1630 |
+
}
|
| 1631 |
+
],
|
| 1632 |
+
}
|
| 1633 |
+
if self.cache:
|
| 1634 |
+
if self.verbose:
|
| 1635 |
+
print("Llama._create_completion: cache save", file=sys.stderr)
|
| 1636 |
+
self.cache[prompt_tokens + completion_tokens] = self.save_state()
|
| 1637 |
+
if self.verbose:
|
| 1638 |
+
print("Llama._create_completion: cache saved", file=sys.stderr)
|
| 1639 |
+
return
|
| 1640 |
+
|
| 1641 |
+
if self.cache:
|
| 1642 |
+
if self.verbose:
|
| 1643 |
+
print("Llama._create_completion: cache save", file=sys.stderr)
|
| 1644 |
+
self.cache[prompt_tokens + completion_tokens] = self.save_state()
|
| 1645 |
+
|
| 1646 |
+
text_str = text.decode("utf-8", errors="ignore")
|
| 1647 |
+
|
| 1648 |
+
if echo:
|
| 1649 |
+
text_str = prompt + text_str
|
| 1650 |
+
|
| 1651 |
+
if suffix_token_id < 0 and suffix is not None:
|
| 1652 |
+
text_str = text_str + suffix
|
| 1653 |
+
|
| 1654 |
+
logprobs_or_none: Optional[CompletionLogprobs] = None
|
| 1655 |
+
if logprobs is not None:
|
| 1656 |
+
text_offset = 0 if echo else len(prompt)
|
| 1657 |
+
token_offset = 0 if echo else len(prompt_tokens[1:])
|
| 1658 |
+
text_offsets: List[int] = []
|
| 1659 |
+
token_logprobs: List[Optional[float]] = []
|
| 1660 |
+
tokens: List[str] = []
|
| 1661 |
+
top_logprobs: List[Optional[Dict[str, float]]] = []
|
| 1662 |
+
|
| 1663 |
+
if echo:
|
| 1664 |
+
# Remove leading BOS token if exists
|
| 1665 |
+
all_tokens = (
|
| 1666 |
+
prompt_tokens[1 if prompt_tokens[0] == self.token_bos() else 0 :]
|
| 1667 |
+
+ completion_tokens
|
| 1668 |
+
)
|
| 1669 |
+
else:
|
| 1670 |
+
all_tokens = completion_tokens
|
| 1671 |
+
|
| 1672 |
+
all_token_strs = [
|
| 1673 |
+
self.detokenize([token], prev_tokens=all_tokens[:i]).decode(
|
| 1674 |
+
"utf-8", errors="ignore"
|
| 1675 |
+
)
|
| 1676 |
+
for i, token in enumerate(all_tokens)
|
| 1677 |
+
]
|
| 1678 |
+
all_logprobs = Llama.logits_to_logprobs(self._scores)[token_offset:]
|
| 1679 |
+
# TODO: may be able to change this loop to use np.take_along_dim
|
| 1680 |
+
for idx, (token, token_str, logprobs_token) in enumerate(
|
| 1681 |
+
zip(all_tokens, all_token_strs, all_logprobs)
|
| 1682 |
+
):
|
| 1683 |
+
if token == bos_token_id:
|
| 1684 |
+
continue
|
| 1685 |
+
text_offsets.append(
|
| 1686 |
+
text_offset
|
| 1687 |
+
+ len(
|
| 1688 |
+
self.detokenize(all_tokens[:idx]).decode(
|
| 1689 |
+
"utf-8", errors="ignore"
|
| 1690 |
+
)
|
| 1691 |
+
)
|
| 1692 |
+
)
|
| 1693 |
+
tokens.append(token_str)
|
| 1694 |
+
sorted_logprobs = list(
|
| 1695 |
+
sorted(
|
| 1696 |
+
zip(logprobs_token, range(len(logprobs_token))), reverse=True
|
| 1697 |
+
)
|
| 1698 |
+
)
|
| 1699 |
+
token_logprobs.append(logprobs_token[int(token)])
|
| 1700 |
+
top_logprob: Optional[Dict[str, float]] = {
|
| 1701 |
+
self.detokenize([i], prev_tokens=all_tokens[:idx]).decode(
|
| 1702 |
+
"utf-8", errors="ignore"
|
| 1703 |
+
): logprob
|
| 1704 |
+
for logprob, i in sorted_logprobs[:logprobs]
|
| 1705 |
+
}
|
| 1706 |
+
top_logprob.update({token_str: logprobs_token[int(token)]})
|
| 1707 |
+
top_logprobs.append(top_logprob)
|
| 1708 |
+
# Weird idosincracy of the OpenAI API where
|
| 1709 |
+
# token_logprobs and top_logprobs are null for
|
| 1710 |
+
# the first token.
|
| 1711 |
+
if echo and len(all_tokens) > 0:
|
| 1712 |
+
token_logprobs[0] = None
|
| 1713 |
+
top_logprobs[0] = None
|
| 1714 |
+
logprobs_or_none = {
|
| 1715 |
+
"tokens": tokens,
|
| 1716 |
+
"text_offset": text_offsets,
|
| 1717 |
+
"token_logprobs": token_logprobs,
|
| 1718 |
+
"top_logprobs": top_logprobs,
|
| 1719 |
+
}
|
| 1720 |
+
|
| 1721 |
+
yield {
|
| 1722 |
+
"id": completion_id,
|
| 1723 |
+
"object": "text_completion",
|
| 1724 |
+
"created": created,
|
| 1725 |
+
"model": model_name,
|
| 1726 |
+
"choices": [
|
| 1727 |
+
{
|
| 1728 |
+
"text": text_str,
|
| 1729 |
+
"index": 0,
|
| 1730 |
+
"logprobs": logprobs_or_none,
|
| 1731 |
+
"finish_reason": finish_reason,
|
| 1732 |
+
}
|
| 1733 |
+
],
|
| 1734 |
+
"usage": {
|
| 1735 |
+
"prompt_tokens": len(prompt_tokens),
|
| 1736 |
+
"completion_tokens": len(completion_tokens),
|
| 1737 |
+
"total_tokens": len(prompt_tokens) + len(completion_tokens),
|
| 1738 |
+
},
|
| 1739 |
+
}
|
| 1740 |
+
|
| 1741 |
+
def create_completion(
|
| 1742 |
+
self,
|
| 1743 |
+
prompt: Union[str, List[int]],
|
| 1744 |
+
suffix: Optional[str] = None,
|
| 1745 |
+
max_tokens: Optional[int] = 16,
|
| 1746 |
+
temperature: float = 0.8,
|
| 1747 |
+
top_p: float = 0.95,
|
| 1748 |
+
min_p: float = 0.05,
|
| 1749 |
+
typical_p: float = 1.0,
|
| 1750 |
+
logprobs: Optional[int] = None,
|
| 1751 |
+
echo: bool = False,
|
| 1752 |
+
stop: Optional[Union[str, List[str]]] = [],
|
| 1753 |
+
frequency_penalty: float = 0.0,
|
| 1754 |
+
presence_penalty: float = 0.0,
|
| 1755 |
+
repeat_penalty: float = 1.0,
|
| 1756 |
+
top_k: int = 40,
|
| 1757 |
+
stream: bool = False,
|
| 1758 |
+
seed: Optional[int] = None,
|
| 1759 |
+
tfs_z: float = 1.0,
|
| 1760 |
+
mirostat_mode: int = 0,
|
| 1761 |
+
mirostat_tau: float = 5.0,
|
| 1762 |
+
mirostat_eta: float = 0.1,
|
| 1763 |
+
model: Optional[str] = None,
|
| 1764 |
+
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
| 1765 |
+
logits_processor: Optional[LogitsProcessorList] = None,
|
| 1766 |
+
grammar: Optional[LlamaGrammar] = None,
|
| 1767 |
+
logit_bias: Optional[Dict[int, float]] = None,
|
| 1768 |
+
) -> Union[CreateCompletionResponse, Iterator[CreateCompletionStreamResponse]]:
|
| 1769 |
+
"""Generate text from a prompt.
|
| 1770 |
+
|
| 1771 |
+
Args:
|
| 1772 |
+
prompt: The prompt to generate text from.
|
| 1773 |
+
suffix: A suffix to append to the generated text. If None, no suffix is appended.
|
| 1774 |
+
max_tokens: The maximum number of tokens to generate. If max_tokens <= 0 or None, the maximum number of tokens to generate is unlimited and depends on n_ctx.
|
| 1775 |
+
temperature: The temperature to use for sampling.
|
| 1776 |
+
top_p: The top-p value to use for nucleus sampling. Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
|
| 1777 |
+
min_p: The min-p value to use for minimum p sampling. Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841
|
| 1778 |
+
typical_p: The typical-p value to use for sampling. Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
|
| 1779 |
+
logprobs: The number of logprobs to return. If None, no logprobs are returned.
|
| 1780 |
+
echo: Whether to echo the prompt.
|
| 1781 |
+
stop: A list of strings to stop generation when encountered.
|
| 1782 |
+
frequency_penalty: The penalty to apply to tokens based on their frequency in the prompt.
|
| 1783 |
+
presence_penalty: The penalty to apply to tokens based on their presence in the prompt.
|
| 1784 |
+
repeat_penalty: The penalty to apply to repeated tokens.
|
| 1785 |
+
top_k: The top-k value to use for sampling. Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
|
| 1786 |
+
stream: Whether to stream the results.
|
| 1787 |
+
seed: The seed to use for sampling.
|
| 1788 |
+
tfs_z: The tail-free sampling parameter. Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
|
| 1789 |
+
mirostat_mode: The mirostat sampling mode.
|
| 1790 |
+
mirostat_tau: The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
|
| 1791 |
+
mirostat_eta: The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
|
| 1792 |
+
model: The name to use for the model in the completion object.
|
| 1793 |
+
stopping_criteria: A list of stopping criteria to use.
|
| 1794 |
+
logits_processor: A list of logits processors to use.
|
| 1795 |
+
grammar: A grammar to use for constrained sampling.
|
| 1796 |
+
logit_bias: A logit bias to use.
|
| 1797 |
+
|
| 1798 |
+
Raises:
|
| 1799 |
+
ValueError: If the requested tokens exceed the context window.
|
| 1800 |
+
RuntimeError: If the prompt fails to tokenize or the model fails to evaluate the prompt.
|
| 1801 |
+
|
| 1802 |
+
Returns:
|
| 1803 |
+
Response object containing the generated text.
|
| 1804 |
+
"""
|
| 1805 |
+
completion_or_chunks = self._create_completion(
|
| 1806 |
+
prompt=prompt,
|
| 1807 |
+
suffix=suffix,
|
| 1808 |
+
max_tokens=-1 if max_tokens is None else max_tokens,
|
| 1809 |
+
temperature=temperature,
|
| 1810 |
+
top_p=top_p,
|
| 1811 |
+
min_p=min_p,
|
| 1812 |
+
typical_p=typical_p,
|
| 1813 |
+
logprobs=logprobs,
|
| 1814 |
+
echo=echo,
|
| 1815 |
+
stop=stop,
|
| 1816 |
+
frequency_penalty=frequency_penalty,
|
| 1817 |
+
presence_penalty=presence_penalty,
|
| 1818 |
+
repeat_penalty=repeat_penalty,
|
| 1819 |
+
top_k=top_k,
|
| 1820 |
+
stream=stream,
|
| 1821 |
+
seed=seed,
|
| 1822 |
+
tfs_z=tfs_z,
|
| 1823 |
+
mirostat_mode=mirostat_mode,
|
| 1824 |
+
mirostat_tau=mirostat_tau,
|
| 1825 |
+
mirostat_eta=mirostat_eta,
|
| 1826 |
+
model=model,
|
| 1827 |
+
stopping_criteria=stopping_criteria,
|
| 1828 |
+
logits_processor=logits_processor,
|
| 1829 |
+
grammar=grammar,
|
| 1830 |
+
logit_bias=logit_bias,
|
| 1831 |
+
)
|
| 1832 |
+
if stream:
|
| 1833 |
+
chunks: Iterator[CreateCompletionStreamResponse] = completion_or_chunks
|
| 1834 |
+
return chunks
|
| 1835 |
+
completion: Completion = next(completion_or_chunks) # type: ignore
|
| 1836 |
+
return completion
|
| 1837 |
+
|
| 1838 |
+
def __call__(
|
| 1839 |
+
self,
|
| 1840 |
+
prompt: str,
|
| 1841 |
+
suffix: Optional[str] = None,
|
| 1842 |
+
max_tokens: Optional[int] = 16,
|
| 1843 |
+
temperature: float = 0.8,
|
| 1844 |
+
top_p: float = 0.95,
|
| 1845 |
+
min_p: float = 0.05,
|
| 1846 |
+
typical_p: float = 1.0,
|
| 1847 |
+
logprobs: Optional[int] = None,
|
| 1848 |
+
echo: bool = False,
|
| 1849 |
+
stop: Optional[Union[str, List[str]]] = [],
|
| 1850 |
+
frequency_penalty: float = 0.0,
|
| 1851 |
+
presence_penalty: float = 0.0,
|
| 1852 |
+
repeat_penalty: float = 1.0,
|
| 1853 |
+
top_k: int = 40,
|
| 1854 |
+
stream: bool = False,
|
| 1855 |
+
seed: Optional[int] = None,
|
| 1856 |
+
tfs_z: float = 1.0,
|
| 1857 |
+
mirostat_mode: int = 0,
|
| 1858 |
+
mirostat_tau: float = 5.0,
|
| 1859 |
+
mirostat_eta: float = 0.1,
|
| 1860 |
+
model: Optional[str] = None,
|
| 1861 |
+
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
| 1862 |
+
logits_processor: Optional[LogitsProcessorList] = None,
|
| 1863 |
+
grammar: Optional[LlamaGrammar] = None,
|
| 1864 |
+
logit_bias: Optional[Dict[int, float]] = None,
|
| 1865 |
+
) -> Union[CreateCompletionResponse, Iterator[CreateCompletionStreamResponse]]:
|
| 1866 |
+
"""Generate text from a prompt.
|
| 1867 |
+
|
| 1868 |
+
Args:
|
| 1869 |
+
prompt: The prompt to generate text from.
|
| 1870 |
+
suffix: A suffix to append to the generated text. If None, no suffix is appended.
|
| 1871 |
+
max_tokens: The maximum number of tokens to generate. If max_tokens <= 0 or None, the maximum number of tokens to generate is unlimited and depends on n_ctx.
|
| 1872 |
+
temperature: The temperature to use for sampling.
|
| 1873 |
+
top_p: The top-p value to use for nucleus sampling. Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
|
| 1874 |
+
min_p: The min-p value to use for minimum p sampling. Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841
|
| 1875 |
+
typical_p: The typical-p value to use for sampling. Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
|
| 1876 |
+
logprobs: The number of logprobs to return. If None, no logprobs are returned.
|
| 1877 |
+
echo: Whether to echo the prompt.
|
| 1878 |
+
stop: A list of strings to stop generation when encountered.
|
| 1879 |
+
frequency_penalty: The penalty to apply to tokens based on their frequency in the prompt.
|
| 1880 |
+
presence_penalty: The penalty to apply to tokens based on their presence in the prompt.
|
| 1881 |
+
repeat_penalty: The penalty to apply to repeated tokens.
|
| 1882 |
+
top_k: The top-k value to use for sampling. Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
|
| 1883 |
+
stream: Whether to stream the results.
|
| 1884 |
+
seed: The seed to use for sampling.
|
| 1885 |
+
tfs_z: The tail-free sampling parameter. Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
|
| 1886 |
+
mirostat_mode: The mirostat sampling mode.
|
| 1887 |
+
mirostat_tau: The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
|
| 1888 |
+
mirostat_eta: The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
|
| 1889 |
+
model: The name to use for the model in the completion object.
|
| 1890 |
+
stopping_criteria: A list of stopping criteria to use.
|
| 1891 |
+
logits_processor: A list of logits processors to use.
|
| 1892 |
+
grammar: A grammar to use for constrained sampling.
|
| 1893 |
+
logit_bias: A logit bias to use.
|
| 1894 |
+
|
| 1895 |
+
Raises:
|
| 1896 |
+
ValueError: If the requested tokens exceed the context window.
|
| 1897 |
+
RuntimeError: If the prompt fails to tokenize or the model fails to evaluate the prompt.
|
| 1898 |
+
|
| 1899 |
+
Returns:
|
| 1900 |
+
Response object containing the generated text.
|
| 1901 |
+
"""
|
| 1902 |
+
return self.create_completion(
|
| 1903 |
+
prompt=prompt,
|
| 1904 |
+
suffix=suffix,
|
| 1905 |
+
max_tokens=max_tokens,
|
| 1906 |
+
temperature=temperature,
|
| 1907 |
+
top_p=top_p,
|
| 1908 |
+
min_p=min_p,
|
| 1909 |
+
typical_p=typical_p,
|
| 1910 |
+
logprobs=logprobs,
|
| 1911 |
+
echo=echo,
|
| 1912 |
+
stop=stop,
|
| 1913 |
+
frequency_penalty=frequency_penalty,
|
| 1914 |
+
presence_penalty=presence_penalty,
|
| 1915 |
+
repeat_penalty=repeat_penalty,
|
| 1916 |
+
top_k=top_k,
|
| 1917 |
+
stream=stream,
|
| 1918 |
+
seed=seed,
|
| 1919 |
+
tfs_z=tfs_z,
|
| 1920 |
+
mirostat_mode=mirostat_mode,
|
| 1921 |
+
mirostat_tau=mirostat_tau,
|
| 1922 |
+
mirostat_eta=mirostat_eta,
|
| 1923 |
+
model=model,
|
| 1924 |
+
stopping_criteria=stopping_criteria,
|
| 1925 |
+
logits_processor=logits_processor,
|
| 1926 |
+
grammar=grammar,
|
| 1927 |
+
logit_bias=logit_bias,
|
| 1928 |
+
)
|
| 1929 |
+
|
| 1930 |
+
def create_chat_completion(
|
| 1931 |
+
self,
|
| 1932 |
+
messages: List[ChatCompletionRequestMessage],
|
| 1933 |
+
functions: Optional[List[ChatCompletionFunction]] = None,
|
| 1934 |
+
function_call: Optional[ChatCompletionRequestFunctionCall] = None,
|
| 1935 |
+
tools: Optional[List[ChatCompletionTool]] = None,
|
| 1936 |
+
tool_choice: Optional[ChatCompletionToolChoiceOption] = None,
|
| 1937 |
+
temperature: float = 0.2,
|
| 1938 |
+
top_p: float = 0.95,
|
| 1939 |
+
top_k: int = 40,
|
| 1940 |
+
min_p: float = 0.05,
|
| 1941 |
+
typical_p: float = 1.0,
|
| 1942 |
+
stream: bool = False,
|
| 1943 |
+
stop: Optional[Union[str, List[str]]] = [],
|
| 1944 |
+
seed: Optional[int] = None,
|
| 1945 |
+
response_format: Optional[ChatCompletionRequestResponseFormat] = None,
|
| 1946 |
+
max_tokens: Optional[int] = None,
|
| 1947 |
+
presence_penalty: float = 0.0,
|
| 1948 |
+
frequency_penalty: float = 0.0,
|
| 1949 |
+
repeat_penalty: float = 1.0,
|
| 1950 |
+
tfs_z: float = 1.0,
|
| 1951 |
+
mirostat_mode: int = 0,
|
| 1952 |
+
mirostat_tau: float = 5.0,
|
| 1953 |
+
mirostat_eta: float = 0.1,
|
| 1954 |
+
model: Optional[str] = None,
|
| 1955 |
+
logits_processor: Optional[LogitsProcessorList] = None,
|
| 1956 |
+
grammar: Optional[LlamaGrammar] = None,
|
| 1957 |
+
logit_bias: Optional[Dict[int, float]] = None,
|
| 1958 |
+
logprobs: Optional[bool] = None,
|
| 1959 |
+
top_logprobs: Optional[int] = None,
|
| 1960 |
+
) -> Union[
|
| 1961 |
+
CreateChatCompletionResponse, Iterator[CreateChatCompletionStreamResponse]
|
| 1962 |
+
]:
|
| 1963 |
+
"""Generate a chat completion from a list of messages.
|
| 1964 |
+
|
| 1965 |
+
Args:
|
| 1966 |
+
messages: A list of messages to generate a response for.
|
| 1967 |
+
functions: A list of functions to use for the chat completion.
|
| 1968 |
+
function_call: A function call to use for the chat completion.
|
| 1969 |
+
tools: A list of tools to use for the chat completion.
|
| 1970 |
+
tool_choice: A tool choice to use for the chat completion.
|
| 1971 |
+
temperature: The temperature to use for sampling.
|
| 1972 |
+
top_p: The top-p value to use for nucleus sampling. Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
|
| 1973 |
+
top_k: The top-k value to use for sampling. Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
|
| 1974 |
+
min_p: The min-p value to use for minimum p sampling. Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841
|
| 1975 |
+
typical_p: The typical-p value to use for sampling. Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
|
| 1976 |
+
stream: Whether to stream the results.
|
| 1977 |
+
stop: A list of strings to stop generation when encountered.
|
| 1978 |
+
seed: The seed to use for sampling.
|
| 1979 |
+
response_format: The response format to use for the chat completion. Use { "type": "json_object" } to contstrain output to only valid json.
|
| 1980 |
+
max_tokens: The maximum number of tokens to generate. If max_tokens <= 0 or None, the maximum number of tokens to generate is unlimited and depends on n_ctx.
|
| 1981 |
+
presence_penalty: The penalty to apply to tokens based on their presence in the prompt.
|
| 1982 |
+
frequency_penalty: The penalty to apply to tokens based on their frequency in the prompt.
|
| 1983 |
+
repeat_penalty: The penalty to apply to repeated tokens.
|
| 1984 |
+
tfs_z: The tail-free sampling parameter.
|
| 1985 |
+
mirostat_mode: The mirostat sampling mode.
|
| 1986 |
+
mirostat_tau: The mirostat sampling tau parameter.
|
| 1987 |
+
mirostat_eta: The mirostat sampling eta parameter.
|
| 1988 |
+
model: The name to use for the model in the completion object.
|
| 1989 |
+
logits_processor: A list of logits processors to use.
|
| 1990 |
+
grammar: A grammar to use.
|
| 1991 |
+
logit_bias: A logit bias to use.
|
| 1992 |
+
|
| 1993 |
+
Returns:
|
| 1994 |
+
Generated chat completion or a stream of chat completion chunks.
|
| 1995 |
+
"""
|
| 1996 |
+
handler = (
|
| 1997 |
+
self.chat_handler
|
| 1998 |
+
or self._chat_handlers.get(self.chat_format)
|
| 1999 |
+
or llama_chat_format.get_chat_completion_handler(self.chat_format)
|
| 2000 |
+
)
|
| 2001 |
+
return handler(
|
| 2002 |
+
llama=self,
|
| 2003 |
+
messages=messages,
|
| 2004 |
+
functions=functions,
|
| 2005 |
+
function_call=function_call,
|
| 2006 |
+
tools=tools,
|
| 2007 |
+
tool_choice=tool_choice,
|
| 2008 |
+
temperature=temperature,
|
| 2009 |
+
top_p=top_p,
|
| 2010 |
+
top_k=top_k,
|
| 2011 |
+
min_p=min_p,
|
| 2012 |
+
typical_p=typical_p,
|
| 2013 |
+
logprobs=logprobs,
|
| 2014 |
+
top_logprobs=top_logprobs,
|
| 2015 |
+
stream=stream,
|
| 2016 |
+
stop=stop,
|
| 2017 |
+
seed=seed,
|
| 2018 |
+
response_format=response_format,
|
| 2019 |
+
max_tokens=max_tokens,
|
| 2020 |
+
presence_penalty=presence_penalty,
|
| 2021 |
+
frequency_penalty=frequency_penalty,
|
| 2022 |
+
repeat_penalty=repeat_penalty,
|
| 2023 |
+
tfs_z=tfs_z,
|
| 2024 |
+
mirostat_mode=mirostat_mode,
|
| 2025 |
+
mirostat_tau=mirostat_tau,
|
| 2026 |
+
mirostat_eta=mirostat_eta,
|
| 2027 |
+
model=model,
|
| 2028 |
+
logits_processor=logits_processor,
|
| 2029 |
+
grammar=grammar,
|
| 2030 |
+
logit_bias=logit_bias,
|
| 2031 |
+
)
|
| 2032 |
+
|
| 2033 |
+
def create_chat_completion_openai_v1(
|
| 2034 |
+
self,
|
| 2035 |
+
*args: Any,
|
| 2036 |
+
**kwargs: Any,
|
| 2037 |
+
):
|
| 2038 |
+
"""Generate a chat completion with return type based on the the OpenAI v1 API.
|
| 2039 |
+
|
| 2040 |
+
OpenAI python package is required to use this method.
|
| 2041 |
+
|
| 2042 |
+
You can install it with `pip install openai`.
|
| 2043 |
+
|
| 2044 |
+
Args:
|
| 2045 |
+
*args: Positional arguments to pass to create_chat_completion.
|
| 2046 |
+
**kwargs: Keyword arguments to pass to create_chat_completion.
|
| 2047 |
+
|
| 2048 |
+
Returns:
|
| 2049 |
+
Generated chat completion or a stream of chat completion chunks.
|
| 2050 |
+
"""
|
| 2051 |
+
try:
|
| 2052 |
+
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
| 2053 |
+
|
| 2054 |
+
stream = kwargs.get("stream", False) # type: ignore
|
| 2055 |
+
assert isinstance(stream, bool)
|
| 2056 |
+
if stream:
|
| 2057 |
+
return (ChatCompletionChunk(**chunk) for chunk in self.create_chat_completion(*args, **kwargs)) # type: ignore
|
| 2058 |
+
else:
|
| 2059 |
+
return ChatCompletion(**self.create_chat_completion(*args, **kwargs)) # type: ignore
|
| 2060 |
+
except ImportError:
|
| 2061 |
+
raise ImportError(
|
| 2062 |
+
"To use create_chat_completion_openai_v1, you must install the openai package."
|
| 2063 |
+
"You can install it with `pip install openai`."
|
| 2064 |
+
)
|
| 2065 |
+
|
| 2066 |
+
def __getstate__(self):
|
| 2067 |
+
return dict(
|
| 2068 |
+
model_path=self.model_path,
|
| 2069 |
+
# Model Params
|
| 2070 |
+
n_gpu_layers=self.model_params.n_gpu_layers,
|
| 2071 |
+
split_mode=self.model_params.split_mode,
|
| 2072 |
+
main_gpu=self.model_params.main_gpu,
|
| 2073 |
+
tensor_split=self.tensor_split,
|
| 2074 |
+
vocab_only=self.model_params.vocab_only,
|
| 2075 |
+
use_mmap=self.model_params.use_mmap,
|
| 2076 |
+
use_mlock=self.model_params.use_mlock,
|
| 2077 |
+
kv_overrides=self.kv_overrides,
|
| 2078 |
+
# Context Params
|
| 2079 |
+
seed=self._seed,
|
| 2080 |
+
n_ctx=self.context_params.n_ctx,
|
| 2081 |
+
n_batch=self.n_batch,
|
| 2082 |
+
n_ubatch=self.context_params.n_ubatch,
|
| 2083 |
+
n_threads=self.context_params.n_threads,
|
| 2084 |
+
n_threads_batch=self.context_params.n_threads_batch,
|
| 2085 |
+
rope_scaling_type=self.context_params.rope_scaling_type,
|
| 2086 |
+
pooling_type=self.context_params.pooling_type,
|
| 2087 |
+
rope_freq_base=self.context_params.rope_freq_base,
|
| 2088 |
+
rope_freq_scale=self.context_params.rope_freq_scale,
|
| 2089 |
+
yarn_ext_factor=self.context_params.yarn_ext_factor,
|
| 2090 |
+
yarn_attn_factor=self.context_params.yarn_attn_factor,
|
| 2091 |
+
yarn_beta_fast=self.context_params.yarn_beta_fast,
|
| 2092 |
+
yarn_beta_slow=self.context_params.yarn_beta_slow,
|
| 2093 |
+
yarn_orig_ctx=self.context_params.yarn_orig_ctx,
|
| 2094 |
+
logits_all=self.context_params.logits_all,
|
| 2095 |
+
embedding=self.context_params.embeddings,
|
| 2096 |
+
offload_kqv=self.context_params.offload_kqv,
|
| 2097 |
+
flash_attn=self.context_params.flash_attn,
|
| 2098 |
+
# Sampling Params
|
| 2099 |
+
no_perf=self.context_params.no_perf,
|
| 2100 |
+
last_n_tokens_size=self.last_n_tokens_size,
|
| 2101 |
+
# LoRA Params
|
| 2102 |
+
lora_base=self.lora_base,
|
| 2103 |
+
lora_scale=self.lora_scale,
|
| 2104 |
+
lora_path=self.lora_path,
|
| 2105 |
+
# Backend Params
|
| 2106 |
+
numa=self.numa,
|
| 2107 |
+
# Chat Format Params
|
| 2108 |
+
chat_format=self.chat_format,
|
| 2109 |
+
chat_handler=self.chat_handler,
|
| 2110 |
+
# Speculative Decidng
|
| 2111 |
+
draft_model=self.draft_model,
|
| 2112 |
+
# KV cache quantization
|
| 2113 |
+
type_k=self.context_params.type_k,
|
| 2114 |
+
type_v=self.context_params.type_v,
|
| 2115 |
+
# Misc
|
| 2116 |
+
spm_infill=self.spm_infill,
|
| 2117 |
+
verbose=self.verbose,
|
| 2118 |
+
)
|
| 2119 |
+
|
| 2120 |
+
def __setstate__(self, state):
|
| 2121 |
+
self.__init__(**state)
|
| 2122 |
+
|
| 2123 |
+
def save_state(self) -> LlamaState:
|
| 2124 |
+
if self.verbose:
|
| 2125 |
+
print("Llama.save_state: saving llama state", file=sys.stderr)
|
| 2126 |
+
state_size = llama_cpp.llama_get_state_size(self._ctx.ctx)
|
| 2127 |
+
if self.verbose:
|
| 2128 |
+
print(f"Llama.save_state: got state size: {state_size}", file=sys.stderr)
|
| 2129 |
+
llama_state = (ctypes.c_uint8 * int(state_size))()
|
| 2130 |
+
if self.verbose:
|
| 2131 |
+
print("Llama.save_state: allocated state", file=sys.stderr)
|
| 2132 |
+
n_bytes = llama_cpp.llama_copy_state_data(self._ctx.ctx, llama_state)
|
| 2133 |
+
if self.verbose:
|
| 2134 |
+
print(f"Llama.save_state: copied llama state: {n_bytes}", file=sys.stderr)
|
| 2135 |
+
if int(n_bytes) > int(state_size):
|
| 2136 |
+
raise RuntimeError("Failed to copy llama state data")
|
| 2137 |
+
llama_state_compact = (ctypes.c_uint8 * int(n_bytes))()
|
| 2138 |
+
llama_cpp.ctypes.memmove(llama_state_compact, llama_state, int(n_bytes))
|
| 2139 |
+
if self.verbose:
|
| 2140 |
+
print(
|
| 2141 |
+
f"Llama.save_state: saving {n_bytes} bytes of llama state",
|
| 2142 |
+
file=sys.stderr,
|
| 2143 |
+
)
|
| 2144 |
+
return LlamaState(
|
| 2145 |
+
scores=self._scores.copy(),
|
| 2146 |
+
input_ids=self.input_ids.copy(),
|
| 2147 |
+
n_tokens=self.n_tokens,
|
| 2148 |
+
llama_state=bytes(llama_state_compact),
|
| 2149 |
+
llama_state_size=n_bytes,
|
| 2150 |
+
seed=self._seed,
|
| 2151 |
+
)
|
| 2152 |
+
|
| 2153 |
+
def load_state(self, state: LlamaState) -> None:
|
| 2154 |
+
# Only filling in up to `n_tokens` and then zero-ing out the rest
|
| 2155 |
+
self.scores[: state.n_tokens, :] = state.scores.copy()
|
| 2156 |
+
rest = self.scores[state.n_tokens :, :]
|
| 2157 |
+
rest[rest > 0] = 0.0
|
| 2158 |
+
self.input_ids = state.input_ids.copy()
|
| 2159 |
+
self.n_tokens = state.n_tokens
|
| 2160 |
+
self._seed = state.seed
|
| 2161 |
+
state_size = state.llama_state_size
|
| 2162 |
+
LLamaStateArrayType = ctypes.c_uint8 * state_size
|
| 2163 |
+
llama_state = LLamaStateArrayType.from_buffer_copy(state.llama_state)
|
| 2164 |
+
|
| 2165 |
+
if llama_cpp.llama_set_state_data(self._ctx.ctx, llama_state) != state_size:
|
| 2166 |
+
raise RuntimeError("Failed to set llama state data")
|
| 2167 |
+
|
| 2168 |
+
def n_ctx(self) -> int:
|
| 2169 |
+
"""Return the context window size."""
|
| 2170 |
+
return self._ctx.n_ctx()
|
| 2171 |
+
|
| 2172 |
+
def n_embd(self) -> int:
|
| 2173 |
+
"""Return the embedding size."""
|
| 2174 |
+
return self._model.n_embd()
|
| 2175 |
+
|
| 2176 |
+
def n_vocab(self) -> int:
|
| 2177 |
+
"""Return the vocabulary size."""
|
| 2178 |
+
return self._model.n_vocab()
|
| 2179 |
+
|
| 2180 |
+
def tokenizer(self) -> LlamaTokenizer:
|
| 2181 |
+
"""Return the llama tokenizer for this model."""
|
| 2182 |
+
return LlamaTokenizer(self)
|
| 2183 |
+
|
| 2184 |
+
def token_eos(self) -> int:
|
| 2185 |
+
"""Return the end-of-sequence token."""
|
| 2186 |
+
return self._model.token_eos()
|
| 2187 |
+
|
| 2188 |
+
def token_bos(self) -> int:
|
| 2189 |
+
"""Return the beginning-of-sequence token."""
|
| 2190 |
+
return self._model.token_bos()
|
| 2191 |
+
|
| 2192 |
+
def token_nl(self) -> int:
|
| 2193 |
+
"""Return the newline token."""
|
| 2194 |
+
return self._model.token_nl()
|
| 2195 |
+
|
| 2196 |
+
def pooling_type(self) -> str:
|
| 2197 |
+
"""Return the pooling type."""
|
| 2198 |
+
return self._ctx.pooling_type()
|
| 2199 |
+
|
| 2200 |
+
def close(self) -> None:
|
| 2201 |
+
"""Explicitly free the model from memory."""
|
| 2202 |
+
self._stack.close()
|
| 2203 |
+
|
| 2204 |
+
def __del__(self) -> None:
|
| 2205 |
+
self.close()
|
| 2206 |
+
|
| 2207 |
+
@staticmethod
|
| 2208 |
+
def logits_to_logprobs(
|
| 2209 |
+
logits: Union[npt.NDArray[np.single], List], axis: int = -1
|
| 2210 |
+
) -> npt.NDArray[np.single]:
|
| 2211 |
+
# https://docs.scipy.org/doc/scipy/reference/generated/scipy.special.log_softmax.html
|
| 2212 |
+
logits_maxs: np.ndarray = np.amax(logits, axis=axis, keepdims=True)
|
| 2213 |
+
if logits_maxs.ndim > 0:
|
| 2214 |
+
logits_maxs[~np.isfinite(logits_maxs)] = 0
|
| 2215 |
+
elif not np.isfinite(logits_maxs):
|
| 2216 |
+
logits_maxs = 0
|
| 2217 |
+
subtract_maxs = np.subtract(logits, logits_maxs, dtype=np.single)
|
| 2218 |
+
exp = np.exp(subtract_maxs)
|
| 2219 |
+
# Suppress warnings about log of zero
|
| 2220 |
+
with np.errstate(divide="ignore"):
|
| 2221 |
+
summed = np.sum(exp, axis=axis, keepdims=True)
|
| 2222 |
+
out = np.log(summed)
|
| 2223 |
+
return subtract_maxs - out
|
| 2224 |
+
|
| 2225 |
+
@staticmethod
|
| 2226 |
+
def longest_token_prefix(a: Sequence[int], b: Sequence[int]):
|
| 2227 |
+
longest_prefix = 0
|
| 2228 |
+
for _a, _b in zip(a, b):
|
| 2229 |
+
if _a == _b:
|
| 2230 |
+
longest_prefix += 1
|
| 2231 |
+
else:
|
| 2232 |
+
break
|
| 2233 |
+
return longest_prefix
|
| 2234 |
+
|
| 2235 |
+
@classmethod
|
| 2236 |
+
def from_pretrained(
|
| 2237 |
+
cls,
|
| 2238 |
+
repo_id: str,
|
| 2239 |
+
filename: Optional[str],
|
| 2240 |
+
additional_files: Optional[List] = None,
|
| 2241 |
+
local_dir: Optional[Union[str, os.PathLike[str]]] = None,
|
| 2242 |
+
local_dir_use_symlinks: Union[bool, Literal["auto"]] = "auto",
|
| 2243 |
+
cache_dir: Optional[Union[str, os.PathLike[str]]] = None,
|
| 2244 |
+
**kwargs: Any,
|
| 2245 |
+
) -> "Llama":
|
| 2246 |
+
"""Create a Llama model from a pretrained model name or path.
|
| 2247 |
+
This method requires the huggingface-hub package.
|
| 2248 |
+
You can install it with `pip install huggingface-hub`.
|
| 2249 |
+
|
| 2250 |
+
Args:
|
| 2251 |
+
repo_id: The model repo id.
|
| 2252 |
+
filename: A filename or glob pattern to match the model file in the repo.
|
| 2253 |
+
additional_files: A list of filenames or glob patterns to match additional model files in the repo.
|
| 2254 |
+
local_dir: The local directory to save the model to.
|
| 2255 |
+
local_dir_use_symlinks: Whether to use symlinks when downloading the model.
|
| 2256 |
+
**kwargs: Additional keyword arguments to pass to the Llama constructor.
|
| 2257 |
+
|
| 2258 |
+
Returns:
|
| 2259 |
+
A Llama model."""
|
| 2260 |
+
try:
|
| 2261 |
+
from huggingface_hub import hf_hub_download, HfFileSystem
|
| 2262 |
+
from huggingface_hub.utils import validate_repo_id
|
| 2263 |
+
except ImportError:
|
| 2264 |
+
raise ImportError(
|
| 2265 |
+
"Llama.from_pretrained requires the huggingface-hub package. "
|
| 2266 |
+
"You can install it with `pip install huggingface-hub`."
|
| 2267 |
+
)
|
| 2268 |
+
|
| 2269 |
+
validate_repo_id(repo_id)
|
| 2270 |
+
|
| 2271 |
+
hffs = HfFileSystem()
|
| 2272 |
+
|
| 2273 |
+
files = [
|
| 2274 |
+
file["name"] if isinstance(file, dict) else file
|
| 2275 |
+
for file in hffs.ls(repo_id, recursive=True)
|
| 2276 |
+
]
|
| 2277 |
+
|
| 2278 |
+
# split each file into repo_id, subfolder, filename
|
| 2279 |
+
file_list: List[str] = []
|
| 2280 |
+
for file in files:
|
| 2281 |
+
rel_path = Path(file).relative_to(repo_id)
|
| 2282 |
+
file_list.append(str(rel_path))
|
| 2283 |
+
|
| 2284 |
+
# find the only/first shard file:
|
| 2285 |
+
matching_files = [file for file in file_list if fnmatch.fnmatch(file, filename)] # type: ignore
|
| 2286 |
+
|
| 2287 |
+
if len(matching_files) == 0:
|
| 2288 |
+
raise ValueError(
|
| 2289 |
+
f"No file found in {repo_id} that match {filename}\n\n"
|
| 2290 |
+
f"Available Files:\n{json.dumps(file_list)}"
|
| 2291 |
+
)
|
| 2292 |
+
|
| 2293 |
+
if len(matching_files) > 1:
|
| 2294 |
+
raise ValueError(
|
| 2295 |
+
f"Multiple files found in {repo_id} matching {filename}\n\n"
|
| 2296 |
+
f"Available Files:\n{json.dumps(files)}"
|
| 2297 |
+
)
|
| 2298 |
+
|
| 2299 |
+
(matching_file,) = matching_files
|
| 2300 |
+
|
| 2301 |
+
subfolder = str(Path(matching_file).parent)
|
| 2302 |
+
filename = Path(matching_file).name
|
| 2303 |
+
|
| 2304 |
+
# download the file
|
| 2305 |
+
hf_hub_download(
|
| 2306 |
+
repo_id=repo_id,
|
| 2307 |
+
filename=filename,
|
| 2308 |
+
subfolder=subfolder,
|
| 2309 |
+
local_dir=local_dir,
|
| 2310 |
+
local_dir_use_symlinks=local_dir_use_symlinks,
|
| 2311 |
+
cache_dir=cache_dir,
|
| 2312 |
+
)
|
| 2313 |
+
|
| 2314 |
+
if additional_files:
|
| 2315 |
+
for additonal_file_name in additional_files:
|
| 2316 |
+
# find the additional shard file:
|
| 2317 |
+
matching_additional_files = [file for file in file_list if fnmatch.fnmatch(file, additonal_file_name)]
|
| 2318 |
+
|
| 2319 |
+
if len(matching_additional_files) == 0:
|
| 2320 |
+
raise ValueError(
|
| 2321 |
+
f"No file found in {repo_id} that match {additonal_file_name}\n\n"
|
| 2322 |
+
f"Available Files:\n{json.dumps(file_list)}"
|
| 2323 |
+
)
|
| 2324 |
+
|
| 2325 |
+
if len(matching_additional_files) > 1:
|
| 2326 |
+
raise ValueError(
|
| 2327 |
+
f"Multiple files found in {repo_id} matching {additonal_file_name}\n\n"
|
| 2328 |
+
f"Available Files:\n{json.dumps(files)}"
|
| 2329 |
+
)
|
| 2330 |
+
|
| 2331 |
+
(matching_additional_file,) = matching_additional_files
|
| 2332 |
+
|
| 2333 |
+
# download the additional file
|
| 2334 |
+
hf_hub_download(
|
| 2335 |
+
repo_id=repo_id,
|
| 2336 |
+
filename=matching_additional_file,
|
| 2337 |
+
subfolder=subfolder,
|
| 2338 |
+
local_dir=local_dir,
|
| 2339 |
+
local_dir_use_symlinks=local_dir_use_symlinks,
|
| 2340 |
+
cache_dir=cache_dir,
|
| 2341 |
+
)
|
| 2342 |
+
|
| 2343 |
+
if local_dir is None:
|
| 2344 |
+
model_path = hf_hub_download(
|
| 2345 |
+
repo_id=repo_id,
|
| 2346 |
+
filename=filename,
|
| 2347 |
+
subfolder=subfolder,
|
| 2348 |
+
local_dir=local_dir,
|
| 2349 |
+
local_dir_use_symlinks=local_dir_use_symlinks,
|
| 2350 |
+
cache_dir=cache_dir,
|
| 2351 |
+
local_files_only=True,
|
| 2352 |
+
)
|
| 2353 |
+
else:
|
| 2354 |
+
model_path = os.path.join(local_dir, filename)
|
| 2355 |
+
|
| 2356 |
+
# loading the first file of a sharded GGUF loads all remaining shard files in the subfolder
|
| 2357 |
+
return cls(
|
| 2358 |
+
model_path=model_path,
|
| 2359 |
+
**kwargs,
|
| 2360 |
+
)
|
| 2361 |
+
|
| 2362 |
+
|
| 2363 |
+
class LlamaState:
|
| 2364 |
+
def __init__(
|
| 2365 |
+
self,
|
| 2366 |
+
input_ids: npt.NDArray[np.intc],
|
| 2367 |
+
scores: npt.NDArray[np.single],
|
| 2368 |
+
n_tokens: int,
|
| 2369 |
+
llama_state: bytes,
|
| 2370 |
+
llama_state_size: int,
|
| 2371 |
+
seed: int,
|
| 2372 |
+
):
|
| 2373 |
+
self.input_ids = input_ids
|
| 2374 |
+
self.scores = scores
|
| 2375 |
+
self.n_tokens = n_tokens
|
| 2376 |
+
self.llama_state = llama_state
|
| 2377 |
+
self.llama_state_size = llama_state_size
|
| 2378 |
+
self.seed = seed
|
| 2379 |
+
|
| 2380 |
+
|
| 2381 |
+
LogitsProcessor = Callable[
|
| 2382 |
+
[npt.NDArray[np.intc], npt.NDArray[np.single]], npt.NDArray[np.single]
|
| 2383 |
+
]
|
| 2384 |
+
|
| 2385 |
+
|
| 2386 |
+
class LogitsProcessorList(List[LogitsProcessor]):
|
| 2387 |
+
def __call__(
|
| 2388 |
+
self, input_ids: npt.NDArray[np.intc], scores: npt.NDArray[np.single]
|
| 2389 |
+
) -> npt.NDArray[np.single]:
|
| 2390 |
+
for processor in self:
|
| 2391 |
+
scores = processor(input_ids, scores)
|
| 2392 |
+
return scores
|
| 2393 |
+
|
| 2394 |
+
|
| 2395 |
+
StoppingCriteria = Callable[[npt.NDArray[np.intc], npt.NDArray[np.single]], bool]
|
| 2396 |
+
|
| 2397 |
+
|
| 2398 |
+
class StoppingCriteriaList(List[StoppingCriteria]):
|
| 2399 |
+
def __call__(
|
| 2400 |
+
self, input_ids: npt.NDArray[np.intc], logits: npt.NDArray[np.single]
|
| 2401 |
+
) -> bool:
|
| 2402 |
+
return any([stopping_criteria(input_ids, logits) for stopping_criteria in self])
|
| 2403 |
+
|
| 2404 |
+
|
| 2405 |
+
class MinTokensLogitsProcessor(LogitsProcessor):
|
| 2406 |
+
def __init__(self, min_tokens: int, token_eos: int):
|
| 2407 |
+
self.min_tokens = min_tokens
|
| 2408 |
+
self.token_eos = token_eos
|
| 2409 |
+
self.prompt_tokens = None
|
| 2410 |
+
|
| 2411 |
+
def __call__(
|
| 2412 |
+
self, input_ids: npt.NDArray[np.intc], scores: npt.NDArray[np.single]
|
| 2413 |
+
) -> npt.NDArray[np.single]:
|
| 2414 |
+
if self.prompt_tokens is None:
|
| 2415 |
+
self.prompt_tokens = len(input_ids)
|
| 2416 |
+
if len(input_ids) - self.prompt_tokens < self.min_tokens:
|
| 2417 |
+
scores[self.token_eos] = -np.inf
|
| 2418 |
+
return scores
|
llama_cpp/llama_cache.py
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
from abc import ABC, abstractmethod
|
| 3 |
+
from typing import (
|
| 4 |
+
Optional,
|
| 5 |
+
Sequence,
|
| 6 |
+
Tuple,
|
| 7 |
+
)
|
| 8 |
+
from collections import OrderedDict
|
| 9 |
+
|
| 10 |
+
import diskcache
|
| 11 |
+
|
| 12 |
+
import llama_cpp.llama
|
| 13 |
+
|
| 14 |
+
from .llama_types import *
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class BaseLlamaCache(ABC):
|
| 18 |
+
"""Base cache class for a llama.cpp model."""
|
| 19 |
+
|
| 20 |
+
def __init__(self, capacity_bytes: int = (2 << 30)):
|
| 21 |
+
self.capacity_bytes = capacity_bytes
|
| 22 |
+
|
| 23 |
+
@property
|
| 24 |
+
@abstractmethod
|
| 25 |
+
def cache_size(self) -> int:
|
| 26 |
+
raise NotImplementedError
|
| 27 |
+
|
| 28 |
+
def _find_longest_prefix_key(
|
| 29 |
+
self,
|
| 30 |
+
key: Tuple[int, ...],
|
| 31 |
+
) -> Optional[Tuple[int, ...]]:
|
| 32 |
+
pass
|
| 33 |
+
|
| 34 |
+
@abstractmethod
|
| 35 |
+
def __getitem__(self, key: Sequence[int]) -> "llama_cpp.llama.LlamaState":
|
| 36 |
+
raise NotImplementedError
|
| 37 |
+
|
| 38 |
+
@abstractmethod
|
| 39 |
+
def __contains__(self, key: Sequence[int]) -> bool:
|
| 40 |
+
raise NotImplementedError
|
| 41 |
+
|
| 42 |
+
@abstractmethod
|
| 43 |
+
def __setitem__(
|
| 44 |
+
self, key: Sequence[int], value: "llama_cpp.llama.LlamaState"
|
| 45 |
+
) -> None:
|
| 46 |
+
raise NotImplementedError
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class LlamaRAMCache(BaseLlamaCache):
|
| 50 |
+
"""Cache for a llama.cpp model using RAM."""
|
| 51 |
+
|
| 52 |
+
def __init__(self, capacity_bytes: int = (2 << 30)):
|
| 53 |
+
super().__init__(capacity_bytes)
|
| 54 |
+
self.capacity_bytes = capacity_bytes
|
| 55 |
+
self.cache_state: OrderedDict[
|
| 56 |
+
Tuple[int, ...], "llama_cpp.llama.LlamaState"
|
| 57 |
+
] = OrderedDict()
|
| 58 |
+
|
| 59 |
+
@property
|
| 60 |
+
def cache_size(self):
|
| 61 |
+
return sum([state.llama_state_size for state in self.cache_state.values()])
|
| 62 |
+
|
| 63 |
+
def _find_longest_prefix_key(
|
| 64 |
+
self,
|
| 65 |
+
key: Tuple[int, ...],
|
| 66 |
+
) -> Optional[Tuple[int, ...]]:
|
| 67 |
+
min_len = 0
|
| 68 |
+
min_key = None
|
| 69 |
+
keys = (
|
| 70 |
+
(k, llama_cpp.llama.Llama.longest_token_prefix(k, key))
|
| 71 |
+
for k in self.cache_state.keys()
|
| 72 |
+
)
|
| 73 |
+
for k, prefix_len in keys:
|
| 74 |
+
if prefix_len > min_len:
|
| 75 |
+
min_len = prefix_len
|
| 76 |
+
min_key = k
|
| 77 |
+
return min_key
|
| 78 |
+
|
| 79 |
+
def __getitem__(self, key: Sequence[int]) -> "llama_cpp.llama.LlamaState":
|
| 80 |
+
key = tuple(key)
|
| 81 |
+
_key = self._find_longest_prefix_key(key)
|
| 82 |
+
if _key is None:
|
| 83 |
+
raise KeyError("Key not found")
|
| 84 |
+
value = self.cache_state[_key]
|
| 85 |
+
self.cache_state.move_to_end(_key)
|
| 86 |
+
return value
|
| 87 |
+
|
| 88 |
+
def __contains__(self, key: Sequence[int]) -> bool:
|
| 89 |
+
return self._find_longest_prefix_key(tuple(key)) is not None
|
| 90 |
+
|
| 91 |
+
def __setitem__(self, key: Sequence[int], value: "llama_cpp.llama.LlamaState"):
|
| 92 |
+
key = tuple(key)
|
| 93 |
+
if key in self.cache_state:
|
| 94 |
+
del self.cache_state[key]
|
| 95 |
+
self.cache_state[key] = value
|
| 96 |
+
while self.cache_size > self.capacity_bytes and len(self.cache_state) > 0:
|
| 97 |
+
self.cache_state.popitem(last=False)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
# Alias for backwards compatibility
|
| 101 |
+
LlamaCache = LlamaRAMCache
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
class LlamaDiskCache(BaseLlamaCache):
|
| 105 |
+
"""Cache for a llama.cpp model using disk."""
|
| 106 |
+
|
| 107 |
+
def __init__(
|
| 108 |
+
self, cache_dir: str = ".cache/llama_cache", capacity_bytes: int = (2 << 30)
|
| 109 |
+
):
|
| 110 |
+
super().__init__(capacity_bytes)
|
| 111 |
+
self.cache = diskcache.Cache(cache_dir)
|
| 112 |
+
|
| 113 |
+
@property
|
| 114 |
+
def cache_size(self):
|
| 115 |
+
return int(self.cache.volume()) # type: ignore
|
| 116 |
+
|
| 117 |
+
def _find_longest_prefix_key(
|
| 118 |
+
self,
|
| 119 |
+
key: Tuple[int, ...],
|
| 120 |
+
) -> Optional[Tuple[int, ...]]:
|
| 121 |
+
min_len = 0
|
| 122 |
+
min_key: Optional[Tuple[int, ...]] = None
|
| 123 |
+
for k in self.cache.iterkeys(): # type: ignore
|
| 124 |
+
prefix_len = llama_cpp.llama.Llama.longest_token_prefix(k, key)
|
| 125 |
+
if prefix_len > min_len:
|
| 126 |
+
min_len = prefix_len
|
| 127 |
+
min_key = k # type: ignore
|
| 128 |
+
return min_key
|
| 129 |
+
|
| 130 |
+
def __getitem__(self, key: Sequence[int]) -> "llama_cpp.llama.LlamaState":
|
| 131 |
+
key = tuple(key)
|
| 132 |
+
_key = self._find_longest_prefix_key(key)
|
| 133 |
+
if _key is None:
|
| 134 |
+
raise KeyError("Key not found")
|
| 135 |
+
value: "llama_cpp.llama.LlamaState" = self.cache.pop(_key) # type: ignore
|
| 136 |
+
# NOTE: This puts an integer as key in cache, which breaks,
|
| 137 |
+
# Llama.longest_token_prefix(k, key) above since k is not a tuple of ints/tokens
|
| 138 |
+
# self.cache.push(_key, side="front") # type: ignore
|
| 139 |
+
return value
|
| 140 |
+
|
| 141 |
+
def __contains__(self, key: Sequence[int]) -> bool:
|
| 142 |
+
return self._find_longest_prefix_key(tuple(key)) is not None
|
| 143 |
+
|
| 144 |
+
def __setitem__(self, key: Sequence[int], value: "llama_cpp.llama.LlamaState"):
|
| 145 |
+
print("LlamaDiskCache.__setitem__: called", file=sys.stderr)
|
| 146 |
+
key = tuple(key)
|
| 147 |
+
if key in self.cache:
|
| 148 |
+
print("LlamaDiskCache.__setitem__: delete", file=sys.stderr)
|
| 149 |
+
del self.cache[key]
|
| 150 |
+
self.cache[key] = value
|
| 151 |
+
print("LlamaDiskCache.__setitem__: set", file=sys.stderr)
|
| 152 |
+
while self.cache_size > self.capacity_bytes and len(self.cache) > 0:
|
| 153 |
+
key_to_remove = next(iter(self.cache))
|
| 154 |
+
del self.cache[key_to_remove]
|
| 155 |
+
print("LlamaDiskCache.__setitem__: trim", file=sys.stderr)
|
llama_cpp/llama_chat_format.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
llama_cpp/llama_cpp.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
llama_cpp/llama_grammar.py
ADDED
|
@@ -0,0 +1,953 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Python implementation of llama grammar parser directly translated from C++ source file in vendor/llama.cpp/common/grammar-parser.cpp."""
|
| 2 |
+
|
| 3 |
+
# flake8: noqa
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
from itertools import groupby
|
| 7 |
+
from typing import (
|
| 8 |
+
Any,
|
| 9 |
+
Set,
|
| 10 |
+
List,
|
| 11 |
+
Optional,
|
| 12 |
+
Tuple,
|
| 13 |
+
Union,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
LLAMA_GRAMMAR_DEFAULT_ROOT = "root"
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class LlamaGrammar:
|
| 20 |
+
def __init__(self, *args, _grammar: str, **kwargs):
|
| 21 |
+
self._grammar = _grammar
|
| 22 |
+
self._root = LLAMA_GRAMMAR_DEFAULT_ROOT
|
| 23 |
+
|
| 24 |
+
@classmethod
|
| 25 |
+
def from_string(cls, grammar: str, verbose: bool = True) -> "LlamaGrammar":
|
| 26 |
+
return cls(_grammar=grammar)
|
| 27 |
+
|
| 28 |
+
@classmethod
|
| 29 |
+
def from_file(cls, file: Union[str, Path], verbose: bool = True) -> "LlamaGrammar":
|
| 30 |
+
try:
|
| 31 |
+
with open(file) as f:
|
| 32 |
+
grammar = f.read()
|
| 33 |
+
except Exception as err:
|
| 34 |
+
raise Exception(
|
| 35 |
+
f"{cls.from_file.__name__}: error reading grammar file: {err}"
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
if grammar:
|
| 39 |
+
return cls.from_string(grammar, verbose=verbose)
|
| 40 |
+
|
| 41 |
+
raise ValueError(
|
| 42 |
+
f"{cls.from_file.__name__}: error parsing grammar file: params_grammer is empty"
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
@classmethod
|
| 46 |
+
def from_json_schema(cls, json_schema: str, verbose: bool = True) -> "LlamaGrammar":
|
| 47 |
+
return cls.from_string(json_schema_to_gbnf(json_schema), verbose=verbose)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
"""llama.cpp gbnf rules from vendor/llama.cpp/grammars"""
|
| 51 |
+
|
| 52 |
+
ARITHMETIC_GBNF = r"""
|
| 53 |
+
root ::= (expr "=" ws term "\n")+
|
| 54 |
+
expr ::= term ([-+*/] term)*
|
| 55 |
+
term ::= ident | num | "(" ws expr ")" ws
|
| 56 |
+
ident ::= [a-z] [a-z0-9_]* ws
|
| 57 |
+
num ::= [0-9]+ ws
|
| 58 |
+
ws ::= [ \t\n]*
|
| 59 |
+
"""
|
| 60 |
+
|
| 61 |
+
C_GBNF = r"""
|
| 62 |
+
root ::= (declaration)*
|
| 63 |
+
|
| 64 |
+
declaration ::= dataType identifier "(" parameter? ")" "{" statement* "}"
|
| 65 |
+
|
| 66 |
+
dataType ::= "int" ws | "float" ws | "char" ws
|
| 67 |
+
identifier ::= [a-zA-Z_] [a-zA-Z_0-9]*
|
| 68 |
+
|
| 69 |
+
parameter ::= dataType identifier
|
| 70 |
+
|
| 71 |
+
statement ::=
|
| 72 |
+
( dataType identifier ws "=" ws expression ";" ) |
|
| 73 |
+
( identifier ws "=" ws expression ";" ) |
|
| 74 |
+
( identifier ws "(" argList? ")" ";" ) |
|
| 75 |
+
( "return" ws expression ";" ) |
|
| 76 |
+
( "while" "(" condition ")" "{" statement* "}" ) |
|
| 77 |
+
( "for" "(" forInit ";" ws condition ";" ws forUpdate ")" "{" statement* "}" ) |
|
| 78 |
+
( "if" "(" condition ")" "{" statement* "}" ("else" "{" statement* "}")? ) |
|
| 79 |
+
( singleLineComment ) |
|
| 80 |
+
( multiLineComment )
|
| 81 |
+
|
| 82 |
+
forInit ::= dataType identifier ws "=" ws expression | identifier ws "=" ws expression
|
| 83 |
+
forUpdate ::= identifier ws "=" ws expression
|
| 84 |
+
|
| 85 |
+
condition ::= expression relationOperator expression
|
| 86 |
+
relationOperator ::= ("<=" | "<" | "==" | "!=" | ">=" | ">")
|
| 87 |
+
|
| 88 |
+
expression ::= term (("+" | "-") term)*
|
| 89 |
+
term ::= factor(("*" | "/") factor)*
|
| 90 |
+
|
| 91 |
+
factor ::= identifier | number | unaryTerm | funcCall | parenExpression
|
| 92 |
+
unaryTerm ::= "-" factor
|
| 93 |
+
funcCall ::= identifier "(" argList? ")"
|
| 94 |
+
parenExpression ::= "(" ws expression ws ")"
|
| 95 |
+
|
| 96 |
+
argList ::= expression ("," ws expression)*
|
| 97 |
+
|
| 98 |
+
number ::= [0-9]+
|
| 99 |
+
|
| 100 |
+
singleLineComment ::= "//" [^\n]* "\n"
|
| 101 |
+
multiLineComment ::= "/*" ( [^*] | ("*" [^/]) )* "*/"
|
| 102 |
+
|
| 103 |
+
ws ::= ([ \t\n]+)
|
| 104 |
+
"""
|
| 105 |
+
|
| 106 |
+
CHESS_GBNF = r"""
|
| 107 |
+
root ::= object
|
| 108 |
+
value ::= object | array | string | number | ("true" | "false" | "null") ws
|
| 109 |
+
|
| 110 |
+
object ::=
|
| 111 |
+
"{" ws (
|
| 112 |
+
string ":" ws value
|
| 113 |
+
("," ws string ":" ws value)*
|
| 114 |
+
)? "}" ws
|
| 115 |
+
|
| 116 |
+
array ::=
|
| 117 |
+
"[" ws (
|
| 118 |
+
value
|
| 119 |
+
("," ws value)*
|
| 120 |
+
)? "]" ws
|
| 121 |
+
|
| 122 |
+
string ::=
|
| 123 |
+
"\"" (
|
| 124 |
+
[^"\\] |
|
| 125 |
+
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
|
| 126 |
+
)* "\"" ws
|
| 127 |
+
|
| 128 |
+
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws
|
| 129 |
+
|
| 130 |
+
# Optional space: by convention, applied in this grammar after literal chars when allowed
|
| 131 |
+
ws ::= ([ \t\n] ws)?
|
| 132 |
+
"""
|
| 133 |
+
|
| 134 |
+
JAPANESE_GBNF = r"""
|
| 135 |
+
root ::= object
|
| 136 |
+
value ::= object | array | string | number | ("true" | "false" | "null") ws
|
| 137 |
+
|
| 138 |
+
object ::=
|
| 139 |
+
"{" ws (
|
| 140 |
+
string ":" ws value
|
| 141 |
+
("," ws string ":" ws value)*
|
| 142 |
+
)? "}" ws
|
| 143 |
+
|
| 144 |
+
array ::=
|
| 145 |
+
"[" ws (
|
| 146 |
+
value
|
| 147 |
+
("," ws value)*
|
| 148 |
+
)? "]" ws
|
| 149 |
+
|
| 150 |
+
string ::=
|
| 151 |
+
"\"" (
|
| 152 |
+
[^"\\] |
|
| 153 |
+
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
|
| 154 |
+
)* "\"" ws
|
| 155 |
+
|
| 156 |
+
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws
|
| 157 |
+
|
| 158 |
+
# Optional space: by convention, applied in this grammar after literal chars when allowed
|
| 159 |
+
ws ::= ([ \t\n] ws)?
|
| 160 |
+
"""
|
| 161 |
+
|
| 162 |
+
JSON_ARR_GBNF = r"""
|
| 163 |
+
# This is the same as json.gbnf but we restrict whitespaces at the end of the root array
|
| 164 |
+
# Useful for generating JSON arrays
|
| 165 |
+
|
| 166 |
+
root ::= arr
|
| 167 |
+
value ::= object | array | string | number | ("true" | "false" | "null") ws
|
| 168 |
+
|
| 169 |
+
arr ::=
|
| 170 |
+
"[\n" ws (
|
| 171 |
+
value
|
| 172 |
+
(",\n" ws value)*
|
| 173 |
+
)? "]"
|
| 174 |
+
|
| 175 |
+
object ::=
|
| 176 |
+
"{" ws (
|
| 177 |
+
string ":" ws value
|
| 178 |
+
("," ws string ":" ws value)*
|
| 179 |
+
)? "}" ws
|
| 180 |
+
|
| 181 |
+
array ::=
|
| 182 |
+
"[" ws (
|
| 183 |
+
value
|
| 184 |
+
("," ws value)*
|
| 185 |
+
)? "]" ws
|
| 186 |
+
|
| 187 |
+
string ::=
|
| 188 |
+
"\"" (
|
| 189 |
+
[^"\\\x7F\x00-\x1F] |
|
| 190 |
+
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
|
| 191 |
+
)* "\"" ws
|
| 192 |
+
|
| 193 |
+
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws
|
| 194 |
+
|
| 195 |
+
# Optional space: by convention, applied in this grammar after literal chars when allowed
|
| 196 |
+
ws ::= ([ \t\n] ws)?
|
| 197 |
+
"""
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
JSON_GBNF = r"""
|
| 201 |
+
root ::= object
|
| 202 |
+
value ::= object | array | string | number | ("true" | "false" | "null") ws
|
| 203 |
+
|
| 204 |
+
object ::=
|
| 205 |
+
"{" ws (
|
| 206 |
+
string ":" ws value
|
| 207 |
+
("," ws string ":" ws value)*
|
| 208 |
+
)? "}" ws
|
| 209 |
+
|
| 210 |
+
array ::=
|
| 211 |
+
"[" ws (
|
| 212 |
+
value
|
| 213 |
+
("," ws value)*
|
| 214 |
+
)? "]" ws
|
| 215 |
+
|
| 216 |
+
string ::=
|
| 217 |
+
"\"" (
|
| 218 |
+
[^"\\\x7F\x00-\x1F] |
|
| 219 |
+
"\\" (["\\bfnrt] | "u" [0-9a-fA-F]{4}) # escapes
|
| 220 |
+
)* "\"" ws
|
| 221 |
+
|
| 222 |
+
number ::= ("-"? ([0-9] | [1-9] [0-9]{0,15})) ("." [0-9]+)? ([eE] [-+]? [0-9] [1-9]{0,15})? ws
|
| 223 |
+
|
| 224 |
+
# Optional space: by convention, applied in this grammar after literal chars when allowed
|
| 225 |
+
ws ::= | " " | "\n" [ \t]{0,20}
|
| 226 |
+
"""
|
| 227 |
+
|
| 228 |
+
LIST_GBNF = r"""
|
| 229 |
+
root ::= item+
|
| 230 |
+
|
| 231 |
+
# Excludes various line break characters
|
| 232 |
+
item ::= "- " [^\r\n\x0b\x0c\x85\u2028\u2029]+ "\n"
|
| 233 |
+
"""
|
| 234 |
+
|
| 235 |
+
"""llama.cpp json-schema to grammar converter from vendor/llama.cpp/examples/json-schema-to-grammar.py"""
|
| 236 |
+
import json
|
| 237 |
+
import re
|
| 238 |
+
from typing import List, Optional
|
| 239 |
+
|
| 240 |
+
# whitespace is constrained to a single space char to prevent model "running away" in
|
| 241 |
+
# whitespace. Also maybe improves generation quality?
|
| 242 |
+
SPACE_RULE = '" "?'
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
INVALID_RULE_CHARS_RE = re.compile(r"[^a-zA-Z0-9-]+")
|
| 246 |
+
GRAMMAR_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"]')
|
| 247 |
+
GRAMMAR_LITERAL_ESCAPES = {"\r": "\\r", "\n": "\\n", '"': '\\"'}
|
| 248 |
+
|
| 249 |
+
# whitespace is constrained to a single space char to prevent model "running away" in
|
| 250 |
+
# whitespace. Also maybe improves generation quality?
|
| 251 |
+
SPACE_RULE = '" "?'
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
def _build_repetition(
|
| 255 |
+
item_rule, min_items, max_items, separator_rule=None, item_rule_is_literal=False
|
| 256 |
+
):
|
| 257 |
+
if not separator_rule:
|
| 258 |
+
if min_items == 0 and max_items == 1:
|
| 259 |
+
return f"{item_rule}?"
|
| 260 |
+
elif min_items == 1 and max_items is None:
|
| 261 |
+
return f"{item_rule}+"
|
| 262 |
+
|
| 263 |
+
result = ""
|
| 264 |
+
|
| 265 |
+
if min_items > 0:
|
| 266 |
+
if item_rule_is_literal and separator_rule is None:
|
| 267 |
+
result = '"' + (item_rule[1:-1] * min_items) + '"'
|
| 268 |
+
else:
|
| 269 |
+
result = (f" {separator_rule} " if separator_rule else " ").join(
|
| 270 |
+
[item_rule] * min_items
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
def opt_repetitions(up_to_n, prefix_with_sep=False):
|
| 274 |
+
"""
|
| 275 |
+
- n=4, no sep: '(a (a (a (a)?)?)?)?'
|
| 276 |
+
- n=4, sep=',', prefix: '("," a ("," a ("," a ("," a)?)?)?)?'
|
| 277 |
+
- n=4, sep=',', no prefix: '(a ("," a ("," a ("," a)?)?)?)?'
|
| 278 |
+
"""
|
| 279 |
+
|
| 280 |
+
content = (
|
| 281 |
+
f"{separator_rule} {item_rule}"
|
| 282 |
+
if prefix_with_sep and separator_rule
|
| 283 |
+
else item_rule
|
| 284 |
+
)
|
| 285 |
+
if up_to_n == 0:
|
| 286 |
+
return ""
|
| 287 |
+
elif up_to_n == 1:
|
| 288 |
+
return f"({content})?"
|
| 289 |
+
elif separator_rule and not prefix_with_sep:
|
| 290 |
+
return f"({content} {opt_repetitions(up_to_n - 1, prefix_with_sep=True)})?"
|
| 291 |
+
else:
|
| 292 |
+
return (f"({content} " * up_to_n).rstrip() + (")?" * up_to_n)
|
| 293 |
+
|
| 294 |
+
if min_items > 0 and max_items != min_items:
|
| 295 |
+
result += " "
|
| 296 |
+
|
| 297 |
+
if max_items is not None:
|
| 298 |
+
result += opt_repetitions(max_items - min_items, prefix_with_sep=min_items > 0)
|
| 299 |
+
else:
|
| 300 |
+
item_operator = f'({separator_rule + " " if separator_rule else ""}{item_rule})'
|
| 301 |
+
|
| 302 |
+
if min_items == 0 and separator_rule:
|
| 303 |
+
result = f"({item_rule} {item_operator}*)?"
|
| 304 |
+
else:
|
| 305 |
+
result += f"{item_operator}*"
|
| 306 |
+
|
| 307 |
+
return result
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
class BuiltinRule:
|
| 311 |
+
def __init__(self, content: str, deps: list = None):
|
| 312 |
+
self.content = content
|
| 313 |
+
self.deps = deps or []
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
_up_to_15_digits = _build_repetition("[0-9]", 0, 15)
|
| 317 |
+
|
| 318 |
+
PRIMITIVE_RULES = {
|
| 319 |
+
"boolean": BuiltinRule('("true" | "false") space', []),
|
| 320 |
+
"decimal-part": BuiltinRule("[0-9] " + _up_to_15_digits, []),
|
| 321 |
+
"integral-part": BuiltinRule("[0-9] | [1-9] " + _up_to_15_digits, []),
|
| 322 |
+
"number": BuiltinRule(
|
| 323 |
+
'("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space',
|
| 324 |
+
["integral-part", "decimal-part"],
|
| 325 |
+
),
|
| 326 |
+
"integer": BuiltinRule('("-"? integral-part) space', ["integral-part"]),
|
| 327 |
+
"value": BuiltinRule(
|
| 328 |
+
"object | array | string | number | boolean | null",
|
| 329 |
+
["object", "array", "string", "number", "boolean", "null"],
|
| 330 |
+
),
|
| 331 |
+
"object": BuiltinRule(
|
| 332 |
+
'"{" space ( string ":" space value ("," space string ":" space value)* )? "}" space',
|
| 333 |
+
["string", "value"],
|
| 334 |
+
),
|
| 335 |
+
"array": BuiltinRule(
|
| 336 |
+
'"[" space ( value ("," space value)* )? "]" space', ["value"]
|
| 337 |
+
),
|
| 338 |
+
"uuid": BuiltinRule(
|
| 339 |
+
r'"\"" '
|
| 340 |
+
+ ' "-" '.join("[0-9a-fA-F]" * n for n in [8, 4, 4, 4, 12])
|
| 341 |
+
+ r' "\"" space',
|
| 342 |
+
[],
|
| 343 |
+
),
|
| 344 |
+
"char": BuiltinRule(
|
| 345 |
+
r'[^"\\] | "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])',
|
| 346 |
+
[],
|
| 347 |
+
),
|
| 348 |
+
"string": BuiltinRule(r'"\"" char* "\"" space', ["char"]),
|
| 349 |
+
"null": BuiltinRule('"null" space', []),
|
| 350 |
+
}
|
| 351 |
+
|
| 352 |
+
# TODO: support "uri", "email" string formats
|
| 353 |
+
STRING_FORMAT_RULES = {
|
| 354 |
+
"date": BuiltinRule(
|
| 355 |
+
'[0-9] [0-9] [0-9] [0-9] "-" ( "0" [1-9] | "1" [0-2] ) "-" ( "0" [1-9] | [1-2] [0-9] | "3" [0-1] )',
|
| 356 |
+
[],
|
| 357 |
+
),
|
| 358 |
+
"time": BuiltinRule(
|
| 359 |
+
'([01] [0-9] | "2" [0-3]) ":" [0-5] [0-9] ":" [0-5] [0-9] ( "." [0-9] [0-9] [0-9] )? ( "Z" | ( "+" | "-" ) ( [01] [0-9] | "2" [0-3] ) ":" [0-5] [0-9] )',
|
| 360 |
+
[],
|
| 361 |
+
),
|
| 362 |
+
"date-time": BuiltinRule('date "T" time', ["date", "time"]),
|
| 363 |
+
"date-string": BuiltinRule('"\\"" date "\\"" space', ["date"]),
|
| 364 |
+
"time-string": BuiltinRule('"\\"" time "\\"" space', ["time"]),
|
| 365 |
+
"date-time-string": BuiltinRule('"\\"" date-time "\\"" space', ["date-time"]),
|
| 366 |
+
}
|
| 367 |
+
|
| 368 |
+
DOTALL = "[\\U00000000-\\U0010FFFF]"
|
| 369 |
+
DOT = "[^\\x0A\\x0D]"
|
| 370 |
+
|
| 371 |
+
RESERVED_NAMES = set(
|
| 372 |
+
["root", "dot", *PRIMITIVE_RULES.keys(), *STRING_FORMAT_RULES.keys()]
|
| 373 |
+
)
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
NON_LITERAL_SET = set("|.()[]{}*+?")
|
| 377 |
+
ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = set("[]()|{}*+?")
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
class SchemaConverter:
|
| 381 |
+
def __init__(self, *, prop_order, allow_fetch, dotall, raw_pattern):
|
| 382 |
+
self._prop_order = prop_order
|
| 383 |
+
self._allow_fetch = allow_fetch
|
| 384 |
+
self._dotall = dotall
|
| 385 |
+
self._raw_pattern = raw_pattern
|
| 386 |
+
self._rules = {
|
| 387 |
+
"space": SPACE_RULE,
|
| 388 |
+
}
|
| 389 |
+
self._refs = {}
|
| 390 |
+
self._refs_being_resolved = set()
|
| 391 |
+
|
| 392 |
+
def _format_literal(self, literal):
|
| 393 |
+
escaped = GRAMMAR_LITERAL_ESCAPE_RE.sub(
|
| 394 |
+
lambda m: GRAMMAR_LITERAL_ESCAPES.get(m.group(0)), literal
|
| 395 |
+
)
|
| 396 |
+
return f'"{escaped}"'
|
| 397 |
+
|
| 398 |
+
def not_literal(
|
| 399 |
+
self, literal: str, dotall: bool = True, maybe_escaped_underscores=False
|
| 400 |
+
) -> str:
|
| 401 |
+
"""
|
| 402 |
+
not_literal('a') -> '[^a]'
|
| 403 |
+
not_literal('abc') -> '([^a] | "a" ([^b] | "b" ([^c])?)?)?'
|
| 404 |
+
"""
|
| 405 |
+
assert len(literal) > 0, "Empty literal not supported"
|
| 406 |
+
|
| 407 |
+
def recurse(i: int):
|
| 408 |
+
c = literal[i]
|
| 409 |
+
if maybe_escaped_underscores and c == "_":
|
| 410 |
+
yield f"[^{c}\\\\]"
|
| 411 |
+
yield " | "
|
| 412 |
+
yield f'"\\\\"? "{c}"'
|
| 413 |
+
else:
|
| 414 |
+
yield f"[^{c}]"
|
| 415 |
+
if i < len(literal) - 1:
|
| 416 |
+
yield " | "
|
| 417 |
+
yield self._format_literal(c)
|
| 418 |
+
yield " ("
|
| 419 |
+
yield from recurse(i + 1)
|
| 420 |
+
yield ")?"
|
| 421 |
+
|
| 422 |
+
return "".join(("(", *recurse(0), ")"))
|
| 423 |
+
|
| 424 |
+
def _add_rule(self, name, rule):
|
| 425 |
+
esc_name = INVALID_RULE_CHARS_RE.sub("-", name)
|
| 426 |
+
if esc_name not in self._rules or self._rules[esc_name] == rule:
|
| 427 |
+
key = esc_name
|
| 428 |
+
else:
|
| 429 |
+
i = 0
|
| 430 |
+
while (
|
| 431 |
+
f"{esc_name}{i}" in self._rules
|
| 432 |
+
and self._rules[f"{esc_name}{i}"] != rule
|
| 433 |
+
):
|
| 434 |
+
i += 1
|
| 435 |
+
key = f"{esc_name}{i}"
|
| 436 |
+
self._rules[key] = rule
|
| 437 |
+
return key
|
| 438 |
+
|
| 439 |
+
def resolve_refs(self, schema: dict, url: str):
|
| 440 |
+
"""
|
| 441 |
+
Resolves all $ref fields in the given schema, fetching any remote schemas,
|
| 442 |
+
replacing $ref with absolute reference URL and populating self._refs with the
|
| 443 |
+
respective referenced (sub)schema dictionaries.
|
| 444 |
+
"""
|
| 445 |
+
|
| 446 |
+
def visit(n: dict):
|
| 447 |
+
if isinstance(n, list):
|
| 448 |
+
return [visit(x) for x in n]
|
| 449 |
+
elif isinstance(n, dict):
|
| 450 |
+
ref = n.get("$ref")
|
| 451 |
+
if ref is not None and ref not in self._refs:
|
| 452 |
+
if ref.startswith("https://"):
|
| 453 |
+
assert (
|
| 454 |
+
self._allow_fetch
|
| 455 |
+
), "Fetching remote schemas is not allowed (use --allow-fetch for force)"
|
| 456 |
+
import requests
|
| 457 |
+
|
| 458 |
+
frag_split = ref.split("#")
|
| 459 |
+
base_url = frag_split[0]
|
| 460 |
+
|
| 461 |
+
target = self._refs.get(base_url)
|
| 462 |
+
if target is None:
|
| 463 |
+
target = self.resolve_refs(
|
| 464 |
+
requests.get(ref).json(), base_url
|
| 465 |
+
)
|
| 466 |
+
self._refs[base_url] = target
|
| 467 |
+
|
| 468 |
+
if len(frag_split) == 1 or frag_split[-1] == "":
|
| 469 |
+
return target
|
| 470 |
+
elif ref.startswith("#/"):
|
| 471 |
+
target = schema
|
| 472 |
+
ref = f"{url}{ref}"
|
| 473 |
+
n["$ref"] = ref
|
| 474 |
+
else:
|
| 475 |
+
raise ValueError(f"Unsupported ref {ref}")
|
| 476 |
+
|
| 477 |
+
for sel in ref.split("#")[-1].split("/")[1:]:
|
| 478 |
+
assert (
|
| 479 |
+
target is not None and sel in target
|
| 480 |
+
), f"Error resolving ref {ref}: {sel} not in {target}"
|
| 481 |
+
target = target[sel]
|
| 482 |
+
|
| 483 |
+
self._refs[ref] = target
|
| 484 |
+
else:
|
| 485 |
+
for v in n.values():
|
| 486 |
+
visit(v)
|
| 487 |
+
|
| 488 |
+
return n
|
| 489 |
+
|
| 490 |
+
return visit(schema)
|
| 491 |
+
|
| 492 |
+
def _generate_union_rule(self, name, alt_schemas):
|
| 493 |
+
return " | ".join(
|
| 494 |
+
(
|
| 495 |
+
self.visit(alt_schema, f'{name}{"-" if name else "alternative-"}{i}')
|
| 496 |
+
for i, alt_schema in enumerate(alt_schemas)
|
| 497 |
+
)
|
| 498 |
+
)
|
| 499 |
+
|
| 500 |
+
def _visit_pattern(self, pattern, name):
|
| 501 |
+
"""
|
| 502 |
+
Transforms a regular expression pattern into a GBNF rule.
|
| 503 |
+
|
| 504 |
+
Input: https://json-schema.org/understanding-json-schema/reference/regular_expressions
|
| 505 |
+
Output: https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md
|
| 506 |
+
|
| 507 |
+
Unsupported features: negative/positive lookaheads, greedy/non-greedy modifiers.
|
| 508 |
+
|
| 509 |
+
Mostly a 1:1 translation, except for {x} / {x,} / {x,y} quantifiers for which
|
| 510 |
+
we define sub-rules to keep the output lean.
|
| 511 |
+
"""
|
| 512 |
+
|
| 513 |
+
assert pattern.startswith("^") and pattern.endswith(
|
| 514 |
+
"$"
|
| 515 |
+
), 'Pattern must start with "^" and end with "$"'
|
| 516 |
+
pattern = pattern[1:-1]
|
| 517 |
+
sub_rule_ids = {}
|
| 518 |
+
|
| 519 |
+
i = 0
|
| 520 |
+
length = len(pattern)
|
| 521 |
+
|
| 522 |
+
def to_rule(s: Tuple[str, bool]) -> str:
|
| 523 |
+
(txt, is_literal) = s
|
| 524 |
+
return '"' + txt + '"' if is_literal else txt
|
| 525 |
+
|
| 526 |
+
def transform() -> Tuple[str, bool]:
|
| 527 |
+
"""
|
| 528 |
+
Parse a unit at index i (advancing it), and return its string representation + whether it's a literal.
|
| 529 |
+
"""
|
| 530 |
+
nonlocal i
|
| 531 |
+
nonlocal pattern
|
| 532 |
+
nonlocal sub_rule_ids
|
| 533 |
+
|
| 534 |
+
start = i
|
| 535 |
+
# For each component of this sequence, store its string representation and whether it's a literal.
|
| 536 |
+
# We only need a flat structure here to apply repetition operators to the last item, and
|
| 537 |
+
# to merge literals at the and (we're parsing grouped ( sequences ) recursively and don't treat '|' specially
|
| 538 |
+
# (GBNF's syntax is luckily very close to regular expressions!)
|
| 539 |
+
seq: list[Tuple[str, bool]] = []
|
| 540 |
+
|
| 541 |
+
def get_dot():
|
| 542 |
+
if self._dotall:
|
| 543 |
+
rule = DOTALL
|
| 544 |
+
else:
|
| 545 |
+
# Accept any character... except \n and \r line break chars (\x0A and \xOD)
|
| 546 |
+
rule = DOT
|
| 547 |
+
return self._add_rule(f"dot", rule)
|
| 548 |
+
|
| 549 |
+
def join_seq():
|
| 550 |
+
nonlocal seq
|
| 551 |
+
ret = []
|
| 552 |
+
for is_literal, g in groupby(seq, lambda x: x[1]):
|
| 553 |
+
if is_literal:
|
| 554 |
+
ret.append(("".join(x[0] for x in g), True))
|
| 555 |
+
else:
|
| 556 |
+
ret.extend(g)
|
| 557 |
+
if len(ret) == 1:
|
| 558 |
+
return ret[0]
|
| 559 |
+
return (" ".join(to_rule(x) for x in seq), False)
|
| 560 |
+
|
| 561 |
+
while i < length:
|
| 562 |
+
c = pattern[i]
|
| 563 |
+
if c == ".":
|
| 564 |
+
seq.append((get_dot(), False))
|
| 565 |
+
i += 1
|
| 566 |
+
elif c == "(":
|
| 567 |
+
i += 1
|
| 568 |
+
if i < length:
|
| 569 |
+
assert (
|
| 570 |
+
pattern[i] != "?"
|
| 571 |
+
), f'Unsupported pattern syntax "{pattern[i]}" at index {i} of /{pattern}/'
|
| 572 |
+
seq.append((f"({to_rule(transform())})", False))
|
| 573 |
+
elif c == ")":
|
| 574 |
+
i += 1
|
| 575 |
+
assert (
|
| 576 |
+
start > 0 and pattern[start - 1] == "("
|
| 577 |
+
), f"Unbalanced parentheses; start = {start}, i = {i}, pattern = {pattern}"
|
| 578 |
+
return join_seq()
|
| 579 |
+
elif c == "[":
|
| 580 |
+
square_brackets = c
|
| 581 |
+
i += 1
|
| 582 |
+
while i < length and pattern[i] != "]":
|
| 583 |
+
if pattern[i] == "\\":
|
| 584 |
+
square_brackets += pattern[i : i + 2]
|
| 585 |
+
i += 2
|
| 586 |
+
else:
|
| 587 |
+
square_brackets += pattern[i]
|
| 588 |
+
i += 1
|
| 589 |
+
assert (
|
| 590 |
+
i < length
|
| 591 |
+
), f"Unbalanced square brackets; start = {start}, i = {i}, pattern = {pattern}"
|
| 592 |
+
square_brackets += "]"
|
| 593 |
+
i += 1
|
| 594 |
+
seq.append((square_brackets, False))
|
| 595 |
+
elif c == "|":
|
| 596 |
+
seq.append(("|", False))
|
| 597 |
+
i += 1
|
| 598 |
+
elif c in ("*", "+", "?"):
|
| 599 |
+
seq[-1] = (to_rule(seq[-1]) + c, False)
|
| 600 |
+
i += 1
|
| 601 |
+
elif c == "{":
|
| 602 |
+
curly_brackets = c
|
| 603 |
+
i += 1
|
| 604 |
+
while i < length and pattern[i] != "}":
|
| 605 |
+
curly_brackets += pattern[i]
|
| 606 |
+
i += 1
|
| 607 |
+
assert (
|
| 608 |
+
i < length
|
| 609 |
+
), f"Unbalanced curly brackets; start = {start}, i = {i}, pattern = {pattern}"
|
| 610 |
+
curly_brackets += "}"
|
| 611 |
+
i += 1
|
| 612 |
+
nums = [s.strip() for s in curly_brackets[1:-1].split(",")]
|
| 613 |
+
min_times = 0
|
| 614 |
+
max_times = None
|
| 615 |
+
try:
|
| 616 |
+
if len(nums) == 1:
|
| 617 |
+
min_times = int(nums[0])
|
| 618 |
+
max_times = min_times
|
| 619 |
+
else:
|
| 620 |
+
assert len(nums) == 2
|
| 621 |
+
min_times = int(nums[0]) if nums[0] else 0
|
| 622 |
+
max_times = int(nums[1]) if nums[1] else None
|
| 623 |
+
except ValueError:
|
| 624 |
+
raise ValueError(
|
| 625 |
+
f"Invalid quantifier {curly_brackets} in /{pattern}/"
|
| 626 |
+
)
|
| 627 |
+
|
| 628 |
+
(sub, sub_is_literal) = seq[-1]
|
| 629 |
+
|
| 630 |
+
if not sub_is_literal:
|
| 631 |
+
id = sub_rule_ids.get(sub)
|
| 632 |
+
if id is None:
|
| 633 |
+
id = self._add_rule(f"{name}-{len(sub_rule_ids) + 1}", sub)
|
| 634 |
+
sub_rule_ids[sub] = id
|
| 635 |
+
sub = id
|
| 636 |
+
|
| 637 |
+
seq[-1] = (
|
| 638 |
+
_build_repetition(
|
| 639 |
+
f'"{sub}"' if sub_is_literal else sub,
|
| 640 |
+
min_times,
|
| 641 |
+
max_times,
|
| 642 |
+
item_rule_is_literal=sub_is_literal,
|
| 643 |
+
),
|
| 644 |
+
False,
|
| 645 |
+
)
|
| 646 |
+
else:
|
| 647 |
+
literal = ""
|
| 648 |
+
while i < length:
|
| 649 |
+
if pattern[i] == "\\" and i < length - 1:
|
| 650 |
+
next = pattern[i + 1]
|
| 651 |
+
if next in ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS:
|
| 652 |
+
i += 1
|
| 653 |
+
literal += pattern[i]
|
| 654 |
+
i += 1
|
| 655 |
+
else:
|
| 656 |
+
literal += pattern[i : i + 2]
|
| 657 |
+
i += 2
|
| 658 |
+
elif pattern[i] == '"' and not self._raw_pattern:
|
| 659 |
+
literal += '\\"'
|
| 660 |
+
i += 1
|
| 661 |
+
elif pattern[i] not in NON_LITERAL_SET and (
|
| 662 |
+
i == length - 1
|
| 663 |
+
or literal == ""
|
| 664 |
+
or pattern[i + 1] == "."
|
| 665 |
+
or pattern[i + 1] not in NON_LITERAL_SET
|
| 666 |
+
):
|
| 667 |
+
literal += pattern[i]
|
| 668 |
+
i += 1
|
| 669 |
+
else:
|
| 670 |
+
break
|
| 671 |
+
if literal:
|
| 672 |
+
seq.append((literal, True))
|
| 673 |
+
|
| 674 |
+
return join_seq()
|
| 675 |
+
|
| 676 |
+
return self._add_rule(
|
| 677 |
+
name,
|
| 678 |
+
(
|
| 679 |
+
to_rule(transform())
|
| 680 |
+
if self._raw_pattern
|
| 681 |
+
else '"\\"" ' + to_rule(transform()) + ' "\\"" space'
|
| 682 |
+
),
|
| 683 |
+
)
|
| 684 |
+
|
| 685 |
+
def _resolve_ref(self, ref):
|
| 686 |
+
ref_name = ref.split("/")[-1]
|
| 687 |
+
if ref_name not in self._rules and ref not in self._refs_being_resolved:
|
| 688 |
+
self._refs_being_resolved.add(ref)
|
| 689 |
+
resolved = self._refs[ref]
|
| 690 |
+
ref_name = self.visit(resolved, ref_name)
|
| 691 |
+
self._refs_being_resolved.remove(ref)
|
| 692 |
+
return ref_name
|
| 693 |
+
|
| 694 |
+
def _generate_constant_rule(self, value):
|
| 695 |
+
return self._format_literal(json.dumps(value))
|
| 696 |
+
|
| 697 |
+
def visit(self, schema, name):
|
| 698 |
+
schema_type = schema.get("type")
|
| 699 |
+
schema_format = schema.get("format")
|
| 700 |
+
rule_name = name + "-" if name in RESERVED_NAMES else name or "root"
|
| 701 |
+
|
| 702 |
+
if (ref := schema.get("$ref")) is not None:
|
| 703 |
+
return self._add_rule(rule_name, self._resolve_ref(ref))
|
| 704 |
+
|
| 705 |
+
elif "oneOf" in schema or "anyOf" in schema:
|
| 706 |
+
return self._add_rule(
|
| 707 |
+
rule_name,
|
| 708 |
+
self._generate_union_rule(name, schema.get("oneOf") or schema["anyOf"]),
|
| 709 |
+
)
|
| 710 |
+
|
| 711 |
+
elif isinstance(schema_type, list):
|
| 712 |
+
return self._add_rule(
|
| 713 |
+
rule_name,
|
| 714 |
+
self._generate_union_rule(name, [{"type": t} for t in schema_type]),
|
| 715 |
+
)
|
| 716 |
+
|
| 717 |
+
elif "const" in schema:
|
| 718 |
+
return self._add_rule(
|
| 719 |
+
rule_name, self._generate_constant_rule(schema["const"])
|
| 720 |
+
)
|
| 721 |
+
|
| 722 |
+
elif "enum" in schema:
|
| 723 |
+
rule = " | ".join((self._generate_constant_rule(v) for v in schema["enum"]))
|
| 724 |
+
return self._add_rule(rule_name, rule)
|
| 725 |
+
|
| 726 |
+
elif schema_type in (None, "object") and (
|
| 727 |
+
"properties" in schema
|
| 728 |
+
or (
|
| 729 |
+
"additionalProperties" in schema
|
| 730 |
+
and schema["additionalProperties"] is not True
|
| 731 |
+
)
|
| 732 |
+
):
|
| 733 |
+
required = set(schema.get("required", []))
|
| 734 |
+
properties = list(schema.get("properties", {}).items())
|
| 735 |
+
return self._add_rule(
|
| 736 |
+
rule_name,
|
| 737 |
+
self._build_object_rule(
|
| 738 |
+
properties, required, name, schema.get("additionalProperties")
|
| 739 |
+
),
|
| 740 |
+
)
|
| 741 |
+
|
| 742 |
+
elif schema_type in (None, "object") and "allOf" in schema:
|
| 743 |
+
required = set()
|
| 744 |
+
properties = []
|
| 745 |
+
hybrid_name = name
|
| 746 |
+
|
| 747 |
+
def add_component(comp_schema, is_required):
|
| 748 |
+
if (ref := comp_schema.get("$ref")) is not None:
|
| 749 |
+
comp_schema = self._refs[ref]
|
| 750 |
+
|
| 751 |
+
if "properties" in comp_schema:
|
| 752 |
+
for prop_name, prop_schema in comp_schema["properties"].items():
|
| 753 |
+
properties.append((prop_name, prop_schema))
|
| 754 |
+
if is_required:
|
| 755 |
+
required.add(prop_name)
|
| 756 |
+
|
| 757 |
+
for t in schema["allOf"]:
|
| 758 |
+
if "anyOf" in t:
|
| 759 |
+
for tt in t["anyOf"]:
|
| 760 |
+
add_component(tt, is_required=False)
|
| 761 |
+
else:
|
| 762 |
+
add_component(t, is_required=True)
|
| 763 |
+
|
| 764 |
+
return self._add_rule(
|
| 765 |
+
rule_name,
|
| 766 |
+
self._build_object_rule(
|
| 767 |
+
properties, required, hybrid_name, additional_properties=[]
|
| 768 |
+
),
|
| 769 |
+
)
|
| 770 |
+
|
| 771 |
+
elif schema_type in (None, "array") and (
|
| 772 |
+
"items" in schema or "prefixItems" in schema
|
| 773 |
+
):
|
| 774 |
+
items = schema.get("items") or schema["prefixItems"]
|
| 775 |
+
if isinstance(items, list):
|
| 776 |
+
return self._add_rule(
|
| 777 |
+
rule_name,
|
| 778 |
+
'"[" space '
|
| 779 |
+
+ ' "," space '.join(
|
| 780 |
+
self.visit(item, f'{name}{"-" if name else ""}tuple-{i}')
|
| 781 |
+
for i, item in enumerate(items)
|
| 782 |
+
)
|
| 783 |
+
+ ' "]" space',
|
| 784 |
+
)
|
| 785 |
+
else:
|
| 786 |
+
item_rule_name = self.visit(items, f'{name}{"-" if name else ""}item')
|
| 787 |
+
min_items = schema.get("minItems", 0)
|
| 788 |
+
max_items = schema.get("maxItems")
|
| 789 |
+
return self._add_rule(
|
| 790 |
+
rule_name,
|
| 791 |
+
'"[" space '
|
| 792 |
+
+ _build_repetition(
|
| 793 |
+
item_rule_name, min_items, max_items, separator_rule='"," space'
|
| 794 |
+
)
|
| 795 |
+
+ ' "]" space',
|
| 796 |
+
)
|
| 797 |
+
|
| 798 |
+
elif schema_type in (None, "string") and "pattern" in schema:
|
| 799 |
+
return self._visit_pattern(schema["pattern"], rule_name)
|
| 800 |
+
|
| 801 |
+
elif schema_type in (None, "string") and re.match(
|
| 802 |
+
r"^uuid[1-5]?$", schema_format or ""
|
| 803 |
+
):
|
| 804 |
+
return self._add_primitive(
|
| 805 |
+
"root" if rule_name == "root" else schema_format,
|
| 806 |
+
PRIMITIVE_RULES["uuid"],
|
| 807 |
+
)
|
| 808 |
+
|
| 809 |
+
elif (
|
| 810 |
+
schema_type in (None, "string")
|
| 811 |
+
and f"{schema_format}-string" in STRING_FORMAT_RULES
|
| 812 |
+
):
|
| 813 |
+
prim_name = f"{schema_format}-string"
|
| 814 |
+
return self._add_rule(
|
| 815 |
+
rule_name,
|
| 816 |
+
self._add_primitive(prim_name, STRING_FORMAT_RULES[prim_name]),
|
| 817 |
+
)
|
| 818 |
+
|
| 819 |
+
elif schema_type == "string" and (
|
| 820 |
+
"minLength" in schema or "maxLength" in schema
|
| 821 |
+
):
|
| 822 |
+
char_rule = self._add_primitive("char", PRIMITIVE_RULES["char"])
|
| 823 |
+
min_len = schema.get("minLength", 0)
|
| 824 |
+
max_len = schema.get("maxLength")
|
| 825 |
+
|
| 826 |
+
return self._add_rule(
|
| 827 |
+
rule_name,
|
| 828 |
+
r'"\"" '
|
| 829 |
+
+ _build_repetition(char_rule, min_len, max_len)
|
| 830 |
+
+ r' "\"" space',
|
| 831 |
+
)
|
| 832 |
+
|
| 833 |
+
elif (schema_type == "object") or (len(schema) == 0):
|
| 834 |
+
return self._add_rule(
|
| 835 |
+
rule_name, self._add_primitive("object", PRIMITIVE_RULES["object"])
|
| 836 |
+
)
|
| 837 |
+
|
| 838 |
+
else:
|
| 839 |
+
assert schema_type in PRIMITIVE_RULES, f"Unrecognized schema: {schema}"
|
| 840 |
+
# TODO: support minimum, maximum, exclusiveMinimum, exclusiveMaximum at least for zero
|
| 841 |
+
return self._add_primitive(
|
| 842 |
+
"root" if rule_name == "root" else schema_type,
|
| 843 |
+
PRIMITIVE_RULES[schema_type],
|
| 844 |
+
)
|
| 845 |
+
|
| 846 |
+
def _add_primitive(self, name: str, rule: BuiltinRule):
|
| 847 |
+
n = self._add_rule(name, rule.content)
|
| 848 |
+
|
| 849 |
+
for dep in rule.deps:
|
| 850 |
+
dep_rule = PRIMITIVE_RULES.get(dep) or STRING_FORMAT_RULES.get(dep)
|
| 851 |
+
assert dep_rule, f"Rule {dep} not known"
|
| 852 |
+
if dep not in self._rules:
|
| 853 |
+
self._add_primitive(dep, dep_rule)
|
| 854 |
+
return n
|
| 855 |
+
|
| 856 |
+
def _build_object_rule(
|
| 857 |
+
self,
|
| 858 |
+
properties: List[Tuple[str, Any]],
|
| 859 |
+
required: Set[str],
|
| 860 |
+
name: str,
|
| 861 |
+
additional_properties: Union[bool, Any],
|
| 862 |
+
):
|
| 863 |
+
prop_order = self._prop_order
|
| 864 |
+
# sort by position in prop_order (if specified) then by original order
|
| 865 |
+
sorted_props = [
|
| 866 |
+
kv[0]
|
| 867 |
+
for _, kv in sorted(
|
| 868 |
+
enumerate(properties),
|
| 869 |
+
key=lambda ikv: (prop_order.get(ikv[1][0], len(prop_order)), ikv[0]),
|
| 870 |
+
)
|
| 871 |
+
]
|
| 872 |
+
|
| 873 |
+
prop_kv_rule_names = {}
|
| 874 |
+
for prop_name, prop_schema in properties:
|
| 875 |
+
prop_rule_name = self.visit(
|
| 876 |
+
prop_schema, f'{name}{"-" if name else ""}{prop_name}'
|
| 877 |
+
)
|
| 878 |
+
prop_kv_rule_names[prop_name] = self._add_rule(
|
| 879 |
+
f'{name}{"-" if name else ""}{prop_name}-kv',
|
| 880 |
+
rf'{self._format_literal(json.dumps(prop_name))} space ":" space {prop_rule_name}',
|
| 881 |
+
)
|
| 882 |
+
required_props = [k for k in sorted_props if k in required]
|
| 883 |
+
optional_props = [k for k in sorted_props if k not in required]
|
| 884 |
+
|
| 885 |
+
if additional_properties == True or isinstance(additional_properties, dict):
|
| 886 |
+
sub_name = f'{name}{"-" if name else ""}additional'
|
| 887 |
+
value_rule = self.visit(
|
| 888 |
+
{} if additional_properties == True else additional_properties,
|
| 889 |
+
f"{sub_name}-value",
|
| 890 |
+
)
|
| 891 |
+
prop_kv_rule_names["*"] = self._add_rule(
|
| 892 |
+
f"{sub_name}-kv",
|
| 893 |
+
self._add_primitive("string", PRIMITIVE_RULES["string"])
|
| 894 |
+
+ f' ":" space {value_rule}',
|
| 895 |
+
)
|
| 896 |
+
optional_props.append("*")
|
| 897 |
+
|
| 898 |
+
rule = '"{" space '
|
| 899 |
+
rule += ' "," space '.join(prop_kv_rule_names[k] for k in required_props)
|
| 900 |
+
|
| 901 |
+
if optional_props:
|
| 902 |
+
rule += " ("
|
| 903 |
+
if required_props:
|
| 904 |
+
rule += ' "," space ( '
|
| 905 |
+
|
| 906 |
+
def get_recursive_refs(ks, first_is_optional):
|
| 907 |
+
[k, *rest] = ks
|
| 908 |
+
kv_rule_name = prop_kv_rule_names[k]
|
| 909 |
+
if k == "*":
|
| 910 |
+
res = self._add_rule(
|
| 911 |
+
f'{name}{"-" if name else ""}additional-kvs',
|
| 912 |
+
f'{kv_rule_name} ( "," space ' + kv_rule_name + " )*",
|
| 913 |
+
)
|
| 914 |
+
elif first_is_optional:
|
| 915 |
+
res = f'( "," space {kv_rule_name} )?'
|
| 916 |
+
else:
|
| 917 |
+
res = kv_rule_name
|
| 918 |
+
if len(rest) > 0:
|
| 919 |
+
res += " " + self._add_rule(
|
| 920 |
+
f'{name}{"-" if name else ""}{k}-rest',
|
| 921 |
+
get_recursive_refs(rest, first_is_optional=True),
|
| 922 |
+
)
|
| 923 |
+
return res
|
| 924 |
+
|
| 925 |
+
rule += " | ".join(
|
| 926 |
+
get_recursive_refs(optional_props[i:], first_is_optional=False)
|
| 927 |
+
for i in range(len(optional_props))
|
| 928 |
+
)
|
| 929 |
+
if required_props:
|
| 930 |
+
rule += " )"
|
| 931 |
+
rule += " )?"
|
| 932 |
+
|
| 933 |
+
rule += ' "}" space'
|
| 934 |
+
|
| 935 |
+
return rule
|
| 936 |
+
|
| 937 |
+
def format_grammar(self):
|
| 938 |
+
return "\n".join(
|
| 939 |
+
f"{name} ::= {rule}"
|
| 940 |
+
for name, rule in sorted(self._rules.items(), key=lambda kv: kv[0])
|
| 941 |
+
)
|
| 942 |
+
|
| 943 |
+
|
| 944 |
+
def json_schema_to_gbnf(schema: str, prop_order: Optional[List[str]] = None):
|
| 945 |
+
prop_order = prop_order or []
|
| 946 |
+
schema = json.loads(schema)
|
| 947 |
+
prop_order = {name: idx for idx, name in enumerate(prop_order)}
|
| 948 |
+
converter = SchemaConverter(
|
| 949 |
+
prop_order=prop_order, allow_fetch=False, dotall=False, raw_pattern=False
|
| 950 |
+
)
|
| 951 |
+
schema = converter.resolve_refs(schema, "stdin")
|
| 952 |
+
converter.visit(schema, "")
|
| 953 |
+
return converter.format_grammar()
|
llama_cpp/llama_speculative.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import abc
|
| 2 |
+
|
| 3 |
+
from typing import Any
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import numpy.typing as npt
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class LlamaDraftModel(abc.ABC):
|
| 10 |
+
@abc.abstractmethod
|
| 11 |
+
def __call__(
|
| 12 |
+
self, input_ids: npt.NDArray[np.intc], /, **kwargs: Any
|
| 13 |
+
) -> npt.NDArray[np.intc]:
|
| 14 |
+
raise NotImplementedError()
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class LlamaPromptLookupDecoding(LlamaDraftModel):
|
| 18 |
+
"""Based on https://github.com/apoorvumang/prompt-lookup-decoding"""
|
| 19 |
+
|
| 20 |
+
def __init__(self, max_ngram_size: int = 2, num_pred_tokens: int = 10):
|
| 21 |
+
self.max_ngram_size = max_ngram_size
|
| 22 |
+
self.num_pred_tokens = num_pred_tokens
|
| 23 |
+
|
| 24 |
+
@staticmethod
|
| 25 |
+
def find_candidate_pred_tokens(
|
| 26 |
+
input_ids: npt.NDArray[np.intc],
|
| 27 |
+
max_ngram_size: int,
|
| 28 |
+
num_pred_tokens: int,
|
| 29 |
+
):
|
| 30 |
+
input_length = input_ids.shape[0]
|
| 31 |
+
|
| 32 |
+
for ngram_size in range(min(max_ngram_size, input_length - 1), 0, -1):
|
| 33 |
+
# Create sliding windows of size ngram_size
|
| 34 |
+
windows = np.lib.stride_tricks.sliding_window_view(input_ids, (ngram_size,))
|
| 35 |
+
|
| 36 |
+
# Convert ngram to an array for comparison
|
| 37 |
+
ngram_array = input_ids[-ngram_size:]
|
| 38 |
+
|
| 39 |
+
# Find where the windows match the ngram
|
| 40 |
+
matches = np.all(windows == ngram_array, axis=1)
|
| 41 |
+
|
| 42 |
+
# Get the indices of matches
|
| 43 |
+
match_indices = np.nonzero(matches)[0]
|
| 44 |
+
|
| 45 |
+
# Iterate through match indices to find a valid continuation
|
| 46 |
+
for idx in match_indices:
|
| 47 |
+
start_idx = idx + ngram_size
|
| 48 |
+
end_idx = start_idx + num_pred_tokens
|
| 49 |
+
end_idx = min(end_idx, input_length)
|
| 50 |
+
|
| 51 |
+
if start_idx < end_idx:
|
| 52 |
+
return input_ids[start_idx:end_idx]
|
| 53 |
+
|
| 54 |
+
# If no match is found, return an empty array
|
| 55 |
+
return np.array([], dtype=np.intc)
|
| 56 |
+
|
| 57 |
+
def __call__(
|
| 58 |
+
self, input_ids: npt.NDArray[np.intc], /, **kwargs: Any
|
| 59 |
+
) -> npt.NDArray[np.intc]:
|
| 60 |
+
return self.find_candidate_pred_tokens(
|
| 61 |
+
input_ids=input_ids,
|
| 62 |
+
max_ngram_size=self.max_ngram_size,
|
| 63 |
+
num_pred_tokens=self.num_pred_tokens,
|
| 64 |
+
)
|
llama_cpp/llama_tokenizer.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import abc
|
| 4 |
+
from typing import (
|
| 5 |
+
List,
|
| 6 |
+
Optional,
|
| 7 |
+
Any,
|
| 8 |
+
)
|
| 9 |
+
|
| 10 |
+
import llama_cpp
|
| 11 |
+
from llama_cpp.llama_types import List
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class BaseLlamaTokenizer(abc.ABC):
|
| 15 |
+
@abc.abstractmethod
|
| 16 |
+
def tokenize(
|
| 17 |
+
self, text: bytes, add_bos: bool = True, special: bool = True
|
| 18 |
+
) -> List[int]:
|
| 19 |
+
"""Tokenize the text into tokens.
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
text: The utf-8 encoded string to tokenize.
|
| 23 |
+
add_bos: Whether to add a beginning of sequence token.
|
| 24 |
+
special: Whether to tokenize special tokens.
|
| 25 |
+
"""
|
| 26 |
+
raise NotImplementedError
|
| 27 |
+
|
| 28 |
+
@abc.abstractmethod
|
| 29 |
+
def detokenize(
|
| 30 |
+
self,
|
| 31 |
+
tokens: List[int],
|
| 32 |
+
prev_tokens: Optional[List[int]] = None,
|
| 33 |
+
special: bool = False,
|
| 34 |
+
) -> bytes:
|
| 35 |
+
"""Detokenize the tokens into text.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
tokens: The list of tokens to detokenize.
|
| 39 |
+
prev_tokens: The list of previous tokens. Offset mapping will be performed if provided.
|
| 40 |
+
special: Whether to detokenize special tokens.
|
| 41 |
+
"""
|
| 42 |
+
raise NotImplementedError
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class LlamaTokenizer(BaseLlamaTokenizer):
|
| 46 |
+
def __init__(self, llama: llama_cpp.Llama):
|
| 47 |
+
self._model = llama._model # type: ignore
|
| 48 |
+
|
| 49 |
+
def tokenize(
|
| 50 |
+
self, text: bytes, add_bos: bool = True, special: bool = True
|
| 51 |
+
) -> List[int]:
|
| 52 |
+
return self._model.tokenize(text, add_bos=add_bos, special=special)
|
| 53 |
+
|
| 54 |
+
def detokenize(
|
| 55 |
+
self,
|
| 56 |
+
tokens: List[int],
|
| 57 |
+
prev_tokens: Optional[List[int]] = None,
|
| 58 |
+
special: bool = False,
|
| 59 |
+
) -> bytes:
|
| 60 |
+
return self._model.detokenize(tokens, special=special)
|
| 61 |
+
|
| 62 |
+
def encode(
|
| 63 |
+
self, text: str, add_bos: bool = True, special: bool = True
|
| 64 |
+
) -> List[int]:
|
| 65 |
+
return self.tokenize(
|
| 66 |
+
text.encode("utf-8", errors="ignore"), add_bos=add_bos, special=special
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
def decode(self, tokens: List[int]) -> str:
|
| 70 |
+
return self.detokenize(tokens).decode("utf-8", errors="ignore")
|
| 71 |
+
|
| 72 |
+
@classmethod
|
| 73 |
+
def from_ggml_file(cls, path: str) -> "LlamaTokenizer":
|
| 74 |
+
return cls(llama_cpp.Llama(model_path=path, vocab_only=True))
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class LlamaHFTokenizer(BaseLlamaTokenizer):
|
| 78 |
+
def __init__(self, hf_tokenizer: Any):
|
| 79 |
+
self.hf_tokenizer = hf_tokenizer
|
| 80 |
+
|
| 81 |
+
def tokenize(
|
| 82 |
+
self, text: bytes, add_bos: bool = True, special: bool = True
|
| 83 |
+
) -> List[int]:
|
| 84 |
+
return self.hf_tokenizer.encode(
|
| 85 |
+
text.decode("utf-8", errors="ignore"), add_special_tokens=special
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
def detokenize(
|
| 89 |
+
self,
|
| 90 |
+
tokens: List[int],
|
| 91 |
+
prev_tokens: Optional[List[int]] = None,
|
| 92 |
+
special: bool = False,
|
| 93 |
+
) -> bytes:
|
| 94 |
+
skip_special_tokens = not special
|
| 95 |
+
if prev_tokens is not None:
|
| 96 |
+
text = self.hf_tokenizer.decode(
|
| 97 |
+
prev_tokens + tokens, skip_special_tokens=skip_special_tokens
|
| 98 |
+
).encode("utf-8", errors="ignore")
|
| 99 |
+
prev_text = self.hf_tokenizer.decode(
|
| 100 |
+
prev_tokens, skip_special_tokens=skip_special_tokens
|
| 101 |
+
).encode("utf-8", errors="ignore")
|
| 102 |
+
return text[len(prev_text) :]
|
| 103 |
+
else:
|
| 104 |
+
return self.hf_tokenizer.decode(
|
| 105 |
+
tokens, skip_special_tokens=skip_special_tokens
|
| 106 |
+
).encode("utf-8", errors="ignore")
|
| 107 |
+
|
| 108 |
+
@classmethod
|
| 109 |
+
def from_pretrained(cls, pretrained_model_name_or_path: str) -> "LlamaHFTokenizer":
|
| 110 |
+
try:
|
| 111 |
+
from transformers import AutoTokenizer
|
| 112 |
+
except ImportError:
|
| 113 |
+
raise ImportError(
|
| 114 |
+
"The `transformers` library is required to use the `HFTokenizer`."
|
| 115 |
+
"You can install it with `pip install transformers`."
|
| 116 |
+
)
|
| 117 |
+
hf_tokenizer = AutoTokenizer.from_pretrained(
|
| 118 |
+
pretrained_model_name_or_path=pretrained_model_name_or_path
|
| 119 |
+
)
|
| 120 |
+
return cls(hf_tokenizer)
|
llama_cpp/llama_types.py
ADDED
|
@@ -0,0 +1,316 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Types and request signatures for OpenAI compatibility
|
| 2 |
+
|
| 3 |
+
NOTE: These types may change to match the OpenAI OpenAPI specification.
|
| 4 |
+
|
| 5 |
+
Based on the OpenAI OpenAPI specification:
|
| 6 |
+
https://github.com/openai/openai-openapi/blob/master/openapi.yaml
|
| 7 |
+
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from typing import Any, List, Optional, Dict, Union
|
| 11 |
+
from typing_extensions import TypedDict, NotRequired, Literal
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
# NOTE: Defining this correctly using annotations seems to break pydantic validation.
|
| 15 |
+
# This is a workaround until we can figure out how to do this correctly
|
| 16 |
+
# JsonType = Union[None, int, str, bool, List["JsonType"], Dict[str, "JsonType"]]
|
| 17 |
+
JsonType = Union[None, int, str, bool, List[Any], Dict[str, Any]]
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class EmbeddingUsage(TypedDict):
|
| 21 |
+
prompt_tokens: int
|
| 22 |
+
total_tokens: int
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class Embedding(TypedDict):
|
| 26 |
+
index: int
|
| 27 |
+
object: str
|
| 28 |
+
embedding: Union[List[float], List[List[float]]]
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class CreateEmbeddingResponse(TypedDict):
|
| 32 |
+
object: Literal["list"]
|
| 33 |
+
model: str
|
| 34 |
+
data: List[Embedding]
|
| 35 |
+
usage: EmbeddingUsage
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class CompletionLogprobs(TypedDict):
|
| 39 |
+
text_offset: List[int]
|
| 40 |
+
token_logprobs: List[Optional[float]]
|
| 41 |
+
tokens: List[str]
|
| 42 |
+
top_logprobs: List[Optional[Dict[str, float]]]
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class CompletionChoice(TypedDict):
|
| 46 |
+
text: str
|
| 47 |
+
index: int
|
| 48 |
+
logprobs: Optional[CompletionLogprobs]
|
| 49 |
+
finish_reason: Optional[Literal["stop", "length"]]
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class CompletionUsage(TypedDict):
|
| 53 |
+
prompt_tokens: int
|
| 54 |
+
completion_tokens: int
|
| 55 |
+
total_tokens: int
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class CreateCompletionResponse(TypedDict):
|
| 59 |
+
id: str
|
| 60 |
+
object: Literal["text_completion"]
|
| 61 |
+
created: int
|
| 62 |
+
model: str
|
| 63 |
+
choices: List[CompletionChoice]
|
| 64 |
+
usage: NotRequired[CompletionUsage]
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class ChatCompletionResponseFunctionCall(TypedDict):
|
| 68 |
+
name: str
|
| 69 |
+
arguments: str
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class ChatCompletionResponseMessage(TypedDict):
|
| 73 |
+
content: Optional[str]
|
| 74 |
+
tool_calls: NotRequired["ChatCompletionMessageToolCalls"]
|
| 75 |
+
role: Literal["assistant", "function"] # NOTE: "function" may be incorrect here
|
| 76 |
+
function_call: NotRequired[ChatCompletionResponseFunctionCall] # DEPRECATED
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class ChatCompletionFunction(TypedDict):
|
| 80 |
+
name: str
|
| 81 |
+
description: NotRequired[str]
|
| 82 |
+
parameters: Dict[str, JsonType] # TODO: make this more specific
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
class ChatCompletionTopLogprobToken(TypedDict):
|
| 86 |
+
token: str
|
| 87 |
+
logprob: float
|
| 88 |
+
bytes: Optional[List[int]]
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class ChatCompletionLogprobToken(ChatCompletionTopLogprobToken):
|
| 92 |
+
token: str
|
| 93 |
+
logprob: float
|
| 94 |
+
bytes: Optional[List[int]]
|
| 95 |
+
top_logprobs: List[ChatCompletionTopLogprobToken]
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
class ChatCompletionLogprobs(TypedDict):
|
| 99 |
+
content: Optional[List[ChatCompletionLogprobToken]]
|
| 100 |
+
refusal: Optional[List[ChatCompletionLogprobToken]]
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class ChatCompletionResponseChoice(TypedDict):
|
| 104 |
+
index: int
|
| 105 |
+
message: "ChatCompletionResponseMessage"
|
| 106 |
+
logprobs: Optional[ChatCompletionLogprobs]
|
| 107 |
+
finish_reason: Optional[str]
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
class CreateChatCompletionResponse(TypedDict):
|
| 111 |
+
id: str
|
| 112 |
+
object: Literal["chat.completion"]
|
| 113 |
+
created: int
|
| 114 |
+
model: str
|
| 115 |
+
choices: List["ChatCompletionResponseChoice"]
|
| 116 |
+
usage: CompletionUsage
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
class ChatCompletionMessageToolCallChunkFunction(TypedDict):
|
| 120 |
+
name: Optional[str]
|
| 121 |
+
arguments: str
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
class ChatCompletionMessageToolCallChunk(TypedDict):
|
| 125 |
+
index: int
|
| 126 |
+
id: NotRequired[str]
|
| 127 |
+
type: Literal["function"]
|
| 128 |
+
function: ChatCompletionMessageToolCallChunkFunction
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
class ChatCompletionStreamResponseDeltaEmpty(TypedDict):
|
| 132 |
+
pass
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
class ChatCompletionStreamResponseDeltaFunctionCall(TypedDict):
|
| 136 |
+
name: str
|
| 137 |
+
arguments: str
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
class ChatCompletionStreamResponseDelta(TypedDict):
|
| 141 |
+
content: NotRequired[Optional[str]]
|
| 142 |
+
function_call: NotRequired[
|
| 143 |
+
Optional[ChatCompletionStreamResponseDeltaFunctionCall]
|
| 144 |
+
] # DEPRECATED
|
| 145 |
+
tool_calls: NotRequired[Optional[List[ChatCompletionMessageToolCallChunk]]]
|
| 146 |
+
role: NotRequired[Optional[Literal["system", "user", "assistant", "tool"]]]
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
class ChatCompletionStreamResponseChoice(TypedDict):
|
| 150 |
+
index: int
|
| 151 |
+
delta: Union[
|
| 152 |
+
ChatCompletionStreamResponseDelta, ChatCompletionStreamResponseDeltaEmpty
|
| 153 |
+
]
|
| 154 |
+
finish_reason: Optional[Literal["stop", "length", "tool_calls", "function_call"]]
|
| 155 |
+
logprobs: NotRequired[Optional[ChatCompletionLogprobs]]
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
class CreateChatCompletionStreamResponse(TypedDict):
|
| 159 |
+
id: str
|
| 160 |
+
model: str
|
| 161 |
+
object: Literal["chat.completion.chunk"]
|
| 162 |
+
created: int
|
| 163 |
+
choices: List[ChatCompletionStreamResponseChoice]
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
class ChatCompletionFunctions(TypedDict):
|
| 167 |
+
name: str
|
| 168 |
+
description: NotRequired[str]
|
| 169 |
+
parameters: Dict[str, JsonType] # TODO: make this more specific
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
class ChatCompletionFunctionCallOption(TypedDict):
|
| 173 |
+
name: str
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
class ChatCompletionRequestResponseFormat(TypedDict):
|
| 177 |
+
type: Literal["text", "json_object"]
|
| 178 |
+
schema: NotRequired[
|
| 179 |
+
JsonType
|
| 180 |
+
] # https://docs.endpoints.anyscale.com/guides/json_mode/
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
class ChatCompletionRequestMessageContentPartText(TypedDict):
|
| 184 |
+
type: Literal["text"]
|
| 185 |
+
text: str
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
class ChatCompletionRequestMessageContentPartImageImageUrl(TypedDict):
|
| 189 |
+
url: str
|
| 190 |
+
detail: NotRequired[Literal["auto", "low", "high"]]
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
class ChatCompletionRequestMessageContentPartImage(TypedDict):
|
| 194 |
+
type: Literal["image_url"]
|
| 195 |
+
image_url: Union[str, ChatCompletionRequestMessageContentPartImageImageUrl]
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
ChatCompletionRequestMessageContentPart = Union[
|
| 199 |
+
ChatCompletionRequestMessageContentPartText,
|
| 200 |
+
ChatCompletionRequestMessageContentPartImage,
|
| 201 |
+
]
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
class ChatCompletionRequestSystemMessage(TypedDict):
|
| 205 |
+
role: Literal["system"]
|
| 206 |
+
content: Optional[str]
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
class ChatCompletionRequestUserMessage(TypedDict):
|
| 210 |
+
role: Literal["user"]
|
| 211 |
+
content: Optional[Union[str, List[ChatCompletionRequestMessageContentPart]]]
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
class ChatCompletionMessageToolCallFunction(TypedDict):
|
| 215 |
+
name: str
|
| 216 |
+
arguments: str
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
class ChatCompletionMessageToolCall(TypedDict):
|
| 220 |
+
id: str
|
| 221 |
+
type: Literal["function"]
|
| 222 |
+
function: ChatCompletionMessageToolCallFunction
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
ChatCompletionMessageToolCalls = List[ChatCompletionMessageToolCall]
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
class ChatCompletionRequestAssistantMessageFunctionCall(TypedDict):
|
| 229 |
+
name: str
|
| 230 |
+
arguments: str
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
class ChatCompletionRequestAssistantMessage(TypedDict):
|
| 234 |
+
role: Literal["assistant"]
|
| 235 |
+
content: NotRequired[str]
|
| 236 |
+
tool_calls: NotRequired[ChatCompletionMessageToolCalls]
|
| 237 |
+
function_call: NotRequired[
|
| 238 |
+
ChatCompletionRequestAssistantMessageFunctionCall
|
| 239 |
+
] # DEPRECATED
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
class ChatCompletionRequestToolMessage(TypedDict):
|
| 243 |
+
role: Literal["tool"]
|
| 244 |
+
content: Optional[str]
|
| 245 |
+
tool_call_id: str
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
class ChatCompletionRequestFunctionMessage(TypedDict):
|
| 249 |
+
role: Literal["function"]
|
| 250 |
+
content: Optional[str]
|
| 251 |
+
name: str
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
ChatCompletionRequestMessage = Union[
|
| 255 |
+
ChatCompletionRequestSystemMessage,
|
| 256 |
+
ChatCompletionRequestUserMessage,
|
| 257 |
+
ChatCompletionRequestAssistantMessage,
|
| 258 |
+
ChatCompletionRequestUserMessage,
|
| 259 |
+
ChatCompletionRequestToolMessage,
|
| 260 |
+
ChatCompletionRequestFunctionMessage,
|
| 261 |
+
]
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
class ChatCompletionRequestFunctionCallOption(TypedDict):
|
| 265 |
+
name: str
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
ChatCompletionRequestFunctionCall = Union[
|
| 269 |
+
Literal["none", "auto"], ChatCompletionRequestFunctionCallOption
|
| 270 |
+
]
|
| 271 |
+
|
| 272 |
+
ChatCompletionFunctionParameters = Dict[str, JsonType] # TODO: make this more specific
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
class ChatCompletionToolFunction(TypedDict):
|
| 276 |
+
name: str
|
| 277 |
+
description: NotRequired[str]
|
| 278 |
+
parameters: ChatCompletionFunctionParameters
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
class ChatCompletionTool(TypedDict):
|
| 282 |
+
type: Literal["function"]
|
| 283 |
+
function: ChatCompletionToolFunction
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
class ChatCompletionNamedToolChoiceFunction(TypedDict):
|
| 287 |
+
name: str
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
class ChatCompletionNamedToolChoice(TypedDict):
|
| 291 |
+
type: Literal["function"]
|
| 292 |
+
function: ChatCompletionNamedToolChoiceFunction
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
ChatCompletionToolChoiceOption = Union[
|
| 296 |
+
Literal["none", "auto", "required"], ChatCompletionNamedToolChoice
|
| 297 |
+
]
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
# NOTE: The following type names are not part of the OpenAI OpenAPI specification
|
| 301 |
+
# and will be removed in a future major release.
|
| 302 |
+
|
| 303 |
+
EmbeddingData = Embedding
|
| 304 |
+
CompletionChunk = CreateCompletionResponse
|
| 305 |
+
Completion = CreateCompletionResponse
|
| 306 |
+
CreateCompletionStreamResponse = CreateCompletionResponse
|
| 307 |
+
ChatCompletionMessage = ChatCompletionResponseMessage
|
| 308 |
+
ChatCompletionChoice = ChatCompletionResponseChoice
|
| 309 |
+
ChatCompletion = CreateChatCompletionResponse
|
| 310 |
+
ChatCompletionChunkDeltaEmpty = ChatCompletionStreamResponseDeltaEmpty
|
| 311 |
+
ChatCompletionChunkChoice = ChatCompletionStreamResponseChoice
|
| 312 |
+
ChatCompletionChunkDelta = ChatCompletionStreamResponseDelta
|
| 313 |
+
ChatCompletionChunk = CreateChatCompletionStreamResponse
|
| 314 |
+
ChatCompletionStreamResponse = CreateChatCompletionStreamResponse
|
| 315 |
+
ChatCompletionResponseFunction = ChatCompletionFunction
|
| 316 |
+
ChatCompletionFunctionCall = ChatCompletionResponseFunctionCall
|
llama_cpp/llava_cpp.py
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
from ctypes import (
|
| 5 |
+
c_bool,
|
| 6 |
+
c_char_p,
|
| 7 |
+
c_int,
|
| 8 |
+
c_uint8,
|
| 9 |
+
c_float,
|
| 10 |
+
c_void_p,
|
| 11 |
+
POINTER,
|
| 12 |
+
_Pointer, # type: ignore
|
| 13 |
+
Structure,
|
| 14 |
+
)
|
| 15 |
+
import pathlib
|
| 16 |
+
from typing import (
|
| 17 |
+
Union,
|
| 18 |
+
NewType,
|
| 19 |
+
Optional,
|
| 20 |
+
TYPE_CHECKING,
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
import llama_cpp.llama_cpp as llama_cpp
|
| 24 |
+
|
| 25 |
+
from llama_cpp._ctypes_extensions import (
|
| 26 |
+
load_shared_library,
|
| 27 |
+
ctypes_function_for_shared_library,
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
if TYPE_CHECKING:
|
| 31 |
+
from llama_cpp._ctypes_extensions import (
|
| 32 |
+
CtypesArray,
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
# Specify the base name of the shared library to load
|
| 37 |
+
_libllava_base_name = "llava"
|
| 38 |
+
_libllava_override_path = os.environ.get("LLAVA_CPP_LIB")
|
| 39 |
+
_libllava_base_path = pathlib.Path(os.path.abspath(os.path.dirname(__file__))) / "lib" if _libllava_override_path is None else pathlib.Path()
|
| 40 |
+
|
| 41 |
+
# Load the library
|
| 42 |
+
_libllava = load_shared_library(_libllava_base_name, _libllava_base_path)
|
| 43 |
+
|
| 44 |
+
ctypes_function = ctypes_function_for_shared_library(_libllava)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
################################################
|
| 48 |
+
# llava.h
|
| 49 |
+
################################################
|
| 50 |
+
|
| 51 |
+
# struct clip_ctx;
|
| 52 |
+
clip_ctx_p = NewType("clip_ctx_p", int)
|
| 53 |
+
clip_ctx_p_ctypes = c_void_p
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
# struct llava_image_embed {
|
| 57 |
+
# float * embed;
|
| 58 |
+
# int n_image_pos;
|
| 59 |
+
# };
|
| 60 |
+
class llava_image_embed(Structure):
|
| 61 |
+
_fields_ = [
|
| 62 |
+
("embed", POINTER(c_float)),
|
| 63 |
+
("n_image_pos", c_int),
|
| 64 |
+
]
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
# /** sanity check for clip <-> llava embed size match */
|
| 68 |
+
# LLAVA_API bool llava_validate_embed_size(const llama_context * ctx_llama, const clip_ctx * ctx_clip);
|
| 69 |
+
@ctypes_function(
|
| 70 |
+
"llava_validate_embed_size",
|
| 71 |
+
[llama_cpp.llama_context_p_ctypes, clip_ctx_p_ctypes],
|
| 72 |
+
c_bool,
|
| 73 |
+
)
|
| 74 |
+
def llava_validate_embed_size(
|
| 75 |
+
ctx_llama: llama_cpp.llama_context_p, ctx_clip: clip_ctx_p, /
|
| 76 |
+
) -> bool:
|
| 77 |
+
...
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
# /** build an image embed from image file bytes */
|
| 81 |
+
# LLAVA_API struct llava_image_embed * llava_image_embed_make_with_bytes(struct clip_ctx * ctx_clip, int n_threads, const unsigned char * image_bytes, int image_bytes_length);
|
| 82 |
+
@ctypes_function(
|
| 83 |
+
"llava_image_embed_make_with_bytes",
|
| 84 |
+
[clip_ctx_p_ctypes, c_int, POINTER(c_uint8), c_int],
|
| 85 |
+
POINTER(llava_image_embed),
|
| 86 |
+
)
|
| 87 |
+
def llava_image_embed_make_with_bytes(
|
| 88 |
+
ctx_clip: clip_ctx_p,
|
| 89 |
+
n_threads: Union[c_int, int],
|
| 90 |
+
image_bytes: CtypesArray[c_uint8],
|
| 91 |
+
image_bytes_length: Union[c_int, int],
|
| 92 |
+
/,
|
| 93 |
+
) -> "_Pointer[llava_image_embed]":
|
| 94 |
+
...
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
# /** build an image embed from a path to an image filename */
|
| 98 |
+
# LLAVA_API struct llava_image_embed * llava_image_embed_make_with_filename(struct clip_ctx * ctx_clip, int n_threads, const char * image_path);
|
| 99 |
+
@ctypes_function(
|
| 100 |
+
"llava_image_embed_make_with_filename",
|
| 101 |
+
[clip_ctx_p_ctypes, c_int, c_char_p],
|
| 102 |
+
POINTER(llava_image_embed),
|
| 103 |
+
)
|
| 104 |
+
def llava_image_embed_make_with_filename(
|
| 105 |
+
ctx_clip: clip_ctx_p, n_threads: Union[c_int, int], image_path: bytes, /
|
| 106 |
+
) -> "_Pointer[llava_image_embed]":
|
| 107 |
+
...
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
# LLAVA_API void llava_image_embed_free(struct llava_image_embed * embed);
|
| 111 |
+
# /** free an embedding made with llava_image_embed_make_* */
|
| 112 |
+
@ctypes_function("llava_image_embed_free", [POINTER(llava_image_embed)], None)
|
| 113 |
+
def llava_image_embed_free(embed: "_Pointer[llava_image_embed]", /):
|
| 114 |
+
...
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
# /** write the image represented by embed into the llama context with batch size n_batch, starting at context pos n_past. on completion, n_past points to the next position in the context after the image embed. */
|
| 118 |
+
# LLAVA_API bool llava_eval_image_embed(struct llama_context * ctx_llama, const struct llava_image_embed * embed, int n_batch, int * n_past);
|
| 119 |
+
@ctypes_function(
|
| 120 |
+
"llava_eval_image_embed",
|
| 121 |
+
[
|
| 122 |
+
llama_cpp.llama_context_p_ctypes,
|
| 123 |
+
POINTER(llava_image_embed),
|
| 124 |
+
c_int,
|
| 125 |
+
POINTER(c_int),
|
| 126 |
+
],
|
| 127 |
+
c_bool,
|
| 128 |
+
)
|
| 129 |
+
def llava_eval_image_embed(
|
| 130 |
+
ctx_llama: llama_cpp.llama_context_p,
|
| 131 |
+
embed: "_Pointer[llava_image_embed]",
|
| 132 |
+
n_batch: Union[c_int, int],
|
| 133 |
+
n_past: "_Pointer[c_int]",
|
| 134 |
+
/,
|
| 135 |
+
) -> bool:
|
| 136 |
+
...
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
################################################
|
| 140 |
+
# clip.h
|
| 141 |
+
################################################
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
# /** load mmproj model */
|
| 145 |
+
# CLIP_API struct clip_ctx * clip_model_load (const char * fname, int verbosity);
|
| 146 |
+
@ctypes_function("clip_model_load", [c_char_p, c_int], clip_ctx_p_ctypes)
|
| 147 |
+
def clip_model_load(
|
| 148 |
+
fname: bytes, verbosity: Union[c_int, int], /
|
| 149 |
+
) -> Optional[clip_ctx_p]:
|
| 150 |
+
...
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
# /** free mmproj model */
|
| 154 |
+
# CLIP_API void clip_free(struct clip_ctx * ctx);
|
| 155 |
+
@ctypes_function("clip_free", [clip_ctx_p_ctypes], None)
|
| 156 |
+
def clip_free(ctx: clip_ctx_p, /):
|
| 157 |
+
...
|
| 158 |
+
|
llama_cpp/py.typed
ADDED
|
File without changes
|
llama_cpp/server/__init__.py
ADDED
|
File without changes
|
llama_cpp/server/__main__.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Example FastAPI server for llama.cpp.
|
| 2 |
+
|
| 3 |
+
To run this example:
|
| 4 |
+
|
| 5 |
+
```bash
|
| 6 |
+
pip install fastapi uvicorn sse-starlette pydantic-settings
|
| 7 |
+
export MODEL=../models/7B/...
|
| 8 |
+
```
|
| 9 |
+
|
| 10 |
+
Then run:
|
| 11 |
+
```
|
| 12 |
+
uvicorn llama_cpp.server.app:create_app --reload
|
| 13 |
+
```
|
| 14 |
+
|
| 15 |
+
or
|
| 16 |
+
|
| 17 |
+
```
|
| 18 |
+
python3 -m llama_cpp.server
|
| 19 |
+
```
|
| 20 |
+
|
| 21 |
+
Then visit http://localhost:8000/docs to see the interactive API docs.
|
| 22 |
+
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
from __future__ import annotations
|
| 26 |
+
|
| 27 |
+
import os
|
| 28 |
+
import sys
|
| 29 |
+
import argparse
|
| 30 |
+
|
| 31 |
+
import uvicorn
|
| 32 |
+
|
| 33 |
+
from llama_cpp.server.app import create_app
|
| 34 |
+
from llama_cpp.server.settings import (
|
| 35 |
+
Settings,
|
| 36 |
+
ServerSettings,
|
| 37 |
+
ModelSettings,
|
| 38 |
+
ConfigFileSettings,
|
| 39 |
+
)
|
| 40 |
+
from llama_cpp.server.cli import add_args_from_model, parse_model_from_args
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def main():
|
| 44 |
+
description = "🦙 Llama.cpp python server. Host your own LLMs!🚀"
|
| 45 |
+
parser = argparse.ArgumentParser(description=description)
|
| 46 |
+
|
| 47 |
+
add_args_from_model(parser, Settings)
|
| 48 |
+
parser.add_argument(
|
| 49 |
+
"--config_file",
|
| 50 |
+
type=str,
|
| 51 |
+
help="Path to a config file to load.",
|
| 52 |
+
)
|
| 53 |
+
server_settings: ServerSettings | None = None
|
| 54 |
+
model_settings: list[ModelSettings] = []
|
| 55 |
+
args = parser.parse_args()
|
| 56 |
+
try:
|
| 57 |
+
# Load server settings from config_file if provided
|
| 58 |
+
config_file = os.environ.get("CONFIG_FILE", args.config_file)
|
| 59 |
+
if config_file:
|
| 60 |
+
if not os.path.exists(config_file):
|
| 61 |
+
raise ValueError(f"Config file {config_file} not found!")
|
| 62 |
+
with open(config_file, "rb") as f:
|
| 63 |
+
# Check if yaml file
|
| 64 |
+
if config_file.endswith(".yaml") or config_file.endswith(".yml"):
|
| 65 |
+
import yaml
|
| 66 |
+
import json
|
| 67 |
+
|
| 68 |
+
config_file_settings = ConfigFileSettings.model_validate_json(
|
| 69 |
+
json.dumps(yaml.safe_load(f))
|
| 70 |
+
)
|
| 71 |
+
else:
|
| 72 |
+
config_file_settings = ConfigFileSettings.model_validate_json(
|
| 73 |
+
f.read()
|
| 74 |
+
)
|
| 75 |
+
server_settings = ServerSettings.model_validate(config_file_settings)
|
| 76 |
+
model_settings = config_file_settings.models
|
| 77 |
+
else:
|
| 78 |
+
server_settings = parse_model_from_args(ServerSettings, args)
|
| 79 |
+
model_settings = [parse_model_from_args(ModelSettings, args)]
|
| 80 |
+
except Exception as e:
|
| 81 |
+
print(e, file=sys.stderr)
|
| 82 |
+
parser.print_help()
|
| 83 |
+
sys.exit(1)
|
| 84 |
+
assert server_settings is not None
|
| 85 |
+
assert model_settings is not None
|
| 86 |
+
app = create_app(
|
| 87 |
+
server_settings=server_settings,
|
| 88 |
+
model_settings=model_settings,
|
| 89 |
+
)
|
| 90 |
+
uvicorn.run(
|
| 91 |
+
app,
|
| 92 |
+
host=os.getenv("HOST", server_settings.host),
|
| 93 |
+
port=int(os.getenv("PORT", server_settings.port)),
|
| 94 |
+
ssl_keyfile=server_settings.ssl_keyfile,
|
| 95 |
+
ssl_certfile=server_settings.ssl_certfile,
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
if __name__ == "__main__":
|
| 100 |
+
main()
|
llama_cpp/server/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (239 Bytes). View file
|
|
|
llama_cpp/server/__pycache__/__main__.cpython-310.pyc
ADDED
|
Binary file (2.37 kB). View file
|
|
|
llama_cpp/server/__pycache__/app.cpython-310.pyc
ADDED
|
Binary file (12.9 kB). View file
|
|
|
llama_cpp/server/__pycache__/cli.cpython-310.pyc
ADDED
|
Binary file (3.21 kB). View file
|
|
|
llama_cpp/server/__pycache__/errors.cpython-310.pyc
ADDED
|
Binary file (5.54 kB). View file
|
|
|
llama_cpp/server/__pycache__/model.cpython-310.pyc
ADDED
|
Binary file (6.45 kB). View file
|
|
|
llama_cpp/server/__pycache__/settings.cpython-310.pyc
ADDED
|
Binary file (7.32 kB). View file
|
|
|