Spaces:
Paused
Paused
| from functools import partial | |
| from langchain.llms.base import LLM | |
| from langchain.callbacks.manager import CallbackManagerForLLMRun | |
| from typing import Any, Dict, List, Optional | |
| from pydantic import Field, root_validator | |
| from exllama.model import ExLlama, ExLlamaCache, ExLlamaConfig | |
| from exllama.tokenizer import ExLlamaTokenizer | |
| from exllama.generator import ExLlamaGenerator | |
| from exllama.lora import ExLlamaLora | |
| import os, glob | |
| BROKEN_UNICODE = b'\\ufffd'.decode('unicode_escape') | |
| class H2OExLlamaTokenizer(ExLlamaTokenizer): | |
| def __call__(self, text, *args, **kwargs): | |
| return dict(input_ids=self.encode(text)) | |
| class H2OExLlamaGenerator(ExLlamaGenerator): | |
| def is_exlama(self): | |
| return True | |
| class Exllama(LLM): | |
| client: Any #: :meta private: | |
| model_path: str = None | |
| model: Any = None | |
| sanitize_bot_response: bool = False | |
| prompter: Any = None | |
| context: Any = '' | |
| iinput: Any = '' | |
| """The path to the GPTQ model folder.""" | |
| exllama_cache: ExLlamaCache = None #: :meta private: | |
| config: ExLlamaConfig = None #: :meta private: | |
| generator: ExLlamaGenerator = None #: :meta private: | |
| tokenizer: ExLlamaTokenizer = None #: :meta private: | |
| ##Langchain parameters | |
| logfunc = print | |
| stop_sequences: Optional[List[str]] = Field("", description="Sequences that immediately will stop the generator.") | |
| streaming: Optional[bool] = Field(True, description="Whether to stream the results, token by token.") | |
| ##Generator parameters | |
| disallowed_tokens: Optional[List[int]] = Field(None, description="List of tokens to disallow during generation.") | |
| temperature: Optional[float] = Field(None, description="Temperature for sampling diversity.") | |
| top_k: Optional[int] = Field(None, | |
| description="Consider the most probable top_k samples, 0 to disable top_k sampling.") | |
| top_p: Optional[float] = Field(None, | |
| description="Consider tokens up to a cumulative probabiltiy of top_p, 0.0 to disable top_p sampling.") | |
| min_p: Optional[float] = Field(None, description="Do not consider tokens with probability less than this.") | |
| typical: Optional[float] = Field(None, | |
| description="Locally typical sampling threshold, 0.0 to disable typical sampling.") | |
| token_repetition_penalty_max: Optional[float] = Field(None, | |
| description="Repetition penalty for most recent tokens.") | |
| token_repetition_penalty_sustain: Optional[int] = Field(None, | |
| description="No. most recent tokens to repeat penalty for, -1 to apply to whole context.") | |
| token_repetition_penalty_decay: Optional[int] = Field(None, | |
| description="Gradually decrease penalty over this many tokens.") | |
| beams: Optional[int] = Field(None, description="Number of beams for beam search.") | |
| beam_length: Optional[int] = Field(None, description="Length of beams for beam search.") | |
| ##Config overrides | |
| max_seq_len: Optional[int] = Field(2048, | |
| decription="Reduce to save memory. Can also be increased, ideally while also using compress_pos_emn and a compatible model/LoRA") | |
| compress_pos_emb: Optional[float] = Field(1.0, | |
| description="Amount of compression to apply to the positional embedding.") | |
| set_auto_map: Optional[str] = Field(None, | |
| description="Comma-separated list of VRAM (in GB) to use per GPU device for model layers, e.g. 20,7,7") | |
| gpu_peer_fix: Optional[bool] = Field(None, description="Prevent direct copies of data between GPUs") | |
| alpha_value: Optional[float] = Field(1.0, description="Rope context extension alpha") | |
| ##Tuning | |
| matmul_recons_thd: Optional[int] = Field(None) | |
| fused_mlp_thd: Optional[int] = Field(None) | |
| sdp_thd: Optional[int] = Field(None) | |
| fused_attn: Optional[bool] = Field(None) | |
| matmul_fused_remap: Optional[bool] = Field(None) | |
| rmsnorm_no_half2: Optional[bool] = Field(None) | |
| rope_no_half2: Optional[bool] = Field(None) | |
| matmul_no_half2: Optional[bool] = Field(None) | |
| silu_no_half2: Optional[bool] = Field(None) | |
| concurrent_streams: Optional[bool] = Field(None) | |
| ##Lora Parameters | |
| lora_path: Optional[str] = Field(None, description="Path to your lora.") | |
| def get_model_path_at(path): | |
| patterns = ["*.safetensors", "*.bin", "*.pt"] | |
| model_paths = [] | |
| for pattern in patterns: | |
| full_pattern = os.path.join(path, pattern) | |
| model_paths = glob.glob(full_pattern) | |
| if model_paths: # If there are any files matching the current pattern | |
| break # Exit the loop as soon as we find a matching file | |
| if model_paths: # If there are any files matching any of the patterns | |
| return model_paths[0] | |
| else: | |
| return None # Return None if no matching files were found | |
| def configure_object(params, values, logfunc): | |
| obj_params = {k: values.get(k) for k in params} | |
| def apply_to(obj): | |
| for key, value in obj_params.items(): | |
| if value: | |
| if hasattr(obj, key): | |
| setattr(obj, key, value) | |
| logfunc(f"{key} {value}") | |
| else: | |
| raise AttributeError(f"{key} does not exist in {obj}") | |
| return apply_to | |
| def validate_environment(cls, values: Dict) -> Dict: | |
| model_param_names = [ | |
| "temperature", | |
| "top_k", | |
| "top_p", | |
| "min_p", | |
| "typical", | |
| "token_repetition_penalty_max", | |
| "token_repetition_penalty_sustain", | |
| "token_repetition_penalty_decay", | |
| "beams", | |
| "beam_length", | |
| ] | |
| config_param_names = [ | |
| "max_seq_len", | |
| "compress_pos_emb", | |
| "gpu_peer_fix", | |
| "alpha_value" | |
| ] | |
| tuning_parameters = [ | |
| "matmul_recons_thd", | |
| "fused_mlp_thd", | |
| "sdp_thd", | |
| "matmul_fused_remap", | |
| "rmsnorm_no_half2", | |
| "rope_no_half2", | |
| "matmul_no_half2", | |
| "silu_no_half2", | |
| "concurrent_streams", | |
| "fused_attn", | |
| ] | |
| ##Set logging function if verbose or set to empty lambda | |
| verbose = values['verbose'] | |
| if not verbose: | |
| values['logfunc'] = lambda *args, **kwargs: None | |
| logfunc = values['logfunc'] | |
| if values['model'] is None: | |
| model_path = values["model_path"] | |
| lora_path = values["lora_path"] | |
| tokenizer_path = os.path.join(model_path, "tokenizer.model") | |
| model_config_path = os.path.join(model_path, "config.json") | |
| model_path = Exllama.get_model_path_at(model_path) | |
| config = ExLlamaConfig(model_config_path) | |
| tokenizer = ExLlamaTokenizer(tokenizer_path) | |
| config.model_path = model_path | |
| configure_config = Exllama.configure_object(config_param_names, values, logfunc) | |
| configure_config(config) | |
| configure_tuning = Exllama.configure_object(tuning_parameters, values, logfunc) | |
| configure_tuning(config) | |
| ##Special parameter, set auto map, it's a function | |
| if values['set_auto_map']: | |
| config.set_auto_map(values['set_auto_map']) | |
| logfunc(f"set_auto_map {values['set_auto_map']}") | |
| model = ExLlama(config) | |
| exllama_cache = ExLlamaCache(model) | |
| generator = ExLlamaGenerator(model, tokenizer, exllama_cache) | |
| ##Load and apply lora to generator | |
| if lora_path is not None: | |
| lora_config_path = os.path.join(lora_path, "adapter_config.json") | |
| lora_path = Exllama.get_model_path_at(lora_path) | |
| lora = ExLlamaLora(model, lora_config_path, lora_path) | |
| generator.lora = lora | |
| logfunc(f"Loaded LORA @ {lora_path}") | |
| else: | |
| generator = values['model'] | |
| exllama_cache = generator.cache | |
| model = generator.model | |
| config = model.config | |
| tokenizer = generator.tokenizer | |
| # Set if model existed before or not since generation-time parameters | |
| configure_model = Exllama.configure_object(model_param_names, values, logfunc) | |
| values["stop_sequences"] = [x.strip().lower() for x in values["stop_sequences"]] | |
| configure_model(generator.settings) | |
| setattr(generator.settings, "stop_sequences", values["stop_sequences"]) | |
| logfunc(f"stop_sequences {values['stop_sequences']}") | |
| disallowed = values.get("disallowed_tokens") | |
| if disallowed: | |
| generator.disallow_tokens(disallowed) | |
| print(f"Disallowed Tokens: {generator.disallowed_tokens}") | |
| values["client"] = model | |
| values["generator"] = generator | |
| values["config"] = config | |
| values["tokenizer"] = tokenizer | |
| values["exllama_cache"] = exllama_cache | |
| return values | |
| def _llm_type(self) -> str: | |
| """Return type of llm.""" | |
| return "Exllama" | |
| def get_num_tokens(self, text: str) -> int: | |
| """Get the number of tokens present in the text.""" | |
| return self.generator.tokenizer.num_tokens(text) | |
| def get_token_ids(self, text: str) -> List[int]: | |
| return self.generator.tokenizer.encode(text) | |
| # avoid base method that is not aware of how to properly tokenize (uses GPT2) | |
| # return _get_token_ids_default_method(text) | |
| def _call( | |
| self, | |
| prompt: str, | |
| stop: Optional[List[str]] = None, | |
| run_manager: Optional[CallbackManagerForLLMRun] = None, | |
| **kwargs: Any, | |
| ) -> str: | |
| assert self.tokenizer is not None | |
| from h2oai_pipeline import H2OTextGenerationPipeline | |
| prompt, num_prompt_tokens = H2OTextGenerationPipeline.limit_prompt(prompt, self.tokenizer) | |
| # NOTE: TGI server does not add prompting, so must do here | |
| data_point = dict(context=self.context, instruction=prompt, input=self.iinput) | |
| prompt = self.prompter.generate_prompt(data_point) | |
| text = '' | |
| for text1 in self.stream(prompt=prompt, stop=stop, run_manager=run_manager): | |
| text = text1 | |
| return text | |
| from enum import Enum | |
| class MatchStatus(Enum): | |
| EXACT_MATCH = 1 | |
| PARTIAL_MATCH = 0 | |
| NO_MATCH = 2 | |
| def match_status(self, sequence: str, banned_sequences: List[str]): | |
| sequence = sequence.strip().lower() | |
| for banned_seq in banned_sequences: | |
| if banned_seq == sequence: | |
| return self.MatchStatus.EXACT_MATCH | |
| elif banned_seq.startswith(sequence): | |
| return self.MatchStatus.PARTIAL_MATCH | |
| return self.MatchStatus.NO_MATCH | |
| def stream( | |
| self, | |
| prompt: str, | |
| stop: Optional[List[str]] = None, | |
| run_manager: Optional[CallbackManagerForLLMRun] = None, | |
| ) -> str: | |
| config = self.config | |
| generator = self.generator | |
| beam_search = (self.beams and self.beams >= 1 and self.beam_length and self.beam_length >= 1) | |
| ids = generator.tokenizer.encode(prompt) | |
| generator.gen_begin_reuse(ids) | |
| if beam_search: | |
| generator.begin_beam_search() | |
| token_getter = generator.beam_search | |
| else: | |
| generator.end_beam_search() | |
| token_getter = generator.gen_single_token | |
| last_newline_pos = 0 | |
| seq_length = len(generator.tokenizer.decode(generator.sequence_actual[0])) | |
| response_start = seq_length | |
| cursor_head = response_start | |
| text_callback = None | |
| if run_manager: | |
| text_callback = partial( | |
| run_manager.on_llm_new_token, verbose=self.verbose | |
| ) | |
| # parent handler of streamer expects to see prompt first else output="" and lose if prompt=None in prompter | |
| if text_callback: | |
| text_callback(prompt) | |
| text = "" | |
| while (generator.gen_num_tokens() <= ( | |
| self.max_seq_len - 4)): # Slight extra padding space as we seem to occassionally get a few more than 1-2 tokens | |
| # Fetch a token | |
| token = token_getter() | |
| # If it's the ending token replace it and end the generation. | |
| if token.item() == generator.tokenizer.eos_token_id: | |
| generator.replace_last_token(generator.tokenizer.newline_token_id) | |
| if beam_search: | |
| generator.end_beam_search() | |
| return | |
| # Tokenize the string from the last new line, we can't just decode the last token due to how sentencepiece decodes. | |
| stuff = generator.tokenizer.decode(generator.sequence_actual[0][last_newline_pos:]) | |
| cursor_tail = len(stuff) | |
| has_unicode_combined = cursor_tail<cursor_head | |
| text_chunk = stuff[cursor_head:cursor_tail] | |
| if has_unicode_combined: | |
| # replace the broken unicode character with combined one | |
| text=text[:-2] | |
| text_chunk = stuff[cursor_tail-1:cursor_tail] | |
| cursor_head = cursor_tail | |
| # Append the generated chunk to our stream buffer | |
| text += text_chunk | |
| text = self.prompter.get_response(prompt + text, prompt=prompt, | |
| sanitize_bot_response=self.sanitize_bot_response) | |
| if token.item() == generator.tokenizer.newline_token_id: | |
| last_newline_pos = len(generator.sequence_actual[0]) | |
| cursor_head = 0 | |
| cursor_tail = 0 | |
| # Check if the stream buffer is one of the stop sequences | |
| status = self.match_status(text, self.stop_sequences) | |
| if status == self.MatchStatus.EXACT_MATCH: | |
| # Encountered a stop, rewind our generator to before we hit the match and end generation. | |
| rewind_length = generator.tokenizer.encode(text).shape[-1] | |
| generator.gen_rewind(rewind_length) | |
| #gen = generator.tokenizer.decode(generator.sequence_actual[0][response_start:]) | |
| if beam_search: | |
| generator.end_beam_search() | |
| return | |
| elif status == self.MatchStatus.PARTIAL_MATCH: | |
| # Partially matched a stop, continue buffering but don't yield. | |
| continue | |
| elif status == self.MatchStatus.NO_MATCH: | |
| if text_callback and not (text_chunk == BROKEN_UNICODE): | |
| text_callback(text_chunk) | |
| yield text # Not a stop, yield the match buffer. | |
| return | |