Spaces:
Runtime error
Runtime error
| """Chain that runs an arbitrary python function.""" | |
| import functools | |
| import logging | |
| from typing import Any, Awaitable, Callable, Dict, List, Optional | |
| from langchain_core.pydantic_v1 import Field | |
| from langchain.callbacks.manager import ( | |
| AsyncCallbackManagerForChainRun, | |
| CallbackManagerForChainRun, | |
| ) | |
| from langchain.chains.base import Chain | |
| logger = logging.getLogger(__name__) | |
| class TransformChain(Chain): | |
| """Chain that transforms the chain output. | |
| Example: | |
| .. code-block:: python | |
| from langchain.chains import TransformChain | |
| transform_chain = TransformChain(input_variables=["text"], | |
| output_variables["entities"], transform=func()) | |
| """ | |
| input_variables: List[str] | |
| """The keys expected by the transform's input dictionary.""" | |
| output_variables: List[str] | |
| """The keys returned by the transform's output dictionary.""" | |
| transform_cb: Callable[[Dict[str, str]], Dict[str, str]] = Field(alias="transform") | |
| """The transform function.""" | |
| atransform_cb: Optional[ | |
| Callable[[Dict[str, Any]], Awaitable[Dict[str, Any]]] | |
| ] = Field(None, alias="atransform") | |
| """The async coroutine transform function.""" | |
| def _log_once(msg: str) -> None: | |
| """Log a message once. | |
| :meta private: | |
| """ | |
| logger.warning(msg) | |
| def input_keys(self) -> List[str]: | |
| """Expect input keys. | |
| :meta private: | |
| """ | |
| return self.input_variables | |
| def output_keys(self) -> List[str]: | |
| """Return output keys. | |
| :meta private: | |
| """ | |
| return self.output_variables | |
| def _call( | |
| self, | |
| inputs: Dict[str, str], | |
| run_manager: Optional[CallbackManagerForChainRun] = None, | |
| ) -> Dict[str, str]: | |
| return self.transform_cb(inputs) | |
| async def _acall( | |
| self, | |
| inputs: Dict[str, Any], | |
| run_manager: Optional[AsyncCallbackManagerForChainRun] = None, | |
| ) -> Dict[str, Any]: | |
| if self.atransform_cb is not None: | |
| return await self.atransform_cb(inputs) | |
| else: | |
| self._log_once( | |
| "TransformChain's atransform is not provided, falling" | |
| " back to synchronous transform" | |
| ) | |
| return self.transform_cb(inputs) | |