Spaces:
Paused
Paused
| import datetime | |
| import os | |
| import traceback | |
| from typing import Any, Coroutine | |
| from dotenv import load_dotenv | |
| from langchain.chains import LLMChain | |
| from langchain.chains.combine_documents import create_stuff_documents_chain | |
| from langchain.chains.retrieval import create_retrieval_chain | |
| from langchain.retrievers import MultiQueryRetriever, MergerRetriever, ContextualCompressionRetriever, EnsembleRetriever | |
| from langchain_cohere import CohereRerank | |
| from langchain_core.documents import Document | |
| from langchain_core.prompts import PromptTemplate, BasePromptTemplate | |
| from agent.Agent import Agent | |
| from agent.agents import chat_openai_llm, deepinfra_chat | |
| from conversation.conversation_store import ConversationStore | |
| from prompt.prompt_store import PromptStore | |
| from retrieval import retrieve_with_rerank | |
| load_dotenv() | |
| conversation_store = ConversationStore() | |
| prompt_store = PromptStore() | |
| grammar_check_1 = prompt_store.get_by_name("gramar_check_1").text | |
| rewrite_hyde_1 = prompt_store.get_by_name("rewrite_hyde_1").text | |
| rewrite_hyde_2 = prompt_store.get_by_name("rewrite_hyde_2").text | |
| rewrite_1 = prompt_store.get_by_name("rewrite_1").text | |
| rewrite_2 = prompt_store.get_by_name("rewrite_2").text | |
| rewrite_hyde = prompt_store.get_by_name("rewrite_hyde").text | |
| def replace_nl(input: str) -> str: | |
| return input.replace('\r\n', '<br>').replace('\n', '<br>').replace('\r', '<br>') | |
| def rewrite(agent: Agent, q: str, prompt: str) -> list[str]: | |
| prompt_template = PromptTemplate( | |
| input_variables=["question"], | |
| template=prompt | |
| ) | |
| llm_chain = LLMChain( | |
| llm=agent.llm, | |
| prompt=prompt_template, | |
| verbose=False | |
| ) | |
| questions = llm_chain.invoke( | |
| input={"question": q} | |
| )["text"].splitlines() | |
| return [x for x in questions if ("##" not in x and len(str(x).strip()) > 0)] | |
| def rag_with_rerank_check_rewrite_hyde(agent: Agent, q: str, retrieve_document_count: int, prompt: str, | |
| check_prompt: str, | |
| rewrite_prompt: str): | |
| rewritten_list: list[str] = rewrite(agent, q, rewrite_prompt) | |
| if len(rewritten_list) == 0: | |
| return "Neviem, nemám podklady!", "", "" | |
| context_doc = retrieve_subqueries_hyde(agent, retrieve_document_count, rewritten_list) | |
| if len(context_doc) == 0: | |
| return "Neviem, nemám kontext!", "", "" | |
| result = answer_pipeline(agent, context_doc, prompt, q) | |
| answer = result["text"] | |
| check_result = check_pipeline(answer, check_prompt, context_doc, q) | |
| return answer, check_result, context_doc | |
| def rag_with_rerank_check_multi_query_retriever(agent: Agent, q: str, retrieve_document_count: int, prompt: str, | |
| check_prompt: str): | |
| context_doc = hyde_retrieval(agent, retrieve_document_count).invoke( | |
| input=q, | |
| kwargs={"k": retrieve_document_count} | |
| ) | |
| if len(context_doc) == 0: | |
| return "Neviem, nemám kontext!", "", "" | |
| result = answer_pipeline(agent, context_doc, prompt, q) | |
| answer = result["text"] | |
| check_result = check_pipeline(answer, check_prompt, context_doc, q) | |
| return answer, check_result, context_doc | |
| async def rag_chain(agent: Agent, q: str, retrieve_document_count: int, prompt: str, | |
| check_prompt: str): | |
| result = await create_retrieval_chain( | |
| retriever=hyde_2_retrieval(agent, retrieve_document_count), | |
| combine_docs_chain=create_stuff_documents_chain( | |
| llm=agent.llm, | |
| prompt=PromptTemplate( | |
| input_variables=["context", "question", "actual_date"], | |
| template=prompt | |
| ), | |
| document_prompt=PromptTemplate(input_variables=[], template="page_content") | |
| ) | |
| ).ainvoke( | |
| input={ | |
| "question": q, | |
| "input": q, | |
| "actual_date": datetime.date.today().isoformat() | |
| } | |
| ) | |
| print(result) | |
| check_result = check_pipeline(result["answer"], check_prompt, result["context"], q) | |
| print(check_result) | |
| return result["answer"], check_result, result["context"] | |
| def vanilla_rag_chain(agent: Agent, q: str, retrieve_document_count: int, prompt: str, | |
| check_prompt: str): | |
| retriever = ContextualCompressionRetriever( | |
| base_compressor=(CohereRerank( | |
| model="rerank-multilingual-v3.0", | |
| top_n=retrieve_document_count | |
| )), | |
| base_retriever=(agent.embedding.get_vector_store().as_retriever( | |
| search_type="similarity", | |
| search_kwargs={"k": min(retrieve_document_count * 10, 500)}, | |
| )) | |
| ) | |
| result = create_retrieval_chain( | |
| retriever=retriever, | |
| combine_docs_chain=create_stuff_documents_chain( | |
| llm=agent.llm, | |
| prompt=PromptTemplate( | |
| input_variables=["context", "question", "actual_date"], | |
| template=prompt | |
| ), | |
| document_prompt=PromptTemplate(input_variables=[], template="page_content") | |
| ) | |
| ).invoke( | |
| input={ | |
| "question": q, | |
| "input": q, | |
| "actual_date": datetime.date.today().isoformat() | |
| } | |
| ) | |
| print(result) | |
| check_result = check_pipeline(result["answer"], check_prompt, result["context"], q) | |
| print(check_result) | |
| return result["answer"], check_result, result["context"] | |
| def hyde_retrieval(agent, retrieve_document_count): | |
| retriever_1 = MultiQueryRetriever.from_llm( | |
| llm=agent.llm, | |
| retriever=agent.embedding.get_vector_store().as_retriever( | |
| search_type="similarity", | |
| search_kwargs={"k": retrieve_document_count} | |
| ), | |
| prompt=PromptTemplate( | |
| input_variables=["question"], | |
| template=rewrite_hyde_1 | |
| ) | |
| ) | |
| retriever_2 = MultiQueryRetriever.from_llm( | |
| llm=agent.llm, | |
| retriever=agent.embedding.get_vector_store().as_retriever( | |
| search_type="similarity", | |
| search_kwargs={"k": retrieve_document_count} | |
| ), | |
| prompt=PromptTemplate( | |
| input_variables=["question"], | |
| template=rewrite_hyde_2 | |
| ) | |
| ) | |
| merge_retriever = MergerRetriever( | |
| retrievers=[retriever_1, retriever_2], | |
| ) | |
| compressor = CohereRerank( | |
| model="rerank-multilingual-v3.0", | |
| top_n=retrieve_document_count | |
| ) | |
| compression_retriever = ContextualCompressionRetriever( | |
| base_compressor=compressor, | |
| base_retriever=merge_retriever, | |
| search_kwargs={"k": retrieve_document_count}, | |
| ) | |
| return compression_retriever | |
| def hyde_2_retrieval(agent, retrieve_document_count): | |
| compressor = CohereRerank( | |
| model="rerank-multilingual-v3.0", | |
| top_n=retrieve_document_count / 2 | |
| ) | |
| retriever_1 = MultiQueryRetriever.from_llm( | |
| llm=agent.llm, | |
| retriever=agent.embedding.get_vector_store().as_retriever( | |
| search_type="similarity", | |
| search_kwargs={"k": min(retrieve_document_count * 10, 300)} | |
| ), | |
| prompt=PromptTemplate( | |
| input_variables=["question"], | |
| template=rewrite_1 | |
| ) | |
| ) | |
| compression_retriever_1 = ContextualCompressionRetriever( | |
| base_compressor=compressor, | |
| base_retriever=retriever_1 | |
| ) | |
| retriever_2 = MultiQueryRetriever.from_llm( | |
| llm=agent.llm, | |
| retriever=agent.embedding.get_vector_store().as_retriever( | |
| search_type="similarity", | |
| search_kwargs={"k": min(retrieve_document_count * 10, 300)} | |
| ), | |
| prompt=PromptTemplate( | |
| input_variables=["question"], | |
| template=rewrite_2 | |
| ) | |
| ) | |
| compression_retriever_2 = ContextualCompressionRetriever( | |
| base_compressor=compressor, | |
| base_retriever=retriever_2 | |
| ) | |
| retriever_3 = MultiQueryRetriever.from_llm( | |
| llm=agent.llm, | |
| retriever=agent.embedding.get_vector_store().as_retriever( | |
| search_type="similarity", | |
| search_kwargs={"k": min(retrieve_document_count * 10, 300)} | |
| ), | |
| prompt=PromptTemplate( | |
| input_variables=["question"], | |
| template=rewrite_hyde | |
| ) | |
| ) | |
| compression_retriever_3 = ContextualCompressionRetriever( | |
| base_compressor=compressor, | |
| base_retriever=retriever_3 | |
| ) | |
| merge_retriever = EnsembleRetriever( | |
| retrievers=[compression_retriever_1, compression_retriever_2, compression_retriever_3], | |
| weights=[1.0, 1.0, 1.0] | |
| ) | |
| return merge_retriever | |
| def retrieve_subqueries(agent, retrieve_document_count, rewritten_list) -> list[Document]: | |
| contexts: list[Document] = [] | |
| for rewritten in rewritten_list: | |
| contexts.extend(retrieve_with_rerank(agent.embedding, rewritten, retrieve_document_count)) | |
| contexts.sort(key=lambda x: -x.metadata["relevance_score"]) | |
| deduplicated: list[Document] = [] | |
| for doc in contexts: | |
| already_in = False | |
| for de_doc in deduplicated: | |
| if doc.page_content == de_doc.page_content: | |
| already_in = True | |
| if not already_in: | |
| deduplicated.append(doc) | |
| return deduplicated[:retrieve_document_count] | |
| def retrieve_subqueries_hyde(agent, retrieve_document_count, rewritten_list) -> list[Document]: | |
| contexts: list[Document] = [] | |
| for rewritten in rewritten_list: | |
| answer = agent.llm.invoke(rewritten).content | |
| contexts.extend(retrieve_with_rerank(agent.embedding, rewritten + "\n" + answer, retrieve_document_count)) | |
| contexts.sort(key=lambda x: -x.metadata["relevance_score"]) | |
| deduplicated: list[Document] = [] | |
| for doc in contexts: | |
| already_in = False | |
| for de_doc in deduplicated: | |
| if doc.page_content == de_doc.page_content: | |
| already_in = True | |
| if not already_in: | |
| deduplicated.append(doc) | |
| return deduplicated[:retrieve_document_count] | |
| def answer_pipeline(agent, context_doc, prompt, q): | |
| prompt_template = PromptTemplate( | |
| input_variables=["context", "question"], | |
| template=prompt | |
| ) | |
| llm_chain = LLMChain( | |
| llm=agent.llm, | |
| prompt=prompt_template, | |
| verbose=False | |
| ) | |
| result: dict[str, Any] = llm_chain.invoke( | |
| input={ | |
| "question": q, | |
| "context": context_doc, | |
| "actual_date": datetime.date.today().isoformat() | |
| } | |
| ) | |
| return result | |
| def check_pipeline(answer, check_prompt, context_doc, q): | |
| prompt_template = PromptTemplate( | |
| input_variables=["context", "question", "answer"], | |
| template=check_prompt | |
| ) | |
| llm_chain = LLMChain( | |
| llm=deepinfra_chat("meta-llama/Meta-Llama-3-70B-Instruct", "0.4"), | |
| prompt=prompt_template, | |
| verbose=False | |
| ) | |
| try: | |
| check_result = llm_chain.invoke( | |
| input={ | |
| "question": q[:2000], | |
| "context": context_doc, | |
| "answer": answer | |
| } | |
| )["text"] | |
| except Exception as e: | |
| check_result = traceback.format_exc() | |
| return check_result | |
| def rag_with_rerank(agent: Agent, q: str, retrieve_document_count: int, prompt: str = None, check_prompt: str = None): | |
| context_doc: list[Document] = retrieve_with_rerank(agent.embedding, q, retrieve_document_count) | |
| try: | |
| result: dict[str, Any] = answer_pipeline(agent, context_doc, prompt, q) | |
| answer = result["text"] | |
| check_result = "" | |
| if check_prompt is not None: | |
| check_result = check_pipeline(answer, check_prompt, context_doc, q) | |
| return answer, check_result, context_doc | |
| except Exception as e: | |
| return "", traceback.format_exc(), "" | |
| def save_conversation(answer: str, check_result: str, context_doc: list[Document], gramatika: str, question: str, | |
| prompt_id: str, check_prompt_id: str, grammar_prompt_id: str): | |
| if len(answer) > 0: | |
| conversation_store.save_content( | |
| q=question, | |
| a=answer, | |
| sources=list(map(lambda doc: doc.page_content, context_doc)), | |
| params= | |
| { | |
| "prompt_id": prompt_id, | |
| "check_prompt_id": check_prompt_id, | |
| "grammar_prompt_id": grammar_prompt_id, | |
| "check_result": check_result, | |
| "grammar_result": gramatika, | |
| "temperature": os.environ["temperature"], | |
| } | |
| ) | |
| def check_slovak_agent(text: str) -> str: | |
| prompt_template = PromptTemplate( | |
| input_variables=["text"], | |
| template=grammar_check_1 | |
| ) | |
| llm_chain = LLMChain( | |
| llm=chat_openai_llm(), | |
| prompt=prompt_template, | |
| verbose=False | |
| ) | |
| result: dict[str, Any] = llm_chain.invoke(input={"text": text}) | |
| return result["text"] | |