locus-rag-bot / app /engine.py
khagu's picture
chore: updated engine.py with retry logic incase of failure
2911b19
import os
from dotenv import load_dotenv
from langchain_mistralai import ChatMistralAI
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_pinecone import PineconeVectorStore
from langchain_classic.chains import create_retrieval_chain
from langchain_classic.chains.combine_documents import create_stuff_documents_chain
from langchain_core.prompts import ChatPromptTemplate
load_dotenv()
def get_rag_chain():
index_name = os.getenv("PINECONE_INDEX_NAME", "locus-rag")
embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
vectorstore = PineconeVectorStore(index_name=index_name, embedding=embeddings)
llm = ChatMistralAI(
model="mistral-large-latest",
temperature=0,
max_retries=3
)
from langchain_classic.retrievers.multi_query import MultiQueryRetriever
base_retriever = vectorstore.as_retriever(search_kwargs={"k": 5})
retriever = MultiQueryRetriever.from_llm(
retriever=base_retriever,
llm=llm
)
system_prompt = (
"You are an assistant for question-answering tasks. "
"Use the following pieces of retrieved context to answer "
"the question. If you don't know the answer, say that you "
"don't know. Use three sentences maximum and keep the "
"answer concise."
"\n\n"
"{context}"
)
prompt = ChatPromptTemplate.from_messages(
[
("system", system_prompt),
("human", "{input}"),
]
)
question_answer_chain = create_stuff_documents_chain(llm, prompt)
rag_chain = create_retrieval_chain(retriever, question_answer_chain)
return rag_chain
if __name__ == "__main__":
chain = get_rag_chain()
response = chain.invoke({"input": "What is LOCUS?"})
print(response["answer"])