File size: 4,300 Bytes
0870bc8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99

from langchain_core.runnables import RunnablePassthrough, RunnableLambda
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain.memory import ChatMessageHistory
from langchain.retrievers import ContextualCompressionRetriever
from langchain_community.document_compressors import JinaRerank
from langchain_core.chat_history import BaseChatMessageHistory
from src.services.vector_db.qdrent.upload_document import upload_document_existing_collection, \
    answer_query_from_existing_collection
from langchain_groq import ChatGroq
import os
os.environ["JINA_API_KEY"] = os.getenv("JINA_API")
from src import logging as logger


class AnswerQuery:
    def __init__(self, prompt, vector_embedding, sparse_embedding, follow_up_prompt, json_parser):
        self.chat_history_store = {}
        self.compressor = JinaRerank(model="jina-reranker-v2-base-multilingual")
        self.vector_embed = vector_embedding
        self.sparse_embed = sparse_embedding
        self.prompt = prompt
        self.follow_up_prompt = follow_up_prompt
        self.json_parser = json_parser

    def format_docs(self, docs: str):
        global sources
        global temp_context
        sources = []
        context = ""
        for doc in docs:
            context += f"{doc.page_content}\n\n\n"
            source = doc.metadata
            source = source["source"]
            sources.append(source)
        if context == "":
            context = "No context found"
        else:
            pass
        sources = list(set(sources))
        temp_context = context
        return context



    def answer_query(self, query: str, vectorstore: str, llmModel: str = "llama-3.1-70b-versatile"):
        global sources
        global temp_context
        vector_store_name = vectorstore
        vector_store = answer_query_from_existing_collection(vector_embed=self.vector_embed,
                                                             sparse_embed=self.sparse_embed,
                                                             vectorstore=vectorstore)

        retriever = vector_store.as_retriever(search_type="mmr", search_kwargs={"k": 10, "fetch_k": 20})
        compression_retriever = ContextualCompressionRetriever(
            base_compressor = self.compressor, base_retriever = retriever
        )
        brain_chain = (
                {"context": RunnableLambda(lambda x: x["question"]) | compression_retriever | RunnableLambda(self.format_docs),
                 "question": RunnableLambda(lambda x: x["question"]),
                 "chatHistory": RunnableLambda(lambda x: x["chatHistory"])}
                | self.prompt
                | ChatGroq(model=llmModel, temperature=0.75, max_tokens=512)
                | StrOutputParser()
        )
        message_chain = RunnableWithMessageHistory(
            brain_chain,
            self.get_session_history,
            input_messages_key="question",
            history_messages_key="chatHistory"
        )
        chain = RunnablePassthrough.assign(messages_trimmed=self.trim_messages) | message_chain
        follow_up_chain = self.follow_up_prompt | ChatGroq(model_name="llama-3.1-70b-versatile",
                                                           temperature=0) | self.json_parser

        output = chain.invoke(
            {"question": query},
            {"configurable": {"session_id": vector_store_name}}
        )
        follow_up_questions = follow_up_chain.invoke({"context": temp_context})

        return output, follow_up_questions, sources

    def trim_messages(self, chain_input):
        for store_name in self.chat_history_store:
            messages = self.chat_history_store[store_name].messages
            if len(messages) <= 1:
                pass
            else:
                self.chat_history_store[store_name].clear()
                for message in messages[-1:]:
                    self.chat_history_store[store_name].add_message(message)
        return True

    def get_session_history(self, session_id: str) -> BaseChatMessageHistory:
        if session_id not in self.chat_history_store:
            self.chat_history_store[session_id] = ChatMessageHistory()
        return self.chat_history_store[session_id]