| import abc |
| from typing import Any, Generator, List, Optional |
|
|
| from injector import inject |
|
|
| from taskweaver.config.config_mgt import AppConfigSource |
| from taskweaver.config.module_config import ModuleConfig |
| from taskweaver.llm.util import ChatMessageType |
|
|
|
|
| class LLMModuleConfig(ModuleConfig): |
| def _configure(self) -> None: |
| self._set_name("llm") |
| self.api_type = self._get_str( |
| "api_type", |
| "openai", |
| ) |
| self.embedding_api_type = self._get_str( |
| "embedding_api_type", |
| "sentence_transformer", |
| ) |
| self.api_base: Optional[str] = self._get_str("api_base", None, required=False) |
| self.api_key: Optional[str] = self._get_str( |
| "api_key", |
| None, |
| required=False, |
| ) |
|
|
| self.model: Optional[str] = self._get_str("model", None, required=False) |
| self.backup_model: Optional[str] = self._get_str( |
| "backup_model", |
| None, |
| required=False, |
| ) |
| self.embedding_model: Optional[str] = self._get_str( |
| "embedding_model", |
| None, |
| required=False, |
| ) |
|
|
| self.response_format: Optional[str] = self._get_enum( |
| "response_format", |
| options=["json_object", "text"], |
| default="json_object", |
| ) |
|
|
| self.use_mock: bool = self._get_bool("use_mock", False) |
|
|
|
|
| class LLMServiceConfig(ModuleConfig): |
| @inject |
| def __init__( |
| self, |
| src: AppConfigSource, |
| llm_module_config: LLMModuleConfig, |
| ) -> None: |
| self.llm_module_config = llm_module_config |
| super().__init__(src) |
|
|
| def _set_name(self, name: str) -> None: |
| self.name = f"llm.{name}" |
|
|
|
|
| class CompletionService(abc.ABC): |
| @abc.abstractmethod |
| def chat_completion( |
| self, |
| messages: List[ChatMessageType], |
| use_backup_engine: bool = False, |
| stream: bool = True, |
| temperature: Optional[float] = None, |
| max_tokens: Optional[int] = None, |
| top_p: Optional[float] = None, |
| stop: Optional[List[str]] = None, |
| **kwargs: Any, |
| ) -> Generator[ChatMessageType, None, None]: |
| """ |
| Chat completion API |
| |
| :param messages: list of messages |
| |
| :param use_backup_engine: whether to use back up engine |
| :param stream: whether to stream the response |
| |
| :param temperature: temperature |
| :param max_tokens: maximum number of tokens |
| :param top_p: top p |
| |
| :param kwargs: other model specific keyword arguments |
| |
| :return: generator of messages |
| """ |
|
|
| raise NotImplementedError |
|
|
|
|
| class EmbeddingService(abc.ABC): |
| @abc.abstractmethod |
| def get_embeddings(self, strings: List[str]) -> List[List[float]]: |
| """ |
| Embedding API |
| |
| :param strings: list of strings to be embedded |
| :return: list of embeddings |
| """ |
| raise NotImplementedError |
|
|