Spaces:
Runtime error
Runtime error
| """Load summarizing chains.""" | |
| from typing import Any, Mapping, Optional, Protocol | |
| from langchain_core.language_models import BaseLanguageModel | |
| from langchain_core.prompts import BasePromptTemplate | |
| from langchain.callbacks.manager import Callbacks | |
| from langchain.chains.combine_documents.base import BaseCombineDocumentsChain | |
| from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain | |
| from langchain.chains.combine_documents.reduce import ReduceDocumentsChain | |
| from langchain.chains.combine_documents.refine import RefineDocumentsChain | |
| from langchain.chains.combine_documents.stuff import StuffDocumentsChain | |
| from langchain.chains.llm import LLMChain | |
| from langchain.chains.summarize import map_reduce_prompt, refine_prompts, stuff_prompt | |
| class LoadingCallable(Protocol): | |
| """Interface for loading the combine documents chain.""" | |
| def __call__( | |
| self, llm: BaseLanguageModel, **kwargs: Any | |
| ) -> BaseCombineDocumentsChain: | |
| """Callable to load the combine documents chain.""" | |
| def _load_stuff_chain( | |
| llm: BaseLanguageModel, | |
| prompt: BasePromptTemplate = stuff_prompt.PROMPT, | |
| document_variable_name: str = "text", | |
| verbose: Optional[bool] = None, | |
| **kwargs: Any, | |
| ) -> StuffDocumentsChain: | |
| llm_chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose) | |
| # TODO: document prompt | |
| return StuffDocumentsChain( | |
| llm_chain=llm_chain, | |
| document_variable_name=document_variable_name, | |
| verbose=verbose, | |
| **kwargs, | |
| ) | |
| def _load_map_reduce_chain( | |
| llm: BaseLanguageModel, | |
| map_prompt: BasePromptTemplate = map_reduce_prompt.PROMPT, | |
| combine_prompt: BasePromptTemplate = map_reduce_prompt.PROMPT, | |
| combine_document_variable_name: str = "text", | |
| map_reduce_document_variable_name: str = "text", | |
| collapse_prompt: Optional[BasePromptTemplate] = None, | |
| reduce_llm: Optional[BaseLanguageModel] = None, | |
| collapse_llm: Optional[BaseLanguageModel] = None, | |
| verbose: Optional[bool] = None, | |
| token_max: int = 3000, | |
| callbacks: Callbacks = None, | |
| **kwargs: Any, | |
| ) -> MapReduceDocumentsChain: | |
| map_chain = LLMChain( | |
| llm=llm, prompt=map_prompt, verbose=verbose, callbacks=callbacks | |
| ) | |
| _reduce_llm = reduce_llm or llm | |
| reduce_chain = LLMChain( | |
| llm=_reduce_llm, prompt=combine_prompt, verbose=verbose, callbacks=callbacks | |
| ) | |
| # TODO: document prompt | |
| combine_documents_chain = StuffDocumentsChain( | |
| llm_chain=reduce_chain, | |
| document_variable_name=combine_document_variable_name, | |
| verbose=verbose, | |
| callbacks=callbacks, | |
| ) | |
| if collapse_prompt is None: | |
| collapse_chain = None | |
| if collapse_llm is not None: | |
| raise ValueError( | |
| "collapse_llm provided, but collapse_prompt was not: please " | |
| "provide one or stop providing collapse_llm." | |
| ) | |
| else: | |
| _collapse_llm = collapse_llm or llm | |
| collapse_chain = StuffDocumentsChain( | |
| llm_chain=LLMChain( | |
| llm=_collapse_llm, | |
| prompt=collapse_prompt, | |
| verbose=verbose, | |
| callbacks=callbacks, | |
| ), | |
| document_variable_name=combine_document_variable_name, | |
| ) | |
| reduce_documents_chain = ReduceDocumentsChain( | |
| combine_documents_chain=combine_documents_chain, | |
| collapse_documents_chain=collapse_chain, | |
| token_max=token_max, | |
| verbose=verbose, | |
| callbacks=callbacks, | |
| ) | |
| return MapReduceDocumentsChain( | |
| llm_chain=map_chain, | |
| reduce_documents_chain=reduce_documents_chain, | |
| document_variable_name=map_reduce_document_variable_name, | |
| verbose=verbose, | |
| callbacks=callbacks, | |
| **kwargs, | |
| ) | |
| def _load_refine_chain( | |
| llm: BaseLanguageModel, | |
| question_prompt: BasePromptTemplate = refine_prompts.PROMPT, | |
| refine_prompt: BasePromptTemplate = refine_prompts.REFINE_PROMPT, | |
| document_variable_name: str = "text", | |
| initial_response_name: str = "existing_answer", | |
| refine_llm: Optional[BaseLanguageModel] = None, | |
| verbose: Optional[bool] = None, | |
| **kwargs: Any, | |
| ) -> RefineDocumentsChain: | |
| initial_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose) | |
| _refine_llm = refine_llm or llm | |
| refine_chain = LLMChain(llm=_refine_llm, prompt=refine_prompt, verbose=verbose) | |
| return RefineDocumentsChain( | |
| initial_llm_chain=initial_chain, | |
| refine_llm_chain=refine_chain, | |
| document_variable_name=document_variable_name, | |
| initial_response_name=initial_response_name, | |
| verbose=verbose, | |
| **kwargs, | |
| ) | |
| def load_summarize_chain( | |
| llm: BaseLanguageModel, | |
| chain_type: str = "stuff", | |
| verbose: Optional[bool] = None, | |
| **kwargs: Any, | |
| ) -> BaseCombineDocumentsChain: | |
| """Load summarizing chain. | |
| Args: | |
| llm: Language Model to use in the chain. | |
| chain_type: Type of document combining chain to use. Should be one of "stuff", | |
| "map_reduce", and "refine". | |
| verbose: Whether chains should be run in verbose mode or not. Note that this | |
| applies to all chains that make up the final chain. | |
| Returns: | |
| A chain to use for summarizing. | |
| """ | |
| loader_mapping: Mapping[str, LoadingCallable] = { | |
| "stuff": _load_stuff_chain, | |
| "map_reduce": _load_map_reduce_chain, | |
| "refine": _load_refine_chain, | |
| } | |
| if chain_type not in loader_mapping: | |
| raise ValueError( | |
| f"Got unsupported chain type: {chain_type}. " | |
| f"Should be one of {loader_mapping.keys()}" | |
| ) | |
| return loader_mapping[chain_type](llm, verbose=verbose, **kwargs) | |