File size: 3,008 Bytes
7b64dcd
 
 
 
 
 
 
 
 
 
 
8630353
 
 
7b64dcd
 
2988b10
7b64dcd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2988b10
7b64dcd
 
 
 
 
 
 
 
 
 
 
 
 
2988b10
7b64dcd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import pathlib
import typing

from langchain_community.chat_models.llamacpp import ChatLlamaCpp
from langchain_core.callbacks import StreamingStdOutCallbackHandler, CallbackManager
from langchain_core.messages import SystemMessage
from langchain_core.prompts import (
    ChatPromptTemplate, MessagesPlaceholder, HumanMessagePromptTemplate
)
from langchain_core.runnables import RunnableWithMessageHistory

from voice_dialogue.utils.strings import (
    remove_emojis, convert_comma_separated_numbers, convert_uppercase_words_to_lowercase
)


def create_langchain_chat_llamacpp_instance(
        local_model_path: str,
        model_params: dict | None = None
) -> ChatLlamaCpp:
    print(">>>>>>> Initializing LlamaCpp Langchain instance...")

    model_path = pathlib.Path(local_model_path)
    callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
    llamacpp_langchain_instance = ChatLlamaCpp(
        model_path=str(model_path),
        streaming=model_params.get('streaming', True),
        n_gpu_layers=model_params.get('n_gpu_layers', -1),
        n_batch=model_params.get('n_batch', 512),
        n_ctx=model_params.get('n_ctx', 2048),
        f16_kv=model_params.get('f16_kv', True),
        temperature=model_params.get('temperature', 0.8),
        top_k=model_params.get('top_k', 40),
        top_p=model_params.get('top_p', 0.95),
        max_tokens=model_params.get('n_predict', 256),
        # callback_manager=callback_manager,
        verbose=False
    )

    return llamacpp_langchain_instance


def create_langchain_pipeline(langchain_instance, system_prompt: str, get_session_history: typing.Callable):
    prompt = ChatPromptTemplate(messages=[
        SystemMessage(content=system_prompt),
        MessagesPlaceholder(variable_name="history"),
        HumanMessagePromptTemplate.from_template("{input}")
    ])
    langchain_pipeline = prompt | langchain_instance
    if get_session_history is None:
        raise NotImplementedError
    chain_with_history = RunnableWithMessageHistory(langchain_pipeline, get_session_history,
                                                    history_messages_key='history')
    return chain_with_history


def warmup_langchain_pipeline(pipeline):
    print("Warmup chat pipeline...")

    user_input = 'Hello, this is warming up step, if you understand, output "Ok".'
    config = {"configurable": {"session_id": 'warmup'}}
    for _ in pipeline.stream(input={'input': user_input}, config=config):
        pass


def preprocess_sentence_text(sentences):
    sentence_text = ''.join(sentences)
    sentence_text = remove_emojis(sentence_text)
    sentence_text = convert_comma_separated_numbers(sentence_text)
    sentence_text = convert_uppercase_words_to_lowercase(sentence_text)
    if sentence_text:
        sentence_mark = sentence_text[-1]
        sentence_content = sentence_text[:-1].replace('!', ',').replace('?', ',').replace('.', ',')
        sentence_text = f'{sentence_content}{sentence_mark}'
    return sentence_text