File size: 5,360 Bytes
593f0ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
from langchain_community.llms import Ollama
from langchain.prompts import PromptTemplate, ChatPromptTemplate, MessagesPlaceholder
from langchain.memory import ConversationBufferWindowMemory
from langchain.chains import LLMChain, create_history_aware_retriever, create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain.schema import Document

from src.utils import load_config
from src.vectorstore import VectorDB


def format_docs(docs: list[Document]):
    return '\n\n'.join(doc.page_content for doc in docs)


class OllamaChain:
    def __init__(self, chat_memory) -> None:
        prompt = PromptTemplate(
            template="""<|begin_of_text|>
            <|start_header_id|>system<|end_header_id|>
            You are a honest and unbiased AI assistant
            <|eot_id|>
            <|start_header_id|>user<|end_header_id|>
            Previous conversation={chat_history}
            Question: {input} 
            Answer: <|eot_id|><|start_header_id|>assistant<|end_header_id|>""",
            input_variables=['chat_history', 'input']
        )

        self.memory = ConversationBufferWindowMemory(
            memory_key='chat_history',
            chat_memory=chat_memory,
            k=3,
            return_messages=True
        )

        config = load_config()
        llm = Ollama(**config['chat_model'])
        # llm = Ollama(model='llama3:latest', temperature=0.75, num_gpu=1)

        self.llm_chain = LLMChain(prompt=prompt, llm=llm, memory=self.memory, output_parser=StrOutputParser())
        # runnable = prompt | llm

    def run(self, user_input):
        response = self.llm_chain.invoke(user_input)

        return response['text']


class OllamaRAGChain:
    def __init__(self, chat_memory, uploaded_file=None):
        # initialize vector db using config
        from src.utils import load_config
        config = load_config()
        vector_db_config = config.get('vector_database', {})
        db_name = 'pinecone' if 'pinecone' in vector_db_config else 'chroma'
        index_name = 'default'
        self.vector_db = VectorDB(db_name, index_name)
        if uploaded_file:
            self.update_knowledge_base(uploaded_file)

        # initialize llm
        config = load_config()
        self.llm = Ollama(**config['chat_model'])

        # initialize memory
        self.chat_memory = chat_memory

        # initialize sub chain with history message
        contextual_q_system_prompt = """Given a chat history and the latest user question which might refer to context \
        in the chat history. Check if the user's question refers to the chat history or not. If does, formulate a \
        standalone question which is incorporated from the latest question and history and can be understood without \
        the chat history.
        Do NOT answer the question, just reformulate it if needed and otherwise return it as is."""

        self.contextual_q_prompt = ChatPromptTemplate.from_messages(
            [
                ('system', contextual_q_system_prompt),
                MessagesPlaceholder('chat_history'),
                ('human', '{input}'),
            ]
        )

        self.history_aware_retriever = create_history_aware_retriever(
            self.llm, self.vector_db.as_retriever(), self.contextual_q_prompt
        )

        # initialize qa chain
        qa_system_prompt = """You are an assistant for question-answering tasks. Use the following pieces of retrieved\
        context to answer the question. If you don't know the answer, just say that you don't know.
        Context: {context}"""
        qa_prompt = ChatPromptTemplate.from_messages(
            [
                ('system', qa_system_prompt),
                MessagesPlaceholder('chat_history'),
                ('human', '{input}'),
            ]
        )

        self.question_answer_chain = create_stuff_documents_chain(self.llm, qa_prompt)

        rag_chain = create_retrieval_chain(self.history_aware_retriever, self.question_answer_chain)

        self.conversation_rag_chain = RunnableWithMessageHistory(
            rag_chain,
            lambda session_id: chat_memory,
            input_messages_key='input',
            history_messages_key='chat_history',
            output_messages_key='answer'
        )

    def run(self, user_input):
        config = {"configurable": {"session_id": "any"}}
        response = self.conversation_rag_chain.invoke({'input': user_input}, config)

        return response['answer']

    def update_chain(self, uploaded_pdf):
        self.update_knowledge_base(uploaded_pdf)
        self.history_aware_retriever = create_history_aware_retriever(
            self.llm, self.vector_db.as_retriever(), self.contextual_q_prompt
        )
        self.conversation_rag_chain = RunnableWithMessageHistory(
            create_retrieval_chain(self.history_aware_retriever, self.question_answer_chain),
            lambda session_id: self.chat_memory,
            input_messages_key='input',
            history_messages_key='chat_history',
            output_messages_key='answer'
        )

    def update_knowledge_base(self, uploaded_pdf):
        self.vector_db.index(uploaded_pdf)