import datetime
import os
import traceback
from typing import Any, Coroutine
from dotenv import load_dotenv
from langchain.chains import LLMChain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.chains.retrieval import create_retrieval_chain
from langchain.retrievers import MultiQueryRetriever, MergerRetriever, ContextualCompressionRetriever, EnsembleRetriever
from langchain_cohere import CohereRerank
from langchain_core.documents import Document
from langchain_core.prompts import PromptTemplate, BasePromptTemplate
from agent.Agent import Agent
from agent.agents import chat_openai_llm, deepinfra_chat
from conversation.conversation_store import ConversationStore
from prompt.prompt_store import PromptStore
from retrieval import retrieve_with_rerank
load_dotenv()
conversation_store = ConversationStore()
prompt_store = PromptStore()
grammar_check_1 = prompt_store.get_by_name("gramar_check_1").text
rewrite_hyde_1 = prompt_store.get_by_name("rewrite_hyde_1").text
rewrite_hyde_2 = prompt_store.get_by_name("rewrite_hyde_2").text
rewrite_1 = prompt_store.get_by_name("rewrite_1").text
rewrite_2 = prompt_store.get_by_name("rewrite_2").text
rewrite_hyde = prompt_store.get_by_name("rewrite_hyde").text
def replace_nl(input: str) -> str:
return input.replace('\r\n', '
').replace('\n', '
').replace('\r', '
')
def rewrite(agent: Agent, q: str, prompt: str) -> list[str]:
prompt_template = PromptTemplate(
input_variables=["question"],
template=prompt
)
llm_chain = LLMChain(
llm=agent.llm,
prompt=prompt_template,
verbose=False
)
questions = llm_chain.invoke(
input={"question": q}
)["text"].splitlines()
return [x for x in questions if ("##" not in x and len(str(x).strip()) > 0)]
def rag_with_rerank_check_rewrite_hyde(agent: Agent, q: str, retrieve_document_count: int, prompt: str,
check_prompt: str,
rewrite_prompt: str):
rewritten_list: list[str] = rewrite(agent, q, rewrite_prompt)
if len(rewritten_list) == 0:
return "Neviem, nemám podklady!", "", ""
context_doc = retrieve_subqueries_hyde(agent, retrieve_document_count, rewritten_list)
if len(context_doc) == 0:
return "Neviem, nemám kontext!", "", ""
result = answer_pipeline(agent, context_doc, prompt, q)
answer = result["text"]
check_result = check_pipeline(answer, check_prompt, context_doc, q)
return answer, check_result, context_doc
def rag_with_rerank_check_multi_query_retriever(agent: Agent, q: str, retrieve_document_count: int, prompt: str,
check_prompt: str):
context_doc = hyde_retrieval(agent, retrieve_document_count).invoke(
input=q,
kwargs={"k": retrieve_document_count}
)
if len(context_doc) == 0:
return "Neviem, nemám kontext!", "", ""
result = answer_pipeline(agent, context_doc, prompt, q)
answer = result["text"]
check_result = check_pipeline(answer, check_prompt, context_doc, q)
return answer, check_result, context_doc
async def rag_chain(agent: Agent, q: str, retrieve_document_count: int, prompt: str,
check_prompt: str):
result = await create_retrieval_chain(
retriever=hyde_2_retrieval(agent, retrieve_document_count),
combine_docs_chain=create_stuff_documents_chain(
llm=agent.llm,
prompt=PromptTemplate(
input_variables=["context", "question", "actual_date"],
template=prompt
),
document_prompt=PromptTemplate(input_variables=[], template="page_content")
)
).ainvoke(
input={
"question": q,
"input": q,
"actual_date": datetime.date.today().isoformat()
}
)
print(result)
check_result = check_pipeline(result["answer"], check_prompt, result["context"], q)
print(check_result)
return result["answer"], check_result, result["context"]
def vanilla_rag_chain(agent: Agent, q: str, retrieve_document_count: int, prompt: str,
check_prompt: str):
retriever = ContextualCompressionRetriever(
base_compressor=(CohereRerank(
model="rerank-multilingual-v3.0",
top_n=retrieve_document_count
)),
base_retriever=(agent.embedding.get_vector_store().as_retriever(
search_type="similarity",
search_kwargs={"k": min(retrieve_document_count * 10, 500)},
))
)
result = create_retrieval_chain(
retriever=retriever,
combine_docs_chain=create_stuff_documents_chain(
llm=agent.llm,
prompt=PromptTemplate(
input_variables=["context", "question", "actual_date"],
template=prompt
),
document_prompt=PromptTemplate(input_variables=[], template="page_content")
)
).invoke(
input={
"question": q,
"input": q,
"actual_date": datetime.date.today().isoformat()
}
)
print(result)
check_result = check_pipeline(result["answer"], check_prompt, result["context"], q)
print(check_result)
return result["answer"], check_result, result["context"]
def hyde_retrieval(agent, retrieve_document_count):
retriever_1 = MultiQueryRetriever.from_llm(
llm=agent.llm,
retriever=agent.embedding.get_vector_store().as_retriever(
search_type="similarity",
search_kwargs={"k": retrieve_document_count}
),
prompt=PromptTemplate(
input_variables=["question"],
template=rewrite_hyde_1
)
)
retriever_2 = MultiQueryRetriever.from_llm(
llm=agent.llm,
retriever=agent.embedding.get_vector_store().as_retriever(
search_type="similarity",
search_kwargs={"k": retrieve_document_count}
),
prompt=PromptTemplate(
input_variables=["question"],
template=rewrite_hyde_2
)
)
merge_retriever = MergerRetriever(
retrievers=[retriever_1, retriever_2],
)
compressor = CohereRerank(
model="rerank-multilingual-v3.0",
top_n=retrieve_document_count
)
compression_retriever = ContextualCompressionRetriever(
base_compressor=compressor,
base_retriever=merge_retriever,
search_kwargs={"k": retrieve_document_count},
)
return compression_retriever
def hyde_2_retrieval(agent, retrieve_document_count):
compressor = CohereRerank(
model="rerank-multilingual-v3.0",
top_n=retrieve_document_count / 2
)
retriever_1 = MultiQueryRetriever.from_llm(
llm=agent.llm,
retriever=agent.embedding.get_vector_store().as_retriever(
search_type="similarity",
search_kwargs={"k": min(retrieve_document_count * 10, 300)}
),
prompt=PromptTemplate(
input_variables=["question"],
template=rewrite_1
)
)
compression_retriever_1 = ContextualCompressionRetriever(
base_compressor=compressor,
base_retriever=retriever_1
)
retriever_2 = MultiQueryRetriever.from_llm(
llm=agent.llm,
retriever=agent.embedding.get_vector_store().as_retriever(
search_type="similarity",
search_kwargs={"k": min(retrieve_document_count * 10, 300)}
),
prompt=PromptTemplate(
input_variables=["question"],
template=rewrite_2
)
)
compression_retriever_2 = ContextualCompressionRetriever(
base_compressor=compressor,
base_retriever=retriever_2
)
retriever_3 = MultiQueryRetriever.from_llm(
llm=agent.llm,
retriever=agent.embedding.get_vector_store().as_retriever(
search_type="similarity",
search_kwargs={"k": min(retrieve_document_count * 10, 300)}
),
prompt=PromptTemplate(
input_variables=["question"],
template=rewrite_hyde
)
)
compression_retriever_3 = ContextualCompressionRetriever(
base_compressor=compressor,
base_retriever=retriever_3
)
merge_retriever = EnsembleRetriever(
retrievers=[compression_retriever_1, compression_retriever_2, compression_retriever_3],
weights=[1.0, 1.0, 1.0]
)
return merge_retriever
def retrieve_subqueries(agent, retrieve_document_count, rewritten_list) -> list[Document]:
contexts: list[Document] = []
for rewritten in rewritten_list:
contexts.extend(retrieve_with_rerank(agent.embedding, rewritten, retrieve_document_count))
contexts.sort(key=lambda x: -x.metadata["relevance_score"])
deduplicated: list[Document] = []
for doc in contexts:
already_in = False
for de_doc in deduplicated:
if doc.page_content == de_doc.page_content:
already_in = True
if not already_in:
deduplicated.append(doc)
return deduplicated[:retrieve_document_count]
def retrieve_subqueries_hyde(agent, retrieve_document_count, rewritten_list) -> list[Document]:
contexts: list[Document] = []
for rewritten in rewritten_list:
answer = agent.llm.invoke(rewritten).content
contexts.extend(retrieve_with_rerank(agent.embedding, rewritten + "\n" + answer, retrieve_document_count))
contexts.sort(key=lambda x: -x.metadata["relevance_score"])
deduplicated: list[Document] = []
for doc in contexts:
already_in = False
for de_doc in deduplicated:
if doc.page_content == de_doc.page_content:
already_in = True
if not already_in:
deduplicated.append(doc)
return deduplicated[:retrieve_document_count]
def answer_pipeline(agent, context_doc, prompt, q):
prompt_template = PromptTemplate(
input_variables=["context", "question"],
template=prompt
)
llm_chain = LLMChain(
llm=agent.llm,
prompt=prompt_template,
verbose=False
)
result: dict[str, Any] = llm_chain.invoke(
input={
"question": q,
"context": context_doc,
"actual_date": datetime.date.today().isoformat()
}
)
return result
def check_pipeline(answer, check_prompt, context_doc, q):
prompt_template = PromptTemplate(
input_variables=["context", "question", "answer"],
template=check_prompt
)
llm_chain = LLMChain(
llm=deepinfra_chat("meta-llama/Meta-Llama-3-70B-Instruct", "0.4"),
prompt=prompt_template,
verbose=False
)
try:
check_result = llm_chain.invoke(
input={
"question": q[:2000],
"context": context_doc,
"answer": answer
}
)["text"]
except Exception as e:
check_result = traceback.format_exc()
return check_result
def rag_with_rerank(agent: Agent, q: str, retrieve_document_count: int, prompt: str = None, check_prompt: str = None):
context_doc: list[Document] = retrieve_with_rerank(agent.embedding, q, retrieve_document_count)
try:
result: dict[str, Any] = answer_pipeline(agent, context_doc, prompt, q)
answer = result["text"]
check_result = ""
if check_prompt is not None:
check_result = check_pipeline(answer, check_prompt, context_doc, q)
return answer, check_result, context_doc
except Exception as e:
return "", traceback.format_exc(), ""
def save_conversation(answer: str, check_result: str, context_doc: list[Document], gramatika: str, question: str,
prompt_id: str, check_prompt_id: str, grammar_prompt_id: str):
if len(answer) > 0:
conversation_store.save_content(
q=question,
a=answer,
sources=list(map(lambda doc: doc.page_content, context_doc)),
params=
{
"prompt_id": prompt_id,
"check_prompt_id": check_prompt_id,
"grammar_prompt_id": grammar_prompt_id,
"check_result": check_result,
"grammar_result": gramatika,
"temperature": os.environ["temperature"],
}
)
def check_slovak_agent(text: str) -> str:
prompt_template = PromptTemplate(
input_variables=["text"],
template=grammar_check_1
)
llm_chain = LLMChain(
llm=chat_openai_llm(),
prompt=prompt_template,
verbose=False
)
result: dict[str, Any] = llm_chain.invoke(input={"text": text})
return result["text"]