transcript_summary / utils /refine_summary.py
XINZHANG-Geotab's picture
Upload 3 files
364f00c verified
"""Definitions for refine data summarizer."""
from typing import Any, List, Dict
from langchain.chat_models.base import BaseChatModel
from langchain.docstore.document import Document
from langchain.prompts import PromptTemplate
from langchain.document_loaders import TextLoader
from langchain.text_splitter import TokenTextSplitter, CharacterTextSplitter
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.chains.summarize import load_summarize_chain
class RefineDataSummarizer:
"""Refine data summarizer."""
token_limit = {"gpt-3.5-turbo": 4096,
"gpt-4": 8192,
"gpt-3.5-turbo-16k": 16385,
"gpt-3.5-turbo-1106": 16385,
"gpt-4-1106-preview": 128000,
"gemini-pro": 32768,
"codechat-bison": 8192,
"chat-bison": 8192
}
def __init__(
self,
llm: BaseChatModel,
prompt_template: str,
refine_template: str,
):
"""Initialize the data summarizer."""
self.llm = llm
self.llm_model = self.llm.model_name
self.prompt = PromptTemplate.from_template(prompt_template.strip())
self.refine_prompt = PromptTemplate.from_template(refine_template.strip())
def get_summarization(self,
text: str,
chunk_num: int = 5,
chunk_overlap: int = 30) -> Dict:
"""Get Summarization."""
text_splitter = TokenTextSplitter(
chunk_size=self.token_limit[self.llm_model] // chunk_num,
chunk_overlap=chunk_overlap,
)
docs = [Document(page_content=t, metadata={"source": "local"}) for t in text_splitter.split_text(text)]
chain = load_summarize_chain(
llm=self.llm,
chain_type="refine",
question_prompt=self.prompt,
refine_prompt=self.refine_prompt,
return_intermediate_steps=True,
input_key="input_documents",
output_key="output_text",
verbose=True,
)
result = chain({"input_documents": docs}, return_only_outputs=True)
return result