Spaces:
Sleeping
Sleeping
File size: 1,369 Bytes
9c37331 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 | from llm.base_llm import BaseLLM
from src.utils import load_config
from langchain_groq import ChatGroq
from langchain.prompts import PromptTemplate
from langchain.memory import ConversationBufferWindowMemory
from langchain.chains import ConversationalRetrievalChain
class GroqAnswerGenerator(BaseLLM):
def __init__(self, model_name: str, temperature: float, max_tokens: int, retriever=None):
self.retriever = retriever
self.config = load_config("./configs/llm_producer.yaml")
self.model = ChatGroq(
model=model_name,
temperature=temperature,
max_tokens=max_tokens
)
self.prompt_template = PromptTemplate.from_template(
self.config["prompt_template"]
)
self.memory = ConversationBufferWindowMemory(
memory_key="chat_history", # required by ConversationalRetrievalChain
return_messages=True,
k=self.config["memory_window"],
)
self.qa_chain = ConversationalRetrievalChain.from_llm(
llm=self.model,
retriever=self.retriever,
memory=self.memory,
chain_type="stuff",
combine_docs_chain_kwargs={
"prompt": self.prompt_template}
)
def generate_answer(self, prompt: str):
return self.qa_chain.run(question=prompt)
|