| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | from threading import Thread |
| | from ten import ( |
| | Extension, |
| | TenEnv, |
| | Cmd, |
| | Data, |
| | StatusCode, |
| | CmdResult, |
| | ) |
| | from .litellm import LiteLLM, LiteLLMConfig |
| | from .log import logger |
| | from .utils import get_micro_ts, parse_sentence |
| |
|
| |
|
| | CMD_IN_FLUSH = "flush" |
| | CMD_OUT_FLUSH = "flush" |
| | DATA_IN_TEXT_DATA_PROPERTY_TEXT = "text" |
| | DATA_IN_TEXT_DATA_PROPERTY_IS_FINAL = "is_final" |
| | DATA_OUT_TEXT_DATA_PROPERTY_TEXT = "text" |
| | DATA_OUT_TEXT_DATA_PROPERTY_TEXT_END_OF_SEGMENT = "end_of_segment" |
| |
|
| | PROPERTY_API_KEY = "api_key" |
| | PROPERTY_BASE_URL = "base_url" |
| | PROPERTY_FREQUENCY_PENALTY = "frequency_penalty" |
| | PROPERTY_GREETING = "greeting" |
| | PROPERTY_MAX_MEMORY_LENGTH = "max_memory_length" |
| | PROPERTY_MAX_TOKENS = "max_tokens" |
| | PROPERTY_MODEL = "model" |
| | PROPERTY_PRESENCE_PENALTY = "presence_penalty" |
| | PROPERTY_PROMPT = "prompt" |
| | PROPERTY_PROVIDER = "provider" |
| | PROPERTY_TEMPERATURE = "temperature" |
| | PROPERTY_TOP_P = "top_p" |
| |
|
| |
|
| | class LiteLLMExtension(Extension): |
| | memory = [] |
| | max_memory_length = 10 |
| | outdate_ts = 0 |
| | litellm = None |
| |
|
| | def on_start(self, ten: TenEnv) -> None: |
| | logger.info("LiteLLMExtension on_start") |
| | |
| | litellm_config = LiteLLMConfig.default_config() |
| |
|
| | for key in [PROPERTY_API_KEY, PROPERTY_GREETING, PROPERTY_MODEL, PROPERTY_PROMPT]: |
| | try: |
| | val = ten.get_property_string(key) |
| | if val: |
| | litellm_config.key = val |
| | except Exception as e: |
| | logger.warning(f"get_property_string optional {key} failed, err: {e}") |
| |
|
| | for key in [PROPERTY_FREQUENCY_PENALTY, PROPERTY_PRESENCE_PENALTY, PROPERTY_TEMPERATURE, PROPERTY_TOP_P]: |
| | try: |
| | litellm_config.key = float(ten.get_property_float(key)) |
| | except Exception as e: |
| | logger.warning(f"get_property_float optional {key} failed, err: {e}") |
| |
|
| | for key in [PROPERTY_MAX_MEMORY_LENGTH, PROPERTY_MAX_TOKENS]: |
| | try: |
| | litellm_config.key = int(ten.get_property_int(key)) |
| | except Exception as e: |
| | logger.warning(f"get_property_int optional {key} failed, err: {e}") |
| |
|
| | |
| | self.litellm = LiteLLM(litellm_config) |
| | logger.info(f"newLiteLLM succeed with max_tokens: {litellm_config.max_tokens}, model: {litellm_config.model}") |
| |
|
| | |
| | greeting = ten.get_property_string(PROPERTY_GREETING) |
| | if greeting: |
| | try: |
| | output_data = Data.create("text_data") |
| | output_data.set_property_string(DATA_OUT_TEXT_DATA_PROPERTY_TEXT, greeting) |
| | output_data.set_property_bool(DATA_OUT_TEXT_DATA_PROPERTY_TEXT_END_OF_SEGMENT, True) |
| | ten.send_data(output_data) |
| | logger.info(f"greeting [{greeting}] sent") |
| | except Exception as e: |
| | logger.error(f"greeting [{greeting}] send failed, err: {e}") |
| |
|
| | ten.on_start_done() |
| |
|
| | def on_stop(self, ten: TenEnv) -> None: |
| | logger.info("LiteLLMExtension on_stop") |
| | ten.on_stop_done() |
| |
|
| | def on_cmd(self, ten: TenEnv, cmd: Cmd) -> None: |
| | logger.info("LiteLLMExtension on_cmd") |
| | cmd_json = cmd.to_json() |
| | logger.info(f"LiteLLMExtension on_cmd json: {cmd_json}") |
| |
|
| | cmd_name = cmd.get_name() |
| |
|
| | if cmd_name == CMD_IN_FLUSH: |
| | self.outdate_ts = get_micro_ts() |
| | cmd_out = Cmd.create(CMD_OUT_FLUSH) |
| | ten.send_cmd(cmd_out, None) |
| | logger.info(f"LiteLLMExtension on_cmd sent flush") |
| | else: |
| | logger.info(f"LiteLLMExtension on_cmd unknown cmd: {cmd_name}") |
| | cmd_result = CmdResult.create(StatusCode.ERROR) |
| | cmd_result.set_property_string("detail", "unknown cmd") |
| | ten.return_result(cmd_result, cmd) |
| | return |
| |
|
| | cmd_result = CmdResult.create(StatusCode.OK) |
| | cmd_result.set_property_string("detail", "success") |
| | ten.return_result(cmd_result, cmd) |
| |
|
| | def on_data(self, ten: TenEnv, data: Data) -> None: |
| | """ |
| | on_data receives data from ten graph. |
| | current suppotend data: |
| | - name: text_data |
| | example: |
| | {name: text_data, properties: {text: "hello"} |
| | """ |
| | logger.info(f"LiteLLMExtension on_data") |
| |
|
| | |
| | try: |
| | is_final = data.get_property_bool(DATA_IN_TEXT_DATA_PROPERTY_IS_FINAL) |
| | if not is_final: |
| | logger.info("ignore non-final input") |
| | return |
| | except Exception as e: |
| | logger.error(f"on_data get_property_bool {DATA_IN_TEXT_DATA_PROPERTY_IS_FINAL} failed, err: {e}") |
| | return |
| |
|
| | |
| | try: |
| | input_text = data.get_property_string(DATA_IN_TEXT_DATA_PROPERTY_TEXT) |
| | if not input_text: |
| | logger.info("ignore empty text") |
| | return |
| | logger.info(f"on_data input text: [{input_text}]") |
| | except Exception as e: |
| | logger.error(f"on_data get_property_string {DATA_IN_TEXT_DATA_PROPERTY_TEXT} failed, err: {e}") |
| | return |
| |
|
| | |
| | if len(self.memory) > self.max_memory_length: |
| | self.memory.pop(0) |
| | self.memory.append({"role": "user", "content": input_text}) |
| |
|
| | def chat_completions_stream_worker(start_time, input_text, memory): |
| | try: |
| | logger.info(f"chat_completions_stream_worker for input text: [{input_text}] memory: {memory}") |
| |
|
| | |
| | resp = self.litellm.get_chat_completions_stream(memory) |
| | if resp is None: |
| | logger.info(f"chat_completions_stream_worker for input text: [{input_text}] failed") |
| | return |
| |
|
| | sentence = "" |
| | full_content = "" |
| | first_sentence_sent = False |
| |
|
| | for chat_completions in resp: |
| | if start_time < self.outdate_ts: |
| | logger.info(f"chat_completions_stream_worker recv interrupt and flushing for input text: [{input_text}], startTs: {start_time}, outdateTs: {self.outdate_ts}") |
| | break |
| |
|
| | if (len(chat_completions.choices) > 0 and chat_completions.choices[0].delta.content is not None): |
| | content = chat_completions.choices[0].delta.content |
| | else: |
| | content = "" |
| |
|
| | full_content += content |
| |
|
| | while True: |
| | sentence, content, sentence_is_final = parse_sentence(sentence, content) |
| |
|
| | if len(sentence) == 0 or not sentence_is_final: |
| | logger.info(f"sentence {sentence} is empty or not final") |
| | break |
| |
|
| | logger.info(f"chat_completions_stream_worker recv for input text: [{input_text}] got sentence: [{sentence}]") |
| |
|
| | |
| | try: |
| | output_data = Data.create("text_data") |
| | output_data.set_property_string(DATA_OUT_TEXT_DATA_PROPERTY_TEXT, sentence) |
| | output_data.set_property_bool(DATA_OUT_TEXT_DATA_PROPERTY_TEXT_END_OF_SEGMENT, False) |
| | ten.send_data(output_data) |
| | logger.info(f"chat_completions_stream_worker recv for input text: [{input_text}] sent sentence [{sentence}]") |
| | except Exception as e: |
| | logger.error(f"chat_completions_stream_worker recv for input text: [{input_text}] send sentence [{sentence}] failed, err: {e}") |
| | break |
| |
|
| | sentence = "" |
| | if not first_sentence_sent: |
| | first_sentence_sent = True |
| | logger.info(f"chat_completions_stream_worker recv for input text: [{input_text}] first sentence sent, first_sentence_latency {get_micro_ts() - start_time}ms") |
| |
|
| | |
| | memory.append({"role": "assistant", "content": full_content}) |
| |
|
| | |
| | try: |
| | output_data = Data.create("text_data") |
| | output_data.set_property_string(DATA_OUT_TEXT_DATA_PROPERTY_TEXT, sentence) |
| | output_data.set_property_bool(DATA_OUT_TEXT_DATA_PROPERTY_TEXT_END_OF_SEGMENT, True) |
| | ten.send_data(output_data) |
| | logger.info(f"chat_completions_stream_worker for input text: [{input_text}] end of segment with sentence [{sentence}] sent") |
| | except Exception as e: |
| | logger.error(f"chat_completions_stream_worker for input text: [{input_text}] end of segment with sentence [{sentence}] send failed, err: {e}") |
| |
|
| | except Exception as e: |
| | logger.error(f"chat_completions_stream_worker for input text: [{input_text}] failed, err: {e}") |
| |
|
| | |
| | start_time = get_micro_ts() |
| | thread = Thread( |
| | target=chat_completions_stream_worker, |
| | args=(start_time, input_text, self.memory), |
| | ) |
| | thread.start() |
| | logger.info(f"LiteLLMExtension on_data end") |
| |
|