| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import asyncio |
| | import os |
| | from collections.abc import AsyncGenerator, Generator |
| | from threading import Thread |
| | from typing import TYPE_CHECKING, Any, Optional |
| |
|
| | from ..extras.constants import EngineName |
| | from ..extras.misc import torch_gc |
| | from ..hparams import get_infer_args |
| |
|
| |
|
| | if TYPE_CHECKING: |
| | from ..data.mm_plugin import AudioInput, ImageInput, VideoInput |
| | from .base_engine import BaseEngine, Response |
| |
|
| |
|
| | def _start_background_loop(loop: "asyncio.AbstractEventLoop") -> None: |
| | asyncio.set_event_loop(loop) |
| | loop.run_forever() |
| |
|
| |
|
| | class ChatModel: |
| | r"""General class for chat models. Backed by huggingface or vllm engines. |
| | |
| | Supports both sync and async methods. |
| | Sync methods: chat(), stream_chat() and get_scores(). |
| | Async methods: achat(), astream_chat() and aget_scores(). |
| | """ |
| |
|
| | def __init__(self, args: Optional[dict[str, Any]] = None) -> None: |
| | model_args, data_args, finetuning_args, generating_args = get_infer_args(args) |
| |
|
| | if model_args.infer_backend == EngineName.HF: |
| | from .hf_engine import HuggingfaceEngine |
| |
|
| | self.engine: BaseEngine = HuggingfaceEngine(model_args, data_args, finetuning_args, generating_args) |
| | elif model_args.infer_backend == EngineName.VLLM: |
| | try: |
| | from .vllm_engine import VllmEngine |
| |
|
| | self.engine: BaseEngine = VllmEngine(model_args, data_args, finetuning_args, generating_args) |
| | except ImportError as e: |
| | raise ImportError( |
| | "vLLM not install, you may need to run `pip install vllm`\n" |
| | "or try to use HuggingFace backend: --infer_backend huggingface" |
| | ) from e |
| | elif model_args.infer_backend == EngineName.SGLANG: |
| | try: |
| | from .sglang_engine import SGLangEngine |
| |
|
| | self.engine: BaseEngine = SGLangEngine(model_args, data_args, finetuning_args, generating_args) |
| | except ImportError as e: |
| | raise ImportError( |
| | "SGLang not install, you may need to run `pip install sglang[all]`\n" |
| | "or try to use HuggingFace backend: --infer_backend huggingface" |
| | ) from e |
| | elif model_args.infer_backend == EngineName.KT: |
| | try: |
| | from .kt_engine import KTransformersEngine |
| |
|
| | self.engine: BaseEngine = KTransformersEngine(model_args, data_args, finetuning_args, generating_args) |
| | except ImportError as e: |
| | raise ImportError( |
| | "KTransformers not install, you may need to run `pip install ktransformers`\n" |
| | "or try to use HuggingFace backend: --infer_backend huggingface" |
| | ) from e |
| | else: |
| | raise NotImplementedError(f"Unknown backend: {model_args.infer_backend}") |
| |
|
| | self._loop = asyncio.new_event_loop() |
| | self._thread = Thread(target=_start_background_loop, args=(self._loop,), daemon=True) |
| | self._thread.start() |
| |
|
| | def chat( |
| | self, |
| | messages: list[dict[str, str]], |
| | system: Optional[str] = None, |
| | tools: Optional[str] = None, |
| | images: Optional[list["ImageInput"]] = None, |
| | videos: Optional[list["VideoInput"]] = None, |
| | audios: Optional[list["AudioInput"]] = None, |
| | **input_kwargs, |
| | ) -> list["Response"]: |
| | r"""Get a list of responses of the chat model.""" |
| | task = asyncio.run_coroutine_threadsafe( |
| | self.achat(messages, system, tools, images, videos, audios, **input_kwargs), self._loop |
| | ) |
| | return task.result() |
| |
|
| | async def achat( |
| | self, |
| | messages: list[dict[str, str]], |
| | system: Optional[str] = None, |
| | tools: Optional[str] = None, |
| | images: Optional[list["ImageInput"]] = None, |
| | videos: Optional[list["VideoInput"]] = None, |
| | audios: Optional[list["AudioInput"]] = None, |
| | **input_kwargs, |
| | ) -> list["Response"]: |
| | r"""Asynchronously get a list of responses of the chat model.""" |
| | return await self.engine.chat(messages, system, tools, images, videos, audios, **input_kwargs) |
| |
|
| | def stream_chat( |
| | self, |
| | messages: list[dict[str, str]], |
| | system: Optional[str] = None, |
| | tools: Optional[str] = None, |
| | images: Optional[list["ImageInput"]] = None, |
| | videos: Optional[list["VideoInput"]] = None, |
| | audios: Optional[list["AudioInput"]] = None, |
| | **input_kwargs, |
| | ) -> Generator[str, None, None]: |
| | r"""Get the response token-by-token of the chat model.""" |
| | generator = self.astream_chat(messages, system, tools, images, videos, audios, **input_kwargs) |
| | while True: |
| | try: |
| | task = asyncio.run_coroutine_threadsafe(generator.__anext__(), self._loop) |
| | yield task.result() |
| | except StopAsyncIteration: |
| | break |
| |
|
| | async def astream_chat( |
| | self, |
| | messages: list[dict[str, str]], |
| | system: Optional[str] = None, |
| | tools: Optional[str] = None, |
| | images: Optional[list["ImageInput"]] = None, |
| | videos: Optional[list["VideoInput"]] = None, |
| | audios: Optional[list["AudioInput"]] = None, |
| | **input_kwargs, |
| | ) -> AsyncGenerator[str, None]: |
| | r"""Asynchronously get the response token-by-token of the chat model.""" |
| | async for new_token in self.engine.stream_chat( |
| | messages, system, tools, images, videos, audios, **input_kwargs |
| | ): |
| | yield new_token |
| |
|
| | def get_scores( |
| | self, |
| | batch_input: list[str], |
| | **input_kwargs, |
| | ) -> list[float]: |
| | r"""Get a list of scores of the reward model.""" |
| | task = asyncio.run_coroutine_threadsafe(self.aget_scores(batch_input, **input_kwargs), self._loop) |
| | return task.result() |
| |
|
| | async def aget_scores( |
| | self, |
| | batch_input: list[str], |
| | **input_kwargs, |
| | ) -> list[float]: |
| | r"""Asynchronously get a list of scores of the reward model.""" |
| | return await self.engine.get_scores(batch_input, **input_kwargs) |
| |
|
| |
|
| | def run_chat() -> None: |
| | if os.name != "nt": |
| | try: |
| | import readline |
| | except ImportError: |
| | print("Install `readline` for a better experience.") |
| |
|
| | chat_model = ChatModel() |
| | messages = [] |
| | print("Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application.") |
| |
|
| | while True: |
| | try: |
| | query = input("\nUser: ") |
| | except UnicodeDecodeError: |
| | print("Detected decoding error at the inputs, please set the terminal encoding to utf-8.") |
| | continue |
| | except Exception: |
| | raise |
| |
|
| | if query.strip() == "exit": |
| | break |
| |
|
| | if query.strip() == "clear": |
| | messages = [] |
| | torch_gc() |
| | print("History has been removed.") |
| | continue |
| |
|
| | messages.append({"role": "user", "content": query}) |
| | print("Assistant: ", end="", flush=True) |
| |
|
| | response = "" |
| | for new_text in chat_model.stream_chat(messages): |
| | print(new_text, end="", flush=True) |
| | response += new_text |
| | print() |
| | messages.append({"role": "assistant", "content": response}) |
| |
|