koichi12's picture
Add files using upload-large-folder tool
762d748 verified
raw
history blame
13.3 kB
import dataclasses
from typing import TYPE_CHECKING, Iterator, List, Optional, Tuple, TypedDict, Union
import torch
from typing_extensions import Unpack
from outlines.generate.api import GenerationParameters, SamplingParameters
if TYPE_CHECKING:
import torch.LongTensor
from exllamav2.generator import ExLlamaV2DynamicGenerator, ExLlamaV2Sampler
class ExllamaV2Params(TypedDict, total=False):
max_tokens: int
stop_conditions: Optional[List[Union[int, str]]]
seed: Optional[int]
gen_settings: "ExLlamaV2Sampler.Settings"
max_new_tokens: List[int]
class OutlinesExLlamaV2Tokenizer:
def __init__(self, tokenizer):
self.exl2_tokenizer = tokenizer
self.vocabulary = self.exl2_tokenizer.get_piece_to_id_dict()
self.special_tokens = set(self.exl2_tokenizer.extended_piece_to_id)
self.eos_token_id = self.exl2_tokenizer.eos_token_id
def convert_token_to_string(self, token):
return token
def decode(self, token_ids: "torch.LongTensor") -> List[str]:
decoded = self.exl2_tokenizer.decode(
torch.tensor(token_ids),
decode_special_tokens=False,
)
if isinstance(decoded, str):
return [decoded]
return decoded
class ExLlamaV2Model:
"""Represents a `exl2` model."""
def __init__(
self,
generator: "ExLlamaV2DynamicGenerator",
tokenizer: "OutlinesExLlamaV2Tokenizer",
max_seq_len: int,
):
self.generator = generator
self.tokenizer = tokenizer
self.max_seq_len = max_seq_len
def prepare_generation_parameters(
self,
prompts: Union[str, List[str]],
generation_parameters: GenerationParameters,
sampling_parameters: SamplingParameters,
structure_logits_processor,
**exllamav2_params: Unpack[ExllamaV2Params],
) -> Tuple[ExllamaV2Params, Union[str, List[str]]]:
"""Prepare the generation parameters.
`exllamav2` uses different default values
"""
from exllamav2.generator import ExLlamaV2Sampler
if isinstance(prompts, str):
prompts = [prompts]
max_tokens, stop_at, seed = dataclasses.astuple(generation_parameters)
if max_tokens is None:
max_tokens = []
for prompt in prompts:
ids = self.generator.tokenizer.encode(
prompt, encode_special_tokens=True
)
prompt_tokens = ids.shape[-1]
max_tokens.append(self.max_seq_len - prompt_tokens)
exllamav2_params["max_new_tokens"] = max_tokens
else:
exllamav2_params["max_new_tokens"] = [
max_tokens for _ in range(len(prompts))
]
stop_conditions = [self.generator.tokenizer.eos_token_id]
if isinstance(generation_parameters.stop_at, str):
stop_conditions.append(generation_parameters.stop_at)
elif isinstance(generation_parameters.stop_at, list):
for stop_at in generation_parameters.stop_at:
stop_conditions.append(stop_at)
exllamav2_params["stop_conditions"] = stop_conditions
exllamav2_params["seed"] = seed
gen_settings = ExLlamaV2Sampler.Settings()
if sampling_parameters.temperature is not None:
gen_settings.temperature = sampling_parameters.temperature
if sampling_parameters.top_p is not None:
gen_settings.top_p = sampling_parameters.top_p
if sampling_parameters.top_k is not None:
gen_settings.top_k = sampling_parameters.top_k
gen_settings.logits_processor = structure_logits_processor
exllamav2_params["gen_settings"] = gen_settings
if sampling_parameters.num_samples > 1:
prompts = prompts * sampling_parameters.num_samples
exllamav2_params["max_new_tokens"] = (
exllamav2_params["max_new_tokens"] * sampling_parameters.num_samples
)
if len(prompts) == 1:
prompts = prompts[0]
return exllamav2_params, prompts
def reformat_output(
self, output: Union[str, List[str]], sampling_parameters: SamplingParameters
):
"""
The purpose of this function is to reformat the output from exllamav2's output format to outline's output format
For exllamav2, it mainly accepts only a list or a string(they also do cfg sampling with tuples but we will ignore this for now)
The exllamav2's logic is
1. If the prompt is a string, return a string. This is the same as outlines
2. If a prompt is a list, return a list. This is not the same as outlines output in that if the list is only one element, the string is expected to be outputted.
3. There is no such thing as num_samples, so the prompts had to be duplicated by num_samples times. Then, we had the function output a list of lists
"""
if isinstance(output, str):
return output
if len(output) == 1:
return output[0]
if sampling_parameters.num_samples > 1:
if len(output) == sampling_parameters.num_samples:
return output
assert len(output) % sampling_parameters.num_samples == 0
num_items_per_sample = len(output) // sampling_parameters.num_samples
new_output = []
for i in range(sampling_parameters.num_samples):
curr_sample = []
for j in range(num_items_per_sample):
curr_sample.append(output[i * num_items_per_sample + j])
new_output.append(curr_sample)
return new_output
return output
def generate(
self,
prompts: Union[str, List[str]],
generation_parameters: GenerationParameters,
structure_logits_processor,
sampling_parameters: SamplingParameters,
**exllamav2_params: Unpack[ExllamaV2Params],
) -> Union[str, List[str]]:
exllamav2_params, prompts = self.prepare_generation_parameters(
prompts,
generation_parameters,
sampling_parameters,
structure_logits_processor,
)
"""
In exllamav2, it needs the max amount of new tokens generated.
The reason exllamav2_params["max_new_tokens"] is a list is because in prepare_generation_parameters
the max amount of tokens that can be generated by the model for each prompt(by encoding with tokenizer) is calculated.
The minimum is picked because otherwise it might be possible for one of the
prompts to exceed the max sequence length.
"""
output = self.generator.generate(
prompt=prompts,
gen_settings=exllamav2_params["gen_settings"],
max_new_tokens=min(exllamav2_params["max_new_tokens"]),
completion_only=True,
encode_special_tokens=True,
stop_conditions=exllamav2_params["stop_conditions"],
add_bos=False,
seed=exllamav2_params["seed"],
)
return self.reformat_output(output, sampling_parameters)
def stream(
self,
prompts: Union[str, List[str]],
generation_parameters: GenerationParameters,
structure_logits_processor,
sampling_parameters: SamplingParameters,
**exllamav2_params: Unpack[ExllamaV2Params],
) -> Iterator[Union[str, List[str]]]:
from exllamav2.generator import ExLlamaV2DynamicJob
exllamav2_params, prompts = self.prepare_generation_parameters(
prompts,
generation_parameters,
sampling_parameters,
structure_logits_processor,
)
order = {}
if isinstance(prompts, str):
prompts = [prompts]
batch_size = len(prompts)
seed = exllamav2_params["seed"]
for idx, p in enumerate(prompts):
input_ids = self.generator.tokenizer.encode(
p, encode_special_tokens=True, add_bos=False
)
job = ExLlamaV2DynamicJob(
input_ids=input_ids,
max_new_tokens=exllamav2_params["max_new_tokens"][idx],
min_new_tokens=0,
seed=seed,
stop_conditions=exllamav2_params["stop_conditions"],
gen_settings=exllamav2_params["gen_settings"],
token_healing=False,
decode_special_tokens=False,
)
if seed is not None:
seed += 1
serial = self.generator.enqueue(job)
order[serial] = idx
# Collect outputs until all jobs finish
next_text = [""] * batch_size
def token_generator() -> Iterator[str]:
while self.generator.num_remaining_jobs():
results = self.generator.iterate()
for r in results:
idx = order[r["serial"]]
if r["stage"] == "streaming":
text = r.get("text", "")
next_text[idx] = text
if r["eos"]:
next_text[idx] = ""
yield self.reformat_output(next_text, sampling_parameters)
return
return token_generator()
def exl2(
model_path: str,
draft_model_path: Optional[str] = None,
max_seq_len: Optional[int] = None,
cache_q4: bool = False,
paged: bool = True,
max_chunk_size: Optional[int] = None,
) -> ExLlamaV2Model:
"""
Load an ExLlamaV2 model.
Parameters
----------
model_path (str)
Path to the model directory.
device (str)
Device to load the model on. Pass in 'cuda' for GPU or 'cpu' for CPU
max_seq_len (Optional[int], optional)
Maximum sequence length. Defaults to None.
scale_pos_emb (Optional[float], optional)
Scale factor for positional embeddings. Defaults to None.
scale_alpha_value (Optional[float], optional)
Scale alpha value. Defaults to None.
no_flash_attn (Optional[bool], optional)
Disable flash attention. Defaults to None.
num_experts_per_token (Optional[int], optional)
Number of experts per token. Defaults to None.
cache_q4 (bool, optional)
Use Q4 cache. Defaults to False.
tokenizer_kwargs (dict, optional)
Additional keyword arguments for the tokenizer. Defaults to {}.
gpu_split (str)
\"auto\", or VRAM allocation per GPU in GB. Auto will use exllama's autosplit feature
low_mem (bool, optional)
Enable VRAM optimizations, potentially trading off speed
verbose (bool, optional)
Enable if you want debugging statements
Returns
-------
An `ExLlamaV2Model` instance.
Raises
------
`ImportError` if the `exllamav2` library is not installed.
"""
try:
from exllamav2 import (
ExLlamaV2,
ExLlamaV2Cache,
ExLlamaV2Cache_Q4,
ExLlamaV2Config,
ExLlamaV2Tokenizer,
)
from exllamav2.generator import ExLlamaV2DynamicGenerator
except ImportError:
raise ImportError(
"The `exllamav2`, `transformers` and `torch` libraries needs to be installed in order to use `exllamav2` models. "
"Please run `pip install transformers torch git+https://github.com/lapp0/exllamav2@sampler-logits-processor` "
"Documentation: https://dottxt-ai.github.io/outlines/latest/reference/models/exllamav2/"
)
config = ExLlamaV2Config(model_path)
if max_chunk_size is not None:
config.max_input_len = max_chunk_size
config.max_attention_size = max_chunk_size**2
config.arch_compat_overrides()
model = ExLlamaV2(config)
if max_seq_len is None:
max_seq_len = -1
if cache_q4:
cache = ExLlamaV2Cache_Q4(model, max_seq_len=max_seq_len, lazy=True)
else:
cache = ExLlamaV2Cache(model, max_seq_len=max_seq_len, lazy=True)
model.load_autosplit(cache, progress=True)
print("Loading tokenizer...")
tokenizer = ExLlamaV2Tokenizer(config)
max_batch_size = 4 if paged else 1
draft_model = None
draft_cache = None
if draft_model_path is not None:
draft_config = ExLlamaV2Config(draft_model_path)
draft_model = ExLlamaV2(draft_config)
if cache_q4:
draft_cache = ExLlamaV2Cache_Q4(
draft_model, max_seq_len=max_seq_len, lazy=True
)
else:
draft_cache = ExLlamaV2Cache(
draft_model, max_seq_len=max_seq_len, lazy=True
)
# Initialize the generator with all default parameters
generator = ExLlamaV2DynamicGenerator(
model=model,
cache=cache,
draft_model=draft_model,
draft_cache=draft_cache,
tokenizer=tokenizer,
max_batch_size=max_batch_size,
use_ngram_draft=False,
max_chunk_size=max_chunk_size,
paged=paged,
)
max_seq_len = cache.max_seq_len
outlines_tokenizer = OutlinesExLlamaV2Tokenizer(tokenizer)
outlines_exl2_model = ExLlamaV2Model(generator, outlines_tokenizer, max_seq_len)
return outlines_exl2_model