Spaces:
Runtime error
Runtime error
| """Chain that combines documents by stuffing into context.""" | |
| from typing import Any, Dict, List, Optional, Tuple | |
| from langchain_core.prompts import BasePromptTemplate, format_document | |
| from langchain_core.prompts.prompt import PromptTemplate | |
| from langchain_core.pydantic_v1 import Extra, Field, root_validator | |
| from langchain.callbacks.manager import Callbacks | |
| from langchain.chains.combine_documents.base import ( | |
| BaseCombineDocumentsChain, | |
| ) | |
| from langchain.chains.llm import LLMChain | |
| from langchain.docstore.document import Document | |
| def _get_default_document_prompt() -> PromptTemplate: | |
| return PromptTemplate(input_variables=["page_content"], template="{page_content}") | |
| class StuffDocumentsChain(BaseCombineDocumentsChain): | |
| """Chain that combines documents by stuffing into context. | |
| This chain takes a list of documents and first combines them into a single string. | |
| It does this by formatting each document into a string with the `document_prompt` | |
| and then joining them together with `document_separator`. It then adds that new | |
| string to the inputs with the variable name set by `document_variable_name`. | |
| Those inputs are then passed to the `llm_chain`. | |
| Example: | |
| .. code-block:: python | |
| from langchain.chains import StuffDocumentsChain, LLMChain | |
| from langchain_core.prompts import PromptTemplate | |
| from langchain.llms import OpenAI | |
| # This controls how each document will be formatted. Specifically, | |
| # it will be passed to `format_document` - see that function for more | |
| # details. | |
| document_prompt = PromptTemplate( | |
| input_variables=["page_content"], | |
| template="{page_content}" | |
| ) | |
| document_variable_name = "context" | |
| llm = OpenAI() | |
| # The prompt here should take as an input variable the | |
| # `document_variable_name` | |
| prompt = PromptTemplate.from_template( | |
| "Summarize this content: {context}" | |
| ) | |
| llm_chain = LLMChain(llm=llm, prompt=prompt) | |
| chain = StuffDocumentsChain( | |
| llm_chain=llm_chain, | |
| document_prompt=document_prompt, | |
| document_variable_name=document_variable_name | |
| ) | |
| """ | |
| llm_chain: LLMChain | |
| """LLM chain which is called with the formatted document string, | |
| along with any other inputs.""" | |
| document_prompt: BasePromptTemplate = Field( | |
| default_factory=_get_default_document_prompt | |
| ) | |
| """Prompt to use to format each document, gets passed to `format_document`.""" | |
| document_variable_name: str | |
| """The variable name in the llm_chain to put the documents in. | |
| If only one variable in the llm_chain, this need not be provided.""" | |
| document_separator: str = "\n\n" | |
| """The string with which to join the formatted documents""" | |
| class Config: | |
| """Configuration for this pydantic object.""" | |
| extra = Extra.forbid | |
| arbitrary_types_allowed = True | |
| def get_default_document_variable_name(cls, values: Dict) -> Dict: | |
| """Get default document variable name, if not provided. | |
| If only one variable is present in the llm_chain.prompt, | |
| we can infer that the formatted documents should be passed in | |
| with this variable name. | |
| """ | |
| llm_chain_variables = values["llm_chain"].prompt.input_variables | |
| if "document_variable_name" not in values: | |
| if len(llm_chain_variables) == 1: | |
| values["document_variable_name"] = llm_chain_variables[0] | |
| else: | |
| raise ValueError( | |
| "document_variable_name must be provided if there are " | |
| "multiple llm_chain_variables" | |
| ) | |
| else: | |
| if values["document_variable_name"] not in llm_chain_variables: | |
| raise ValueError( | |
| f"document_variable_name {values['document_variable_name']} was " | |
| f"not found in llm_chain input_variables: {llm_chain_variables}" | |
| ) | |
| return values | |
| def input_keys(self) -> List[str]: | |
| extra_keys = [ | |
| k for k in self.llm_chain.input_keys if k != self.document_variable_name | |
| ] | |
| return super().input_keys + extra_keys | |
| def _get_inputs(self, docs: List[Document], **kwargs: Any) -> dict: | |
| """Construct inputs from kwargs and docs. | |
| Format and the join all the documents together into one input with name | |
| `self.document_variable_name`. The pluck any additional variables | |
| from **kwargs. | |
| Args: | |
| docs: List of documents to format and then join into single input | |
| **kwargs: additional inputs to chain, will pluck any other required | |
| arguments from here. | |
| Returns: | |
| dictionary of inputs to LLMChain | |
| """ | |
| # Format each document according to the prompt | |
| doc_strings = [format_document(doc, self.document_prompt) for doc in docs] | |
| # Join the documents together to put them in the prompt. | |
| inputs = { | |
| k: v | |
| for k, v in kwargs.items() | |
| if k in self.llm_chain.prompt.input_variables | |
| } | |
| inputs[self.document_variable_name] = self.document_separator.join(doc_strings) | |
| return inputs | |
| def prompt_length(self, docs: List[Document], **kwargs: Any) -> Optional[int]: | |
| """Return the prompt length given the documents passed in. | |
| This can be used by a caller to determine whether passing in a list | |
| of documents would exceed a certain prompt length. This useful when | |
| trying to ensure that the size of a prompt remains below a certain | |
| context limit. | |
| Args: | |
| docs: List[Document], a list of documents to use to calculate the | |
| total prompt length. | |
| Returns: | |
| Returns None if the method does not depend on the prompt length, | |
| otherwise the length of the prompt in tokens. | |
| """ | |
| inputs = self._get_inputs(docs, **kwargs) | |
| prompt = self.llm_chain.prompt.format(**inputs) | |
| return self.llm_chain._get_num_tokens(prompt) | |
| def combine_docs( | |
| self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any | |
| ) -> Tuple[str, dict]: | |
| """Stuff all documents into one prompt and pass to LLM. | |
| Args: | |
| docs: List of documents to join together into one variable | |
| callbacks: Optional callbacks to pass along | |
| **kwargs: additional parameters to use to get inputs to LLMChain. | |
| Returns: | |
| The first element returned is the single string output. The second | |
| element returned is a dictionary of other keys to return. | |
| """ | |
| inputs = self._get_inputs(docs, **kwargs) | |
| # Call predict on the LLM. | |
| return self.llm_chain.predict(callbacks=callbacks, **inputs), {} | |
| async def acombine_docs( | |
| self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any | |
| ) -> Tuple[str, dict]: | |
| """Async stuff all documents into one prompt and pass to LLM. | |
| Args: | |
| docs: List of documents to join together into one variable | |
| callbacks: Optional callbacks to pass along | |
| **kwargs: additional parameters to use to get inputs to LLMChain. | |
| Returns: | |
| The first element returned is the single string output. The second | |
| element returned is a dictionary of other keys to return. | |
| """ | |
| inputs = self._get_inputs(docs, **kwargs) | |
| # Call predict on the LLM. | |
| return await self.llm_chain.apredict(callbacks=callbacks, **inputs), {} | |
| def _chain_type(self) -> str: | |
| return "stuff_documents_chain" | |