| | import torch |
| | from transformers import LlamaTokenizer, LlamaForCausalLM |
| | from peft import PeftModel |
| | from typing import Iterator |
| | from variables import SYSTEM, HUMAN, AI |
| |
|
| |
|
| | def load_tokenizer_and_model(base_model, adapter_model, load_8bit=True): |
| | """ |
| | Loads the tokenizer and chatbot model. |
| | Args: |
| | base_model (str): The base model to use (path to the model). |
| | adapter_model (str): The LoRA model to use (path to LoRA model). |
| | load_8bit (bool): Whether to load the model in 8-bit mode. |
| | """ |
| | if torch.cuda.is_available(): |
| | device = "cuda" |
| | else: |
| | device = "cpu" |
| |
|
| | try: |
| | if torch.backends.mps.is_available(): |
| | device = "mps" |
| | except: |
| | pass |
| | tokenizer = LlamaTokenizer.from_pretrained(base_model) |
| | if device == "cuda": |
| | model = LlamaForCausalLM.from_pretrained( |
| | base_model, |
| | load_in_8bit=load_8bit, |
| | torch_dtype=torch.float16 |
| | ) |
| | elif device == "mps": |
| | model = LlamaForCausalLM.from_pretrained( |
| | base_model, |
| | device_map={"": device} |
| | ) |
| | if adapter_model is not None: |
| | model = PeftModel.from_pretrained( |
| | model, |
| | adapter_model, |
| | device_map={"": device}, |
| | torch_dtype=torch.float16, |
| | ) |
| | else: |
| | model = LlamaForCausalLM.from_pretrained( |
| | base_model, |
| | device_map={"": device}, |
| | low_cpu_mem_usage=True, |
| | torch_dtype=torch.bfloat16, |
| | offload_folder="." |
| | ) |
| | if adapter_model is not None: |
| | model = PeftModel.from_pretrained( |
| | model, |
| | adapter_model, |
| | torch_dtype=torch.bfloat16, |
| | offload_folder="." |
| | ) |
| |
|
| | model.eval() |
| | return tokenizer, model, device |
| |
|
| | class State: |
| | interrupted = False |
| |
|
| | def interrupt(self): |
| | self.interrupted = True |
| |
|
| | def recover(self): |
| | self.interrupted = False |
| |
|
| | shared_state = State() |
| |
|
| | def decode( |
| | input_ids: torch.Tensor, |
| | model: PeftModel, |
| | tokenizer: LlamaTokenizer, |
| | stop_words: list, |
| | max_length: int, |
| | temperature: float = 1.0, |
| | top_p: float = 1.0, |
| | ) -> Iterator[str]: |
| | generated_tokens = [] |
| | past_key_values = None |
| | |
| | for _ in range(max_length): |
| | with torch.no_grad(): |
| | if past_key_values is None: |
| | outputs = model(input_ids) |
| | else: |
| | outputs = model(input_ids[:, -1:], past_key_values=past_key_values) |
| | logits = outputs.logits[:, -1, :] |
| | past_key_values = outputs.past_key_values |
| |
|
| | |
| | logits /= temperature |
| |
|
| | probs = torch.softmax(logits, dim=-1) |
| | |
| | probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) |
| | probs_sum = torch.cumsum(probs_sort, dim=-1) |
| | mask = probs_sum - probs_sort > top_p |
| | probs_sort[mask] = 0.0 |
| |
|
| | probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) |
| | next_token = torch.multinomial(probs_sort, num_samples=1) |
| | next_token = torch.gather(probs_idx, -1, next_token) |
| |
|
| | input_ids = torch.cat((input_ids, next_token), dim=-1) |
| |
|
| | generated_tokens.append(next_token[0].item()) |
| | text = tokenizer.decode(generated_tokens) |
| |
|
| | yield text |
| | if any([x in text for x in stop_words]): |
| | return |
| |
|
| |
|
| | def get_prompt_with_history(text, history, tokenizer, max_length=2048): |
| | prompt = SYSTEM |
| | history = [f"\n{HUMAN} {x[0]}\n{AI} {x[1]}" for x in history] |
| | history.append(f"\n{HUMAN} {text}\n{AI}") |
| | history_text = "" |
| | flag = False |
| | for x in history[::-1]: |
| | if ( |
| | tokenizer(prompt + history_text + x, return_tensors="pt")["input_ids"].size( |
| | -1 |
| | ) |
| | <= max_length |
| | ): |
| | history_text = x + history_text |
| | flag = True |
| | else: |
| | break |
| | if flag: |
| | return prompt + history_text, tokenizer( |
| | prompt + history_text, return_tensors="pt" |
| | ) |
| | else: |
| | return None |
| |
|
| | def is_stop_word_or_prefix(s: str, stop_words: list) -> bool: |
| | for stop_word in stop_words: |
| | if s.endswith(stop_word): |
| | return True |
| | for i in range(1, len(stop_word)): |
| | if s.endswith(stop_word[:i]): |
| | return True |
| | return False |