liumaolin
Refactor imports for consistency in `kokoro.py` and `processor.py`. Use absolute paths for better readability and maintainability.
8630353
| import pathlib | |
| import typing | |
| from langchain_community.chat_models.llamacpp import ChatLlamaCpp | |
| from langchain_core.callbacks import StreamingStdOutCallbackHandler, CallbackManager | |
| from langchain_core.messages import SystemMessage | |
| from langchain_core.prompts import ( | |
| ChatPromptTemplate, MessagesPlaceholder, HumanMessagePromptTemplate | |
| ) | |
| from langchain_core.runnables import RunnableWithMessageHistory | |
| from voice_dialogue.utils.strings import ( | |
| remove_emojis, convert_comma_separated_numbers, convert_uppercase_words_to_lowercase | |
| ) | |
| def create_langchain_chat_llamacpp_instance( | |
| local_model_path: str, | |
| model_params: dict | None = None | |
| ) -> ChatLlamaCpp: | |
| print(">>>>>>> Initializing LlamaCpp Langchain instance...") | |
| model_path = pathlib.Path(local_model_path) | |
| callback_manager = CallbackManager([StreamingStdOutCallbackHandler()]) | |
| llamacpp_langchain_instance = ChatLlamaCpp( | |
| model_path=str(model_path), | |
| streaming=model_params.get('streaming', True), | |
| n_gpu_layers=model_params.get('n_gpu_layers', -1), | |
| n_batch=model_params.get('n_batch', 512), | |
| n_ctx=model_params.get('n_ctx', 2048), | |
| f16_kv=model_params.get('f16_kv', True), | |
| temperature=model_params.get('temperature', 0.8), | |
| top_k=model_params.get('top_k', 40), | |
| top_p=model_params.get('top_p', 0.95), | |
| max_tokens=model_params.get('n_predict', 256), | |
| # callback_manager=callback_manager, | |
| verbose=False | |
| ) | |
| return llamacpp_langchain_instance | |
| def create_langchain_pipeline(langchain_instance, system_prompt: str, get_session_history: typing.Callable): | |
| prompt = ChatPromptTemplate(messages=[ | |
| SystemMessage(content=system_prompt), | |
| MessagesPlaceholder(variable_name="history"), | |
| HumanMessagePromptTemplate.from_template("{input}") | |
| ]) | |
| langchain_pipeline = prompt | langchain_instance | |
| if get_session_history is None: | |
| raise NotImplementedError | |
| chain_with_history = RunnableWithMessageHistory(langchain_pipeline, get_session_history, | |
| history_messages_key='history') | |
| return chain_with_history | |
| def warmup_langchain_pipeline(pipeline): | |
| print("Warmup chat pipeline...") | |
| user_input = 'Hello, this is warming up step, if you understand, output "Ok".' | |
| config = {"configurable": {"session_id": 'warmup'}} | |
| for _ in pipeline.stream(input={'input': user_input}, config=config): | |
| pass | |
| def preprocess_sentence_text(sentences): | |
| sentence_text = ''.join(sentences) | |
| sentence_text = remove_emojis(sentence_text) | |
| sentence_text = convert_comma_separated_numbers(sentence_text) | |
| sentence_text = convert_uppercase_words_to_lowercase(sentence_text) | |
| if sentence_text: | |
| sentence_mark = sentence_text[-1] | |
| sentence_content = sentence_text[:-1].replace('!', ',').replace('?', ',').replace('.', ',') | |
| sentence_text = f'{sentence_content}{sentence_mark}' | |
| return sentence_text | |