| import os | |
| from langchain.memory import ChatMessageHistory | |
| from langchain.retrievers import ContextualCompressionRetriever | |
| from langchain_community.document_compressors import JinaRerank | |
| from langchain_core.chat_history import BaseChatMessageHistory | |
| from langchain_core.output_parsers import StrOutputParser | |
| from langchain_core.runnables import RunnablePassthrough, RunnableLambda | |
| from langchain_core.runnables.history import RunnableWithMessageHistory | |
| from langchain_groq import ChatGroq | |
| from core.services.vector_db.qdrent.upload_document import answer_query_from_existing_collection | |
| os.environ["JINA_API_KEY"] = os.getenv("JINA_API") | |
| class AnswerQuery: | |
| def __init__(self, prompt, vector_embedding, sparse_embedding, follow_up_prompt, json_parser): | |
| self.chat_history_store = {} | |
| self.compressor = JinaRerank(model="jina-reranker-v2-base-multilingual") | |
| self.vector_embed = vector_embedding | |
| self.sparse_embed = sparse_embedding | |
| self.prompt = prompt | |
| self.follow_up_prompt = follow_up_prompt | |
| self.json_parser = json_parser | |
| def format_docs(self, docs: str): | |
| global sources | |
| global temp_context | |
| sources = [] | |
| context = "" | |
| for doc in docs: | |
| context += f"{doc.page_content}\n\n\n" | |
| source = doc.metadata | |
| source = source["source"] | |
| sources.append(source) | |
| if context == "": | |
| context = "No context found" | |
| else: | |
| pass | |
| sources = list(set(sources)) | |
| temp_context = context | |
| return context | |
| def answer_query(self, query: str, vectorstore: str, llmModel: str = "llama-3.3-70b-versatile"): | |
| global sources | |
| global temp_context | |
| vector_store_name = vectorstore | |
| vector_store = answer_query_from_existing_collection(vector_embed=self.vector_embed, | |
| sparse_embed=self.sparse_embed, | |
| vectorstore=vectorstore) | |
| retriever = vector_store.as_retriever(search_type="mmr", search_kwargs={"k": 10, "fetch_k": 20}) | |
| compression_retriever = ContextualCompressionRetriever( | |
| base_compressor=self.compressor, base_retriever=retriever | |
| ) | |
| brain_chain = ( | |
| {"context": RunnableLambda(lambda x: x["question"]) | compression_retriever | RunnableLambda( | |
| self.format_docs), | |
| "question": RunnableLambda(lambda x: x["question"]), | |
| "chatHistory": RunnableLambda(lambda x: x["chatHistory"])} | |
| | self.prompt | |
| | ChatGroq(model=llmModel, temperature=0.75, max_tokens=512) | |
| | StrOutputParser() | |
| ) | |
| message_chain = RunnableWithMessageHistory( | |
| brain_chain, | |
| self.get_session_history, | |
| input_messages_key="question", | |
| history_messages_key="chatHistory" | |
| ) | |
| chain = RunnablePassthrough.assign(messages_trimmed=self.trim_messages) | message_chain | |
| follow_up_chain = self.follow_up_prompt | ChatGroq(model_name="llama-3.3-70b-versatile", | |
| temperature=0) | self.json_parser | |
| output = chain.invoke( | |
| {"question": query}, | |
| {"configurable": {"session_id": vector_store_name}} | |
| ) | |
| follow_up_questions = follow_up_chain.invoke({"context": temp_context}) | |
| return output, follow_up_questions, sources | |
| async def answer_query_stream(self, query: str, vectorstore: str, llmModel): | |
| global sources | |
| global temp_context | |
| vector_store_name = vectorstore | |
| vector_store = answer_query_from_existing_collection( | |
| vector_embed=self.vector_embed, | |
| sparse_embed=self.sparse_embed, | |
| vectorstore=vectorstore | |
| ) | |
| retriever = vector_store.as_retriever(search_type="mmr", search_kwargs={"k": 10, "fetch_k": 20}) | |
| compression_retriever = ContextualCompressionRetriever( | |
| base_compressor=self.compressor, | |
| base_retriever=retriever | |
| ) | |
| brain_chain = ( | |
| { | |
| "context": RunnableLambda(lambda x: x["question"]) | compression_retriever | RunnableLambda( | |
| self.format_docs), | |
| "question": RunnableLambda(lambda x: x["question"]), | |
| "chatHistory": RunnableLambda(lambda x: x["chatHistory"]) | |
| } | |
| | self.prompt | |
| | ChatGroq( | |
| model=llmModel, | |
| temperature=0.75, | |
| max_tokens=512, | |
| streaming=True | |
| ) | |
| | StrOutputParser() | |
| ) | |
| message_chain = RunnableWithMessageHistory( | |
| brain_chain, | |
| self.get_session_history, | |
| input_messages_key="question", | |
| history_messages_key="chatHistory" | |
| ) | |
| chain = RunnablePassthrough.assign(messages_trimmed=self.trim_messages) | message_chain | |
| async for chunk in chain.astream( | |
| {"question": query}, | |
| {"configurable": {"session_id": vector_store_name}} | |
| ): | |
| yield { | |
| "type": "main_response", | |
| "content": chunk | |
| } | |
| follow_up_chain = self.follow_up_prompt | ChatGroq( | |
| model_name="llama-3.3-70b-versatile", | |
| temperature=0 | |
| ) | self.json_parser | |
| follow_up_questions = await follow_up_chain.ainvoke({"context": temp_context}) | |
| yield { | |
| "type": "follow_up_questions", | |
| "content": follow_up_questions | |
| } | |
| yield { | |
| "type": "sources", | |
| "content": sources | |
| } | |
| def trim_messages(self, chain_input): | |
| for store_name in self.chat_history_store: | |
| messages = self.chat_history_store[store_name].messages | |
| if len(messages) <= 1: | |
| pass | |
| else: | |
| self.chat_history_store[store_name].clear() | |
| for message in messages[-1:]: | |
| self.chat_history_store[store_name].add_message(message) | |
| return True | |
| def get_session_history(self, session_id: str) -> BaseChatMessageHistory: | |
| if session_id not in self.chat_history_store: | |
| self.chat_history_store[session_id] = ChatMessageHistory() | |
| return self.chat_history_store[session_id] | |