logan-codes's picture
changed the data dir to hf compatible
ea1e7dc
from langchain_chroma.vectorstores import Chroma
from langchain_huggingface.embeddings import HuggingFaceEmbeddings
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from dotenv import load_dotenv
import os
class Retriever:
def __init__(self, embedding_model:HuggingFaceEmbeddings=None):
self.embed= embedding_model if embedding_model else HuggingFaceEmbeddings(
model_name="sentence-transformers/all-MiniLM-L6-v2"
)
load_dotenv()
base_dir = os.getenv("HF_HOME", "/home/user/app")
self.DATA_DIR = os.path.join(base_dir, "data")
self.vector_store=Chroma(
collection_name="documents_collection",
embedding_function=self.embed,
persist_directory=os.path.join(self.DATA_DIR,"chroma_db")
)
self.GEMINI_API_KEY = os.getenv("GOOGLE_API_KEY")
if self.GEMINI_API_KEY is None:
raise ValueError("GOOGLE_API_KEY not found in environment variables.")
def _retrieve_chunks(self,query:str):
retrieved_chunks = self.vector_store.similarity_search(query,k=3)
return retrieved_chunks
def _query_transformer(self,query:str):
template= """You are an AI language model assistant. Your task is to generate three
different versions of the given user question to retrieve relevant documents from a vector
database. By generating multiple perspectives on the user question, your goal is to help
the user overcome some of the limitations of the distance-based similarity search.
Provide these alternative questions separated by newlines. Original question: {question}"""
prompt = ChatPromptTemplate.from_template(template)
llm = ChatGoogleGenerativeAI(model="gemini-2.5-flash-lite",temperature=0.7)
chain= (prompt
| llm
| StrOutputParser()
| (lambda x: x.strip().split("\n")) # Split the output into a list of questions
)
response= chain.invoke({"question": query})
return response
def retrieve_context(self, query: str):
transformed_queries = self._query_transformer(query)
all_retrieved_chunks = []
for tq in transformed_queries:
chunks = self._retrieve_chunks(tq)
for chunk in chunks:
if chunk not in all_retrieved_chunks:
all_retrieved_chunks.append(chunk)
context=""
for idx, doc in enumerate(all_retrieved_chunks):
context+=(f"Context {idx+1}:\n{doc.page_content}\n{'-'*50}\n")
return context
if __name__ == "__main__":
retriever_instance = Retriever()
# results = retriever_instance.retrieve_chunks("Sample query")
# print(results)
# transformed_response = retriever_instance.query_transformer("tell me about the history of AI and its applications in healthcare and finance")
# print(transformed_response)
context = retriever_instance.retrieve_context("how does the ocr work in docling?")
print(context)