File size: 1,369 Bytes
9c37331
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
from llm.base_llm import BaseLLM
from src.utils import load_config
from langchain_groq import ChatGroq
from langchain.prompts import PromptTemplate
from langchain.memory import ConversationBufferWindowMemory
from langchain.chains import ConversationalRetrievalChain

class GroqAnswerGenerator(BaseLLM):
    def __init__(self, model_name: str, temperature: float, max_tokens: int, retriever=None):

        self.retriever = retriever
        self.config = load_config("./configs/llm_producer.yaml")
        self.model = ChatGroq(
            model=model_name,
            temperature=temperature,
            max_tokens=max_tokens
        )

        self.prompt_template = PromptTemplate.from_template(
            self.config["prompt_template"]
        )

        self.memory = ConversationBufferWindowMemory(
            memory_key="chat_history",  # required by ConversationalRetrievalChain 
            return_messages=True,
            k=self.config["memory_window"],
        )
        self.qa_chain = ConversationalRetrievalChain.from_llm(
            llm=self.model,
            retriever=self.retriever, 
            memory=self.memory,
            chain_type="stuff",
            combine_docs_chain_kwargs={
                "prompt": self.prompt_template}
        )

    def generate_answer(self, prompt: str):
        return self.qa_chain.run(question=prompt)