Spaces:
Runtime error
Runtime error
| """Chain that implements the ReAct paper from https://arxiv.org/pdf/2210.03629.pdf.""" | |
| from typing import Any, List, Optional, Sequence | |
| from langchain_core.language_models import BaseLanguageModel | |
| from langchain_core.prompts import BasePromptTemplate | |
| from langchain_core.pydantic_v1 import Field | |
| from langchain.agents.agent import Agent, AgentExecutor, AgentOutputParser | |
| from langchain.agents.agent_types import AgentType | |
| from langchain.agents.react.output_parser import ReActOutputParser | |
| from langchain.agents.react.textworld_prompt import TEXTWORLD_PROMPT | |
| from langchain.agents.react.wiki_prompt import WIKI_PROMPT | |
| from langchain.agents.tools import Tool | |
| from langchain.agents.utils import validate_tools_single_input | |
| from langchain.docstore.base import Docstore | |
| from langchain.docstore.document import Document | |
| from langchain.tools.base import BaseTool | |
| class ReActDocstoreAgent(Agent): | |
| """Agent for the ReAct chain.""" | |
| output_parser: AgentOutputParser = Field(default_factory=ReActOutputParser) | |
| def _get_default_output_parser(cls, **kwargs: Any) -> AgentOutputParser: | |
| return ReActOutputParser() | |
| def _agent_type(self) -> str: | |
| """Return Identifier of an agent type.""" | |
| return AgentType.REACT_DOCSTORE | |
| def create_prompt(cls, tools: Sequence[BaseTool]) -> BasePromptTemplate: | |
| """Return default prompt.""" | |
| return WIKI_PROMPT | |
| def _validate_tools(cls, tools: Sequence[BaseTool]) -> None: | |
| validate_tools_single_input(cls.__name__, tools) | |
| super()._validate_tools(tools) | |
| if len(tools) != 2: | |
| raise ValueError(f"Exactly two tools must be specified, but got {tools}") | |
| tool_names = {tool.name for tool in tools} | |
| if tool_names != {"Lookup", "Search"}: | |
| raise ValueError( | |
| f"Tool names should be Lookup and Search, got {tool_names}" | |
| ) | |
| def observation_prefix(self) -> str: | |
| """Prefix to append the observation with.""" | |
| return "Observation: " | |
| def _stop(self) -> List[str]: | |
| return ["\nObservation:"] | |
| def llm_prefix(self) -> str: | |
| """Prefix to append the LLM call with.""" | |
| return "Thought:" | |
| class DocstoreExplorer: | |
| """Class to assist with exploration of a document store.""" | |
| def __init__(self, docstore: Docstore): | |
| """Initialize with a docstore, and set initial document to None.""" | |
| self.docstore = docstore | |
| self.document: Optional[Document] = None | |
| self.lookup_str = "" | |
| self.lookup_index = 0 | |
| def search(self, term: str) -> str: | |
| """Search for a term in the docstore, and if found save.""" | |
| result = self.docstore.search(term) | |
| if isinstance(result, Document): | |
| self.document = result | |
| return self._summary | |
| else: | |
| self.document = None | |
| return result | |
| def lookup(self, term: str) -> str: | |
| """Lookup a term in document (if saved).""" | |
| if self.document is None: | |
| raise ValueError("Cannot lookup without a successful search first") | |
| if term.lower() != self.lookup_str: | |
| self.lookup_str = term.lower() | |
| self.lookup_index = 0 | |
| else: | |
| self.lookup_index += 1 | |
| lookups = [p for p in self._paragraphs if self.lookup_str in p.lower()] | |
| if len(lookups) == 0: | |
| return "No Results" | |
| elif self.lookup_index >= len(lookups): | |
| return "No More Results" | |
| else: | |
| result_prefix = f"(Result {self.lookup_index + 1}/{len(lookups)})" | |
| return f"{result_prefix} {lookups[self.lookup_index]}" | |
| def _summary(self) -> str: | |
| return self._paragraphs[0] | |
| def _paragraphs(self) -> List[str]: | |
| if self.document is None: | |
| raise ValueError("Cannot get paragraphs without a document") | |
| return self.document.page_content.split("\n\n") | |
| class ReActTextWorldAgent(ReActDocstoreAgent): | |
| """Agent for the ReAct TextWorld chain.""" | |
| def create_prompt(cls, tools: Sequence[BaseTool]) -> BasePromptTemplate: | |
| """Return default prompt.""" | |
| return TEXTWORLD_PROMPT | |
| def _validate_tools(cls, tools: Sequence[BaseTool]) -> None: | |
| validate_tools_single_input(cls.__name__, tools) | |
| super()._validate_tools(tools) | |
| if len(tools) != 1: | |
| raise ValueError(f"Exactly one tool must be specified, but got {tools}") | |
| tool_names = {tool.name for tool in tools} | |
| if tool_names != {"Play"}: | |
| raise ValueError(f"Tool name should be Play, got {tool_names}") | |
| class ReActChain(AgentExecutor): | |
| """[Deprecated] Chain that implements the ReAct paper.""" | |
| def __init__(self, llm: BaseLanguageModel, docstore: Docstore, **kwargs: Any): | |
| """Initialize with the LLM and a docstore.""" | |
| docstore_explorer = DocstoreExplorer(docstore) | |
| tools = [ | |
| Tool( | |
| name="Search", | |
| func=docstore_explorer.search, | |
| description="Search for a term in the docstore.", | |
| ), | |
| Tool( | |
| name="Lookup", | |
| func=docstore_explorer.lookup, | |
| description="Lookup a term in the docstore.", | |
| ), | |
| ] | |
| agent = ReActDocstoreAgent.from_llm_and_tools(llm, tools) | |
| super().__init__(agent=agent, tools=tools, **kwargs) | |