Spaces:
Sleeping
Sleeping
| """Definitions for refine data summarizer.""" | |
| from typing import Any, List, Dict | |
| from langchain.chat_models.base import BaseChatModel | |
| from langchain.docstore.document import Document | |
| from langchain.prompts import PromptTemplate | |
| from langchain.document_loaders import TextLoader | |
| from langchain.text_splitter import TokenTextSplitter, CharacterTextSplitter | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain.chains.summarize import load_summarize_chain | |
| class RefineDataSummarizer: | |
| """Refine data summarizer.""" | |
| token_limit = {"gpt-3.5-turbo": 4096, | |
| "gpt-4": 8192, | |
| "gpt-3.5-turbo-16k": 16385, | |
| "gpt-3.5-turbo-1106": 16385, | |
| "gpt-4-1106-preview": 128000, | |
| "gemini-pro": 32768, | |
| "codechat-bison": 8192, | |
| "chat-bison": 8192 | |
| } | |
| def __init__( | |
| self, | |
| llm: BaseChatModel, | |
| prompt_template: str, | |
| refine_template: str, | |
| ): | |
| """Initialize the data summarizer.""" | |
| self.llm = llm | |
| self.llm_model = self.llm.model_name | |
| self.prompt = PromptTemplate.from_template(prompt_template.strip()) | |
| self.refine_prompt = PromptTemplate.from_template(refine_template.strip()) | |
| def get_summarization(self, | |
| text: str, | |
| chunk_num: int = 5, | |
| chunk_overlap: int = 30) -> Dict: | |
| """Get Summarization.""" | |
| text_splitter = TokenTextSplitter( | |
| chunk_size=self.token_limit[self.llm_model] // chunk_num, | |
| chunk_overlap=chunk_overlap, | |
| ) | |
| docs = [Document(page_content=t, metadata={"source": "local"}) for t in text_splitter.split_text(text)] | |
| chain = load_summarize_chain( | |
| llm=self.llm, | |
| chain_type="refine", | |
| question_prompt=self.prompt, | |
| refine_prompt=self.refine_prompt, | |
| return_intermediate_steps=True, | |
| input_key="input_documents", | |
| output_key="output_text", | |
| verbose=True, | |
| ) | |
| result = chain({"input_documents": docs}, return_only_outputs=True) | |
| return result | |