from llm.base_llm import BaseLLM from src.utils import load_config from langchain_groq import ChatGroq from langchain.prompts import PromptTemplate from langchain.memory import ConversationBufferWindowMemory from langchain.chains import ConversationalRetrievalChain class GroqAnswerGenerator(BaseLLM): def __init__(self, model_name: str, temperature: float, max_tokens: int, retriever=None): self.retriever = retriever self.config = load_config("./configs/llm_producer.yaml") self.model = ChatGroq( model=model_name, temperature=temperature, max_tokens=max_tokens ) self.prompt_template = PromptTemplate.from_template( self.config["prompt_template"] ) self.memory = ConversationBufferWindowMemory( memory_key="chat_history", # required by ConversationalRetrievalChain return_messages=True, k=self.config["memory_window"], ) self.qa_chain = ConversationalRetrievalChain.from_llm( llm=self.model, retriever=self.retriever, memory=self.memory, chain_type="stuff", combine_docs_chain_kwargs={ "prompt": self.prompt_template} ) def generate_answer(self, prompt: str): return self.qa_chain.run(question=prompt)