import datetime import os import traceback from typing import Any, Coroutine from dotenv import load_dotenv from langchain.chains import LLMChain from langchain.chains.combine_documents import create_stuff_documents_chain from langchain.chains.retrieval import create_retrieval_chain from langchain.retrievers import MultiQueryRetriever, MergerRetriever, ContextualCompressionRetriever, EnsembleRetriever from langchain_cohere import CohereRerank from langchain_core.documents import Document from langchain_core.prompts import PromptTemplate, BasePromptTemplate from agent.Agent import Agent from agent.agents import chat_openai_llm, deepinfra_chat from conversation.conversation_store import ConversationStore from prompt.prompt_store import PromptStore from retrieval import retrieve_with_rerank load_dotenv() conversation_store = ConversationStore() prompt_store = PromptStore() grammar_check_1 = prompt_store.get_by_name("gramar_check_1").text rewrite_hyde_1 = prompt_store.get_by_name("rewrite_hyde_1").text rewrite_hyde_2 = prompt_store.get_by_name("rewrite_hyde_2").text rewrite_1 = prompt_store.get_by_name("rewrite_1").text rewrite_2 = prompt_store.get_by_name("rewrite_2").text rewrite_hyde = prompt_store.get_by_name("rewrite_hyde").text def replace_nl(input: str) -> str: return input.replace('\r\n', '
').replace('\n', '
').replace('\r', '
') def rewrite(agent: Agent, q: str, prompt: str) -> list[str]: prompt_template = PromptTemplate( input_variables=["question"], template=prompt ) llm_chain = LLMChain( llm=agent.llm, prompt=prompt_template, verbose=False ) questions = llm_chain.invoke( input={"question": q} )["text"].splitlines() return [x for x in questions if ("##" not in x and len(str(x).strip()) > 0)] def rag_with_rerank_check_rewrite_hyde(agent: Agent, q: str, retrieve_document_count: int, prompt: str, check_prompt: str, rewrite_prompt: str): rewritten_list: list[str] = rewrite(agent, q, rewrite_prompt) if len(rewritten_list) == 0: return "Neviem, nemám podklady!", "", "" context_doc = retrieve_subqueries_hyde(agent, retrieve_document_count, rewritten_list) if len(context_doc) == 0: return "Neviem, nemám kontext!", "", "" result = answer_pipeline(agent, context_doc, prompt, q) answer = result["text"] check_result = check_pipeline(answer, check_prompt, context_doc, q) return answer, check_result, context_doc def rag_with_rerank_check_multi_query_retriever(agent: Agent, q: str, retrieve_document_count: int, prompt: str, check_prompt: str): context_doc = hyde_retrieval(agent, retrieve_document_count).invoke( input=q, kwargs={"k": retrieve_document_count} ) if len(context_doc) == 0: return "Neviem, nemám kontext!", "", "" result = answer_pipeline(agent, context_doc, prompt, q) answer = result["text"] check_result = check_pipeline(answer, check_prompt, context_doc, q) return answer, check_result, context_doc async def rag_chain(agent: Agent, q: str, retrieve_document_count: int, prompt: str, check_prompt: str): result = await create_retrieval_chain( retriever=hyde_2_retrieval(agent, retrieve_document_count), combine_docs_chain=create_stuff_documents_chain( llm=agent.llm, prompt=PromptTemplate( input_variables=["context", "question", "actual_date"], template=prompt ), document_prompt=PromptTemplate(input_variables=[], template="page_content") ) ).ainvoke( input={ "question": q, "input": q, "actual_date": datetime.date.today().isoformat() } ) print(result) check_result = check_pipeline(result["answer"], check_prompt, result["context"], q) print(check_result) return result["answer"], check_result, result["context"] def vanilla_rag_chain(agent: Agent, q: str, retrieve_document_count: int, prompt: str, check_prompt: str): retriever = ContextualCompressionRetriever( base_compressor=(CohereRerank( model="rerank-multilingual-v3.0", top_n=retrieve_document_count )), base_retriever=(agent.embedding.get_vector_store().as_retriever( search_type="similarity", search_kwargs={"k": min(retrieve_document_count * 10, 500)}, )) ) result = create_retrieval_chain( retriever=retriever, combine_docs_chain=create_stuff_documents_chain( llm=agent.llm, prompt=PromptTemplate( input_variables=["context", "question", "actual_date"], template=prompt ), document_prompt=PromptTemplate(input_variables=[], template="page_content") ) ).invoke( input={ "question": q, "input": q, "actual_date": datetime.date.today().isoformat() } ) print(result) check_result = check_pipeline(result["answer"], check_prompt, result["context"], q) print(check_result) return result["answer"], check_result, result["context"] def hyde_retrieval(agent, retrieve_document_count): retriever_1 = MultiQueryRetriever.from_llm( llm=agent.llm, retriever=agent.embedding.get_vector_store().as_retriever( search_type="similarity", search_kwargs={"k": retrieve_document_count} ), prompt=PromptTemplate( input_variables=["question"], template=rewrite_hyde_1 ) ) retriever_2 = MultiQueryRetriever.from_llm( llm=agent.llm, retriever=agent.embedding.get_vector_store().as_retriever( search_type="similarity", search_kwargs={"k": retrieve_document_count} ), prompt=PromptTemplate( input_variables=["question"], template=rewrite_hyde_2 ) ) merge_retriever = MergerRetriever( retrievers=[retriever_1, retriever_2], ) compressor = CohereRerank( model="rerank-multilingual-v3.0", top_n=retrieve_document_count ) compression_retriever = ContextualCompressionRetriever( base_compressor=compressor, base_retriever=merge_retriever, search_kwargs={"k": retrieve_document_count}, ) return compression_retriever def hyde_2_retrieval(agent, retrieve_document_count): compressor = CohereRerank( model="rerank-multilingual-v3.0", top_n=retrieve_document_count / 2 ) retriever_1 = MultiQueryRetriever.from_llm( llm=agent.llm, retriever=agent.embedding.get_vector_store().as_retriever( search_type="similarity", search_kwargs={"k": min(retrieve_document_count * 10, 300)} ), prompt=PromptTemplate( input_variables=["question"], template=rewrite_1 ) ) compression_retriever_1 = ContextualCompressionRetriever( base_compressor=compressor, base_retriever=retriever_1 ) retriever_2 = MultiQueryRetriever.from_llm( llm=agent.llm, retriever=agent.embedding.get_vector_store().as_retriever( search_type="similarity", search_kwargs={"k": min(retrieve_document_count * 10, 300)} ), prompt=PromptTemplate( input_variables=["question"], template=rewrite_2 ) ) compression_retriever_2 = ContextualCompressionRetriever( base_compressor=compressor, base_retriever=retriever_2 ) retriever_3 = MultiQueryRetriever.from_llm( llm=agent.llm, retriever=agent.embedding.get_vector_store().as_retriever( search_type="similarity", search_kwargs={"k": min(retrieve_document_count * 10, 300)} ), prompt=PromptTemplate( input_variables=["question"], template=rewrite_hyde ) ) compression_retriever_3 = ContextualCompressionRetriever( base_compressor=compressor, base_retriever=retriever_3 ) merge_retriever = EnsembleRetriever( retrievers=[compression_retriever_1, compression_retriever_2, compression_retriever_3], weights=[1.0, 1.0, 1.0] ) return merge_retriever def retrieve_subqueries(agent, retrieve_document_count, rewritten_list) -> list[Document]: contexts: list[Document] = [] for rewritten in rewritten_list: contexts.extend(retrieve_with_rerank(agent.embedding, rewritten, retrieve_document_count)) contexts.sort(key=lambda x: -x.metadata["relevance_score"]) deduplicated: list[Document] = [] for doc in contexts: already_in = False for de_doc in deduplicated: if doc.page_content == de_doc.page_content: already_in = True if not already_in: deduplicated.append(doc) return deduplicated[:retrieve_document_count] def retrieve_subqueries_hyde(agent, retrieve_document_count, rewritten_list) -> list[Document]: contexts: list[Document] = [] for rewritten in rewritten_list: answer = agent.llm.invoke(rewritten).content contexts.extend(retrieve_with_rerank(agent.embedding, rewritten + "\n" + answer, retrieve_document_count)) contexts.sort(key=lambda x: -x.metadata["relevance_score"]) deduplicated: list[Document] = [] for doc in contexts: already_in = False for de_doc in deduplicated: if doc.page_content == de_doc.page_content: already_in = True if not already_in: deduplicated.append(doc) return deduplicated[:retrieve_document_count] def answer_pipeline(agent, context_doc, prompt, q): prompt_template = PromptTemplate( input_variables=["context", "question"], template=prompt ) llm_chain = LLMChain( llm=agent.llm, prompt=prompt_template, verbose=False ) result: dict[str, Any] = llm_chain.invoke( input={ "question": q, "context": context_doc, "actual_date": datetime.date.today().isoformat() } ) return result def check_pipeline(answer, check_prompt, context_doc, q): prompt_template = PromptTemplate( input_variables=["context", "question", "answer"], template=check_prompt ) llm_chain = LLMChain( llm=deepinfra_chat("meta-llama/Meta-Llama-3-70B-Instruct", "0.4"), prompt=prompt_template, verbose=False ) try: check_result = llm_chain.invoke( input={ "question": q[:2000], "context": context_doc, "answer": answer } )["text"] except Exception as e: check_result = traceback.format_exc() return check_result def rag_with_rerank(agent: Agent, q: str, retrieve_document_count: int, prompt: str = None, check_prompt: str = None): context_doc: list[Document] = retrieve_with_rerank(agent.embedding, q, retrieve_document_count) try: result: dict[str, Any] = answer_pipeline(agent, context_doc, prompt, q) answer = result["text"] check_result = "" if check_prompt is not None: check_result = check_pipeline(answer, check_prompt, context_doc, q) return answer, check_result, context_doc except Exception as e: return "", traceback.format_exc(), "" def save_conversation(answer: str, check_result: str, context_doc: list[Document], gramatika: str, question: str, prompt_id: str, check_prompt_id: str, grammar_prompt_id: str): if len(answer) > 0: conversation_store.save_content( q=question, a=answer, sources=list(map(lambda doc: doc.page_content, context_doc)), params= { "prompt_id": prompt_id, "check_prompt_id": check_prompt_id, "grammar_prompt_id": grammar_prompt_id, "check_result": check_result, "grammar_result": gramatika, "temperature": os.environ["temperature"], } ) def check_slovak_agent(text: str) -> str: prompt_template = PromptTemplate( input_variables=["text"], template=grammar_check_1 ) llm_chain = LLMChain( llm=chat_openai_llm(), prompt=prompt_template, verbose=False ) result: dict[str, Any] = llm_chain.invoke(input={"text": text}) return result["text"]