| import re |
| from functools import partial |
|
|
| import torch |
|
|
| from modules import shared |
| from modules.callbacks import Iteratorize |
| from modules.logging_colors import logger |
|
|
| if torch.cuda.is_available() and not torch.version.hip: |
| try: |
| from llama_cpp_cuda import Llama, LlamaCache, LogitsProcessorList |
| except: |
| from llama_cpp import Llama, LlamaCache, LogitsProcessorList |
| else: |
| from llama_cpp import Llama, LlamaCache, LogitsProcessorList |
|
|
|
|
| def ban_eos_logits_processor(eos_token, input_ids, logits): |
| logits[eos_token] = -float('inf') |
| return logits |
|
|
|
|
| class LlamaCppModel: |
| def __init__(self): |
| self.initialized = False |
|
|
| def __del__(self): |
| self.model.__del__() |
|
|
| @classmethod |
| def from_pretrained(self, path): |
| result = self() |
| cache_capacity = 0 |
| if shared.args.cache_capacity is not None: |
| if 'GiB' in shared.args.cache_capacity: |
| cache_capacity = int(re.sub('[a-zA-Z]', '', shared.args.cache_capacity)) * 1000 * 1000 * 1000 |
| elif 'MiB' in shared.args.cache_capacity: |
| cache_capacity = int(re.sub('[a-zA-Z]', '', shared.args.cache_capacity)) * 1000 * 1000 |
| else: |
| cache_capacity = int(shared.args.cache_capacity) |
|
|
| logger.info("Cache capacity is " + str(cache_capacity) + " bytes") |
| params = { |
| 'model_path': str(path), |
| 'n_ctx': shared.args.n_ctx, |
| 'seed': int(shared.args.llama_cpp_seed), |
| 'n_threads': shared.args.threads or None, |
| 'n_batch': shared.args.n_batch, |
| 'use_mmap': not shared.args.no_mmap, |
| 'use_mlock': shared.args.mlock, |
| 'low_vram': shared.args.low_vram, |
| 'n_gpu_layers': shared.args.n_gpu_layers, |
| 'rope_freq_base': 10000 * shared.args.alpha_value ** (64/63.), |
| 'rope_freq_scale': 1.0 / shared.args.compress_pos_emb, |
| 'n_gqa': shared.args.n_gqa or None, |
| 'rms_norm_eps': shared.args.rms_norm_eps or None, |
| } |
|
|
| result.model = Llama(**params) |
| if cache_capacity > 0: |
| result.model.set_cache(LlamaCache(capacity_bytes=cache_capacity)) |
|
|
| |
| return result, result |
|
|
| def encode(self, string): |
| if type(string) is str: |
| string = string.encode() |
|
|
| return self.model.tokenize(string) |
|
|
| def decode(self, tokens): |
| return self.model.detokenize(tokens) |
|
|
| def generate(self, prompt, state, callback=None): |
| prompt = prompt if type(prompt) is str else prompt.decode() |
| completion_chunks = self.model.create_completion( |
| prompt=prompt, |
| max_tokens=state['max_new_tokens'], |
| temperature=state['temperature'], |
| top_p=state['top_p'], |
| top_k=state['top_k'], |
| repeat_penalty=state['repetition_penalty'], |
| tfs_z=state['tfs'], |
| mirostat_mode=int(state['mirostat_mode']), |
| mirostat_tau=state['mirostat_tau'], |
| mirostat_eta=state['mirostat_eta'], |
| stream=True, |
| logits_processor=LogitsProcessorList([ |
| partial(ban_eos_logits_processor, self.model.token_eos()), |
| ]) if state['ban_eos_token'] else None, |
| ) |
|
|
| output = "" |
| for completion_chunk in completion_chunks: |
| text = completion_chunk['choices'][0]['text'] |
| output += text |
| if callback: |
| callback(text) |
|
|
| return output |
|
|
| def generate_with_streaming(self, *args, **kwargs): |
| with Iteratorize(self.generate, args, kwargs, callback=None) as generator: |
| reply = '' |
| for token in generator: |
| reply += token |
| yield reply |
|
|