| | import sys |
| | import os |
| | from contextlib import contextmanager |
| |
|
| | from langchain_core.tools import tool |
| | from langchain_core.runnables import chain |
| | from langchain_core.runnables import RunnableParallel, RunnablePassthrough |
| | from langchain_core.runnables import RunnableLambda |
| |
|
| | from ..reranker import rerank_docs, rerank_and_sort_docs |
| | |
| | from ...knowledge.openalex import OpenAlexRetriever |
| | from .keywords_extraction import make_keywords_extraction_chain |
| | from ..utils import log_event |
| | from langchain_core.vectorstores import VectorStore |
| | from typing import List |
| | from langchain_core.documents.base import Document |
| | from ..llm import get_llm |
| | from .prompts import retrieve_chapter_prompt_template |
| | from langchain_core.prompts import ChatPromptTemplate |
| | from langchain_core.output_parsers import StrOutputParser |
| | from ..vectorstore import get_azure_search_vectorstore |
| | from ..embeddings import get_embeddings_function |
| | import ast |
| |
|
| | import asyncio |
| |
|
| | from typing import Any, Dict, List, Tuple |
| |
|
| |
|
| | def divide_into_parts(target, parts): |
| | |
| | base = target // parts |
| | |
| | remainder = target % parts |
| | |
| | result = [] |
| | |
| | for i in range(parts): |
| | if i < remainder: |
| | |
| | result.append(base + 1) |
| | else: |
| | |
| | result.append(base) |
| | |
| | return result |
| |
|
| |
|
| | @contextmanager |
| | def suppress_output(): |
| | |
| | with open(os.devnull, 'w') as devnull: |
| | |
| | old_stdout = sys.stdout |
| | old_stderr = sys.stderr |
| | |
| | sys.stdout = devnull |
| | sys.stderr = devnull |
| | try: |
| | yield |
| | finally: |
| | |
| | sys.stdout = old_stdout |
| | sys.stderr = old_stderr |
| |
|
| |
|
| | @tool |
| | def query_retriever(question): |
| | """Just a dummy tool to simulate the retriever query""" |
| | return question |
| |
|
| | def _add_sources_used_in_metadata(docs,sources,question,index): |
| | for doc in docs: |
| | doc.metadata["sources_used"] = sources |
| | doc.metadata["question_used"] = question |
| | doc.metadata["index_used"] = index |
| | return docs |
| |
|
| | def _get_k_summary_by_question(n_questions): |
| | if n_questions == 0: |
| | return 0 |
| | elif n_questions == 1: |
| | return 5 |
| | elif n_questions == 2: |
| | return 3 |
| | elif n_questions == 3: |
| | return 2 |
| | else: |
| | return 1 |
| | |
| | def _get_k_images_by_question(n_questions): |
| | if n_questions == 0: |
| | return 0 |
| | elif n_questions == 1: |
| | return 7 |
| | elif n_questions == 2: |
| | return 5 |
| | elif n_questions == 3: |
| | return 3 |
| | else: |
| | return 1 |
| | |
| | def _add_metadata_and_score(docs: List) -> Document: |
| | |
| | docs_with_metadata = [] |
| | for i,(doc,score) in enumerate(docs): |
| | doc.page_content = doc.page_content.replace("\r\n"," ") |
| | doc.metadata["similarity_score"] = score |
| | doc.metadata["content"] = doc.page_content |
| | if doc.metadata["page_number"] is not None and doc.metadata["page_number"] != "N/A": |
| | doc.metadata["page_number"] = int(doc.metadata["page_number"]) + 1 |
| | else: |
| | doc.metadata["page_number"] = 1 |
| | |
| | docs_with_metadata.append(doc) |
| | return docs_with_metadata |
| |
|
| | def remove_duplicates_chunks(docs): |
| | |
| | docs = sorted(docs,key=lambda x: x[1],reverse=True) |
| | seen = set() |
| | result = [] |
| | for doc in docs: |
| | if doc[0].page_content not in seen: |
| | seen.add(doc[0].page_content) |
| | result.append(doc) |
| | return result |
| |
|
| | def get_ToCs(version: str) : |
| |
|
| | filters_text = { |
| | "chunk_type":"toc", |
| | "version": version |
| | } |
| | embeddings_function = get_embeddings_function() |
| | vectorstore = get_azure_search_vectorstore(embeddings=embeddings_function, index_name="climateqa") |
| | tocs = vectorstore.similarity_search_with_score(query="",filter = filters_text) |
| |
|
| | |
| | tocs = remove_duplicates_chunks(tocs) |
| | |
| | return tocs |
| |
|
| | async def get_POC_relevant_documents( |
| | query: str, |
| | vectorstore:VectorStore, |
| | sources:list = ["Acclimaterra","PCAET","Plan Biodiversite"], |
| | search_figures:bool = False, |
| | search_only:bool = False, |
| | k_documents:int = 10, |
| | threshold:float = 0.6, |
| | k_images: int = 5, |
| | reports:list = [], |
| | min_size:int = 200, |
| | ) : |
| | |
| | filters = {} |
| | docs_question = [] |
| | docs_images = [] |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | filters_text = { |
| | **filters, |
| | "chunk_type":"text", |
| | |
| | } |
| | |
| | docs_question = vectorstore.similarity_search_with_score(query=query,filter = filters_text,k = k_documents) |
| | |
| | docs_question = remove_duplicates_chunks(docs_question) |
| | docs_question = [x for x in docs_question if x[1] > threshold] |
| | |
| | if search_figures: |
| | |
| | filters_image = { |
| | **filters, |
| | "chunk_type":"image" |
| | } |
| | docs_images = vectorstore.similarity_search_with_score(query=query,filter = filters_image,k = k_images) |
| | |
| | docs_question, docs_images = _add_metadata_and_score(docs_question), _add_metadata_and_score(docs_images) |
| | |
| | docs_question = [x for x in docs_question if len(x.page_content) > min_size] |
| | |
| | return { |
| | "docs_question" : docs_question, |
| | "docs_images" : docs_images |
| | } |
| |
|
| | async def get_POC_documents_by_ToC_relevant_documents( |
| | query: str, |
| | tocs: list, |
| | vectorstore:VectorStore, |
| | version: str, |
| | sources:list = ["Acclimaterra","PCAET","Plan Biodiversite"], |
| | search_figures:bool = False, |
| | search_only:bool = False, |
| | k_documents:int = 10, |
| | threshold:float = 0.6, |
| | k_images: int = 5, |
| | reports:list = [], |
| | min_size:int = 200, |
| | proportion: float = 0.5, |
| | ) : |
| | """ |
| | Args: |
| | - tocs : list with the table of contents of each document |
| | - version : version of the parsed documents (e.g. "v4") |
| | - proportion : share of documents retrieved using ToCs |
| | """ |
| | |
| | filters = {} |
| | docs_question = [] |
| | docs_images = [] |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | k_documents_toc = round(k_documents * proportion) |
| | |
| | relevant_tocs = await get_relevant_toc_level_for_query(query, tocs) |
| | |
| | print(f"Relevant ToCs : {relevant_tocs}") |
| | |
| | toc_filters = [toc['chapter'] for toc in relevant_tocs] |
| |
|
| | filters_text_toc = { |
| | **filters, |
| | "chunk_type":"text", |
| | "toc_level0": toc_filters, |
| | "version": version |
| | |
| | } |
| | |
| | docs_question = vectorstore.similarity_search_with_score(query=query,filter = filters_text_toc,k = k_documents_toc) |
| |
|
| | filters_text = { |
| | **filters, |
| | "chunk_type":"text", |
| | "version": version |
| | |
| | } |
| |
|
| | docs_question += vectorstore.similarity_search_with_score(query=query,filter = filters_text,k = k_documents - k_documents_toc) |
| |
|
| | |
| | docs_question = remove_duplicates_chunks(docs_question) |
| | docs_question = [x for x in docs_question if x[1] > threshold] |
| | |
| | if search_figures: |
| | |
| | filters_image = { |
| | **filters, |
| | "chunk_type":"image" |
| | } |
| | docs_images = vectorstore.similarity_search_with_score(query=query,filter = filters_image,k = k_images) |
| | |
| | docs_question, docs_images = _add_metadata_and_score(docs_question), _add_metadata_and_score(docs_images) |
| | |
| | docs_question = [x for x in docs_question if len(x.page_content) > min_size] |
| | |
| | return { |
| | "docs_question" : docs_question, |
| | "docs_images" : docs_images |
| | } |
| | |
| | def filter_for_full_report_documents(filters: dict) -> dict: |
| | """ |
| | Filter for full report documents. |
| | Returns a dictionary format compatible with all vectorstore providers. |
| | """ |
| | |
| | full_filters = filters.copy() |
| | |
| | |
| | full_filters["chunk_type"] = "text" |
| | |
| | |
| | |
| | full_filters["report_type_exclude"] = ["SPM"] |
| | |
| | return full_filters |
| |
|
| | async def get_IPCC_relevant_documents( |
| | query: str, |
| | vectorstore:VectorStore, |
| | sources:list = ["IPCC","IPBES","IPOS"], |
| | search_figures:bool = False, |
| | reports:list = [], |
| | threshold:float = 0.6, |
| | k_summary:int = 3, |
| | k_total:int = 10, |
| | k_images: int = 5, |
| | namespace:str = "vectors", |
| | min_size:int = 200, |
| | search_only:bool = False, |
| | ) : |
| |
|
| | |
| | assert isinstance(sources,list) |
| | assert sources |
| | assert all([x in ["IPCC","IPBES","IPOS"] for x in sources]) |
| | assert k_total > k_summary, "k_total should be greater than k_summary" |
| |
|
| | |
| | filters = {} |
| |
|
| | if len(reports) > 0: |
| | filters["short_name"] = reports |
| | else: |
| | filters["source"] = sources |
| |
|
| | |
| | docs_summaries = [] |
| | docs_full = [] |
| | docs_images = [] |
| |
|
| | if search_only: |
| | |
| | if search_figures: |
| | filters_image = { |
| | **filters, |
| | "chunk_type":"image" |
| | } |
| | docs_images = vectorstore.similarity_search_with_score(query=query,filter = filters_image,k = k_images) |
| | docs_images = _add_metadata_and_score(docs_images) |
| | else: |
| | |
| | |
| | filters_summaries = { |
| | **filters, |
| | "chunk_type":"text", |
| | "report_type": ["SPM"], |
| | } |
| |
|
| | docs_summaries = vectorstore.similarity_search_with_score(query=query,filter = filters_summaries,k = k_summary) |
| | docs_summaries = [x for x in docs_summaries if x[1] > threshold] |
| |
|
| | |
| | filters_full = filter_for_full_report_documents(filters) |
| | |
| | |
| | docs_full = vectorstore.similarity_search_with_score(query=query,filter = filters_full,k = k_total) |
| | |
| | if search_figures: |
| | |
| | filters_image = { |
| | **filters, |
| | "chunk_type":"image" |
| | } |
| | docs_images = vectorstore.similarity_search_with_score(query=query,filter = filters_image,k = k_images) |
| |
|
| | docs_summaries, docs_full, docs_images = _add_metadata_and_score(docs_summaries), _add_metadata_and_score(docs_full), _add_metadata_and_score(docs_images) |
| | |
| | |
| | docs_summaries = [x for x in docs_summaries if len(x.page_content) > min_size] |
| | docs_full = [x for x in docs_full if len(x.page_content) > min_size] |
| | |
| | return { |
| | "docs_summaries" : docs_summaries, |
| | "docs_full" : docs_full, |
| | "docs_images" : docs_images, |
| | } |
| |
|
| |
|
| | |
| | def concatenate_documents(index, source_type, docs_question_dict, k_by_question, k_summary_by_question, k_images_by_question): |
| | |
| | if source_type == "IPx": |
| | docs_question = docs_question_dict["docs_summaries"][:k_summary_by_question] + docs_question_dict["docs_full"][:(k_by_question - k_summary_by_question)] |
| | elif source_type == "POC" : |
| | docs_question = docs_question_dict["docs_question"][:k_by_question] |
| | else : |
| | raise ValueError("source_type should be either Vector or POC") |
| | |
| | |
| | images_question = docs_question_dict["docs_images"][:k_images_by_question] |
| | |
| | return docs_question, images_question |
| |
|
| |
|
| |
|
| | |
| | |
| | async def retrieve_documents( |
| | current_question: Dict[str, Any], |
| | config: Dict[str, Any], |
| | source_type: str, |
| | vectorstore: VectorStore, |
| | reranker: Any, |
| | version: str = "", |
| | search_figures: bool = False, |
| | search_only: bool = False, |
| | reports: list = [], |
| | rerank_by_question: bool = True, |
| | k_images_by_question: int = 5, |
| | k_before_reranking: int = 100, |
| | k_by_question: int = 5, |
| | k_summary_by_question: int = 3, |
| | tocs: list = [], |
| | by_toc=False |
| | ) -> Tuple[List[Document], List[Document]]: |
| | """ |
| | Unpack the first question of the remaining questions, and retrieve and rerank corresponding documents, based on the question and selected_sources |
| | |
| | Args: |
| | state (dict): The current state containing documents, related content, relevant content sources, remaining questions and n_questions. |
| | current_question (dict): The current question being processed. |
| | config (dict): Configuration settings for logging and other purposes. |
| | vectorstore (object): The vector store used to retrieve relevant documents. |
| | reranker (object): The reranker used to rerank the retrieved documents. |
| | llm (object): The language model used for processing. |
| | rerank_by_question (bool, optional): Whether to rerank documents by question. Defaults to True. |
| | k_final (int, optional): The final number of documents to retrieve. Defaults to 15. |
| | k_before_reranking (int, optional): The number of documents to retrieve before reranking. Defaults to 100. |
| | k_summary (int, optional): The number of summary documents to retrieve. Defaults to 5. |
| | k_images (int, optional): The number of image documents to retrieve. Defaults to 5. |
| | Returns: |
| | dict: The updated state containing the retrieved and reranked documents, related content, and remaining questions. |
| | """ |
| | sources = current_question["sources"] |
| | question = current_question["question"] |
| | index = current_question["index"] |
| | source_type = current_question["source_type"] |
| | |
| | print(f"Retrieve documents for question: {question}") |
| | await log_event({"question":question,"sources":sources,"index":index},"log_retriever",config) |
| |
|
| | print(f"""---- Retrieve documents from {current_question["source_type"]}----""") |
| |
|
| | |
| | if source_type == "IPx": |
| | docs_question_dict = await get_IPCC_relevant_documents( |
| | query = question, |
| | vectorstore=vectorstore, |
| | search_figures = search_figures, |
| | sources = sources, |
| | min_size = 200, |
| | k_summary = k_before_reranking-1, |
| | k_total = k_before_reranking, |
| | k_images = k_images_by_question, |
| | threshold = 0.5, |
| | search_only = search_only, |
| | reports = reports, |
| | ) |
| | |
| | if source_type == 'POC': |
| | if by_toc == True: |
| | print("---- Retrieve documents by ToC----") |
| | docs_question_dict = await get_POC_documents_by_ToC_relevant_documents( |
| | query=question, |
| | tocs = tocs, |
| | vectorstore=vectorstore, |
| | version=version, |
| | search_figures = search_figures, |
| | sources = sources, |
| | threshold = 0.5, |
| | search_only = search_only, |
| | reports = reports, |
| | min_size= 200, |
| | k_documents= k_before_reranking, |
| | k_images= k_by_question |
| | ) |
| | else : |
| | docs_question_dict = await get_POC_relevant_documents( |
| | query = question, |
| | vectorstore=vectorstore, |
| | search_figures = search_figures, |
| | sources = sources, |
| | threshold = 0.5, |
| | search_only = search_only, |
| | reports = reports, |
| | min_size= 200, |
| | k_documents= k_before_reranking, |
| | k_images= k_by_question |
| | ) |
| |
|
| | |
| | if reranker is not None and rerank_by_question: |
| | with suppress_output(): |
| | for key in docs_question_dict.keys(): |
| | docs_question_dict[key] = rerank_and_sort_docs(reranker,docs_question_dict[key],question) |
| | else: |
| | |
| | for key in docs_question_dict.keys(): |
| | if isinstance(docs_question_dict[key], list) and len(docs_question_dict[key]) > 0: |
| | for doc in docs_question_dict[key]: |
| | doc.metadata["reranking_score"] = doc.metadata["similarity_score"] |
| | |
| | |
| | docs_question, images_question = concatenate_documents(index, source_type, docs_question_dict, k_by_question, k_summary_by_question, k_images_by_question) |
| | |
| | |
| | if reranker is not None and rerank_by_question: |
| | docs_question = rerank_and_sort_docs(reranker, docs_question, question) |
| |
|
| | |
| | docs_question = _add_sources_used_in_metadata(docs_question,sources,question,index) |
| | images_question = _add_sources_used_in_metadata(images_question,sources,question,index) |
| | |
| | return docs_question, images_question |
| | |
| |
|
| | async def retrieve_documents_for_all_questions( |
| | search_figures, |
| | search_only, |
| | reports, |
| | questions_list, |
| | n_questions, |
| | config, |
| | source_type, |
| | to_handle_questions_index, |
| | vectorstore, |
| | reranker, |
| | rerank_by_question=True, |
| | k_final=15, |
| | k_before_reranking=100, |
| | version: str = "", |
| | tocs: list[dict] = [], |
| | by_toc: bool = False |
| | ): |
| | """ |
| | Retrieve documents in parallel for all questions. |
| | """ |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | k_by_question = k_final // n_questions |
| | k_summary_by_question = _get_k_summary_by_question(n_questions) |
| | k_images_by_question = _get_k_images_by_question(n_questions) |
| | k_before_reranking=100 |
| | |
| | print(f"Source type here is {source_type}") |
| | tasks = [ |
| | retrieve_documents( |
| | current_question=question, |
| | config=config, |
| | source_type=source_type, |
| | vectorstore=vectorstore, |
| | reranker=reranker, |
| | search_figures=search_figures, |
| | search_only=search_only, |
| | reports=reports, |
| | rerank_by_question=rerank_by_question, |
| | k_images_by_question=k_images_by_question, |
| | k_before_reranking=k_before_reranking, |
| | k_by_question=k_by_question, |
| | k_summary_by_question=k_summary_by_question, |
| | tocs=tocs, |
| | version=version, |
| | by_toc=by_toc |
| | ) |
| | for i, question in enumerate(questions_list) if i in to_handle_questions_index |
| | ] |
| | results = await asyncio.gather(*tasks) |
| | |
| | new_state = {"documents": [], "related_contents": [], "handled_questions_index": to_handle_questions_index} |
| | for docs_question, images_question in results: |
| | new_state["documents"].extend(docs_question) |
| | new_state["related_contents"].extend(images_question) |
| | return new_state |
| |
|
| | |
| | async def get_relevant_toc_level_for_query( |
| | query: str, |
| | tocs: list[Document], |
| | ) -> list[dict] : |
| |
|
| | doc_list = [] |
| | for doc in tocs: |
| | doc_name = doc[0].metadata['name'] |
| | toc = doc[0].page_content |
| | doc_list.append({'document': doc_name, 'toc': toc}) |
| |
|
| | llm = get_llm(provider="openai",max_tokens = 1024,temperature = 0.0) |
| |
|
| | prompt = ChatPromptTemplate.from_template(retrieve_chapter_prompt_template) |
| | chain = prompt | llm | StrOutputParser() |
| | response = chain.invoke({"query": query, "doc_list": doc_list}) |
| |
|
| | try: |
| | relevant_tocs = ast.literal_eval(response) |
| | except Exception as e: |
| | print(f" Failed to parse the result because of : {e}") |
| |
|
| | return relevant_tocs |
| |
|
| |
|
| | def make_IPx_retriever_node(vectorstore,reranker,llm,rerank_by_question=True, k_final=15, k_before_reranking=100, k_summary=5): |
| | |
| | async def retrieve_IPx_docs(state, config): |
| | source_type = "IPx" |
| | IPx_questions_index = [i for i, x in enumerate(state["questions_list"]) if x["source_type"] == "IPx"] |
| |
|
| | search_figures = "Figures (IPCC/IPBES)" in state["relevant_content_sources_selection"] |
| | search_only = state["search_only"] |
| | reports = state["reports"] |
| | questions_list = state["questions_list"] |
| | n_questions=state["n_questions"]["total"] |
| | |
| | state = await retrieve_documents_for_all_questions( |
| | search_figures=search_figures, |
| | search_only=search_only, |
| | reports=reports, |
| | questions_list=questions_list, |
| | n_questions=n_questions, |
| | config=config, |
| | source_type=source_type, |
| | to_handle_questions_index=IPx_questions_index, |
| | vectorstore=vectorstore, |
| | reranker=reranker, |
| | rerank_by_question=rerank_by_question, |
| | k_final=k_final, |
| | k_before_reranking=k_before_reranking, |
| | ) |
| | return state |
| | |
| | return retrieve_IPx_docs |
| |
|
| |
|
| | def make_POC_retriever_node(vectorstore,reranker,llm,rerank_by_question=True, k_final=15, k_before_reranking=100, k_summary=5): |
| | |
| | async def retrieve_POC_docs_node(state, config): |
| | source_type = "POC" |
| | POC_questions_index = [i for i, x in enumerate(state["questions_list"]) if x["source_type"] == "POC"] |
| | |
| | search_figures = "Figures (IPCC/IPBES)" in state["relevant_content_sources_selection"] |
| | search_only = state["search_only"] |
| | reports = state["reports"] |
| | questions_list = state["questions_list"] |
| | n_questions=state["n_questions"]["total"] |
| | |
| | state = await retrieve_documents_for_all_questions( |
| | search_figures=search_figures, |
| | search_only=search_only, |
| | reports=reports, |
| | questions_list=questions_list, |
| | n_questions=n_questions, |
| | config=config, |
| | source_type=source_type, |
| | to_handle_questions_index=POC_questions_index, |
| | vectorstore=vectorstore, |
| | reranker=reranker, |
| | rerank_by_question=rerank_by_question, |
| | k_final=k_final, |
| | k_before_reranking=k_before_reranking, |
| | ) |
| | return state |
| | |
| | return retrieve_POC_docs_node |
| |
|
| |
|
| | def make_POC_by_ToC_retriever_node( |
| | vectorstore: VectorStore, |
| | reranker, |
| | llm, |
| | version: str = "", |
| | rerank_by_question=True, |
| | k_final=15, |
| | k_before_reranking=100, |
| | k_summary=5, |
| | ): |
| | |
| | async def retrieve_POC_docs_node(state, config): |
| | search_figures = "Figures (IPCC/IPBES)" in state["relevant_content_sources_selection"] |
| | search_only = state["search_only"] |
| | search_only = state["search_only"] |
| | reports = state["reports"] |
| | questions_list = state["questions_list"] |
| | n_questions=state["n_questions"]["total"] |
| |
|
| | tocs = get_ToCs(version=version) |
| |
|
| | source_type = "POC" |
| | POC_questions_index = [i for i, x in enumerate(state["questions_list"]) if x["source_type"] == "POC"] |
| | |
| | state = await retrieve_documents_for_all_questions( |
| | search_figures=search_figures, |
| | search_only=search_only, |
| | config=config, |
| | reports=reports, |
| | questions_list=questions_list, |
| | n_questions=n_questions, |
| | source_type=source_type, |
| | to_handle_questions_index=POC_questions_index, |
| | vectorstore=vectorstore, |
| | reranker=reranker, |
| | rerank_by_question=rerank_by_question, |
| | k_final=k_final, |
| | k_before_reranking=k_before_reranking, |
| | tocs=tocs, |
| | version=version, |
| | by_toc=True |
| | ) |
| | return state |
| | |
| | return retrieve_POC_docs_node |
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|