Spaces:
Sleeping
Sleeping
Update app/rag_pipeline.py
Browse files- app/rag_pipeline.py +30 -0
app/rag_pipeline.py
CHANGED
|
@@ -9,6 +9,7 @@ from langchain.embeddings import HuggingFaceEmbeddings
|
|
| 9 |
from langchain.embeddings import SentenceTransformerEmbeddings
|
| 10 |
|
| 11 |
CHROMA_DIR = "data/chroma_db"
|
|
|
|
| 12 |
|
| 13 |
load_dotenv()
|
| 14 |
HUGGINGFACE_API_KEY = os.getenv("HF_API_KEY") # set this in .env or directly
|
|
@@ -83,3 +84,32 @@ def generate_answer(query, top_k=3):
|
|
| 83 |
|
| 84 |
answer = call_huggingface_mistral(prompt)
|
| 85 |
return {"answer": answer}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
from langchain.embeddings import SentenceTransformerEmbeddings
|
| 10 |
|
| 11 |
CHROMA_DIR = "data/chroma_db"
|
| 12 |
+
CHROMA_IMG_DIR = "data/image_db"
|
| 13 |
|
| 14 |
load_dotenv()
|
| 15 |
HUGGINGFACE_API_KEY = os.getenv("HF_API_KEY") # set this in .env or directly
|
|
|
|
| 84 |
|
| 85 |
answer = call_huggingface_mistral(prompt)
|
| 86 |
return {"answer": answer}
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def generate_answer_am(query, top_k=3):
|
| 90 |
+
|
| 91 |
+
CHROMA_DIR = "data/chroma_db"
|
| 92 |
+
embeddings = SentenceTransformerEmbeddings(model_name='thenlper/gte-large')
|
| 93 |
+
|
| 94 |
+
db = Chroma(persist_directory=CHROMA_IMG_DIR,
|
| 95 |
+
embedding_function=embeddings)
|
| 96 |
+
retriever = db.as_retriever(
|
| 97 |
+
search_type='similarity',
|
| 98 |
+
search_kwargs={'k': 4}
|
| 99 |
+
)
|
| 100 |
+
relevant_document_chunks = retriever.get_relevant_documents(
|
| 101 |
+
query=query, k=top_k)
|
| 102 |
+
context_list = [d.page_content for d in relevant_document_chunks]
|
| 103 |
+
print(f'context_list: {context_list}')
|
| 104 |
+
|
| 105 |
+
# Combine document chunks into a single context
|
| 106 |
+
context = ". ".join(context_list)
|
| 107 |
+
user_message = qna_user_message_template.replace(
|
| 108 |
+
'{context}', context)
|
| 109 |
+
user_message = user_message.replace('{question}', query)
|
| 110 |
+
|
| 111 |
+
prompt = qna_system_message + '\n' + user_message
|
| 112 |
+
print(f'Prompt: {prompt}')
|
| 113 |
+
|
| 114 |
+
answer = call_huggingface_mistral(prompt)
|
| 115 |
+
return {"answer": answer}
|