Spaces:
Runtime error
Runtime error
| """Chain that makes API calls and summarizes the responses to answer a question.""" | |
| from __future__ import annotations | |
| from typing import Any, Dict, List, Optional, Sequence, Tuple | |
| from urllib.parse import urlparse | |
| from langchain_core.language_models import BaseLanguageModel | |
| from langchain_core.prompts import BasePromptTemplate | |
| from langchain_core.pydantic_v1 import Field, root_validator | |
| from langchain.callbacks.manager import ( | |
| AsyncCallbackManagerForChainRun, | |
| CallbackManagerForChainRun, | |
| ) | |
| from langchain.chains.api.prompt import API_RESPONSE_PROMPT, API_URL_PROMPT | |
| from langchain.chains.base import Chain | |
| from langchain.chains.llm import LLMChain | |
| from langchain.utilities.requests import TextRequestsWrapper | |
| def _extract_scheme_and_domain(url: str) -> Tuple[str, str]: | |
| """Extract the scheme + domain from a given URL. | |
| Args: | |
| url (str): The input URL. | |
| Returns: | |
| return a 2-tuple of scheme and domain | |
| """ | |
| parsed_uri = urlparse(url) | |
| return parsed_uri.scheme, parsed_uri.netloc | |
| def _check_in_allowed_domain(url: str, limit_to_domains: Sequence[str]) -> bool: | |
| """Check if a URL is in the allowed domains. | |
| Args: | |
| url (str): The input URL. | |
| limit_to_domains (Sequence[str]): The allowed domains. | |
| Returns: | |
| bool: True if the URL is in the allowed domains, False otherwise. | |
| """ | |
| scheme, domain = _extract_scheme_and_domain(url) | |
| for allowed_domain in limit_to_domains: | |
| allowed_scheme, allowed_domain = _extract_scheme_and_domain(allowed_domain) | |
| if scheme == allowed_scheme and domain == allowed_domain: | |
| return True | |
| return False | |
| class APIChain(Chain): | |
| """Chain that makes API calls and summarizes the responses to answer a question. | |
| *Security Note*: This API chain uses the requests toolkit | |
| to make GET, POST, PATCH, PUT, and DELETE requests to an API. | |
| Exercise care in who is allowed to use this chain. If exposing | |
| to end users, consider that users will be able to make arbitrary | |
| requests on behalf of the server hosting the code. For example, | |
| users could ask the server to make a request to a private API | |
| that is only accessible from the server. | |
| Control access to who can submit issue requests using this toolkit and | |
| what network access it has. | |
| See https://python.langchain.com/docs/security for more information. | |
| """ | |
| api_request_chain: LLMChain | |
| api_answer_chain: LLMChain | |
| requests_wrapper: TextRequestsWrapper = Field(exclude=True) | |
| api_docs: str | |
| question_key: str = "question" #: :meta private: | |
| output_key: str = "output" #: :meta private: | |
| limit_to_domains: Optional[Sequence[str]] | |
| """Use to limit the domains that can be accessed by the API chain. | |
| * For example, to limit to just the domain `https://www.example.com`, set | |
| `limit_to_domains=["https://www.example.com"]`. | |
| * The default value is an empty tuple, which means that no domains are | |
| allowed by default. By design this will raise an error on instantiation. | |
| * Use a None if you want to allow all domains by default -- this is not | |
| recommended for security reasons, as it would allow malicious users to | |
| make requests to arbitrary URLS including internal APIs accessible from | |
| the server. | |
| """ | |
| def input_keys(self) -> List[str]: | |
| """Expect input key. | |
| :meta private: | |
| """ | |
| return [self.question_key] | |
| def output_keys(self) -> List[str]: | |
| """Expect output key. | |
| :meta private: | |
| """ | |
| return [self.output_key] | |
| def validate_api_request_prompt(cls, values: Dict) -> Dict: | |
| """Check that api request prompt expects the right variables.""" | |
| input_vars = values["api_request_chain"].prompt.input_variables | |
| expected_vars = {"question", "api_docs"} | |
| if set(input_vars) != expected_vars: | |
| raise ValueError( | |
| f"Input variables should be {expected_vars}, got {input_vars}" | |
| ) | |
| return values | |
| def validate_limit_to_domains(cls, values: Dict) -> Dict: | |
| """Check that allowed domains are valid.""" | |
| if "limit_to_domains" not in values: | |
| raise ValueError( | |
| "You must specify a list of domains to limit access using " | |
| "`limit_to_domains`" | |
| ) | |
| if not values["limit_to_domains"] and values["limit_to_domains"] is not None: | |
| raise ValueError( | |
| "Please provide a list of domains to limit access using " | |
| "`limit_to_domains`." | |
| ) | |
| return values | |
| def validate_api_answer_prompt(cls, values: Dict) -> Dict: | |
| """Check that api answer prompt expects the right variables.""" | |
| input_vars = values["api_answer_chain"].prompt.input_variables | |
| expected_vars = {"question", "api_docs", "api_url", "api_response"} | |
| if set(input_vars) != expected_vars: | |
| raise ValueError( | |
| f"Input variables should be {expected_vars}, got {input_vars}" | |
| ) | |
| return values | |
| def _call( | |
| self, | |
| inputs: Dict[str, Any], | |
| run_manager: Optional[CallbackManagerForChainRun] = None, | |
| ) -> Dict[str, str]: | |
| _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() | |
| question = inputs[self.question_key] | |
| api_url = self.api_request_chain.predict( | |
| question=question, | |
| api_docs=self.api_docs, | |
| callbacks=_run_manager.get_child(), | |
| ) | |
| _run_manager.on_text(api_url, color="green", end="\n", verbose=self.verbose) | |
| api_url = api_url.strip() | |
| if self.limit_to_domains and not _check_in_allowed_domain( | |
| api_url, self.limit_to_domains | |
| ): | |
| raise ValueError( | |
| f"{api_url} is not in the allowed domains: {self.limit_to_domains}" | |
| ) | |
| api_response = self.requests_wrapper.get(api_url) | |
| _run_manager.on_text( | |
| api_response, color="yellow", end="\n", verbose=self.verbose | |
| ) | |
| answer = self.api_answer_chain.predict( | |
| question=question, | |
| api_docs=self.api_docs, | |
| api_url=api_url, | |
| api_response=api_response, | |
| callbacks=_run_manager.get_child(), | |
| ) | |
| return {self.output_key: answer} | |
| async def _acall( | |
| self, | |
| inputs: Dict[str, Any], | |
| run_manager: Optional[AsyncCallbackManagerForChainRun] = None, | |
| ) -> Dict[str, str]: | |
| _run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager() | |
| question = inputs[self.question_key] | |
| api_url = await self.api_request_chain.apredict( | |
| question=question, | |
| api_docs=self.api_docs, | |
| callbacks=_run_manager.get_child(), | |
| ) | |
| await _run_manager.on_text( | |
| api_url, color="green", end="\n", verbose=self.verbose | |
| ) | |
| api_url = api_url.strip() | |
| if self.limit_to_domains and not _check_in_allowed_domain( | |
| api_url, self.limit_to_domains | |
| ): | |
| raise ValueError( | |
| f"{api_url} is not in the allowed domains: {self.limit_to_domains}" | |
| ) | |
| api_response = await self.requests_wrapper.aget(api_url) | |
| await _run_manager.on_text( | |
| api_response, color="yellow", end="\n", verbose=self.verbose | |
| ) | |
| answer = await self.api_answer_chain.apredict( | |
| question=question, | |
| api_docs=self.api_docs, | |
| api_url=api_url, | |
| api_response=api_response, | |
| callbacks=_run_manager.get_child(), | |
| ) | |
| return {self.output_key: answer} | |
| def from_llm_and_api_docs( | |
| cls, | |
| llm: BaseLanguageModel, | |
| api_docs: str, | |
| headers: Optional[dict] = None, | |
| api_url_prompt: BasePromptTemplate = API_URL_PROMPT, | |
| api_response_prompt: BasePromptTemplate = API_RESPONSE_PROMPT, | |
| limit_to_domains: Optional[Sequence[str]] = tuple(), | |
| **kwargs: Any, | |
| ) -> APIChain: | |
| """Load chain from just an LLM and the api docs.""" | |
| get_request_chain = LLMChain(llm=llm, prompt=api_url_prompt) | |
| requests_wrapper = TextRequestsWrapper(headers=headers) | |
| get_answer_chain = LLMChain(llm=llm, prompt=api_response_prompt) | |
| return cls( | |
| api_request_chain=get_request_chain, | |
| api_answer_chain=get_answer_chain, | |
| requests_wrapper=requests_wrapper, | |
| api_docs=api_docs, | |
| limit_to_domains=limit_to_domains, | |
| **kwargs, | |
| ) | |
| def _chain_type(self) -> str: | |
| return "api_chain" | |