File size: 1,804 Bytes
3b60800 c4a74f5 3b60800 c4a74f5 3b60800 59352f7 3b60800 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 | from typing import Optional, Any, Callable
from pydantic import ConfigDict, BaseModel
from langchain_core.messages import AIMessage, ToolMessage
from langchain_core.runnables import RunnableSerializable, RunnableConfig, Runnable
from langchain_core.runnables.utils import Input, Output
from langchain_core.tools import BaseTool
class RunnableWithTools(RunnableSerializable[Input, Output]):
bound: Runnable[Input, Output]
tools: dict[str, BaseTool]
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
def invoke(
self,
input: Input,
config: Optional[RunnableConfig] = None,
max_depth: Optional[int] = 3,
**kwargs: Any
) -> Output:
depth = 0
message = None
while depth < max_depth:
message = self.bound.invoke(input)
if isinstance(message, AIMessage) and message.tool_calls and self.tools:
text = ''
if isinstance(message.content, list) and 'text' in message.content[0]:
text += message.content[0]['text']
elif isinstance(message.content, str):
text += message.content
# input.append(AIMessage(content=text, **message.additional_kwargs))
input.append(message)
text = ''
for tool_call in message.tool_calls:
selected_tool = self.tools[tool_call["name"].lower()]
if selected_tool:
tool_msg = selected_tool.invoke(tool_call)
text += '\n' + str(tool_msg)
input.append(ToolMessage(tool_call_id=tool_call['id'], content=text))
depth += 1
else:
break
return message
|