| from langchain.chains.summarize import load_summarize_chain |
| from langchain.chat_models import ChatOpenAI |
| from src.prompts import ( |
| prompts, |
| prompts_parallel_summary, |
| ) |
| from src.doc_loading import load_docs |
| from src.llm_utils import async_generate_llmchain |
| import time |
| from typing import Dict, List |
| import asyncio |
|
|
|
|
| def summarize_chain( |
| file_path: str, llm: ChatOpenAI, summarization_kwargs: Dict[str, str] |
| ) -> str: |
| """Summarize a pdf file. The summarization is done by the language model. |
| |
| Args: |
| file_path (str): Path to the pdf file. This can either be a local path or a tempfile.TemporaryFileWrapper_. |
| llm (ChatOpenAI): Language model to use for the summarization. |
| |
| Returns: |
| str: Summarization of the pdf file. |
| """ |
| docs = load_docs(file_path=file_path) |
| chain = load_summarize_chain( |
| llm=llm, |
| **summarization_kwargs, |
| ) |
| summary = chain.run(docs) |
| return summary |
|
|
|
|
| def summarize_wrapper( |
| file: str, llm: ChatOpenAI, summarization_type: str, summarization_kwargs: dict |
| ) -> str: |
| """Wrapper for the summarization function to make it compatible with gradio. This function uses a |
| single summarization chain. |
| |
| Args: |
| file (str): Path to the file. This can either be a local path or a tempfile.TemporaryFileWrapper_. |
| llm (ChatOpenAI): Language model. |
| summarization_type (str): Type of summarization. Can be either "short", "middle" or "long". |
| summarization_kwargs (dict): Keyword arguments for the summarization. |
| |
| Returns: |
| str: Summarization of the file. |
| """ |
| if summarization_type == "short": |
| summarization_kwargs.update( |
| dict( |
| map_prompt=prompts["short_de"]["map_prompt"], |
| combine_prompt=prompts["short_de"]["combine_prompt"], |
| ) |
| ) |
| elif summarization_type == "middle": |
| summarization_kwargs.update( |
| dict( |
| map_prompt=prompts["middle_de"]["map_prompt"], |
| combine_prompt=prompts["middle_de"]["combine_prompt"], |
| ) |
| ) |
| elif summarization_type == "long": |
| summarization_kwargs.update( |
| dict( |
| map_prompt=prompts["long_de"]["map_prompt"], |
| combine_prompt=prompts["long_de"]["combine_prompt"], |
| ) |
| ) |
| else: |
| raise ValueError(f"Summarization type {summarization_type} is not supported.") |
|
|
| return summarize_chain( |
| file_path=file.name, llm=llm[0], summarization_kwargs=summarization_kwargs |
| ) |
|
|
|
|
| async def generate_summary_concurrently( |
| file_path: str, sections: List[str], llm: ChatOpenAI |
| ) -> List[dict]: |
| """Parallel summarization. This function is used to run different prompts for the same docs in parallel. |
| |
| Args: |
| file_path (str): Path to the pdf file. This can either be a local path or a tempfile.TemporaryFileWrapper_. |
| sections (List[str]): List of sections to summarize selected by the user. |
| llm (ChatOpenAI): Language model to use for the summarization. |
| |
| Returns: |
| List: List of summarizations. |
| """ |
|
|
| docs = load_docs(file_path=file_path, with_pageinfo=False) |
| summarization_kwargs = dict() |
|
|
| |
| tasks = [] |
| for k in PARALLEL_SUMMARIZATION_ORDER: |
| if PARALLEL_SUMMARIZATION_MAPPING_INVERSE.get(k, k) in sections: |
| sk = summarization_kwargs.copy() |
| sk["prompt"] = prompts_parallel_summary[k] |
| print(f"Appending task for summary: {k}") |
| tasks.append( |
| async_generate_llmchain(llm=llm, docs=docs, llm_kwargs=sk, k=k) |
| ) |
| print("-------------------") |
| |
| values = await asyncio.gather(*tasks) |
|
|
| |
| values_flattened = {} |
| for v in values: |
| values_flattened.update(v) |
| return values_flattened |
|
|
|
|
| PARALLEL_SUMMARIZATION_ORDER = [ |
| "intro", |
| "darstellung_des_rechtsproblems", |
| "II. Die Entscheidung", |
| "angaben_ueber_das_urteil", |
| "sachverhalt", |
| "prozessgeschichte", |
| "rechtsproblem", |
| "loesung_des_gerichts", |
| ] |
| PARALLEL_SUMMARIZATION_MAPPING = { |
| "I. Einleitung": "intro", |
| "Darstellung des Rechtsproblems": "darstellung_des_rechtsproblems", |
| "Angaben über das Urteil": "angaben_ueber_das_urteil", |
| "Sachverhalt": "sachverhalt", |
| "Prozessgeschichte": "prozessgeschichte", |
| "Rechtsproblem": "rechtsproblem", |
| "Lösung des Gerichts": "loesung_des_gerichts", |
| } |
| PARALLEL_SUMMARIZATION_MAPPING_INVERSE = { |
| v: k for k, v in PARALLEL_SUMMARIZATION_MAPPING.items() |
| } |
|
|
|
|
| def parallel_summarization(file: str, sections: List[str], llm: ChatOpenAI) -> str: |
| """Wrapper for the parallel summarization function to make it compatible with gradio. |
| |
| Args: |
| file (str): Path to the file. This can either be a local path or a tempfile.TemporaryFileWrapper_. |
| sections (List[str]): List of sections to summarize. |
| llm (ChatOpenAI): Language model. |
| |
| Returns: |
| str: Summarization of the file. |
| """ |
| now = time.time() |
|
|
| values_flattened = asyncio.run( |
| generate_summary_concurrently( |
| file_path=file.name, sections=sections, llm=llm[0] |
| ) |
| ) |
|
|
| print("Time taken for complete parallel summarization: ", time.time() - now) |
| output = "" |
|
|
| for section in values_flattened.keys(): |
| output += ( |
| values_flattened.get( |
| section, PARALLEL_SUMMARIZATION_MAPPING_INVERSE.get(section, section) |
| ) |
| + "\n\n" |
| ) |
|
|
| return output |
|
|