ScholarBot / llm /answer_generator.py
vinny4's picture
initial commit
9c37331
from llm.base_llm import BaseLLM
from src.utils import load_config
from langchain_groq import ChatGroq
from langchain.prompts import PromptTemplate
from langchain.memory import ConversationBufferWindowMemory
from langchain.chains import ConversationalRetrievalChain
class GroqAnswerGenerator(BaseLLM):
def __init__(self, model_name: str, temperature: float, max_tokens: int, retriever=None):
self.retriever = retriever
self.config = load_config("./configs/llm_producer.yaml")
self.model = ChatGroq(
model=model_name,
temperature=temperature,
max_tokens=max_tokens
)
self.prompt_template = PromptTemplate.from_template(
self.config["prompt_template"]
)
self.memory = ConversationBufferWindowMemory(
memory_key="chat_history", # required by ConversationalRetrievalChain
return_messages=True,
k=self.config["memory_window"],
)
self.qa_chain = ConversationalRetrievalChain.from_llm(
llm=self.model,
retriever=self.retriever,
memory=self.memory,
chain_type="stuff",
combine_docs_chain_kwargs={
"prompt": self.prompt_template}
)
def generate_answer(self, prompt: str):
return self.qa_chain.run(question=prompt)