from langchain.chains.summarize import load_summarize_chain from langchain.chat_models import ChatOpenAI from src.prompts import ( prompts, prompts_parallel_summary, ) from src.doc_loading import load_docs from src.llm_utils import async_generate_llmchain import time from typing import Dict, List import asyncio def summarize_chain( file_path: str, llm: ChatOpenAI, summarization_kwargs: Dict[str, str] ) -> str: """Summarize a pdf file. The summarization is done by the language model. Args: file_path (str): Path to the pdf file. This can either be a local path or a tempfile.TemporaryFileWrapper_. llm (ChatOpenAI): Language model to use for the summarization. Returns: str: Summarization of the pdf file. """ docs = load_docs(file_path=file_path) chain = load_summarize_chain( llm=llm, **summarization_kwargs, ) summary = chain.run(docs) return summary def summarize_wrapper( file: str, llm: ChatOpenAI, summarization_type: str, summarization_kwargs: dict ) -> str: """Wrapper for the summarization function to make it compatible with gradio. This function uses a single summarization chain. Args: file (str): Path to the file. This can either be a local path or a tempfile.TemporaryFileWrapper_. llm (ChatOpenAI): Language model. summarization_type (str): Type of summarization. Can be either "short", "middle" or "long". summarization_kwargs (dict): Keyword arguments for the summarization. Returns: str: Summarization of the file. """ if summarization_type == "short": summarization_kwargs.update( dict( map_prompt=prompts["short_de"]["map_prompt"], combine_prompt=prompts["short_de"]["combine_prompt"], ) ) elif summarization_type == "middle": summarization_kwargs.update( dict( map_prompt=prompts["middle_de"]["map_prompt"], combine_prompt=prompts["middle_de"]["combine_prompt"], ) ) elif summarization_type == "long": summarization_kwargs.update( dict( map_prompt=prompts["long_de"]["map_prompt"], combine_prompt=prompts["long_de"]["combine_prompt"], ) ) else: raise ValueError(f"Summarization type {summarization_type} is not supported.") return summarize_chain( file_path=file.name, llm=llm[0], summarization_kwargs=summarization_kwargs ) async def generate_summary_concurrently( file_path: str, sections: List[str], llm: ChatOpenAI ) -> List[dict]: """Parallel summarization. This function is used to run different prompts for the same docs in parallel. Args: file_path (str): Path to the pdf file. This can either be a local path or a tempfile.TemporaryFileWrapper_. sections (List[str]): List of sections to summarize selected by the user. llm (ChatOpenAI): Language model to use for the summarization. Returns: List: List of summarizations. """ docs = load_docs(file_path=file_path, with_pageinfo=False) summarization_kwargs = dict() # create parallel tasks tasks = [] for k in PARALLEL_SUMMARIZATION_ORDER: if PARALLEL_SUMMARIZATION_MAPPING_INVERSE.get(k, k) in sections: sk = summarization_kwargs.copy() sk["prompt"] = prompts_parallel_summary[k] print(f"Appending task for summary: {k}") tasks.append( async_generate_llmchain(llm=llm, docs=docs, llm_kwargs=sk, k=k) ) print("-------------------") # execute all coroutines concurrently values = await asyncio.gather(*tasks) # report return values values_flattened = {} for v in values: values_flattened.update(v) return values_flattened PARALLEL_SUMMARIZATION_ORDER = [ "intro", "darstellung_des_rechtsproblems", "II. Die Entscheidung", "angaben_ueber_das_urteil", "sachverhalt", "prozessgeschichte", "rechtsproblem", "loesung_des_gerichts", ] PARALLEL_SUMMARIZATION_MAPPING = { "I. Einleitung": "intro", "Darstellung des Rechtsproblems": "darstellung_des_rechtsproblems", "Angaben über das Urteil": "angaben_ueber_das_urteil", "Sachverhalt": "sachverhalt", "Prozessgeschichte": "prozessgeschichte", "Rechtsproblem": "rechtsproblem", "Lösung des Gerichts": "loesung_des_gerichts", } PARALLEL_SUMMARIZATION_MAPPING_INVERSE = { v: k for k, v in PARALLEL_SUMMARIZATION_MAPPING.items() } def parallel_summarization(file: str, sections: List[str], llm: ChatOpenAI) -> str: """Wrapper for the parallel summarization function to make it compatible with gradio. Args: file (str): Path to the file. This can either be a local path or a tempfile.TemporaryFileWrapper_. sections (List[str]): List of sections to summarize. llm (ChatOpenAI): Language model. Returns: str: Summarization of the file. """ now = time.time() values_flattened = asyncio.run( generate_summary_concurrently( file_path=file.name, sections=sections, llm=llm[0] ) ) print("Time taken for complete parallel summarization: ", time.time() - now) output = "" for section in values_flattened.keys(): output += ( values_flattened.get( section, PARALLEL_SUMMARIZATION_MAPPING_INVERSE.get(section, section) ) + "\n\n" ) return output