Spaces:
Runtime error
Runtime error
| import json | |
| from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Sequence, Tuple | |
| import gradio as gr | |
| from gradio.components import Component # cannot use TYPE_CHECKING here | |
| from ..chat import ChatModel | |
| from ..data import Role | |
| from ..extras.misc import torch_gc | |
| from ..hparams import GeneratingArguments | |
| from .common import get_save_dir | |
| from .locales import ALERTS | |
| if TYPE_CHECKING: | |
| from .manager import Manager | |
| class WebChatModel(ChatModel): | |
| def __init__( | |
| self, manager: "Manager", demo_mode: Optional[bool] = False, lazy_init: Optional[bool] = True | |
| ) -> None: | |
| self.manager = manager | |
| self.demo_mode = demo_mode | |
| self.model = None | |
| self.tokenizer = None | |
| self.generating_args = GeneratingArguments() | |
| if not lazy_init: # read arguments from command line | |
| super().__init__() | |
| if demo_mode: # load demo_config.json if exists | |
| import json | |
| try: | |
| with open("demo_config.json", "r", encoding="utf-8") as f: | |
| args = json.load(f) | |
| assert args.get("model_name_or_path", None) and args.get("template", None) | |
| super().__init__(args) | |
| except AssertionError: | |
| print("Please provided model name and template in `demo_config.json`.") | |
| except Exception: | |
| print("Cannot find `demo_config.json` at current directory.") | |
| def loaded(self) -> bool: | |
| return self.model is not None | |
| def load_model(self, data: Dict[Component, Any]) -> Generator[str, None, None]: | |
| get = lambda name: data[self.manager.get_elem_by_name(name)] | |
| lang = get("top.lang") | |
| error = "" | |
| if self.loaded: | |
| error = ALERTS["err_exists"][lang] | |
| elif not get("top.model_name"): | |
| error = ALERTS["err_no_model"][lang] | |
| elif not get("top.model_path"): | |
| error = ALERTS["err_no_path"][lang] | |
| elif self.demo_mode: | |
| error = ALERTS["err_demo"][lang] | |
| if error: | |
| gr.Warning(error) | |
| yield error | |
| return | |
| if get("top.adapter_path"): | |
| adapter_name_or_path = ",".join( | |
| [ | |
| get_save_dir(get("top.model_name"), get("top.finetuning_type"), adapter) | |
| for adapter in get("top.adapter_path") | |
| ] | |
| ) | |
| else: | |
| adapter_name_or_path = None | |
| yield ALERTS["info_loading"][lang] | |
| args = dict( | |
| model_name_or_path=get("top.model_path"), | |
| adapter_name_or_path=adapter_name_or_path, | |
| finetuning_type=get("top.finetuning_type"), | |
| quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None, | |
| template=get("top.template"), | |
| flash_attn=(get("top.booster") == "flash_attn"), | |
| use_unsloth=(get("top.booster") == "unsloth"), | |
| rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None, | |
| ) | |
| super().__init__(args) | |
| yield ALERTS["info_loaded"][lang] | |
| def unload_model(self, data: Dict[Component, Any]) -> Generator[str, None, None]: | |
| lang = data[self.manager.get_elem_by_name("top.lang")] | |
| if self.demo_mode: | |
| gr.Warning(ALERTS["err_demo"][lang]) | |
| yield ALERTS["err_demo"][lang] | |
| return | |
| yield ALERTS["info_unloading"][lang] | |
| self.model = None | |
| self.tokenizer = None | |
| torch_gc() | |
| yield ALERTS["info_unloaded"][lang] | |
| def predict( | |
| self, | |
| chatbot: List[Tuple[str, str]], | |
| query: str, | |
| messages: Sequence[Tuple[str, str]], | |
| system: str, | |
| tools: str, | |
| max_new_tokens: int, | |
| top_p: float, | |
| temperature: float, | |
| ) -> Generator[Tuple[Sequence[Tuple[str, str]], Sequence[Tuple[str, str]]], None, None]: | |
| chatbot.append([query, ""]) | |
| query_messages = messages + [{"role": Role.USER, "content": query}] | |
| response = "" | |
| for new_text in self.stream_chat( | |
| query_messages, system, tools, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature | |
| ): | |
| response += new_text | |
| if tools: | |
| result = self.template.format_tools.extract(response) | |
| else: | |
| result = response | |
| if isinstance(result, tuple): | |
| name, arguments = result | |
| arguments = json.loads(arguments) | |
| tool_call = json.dumps({"name": name, "arguments": arguments}, ensure_ascii=False) | |
| output_messages = query_messages + [{"role": Role.FUNCTION, "content": tool_call}] | |
| bot_text = "```json\n" + tool_call + "\n```" | |
| else: | |
| output_messages = query_messages + [{"role": Role.ASSISTANT, "content": result}] | |
| bot_text = result | |
| chatbot[-1] = [query, self.postprocess(bot_text)] | |
| yield chatbot, output_messages | |
| def postprocess(self, response: str) -> str: | |
| blocks = response.split("```") | |
| for i, block in enumerate(blocks): | |
| if i % 2 == 0: | |
| blocks[i] = block.replace("<", "<").replace(">", ">") | |
| return "```".join(blocks) | |