File size: 2,913 Bytes
633bb91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

import warnings
from config import Config
from llm.gemini_llm import GeminiLLM
from memory.chat_memory import MemoryManager
from langchain_core.prompts import ChatPromptTemplate
from retriever.qdrant_retriever import QdrantRetriever
from langchain.chains.retrieval import create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_core.messages import HumanMessage, AIMessage, BaseMessage, SystemMessage
warnings.filterwarnings("ignore", category=DeprecationWarning)


class RAGPipeline:
    def __init__(self):
        self.retriever = QdrantRetriever()
        self.memory = MemoryManager()
        self.llm = GeminiLLM().get_client()

        self.prompt = self._load_prompt(Config.RAG_PROMPT)
        self.qa_chain = create_stuff_documents_chain(self.llm, self.prompt)
        self.chain = create_retrieval_chain(self.retriever, self.qa_chain)

    def _load_prompt(self, path: str) -> ChatPromptTemplate:
        if not os.path.exists(path):
            raise FileNotFoundError(f"Prompt file not found: {path}")
        with open(path, "r", encoding="utf-8") as f:
            system_prompt = f.read()

        return ChatPromptTemplate.from_messages([
            ("system", "{chat_history}\n\n" + system_prompt),
            ("human", "{input}")
        ])

    def messages_to_string(self, messages: list[BaseMessage]) -> str:
        history = []
        for msg in messages:
            if isinstance(msg, HumanMessage):
                role = "user"
            elif isinstance(msg, AIMessage):
                role = "assistant"
            elif isinstance(msg, SystemMessage):
                role = "system"
            else:
                role = "unknown"
            history.append(f"{role}: {msg.content}")
        return "\n".join(history)

    def ask(self, query: str) -> str:
        session_id = Config.SESSION_ID

        # Get conversation history and format it
        history_messages = self.memory.get(session_id)
        chat_history_str = self.messages_to_string(history_messages)

        # Prepare inputs for the chain
        inputs = {
            "input": query,
            "chat_history": chat_history_str.strip()
        }

        # Invoke RAG chain
        response = self.chain.invoke(inputs)

        # Extract final answer
        answer = response["answer"]

        # Save interaction to memory
        self.memory.add(session_id, HumanMessage(content=query))
        self.memory.add(session_id, AIMessage(content=answer))

        return answer


if __name__ == "__main__":
    rag = RAGPipeline()
    query1 = "What is the full form of K12HSN?"
    query2 = "What does the abbreviation stand for?"

    response1 = rag.ask(query1)
    print(f"Q1: {query1}\nA1: {response1}")

    response2 = rag.ask(query2)
    print(f"Q2: {query2}\nA2: {response2}")