| from __future__ import annotations |
|
|
| import logging |
| import os |
| import platform |
|
|
| import gc |
| import torch |
| import colorama |
|
|
| from ..index_func import * |
| from ..presets import * |
| from ..utils import * |
| from .base_model import BaseLLMModel |
|
|
|
|
| class ChatGLM_Client(BaseLLMModel): |
| def __init__(self, model_name, user_name="") -> None: |
| super().__init__(model_name=model_name, user=user_name) |
| import torch |
| from transformers import AutoModel, AutoTokenizer |
| global CHATGLM_TOKENIZER, CHATGLM_MODEL |
| self.deinitialize() |
| if CHATGLM_TOKENIZER is None or CHATGLM_MODEL is None: |
| system_name = platform.system() |
| model_path = None |
| if os.path.exists("models"): |
| model_dirs = os.listdir("models") |
| if model_name in model_dirs: |
| model_path = f"models/{model_name}" |
| if model_path is not None: |
| model_source = model_path |
| else: |
| model_source = f"THUDM/{model_name}" |
| CHATGLM_TOKENIZER = AutoTokenizer.from_pretrained( |
| model_source, trust_remote_code=True |
| ) |
| quantified = False |
| if "int4" in model_name: |
| quantified = True |
| model = AutoModel.from_pretrained( |
| model_source, trust_remote_code=True |
| ) |
| if torch.cuda.is_available(): |
| |
| logging.info("CUDA is available, using CUDA") |
| model = model.half().cuda() |
| |
| elif system_name == "Darwin" and model_path is not None and not quantified: |
| logging.info("Running on macOS, using MPS") |
| |
| model = model.half().to("mps") |
| else: |
| logging.info("GPU is not available, using CPU") |
| model = model.float() |
| model = model.eval() |
| CHATGLM_MODEL = model |
|
|
| def _get_glm3_style_input(self): |
| history = self.history |
| query = history.pop()["content"] |
| return history, query |
|
|
| def _get_glm2_style_input(self): |
| history = [x["content"] for x in self.history] |
| query = history.pop() |
| logging.debug(colorama.Fore.YELLOW + |
| f"{history}" + colorama.Fore.RESET) |
| assert ( |
| len(history) % 2 == 0 |
| ), f"History should be even length. current history is: {history}" |
| history = [[history[i], history[i + 1]] |
| for i in range(0, len(history), 2)] |
| return history, query |
|
|
| def _get_glm_style_input(self): |
| if "glm2" in self.model_name: |
| return self._get_glm2_style_input() |
| else: |
| return self._get_glm3_style_input() |
|
|
| def get_answer_at_once(self): |
| history, query = self._get_glm_style_input() |
| response, _ = CHATGLM_MODEL.chat( |
| CHATGLM_TOKENIZER, query, history=history) |
| return response, len(response) |
|
|
| def get_answer_stream_iter(self): |
| history, query = self._get_glm_style_input() |
| for response, history in CHATGLM_MODEL.stream_chat( |
| CHATGLM_TOKENIZER, |
| query, |
| history, |
| max_length=self.token_upper_limit, |
| top_p=self.top_p, |
| temperature=self.temperature, |
| ): |
| yield response |
|
|
| def deinitialize(self): |
| |
| global CHATGLM_MODEL, CHATGLM_TOKENIZER |
| CHATGLM_MODEL = None |
| CHATGLM_TOKENIZER = None |
| gc.collect() |
| torch.cuda.empty_cache() |
| logging.info("ChatGLM model deinitialized") |
|
|