Spaces:
Sleeping
Sleeping
| from llm.base_llm import BaseLLM | |
| from src.utils import load_config | |
| from langchain_groq import ChatGroq | |
| from langchain.prompts import PromptTemplate | |
| from langchain.memory import ConversationBufferWindowMemory | |
| from langchain.chains import ConversationalRetrievalChain | |
| class GroqAnswerGenerator(BaseLLM): | |
| def __init__(self, model_name: str, temperature: float, max_tokens: int, retriever=None): | |
| self.retriever = retriever | |
| self.config = load_config("./configs/llm_producer.yaml") | |
| self.model = ChatGroq( | |
| model=model_name, | |
| temperature=temperature, | |
| max_tokens=max_tokens | |
| ) | |
| self.prompt_template = PromptTemplate.from_template( | |
| self.config["prompt_template"] | |
| ) | |
| self.memory = ConversationBufferWindowMemory( | |
| memory_key="chat_history", # required by ConversationalRetrievalChain | |
| return_messages=True, | |
| k=self.config["memory_window"], | |
| ) | |
| self.qa_chain = ConversationalRetrievalChain.from_llm( | |
| llm=self.model, | |
| retriever=self.retriever, | |
| memory=self.memory, | |
| chain_type="stuff", | |
| combine_docs_chain_kwargs={ | |
| "prompt": self.prompt_template} | |
| ) | |
| def generate_answer(self, prompt: str): | |
| return self.qa_chain.run(question=prompt) | |