Shroominic commited on
Commit
a14ae24
·
1 Parent(s): e1776b1

oaifunctions agent override

Browse files
Files changed (1) hide show
  1. codeinterpreterapi/functions_agent.py +293 -0
codeinterpreterapi/functions_agent.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Module implements an agent that uses OpenAI's APIs function enabled API.
3
+
4
+ This file is a modified version of the original file
5
+ from langchain/agents/openai_functions_agent/base.py.
6
+ Credits go to the original authors :)
7
+ """
8
+
9
+
10
+ import json
11
+ from dataclasses import dataclass
12
+ from json import JSONDecodeError
13
+ from typing import Any, List, Optional, Sequence, Tuple, Union
14
+
15
+ from pydantic import root_validator
16
+
17
+ from langchain.agents import BaseSingleActionAgent
18
+ from langchain.base_language import BaseLanguageModel
19
+ from langchain.callbacks.base import BaseCallbackManager
20
+ from langchain.callbacks.manager import Callbacks
21
+ from langchain.chat_models.openai import ChatOpenAI
22
+ from langchain.schema import BasePromptTemplate
23
+ from langchain.prompts.chat import (
24
+ BaseMessagePromptTemplate,
25
+ ChatPromptTemplate,
26
+ HumanMessagePromptTemplate,
27
+ MessagesPlaceholder,
28
+ )
29
+ from langchain.schema import (
30
+ AgentAction,
31
+ AgentFinish,
32
+ AIMessage,
33
+ BaseMessage,
34
+ FunctionMessage,
35
+ OutputParserException,
36
+ HumanMessage,
37
+ SystemMessage,
38
+ )
39
+ from langchain.tools import BaseTool
40
+ from langchain.tools.convert_to_openai import format_tool_to_openai_function
41
+
42
+
43
+ @dataclass
44
+ class _FunctionsAgentAction(AgentAction):
45
+ message_log: List[BaseMessage]
46
+
47
+
48
+ def _convert_agent_action_to_messages(
49
+ agent_action: AgentAction, observation: str
50
+ ) -> List[BaseMessage]:
51
+ """Convert an agent action to a message.
52
+
53
+ This code is used to reconstruct the original AI message from the agent action.
54
+
55
+ Args:
56
+ agent_action: Agent action to convert.
57
+
58
+ Returns:
59
+ AIMessage that corresponds to the original tool invocation.
60
+ """
61
+ if isinstance(agent_action, _FunctionsAgentAction):
62
+ return agent_action.message_log + [
63
+ _create_function_message(agent_action, observation)
64
+ ]
65
+ else:
66
+ return [AIMessage(content=agent_action.log)]
67
+
68
+
69
+ def _create_function_message(
70
+ agent_action: AgentAction, observation: str
71
+ ) -> FunctionMessage:
72
+ """Convert agent action and observation into a function message.
73
+ Args:
74
+ agent_action: the tool invocation request from the agent
75
+ observation: the result of the tool invocation
76
+ Returns:
77
+ FunctionMessage that corresponds to the original tool invocation
78
+ """
79
+ if not isinstance(observation, str):
80
+ try:
81
+ content = json.dumps(observation, ensure_ascii=False)
82
+ except Exception:
83
+ content = str(observation)
84
+ else:
85
+ content = observation
86
+ return FunctionMessage(
87
+ name=agent_action.tool,
88
+ content=content,
89
+ )
90
+
91
+
92
+ def _format_intermediate_steps(
93
+ intermediate_steps: List[Tuple[AgentAction, str]],
94
+ ) -> List[BaseMessage]:
95
+ """Format intermediate steps.
96
+ Args:
97
+ intermediate_steps: Steps the LLM has taken to date, along with observations
98
+ Returns:
99
+ list of messages to send to the LLM for the next prediction
100
+ """
101
+ messages = []
102
+
103
+ for intermediate_step in intermediate_steps:
104
+ agent_action, observation = intermediate_step
105
+ messages.extend(_convert_agent_action_to_messages(agent_action, observation))
106
+
107
+ return messages
108
+
109
+
110
+ async def _parse_ai_message(message: BaseMessage, llm: BaseLanguageModel) -> Union[AgentAction, AgentFinish]:
111
+ """Parse an AI message."""
112
+ if not isinstance(message, AIMessage):
113
+ raise TypeError(f"Expected an AI message got {type(message)}")
114
+
115
+ function_call = message.additional_kwargs.get("function_call", {})
116
+
117
+ if function_call:
118
+ function_call = message.additional_kwargs["function_call"]
119
+ function_name = function_call["name"]
120
+ try:
121
+ _tool_input = json.loads(function_call["arguments"])
122
+ except JSONDecodeError:
123
+ if function_name == "python":
124
+ code = function_call["arguments"]
125
+ _tool_input = {
126
+ "code": code,
127
+ }
128
+ else:
129
+ raise OutputParserException(
130
+ f"Could not parse tool input: {function_call} because "
131
+ f"the `arguments` is not valid JSON."
132
+ )
133
+
134
+ # HACK HACK HACK:
135
+ # The code that encodes tool input into Open AI uses a special variable
136
+ # name called `__arg1` to handle old style tools that do not expose a
137
+ # schema and expect a single string argument as an input.
138
+ # We unpack the argument here if it exists.
139
+ # Open AI does not support passing in a JSON array as an argument.
140
+ if "__arg1" in _tool_input:
141
+ tool_input = _tool_input["__arg1"]
142
+ else:
143
+ tool_input = _tool_input
144
+
145
+ content_msg = "responded: {content}\n" if message.content else "\n"
146
+
147
+ return _FunctionsAgentAction(
148
+ tool=function_name,
149
+ tool_input=tool_input,
150
+ log=f"\nInvoking: `{function_name}` with `{tool_input}`\n{content_msg}\n",
151
+ message_log=[message],
152
+ )
153
+
154
+ return AgentFinish(return_values={"output": message.content}, log=message.content)
155
+
156
+
157
+ class OpenAIFunctionsAgent(BaseSingleActionAgent):
158
+ """An Agent driven by OpenAIs function powered API.
159
+
160
+ Args:
161
+ llm: This should be an instance of ChatOpenAI, specifically a model
162
+ that supports using `functions`.
163
+ tools: The tools this agent has access to.
164
+ prompt: The prompt for this agent, should support agent_scratchpad as one
165
+ of the variables. For an easy way to construct this prompt, use
166
+ `OpenAIFunctionsAgent.create_prompt(...)`
167
+ """
168
+
169
+ llm: BaseLanguageModel
170
+ tools: Sequence[BaseTool]
171
+ prompt: BasePromptTemplate
172
+
173
+ def get_allowed_tools(self) -> List[str]:
174
+ """Get allowed tools."""
175
+ return list([t.name for t in self.tools])
176
+
177
+ @root_validator
178
+ def validate_llm(cls, values: dict) -> dict:
179
+ if not isinstance(values["llm"], ChatOpenAI):
180
+ raise ValueError("Only supported with ChatOpenAI models.")
181
+ return values
182
+
183
+ @root_validator
184
+ def validate_prompt(cls, values: dict) -> dict:
185
+ prompt: BasePromptTemplate = values["prompt"]
186
+ if "agent_scratchpad" not in prompt.input_variables:
187
+ raise ValueError(
188
+ "`agent_scratchpad` should be one of the variables in the prompt, "
189
+ f"got {prompt.input_variables}"
190
+ )
191
+ return values
192
+
193
+ @property
194
+ def input_keys(self) -> List[str]:
195
+ """Get input keys. Input refers to user input here."""
196
+ return ["input"]
197
+
198
+ @property
199
+ def functions(self) -> List[dict]:
200
+ return [dict(format_tool_to_openai_function(t)) for t in self.tools]
201
+
202
+ def plan(self): raise NotImplementedError
203
+
204
+ async def aplan(
205
+ self,
206
+ intermediate_steps: List[Tuple[AgentAction, str]],
207
+ callbacks: Callbacks = None,
208
+ **kwargs: Any,
209
+ ) -> Union[AgentAction, AgentFinish]:
210
+ """Given input, decided what to do.
211
+
212
+ Args:
213
+ intermediate_steps: Steps the LLM has taken to date,
214
+ along with observations
215
+ **kwargs: User inputs.
216
+
217
+ Returns:
218
+ Action specifying what tool to use.
219
+ """
220
+ agent_scratchpad = _format_intermediate_steps(intermediate_steps)
221
+ selected_inputs = {
222
+ k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
223
+ }
224
+ full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad)
225
+ prompt = self.prompt.format_prompt(**full_inputs)
226
+ messages = prompt.to_messages()
227
+ predicted_message = await self.llm.apredict_messages(
228
+ messages, functions=self.functions, callbacks=callbacks
229
+ )
230
+ agent_decision = await _parse_ai_message(predicted_message, self.llm)
231
+ return agent_decision
232
+
233
+ @classmethod
234
+ def create_prompt(
235
+ cls,
236
+ system_message: Optional[SystemMessage] = SystemMessage(
237
+ content="You are a helpful AI assistant."
238
+ ),
239
+ extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
240
+ ) -> BasePromptTemplate:
241
+ """Create prompt for this agent.
242
+
243
+ Args:
244
+ system_message: Message to use as the system message that will be the
245
+ first in the prompt.
246
+ extra_prompt_messages: Prompt messages that will be placed between the
247
+ system message and the new human input.
248
+
249
+ Returns:
250
+ A prompt template to pass into this agent.
251
+ """
252
+ _prompts = extra_prompt_messages or []
253
+ messages: List[Union[BaseMessagePromptTemplate, BaseMessage]]
254
+ if system_message:
255
+ messages = [system_message]
256
+ else:
257
+ messages = []
258
+
259
+ messages.extend(
260
+ [
261
+ *_prompts,
262
+ HumanMessagePromptTemplate.from_template("{input}"),
263
+ MessagesPlaceholder(variable_name="agent_scratchpad"),
264
+ ]
265
+ )
266
+ return ChatPromptTemplate(messages=messages) # type: ignore
267
+
268
+ @classmethod
269
+ def from_llm_and_tools(
270
+ cls,
271
+ llm: BaseLanguageModel,
272
+ tools: Sequence[BaseTool],
273
+ callback_manager: Optional[BaseCallbackManager] = None,
274
+ extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
275
+ system_message: Optional[SystemMessage] = SystemMessage(
276
+ content="You are a helpful AI assistant."
277
+ ),
278
+ **kwargs: Any,
279
+ ) -> BaseSingleActionAgent:
280
+ """Construct an agent from an LLM and tools."""
281
+ if not isinstance(llm, ChatOpenAI):
282
+ raise ValueError("Only supported with ChatOpenAI models.")
283
+ prompt = cls.create_prompt(
284
+ extra_prompt_messages=extra_prompt_messages,
285
+ system_message=system_message,
286
+ )
287
+ return cls(
288
+ llm=llm,
289
+ prompt=prompt,
290
+ tools=tools,
291
+ callback_manager=callback_manager, # type: ignore
292
+ **kwargs,
293
+ )