liumaolin
Refactor imports for consistency in `kokoro.py` and `processor.py`. Use absolute paths for better readability and maintainability.
8630353
raw
history blame
3.01 kB
import pathlib
import typing
from langchain_community.chat_models.llamacpp import ChatLlamaCpp
from langchain_core.callbacks import StreamingStdOutCallbackHandler, CallbackManager
from langchain_core.messages import SystemMessage
from langchain_core.prompts import (
ChatPromptTemplate, MessagesPlaceholder, HumanMessagePromptTemplate
)
from langchain_core.runnables import RunnableWithMessageHistory
from voice_dialogue.utils.strings import (
remove_emojis, convert_comma_separated_numbers, convert_uppercase_words_to_lowercase
)
def create_langchain_chat_llamacpp_instance(
local_model_path: str,
model_params: dict | None = None
) -> ChatLlamaCpp:
print(">>>>>>> Initializing LlamaCpp Langchain instance...")
model_path = pathlib.Path(local_model_path)
callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
llamacpp_langchain_instance = ChatLlamaCpp(
model_path=str(model_path),
streaming=model_params.get('streaming', True),
n_gpu_layers=model_params.get('n_gpu_layers', -1),
n_batch=model_params.get('n_batch', 512),
n_ctx=model_params.get('n_ctx', 2048),
f16_kv=model_params.get('f16_kv', True),
temperature=model_params.get('temperature', 0.8),
top_k=model_params.get('top_k', 40),
top_p=model_params.get('top_p', 0.95),
max_tokens=model_params.get('n_predict', 256),
# callback_manager=callback_manager,
verbose=False
)
return llamacpp_langchain_instance
def create_langchain_pipeline(langchain_instance, system_prompt: str, get_session_history: typing.Callable):
prompt = ChatPromptTemplate(messages=[
SystemMessage(content=system_prompt),
MessagesPlaceholder(variable_name="history"),
HumanMessagePromptTemplate.from_template("{input}")
])
langchain_pipeline = prompt | langchain_instance
if get_session_history is None:
raise NotImplementedError
chain_with_history = RunnableWithMessageHistory(langchain_pipeline, get_session_history,
history_messages_key='history')
return chain_with_history
def warmup_langchain_pipeline(pipeline):
print("Warmup chat pipeline...")
user_input = 'Hello, this is warming up step, if you understand, output "Ok".'
config = {"configurable": {"session_id": 'warmup'}}
for _ in pipeline.stream(input={'input': user_input}, config=config):
pass
def preprocess_sentence_text(sentences):
sentence_text = ''.join(sentences)
sentence_text = remove_emojis(sentence_text)
sentence_text = convert_comma_separated_numbers(sentence_text)
sentence_text = convert_uppercase_words_to_lowercase(sentence_text)
if sentence_text:
sentence_mark = sentence_text[-1]
sentence_content = sentence_text[:-1].replace('!', ',').replace('?', ',').replace('.', ',')
sentence_text = f'{sentence_content}{sentence_mark}'
return sentence_text