HelpDev / app /rag_pipeline.py
Dewasheesh's picture
Update app/rag_pipeline.py
ad6e44b verified
import os
from dotenv import load_dotenv
import requests
from langchain.vectorstores import Chroma
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.llms import HuggingFacePipeline
from langchain.vectorstores import Chroma
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.embeddings import SentenceTransformerEmbeddings
CHROMA_DIR = "data/chroma_db"
CHROMA_IMG_DIR = "data/image_db"
load_dotenv()
HUGGINGFACE_API_KEY = os.getenv("HF_API_KEY") # set this in .env or directly
qna_system_message = """
You are an assistant whose work is to review the report and provide the appropriate answers from the context.
User input will have the context required by you to answer user questions.
This context will begin with the token: ###Context.
The context contains references to specific portions of a document relevant to the user query.
User questions will begin with the token: ###Question.
Please answer only using the context provided in the input. Do not mention anything about the context in your final answer.
If the answer is not found in the context, respond "I don't know".
"""
qna_user_message_template = """
###Context
Here are some documents that are relevant to the question mentioned below.
{context}
###Question
{question}
"""
def call_huggingface_mistral(prompt: str):
api_url = "https://router.huggingface.co/featherless-ai/v1/chat/completions"
headers = {"Authorization": f"Bearer {HUGGINGFACE_API_KEY}"}
payload = {
"messages": [
{
"role": "user",
"content": prompt
}
],
"model": 'mistralai/Mistral-7B-Instruct-v0.2'
}
response = requests.post(api_url, headers=headers, json=payload)
if response.status_code != 200:
return f"[Error {response.status_code}] {response.text}"
return response.json()["choices"][0]["message"]
def generate_answer(query, top_k=3):
CHROMA_DIR = "data/chroma_db"
embeddings = SentenceTransformerEmbeddings(model_name='thenlper/gte-large')
db = Chroma(persist_directory=CHROMA_DIR, embedding_function=embeddings)
retriever = db.as_retriever(
search_type='similarity',
search_kwargs={'k': 4}
)
relevant_document_chunks = retriever.get_relevant_documents(
query=query, k=top_k)
context_list = [d.page_content for d in relevant_document_chunks]
print(f'context_list: {context_list}')
# Combine document chunks into a single context
context = ". ".join(context_list)
user_message = qna_user_message_template.replace(
'{context}', context)
user_message = user_message.replace('{question}', query)
prompt = qna_system_message + '\n' + user_message
print(f'Prompt: {prompt}')
answer = call_huggingface_mistral(prompt)
return {"answer": answer}
def generate_answer_am(query, top_k=3):
CHROMA_DIR = "data/chroma_db"
embeddings = SentenceTransformerEmbeddings(model_name='thenlper/gte-large')
db = Chroma(persist_directory=CHROMA_IMG_DIR,
embedding_function=embeddings)
retriever = db.as_retriever(
search_type='similarity',
search_kwargs={'k': 4}
)
relevant_document_chunks = retriever.get_relevant_documents(
query=query, k=top_k)
context_list = [d.page_content for d in relevant_document_chunks]
print(f'context_list: {context_list}')
# Combine document chunks into a single context
context = ". ".join(context_list)
user_message = qna_user_message_template.replace(
'{context}', context)
user_message = user_message.replace('{question}', query)
prompt = qna_system_message + '\n' + user_message
print(f'Prompt: {prompt}')
answer = call_huggingface_mistral(prompt)
return {"answer": answer}