prototype / src /summarization.py
fvde's picture
Upload folder using huggingface_hub
799b85e
from langchain.chains.summarize import load_summarize_chain
from langchain.chat_models import ChatOpenAI
from src.prompts import (
prompts,
prompts_parallel_summary,
)
from src.doc_loading import load_docs
from src.llm_utils import async_generate_llmchain
import time
from typing import Dict, List
import asyncio
def summarize_chain(
file_path: str, llm: ChatOpenAI, summarization_kwargs: Dict[str, str]
) -> str:
"""Summarize a pdf file. The summarization is done by the language model.
Args:
file_path (str): Path to the pdf file. This can either be a local path or a tempfile.TemporaryFileWrapper_.
llm (ChatOpenAI): Language model to use for the summarization.
Returns:
str: Summarization of the pdf file.
"""
docs = load_docs(file_path=file_path)
chain = load_summarize_chain(
llm=llm,
**summarization_kwargs,
)
summary = chain.run(docs)
return summary
def summarize_wrapper(
file: str, llm: ChatOpenAI, summarization_type: str, summarization_kwargs: dict
) -> str:
"""Wrapper for the summarization function to make it compatible with gradio. This function uses a
single summarization chain.
Args:
file (str): Path to the file. This can either be a local path or a tempfile.TemporaryFileWrapper_.
llm (ChatOpenAI): Language model.
summarization_type (str): Type of summarization. Can be either "short", "middle" or "long".
summarization_kwargs (dict): Keyword arguments for the summarization.
Returns:
str: Summarization of the file.
"""
if summarization_type == "short":
summarization_kwargs.update(
dict(
map_prompt=prompts["short_de"]["map_prompt"],
combine_prompt=prompts["short_de"]["combine_prompt"],
)
)
elif summarization_type == "middle":
summarization_kwargs.update(
dict(
map_prompt=prompts["middle_de"]["map_prompt"],
combine_prompt=prompts["middle_de"]["combine_prompt"],
)
)
elif summarization_type == "long":
summarization_kwargs.update(
dict(
map_prompt=prompts["long_de"]["map_prompt"],
combine_prompt=prompts["long_de"]["combine_prompt"],
)
)
else:
raise ValueError(f"Summarization type {summarization_type} is not supported.")
return summarize_chain(
file_path=file.name, llm=llm[0], summarization_kwargs=summarization_kwargs
)
async def generate_summary_concurrently(
file_path: str, sections: List[str], llm: ChatOpenAI
) -> List[dict]:
"""Parallel summarization. This function is used to run different prompts for the same docs in parallel.
Args:
file_path (str): Path to the pdf file. This can either be a local path or a tempfile.TemporaryFileWrapper_.
sections (List[str]): List of sections to summarize selected by the user.
llm (ChatOpenAI): Language model to use for the summarization.
Returns:
List: List of summarizations.
"""
docs = load_docs(file_path=file_path, with_pageinfo=False)
summarization_kwargs = dict()
# create parallel tasks
tasks = []
for k in PARALLEL_SUMMARIZATION_ORDER:
if PARALLEL_SUMMARIZATION_MAPPING_INVERSE.get(k, k) in sections:
sk = summarization_kwargs.copy()
sk["prompt"] = prompts_parallel_summary[k]
print(f"Appending task for summary: {k}")
tasks.append(
async_generate_llmchain(llm=llm, docs=docs, llm_kwargs=sk, k=k)
)
print("-------------------")
# execute all coroutines concurrently
values = await asyncio.gather(*tasks)
# report return values
values_flattened = {}
for v in values:
values_flattened.update(v)
return values_flattened
PARALLEL_SUMMARIZATION_ORDER = [
"intro",
"darstellung_des_rechtsproblems",
"II. Die Entscheidung",
"angaben_ueber_das_urteil",
"sachverhalt",
"prozessgeschichte",
"rechtsproblem",
"loesung_des_gerichts",
]
PARALLEL_SUMMARIZATION_MAPPING = {
"I. Einleitung": "intro",
"Darstellung des Rechtsproblems": "darstellung_des_rechtsproblems",
"Angaben über das Urteil": "angaben_ueber_das_urteil",
"Sachverhalt": "sachverhalt",
"Prozessgeschichte": "prozessgeschichte",
"Rechtsproblem": "rechtsproblem",
"Lösung des Gerichts": "loesung_des_gerichts",
}
PARALLEL_SUMMARIZATION_MAPPING_INVERSE = {
v: k for k, v in PARALLEL_SUMMARIZATION_MAPPING.items()
}
def parallel_summarization(file: str, sections: List[str], llm: ChatOpenAI) -> str:
"""Wrapper for the parallel summarization function to make it compatible with gradio.
Args:
file (str): Path to the file. This can either be a local path or a tempfile.TemporaryFileWrapper_.
sections (List[str]): List of sections to summarize.
llm (ChatOpenAI): Language model.
Returns:
str: Summarization of the file.
"""
now = time.time()
values_flattened = asyncio.run(
generate_summary_concurrently(
file_path=file.name, sections=sections, llm=llm[0]
)
)
print("Time taken for complete parallel summarization: ", time.time() - now)
output = ""
for section in values_flattened.keys():
output += (
values_flattened.get(
section, PARALLEL_SUMMARIZATION_MAPPING_INVERSE.get(section, section)
)
+ "\n\n"
)
return output